From b80c1a659050650edadced3c13a4a0e44e1eaf16 Mon Sep 17 00:00:00 2001 From: Renan Soares Date: Tue, 26 May 2026 17:41:15 -0300 Subject: [PATCH 1/2] refactor: extract helpers to reduce pylint complexity --- .../data/builtin/datetime_migration.py | 149 ++-- aredis_om/model/migrations/data/migrator.py | 64 +- aredis_om/model/model.py | 807 ++++++++++-------- 3 files changed, 579 insertions(+), 441 deletions(-) diff --git a/aredis_om/model/migrations/data/builtin/datetime_migration.py b/aredis_om/model/migrations/data/builtin/datetime_migration.py index e4c0607c..1ff9f05f 100644 --- a/aredis_om/model/migrations/data/builtin/datetime_migration.py +++ b/aredis_om/model/migrations/data/builtin/datetime_migration.py @@ -363,6 +363,79 @@ async def _save_progress_if_needed(self, current_model: str, total_keys: int): stats=self.stats.get_summary(), ) + async def _collect_hash_keys(self, model_class) -> List[str]: + """Collect all Redis hash keys for a model.""" + key_pattern = model_class.make_key("*") + all_keys = [] + scan_iter = self.redis.scan_iter(match=key_pattern, _type="HASH") + + async for key in scan_iter: # type: ignore[misc] + if isinstance(key, bytes): + key = key.decode("utf-8") + all_keys.append(key) + + return all_keys + + @staticmethod + def _normalize_hash_data(hash_data: Dict[Any, Any]) -> Dict[str, Any]: + """Normalize hash payload to string keys and values.""" + if hash_data and isinstance(next(iter(hash_data.keys())), bytes): + return { + k.decode("utf-8"): v.decode("utf-8") + for k, v in hash_data.items() + } + + return hash_data + + async def _process_hash_key( + self, + key: str, + datetime_fields: List[str], + model_name: str, + total_keys: int, + ) -> bool: + """Process a single hash key and return whether it was handled.""" + if key in self.processed_keys_set: + return False + + try: + hash_data = await self.redis.hgetall(key) # type: ignore[misc] + except Exception as e: + log.warning(f"Failed to get hash data from {key}: {e}") + return False + + if not hash_data: + return False + + hash_data = self._normalize_hash_data(hash_data) + updates = {} + + for field_name in datetime_fields: + if field_name in hash_data: + value = hash_data[field_name] + converted, success = self._safe_convert_datetime_value( + key, field_name, value + ) + + if success and converted != value: + updates[field_name] = str(converted) + + if updates: + try: + await self.redis.hset(key, mapping=updates) # type: ignore[misc] + except Exception as e: + log.error(f"Failed to update hash {key}: {e}") + if self.failure_mode == ConversionFailureMode.FAIL: + raise DataMigrationError(f"Failed to update hash {key}: {e}") + + self.processed_keys_set.add(key) + self.stats.add_processed_key() + self._processed_keys += 1 + + self._check_error_threshold() + await self._save_progress_if_needed(model_name, total_keys) + return True + async def _clear_progress_on_completion(self): """Clear saved progress when migration completes successfully.""" if self.migration_state: @@ -504,16 +577,7 @@ async def _process_hash_model( self, model_class, datetime_fields: List[str] ) -> None: """Process HashModel instances to convert datetime fields with enhanced error handling.""" - # Get all keys for this model - key_pattern = model_class.make_key("*") - - # Collect all keys first for batch processing - all_keys = [] - scan_iter = self.redis.scan_iter(match=key_pattern, _type="HASH") - async for key in scan_iter: # type: ignore[misc] - if isinstance(key, bytes): - key = key.decode("utf-8") - all_keys.append(key) + all_keys = await self._collect_hash_keys(model_class) total_keys = len(all_keys) log.info( @@ -531,64 +595,13 @@ async def _process_hash_model( for key in batch_keys: try: - # Skip if already processed (resume capability) - if key in self.processed_keys_set: - continue - - # Get all fields from the hash - try: - hash_data = await self.redis.hgetall(key) # type: ignore[misc] - except Exception as e: - log.warning(f"Failed to get hash data from {key}: {e}") - continue - - if not hash_data: - continue - - # Convert byte keys/values to strings if needed - if hash_data and isinstance(next(iter(hash_data.keys())), bytes): - hash_data = { - k.decode("utf-8"): v.decode("utf-8") - for k, v in hash_data.items() - } - - updates = {} - - # Check each datetime field with safe conversion - for field_name in datetime_fields: - if field_name in hash_data: - value = hash_data[field_name] - converted, success = self._safe_convert_datetime_value( - key, field_name, value - ) - - if success and converted != value: - updates[field_name] = str(converted) - - # Update the hash if we have changes - if updates: - try: - await self.redis.hset(key, mapping=updates) # type: ignore[misc] - except Exception as e: - log.error(f"Failed to update hash {key}: {e}") - if self.failure_mode == ConversionFailureMode.FAIL: - raise DataMigrationError( - f"Failed to update hash {key}: {e}" - ) - - # Mark key as processed - self.processed_keys_set.add(key) - self.stats.add_processed_key() - self._processed_keys += 1 - processed_count += 1 - - # Error threshold checking - self._check_error_threshold() - - # Save progress periodically - await self._save_progress_if_needed( - model_class.__name__, total_keys - ) + if await self._process_hash_key( + key, + datetime_fields, + model_class.__name__, + total_keys, + ): + processed_count += 1 except DataMigrationError: # Re-raise migration errors diff --git a/aredis_om/model/migrations/data/migrator.py b/aredis_om/model/migrations/data/migrator.py index 23456775..6b9ba1f0 100644 --- a/aredis_om/model/migrations/data/migrator.py +++ b/aredis_om/model/migrations/data/migrator.py @@ -329,21 +329,18 @@ async def run_migrations_with_monitoring( "errors": [], } - if verbose: - print(f"Found {len(pending_migrations)} pending migration(s):") - for migration in pending_migrations: - print(f"- {migration.migration_id}: {migration.description}") + self._log_pending_migrations(pending_migrations, verbose) if dry_run: if verbose: print("Dry run mode - no changes will be applied.") - return { - "applied_count": len(pending_migrations), - "total_migrations": len(pending_migrations), - "performance_stats": monitor.get_stats(), - "errors": [], - "dry_run": True, - } + return self._build_monitoring_result( + applied_count=len(pending_migrations), + pending_migrations=pending_migrations, + monitor=monitor, + errors=[], + dry_run=True, + ) applied_count = 0 errors = [] @@ -400,6 +397,33 @@ async def run_migrations_with_monitoring( monitor.finish() + result = self._build_monitoring_result( + applied_count=applied_count, + pending_migrations=pending_migrations, + monitor=monitor, + errors=errors, + ) + + self._log_monitoring_summary(result, verbose) + + return result + + def _log_pending_migrations( + self, pending_migrations: List[BaseMigration], verbose: bool + ) -> None: + if verbose: + print(f"Found {len(pending_migrations)} pending migration(s):") + for migration in pending_migrations: + print(f"- {migration.migration_id}: {migration.description}") + + def _build_monitoring_result( + self, + applied_count: int, + pending_migrations: List[BaseMigration], + monitor: PerformanceMonitor, + errors: List[Dict[str, Any]], + dry_run: bool = False, + ) -> Dict[str, Any]: result = { "applied_count": applied_count, "total_migrations": len(pending_migrations), @@ -412,18 +436,28 @@ async def run_migrations_with_monitoring( ), } + if dry_run: + result["dry_run"] = True + + return result + + def _log_monitoring_summary( + self, result: Dict[str, Any], verbose: bool + ) -> None: if verbose: - print(f"Applied {applied_count}/{len(pending_migrations)} migration(s).") + total_migrations = result["total_migrations"] + applied_count = result["applied_count"] + print(f"Applied {applied_count}/{total_migrations} migration(s).") stats = result["performance_stats"] if stats: print(f"Total time: {stats.get('total_time_seconds', 0):.2f}s") if "items_per_second" in stats: # type: ignore - print(f"Performance: {stats['items_per_second']:.1f} items/second") # type: ignore + print( + f"Performance: {stats['items_per_second']:.1f} items/second" + ) # type: ignore if "peak_memory_mb" in stats: # type: ignore print(f"Peak memory: {stats['peak_memory_mb']:.1f} MB") # type: ignore - return result - async def rollback_migration( self, migration_id: str, dry_run: bool = False, verbose: bool = False ) -> bool: diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 07f6de14..4cf8b0a2 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1541,6 +1541,161 @@ def expand_tag_value(value): ) return escaper.escape(str(value)) + @staticmethod + def _convert_numeric_value(value): + if isinstance(value, Enum): + value = value.value + if isinstance(value, (datetime.datetime, datetime.date)): + if isinstance(value, datetime.date) and not isinstance( + value, datetime.datetime + ): + value = datetime.datetime.combine( + value, datetime.time.min, tzinfo=datetime.timezone.utc + ) + value = value.timestamp() + return value + + @classmethod + def _resolve_text_value(cls, field_name: str, op: Operators, value: Any) -> str: + result = f"@{field_name}_fts:" + if op is Operators.EQ: + return f'{result}"{value}"' + if op is Operators.NE: + return f'-({result}"{value}")' + if op is Operators.LIKE: + return f"{result}{value}" + raise QueryNotSupportedError( + "Only equals (=), not-equals (!=), and like() " + "comparisons are supported for TEXT fields. " + f"Docs: {ERRORS_URL}#E5" + ) + + @classmethod + def _resolve_numeric_value( + cls, field_name: str, op: Operators, value: Any + ) -> str: + if op is Operators.IN: + converted_values = [cls._convert_numeric_value(v) for v in value] + parts = [f"(@{field_name}:[{v} {v}])" for v in converted_values] + return "|".join(parts) + if op is Operators.NOT_IN: + converted_values = [cls._convert_numeric_value(v) for v in value] + parts = [f"(@{field_name}:[{v} {v}])" for v in converted_values] + return f"-({' | '.join(parts)})" + + value = cls._convert_numeric_value(value) + if op is Operators.EQ: + return f"@{field_name}:[{value} {value}]" + if op is Operators.NE: + return f"-(@{field_name}:[{value} {value}])" + if op is Operators.GT: + return f"@{field_name}:[({value} +inf]" + if op is Operators.LT: + return f"@{field_name}:[-inf ({value}]" + if op is Operators.GE: + return f"@{field_name}:[{value} +inf]" + if op is Operators.LE: + return f"@{field_name}:[-inf {value}]" + return "" + + @classmethod + def _resolve_tag_eq_value( + cls, + field_name: str, + field_info: PydanticFieldInfo, + value: Any, + model_class: Optional[Type["RedisModel"]], + ) -> str: + separator_char = getattr( + field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR + ) + if value == separator_char: + log.warning( + "Your query against the field %s is for a single character, %s, " + "that is used internally by Redis OM Python. We must ignore " + "this portion of the query. Please review your query to find " + "an alternative query that uses a string containing more than " + "just the character %s.", + field_name, + separator_char, + separator_char, + ) + return "" + if isinstance(value, bool): + if model_class: + is_hash_model = any( + base.__name__ == "HashModel" for base in model_class.__mro__ + ) + value = ("1" if value else "0") if is_hash_model else value + return "@{field_name}:{{{value}}}".format( + field_name=field_name, value=value + ) + if isinstance(value, int): + return f"@{field_name}:[{value} {value}]" + if separator_char in value: + values: filter = filter(None, value.split(separator_char)) + return "".join( + "@{field_name}:{{{value}}}".format( + field_name=field_name, value=escaper.escape(value) + ) + for value in values + ) + value = escaper.escape(value) + return "@{field_name}:{{{value}}}".format(field_name=field_name, value=value) + + @classmethod + def _resolve_tag_value( + cls, + field_name: str, + field_info: PydanticFieldInfo, + op: Operators, + value: Any, + model_class: Optional[Type["RedisModel"]], + ) -> str: + if op is Operators.EQ: + return cls._resolve_tag_eq_value(field_name, field_info, value, model_class) + if op is Operators.NE: + value = escaper.escape(value) + return "-(@{field_name}:{{{value}}})".format( + field_name=field_name, value=value + ) + if op is Operators.IN: + expanded_value = cls.expand_tag_value(value) + return "(@{field_name}:{{{expanded_value}}})".format( + field_name=field_name, expanded_value=expanded_value + ) + if op is Operators.NOT_IN: + expanded_value = cls.expand_tag_value(value) + return "-(@{field_name}:{{{expanded_value}}})".format( + field_name=field_name, expanded_value=expanded_value + ) + if op is Operators.STARTSWITH: + expanded_value = cls.expand_tag_value(value) + return "(@{field_name}:{{{expanded_value}*}})".format( + field_name=field_name, expanded_value=expanded_value + ) + if op is Operators.ENDSWITH: + expanded_value = cls.expand_tag_value(value) + return "(@{field_name}:{{*{expanded_value}}})".format( + field_name=field_name, expanded_value=expanded_value + ) + if op is Operators.CONTAINS: + expanded_value = cls.expand_tag_value(value) + return "(@{field_name}:{{*{expanded_value}*}})".format( + field_name=field_name, expanded_value=expanded_value + ) + return "" + + @staticmethod + def _resolve_geo_value(field_name: str, op: Operators, value: Any) -> str: + if not isinstance(value, GeoFilter): + raise QuerySyntaxError( + "You can only use a GeoFilter object with a GEO field." + ) + if op is Operators.EQ: + return f"@{field_name}:[{value}]" + return "" + @classmethod def resolve_value( cls, @@ -1553,163 +1708,22 @@ def resolve_value( model_class: Optional[Type["RedisModel"]] = None, ) -> str: # The 'field_name' should already include the correct prefix - result = "" if parents: prefix = "_".join([p[0] for p in parents]) field_name = f"{prefix}_{field_name}" if field_type is RediSearchFieldTypes.TEXT: - result = f"@{field_name}_fts:" - if op is Operators.EQ: - result += f'"{value}"' - elif op is Operators.NE: - result = f'-({result}"{value}")' - elif op is Operators.LIKE: - result += value - else: - raise QueryNotSupportedError( - "Only equals (=), not-equals (!=), and like() " - "comparisons are supported for TEXT fields. " - f"Docs: {ERRORS_URL}#E5" - ) + return cls._resolve_text_value(field_name, op, value) elif field_type is RediSearchFieldTypes.NUMERIC: - # Helper to convert a single value for NUMERIC queries - def convert_numeric_value(v): - # Convert Enum to its value (fixes #108) - if isinstance(v, Enum): - v = v.value - # Convert datetime objects to timestamps - if isinstance(v, (datetime.datetime, datetime.date)): - if isinstance(v, datetime.date) and not isinstance( - v, datetime.datetime - ): - # Use UTC midnight so query conversion matches storage conversion. - v = datetime.datetime.combine( - v, datetime.time.min, tzinfo=datetime.timezone.utc - ) - v = v.timestamp() - return v - - if op is Operators.IN: - # Handle IN operator for NUMERIC fields (fixes #499) - # Convert each value and create OR of range queries - converted_values = [convert_numeric_value(v) for v in value] - parts = [f"(@{field_name}:[{v} {v}])" for v in converted_values] - result += "|".join(parts) - elif op is Operators.NOT_IN: - # Handle NOT_IN operator for NUMERIC fields - converted_values = [convert_numeric_value(v) for v in value] - parts = [f"(@{field_name}:[{v} {v}])" for v in converted_values] - result += f"-({' | '.join(parts)})" - else: - value = convert_numeric_value(value) - - if op is Operators.EQ: - result += f"@{field_name}:[{value} {value}]" - elif op is Operators.NE: - result += f"-(@{field_name}:[{value} {value}])" - elif op is Operators.GT: - result += f"@{field_name}:[({value} +inf]" - elif op is Operators.LT: - result += f"@{field_name}:[-inf ({value}]" - elif op is Operators.GE: - result += f"@{field_name}:[{value} +inf]" - elif op is Operators.LE: - result += f"@{field_name}:[-inf {value}]" + return cls._resolve_numeric_value(field_name, op, value) # TODO: How will we know the difference between a multi-value use of a TAG # field and our hidden use of TAG for exact-match queries? elif field_type is RediSearchFieldTypes.TAG: - if op is Operators.EQ: - separator_char = getattr( - field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR - ) - if value == separator_char: - # The value is ONLY the TAG field separator character -- - # this is not going to work. - log.warning( - "Your query against the field %s is for a single character, %s, " - "that is used internally by Redis OM Python. We must ignore " - "this portion of the query. Please review your query to find " - "an alternative query that uses a string containing more than " - "just the character %s.", - field_name, - separator_char, - separator_char, - ) - return "" - if isinstance(value, bool): - # For HashModel, convert boolean to "1"/"0" to match storage format - # For JsonModel, keep as boolean since JSON supports native booleans - if model_class: - # Check if this is a HashModel by checking the class hierarchy - is_hash_model = any( - base.__name__ == "HashModel" for base in model_class.__mro__ - ) - bool_value = ("1" if value else "0") if is_hash_model else value - else: - bool_value = value - result = "@{field_name}:{{{value}}}".format( - field_name=field_name, value=bool_value - ) - elif isinstance(value, int): - # This if will hit only if the field is a primary key of type int - result = f"@{field_name}:[{value} {value}]" - elif separator_char in value: - # The value contains the TAG field separator. We can work - # around this by breaking apart the values and unioning them - # with multiple field:{} queries. - values: filter = filter(None, value.split(separator_char)) - for value in values: - value = escaper.escape(value) - result += "@{field_name}:{{{value}}}".format( - field_name=field_name, value=value - ) - else: - value = escaper.escape(value) - result += "@{field_name}:{{{value}}}".format( - field_name=field_name, value=value - ) - elif op is Operators.NE: - value = escaper.escape(value) - result += "-(@{field_name}:{{{value}}})".format( - field_name=field_name, value=value - ) - elif op is Operators.IN: - expanded_value = cls.expand_tag_value(value) - result += "(@{field_name}:{{{expanded_value}}})".format( - field_name=field_name, expanded_value=expanded_value - ) - elif op is Operators.NOT_IN: - # TODO: Implement NOT_IN, test this... - expanded_value = cls.expand_tag_value(value) - result += "-(@{field_name}:{{{expanded_value}}})".format( - field_name=field_name, expanded_value=expanded_value - ) - elif op is Operators.STARTSWITH: - expanded_value = cls.expand_tag_value(value) - result += "(@{field_name}:{{{expanded_value}*}})".format( - field_name=field_name, expanded_value=expanded_value - ) - elif op is Operators.ENDSWITH: - expanded_value = cls.expand_tag_value(value) - result += "(@{field_name}:{{*{expanded_value}}})".format( - field_name=field_name, expanded_value=expanded_value - ) - elif op is Operators.CONTAINS: - expanded_value = cls.expand_tag_value(value) - result += "(@{field_name}:{{*{expanded_value}*}})".format( - field_name=field_name, expanded_value=expanded_value - ) + return cls._resolve_tag_value(field_name, field_info, op, value, model_class) elif field_type is RediSearchFieldTypes.GEO: - if not isinstance(value, GeoFilter): - raise QuerySyntaxError( - "You can only use a GeoFilter object with a GEO field." - ) + return cls._resolve_geo_value(field_name, op, value) - if op is Operators.EQ: - result += f"@{field_name}:[{value}]" - - return result + return "" def resolve_redisearch_pagination(self): """Resolve pagination options for a query.""" @@ -1835,9 +1849,7 @@ def _resolve_redisearch_query(self, expression: ExpressionOrNegated) -> str: return result - async def execute( - self, exhaust_results=True, return_raw_result=False, return_query_args=False - ): + def _build_execute_args(self) -> tuple[List[Union[str, bytes]], bool]: args: List[Union[str, bytes]] = [ "FT.SEARCH", self.model.Meta.index_name, @@ -1862,29 +1874,22 @@ async def execute( if self.nocontent: args.append("NOCONTENT") - # Check if we have complex fields that RediSearch RETURN clause can't handle - use_full_document_fallback = False - if self.projected_fields: - use_full_document_fallback = self._has_complex_projected_fields() + use_full_document_fallback = bool( + self.projected_fields and self._has_complex_projected_fields() + ) - # Add RETURN clause to the args list, not to the query string - # Skip RETURN clause if we need full documents for complex field projection + # Add RETURN clause to the args list, not to the query string. + # Skip RETURN clause if we need full documents for complex field projection. if self.projected_fields and not use_full_document_fallback: args.extend( ["RETURN", str(len(self.projected_fields))] + self.projected_fields ) - if return_query_args: - return self.model.Meta.index_name, args - - # Reset the cache if we're executing from offset 0. - if self.offset == 0: - self._model_cache.clear() + return args, use_full_document_fallback - # If the offset is greater than 0, we're paginating through a result set, - # so append the new results to results already in the cache. + async def _execute_command(self, args: List[Union[str, bytes]]): try: - raw_result = await self.model.db().execute_command(*args) + return await self.model.db().execute_command(*args) except Exception as e: error_msg = str(e).lower() @@ -1900,32 +1905,55 @@ async def execute( # Re-raise the original exception raise - if return_raw_result: - return raw_result - count = raw_result[0] - # Handle different result processing based on what was requested + async def _parse_execute_results( + self, raw_result: Any, use_full_document_fallback: bool + ): if self.projected_fields and use_full_document_fallback: # Complex field projection - use full document fallback if self.return_as_dict: - results = await self._parse_full_document_projection_as_dict(raw_result) - else: - results = await self._parse_full_document_projection_as_models( - raw_result - ) - elif self.projected_fields and self.return_as_dict: + return await self._parse_full_document_projection_as_dict(raw_result) + return await self._parse_full_document_projection_as_models(raw_result) + + if self.projected_fields and self.return_as_dict: # .values('field1', 'field2') - specific fields as dicts - results = self._parse_projected_results(raw_result) - elif self.projected_fields and not self.return_as_dict: + return self._parse_projected_results(raw_result) + + if self.projected_fields and not self.return_as_dict: # .only('field1', 'field2') - partial model instances - results = self._parse_projected_models(raw_result) - elif self.return_as_dict and not self.projected_fields: + return self._parse_projected_models(raw_result) + + if self.return_as_dict and not self.projected_fields: # .values() - all fields as dicts model_results = self.model.from_redis(raw_result, self.knn) - results = [model.model_dump() for model in model_results] - else: - # Normal query - full model instances - results = self.model.from_redis(raw_result, self.knn) + return [model.model_dump() for model in model_results] + + # Normal query - full model instances + return self.model.from_redis(raw_result, self.knn) + + async def execute( + self, exhaust_results=True, return_raw_result=False, return_query_args=False + ): + args, use_full_document_fallback = self._build_execute_args() + + if return_query_args: + return self.model.Meta.index_name, args + + # Reset the cache if we're executing from offset 0. + if self.offset == 0: + self._model_cache.clear() + + # If the offset is greater than 0, we're paginating through a result set, + # so append the new results to results already in the cache. + raw_result = await self._execute_command(args) + + if return_raw_result: + return raw_result + count = raw_result[0] + + results = await self._parse_execute_results( + raw_result, use_full_document_fallback + ) self._model_cache += results if not exhaust_results: @@ -2409,44 +2437,30 @@ class ModelMeta(ModelMetaclass): model_config: RedisOmConfig model_fields: Dict[str, FieldInfo] # type: ignore[assignment] - def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 - meta = attrs.pop("Meta", None) - - # Capture original FieldInfo objects from attrs before Pydantic processes them. - # Pydantic 2.12+ may convert custom FieldInfo subclasses to plain PydanticFieldInfo - # for Annotated types, losing custom attributes like index, sortable, etc. + @staticmethod + def _capture_original_field_infos(attrs): original_field_infos: Dict[str, FieldInfo] = {} if PYDANTIC_V2: for attr_name, attr_value in attrs.items(): if isinstance(attr_value, FieldInfo): original_field_infos[attr_name] = attr_value + return original_field_infos - # Duplicate logic from Pydantic to filter config kwargs because if they are - # passed directly including the registry Pydantic will pass them over to the - # superclass causing an error - allowed_config_kwargs: Set[str] = { + @staticmethod + def _allowed_config_kwargs(): + return { key for key in dir(ConfigDict) - if not ( - key.startswith("__") and key.endswith("__") - ) # skip dunder methods and attributes + if not (key.startswith("__") and key.endswith("__")) } - config_kwargs = { - key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs - } - - new_class: RedisModel = super().__new__( - cls, name, bases, attrs, **config_kwargs - ) - + @staticmethod + def _build_meta(new_class, meta, base_meta): # The fact that there is a Meta field and _meta field is important: a # user may have given us a Meta object with their configuration, while # we might have inherited _meta from a parent class, and should # therefore use some of the inherited fields. meta = meta or getattr(new_class, "Meta", None) - base_meta = getattr(new_class, "_meta", None) - if meta and meta != DefaultMeta and meta != base_meta: new_class.Meta = meta new_class._meta = meta @@ -2465,13 +2479,8 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 ) new_class.Meta = new_class._meta - is_indexed = kwargs.get("index", None) is True - - if is_indexed and new_class.model_config.get("index", None) is True: - raise RedisModelError( - f"{new_class.__name__} cannot be indexed, only one model can be indexed in an inheritance tree" - ) - + @staticmethod + def _set_index_config(new_class, is_indexed): if PYDANTIC_V2: new_class.model_config["index"] = is_indexed else: @@ -2485,14 +2494,20 @@ class Config: new_class.Config = Config - # Create proxies for each model field so that we can use the field - # in queries, like Model.get(Model.field_name == 1) - # Only set if the model is has index=True + @staticmethod + def _model_fields(new_class): if PYDANTIC_V2: - model_fields = new_class.model_fields - else: - model_fields = new_class.__fields__ + return new_class.model_fields + return new_class.__fields__ + @staticmethod + def _apply_field_proxies( + new_class, + model_fields, + original_field_infos, + is_indexed, + ): + custom_pk_count = 0 for field_name, field in model_fields.items(): pydantic_field = field # Keep reference to Pydantic's processed field if type(field) is PydanticFieldInfo: @@ -2541,7 +2556,6 @@ class Config: # should set up the PrimaryKeyAccessor. We only do this when there's # exactly one custom primary key. Multiple custom primary keys will be # caught by validate_primary_key() later. - custom_pk_count = 0 for field_name, field in model_fields.items(): if field_name == "pk": continue @@ -2557,6 +2571,10 @@ class Config: if getattr(check_field, "primary_key", None) is True: custom_pk_count += 1 + return custom_pk_count + + @staticmethod + def _setup_primary_key_accessor(new_class, model_fields, custom_pk_count): # If there's exactly one custom primary key (not the default 'pk'), set up # a PrimaryKeyAccessor so that .pk always returns the correct value. # This fixes GitHub issue #570. @@ -2572,6 +2590,8 @@ class Config: # Set up PrimaryKeyAccessor descriptor for .pk access setattr(new_class, "pk", PrimaryKeyAccessor()) + @staticmethod + def _apply_embedded_model_rules(new_class): # For embedded models, clear the primary_key from meta since they don't # need primary keys - they're stored as part of their parent document, # not as separate Redis keys. This fixes GitHub issue #496. @@ -2580,6 +2600,8 @@ class Config: if getattr(new_class._meta, "embedded", False): new_class._meta.primary_key = None + @staticmethod + def _apply_meta_defaults(new_class, base_meta): if not getattr(new_class._meta, "global_key_prefix", None): new_class._meta.global_key_prefix = getattr( base_meta, "global_key_prefix", "" @@ -2610,6 +2632,8 @@ class Config: f"{new_class._meta.model_key_prefix}:index" ) + @staticmethod + def _register_indexed_model(new_class, bases): # Model is indexed and not an abstract model class or embedded model, so we should let the # Migrator create indexes for it. if ( @@ -2620,6 +2644,52 @@ class Config: key = f"{new_class.__module__}.{new_class.__qualname__}" model_registry[key] = new_class + def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 + meta = attrs.pop("Meta", None) + + # Capture original FieldInfo objects from attrs before Pydantic processes them. + # Pydantic 2.12+ may convert custom FieldInfo subclasses to plain PydanticFieldInfo + # for Annotated types, losing custom attributes like index, sortable, etc. + original_field_infos = cls._capture_original_field_infos(attrs) + + # Duplicate logic from Pydantic to filter config kwargs because if they are + # passed directly including the registry Pydantic will pass them over to the + # superclass causing an error + allowed_config_kwargs = cls._allowed_config_kwargs() + + config_kwargs = {key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs} + + new_class: RedisModel = super().__new__( + cls, name, bases, attrs, **config_kwargs + ) + + base_meta = getattr(new_class, "_meta", None) + cls._build_meta(new_class, meta, base_meta) + + is_indexed = kwargs.get("index", None) is True + + if is_indexed and new_class.model_config.get("index", None) is True: + raise RedisModelError( + f"{new_class.__name__} cannot be indexed, only one model can be indexed in an inheritance tree" + ) + + cls._set_index_config(new_class, is_indexed) + + # Create proxies for each model field so that we can use the field + # in queries, like Model.get(Model.field_name == 1) + # Only set if the model is has index=True + model_fields = cls._model_fields(new_class) + custom_pk_count = cls._apply_field_proxies( + new_class, + model_fields, + original_field_infos, + is_indexed, + ) + cls._setup_primary_key_accessor(new_class, model_fields, custom_pk_count) + cls._apply_embedded_model_rules(new_class) + cls._apply_meta_defaults(new_class, base_meta) + cls._register_indexed_model(new_class, bases) + return new_class @@ -3713,161 +3783,182 @@ def schema_for_type( # into the values of the list or the fields of the model to # find any values marked as indexed. if is_container_type and not is_vector: - field_type = get_origin(typ) - if field_type == Literal: - path = f"{json_path}.{name}" - return cls.schema_for_type( - path, - name, - name_prefix, - str, - field_info, - parent_type=field_type, - ) - else: - embedded_cls = get_args(typ) - if not embedded_cls: - log.warning( - "Model %s defined an empty list or tuple field: %s", cls, name - ) - return "" - path = f"{json_path}.{name}[*]" - embedded_cls = embedded_cls[0] - return cls.schema_for_type( - path, - name, - name_prefix, - embedded_cls, - field_info, - parent_type=field_type, - ) + return cls._schema_for_container_type( + json_path, name, name_prefix, typ, field_info + ) elif field_is_model: - name_prefix = f"{name_prefix}_{name}" if name_prefix else name - sub_fields = [] - for embedded_name, field in typ.model_fields.items(): - if ( - hasattr(field, "metadata") - and len(field.metadata) > 0 - and isinstance(field.metadata[0], FieldInfo) - ): - field_info = field.metadata[0] - else: - field_info = field - - if parent_is_container_type: - # We'll store this value either as a JavaScript array, so - # the correct JSONPath expression is to refer directly to - # attribute names after the container notation, e.g. - # orders[*].created_date. - path = json_path - else: - # All other fields should use dot notation with both the - # current field name and "embedded" field name, e.g., - # order.address.street_line_1. - path = f"{json_path}.{name}" - sub_fields.append( - cls.schema_for_type( - path, - embedded_name, - name_prefix, - # field.annotation, - get_outer_type(field), - field_info, - parent_type=typ, - ) - ) - return " ".join(filter(None, sub_fields)) + return cls._schema_for_embedded_model( + json_path, name, name_prefix, typ, parent_is_container_type + ) # NOTE: This is the termination point for recursion. We've descended # into models and lists until we found an actual value to index. elif should_index: - index_field_name = f"{name_prefix}_{name}" if name_prefix else name - if parent_is_container_type: - # If we're indexing the this field as a JavaScript array, then - # the currently built-up JSONPath expression will be - # "field_name[*]", which is what we want to use. - path = json_path + return cls._schema_for_leaf_type( + json_path, + name, + name_prefix, + typ, + field_info, + is_vector, + vector_options, + parent_is_container_type, + parent_is_model_in_container, + ) + return "" + + @classmethod + def _schema_for_container_type( + cls, + json_path: str, + name: str, + name_prefix: str, + typ: Any, + field_info: PydanticFieldInfo, + ) -> str: + field_type = get_origin(typ) + if field_type == Literal: + path = f"{json_path}.{name}" + return cls.schema_for_type( + path, + name, + name_prefix, + str, + field_info, + parent_type=field_type, + ) + + embedded_cls = get_args(typ) + if not embedded_cls: + log.warning("Model %s defined an empty list or tuple field: %s", cls, name) + return "" + + path = f"{json_path}.{name}[*]" + return cls.schema_for_type( + path, + name, + name_prefix, + embedded_cls[0], + field_info, + parent_type=field_type, + ) + + @classmethod + def _schema_for_embedded_model( + cls, + json_path: str, + name: str, + name_prefix: str, + typ: Union[Type[RedisModel], Any], + parent_is_container_type: bool, + ) -> str: + name_prefix = f"{name_prefix}_{name}" if name_prefix else name + sub_fields = [] + for embedded_name, field in typ.model_fields.items(): + if ( + hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): + field_info = field.metadata[0] else: - path = f"{json_path}.{name}" - sortable = getattr(field_info, "sortable", False) - case_sensitive = getattr(field_info, "case_sensitive", False) - full_text_search = getattr(field_info, "full_text_search", False) - - # For more complicated compound validators (e.g. PositiveInt), we might get a _GenericAlias rather than - # a proper type, we can pull the type information from the origin of the first argument. - if not isinstance(typ, type): - type_args = typing_get_args(field_info.annotation) - typ = ( - getattr(type_args[0], "__origin__", type_args[0]) - if type_args - else typ - ) + field_info = field - # Get separator from field_info, defaulting to pipe - separator = getattr( - field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR + # Lists store embedded values directly under the container path. + path = json_path if parent_is_container_type else f"{json_path}.{name}" + sub_fields.append( + cls.schema_for_type( + path, + embedded_name, + name_prefix, + get_outer_type(field), + field_info, + parent_type=typ, + ) ) + return " ".join(filter(None, sub_fields)) - if is_vector and vector_options: - schema = f"{path} AS {index_field_name} {vector_options.schema}" - elif parent_is_container_type or parent_is_model_in_container: - if typ is not str: - raise RedisModelError( - "List and tuple fields can only contain strings. " - f"Problem field: {name}. Docs: {ERRORS_URL}#E12" - ) - if full_text_search is True: - raise RedisModelError( - "List and tuple fields cannot be indexed for full-text " - f"search. Problem field: {name}. Docs: {ERRORS_URL}#E13" - ) - # List/tuple fields are indexed as TAG fields and can be sortable - schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" + @classmethod + def _schema_for_leaf_type( + cls, + json_path: str, + name: str, + name_prefix: str, + typ: Union[Type[RedisModel], Any], + field_info: PydanticFieldInfo, + is_vector: bool, + vector_options: Optional[VectorFieldOptions], + parent_is_container_type: bool, + parent_is_model_in_container: bool, + ) -> str: + index_field_name = f"{name_prefix}_{name}" if name_prefix else name + path = json_path if parent_is_container_type else f"{json_path}.{name}" + sortable = getattr(field_info, "sortable", False) + case_sensitive = getattr(field_info, "case_sensitive", False) + full_text_search = getattr(field_info, "full_text_search", False) + + # For more complicated compound validators (e.g. PositiveInt), we might + # get a _GenericAlias rather than a proper type. + if not isinstance(typ, type): + type_args = typing_get_args(field_info.annotation) + typ = getattr(type_args[0], "__origin__", type_args[0]) if type_args else typ + + separator = getattr(field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR) + + if is_vector and vector_options: + schema = f"{path} AS {index_field_name} {vector_options.schema}" + elif parent_is_container_type or parent_is_model_in_container: + if typ is not str: + raise RedisModelError( + "List and tuple fields can only contain strings. " + f"Problem field: {name}. Docs: {ERRORS_URL}#E12" + ) + if full_text_search is True: + raise RedisModelError( + "List and tuple fields cannot be indexed for full-text " + f"search. Problem field: {name}. Docs: {ERRORS_URL}#E13" + ) + schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" + if sortable is True: + schema += " SORTABLE" + if case_sensitive is True: + schema += " CASESENSITIVE" + elif typ is bool: + schema = f"{path} AS {index_field_name} TAG" + if sortable is True: + schema += " SORTABLE" + elif typ in [CoordinateType, Coordinates]: + schema = f"{path} AS {index_field_name} GEO" + if sortable is True: + schema += " SORTABLE" + elif is_numeric_type(typ): + schema = f"{path} AS {index_field_name} NUMERIC" + if sortable is True: + schema += " SORTABLE" + elif issubclass(typ, str): + if full_text_search is True: + schema = ( + f"{path} AS {index_field_name} TAG SEPARATOR {separator} " + f"{path} AS {index_field_name}_fts TEXT" + ) if sortable is True: + # NOTE: With the current preview release, making a field + # full-text searchable and sortable only makes the TEXT + # field sortable. schema += " SORTABLE" if case_sensitive is True: - schema += " CASESENSITIVE" - elif typ is bool: - schema = f"{path} AS {index_field_name} TAG" - if sortable is True: - schema += " SORTABLE" - elif typ in [CoordinateType, Coordinates]: - schema = f"{path} AS {index_field_name} GEO" - if sortable is True: - schema += " SORTABLE" - elif is_numeric_type(typ): - schema = f"{path} AS {index_field_name} NUMERIC" - if sortable is True: - schema += " SORTABLE" - elif issubclass(typ, str): - if full_text_search is True: - schema = ( - f"{path} AS {index_field_name} TAG SEPARATOR {separator} " - f"{path} AS {index_field_name}_fts TEXT" - ) - if sortable is True: - # NOTE: With the current preview release, making a field - # full-text searchable and sortable only makes the TEXT - # field sortable. This means that results for full-text - # search queries can be sorted, but not exact match - # queries. - schema += " SORTABLE" - if case_sensitive is True: - raise RedisModelError("Text fields cannot be case-sensitive.") - else: - # String fields are indexed as TAG fields and can be sortable - schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" - if sortable is True: - schema += " SORTABLE" - if case_sensitive is True: - schema += " CASESENSITIVE" + raise RedisModelError("Text fields cannot be case-sensitive.") else: - # Default to TAG field, which can be sortable schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" if sortable is True: schema += " SORTABLE" + if case_sensitive is True: + schema += " CASESENSITIVE" + else: + schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" + if sortable is True: + schema += " SORTABLE" - return schema - return "" + return schema class EmbeddedJsonModel(JsonModel, abc.ABC): From 1c3a4b211378ef014078a034cf551a975d502fb6 Mon Sep 17 00:00:00 2001 From: Renan Soares Date: Fri, 29 May 2026 12:41:07 -0300 Subject: [PATCH 2/2] refactor: reduce instance attributes in migration classes --- .../data/builtin/datetime_migration.py | 8 - aredis_om/model/model.py | 469 ++++++++++++++---- 2 files changed, 385 insertions(+), 92 deletions(-) diff --git a/aredis_om/model/migrations/data/builtin/datetime_migration.py b/aredis_om/model/migrations/data/builtin/datetime_migration.py index 1ff9f05f..f0e02ef0 100644 --- a/aredis_om/model/migrations/data/builtin/datetime_migration.py +++ b/aredis_om/model/migrations/data/builtin/datetime_migration.py @@ -247,7 +247,6 @@ def __init__( self.failure_mode = failure_mode self.batch_size = batch_size self.max_errors = max_errors - self.enable_resume = enable_resume self.progress_save_interval = progress_save_interval self.stats = MigrationStats() self.migration_state = ( @@ -255,10 +254,6 @@ def __init__( ) self.processed_keys_set: Set[str] = set() - # Legacy compatibility - self._processed_keys = 0 - self._converted_fields = 0 - def _safe_convert_datetime_value( self, key: str, field_name: str, value: Any ) -> Tuple[Any, bool]: @@ -332,7 +327,6 @@ async def _load_previous_progress(self) -> bool: if progress["processed_keys"]: self.processed_keys_set = set(progress["processed_keys"]) - self._processed_keys = len(self.processed_keys_set) # Restore stats if available if progress.get("stats"): @@ -430,7 +424,6 @@ async def _process_hash_key( self.processed_keys_set.add(key) self.stats.add_processed_key() - self._processed_keys += 1 self._check_error_threshold() await self._save_progress_if_needed(model_name, total_keys) @@ -690,7 +683,6 @@ async def _process_json_model( # Mark key as processed self.processed_keys_set.add(key) self.stats.add_processed_key() - self._processed_keys += 1 processed_count += 1 # Error threshold checking diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 4cf8b0a2..0a8fb2a8 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -11,6 +11,7 @@ from typing import ( Any, Callable, + ClassVar, Dict, List, Literal, @@ -923,6 +924,28 @@ class RediSearchFieldTypes(Enum): DEFAULT_PAGE_SIZE = 1000 +@dataclasses.dataclass +class _FindQueryState: + expressions: Sequence[ExpressionOrNegated] + model: Type["RedisModel"] + knn: Optional[KNNExpression] = None + offset: int = 0 + limit: Optional[int] = None + page_size: int = DEFAULT_PAGE_SIZE + sort_fields: List[str] = dataclasses.field(default_factory=list) + projected_fields: List[str] = dataclasses.field(default_factory=list) + nocontent: bool = False + return_as_dict: bool = False + + +@dataclasses.dataclass +class _FindQueryCache: + expression: Optional[ExpressionOrNegated] = None + query: Optional[str] = None + pagination: List[str] = dataclasses.field(default_factory=list) + model_cache: List["RedisModel"] = dataclasses.field(default_factory=list) + + class FindQuery: def __init__( self, @@ -944,32 +967,130 @@ def __init__( "instance has one of these modules installed." ) - self.expressions = expressions - self.model = model - self.knn = knn - self.offset = offset - self.limit = limit or (self.knn.k if self.knn else DEFAULT_PAGE_SIZE) - self.page_size = page_size - self.nocontent = nocontent + self._state = _FindQueryState( + expressions=expressions, + model=model, + knn=knn, + offset=offset, + limit=limit if limit is not None else (knn.k if knn else DEFAULT_PAGE_SIZE), + page_size=page_size, + nocontent=nocontent, + ) + self._cache = _FindQueryCache() if sort_fields: - self.sort_fields = self.validate_sort_fields(sort_fields) - elif self.knn: - self.sort_fields = [self.knn.score_field_name] + resolved_sort_fields = self.validate_sort_fields(sort_fields) + elif knn: + resolved_sort_fields = [knn.score_field_name] else: - self.sort_fields = [] + resolved_sort_fields = [] if projected_fields: - self.projected_fields = self.validate_projected_fields(projected_fields) + resolved_projected_fields = self.validate_projected_fields( + projected_fields + ) else: - self.projected_fields = [] + resolved_projected_fields = [] + + self._state.sort_fields = resolved_sort_fields + self._state.projected_fields = resolved_projected_fields + self._state.return_as_dict = return_as_dict + + @property + def expressions(self): + return self._state.expressions + + @expressions.setter + def expressions(self, value): + self._state.expressions = value + + @property + def model(self): + return self._state.model + + @model.setter + def model(self, value): + self._state.model = value + + @property + def knn(self): + return self._state.knn + + @knn.setter + def knn(self, value): + self._state.knn = value + + @property + def offset(self): + return self._state.offset + + @offset.setter + def offset(self, value): + self._state.offset = value + + @property + def limit(self): + return self._state.limit + + @limit.setter + def limit(self, value): + self._state.limit = value + + @property + def page_size(self): + return self._state.page_size + + @page_size.setter + def page_size(self, value): + self._state.page_size = value + + @property + def sort_fields(self): + return self._state.sort_fields + + @sort_fields.setter + def sort_fields(self, value): + self._state.sort_fields = value + + @property + def projected_fields(self): + return self._state.projected_fields + + @projected_fields.setter + def projected_fields(self, value): + self._state.projected_fields = value - self.return_as_dict = return_as_dict + @property + def nocontent(self): + return self._state.nocontent + + @nocontent.setter + def nocontent(self, value): + self._state.nocontent = value - self._expression = None - self._query: Optional[str] = None - self._pagination: List[str] = [] - self._model_cache: List[RedisModel] = [] + @property + def return_as_dict(self): + return self._state.return_as_dict + + @return_as_dict.setter + def return_as_dict(self, value): + self._state.return_as_dict = value + + @property + def _expression(self): + return self._cache.expression + + @property + def _query(self): + return self._cache.query + + @property + def _pagination(self): + return self._cache.pagination + + @property + def _model_cache(self): + return self._cache.model_cache def dict(self) -> Dict[str, Any]: return dict( @@ -991,22 +1112,22 @@ def copy(self, **kwargs): @property def pagination(self): - if self._pagination: - return self._pagination - self._pagination = self.resolve_redisearch_pagination() - return self._pagination + if self._cache.pagination: + return self._cache.pagination + self._cache.pagination = self.resolve_redisearch_pagination() + return self._cache.pagination @property def expression(self): - if self._expression: - return self._expression + if self._cache.expression: + return self._cache.expression if self.expressions: - self._expression = reduce(operator.and_, self.expressions) + self._cache.expression = reduce(operator.and_, self.expressions) else: - self._expression = Expression( + self._cache.expression = Expression( left=None, right=None, op=Operators.ALL, parents=[] ) - return self._expression + return self._cache.expression @property def query(self): @@ -1016,19 +1137,19 @@ def query(self): NOTE: We cache the resolved query string after generating it. This should be OK because all mutations of FindQuery through public APIs return a new FindQuery instance. """ - if self._query: - return self._query - self._query = self._resolve_redisearch_query(self.expression) + if self._cache.query: + return self._cache.query + self._cache.query = self._resolve_redisearch_query(self.expression) if self.knn: # Always wrap the filter expression in parentheses when combining with KNN, # unless it's the wildcard "*". This ensures OR expressions like # "(A)| (B)" become "((A)| (B))=>[KNN ...]" instead of the invalid # "(A)| (B)=>[KNN ...]" where KNN only applies to the second term. - if self._query != "*": - self._query = f"({self._query})" - self._query += f"=>[{self.knn}]" + if self._cache.query != "*": + self._cache.query = f"({self._cache.query})" + self._cache.query += f"=>[{self.knn}]" # RETURN clause should be added to args, not to the query string - return self._query + return self._cache.query def validate_projected_fields(self, projected_fields: List[str]): for field in projected_fields: @@ -1941,7 +2062,7 @@ async def execute( # Reset the cache if we're executing from offset 0. if self.offset == 0: - self._model_cache.clear() + self._cache.model_cache.clear() # If the offset is greater than 0, we're paginating through a result set, # so append the new results to results already in the cache. @@ -1954,14 +2075,14 @@ async def execute( results = await self._parse_execute_results( raw_result, use_full_document_fallback ) - self._model_cache += results + self._cache.model_cache += results if not exhaust_results: - return self._model_cache + return self._cache.model_cache # The query returned all results, so we have no more work to do. if count <= len(results): - return self._model_cache + return self._cache.model_cache # Transparently (to the user) make subsequent requests to paginate # through the results and finally return them all. @@ -1973,8 +2094,8 @@ async def execute( _results = await query.execute(exhaust_results=False) if not _results: break - self._model_cache += _results - return self._model_cache + self._cache.model_cache += _results + return self._cache.model_cache async def get_query(self): query = self.copy() @@ -2073,8 +2194,8 @@ async def delete(self): return 0 async def __aiter__(self): - if self._model_cache: - for m in self._model_cache: + if self._cache.model_cache: + for m in self._cache.model_cache: yield m else: for m in await self.execute(): @@ -2099,8 +2220,8 @@ def __getitem__(self, item: int): "Cannot use [] notation with async code. " "Use FindQuery.get_item() instead." ) - if self._model_cache and len(self._model_cache) >= item: - return self._model_cache[item] + if self._cache.model_cache and len(self._cache.model_cache) >= item: + return self._cache.model_cache[item] query = self.copy(offset=item, limit=1) @@ -2123,8 +2244,8 @@ async def get_item(self, item: int): NOTE: This method is included specifically for async users, who cannot use the notation Model.find()[1000]. """ - if self._model_cache and len(self._model_cache) >= item: - return self._model_cache[item] + if self._cache.model_cache and len(self._cache.model_cache) >= item: + return self._cache.model_cache[item] query = self.copy(offset=item, limit=1) result = await query.execute() @@ -2186,14 +2307,80 @@ def __init__(self, default: Any = ..., **kwargs: Any) -> None: expire = kwargs.pop("expire", None) separator = kwargs.pop("separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR) super().__init__(default=default, **kwargs) - self.primary_key = primary_key - self.sortable = sortable - self.case_sensitive = case_sensitive - self.index = index - self.full_text_search = full_text_search - self.vector_options = vector_options - self.expire = expire - self.separator = separator + self._redis_options: Dict[str, Any] = { + "primary_key": primary_key, + "sortable": sortable, + "case_sensitive": case_sensitive, + "index": index, + "full_text_search": full_text_search, + "vector_options": vector_options, + "expire": expire, + "separator": separator, + } + + @property + def primary_key(self): + return self._redis_options["primary_key"] + + @primary_key.setter + def primary_key(self, value): + self._redis_options["primary_key"] = value + + @property + def sortable(self): + return self._redis_options["sortable"] + + @sortable.setter + def sortable(self, value): + self._redis_options["sortable"] = value + + @property + def case_sensitive(self): + return self._redis_options["case_sensitive"] + + @case_sensitive.setter + def case_sensitive(self, value): + self._redis_options["case_sensitive"] = value + + @property + def index(self): + return self._redis_options["index"] + + @index.setter + def index(self, value): + self._redis_options["index"] = value + + @property + def full_text_search(self): + return self._redis_options["full_text_search"] + + @full_text_search.setter + def full_text_search(self, value): + self._redis_options["full_text_search"] = value + + @property + def vector_options(self): + return self._redis_options["vector_options"] + + @vector_options.setter + def vector_options(self, value): + self._redis_options["vector_options"] = value + + @property + def expire(self): + return self._redis_options["expire"] + + @expire.setter + def expire(self, value): + self._redis_options["expire"] = value + + @property + def separator(self): + return self._redis_options["separator"] + + @separator.setter + def separator(self, value): + self._redis_options["separator"] = value class RelationshipInfo(Representation): @@ -2208,6 +2395,21 @@ def __init__( @dataclasses.dataclass +class VectorFieldParameters: + # Common optional parameters + initial_cap: Optional[int] = None + + # Optional parameters for FLAT + block_size: Optional[int] = None + + # Optional parameters for HNSW + m: Optional[int] = None + ef_construction: Optional[int] = None + ef_runtime: Optional[int] = None + epsilon: Optional[float] = None + + +@dataclasses.dataclass(init=False) class VectorFieldOptions: class ALGORITHM(Enum): FLAT = "FLAT" @@ -2226,18 +2428,98 @@ class DISTANCE_METRIC(Enum): type: TYPE dimension: int distance_metric: DISTANCE_METRIC + params: Optional[VectorFieldParameters] = None - # Common optional parameters - initial_cap: Optional[int] = None + def __init__( + self, + *, + algorithm: ALGORITHM, + type: TYPE, + dimension: int, + distance_metric: DISTANCE_METRIC, + initial_cap: Optional[int] = None, + block_size: Optional[int] = None, + m: Optional[int] = None, + ef_construction: Optional[int] = None, + ef_runtime: Optional[int] = None, + epsilon: Optional[float] = None, + params: Optional[VectorFieldParameters] = None, + ) -> None: + self.algorithm = algorithm + self.type = type + self.dimension = dimension + self.distance_metric = distance_metric + + if params is None: + params = VectorFieldParameters() + + if initial_cap is not None: + params.initial_cap = initial_cap + if block_size is not None: + params.block_size = block_size + if m is not None: + params.m = m + if ef_construction is not None: + params.ef_construction = ef_construction + if ef_runtime is not None: + params.ef_runtime = ef_runtime + if epsilon is not None: + params.epsilon = epsilon + + self.params = params + + def _ensure_params(self) -> VectorFieldParameters: + if self.params is None: + self.params = VectorFieldParameters() + return self.params - # Optional parameters for FLAT - block_size: Optional[int] = None + @property + def initial_cap(self) -> Optional[int]: + return self.params.initial_cap if self.params else None - # Optional parameters for HNSW - m: Optional[int] = None - ef_construction: Optional[int] = None - ef_runtime: Optional[int] = None - epsilon: Optional[float] = None + @initial_cap.setter + def initial_cap(self, value: Optional[int]) -> None: + self._ensure_params().initial_cap = value + + @property + def block_size(self) -> Optional[int]: + return self.params.block_size if self.params else None + + @block_size.setter + def block_size(self, value: Optional[int]) -> None: + self._ensure_params().block_size = value + + @property + def m(self) -> Optional[int]: + return self.params.m if self.params else None + + @m.setter + def m(self, value: Optional[int]) -> None: + self._ensure_params().m = value + + @property + def ef_construction(self) -> Optional[int]: + return self.params.ef_construction if self.params else None + + @ef_construction.setter + def ef_construction(self, value: Optional[int]) -> None: + self._ensure_params().ef_construction = value + + @property + def ef_runtime(self) -> Optional[int]: + return self.params.ef_runtime if self.params else None + + @ef_runtime.setter + def ef_runtime(self, value: Optional[int]) -> None: + self._ensure_params().ef_runtime = value + + @property + def epsilon(self) -> Optional[float]: + return self.params.epsilon if self.params else None + + @epsilon.setter + def epsilon(self, value: Optional[float]) -> None: + self._ensure_params().epsilon = value @staticmethod def flat( @@ -2252,8 +2534,10 @@ def flat( type=type, dimension=dimension, distance_metric=distance_metric, - initial_cap=initial_cap, - block_size=block_size, + params=VectorFieldParameters( + initial_cap=initial_cap, + block_size=block_size, + ), ) @staticmethod @@ -2272,19 +2556,24 @@ def hnsw( type=type, dimension=dimension, distance_metric=distance_metric, - initial_cap=initial_cap, - m=m, - ef_construction=ef_construction, - ef_runtime=ef_runtime, - epsilon=epsilon, + params=VectorFieldParameters( + initial_cap=initial_cap, + m=m, + ef_construction=ef_construction, + ef_runtime=ef_runtime, + epsilon=epsilon, + ), ) @property def schema(self): attr = [] - for k, v in vars(self).items(): - if k == "algorithm" or v is None: - continue + base_fields = ( + ("type", self.type), + ("dimension", self.dimension), + ("distance_metric", self.distance_metric), + ) + for k, v in base_fields: attr.extend( [ k.upper() if k != "dimension" else "DIM", @@ -2292,6 +2581,19 @@ def schema(self): ] ) + if self.params is None: + return " ".join([f"VECTOR {self.algorithm.name} {len(attr)}"] + attr) + + for k, v in dataclasses.asdict(self.params).items(): + if v is None: + continue + attr.extend( + [ + k.upper(), + str(v) if not isinstance(v, Enum) else v.name, + ] + ) + return " ".join([f"VECTOR {self.algorithm.name} {len(attr)}"] + attr) @@ -2412,23 +2714,22 @@ class BaseMeta(Protocol): encoding: str -@dataclasses.dataclass class DefaultMeta: """A default placeholder Meta object. TODO: Revisit whether this is really necessary, and whether making - these all optional here is the right choice. + these all optional here is the right choice. """ - global_key_prefix: Optional[str] = None - model_key_prefix: Optional[str] = None - primary_key_pattern: Optional[str] = None - database: Optional[redis.Redis] = None - primary_key: Optional[PrimaryKey] = None - primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None - index_name: Optional[str] = None - embedded: Optional[bool] = False - encoding: str = "utf-8" + global_key_prefix: ClassVar[Optional[str]] = None + model_key_prefix: ClassVar[Optional[str]] = None + primary_key_pattern: ClassVar[Optional[str]] = None + database: ClassVar[Optional[redis.Redis]] = None + primary_key: ClassVar[Optional[PrimaryKey]] = None + primary_key_creator_cls: ClassVar[Optional[Type[PrimaryKeyCreator]]] = None + index_name: ClassVar[Optional[str]] = None + embedded: ClassVar[bool] = False + encoding: ClassVar[str] = "utf-8" class ModelMeta(ModelMetaclass):