From b80c1a659050650edadced3c13a4a0e44e1eaf16 Mon Sep 17 00:00:00 2001 From: Renan Soares Date: Tue, 26 May 2026 17:41:15 -0300 Subject: [PATCH 1/4] refactor: extract helpers to reduce pylint complexity --- .../data/builtin/datetime_migration.py | 149 ++-- aredis_om/model/migrations/data/migrator.py | 64 +- aredis_om/model/model.py | 807 ++++++++++-------- 3 files changed, 579 insertions(+), 441 deletions(-) diff --git a/aredis_om/model/migrations/data/builtin/datetime_migration.py b/aredis_om/model/migrations/data/builtin/datetime_migration.py index e4c0607c..1ff9f05f 100644 --- a/aredis_om/model/migrations/data/builtin/datetime_migration.py +++ b/aredis_om/model/migrations/data/builtin/datetime_migration.py @@ -363,6 +363,79 @@ async def _save_progress_if_needed(self, current_model: str, total_keys: int): stats=self.stats.get_summary(), ) + async def _collect_hash_keys(self, model_class) -> List[str]: + """Collect all Redis hash keys for a model.""" + key_pattern = model_class.make_key("*") + all_keys = [] + scan_iter = self.redis.scan_iter(match=key_pattern, _type="HASH") + + async for key in scan_iter: # type: ignore[misc] + if isinstance(key, bytes): + key = key.decode("utf-8") + all_keys.append(key) + + return all_keys + + @staticmethod + def _normalize_hash_data(hash_data: Dict[Any, Any]) -> Dict[str, Any]: + """Normalize hash payload to string keys and values.""" + if hash_data and isinstance(next(iter(hash_data.keys())), bytes): + return { + k.decode("utf-8"): v.decode("utf-8") + for k, v in hash_data.items() + } + + return hash_data + + async def _process_hash_key( + self, + key: str, + datetime_fields: List[str], + model_name: str, + total_keys: int, + ) -> bool: + """Process a single hash key and return whether it was handled.""" + if key in self.processed_keys_set: + return False + + try: + hash_data = await self.redis.hgetall(key) # type: ignore[misc] + except Exception as e: + log.warning(f"Failed to get hash data from {key}: {e}") + return False + + if not hash_data: + return False + + hash_data = self._normalize_hash_data(hash_data) + updates = {} + + for field_name in datetime_fields: + if field_name in hash_data: + value = hash_data[field_name] + converted, success = self._safe_convert_datetime_value( + key, field_name, value + ) + + if success and converted != value: + updates[field_name] = str(converted) + + if updates: + try: + await self.redis.hset(key, mapping=updates) # type: ignore[misc] + except Exception as e: + log.error(f"Failed to update hash {key}: {e}") + if self.failure_mode == ConversionFailureMode.FAIL: + raise DataMigrationError(f"Failed to update hash {key}: {e}") + + self.processed_keys_set.add(key) + self.stats.add_processed_key() + self._processed_keys += 1 + + self._check_error_threshold() + await self._save_progress_if_needed(model_name, total_keys) + return True + async def _clear_progress_on_completion(self): """Clear saved progress when migration completes successfully.""" if self.migration_state: @@ -504,16 +577,7 @@ async def _process_hash_model( self, model_class, datetime_fields: List[str] ) -> None: """Process HashModel instances to convert datetime fields with enhanced error handling.""" - # Get all keys for this model - key_pattern = model_class.make_key("*") - - # Collect all keys first for batch processing - all_keys = [] - scan_iter = self.redis.scan_iter(match=key_pattern, _type="HASH") - async for key in scan_iter: # type: ignore[misc] - if isinstance(key, bytes): - key = key.decode("utf-8") - all_keys.append(key) + all_keys = await self._collect_hash_keys(model_class) total_keys = len(all_keys) log.info( @@ -531,64 +595,13 @@ async def _process_hash_model( for key in batch_keys: try: - # Skip if already processed (resume capability) - if key in self.processed_keys_set: - continue - - # Get all fields from the hash - try: - hash_data = await self.redis.hgetall(key) # type: ignore[misc] - except Exception as e: - log.warning(f"Failed to get hash data from {key}: {e}") - continue - - if not hash_data: - continue - - # Convert byte keys/values to strings if needed - if hash_data and isinstance(next(iter(hash_data.keys())), bytes): - hash_data = { - k.decode("utf-8"): v.decode("utf-8") - for k, v in hash_data.items() - } - - updates = {} - - # Check each datetime field with safe conversion - for field_name in datetime_fields: - if field_name in hash_data: - value = hash_data[field_name] - converted, success = self._safe_convert_datetime_value( - key, field_name, value - ) - - if success and converted != value: - updates[field_name] = str(converted) - - # Update the hash if we have changes - if updates: - try: - await self.redis.hset(key, mapping=updates) # type: ignore[misc] - except Exception as e: - log.error(f"Failed to update hash {key}: {e}") - if self.failure_mode == ConversionFailureMode.FAIL: - raise DataMigrationError( - f"Failed to update hash {key}: {e}" - ) - - # Mark key as processed - self.processed_keys_set.add(key) - self.stats.add_processed_key() - self._processed_keys += 1 - processed_count += 1 - - # Error threshold checking - self._check_error_threshold() - - # Save progress periodically - await self._save_progress_if_needed( - model_class.__name__, total_keys - ) + if await self._process_hash_key( + key, + datetime_fields, + model_class.__name__, + total_keys, + ): + processed_count += 1 except DataMigrationError: # Re-raise migration errors diff --git a/aredis_om/model/migrations/data/migrator.py b/aredis_om/model/migrations/data/migrator.py index 23456775..6b9ba1f0 100644 --- a/aredis_om/model/migrations/data/migrator.py +++ b/aredis_om/model/migrations/data/migrator.py @@ -329,21 +329,18 @@ async def run_migrations_with_monitoring( "errors": [], } - if verbose: - print(f"Found {len(pending_migrations)} pending migration(s):") - for migration in pending_migrations: - print(f"- {migration.migration_id}: {migration.description}") + self._log_pending_migrations(pending_migrations, verbose) if dry_run: if verbose: print("Dry run mode - no changes will be applied.") - return { - "applied_count": len(pending_migrations), - "total_migrations": len(pending_migrations), - "performance_stats": monitor.get_stats(), - "errors": [], - "dry_run": True, - } + return self._build_monitoring_result( + applied_count=len(pending_migrations), + pending_migrations=pending_migrations, + monitor=monitor, + errors=[], + dry_run=True, + ) applied_count = 0 errors = [] @@ -400,6 +397,33 @@ async def run_migrations_with_monitoring( monitor.finish() + result = self._build_monitoring_result( + applied_count=applied_count, + pending_migrations=pending_migrations, + monitor=monitor, + errors=errors, + ) + + self._log_monitoring_summary(result, verbose) + + return result + + def _log_pending_migrations( + self, pending_migrations: List[BaseMigration], verbose: bool + ) -> None: + if verbose: + print(f"Found {len(pending_migrations)} pending migration(s):") + for migration in pending_migrations: + print(f"- {migration.migration_id}: {migration.description}") + + def _build_monitoring_result( + self, + applied_count: int, + pending_migrations: List[BaseMigration], + monitor: PerformanceMonitor, + errors: List[Dict[str, Any]], + dry_run: bool = False, + ) -> Dict[str, Any]: result = { "applied_count": applied_count, "total_migrations": len(pending_migrations), @@ -412,18 +436,28 @@ async def run_migrations_with_monitoring( ), } + if dry_run: + result["dry_run"] = True + + return result + + def _log_monitoring_summary( + self, result: Dict[str, Any], verbose: bool + ) -> None: if verbose: - print(f"Applied {applied_count}/{len(pending_migrations)} migration(s).") + total_migrations = result["total_migrations"] + applied_count = result["applied_count"] + print(f"Applied {applied_count}/{total_migrations} migration(s).") stats = result["performance_stats"] if stats: print(f"Total time: {stats.get('total_time_seconds', 0):.2f}s") if "items_per_second" in stats: # type: ignore - print(f"Performance: {stats['items_per_second']:.1f} items/second") # type: ignore + print( + f"Performance: {stats['items_per_second']:.1f} items/second" + ) # type: ignore if "peak_memory_mb" in stats: # type: ignore print(f"Peak memory: {stats['peak_memory_mb']:.1f} MB") # type: ignore - return result - async def rollback_migration( self, migration_id: str, dry_run: bool = False, verbose: bool = False ) -> bool: diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 07f6de14..4cf8b0a2 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1541,6 +1541,161 @@ def expand_tag_value(value): ) return escaper.escape(str(value)) + @staticmethod + def _convert_numeric_value(value): + if isinstance(value, Enum): + value = value.value + if isinstance(value, (datetime.datetime, datetime.date)): + if isinstance(value, datetime.date) and not isinstance( + value, datetime.datetime + ): + value = datetime.datetime.combine( + value, datetime.time.min, tzinfo=datetime.timezone.utc + ) + value = value.timestamp() + return value + + @classmethod + def _resolve_text_value(cls, field_name: str, op: Operators, value: Any) -> str: + result = f"@{field_name}_fts:" + if op is Operators.EQ: + return f'{result}"{value}"' + if op is Operators.NE: + return f'-({result}"{value}")' + if op is Operators.LIKE: + return f"{result}{value}" + raise QueryNotSupportedError( + "Only equals (=), not-equals (!=), and like() " + "comparisons are supported for TEXT fields. " + f"Docs: {ERRORS_URL}#E5" + ) + + @classmethod + def _resolve_numeric_value( + cls, field_name: str, op: Operators, value: Any + ) -> str: + if op is Operators.IN: + converted_values = [cls._convert_numeric_value(v) for v in value] + parts = [f"(@{field_name}:[{v} {v}])" for v in converted_values] + return "|".join(parts) + if op is Operators.NOT_IN: + converted_values = [cls._convert_numeric_value(v) for v in value] + parts = [f"(@{field_name}:[{v} {v}])" for v in converted_values] + return f"-({' | '.join(parts)})" + + value = cls._convert_numeric_value(value) + if op is Operators.EQ: + return f"@{field_name}:[{value} {value}]" + if op is Operators.NE: + return f"-(@{field_name}:[{value} {value}])" + if op is Operators.GT: + return f"@{field_name}:[({value} +inf]" + if op is Operators.LT: + return f"@{field_name}:[-inf ({value}]" + if op is Operators.GE: + return f"@{field_name}:[{value} +inf]" + if op is Operators.LE: + return f"@{field_name}:[-inf {value}]" + return "" + + @classmethod + def _resolve_tag_eq_value( + cls, + field_name: str, + field_info: PydanticFieldInfo, + value: Any, + model_class: Optional[Type["RedisModel"]], + ) -> str: + separator_char = getattr( + field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR + ) + if value == separator_char: + log.warning( + "Your query against the field %s is for a single character, %s, " + "that is used internally by Redis OM Python. We must ignore " + "this portion of the query. Please review your query to find " + "an alternative query that uses a string containing more than " + "just the character %s.", + field_name, + separator_char, + separator_char, + ) + return "" + if isinstance(value, bool): + if model_class: + is_hash_model = any( + base.__name__ == "HashModel" for base in model_class.__mro__ + ) + value = ("1" if value else "0") if is_hash_model else value + return "@{field_name}:{{{value}}}".format( + field_name=field_name, value=value + ) + if isinstance(value, int): + return f"@{field_name}:[{value} {value}]" + if separator_char in value: + values: filter = filter(None, value.split(separator_char)) + return "".join( + "@{field_name}:{{{value}}}".format( + field_name=field_name, value=escaper.escape(value) + ) + for value in values + ) + value = escaper.escape(value) + return "@{field_name}:{{{value}}}".format(field_name=field_name, value=value) + + @classmethod + def _resolve_tag_value( + cls, + field_name: str, + field_info: PydanticFieldInfo, + op: Operators, + value: Any, + model_class: Optional[Type["RedisModel"]], + ) -> str: + if op is Operators.EQ: + return cls._resolve_tag_eq_value(field_name, field_info, value, model_class) + if op is Operators.NE: + value = escaper.escape(value) + return "-(@{field_name}:{{{value}}})".format( + field_name=field_name, value=value + ) + if op is Operators.IN: + expanded_value = cls.expand_tag_value(value) + return "(@{field_name}:{{{expanded_value}}})".format( + field_name=field_name, expanded_value=expanded_value + ) + if op is Operators.NOT_IN: + expanded_value = cls.expand_tag_value(value) + return "-(@{field_name}:{{{expanded_value}}})".format( + field_name=field_name, expanded_value=expanded_value + ) + if op is Operators.STARTSWITH: + expanded_value = cls.expand_tag_value(value) + return "(@{field_name}:{{{expanded_value}*}})".format( + field_name=field_name, expanded_value=expanded_value + ) + if op is Operators.ENDSWITH: + expanded_value = cls.expand_tag_value(value) + return "(@{field_name}:{{*{expanded_value}}})".format( + field_name=field_name, expanded_value=expanded_value + ) + if op is Operators.CONTAINS: + expanded_value = cls.expand_tag_value(value) + return "(@{field_name}:{{*{expanded_value}*}})".format( + field_name=field_name, expanded_value=expanded_value + ) + return "" + + @staticmethod + def _resolve_geo_value(field_name: str, op: Operators, value: Any) -> str: + if not isinstance(value, GeoFilter): + raise QuerySyntaxError( + "You can only use a GeoFilter object with a GEO field." + ) + if op is Operators.EQ: + return f"@{field_name}:[{value}]" + return "" + @classmethod def resolve_value( cls, @@ -1553,163 +1708,22 @@ def resolve_value( model_class: Optional[Type["RedisModel"]] = None, ) -> str: # The 'field_name' should already include the correct prefix - result = "" if parents: prefix = "_".join([p[0] for p in parents]) field_name = f"{prefix}_{field_name}" if field_type is RediSearchFieldTypes.TEXT: - result = f"@{field_name}_fts:" - if op is Operators.EQ: - result += f'"{value}"' - elif op is Operators.NE: - result = f'-({result}"{value}")' - elif op is Operators.LIKE: - result += value - else: - raise QueryNotSupportedError( - "Only equals (=), not-equals (!=), and like() " - "comparisons are supported for TEXT fields. " - f"Docs: {ERRORS_URL}#E5" - ) + return cls._resolve_text_value(field_name, op, value) elif field_type is RediSearchFieldTypes.NUMERIC: - # Helper to convert a single value for NUMERIC queries - def convert_numeric_value(v): - # Convert Enum to its value (fixes #108) - if isinstance(v, Enum): - v = v.value - # Convert datetime objects to timestamps - if isinstance(v, (datetime.datetime, datetime.date)): - if isinstance(v, datetime.date) and not isinstance( - v, datetime.datetime - ): - # Use UTC midnight so query conversion matches storage conversion. - v = datetime.datetime.combine( - v, datetime.time.min, tzinfo=datetime.timezone.utc - ) - v = v.timestamp() - return v - - if op is Operators.IN: - # Handle IN operator for NUMERIC fields (fixes #499) - # Convert each value and create OR of range queries - converted_values = [convert_numeric_value(v) for v in value] - parts = [f"(@{field_name}:[{v} {v}])" for v in converted_values] - result += "|".join(parts) - elif op is Operators.NOT_IN: - # Handle NOT_IN operator for NUMERIC fields - converted_values = [convert_numeric_value(v) for v in value] - parts = [f"(@{field_name}:[{v} {v}])" for v in converted_values] - result += f"-({' | '.join(parts)})" - else: - value = convert_numeric_value(value) - - if op is Operators.EQ: - result += f"@{field_name}:[{value} {value}]" - elif op is Operators.NE: - result += f"-(@{field_name}:[{value} {value}])" - elif op is Operators.GT: - result += f"@{field_name}:[({value} +inf]" - elif op is Operators.LT: - result += f"@{field_name}:[-inf ({value}]" - elif op is Operators.GE: - result += f"@{field_name}:[{value} +inf]" - elif op is Operators.LE: - result += f"@{field_name}:[-inf {value}]" + return cls._resolve_numeric_value(field_name, op, value) # TODO: How will we know the difference between a multi-value use of a TAG # field and our hidden use of TAG for exact-match queries? elif field_type is RediSearchFieldTypes.TAG: - if op is Operators.EQ: - separator_char = getattr( - field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR - ) - if value == separator_char: - # The value is ONLY the TAG field separator character -- - # this is not going to work. - log.warning( - "Your query against the field %s is for a single character, %s, " - "that is used internally by Redis OM Python. We must ignore " - "this portion of the query. Please review your query to find " - "an alternative query that uses a string containing more than " - "just the character %s.", - field_name, - separator_char, - separator_char, - ) - return "" - if isinstance(value, bool): - # For HashModel, convert boolean to "1"/"0" to match storage format - # For JsonModel, keep as boolean since JSON supports native booleans - if model_class: - # Check if this is a HashModel by checking the class hierarchy - is_hash_model = any( - base.__name__ == "HashModel" for base in model_class.__mro__ - ) - bool_value = ("1" if value else "0") if is_hash_model else value - else: - bool_value = value - result = "@{field_name}:{{{value}}}".format( - field_name=field_name, value=bool_value - ) - elif isinstance(value, int): - # This if will hit only if the field is a primary key of type int - result = f"@{field_name}:[{value} {value}]" - elif separator_char in value: - # The value contains the TAG field separator. We can work - # around this by breaking apart the values and unioning them - # with multiple field:{} queries. - values: filter = filter(None, value.split(separator_char)) - for value in values: - value = escaper.escape(value) - result += "@{field_name}:{{{value}}}".format( - field_name=field_name, value=value - ) - else: - value = escaper.escape(value) - result += "@{field_name}:{{{value}}}".format( - field_name=field_name, value=value - ) - elif op is Operators.NE: - value = escaper.escape(value) - result += "-(@{field_name}:{{{value}}})".format( - field_name=field_name, value=value - ) - elif op is Operators.IN: - expanded_value = cls.expand_tag_value(value) - result += "(@{field_name}:{{{expanded_value}}})".format( - field_name=field_name, expanded_value=expanded_value - ) - elif op is Operators.NOT_IN: - # TODO: Implement NOT_IN, test this... - expanded_value = cls.expand_tag_value(value) - result += "-(@{field_name}:{{{expanded_value}}})".format( - field_name=field_name, expanded_value=expanded_value - ) - elif op is Operators.STARTSWITH: - expanded_value = cls.expand_tag_value(value) - result += "(@{field_name}:{{{expanded_value}*}})".format( - field_name=field_name, expanded_value=expanded_value - ) - elif op is Operators.ENDSWITH: - expanded_value = cls.expand_tag_value(value) - result += "(@{field_name}:{{*{expanded_value}}})".format( - field_name=field_name, expanded_value=expanded_value - ) - elif op is Operators.CONTAINS: - expanded_value = cls.expand_tag_value(value) - result += "(@{field_name}:{{*{expanded_value}*}})".format( - field_name=field_name, expanded_value=expanded_value - ) + return cls._resolve_tag_value(field_name, field_info, op, value, model_class) elif field_type is RediSearchFieldTypes.GEO: - if not isinstance(value, GeoFilter): - raise QuerySyntaxError( - "You can only use a GeoFilter object with a GEO field." - ) + return cls._resolve_geo_value(field_name, op, value) - if op is Operators.EQ: - result += f"@{field_name}:[{value}]" - - return result + return "" def resolve_redisearch_pagination(self): """Resolve pagination options for a query.""" @@ -1835,9 +1849,7 @@ def _resolve_redisearch_query(self, expression: ExpressionOrNegated) -> str: return result - async def execute( - self, exhaust_results=True, return_raw_result=False, return_query_args=False - ): + def _build_execute_args(self) -> tuple[List[Union[str, bytes]], bool]: args: List[Union[str, bytes]] = [ "FT.SEARCH", self.model.Meta.index_name, @@ -1862,29 +1874,22 @@ async def execute( if self.nocontent: args.append("NOCONTENT") - # Check if we have complex fields that RediSearch RETURN clause can't handle - use_full_document_fallback = False - if self.projected_fields: - use_full_document_fallback = self._has_complex_projected_fields() + use_full_document_fallback = bool( + self.projected_fields and self._has_complex_projected_fields() + ) - # Add RETURN clause to the args list, not to the query string - # Skip RETURN clause if we need full documents for complex field projection + # Add RETURN clause to the args list, not to the query string. + # Skip RETURN clause if we need full documents for complex field projection. if self.projected_fields and not use_full_document_fallback: args.extend( ["RETURN", str(len(self.projected_fields))] + self.projected_fields ) - if return_query_args: - return self.model.Meta.index_name, args - - # Reset the cache if we're executing from offset 0. - if self.offset == 0: - self._model_cache.clear() + return args, use_full_document_fallback - # If the offset is greater than 0, we're paginating through a result set, - # so append the new results to results already in the cache. + async def _execute_command(self, args: List[Union[str, bytes]]): try: - raw_result = await self.model.db().execute_command(*args) + return await self.model.db().execute_command(*args) except Exception as e: error_msg = str(e).lower() @@ -1900,32 +1905,55 @@ async def execute( # Re-raise the original exception raise - if return_raw_result: - return raw_result - count = raw_result[0] - # Handle different result processing based on what was requested + async def _parse_execute_results( + self, raw_result: Any, use_full_document_fallback: bool + ): if self.projected_fields and use_full_document_fallback: # Complex field projection - use full document fallback if self.return_as_dict: - results = await self._parse_full_document_projection_as_dict(raw_result) - else: - results = await self._parse_full_document_projection_as_models( - raw_result - ) - elif self.projected_fields and self.return_as_dict: + return await self._parse_full_document_projection_as_dict(raw_result) + return await self._parse_full_document_projection_as_models(raw_result) + + if self.projected_fields and self.return_as_dict: # .values('field1', 'field2') - specific fields as dicts - results = self._parse_projected_results(raw_result) - elif self.projected_fields and not self.return_as_dict: + return self._parse_projected_results(raw_result) + + if self.projected_fields and not self.return_as_dict: # .only('field1', 'field2') - partial model instances - results = self._parse_projected_models(raw_result) - elif self.return_as_dict and not self.projected_fields: + return self._parse_projected_models(raw_result) + + if self.return_as_dict and not self.projected_fields: # .values() - all fields as dicts model_results = self.model.from_redis(raw_result, self.knn) - results = [model.model_dump() for model in model_results] - else: - # Normal query - full model instances - results = self.model.from_redis(raw_result, self.knn) + return [model.model_dump() for model in model_results] + + # Normal query - full model instances + return self.model.from_redis(raw_result, self.knn) + + async def execute( + self, exhaust_results=True, return_raw_result=False, return_query_args=False + ): + args, use_full_document_fallback = self._build_execute_args() + + if return_query_args: + return self.model.Meta.index_name, args + + # Reset the cache if we're executing from offset 0. + if self.offset == 0: + self._model_cache.clear() + + # If the offset is greater than 0, we're paginating through a result set, + # so append the new results to results already in the cache. + raw_result = await self._execute_command(args) + + if return_raw_result: + return raw_result + count = raw_result[0] + + results = await self._parse_execute_results( + raw_result, use_full_document_fallback + ) self._model_cache += results if not exhaust_results: @@ -2409,44 +2437,30 @@ class ModelMeta(ModelMetaclass): model_config: RedisOmConfig model_fields: Dict[str, FieldInfo] # type: ignore[assignment] - def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 - meta = attrs.pop("Meta", None) - - # Capture original FieldInfo objects from attrs before Pydantic processes them. - # Pydantic 2.12+ may convert custom FieldInfo subclasses to plain PydanticFieldInfo - # for Annotated types, losing custom attributes like index, sortable, etc. + @staticmethod + def _capture_original_field_infos(attrs): original_field_infos: Dict[str, FieldInfo] = {} if PYDANTIC_V2: for attr_name, attr_value in attrs.items(): if isinstance(attr_value, FieldInfo): original_field_infos[attr_name] = attr_value + return original_field_infos - # Duplicate logic from Pydantic to filter config kwargs because if they are - # passed directly including the registry Pydantic will pass them over to the - # superclass causing an error - allowed_config_kwargs: Set[str] = { + @staticmethod + def _allowed_config_kwargs(): + return { key for key in dir(ConfigDict) - if not ( - key.startswith("__") and key.endswith("__") - ) # skip dunder methods and attributes + if not (key.startswith("__") and key.endswith("__")) } - config_kwargs = { - key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs - } - - new_class: RedisModel = super().__new__( - cls, name, bases, attrs, **config_kwargs - ) - + @staticmethod + def _build_meta(new_class, meta, base_meta): # The fact that there is a Meta field and _meta field is important: a # user may have given us a Meta object with their configuration, while # we might have inherited _meta from a parent class, and should # therefore use some of the inherited fields. meta = meta or getattr(new_class, "Meta", None) - base_meta = getattr(new_class, "_meta", None) - if meta and meta != DefaultMeta and meta != base_meta: new_class.Meta = meta new_class._meta = meta @@ -2465,13 +2479,8 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 ) new_class.Meta = new_class._meta - is_indexed = kwargs.get("index", None) is True - - if is_indexed and new_class.model_config.get("index", None) is True: - raise RedisModelError( - f"{new_class.__name__} cannot be indexed, only one model can be indexed in an inheritance tree" - ) - + @staticmethod + def _set_index_config(new_class, is_indexed): if PYDANTIC_V2: new_class.model_config["index"] = is_indexed else: @@ -2485,14 +2494,20 @@ class Config: new_class.Config = Config - # Create proxies for each model field so that we can use the field - # in queries, like Model.get(Model.field_name == 1) - # Only set if the model is has index=True + @staticmethod + def _model_fields(new_class): if PYDANTIC_V2: - model_fields = new_class.model_fields - else: - model_fields = new_class.__fields__ + return new_class.model_fields + return new_class.__fields__ + @staticmethod + def _apply_field_proxies( + new_class, + model_fields, + original_field_infos, + is_indexed, + ): + custom_pk_count = 0 for field_name, field in model_fields.items(): pydantic_field = field # Keep reference to Pydantic's processed field if type(field) is PydanticFieldInfo: @@ -2541,7 +2556,6 @@ class Config: # should set up the PrimaryKeyAccessor. We only do this when there's # exactly one custom primary key. Multiple custom primary keys will be # caught by validate_primary_key() later. - custom_pk_count = 0 for field_name, field in model_fields.items(): if field_name == "pk": continue @@ -2557,6 +2571,10 @@ class Config: if getattr(check_field, "primary_key", None) is True: custom_pk_count += 1 + return custom_pk_count + + @staticmethod + def _setup_primary_key_accessor(new_class, model_fields, custom_pk_count): # If there's exactly one custom primary key (not the default 'pk'), set up # a PrimaryKeyAccessor so that .pk always returns the correct value. # This fixes GitHub issue #570. @@ -2572,6 +2590,8 @@ class Config: # Set up PrimaryKeyAccessor descriptor for .pk access setattr(new_class, "pk", PrimaryKeyAccessor()) + @staticmethod + def _apply_embedded_model_rules(new_class): # For embedded models, clear the primary_key from meta since they don't # need primary keys - they're stored as part of their parent document, # not as separate Redis keys. This fixes GitHub issue #496. @@ -2580,6 +2600,8 @@ class Config: if getattr(new_class._meta, "embedded", False): new_class._meta.primary_key = None + @staticmethod + def _apply_meta_defaults(new_class, base_meta): if not getattr(new_class._meta, "global_key_prefix", None): new_class._meta.global_key_prefix = getattr( base_meta, "global_key_prefix", "" @@ -2610,6 +2632,8 @@ class Config: f"{new_class._meta.model_key_prefix}:index" ) + @staticmethod + def _register_indexed_model(new_class, bases): # Model is indexed and not an abstract model class or embedded model, so we should let the # Migrator create indexes for it. if ( @@ -2620,6 +2644,52 @@ class Config: key = f"{new_class.__module__}.{new_class.__qualname__}" model_registry[key] = new_class + def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 + meta = attrs.pop("Meta", None) + + # Capture original FieldInfo objects from attrs before Pydantic processes them. + # Pydantic 2.12+ may convert custom FieldInfo subclasses to plain PydanticFieldInfo + # for Annotated types, losing custom attributes like index, sortable, etc. + original_field_infos = cls._capture_original_field_infos(attrs) + + # Duplicate logic from Pydantic to filter config kwargs because if they are + # passed directly including the registry Pydantic will pass them over to the + # superclass causing an error + allowed_config_kwargs = cls._allowed_config_kwargs() + + config_kwargs = {key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs} + + new_class: RedisModel = super().__new__( + cls, name, bases, attrs, **config_kwargs + ) + + base_meta = getattr(new_class, "_meta", None) + cls._build_meta(new_class, meta, base_meta) + + is_indexed = kwargs.get("index", None) is True + + if is_indexed and new_class.model_config.get("index", None) is True: + raise RedisModelError( + f"{new_class.__name__} cannot be indexed, only one model can be indexed in an inheritance tree" + ) + + cls._set_index_config(new_class, is_indexed) + + # Create proxies for each model field so that we can use the field + # in queries, like Model.get(Model.field_name == 1) + # Only set if the model is has index=True + model_fields = cls._model_fields(new_class) + custom_pk_count = cls._apply_field_proxies( + new_class, + model_fields, + original_field_infos, + is_indexed, + ) + cls._setup_primary_key_accessor(new_class, model_fields, custom_pk_count) + cls._apply_embedded_model_rules(new_class) + cls._apply_meta_defaults(new_class, base_meta) + cls._register_indexed_model(new_class, bases) + return new_class @@ -3713,161 +3783,182 @@ def schema_for_type( # into the values of the list or the fields of the model to # find any values marked as indexed. if is_container_type and not is_vector: - field_type = get_origin(typ) - if field_type == Literal: - path = f"{json_path}.{name}" - return cls.schema_for_type( - path, - name, - name_prefix, - str, - field_info, - parent_type=field_type, - ) - else: - embedded_cls = get_args(typ) - if not embedded_cls: - log.warning( - "Model %s defined an empty list or tuple field: %s", cls, name - ) - return "" - path = f"{json_path}.{name}[*]" - embedded_cls = embedded_cls[0] - return cls.schema_for_type( - path, - name, - name_prefix, - embedded_cls, - field_info, - parent_type=field_type, - ) + return cls._schema_for_container_type( + json_path, name, name_prefix, typ, field_info + ) elif field_is_model: - name_prefix = f"{name_prefix}_{name}" if name_prefix else name - sub_fields = [] - for embedded_name, field in typ.model_fields.items(): - if ( - hasattr(field, "metadata") - and len(field.metadata) > 0 - and isinstance(field.metadata[0], FieldInfo) - ): - field_info = field.metadata[0] - else: - field_info = field - - if parent_is_container_type: - # We'll store this value either as a JavaScript array, so - # the correct JSONPath expression is to refer directly to - # attribute names after the container notation, e.g. - # orders[*].created_date. - path = json_path - else: - # All other fields should use dot notation with both the - # current field name and "embedded" field name, e.g., - # order.address.street_line_1. - path = f"{json_path}.{name}" - sub_fields.append( - cls.schema_for_type( - path, - embedded_name, - name_prefix, - # field.annotation, - get_outer_type(field), - field_info, - parent_type=typ, - ) - ) - return " ".join(filter(None, sub_fields)) + return cls._schema_for_embedded_model( + json_path, name, name_prefix, typ, parent_is_container_type + ) # NOTE: This is the termination point for recursion. We've descended # into models and lists until we found an actual value to index. elif should_index: - index_field_name = f"{name_prefix}_{name}" if name_prefix else name - if parent_is_container_type: - # If we're indexing the this field as a JavaScript array, then - # the currently built-up JSONPath expression will be - # "field_name[*]", which is what we want to use. - path = json_path + return cls._schema_for_leaf_type( + json_path, + name, + name_prefix, + typ, + field_info, + is_vector, + vector_options, + parent_is_container_type, + parent_is_model_in_container, + ) + return "" + + @classmethod + def _schema_for_container_type( + cls, + json_path: str, + name: str, + name_prefix: str, + typ: Any, + field_info: PydanticFieldInfo, + ) -> str: + field_type = get_origin(typ) + if field_type == Literal: + path = f"{json_path}.{name}" + return cls.schema_for_type( + path, + name, + name_prefix, + str, + field_info, + parent_type=field_type, + ) + + embedded_cls = get_args(typ) + if not embedded_cls: + log.warning("Model %s defined an empty list or tuple field: %s", cls, name) + return "" + + path = f"{json_path}.{name}[*]" + return cls.schema_for_type( + path, + name, + name_prefix, + embedded_cls[0], + field_info, + parent_type=field_type, + ) + + @classmethod + def _schema_for_embedded_model( + cls, + json_path: str, + name: str, + name_prefix: str, + typ: Union[Type[RedisModel], Any], + parent_is_container_type: bool, + ) -> str: + name_prefix = f"{name_prefix}_{name}" if name_prefix else name + sub_fields = [] + for embedded_name, field in typ.model_fields.items(): + if ( + hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): + field_info = field.metadata[0] else: - path = f"{json_path}.{name}" - sortable = getattr(field_info, "sortable", False) - case_sensitive = getattr(field_info, "case_sensitive", False) - full_text_search = getattr(field_info, "full_text_search", False) - - # For more complicated compound validators (e.g. PositiveInt), we might get a _GenericAlias rather than - # a proper type, we can pull the type information from the origin of the first argument. - if not isinstance(typ, type): - type_args = typing_get_args(field_info.annotation) - typ = ( - getattr(type_args[0], "__origin__", type_args[0]) - if type_args - else typ - ) + field_info = field - # Get separator from field_info, defaulting to pipe - separator = getattr( - field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR + # Lists store embedded values directly under the container path. + path = json_path if parent_is_container_type else f"{json_path}.{name}" + sub_fields.append( + cls.schema_for_type( + path, + embedded_name, + name_prefix, + get_outer_type(field), + field_info, + parent_type=typ, + ) ) + return " ".join(filter(None, sub_fields)) - if is_vector and vector_options: - schema = f"{path} AS {index_field_name} {vector_options.schema}" - elif parent_is_container_type or parent_is_model_in_container: - if typ is not str: - raise RedisModelError( - "List and tuple fields can only contain strings. " - f"Problem field: {name}. Docs: {ERRORS_URL}#E12" - ) - if full_text_search is True: - raise RedisModelError( - "List and tuple fields cannot be indexed for full-text " - f"search. Problem field: {name}. Docs: {ERRORS_URL}#E13" - ) - # List/tuple fields are indexed as TAG fields and can be sortable - schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" + @classmethod + def _schema_for_leaf_type( + cls, + json_path: str, + name: str, + name_prefix: str, + typ: Union[Type[RedisModel], Any], + field_info: PydanticFieldInfo, + is_vector: bool, + vector_options: Optional[VectorFieldOptions], + parent_is_container_type: bool, + parent_is_model_in_container: bool, + ) -> str: + index_field_name = f"{name_prefix}_{name}" if name_prefix else name + path = json_path if parent_is_container_type else f"{json_path}.{name}" + sortable = getattr(field_info, "sortable", False) + case_sensitive = getattr(field_info, "case_sensitive", False) + full_text_search = getattr(field_info, "full_text_search", False) + + # For more complicated compound validators (e.g. PositiveInt), we might + # get a _GenericAlias rather than a proper type. + if not isinstance(typ, type): + type_args = typing_get_args(field_info.annotation) + typ = getattr(type_args[0], "__origin__", type_args[0]) if type_args else typ + + separator = getattr(field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR) + + if is_vector and vector_options: + schema = f"{path} AS {index_field_name} {vector_options.schema}" + elif parent_is_container_type or parent_is_model_in_container: + if typ is not str: + raise RedisModelError( + "List and tuple fields can only contain strings. " + f"Problem field: {name}. Docs: {ERRORS_URL}#E12" + ) + if full_text_search is True: + raise RedisModelError( + "List and tuple fields cannot be indexed for full-text " + f"search. Problem field: {name}. Docs: {ERRORS_URL}#E13" + ) + schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" + if sortable is True: + schema += " SORTABLE" + if case_sensitive is True: + schema += " CASESENSITIVE" + elif typ is bool: + schema = f"{path} AS {index_field_name} TAG" + if sortable is True: + schema += " SORTABLE" + elif typ in [CoordinateType, Coordinates]: + schema = f"{path} AS {index_field_name} GEO" + if sortable is True: + schema += " SORTABLE" + elif is_numeric_type(typ): + schema = f"{path} AS {index_field_name} NUMERIC" + if sortable is True: + schema += " SORTABLE" + elif issubclass(typ, str): + if full_text_search is True: + schema = ( + f"{path} AS {index_field_name} TAG SEPARATOR {separator} " + f"{path} AS {index_field_name}_fts TEXT" + ) if sortable is True: + # NOTE: With the current preview release, making a field + # full-text searchable and sortable only makes the TEXT + # field sortable. schema += " SORTABLE" if case_sensitive is True: - schema += " CASESENSITIVE" - elif typ is bool: - schema = f"{path} AS {index_field_name} TAG" - if sortable is True: - schema += " SORTABLE" - elif typ in [CoordinateType, Coordinates]: - schema = f"{path} AS {index_field_name} GEO" - if sortable is True: - schema += " SORTABLE" - elif is_numeric_type(typ): - schema = f"{path} AS {index_field_name} NUMERIC" - if sortable is True: - schema += " SORTABLE" - elif issubclass(typ, str): - if full_text_search is True: - schema = ( - f"{path} AS {index_field_name} TAG SEPARATOR {separator} " - f"{path} AS {index_field_name}_fts TEXT" - ) - if sortable is True: - # NOTE: With the current preview release, making a field - # full-text searchable and sortable only makes the TEXT - # field sortable. This means that results for full-text - # search queries can be sorted, but not exact match - # queries. - schema += " SORTABLE" - if case_sensitive is True: - raise RedisModelError("Text fields cannot be case-sensitive.") - else: - # String fields are indexed as TAG fields and can be sortable - schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" - if sortable is True: - schema += " SORTABLE" - if case_sensitive is True: - schema += " CASESENSITIVE" + raise RedisModelError("Text fields cannot be case-sensitive.") else: - # Default to TAG field, which can be sortable schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" if sortable is True: schema += " SORTABLE" + if case_sensitive is True: + schema += " CASESENSITIVE" + else: + schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" + if sortable is True: + schema += " SORTABLE" - return schema - return "" + return schema class EmbeddedJsonModel(JsonModel, abc.ABC): From 1c3a4b211378ef014078a034cf551a975d502fb6 Mon Sep 17 00:00:00 2001 From: Renan Soares Date: Fri, 29 May 2026 12:41:07 -0300 Subject: [PATCH 2/4] refactor: reduce instance attributes in migration classes --- .../data/builtin/datetime_migration.py | 8 - aredis_om/model/model.py | 469 ++++++++++++++---- 2 files changed, 385 insertions(+), 92 deletions(-) diff --git a/aredis_om/model/migrations/data/builtin/datetime_migration.py b/aredis_om/model/migrations/data/builtin/datetime_migration.py index 1ff9f05f..f0e02ef0 100644 --- a/aredis_om/model/migrations/data/builtin/datetime_migration.py +++ b/aredis_om/model/migrations/data/builtin/datetime_migration.py @@ -247,7 +247,6 @@ def __init__( self.failure_mode = failure_mode self.batch_size = batch_size self.max_errors = max_errors - self.enable_resume = enable_resume self.progress_save_interval = progress_save_interval self.stats = MigrationStats() self.migration_state = ( @@ -255,10 +254,6 @@ def __init__( ) self.processed_keys_set: Set[str] = set() - # Legacy compatibility - self._processed_keys = 0 - self._converted_fields = 0 - def _safe_convert_datetime_value( self, key: str, field_name: str, value: Any ) -> Tuple[Any, bool]: @@ -332,7 +327,6 @@ async def _load_previous_progress(self) -> bool: if progress["processed_keys"]: self.processed_keys_set = set(progress["processed_keys"]) - self._processed_keys = len(self.processed_keys_set) # Restore stats if available if progress.get("stats"): @@ -430,7 +424,6 @@ async def _process_hash_key( self.processed_keys_set.add(key) self.stats.add_processed_key() - self._processed_keys += 1 self._check_error_threshold() await self._save_progress_if_needed(model_name, total_keys) @@ -690,7 +683,6 @@ async def _process_json_model( # Mark key as processed self.processed_keys_set.add(key) self.stats.add_processed_key() - self._processed_keys += 1 processed_count += 1 # Error threshold checking diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 4cf8b0a2..0a8fb2a8 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -11,6 +11,7 @@ from typing import ( Any, Callable, + ClassVar, Dict, List, Literal, @@ -923,6 +924,28 @@ class RediSearchFieldTypes(Enum): DEFAULT_PAGE_SIZE = 1000 +@dataclasses.dataclass +class _FindQueryState: + expressions: Sequence[ExpressionOrNegated] + model: Type["RedisModel"] + knn: Optional[KNNExpression] = None + offset: int = 0 + limit: Optional[int] = None + page_size: int = DEFAULT_PAGE_SIZE + sort_fields: List[str] = dataclasses.field(default_factory=list) + projected_fields: List[str] = dataclasses.field(default_factory=list) + nocontent: bool = False + return_as_dict: bool = False + + +@dataclasses.dataclass +class _FindQueryCache: + expression: Optional[ExpressionOrNegated] = None + query: Optional[str] = None + pagination: List[str] = dataclasses.field(default_factory=list) + model_cache: List["RedisModel"] = dataclasses.field(default_factory=list) + + class FindQuery: def __init__( self, @@ -944,32 +967,130 @@ def __init__( "instance has one of these modules installed." ) - self.expressions = expressions - self.model = model - self.knn = knn - self.offset = offset - self.limit = limit or (self.knn.k if self.knn else DEFAULT_PAGE_SIZE) - self.page_size = page_size - self.nocontent = nocontent + self._state = _FindQueryState( + expressions=expressions, + model=model, + knn=knn, + offset=offset, + limit=limit if limit is not None else (knn.k if knn else DEFAULT_PAGE_SIZE), + page_size=page_size, + nocontent=nocontent, + ) + self._cache = _FindQueryCache() if sort_fields: - self.sort_fields = self.validate_sort_fields(sort_fields) - elif self.knn: - self.sort_fields = [self.knn.score_field_name] + resolved_sort_fields = self.validate_sort_fields(sort_fields) + elif knn: + resolved_sort_fields = [knn.score_field_name] else: - self.sort_fields = [] + resolved_sort_fields = [] if projected_fields: - self.projected_fields = self.validate_projected_fields(projected_fields) + resolved_projected_fields = self.validate_projected_fields( + projected_fields + ) else: - self.projected_fields = [] + resolved_projected_fields = [] + + self._state.sort_fields = resolved_sort_fields + self._state.projected_fields = resolved_projected_fields + self._state.return_as_dict = return_as_dict + + @property + def expressions(self): + return self._state.expressions + + @expressions.setter + def expressions(self, value): + self._state.expressions = value + + @property + def model(self): + return self._state.model + + @model.setter + def model(self, value): + self._state.model = value + + @property + def knn(self): + return self._state.knn + + @knn.setter + def knn(self, value): + self._state.knn = value + + @property + def offset(self): + return self._state.offset + + @offset.setter + def offset(self, value): + self._state.offset = value + + @property + def limit(self): + return self._state.limit + + @limit.setter + def limit(self, value): + self._state.limit = value + + @property + def page_size(self): + return self._state.page_size + + @page_size.setter + def page_size(self, value): + self._state.page_size = value + + @property + def sort_fields(self): + return self._state.sort_fields + + @sort_fields.setter + def sort_fields(self, value): + self._state.sort_fields = value + + @property + def projected_fields(self): + return self._state.projected_fields + + @projected_fields.setter + def projected_fields(self, value): + self._state.projected_fields = value - self.return_as_dict = return_as_dict + @property + def nocontent(self): + return self._state.nocontent + + @nocontent.setter + def nocontent(self, value): + self._state.nocontent = value - self._expression = None - self._query: Optional[str] = None - self._pagination: List[str] = [] - self._model_cache: List[RedisModel] = [] + @property + def return_as_dict(self): + return self._state.return_as_dict + + @return_as_dict.setter + def return_as_dict(self, value): + self._state.return_as_dict = value + + @property + def _expression(self): + return self._cache.expression + + @property + def _query(self): + return self._cache.query + + @property + def _pagination(self): + return self._cache.pagination + + @property + def _model_cache(self): + return self._cache.model_cache def dict(self) -> Dict[str, Any]: return dict( @@ -991,22 +1112,22 @@ def copy(self, **kwargs): @property def pagination(self): - if self._pagination: - return self._pagination - self._pagination = self.resolve_redisearch_pagination() - return self._pagination + if self._cache.pagination: + return self._cache.pagination + self._cache.pagination = self.resolve_redisearch_pagination() + return self._cache.pagination @property def expression(self): - if self._expression: - return self._expression + if self._cache.expression: + return self._cache.expression if self.expressions: - self._expression = reduce(operator.and_, self.expressions) + self._cache.expression = reduce(operator.and_, self.expressions) else: - self._expression = Expression( + self._cache.expression = Expression( left=None, right=None, op=Operators.ALL, parents=[] ) - return self._expression + return self._cache.expression @property def query(self): @@ -1016,19 +1137,19 @@ def query(self): NOTE: We cache the resolved query string after generating it. This should be OK because all mutations of FindQuery through public APIs return a new FindQuery instance. """ - if self._query: - return self._query - self._query = self._resolve_redisearch_query(self.expression) + if self._cache.query: + return self._cache.query + self._cache.query = self._resolve_redisearch_query(self.expression) if self.knn: # Always wrap the filter expression in parentheses when combining with KNN, # unless it's the wildcard "*". This ensures OR expressions like # "(A)| (B)" become "((A)| (B))=>[KNN ...]" instead of the invalid # "(A)| (B)=>[KNN ...]" where KNN only applies to the second term. - if self._query != "*": - self._query = f"({self._query})" - self._query += f"=>[{self.knn}]" + if self._cache.query != "*": + self._cache.query = f"({self._cache.query})" + self._cache.query += f"=>[{self.knn}]" # RETURN clause should be added to args, not to the query string - return self._query + return self._cache.query def validate_projected_fields(self, projected_fields: List[str]): for field in projected_fields: @@ -1941,7 +2062,7 @@ async def execute( # Reset the cache if we're executing from offset 0. if self.offset == 0: - self._model_cache.clear() + self._cache.model_cache.clear() # If the offset is greater than 0, we're paginating through a result set, # so append the new results to results already in the cache. @@ -1954,14 +2075,14 @@ async def execute( results = await self._parse_execute_results( raw_result, use_full_document_fallback ) - self._model_cache += results + self._cache.model_cache += results if not exhaust_results: - return self._model_cache + return self._cache.model_cache # The query returned all results, so we have no more work to do. if count <= len(results): - return self._model_cache + return self._cache.model_cache # Transparently (to the user) make subsequent requests to paginate # through the results and finally return them all. @@ -1973,8 +2094,8 @@ async def execute( _results = await query.execute(exhaust_results=False) if not _results: break - self._model_cache += _results - return self._model_cache + self._cache.model_cache += _results + return self._cache.model_cache async def get_query(self): query = self.copy() @@ -2073,8 +2194,8 @@ async def delete(self): return 0 async def __aiter__(self): - if self._model_cache: - for m in self._model_cache: + if self._cache.model_cache: + for m in self._cache.model_cache: yield m else: for m in await self.execute(): @@ -2099,8 +2220,8 @@ def __getitem__(self, item: int): "Cannot use [] notation with async code. " "Use FindQuery.get_item() instead." ) - if self._model_cache and len(self._model_cache) >= item: - return self._model_cache[item] + if self._cache.model_cache and len(self._cache.model_cache) >= item: + return self._cache.model_cache[item] query = self.copy(offset=item, limit=1) @@ -2123,8 +2244,8 @@ async def get_item(self, item: int): NOTE: This method is included specifically for async users, who cannot use the notation Model.find()[1000]. """ - if self._model_cache and len(self._model_cache) >= item: - return self._model_cache[item] + if self._cache.model_cache and len(self._cache.model_cache) >= item: + return self._cache.model_cache[item] query = self.copy(offset=item, limit=1) result = await query.execute() @@ -2186,14 +2307,80 @@ def __init__(self, default: Any = ..., **kwargs: Any) -> None: expire = kwargs.pop("expire", None) separator = kwargs.pop("separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR) super().__init__(default=default, **kwargs) - self.primary_key = primary_key - self.sortable = sortable - self.case_sensitive = case_sensitive - self.index = index - self.full_text_search = full_text_search - self.vector_options = vector_options - self.expire = expire - self.separator = separator + self._redis_options: Dict[str, Any] = { + "primary_key": primary_key, + "sortable": sortable, + "case_sensitive": case_sensitive, + "index": index, + "full_text_search": full_text_search, + "vector_options": vector_options, + "expire": expire, + "separator": separator, + } + + @property + def primary_key(self): + return self._redis_options["primary_key"] + + @primary_key.setter + def primary_key(self, value): + self._redis_options["primary_key"] = value + + @property + def sortable(self): + return self._redis_options["sortable"] + + @sortable.setter + def sortable(self, value): + self._redis_options["sortable"] = value + + @property + def case_sensitive(self): + return self._redis_options["case_sensitive"] + + @case_sensitive.setter + def case_sensitive(self, value): + self._redis_options["case_sensitive"] = value + + @property + def index(self): + return self._redis_options["index"] + + @index.setter + def index(self, value): + self._redis_options["index"] = value + + @property + def full_text_search(self): + return self._redis_options["full_text_search"] + + @full_text_search.setter + def full_text_search(self, value): + self._redis_options["full_text_search"] = value + + @property + def vector_options(self): + return self._redis_options["vector_options"] + + @vector_options.setter + def vector_options(self, value): + self._redis_options["vector_options"] = value + + @property + def expire(self): + return self._redis_options["expire"] + + @expire.setter + def expire(self, value): + self._redis_options["expire"] = value + + @property + def separator(self): + return self._redis_options["separator"] + + @separator.setter + def separator(self, value): + self._redis_options["separator"] = value class RelationshipInfo(Representation): @@ -2208,6 +2395,21 @@ def __init__( @dataclasses.dataclass +class VectorFieldParameters: + # Common optional parameters + initial_cap: Optional[int] = None + + # Optional parameters for FLAT + block_size: Optional[int] = None + + # Optional parameters for HNSW + m: Optional[int] = None + ef_construction: Optional[int] = None + ef_runtime: Optional[int] = None + epsilon: Optional[float] = None + + +@dataclasses.dataclass(init=False) class VectorFieldOptions: class ALGORITHM(Enum): FLAT = "FLAT" @@ -2226,18 +2428,98 @@ class DISTANCE_METRIC(Enum): type: TYPE dimension: int distance_metric: DISTANCE_METRIC + params: Optional[VectorFieldParameters] = None - # Common optional parameters - initial_cap: Optional[int] = None + def __init__( + self, + *, + algorithm: ALGORITHM, + type: TYPE, + dimension: int, + distance_metric: DISTANCE_METRIC, + initial_cap: Optional[int] = None, + block_size: Optional[int] = None, + m: Optional[int] = None, + ef_construction: Optional[int] = None, + ef_runtime: Optional[int] = None, + epsilon: Optional[float] = None, + params: Optional[VectorFieldParameters] = None, + ) -> None: + self.algorithm = algorithm + self.type = type + self.dimension = dimension + self.distance_metric = distance_metric + + if params is None: + params = VectorFieldParameters() + + if initial_cap is not None: + params.initial_cap = initial_cap + if block_size is not None: + params.block_size = block_size + if m is not None: + params.m = m + if ef_construction is not None: + params.ef_construction = ef_construction + if ef_runtime is not None: + params.ef_runtime = ef_runtime + if epsilon is not None: + params.epsilon = epsilon + + self.params = params + + def _ensure_params(self) -> VectorFieldParameters: + if self.params is None: + self.params = VectorFieldParameters() + return self.params - # Optional parameters for FLAT - block_size: Optional[int] = None + @property + def initial_cap(self) -> Optional[int]: + return self.params.initial_cap if self.params else None - # Optional parameters for HNSW - m: Optional[int] = None - ef_construction: Optional[int] = None - ef_runtime: Optional[int] = None - epsilon: Optional[float] = None + @initial_cap.setter + def initial_cap(self, value: Optional[int]) -> None: + self._ensure_params().initial_cap = value + + @property + def block_size(self) -> Optional[int]: + return self.params.block_size if self.params else None + + @block_size.setter + def block_size(self, value: Optional[int]) -> None: + self._ensure_params().block_size = value + + @property + def m(self) -> Optional[int]: + return self.params.m if self.params else None + + @m.setter + def m(self, value: Optional[int]) -> None: + self._ensure_params().m = value + + @property + def ef_construction(self) -> Optional[int]: + return self.params.ef_construction if self.params else None + + @ef_construction.setter + def ef_construction(self, value: Optional[int]) -> None: + self._ensure_params().ef_construction = value + + @property + def ef_runtime(self) -> Optional[int]: + return self.params.ef_runtime if self.params else None + + @ef_runtime.setter + def ef_runtime(self, value: Optional[int]) -> None: + self._ensure_params().ef_runtime = value + + @property + def epsilon(self) -> Optional[float]: + return self.params.epsilon if self.params else None + + @epsilon.setter + def epsilon(self, value: Optional[float]) -> None: + self._ensure_params().epsilon = value @staticmethod def flat( @@ -2252,8 +2534,10 @@ def flat( type=type, dimension=dimension, distance_metric=distance_metric, - initial_cap=initial_cap, - block_size=block_size, + params=VectorFieldParameters( + initial_cap=initial_cap, + block_size=block_size, + ), ) @staticmethod @@ -2272,19 +2556,24 @@ def hnsw( type=type, dimension=dimension, distance_metric=distance_metric, - initial_cap=initial_cap, - m=m, - ef_construction=ef_construction, - ef_runtime=ef_runtime, - epsilon=epsilon, + params=VectorFieldParameters( + initial_cap=initial_cap, + m=m, + ef_construction=ef_construction, + ef_runtime=ef_runtime, + epsilon=epsilon, + ), ) @property def schema(self): attr = [] - for k, v in vars(self).items(): - if k == "algorithm" or v is None: - continue + base_fields = ( + ("type", self.type), + ("dimension", self.dimension), + ("distance_metric", self.distance_metric), + ) + for k, v in base_fields: attr.extend( [ k.upper() if k != "dimension" else "DIM", @@ -2292,6 +2581,19 @@ def schema(self): ] ) + if self.params is None: + return " ".join([f"VECTOR {self.algorithm.name} {len(attr)}"] + attr) + + for k, v in dataclasses.asdict(self.params).items(): + if v is None: + continue + attr.extend( + [ + k.upper(), + str(v) if not isinstance(v, Enum) else v.name, + ] + ) + return " ".join([f"VECTOR {self.algorithm.name} {len(attr)}"] + attr) @@ -2412,23 +2714,22 @@ class BaseMeta(Protocol): encoding: str -@dataclasses.dataclass class DefaultMeta: """A default placeholder Meta object. TODO: Revisit whether this is really necessary, and whether making - these all optional here is the right choice. + these all optional here is the right choice. """ - global_key_prefix: Optional[str] = None - model_key_prefix: Optional[str] = None - primary_key_pattern: Optional[str] = None - database: Optional[redis.Redis] = None - primary_key: Optional[PrimaryKey] = None - primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None - index_name: Optional[str] = None - embedded: Optional[bool] = False - encoding: str = "utf-8" + global_key_prefix: ClassVar[Optional[str]] = None + model_key_prefix: ClassVar[Optional[str]] = None + primary_key_pattern: ClassVar[Optional[str]] = None + database: ClassVar[Optional[redis.Redis]] = None + primary_key: ClassVar[Optional[PrimaryKey]] = None + primary_key_creator_cls: ClassVar[Optional[Type[PrimaryKeyCreator]]] = None + index_name: ClassVar[Optional[str]] = None + embedded: ClassVar[bool] = False + encoding: ClassVar[str] = "utf-8" class ModelMeta(ModelMetaclass): From 1644ae6daa44a3a3836ab263ee38e1dba4389843 Mon Sep 17 00:00:00 2001 From: Renan Soares Date: Fri, 29 May 2026 13:20:58 -0300 Subject: [PATCH 3/4] refactoring too-many-branches --- aredis_om/model/encoders.py | 303 +++++++++++++++++++--------- aredis_om/model/model.py | 312 +++++++++++++---------------- aredis_om/redisvl.py | 106 +++++----- tests/test_encoders.py | 55 +++++ tests/test_redisvl_integration.py | 14 ++ tests/test_timestamp_conversion.py | 44 ++++ 6 files changed, 520 insertions(+), 314 deletions(-) create mode 100644 tests/test_encoders.py create mode 100644 tests/test_timestamp_conversion.py diff --git a/aredis_om/model/encoders.py b/aredis_om/model/encoders.py index cf493d12..f61a50bc 100644 --- a/aredis_om/model/encoders.py +++ b/aredis_om/model/encoders.py @@ -65,114 +65,153 @@ def generate_encoders_by_class_tuples( encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) -def jsonable_encoder( - obj: Any, - include: Optional[Union[SetIntStr, DictIntStrAny]] = None, - exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None, - by_alias: bool = True, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - custom_encoder: Dict[Any, Callable[[Any], Any]] = {}, - sqlalchemy_safe: bool = True, -) -> Any: +def _normalize_include_exclude( + include: Optional[Union[SetIntStr, DictIntStrAny]], + exclude: Optional[Union[SetIntStr, DictIntStrAny]], +) -> Tuple[Optional[Union[SetIntStr, DictIntStrAny]], Optional[Union[SetIntStr, DictIntStrAny]]]: if include is not None and not isinstance(include, (set, dict)): include = set(include) if exclude is not None and not isinstance(exclude, (set, dict)): exclude = set(exclude) + return include, exclude - if isinstance(obj, BaseModel) and hasattr(obj, "__config__"): - encoder = getattr(obj.__config__, "json_encoders", {}) - if custom_encoder: - encoder.update(custom_encoder) - obj_dict = obj.model_dump( - include=include, # type: ignore # in Pydantic - exclude=exclude, # type: ignore # in Pydantic - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - ) - if "__root__" in obj_dict: - obj_dict = obj_dict["__root__"] - return jsonable_encoder( - obj_dict, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - custom_encoder=encoder, - sqlalchemy_safe=sqlalchemy_safe, + +def _encode_pydantic_model( + obj: BaseModel, + include: Optional[Union[SetIntStr, DictIntStrAny]], + exclude: Optional[Union[SetIntStr, DictIntStrAny]], + by_alias: bool, + exclude_unset: bool, + exclude_defaults: bool, + exclude_none: bool, + custom_encoder: Dict[Any, Callable[[Any], Any]], + sqlalchemy_safe: bool, +) -> Any: + encoder = dict(getattr(obj.__config__, "json_encoders", {})) + if custom_encoder: + encoder.update(custom_encoder) + obj_dict = obj.model_dump( + include=include, # type: ignore # in Pydantic + exclude=exclude, # type: ignore # in Pydantic + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + ) + if "__root__" in obj_dict: + obj_dict = obj_dict["__root__"] + return jsonable_encoder( + obj_dict, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + custom_encoder=encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + + +def _encode_dict( + obj: Dict[Any, Any], + include: Optional[Union[SetIntStr, DictIntStrAny]], + exclude: Optional[Union[SetIntStr, DictIntStrAny]], + by_alias: bool, + exclude_unset: bool, + exclude_defaults: bool, + exclude_none: bool, + custom_encoder: Dict[Any, Callable[[Any], Any]], + sqlalchemy_safe: bool, +) -> Dict[Any, Any]: + encoded_dict = {} + for key, value in obj.items(): + should_include = ( + ( + not sqlalchemy_safe + or (not isinstance(key, str)) + or (not key.startswith("_sa")) + ) + and value is not PydanticUndefined + and (value is not None or not exclude_none) + and ((include and key in include) or not exclude or key not in exclude) ) - if dataclasses.is_dataclass(obj): - return dataclasses.asdict(obj) # type: ignore - if isinstance(obj, Enum): - return obj.value - if isinstance(obj, PurePath): - return str(obj) - if isinstance(obj, (str, int, float, type(None))): - return obj - if isinstance(obj, dict): - encoded_dict = {} - for key, value in obj.items(): - if ( - ( - not sqlalchemy_safe - or (not isinstance(key, str)) - or (not key.startswith("_sa")) - ) - and value is not PydanticUndefined - and (value is not None or not exclude_none) - and ((include and key in include) or not exclude or key not in exclude) - ): - encoded_key = jsonable_encoder( - key, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_value = jsonable_encoder( - value, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_dict[encoded_key] = encoded_value - return encoded_dict - if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): - encoded_list = [] - for item in obj: - encoded_list.append( - jsonable_encoder( - item, - include=include, - exclude=exclude, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) + if should_include: + encoded_key = jsonable_encoder( + key, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + encoded_value = jsonable_encoder( + value, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, ) - return encoded_list + encoded_dict[encoded_key] = encoded_value + return encoded_dict - if custom_encoder: - if type(obj) in custom_encoder: - return custom_encoder[type(obj)](obj) - else: - for encoder_type, encoder in custom_encoder.items(): - if isinstance(obj, encoder_type): - return encoder(obj) +def _encode_iterable( + obj: Union[List[Any], Set[Any], frozenset, GeneratorType, Tuple[Any, ...]], + include: Optional[Union[SetIntStr, DictIntStrAny]], + exclude: Optional[Union[SetIntStr, DictIntStrAny]], + by_alias: bool, + exclude_unset: bool, + exclude_defaults: bool, + exclude_none: bool, + custom_encoder: Dict[Any, Callable[[Any], Any]], + sqlalchemy_safe: bool, +) -> List[Any]: + encoded_list = [] + for item in obj: + encoded_list.append( + jsonable_encoder( + item, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + ) + return encoded_list + + +def _apply_custom_encoder( + obj: Any, + custom_encoder: Dict[Any, Callable[[Any], Any]], +) -> Any: + if type(obj) in custom_encoder: + return custom_encoder[type(obj)](obj) + for encoder_type, encoder in custom_encoder.items(): + if isinstance(obj, encoder_type): + return encoder(obj) + return None + + +def _apply_builtin_encoder(obj: Any) -> Any: if type(obj) in ENCODERS_BY_TYPE: return ENCODERS_BY_TYPE[type(obj)](obj) for encoder, classes_tuple in encoders_by_class_tuples.items(): if isinstance(obj, classes_tuple): return encoder(obj) + return None + +def _encode_fallback_object( + obj: Any, + by_alias: bool, + exclude_unset: bool, + exclude_defaults: bool, + exclude_none: bool, + custom_encoder: Dict[Any, Callable[[Any], Any]], + sqlalchemy_safe: bool, +) -> Any: errors: List[Exception] = [] try: data = dict(obj) @@ -192,3 +231,81 @@ def jsonable_encoder( custom_encoder=custom_encoder, sqlalchemy_safe=sqlalchemy_safe, ) + + +def jsonable_encoder( + obj: Any, + include: Optional[Union[SetIntStr, DictIntStrAny]] = None, + exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + custom_encoder: Dict[Any, Callable[[Any], Any]] = {}, + sqlalchemy_safe: bool = True, +) -> Any: + include, exclude = _normalize_include_exclude(include, exclude) + + if isinstance(obj, BaseModel) and hasattr(obj, "__config__"): + return _encode_pydantic_model( + obj, + include, + exclude, + by_alias, + exclude_unset, + exclude_defaults, + exclude_none, + custom_encoder, + sqlalchemy_safe, + ) + if dataclasses.is_dataclass(obj): + return dataclasses.asdict(obj) # type: ignore + if isinstance(obj, Enum): + return obj.value + if isinstance(obj, PurePath): + return str(obj) + if isinstance(obj, (str, int, float, type(None))): + return obj + if isinstance(obj, dict): + return _encode_dict( + obj, + include, + exclude, + by_alias, + exclude_unset, + exclude_defaults, + exclude_none, + custom_encoder, + sqlalchemy_safe, + ) + if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): + return _encode_iterable( + obj, + include, + exclude, + by_alias, + exclude_unset, + exclude_defaults, + exclude_none, + custom_encoder, + sqlalchemy_safe, + ) + + if custom_encoder: + encoded_obj = _apply_custom_encoder(obj, custom_encoder) + if encoded_obj is not None: + return encoded_obj + + encoded_obj = _apply_builtin_encoder(obj) + if encoded_obj is not None: + return encoded_obj + + return _encode_fallback_object( + obj, + by_alias, + exclude_unset, + exclude_defaults, + exclude_none, + custom_encoder, + sqlalchemy_safe, + ) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 0a8fb2a8..3afc6835 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -126,102 +126,92 @@ def convert_datetime_to_timestamp(obj): return obj +def _unwrap_optional_type(field_type): + """Return the inner type for Optional[T] annotations.""" + if get_origin(field_type) is Union: + non_none_types = [arg for arg in get_args(field_type) if arg is not type(None)] # noqa: E721 + if len(non_none_types) == 1: + return non_none_types[0] + return field_type + + +def _get_model_field_type(field_info): + """Extract the resolved annotation from a model field.""" + return _unwrap_optional_type(getattr(field_info, "annotation", None)) + + +def _is_redis_model_type(field_type): + """Check whether a type looks like a RedisModel subclass.""" + return ( + isinstance(field_type, type) + and hasattr(field_type, "model_fields") + and bool(field_type.model_fields) + ) + + +def _get_list_inner_type(field_type): + """Return the inner type for List[T] annotations.""" + if get_origin(field_type) in (list, List): + args = get_args(field_type) + if args: + return args[0] + return None + + +def _convert_timestamp_scalar(value, field_type): + """Convert a scalar timestamp to datetime or date.""" + if not isinstance(value, (int, float, str)): + return value + + try: + if isinstance(value, str): + value = float(value) + dt = datetime.datetime.fromtimestamp(value, datetime.timezone.utc) + except (ValueError, OSError): + return value + + if field_type is datetime.date: + return dt.date() + return dt + + +def _convert_timestamp_dict(value, field_type): + """Recursively convert a nested dictionary.""" + if _is_redis_model_type(field_type): + return convert_timestamp_to_datetime(value, field_type.model_fields) + return convert_timestamp_to_datetime(value, {}) + + +def _convert_timestamp_list(value, field_type): + """Recursively convert a list field, preserving nested models.""" + inner_type = _get_list_inner_type(field_type) + if _is_redis_model_type(inner_type): + return [ + convert_timestamp_to_datetime(item, inner_type.model_fields) + for item in value + ] + return convert_timestamp_to_datetime(value, {}) + + +def _convert_timestamp_field_value(value, field_type): + """Convert a field value based on the resolved annotation.""" + if field_type in (datetime.datetime, datetime.date): + return _convert_timestamp_scalar(value, field_type) + if isinstance(value, dict): + return _convert_timestamp_dict(value, field_type) + if isinstance(value, list): + return _convert_timestamp_list(value, field_type) + return convert_timestamp_to_datetime(value, {}) + + def convert_timestamp_to_datetime(obj, model_fields): """Convert Unix timestamps back to datetime objects based on model field types.""" if isinstance(obj, dict): result = {} for key, value in obj.items(): - if key in model_fields: - field_info = model_fields[key] - field_type = ( - field_info.annotation if hasattr(field_info, "annotation") else None - ) - - # Handle Optional types - extract the inner type - if hasattr(field_type, "__origin__") and field_type.__origin__ is Union: - # For Optional[T] which is Union[T, None], get the non-None type - args = getattr(field_type, "__args__", ()) - non_none_types = [ - arg - for arg in args - if arg is not type(None) # noqa: E721 - ] - if len(non_none_types) == 1: - field_type = non_none_types[0] - - # Handle direct datetime/date fields - if field_type in (datetime.datetime, datetime.date) and isinstance( - value, (int, float, str) - ): - try: - if isinstance(value, str): - value = float(value) - # Return UTC-aware datetime for consistency. - # Timestamps are always UTC-referenced, so we return - # UTC-aware datetimes. Users can convert to their - # preferred timezone with dt.astimezone(tz). - dt = datetime.datetime.fromtimestamp( - value, datetime.timezone.utc - ) - # If the field is specifically a date, convert to date - if field_type is datetime.date: - result[key] = dt.date() - else: - result[key] = dt - except (ValueError, OSError): - result[key] = value # Keep original value if conversion fails - # Handle nested models - check if it's a RedisModel subclass - elif isinstance(value, dict): - try: - # Check if field_type is a class and subclass of RedisModel - if ( - isinstance(field_type, type) - and hasattr(field_type, "model_fields") - and field_type.model_fields - ): - result[key] = convert_timestamp_to_datetime( - value, field_type.model_fields - ) - else: - result[key] = convert_timestamp_to_datetime(value, {}) - except (TypeError, AttributeError): - result[key] = convert_timestamp_to_datetime(value, {}) - # Handle lists that might contain nested models - elif isinstance(value, list): - # Try to extract the inner type from List[SomeModel] - inner_type = None - if ( - hasattr(field_type, "__origin__") - and field_type.__origin__ in (list, List) - and hasattr(field_type, "__args__") - and field_type.__args__ - ): - inner_type = field_type.__args__[0] - - # Check if the inner type is a nested model - try: - if ( - isinstance(inner_type, type) - and hasattr(inner_type, "model_fields") - and inner_type.model_fields - ): - result[key] = [ - convert_timestamp_to_datetime( - item, inner_type.model_fields - ) - for item in value - ] - else: - result[key] = convert_timestamp_to_datetime(value, {}) - except (TypeError, AttributeError): - result[key] = convert_timestamp_to_datetime(value, {}) - else: - result[key] = convert_timestamp_to_datetime(value, {}) - else: - result[key] = convert_timestamp_to_datetime(value, {}) - else: - # For keys not in model_fields, still recurse but with empty field info - result[key] = convert_timestamp_to_datetime(value, {}) + field_info = model_fields.get(key) + field_type = _get_model_field_type(field_info) if field_info else None + result[key] = _convert_timestamp_field_value(value, field_type) return result elif isinstance(obj, list): return [convert_timestamp_to_datetime(item, model_fields) for item in obj] @@ -248,94 +238,64 @@ def convert_bytes_to_base64(obj): return obj -def convert_base64_to_bytes(obj, model_fields): - """Convert base64-encoded strings back to bytes based on model field types.""" +def _decode_base64_string(value): + """Decode a base64 string, returning the original value on failure.""" import base64 - if isinstance(obj, dict): - result = {} - for key, value in obj.items(): - if key in model_fields: - field_info = model_fields[key] - field_type = ( - field_info.annotation if hasattr(field_info, "annotation") else None - ) + try: + return base64.b64decode(value) + except (ValueError, TypeError): + return value - # Handle Optional types - extract the inner type - if hasattr(field_type, "__origin__") and field_type.__origin__ is Union: - # For Optional[T] which is Union[T, None], get the non-None type - args = getattr(field_type, "__args__", ()) - non_none_types = [ - arg - for arg in args - if arg is not type(None) # noqa: E721 - ] - if len(non_none_types) == 1: - field_type = non_none_types[0] - - # Handle bytes fields - if field_type is bytes and isinstance(value, str): - try: - result[key] = base64.b64decode(value) - except (ValueError, TypeError): - # If it's not valid base64, keep original value - result[key] = value - # Handle nested models - check if it's a model with fields - elif isinstance(value, dict): - try: - if ( - isinstance(field_type, type) - and hasattr(field_type, "model_fields") - and field_type.model_fields - ): - result[key] = convert_base64_to_bytes( - value, field_type.model_fields - ) - else: - result[key] = convert_base64_to_bytes(value, {}) - except (TypeError, AttributeError): - result[key] = convert_base64_to_bytes(value, {}) - # Handle lists that might contain nested models - elif isinstance(value, list): - # Try to extract the inner type from List[SomeModel] - inner_type = None - if ( - hasattr(field_type, "__origin__") - and field_type.__origin__ in (list, List) - and hasattr(field_type, "__args__") - and field_type.__args__ - ): - inner_type = field_type.__args__[0] - - if inner_type is not None: - try: - if ( - isinstance(inner_type, type) - and hasattr(inner_type, "model_fields") - and inner_type.model_fields - ): - result[key] = [ - ( - convert_base64_to_bytes( - item, inner_type.model_fields - ) - if isinstance(item, dict) - else item - ) - for item in value - ] - else: - result[key] = convert_base64_to_bytes(value, {}) - except (TypeError, AttributeError): - result[key] = convert_base64_to_bytes(value, {}) - else: - result[key] = convert_base64_to_bytes(value, {}) - else: - result[key] = convert_base64_to_bytes(value, {}) - else: - # For keys not in model_fields, still recurse but with empty field info - result[key] = convert_base64_to_bytes(value, {}) - return result + +def _get_nested_base64_model_fields(field_type): + """Return model fields for nested Redis models, if available.""" + if _is_redis_model_type(field_type): + return field_type.model_fields + return {} + + +def _convert_base64_list(value, field_type): + """Convert base64 values inside a list.""" + inner_type = _get_list_inner_type(field_type) + if _is_redis_model_type(inner_type): + return [ + convert_base64_to_bytes(item, inner_type.model_fields) + if isinstance(item, dict) + else item + for item in value + ] + return [convert_base64_to_bytes(item, {}) for item in value] + + +def _convert_base64_field_value(value, field_type): + """Convert a single field value based on the resolved annotation.""" + if field_type is bytes and isinstance(value, str): + return _decode_base64_string(value) + if isinstance(value, dict): + return convert_base64_to_bytes(value, _get_nested_base64_model_fields(field_type)) + if isinstance(value, list): + return _convert_base64_list(value, field_type) + return value + + +def _convert_base64_dict(obj, model_fields): + """Recursively convert a dictionary of values from base64 to bytes.""" + result = {} + for key, value in obj.items(): + if key in model_fields: + field_info = model_fields[key] + field_type = _get_model_field_type(field_info) + result[key] = _convert_base64_field_value(value, field_type) + else: + result[key] = convert_base64_to_bytes(value, {}) + return result + + +def convert_base64_to_bytes(obj, model_fields): + """Convert base64-encoded strings back to bytes based on model field types.""" + if isinstance(obj, dict): + return _convert_base64_dict(obj, model_fields) elif isinstance(obj, list): return [convert_base64_to_bytes(item, model_fields) for item in obj] else: diff --git a/aredis_om/redisvl.py b/aredis_om/redisvl.py index 4ab016b3..d17bde17 100644 --- a/aredis_om/redisvl.py +++ b/aredis_om/redisvl.py @@ -56,65 +56,81 @@ def _get_field_type( vector_options: Optional[VectorFieldOptions] = getattr( field_info, "vector_options", None ) - sortable = getattr(field_info, "sortable", False) is True - full_text_search = getattr(field_info, "full_text_search", False) is True - case_sensitive = getattr(field_info, "case_sensitive", False) is True - # Vector field if vector_options: - attrs = { - "dims": vector_options.dimension, - "distance_metric": vector_options.distance_metric.name.lower(), - "algorithm": vector_options.algorithm.name.lower(), - "datatype": vector_options.type.name.lower(), - } - if vector_options.initial_cap: - attrs["initial_cap"] = vector_options.initial_cap - is_flat = vector_options.algorithm.name == "FLAT" - if is_flat and vector_options.block_size: - attrs["block_size"] = vector_options.block_size - if vector_options.algorithm.name == "HNSW": - if vector_options.m: - attrs["m"] = vector_options.m - if vector_options.ef_construction: - attrs["ef_construction"] = vector_options.ef_construction - if vector_options.ef_runtime: - attrs["ef_runtime"] = vector_options.ef_runtime - if vector_options.epsilon: - attrs["epsilon"] = vector_options.epsilon - return {"name": field_name, "type": "vector", "attrs": attrs} - - # Numeric field + return _get_vector_field_def(field_name, vector_options) + if is_numeric_type(field_type): - attrs = {"sortable": sortable} - return {"name": field_name, "type": "numeric", "attrs": attrs} + return _get_numeric_field_def(field_name, field_info) - # Boolean - stored as TAG if field_type is bool: return {"name": field_name, "type": "tag"} - # String field if isinstance(field_type, type) and issubclass(field_type, str): - if full_text_search: - attrs = {"sortable": sortable} - return {"name": field_name, "type": "text", "attrs": attrs} - else: - attrs = {"sortable": sortable, "case_sensitive": case_sensitive} - return {"name": field_name, "type": "tag", "attrs": attrs} + return _get_string_field_def(field_name, field_info) - # List of strings -> TAG if is_supported_container_type(field_type): - from typing import get_args - - inner_types = get_args(field_type) - if inner_types and inner_types[0] is str: - attrs = {"sortable": sortable} - return {"name": field_name, "type": "tag", "attrs": attrs} + return _get_container_field_def(field_name, field_type, field_info) - # Default to tag for unknown types return {"name": field_name, "type": "tag"} +def _get_vector_field_def( + field_name: str, vector_options: VectorFieldOptions +) -> Dict[str, Any]: + attrs = { + "dims": vector_options.dimension, + "distance_metric": vector_options.distance_metric.name.lower(), + "algorithm": vector_options.algorithm.name.lower(), + "datatype": vector_options.type.name.lower(), + } + if vector_options.initial_cap: + attrs["initial_cap"] = vector_options.initial_cap + if vector_options.algorithm.name == "FLAT" and vector_options.block_size: + attrs["block_size"] = vector_options.block_size + if vector_options.algorithm.name == "HNSW": + if vector_options.m: + attrs["m"] = vector_options.m + if vector_options.ef_construction: + attrs["ef_construction"] = vector_options.ef_construction + if vector_options.ef_runtime: + attrs["ef_runtime"] = vector_options.ef_runtime + if vector_options.epsilon: + attrs["epsilon"] = vector_options.epsilon + return {"name": field_name, "type": "vector", "attrs": attrs} + + +def _get_numeric_field_def(field_name: str, field_info: FieldInfo) -> Dict[str, Any]: + sortable = getattr(field_info, "sortable", False) is True + return {"name": field_name, "type": "numeric", "attrs": {"sortable": sortable}} + + +def _get_string_field_def(field_name: str, field_info: FieldInfo) -> Dict[str, Any]: + sortable = getattr(field_info, "sortable", False) is True + full_text_search = getattr(field_info, "full_text_search", False) is True + case_sensitive = getattr(field_info, "case_sensitive", False) is True + + if full_text_search: + return {"name": field_name, "type": "text", "attrs": {"sortable": sortable}} + return { + "name": field_name, + "type": "tag", + "attrs": {"sortable": sortable, "case_sensitive": case_sensitive}, + } + + +def _get_container_field_def( + field_name: str, field_type: Any, field_info: FieldInfo +) -> Optional[Dict[str, Any]]: + from typing import get_args + + sortable = getattr(field_info, "sortable", False) is True + inner_types = get_args(field_type) + if inner_types and inner_types[0] is str: + return {"name": field_name, "type": "tag", "attrs": {"sortable": sortable}} + return None + + def to_redisvl_schema(model_cls: Type[RedisModel]) -> "IndexSchema": """ Convert a Redis OM model to a RedisVL IndexSchema. diff --git a/tests/test_encoders.py b/tests/test_encoders.py new file mode 100644 index 00000000..eedbdf44 --- /dev/null +++ b/tests/test_encoders.py @@ -0,0 +1,55 @@ +import datetime +from dataclasses import dataclass +from enum import Enum +from pathlib import PurePosixPath + +from aredis_om.model.encoders import jsonable_encoder + + +@dataclass +class Address: + city: str + tags: list[str] + + +class Status(Enum): + ACTIVE = "active" + + +def test_jsonable_encoder_handles_nested_standard_types(): + payload = { + "status": Status.ACTIVE, + "path": PurePosixPath("foo/bar"), + "items": [Address("Porto", ["a"]), (1, 2)], + "keep": "value", + "skip_none": None, + "_sa_instance_state": "ignored", + } + + encoded = jsonable_encoder(payload, exclude_none=True) + + assert encoded == { + "status": "active", + "path": "foo/bar", + "items": [{"city": "Porto", "tags": ["a"]}, [1, 2]], + "keep": "value", + } + + +def test_jsonable_encoder_custom_encoder_takes_precedence(): + value = datetime.datetime(2024, 1, 2, 3, 4, 5) + + encoded = jsonable_encoder( + value, + custom_encoder={datetime.date: lambda _: "custom"}, + ) + + assert encoded == "custom" + + +def test_jsonable_encoder_falls_back_to_vars(): + class Payload: + def __init__(self): + self.value = 42 + + assert jsonable_encoder(Payload()) == {"value": 42} diff --git a/tests/test_redisvl_integration.py b/tests/test_redisvl_integration.py index 7788f7ad..dabd5a7d 100644 --- a/tests/test_redisvl_integration.py +++ b/tests/test_redisvl_integration.py @@ -94,6 +94,8 @@ async def test_to_redisvl_schema_json_model(json_model_with_vector): Document = json_model_with_vector schema = to_redisvl_schema(Document) + schema_dict = schema.to_dict() + fields = {field["name"]: field for field in schema_dict["fields"]} assert isinstance(schema, IndexSchema) assert schema.index.name == Document.Meta.index_name @@ -107,6 +109,11 @@ async def test_to_redisvl_schema_json_model(json_model_with_vector): assert "views" in field_names assert "embedding" in field_names + assert fields["title"]["type"] == "tag" + assert fields["content"]["type"] == "text" + assert fields["views"]["type"] == "numeric" + assert fields["embedding"]["type"] == "vector" + @py_test_mark_asyncio async def test_to_redisvl_schema_hash_model(hash_model_indexed): @@ -114,6 +121,8 @@ async def test_to_redisvl_schema_hash_model(hash_model_indexed): Product = hash_model_indexed schema = to_redisvl_schema(Product) + schema_dict = schema.to_dict() + fields = {field["name"]: field for field in schema_dict["fields"]} assert isinstance(schema, IndexSchema) assert schema.index.storage_type.value == "hash" @@ -124,6 +133,11 @@ async def test_to_redisvl_schema_hash_model(hash_model_indexed): assert "price" in field_names assert "in_stock" in field_names + assert fields["name"]["type"] == "tag" + assert fields["description"]["type"] == "text" + assert fields["price"]["type"] == "numeric" + assert fields["in_stock"]["type"] == "tag" + @py_test_mark_asyncio async def test_to_redisvl_schema_non_indexed_raises(non_indexed_model): diff --git a/tests/test_timestamp_conversion.py b/tests/test_timestamp_conversion.py new file mode 100644 index 00000000..3f4c5be5 --- /dev/null +++ b/tests/test_timestamp_conversion.py @@ -0,0 +1,44 @@ +import datetime +from typing import List, Optional + +from aredis_om.model.model import convert_timestamp_to_datetime + + +class DummyField: + def __init__(self, annotation): + self.annotation = annotation + + +class ChildModel: + model_fields = { + "created_at": DummyField(datetime.datetime), + } + + +class ParentModel: + model_fields = { + "child": DummyField(ChildModel), + "children": DummyField(List[ChildModel]), + "published_on": DummyField(Optional[datetime.date]), + } + + +def test_convert_timestamp_to_datetime_handles_nested_models(): + timestamp = datetime.datetime(2024, 1, 1, 12, 30, tzinfo=datetime.timezone.utc).timestamp() + + result = convert_timestamp_to_datetime( + { + "child": {"created_at": timestamp}, + "children": [{"created_at": timestamp}], + "published_on": timestamp, + }, + ParentModel.model_fields, + ) + + assert result["child"]["created_at"] == datetime.datetime( + 2024, 1, 1, 12, 30, tzinfo=datetime.timezone.utc + ) + assert result["children"][0]["created_at"] == datetime.datetime( + 2024, 1, 1, 12, 30, tzinfo=datetime.timezone.utc + ) + assert result["published_on"] == datetime.date(2024, 1, 1) From 3e2a0856681f52bc075200e29368066543ab6a22 Mon Sep 17 00:00:00 2001 From: Rian Lima Date: Sun, 31 May 2026 12:33:59 -0300 Subject: [PATCH 4/4] refactor: remove too-many-branches code smells --- .../data/builtin/datetime_migration.py | 76 +- aredis_om/model/migrations/data/migrator.py | 148 ++- .../migrations/schema/legacy_migrator.py | 169 +-- aredis_om/model/migrations/schema/migrator.py | 37 +- aredis_om/model/model.py | 1141 ++++++++++------- aredis_om/model/render_tree.py | 64 +- tests/test_find_query.py | 25 + tests/test_json_path_projection.py | 59 + tests/test_render_tree.py | 35 + 9 files changed, 1107 insertions(+), 647 deletions(-) create mode 100644 tests/test_json_path_projection.py create mode 100644 tests/test_render_tree.py diff --git a/aredis_om/model/migrations/data/builtin/datetime_migration.py b/aredis_om/model/migrations/data/builtin/datetime_migration.py index f0e02ef0..3f65370f 100644 --- a/aredis_om/model/migrations/data/builtin/datetime_migration.py +++ b/aredis_om/model/migrations/data/builtin/datetime_migration.py @@ -370,6 +370,49 @@ async def _collect_hash_keys(self, model_class) -> List[str]: return all_keys + async def _process_hash_batch( + self, + batch_keys: List[str], + datetime_fields: List[str], + model_name: str, + total_keys: int, + ) -> int: + """Process a batch of hash keys and return the number of handled keys.""" + processed_count = 0 + + for key in batch_keys: + try: + if await self._process_hash_key( + key, + datetime_fields, + model_name, + total_keys, + ): + processed_count += 1 + + except DataMigrationError: + # Re-raise migration errors + raise + except Exception as e: + log.error(f"Unexpected error processing hash key {key}: {e}") + if self.failure_mode == ConversionFailureMode.FAIL: + raise DataMigrationError( + f"Unexpected error processing hash key {key}: {e}" + ) + # Continue with next key for other failure modes + + return processed_count + + def _log_hash_batch_completion( + self, batch_start: int, batch_size_actual: int, batch_time: float + ) -> None: + """Log summary information for a processed hash batch.""" + log.info( + f"Completed batch {batch_start // self.batch_size + 1}: " + f"{batch_size_actual} keys in {batch_time:.2f}s " + f"({batch_size_actual / batch_time:.1f} keys/sec)" + ) + @staticmethod def _normalize_hash_data(hash_data: Dict[Any, Any]) -> Dict[str, Any]: """Normalize hash payload to string keys and values.""" @@ -585,36 +628,17 @@ async def _process_hash_model( batch_keys = all_keys[batch_start:batch_end] batch_start_time = time.time() - - for key in batch_keys: - try: - if await self._process_hash_key( - key, - datetime_fields, - model_class.__name__, - total_keys, - ): - processed_count += 1 - - except DataMigrationError: - # Re-raise migration errors - raise - except Exception as e: - log.error(f"Unexpected error processing hash key {key}: {e}") - if self.failure_mode == ConversionFailureMode.FAIL: - raise DataMigrationError( - f"Unexpected error processing hash key {key}: {e}" - ) - # Continue with next key for other failure modes + processed_count += await self._process_hash_batch( + batch_keys, + datetime_fields, + model_class.__name__, + total_keys, + ) # Log batch completion batch_time = time.time() - batch_start_time batch_size_actual = len(batch_keys) - log.info( - f"Completed batch {batch_start // self.batch_size + 1}: " - f"{batch_size_actual} keys in {batch_time:.2f}s " - f"({batch_size_actual / batch_time:.1f} keys/sec)" - ) + self._log_hash_batch_completion(batch_start, batch_size_actual, batch_time) # Progress reporting self._log_progress(processed_count, total_keys, "HashModel keys") diff --git a/aredis_om/model/migrations/data/migrator.py b/aredis_om/model/migrations/data/migrator.py index 6b9ba1f0..e12ef250 100644 --- a/aredis_om/model/migrations/data/migrator.py +++ b/aredis_om/model/migrations/data/migrator.py @@ -257,40 +257,43 @@ async def run_migrations( applied_count = 0 for migration in pending_migrations: - if verbose: - print(f"Running migration: {migration.migration_id}") - start_time = time.time() + applied_count += await self._run_single_migration(migration, verbose) - # Check if migration can run - if not await migration.can_run(): - if verbose: - print( - f"Skipping migration {migration.migration_id}: can_run() returned False" - ) - continue + if verbose: + print(f"Applied {applied_count} migration(s).") - try: - await migration.up() - await self.mark_migration_applied(migration.migration_id) - applied_count += 1 + return applied_count - if verbose: - end_time = time.time() - print( - f"Applied migration {migration.migration_id} in {end_time - start_time:.2f}s" - ) + async def _run_single_migration( + self, migration: BaseMigration, verbose: bool + ) -> int: + if verbose: + print(f"Running migration: {migration.migration_id}") + start_time = time.time() - except Exception as e: - if verbose: - print(f"Migration {migration.migration_id} failed: {e}") - raise DataMigrationError( - f"Migration {migration.migration_id} failed: {e}" + # Check if migration can run + if not await migration.can_run(): + if verbose: + print( + f"Skipping migration {migration.migration_id}: can_run() returned False" ) + return 0 - if verbose: - print(f"Applied {applied_count} migration(s).") + try: + await migration.up() + await self.mark_migration_applied(migration.migration_id) - return applied_count + if verbose: + end_time = time.time() + print( + f"Applied migration {migration.migration_id} in {end_time - start_time:.2f}s" + ) + + return 1 + except Exception as e: + if verbose: + print(f"Migration {migration.migration_id} failed: {e}") + raise DataMigrationError(f"Migration {migration.migration_id} failed: {e}") async def run_migrations_with_monitoring( self, @@ -314,26 +317,25 @@ async def run_migrations_with_monitoring( monitor = PerformanceMonitor() monitor.start() - pending_migrations = await self.get_pending_migrations() - - if limit: - pending_migrations = pending_migrations[:limit] + pending_migrations = await self._get_pending_migrations(limit) if not pending_migrations: if verbose: print("No pending migrations found.") - return { - "applied_count": 0, - "total_migrations": 0, - "performance_stats": monitor.get_stats(), - "errors": [], - } + monitor.finish() + return self._build_monitoring_result( + applied_count=0, + pending_migrations=[], + monitor=monitor, + errors=[], + ) self._log_pending_migrations(pending_migrations, verbose) if dry_run: if verbose: print("Dry run mode - no changes will be applied.") + monitor.finish() return self._build_monitoring_result( applied_count=len(pending_migrations), pending_migrations=pending_migrations, @@ -343,14 +345,53 @@ async def run_migrations_with_monitoring( ) applied_count = 0 - errors = [] + applied_count, errors = await self._run_pending_migrations_with_monitoring( + pending_migrations=pending_migrations, + monitor=monitor, + verbose=verbose, + progress_callback=progress_callback, + ) + + monitor.finish() + + result = self._build_monitoring_result( + applied_count=applied_count, + pending_migrations=pending_migrations, + monitor=monitor, + errors=errors, + ) + + self._log_monitoring_summary(result, verbose) + + return result + + async def _get_pending_migrations( + self, limit: Optional[int] + ) -> List[BaseMigration]: + pending_migrations = await self.get_pending_migrations() + + if limit: + return pending_migrations[:limit] + + return pending_migrations + + async def _run_pending_migrations_with_monitoring( + self, + pending_migrations: List[BaseMigration], + monitor: PerformanceMonitor, + verbose: bool, + progress_callback: Optional[Callable], + ) -> tuple[int, List[Dict[str, Any]]]: + applied_count = 0 + errors: List[Dict[str, Any]] = [] + total_migrations = len(pending_migrations) for i, migration in enumerate(pending_migrations): batch_start_time = time.time() if verbose: print( - f"Running migration {i + 1}/{len(pending_migrations)}: {migration.migration_id}" + f"Running migration {i + 1}/{total_migrations}: {migration.migration_id}" ) # Check if migration can run @@ -378,16 +419,11 @@ async def run_migrations_with_monitoring( # Call progress callback if provided if progress_callback: progress_callback( - applied_count, len(pending_migrations), migration.migration_id + applied_count, total_migrations, migration.migration_id ) except Exception as e: - error_info = { - "migration_id": migration.migration_id, - "error": str(e), - "timestamp": datetime.now().isoformat(), - } - errors.append(error_info) + errors.append(self._build_migration_error(migration.migration_id, e)) if verbose: print(f"Migration {migration.migration_id} failed: {e}") @@ -395,18 +431,7 @@ async def run_migrations_with_monitoring( # For now, stop on first error - could be made configurable break - monitor.finish() - - result = self._build_monitoring_result( - applied_count=applied_count, - pending_migrations=pending_migrations, - monitor=monitor, - errors=errors, - ) - - self._log_monitoring_summary(result, verbose) - - return result + return applied_count, errors def _log_pending_migrations( self, pending_migrations: List[BaseMigration], verbose: bool @@ -441,6 +466,13 @@ def _build_monitoring_result( return result + def _build_migration_error(self, migration_id: str, error: Exception) -> Dict[str, Any]: + return { + "migration_id": migration_id, + "error": str(error), + "timestamp": datetime.now().isoformat(), + } + def _log_monitoring_summary( self, result: Dict[str, Any], verbose: bool ) -> None: diff --git a/aredis_om/model/migrations/schema/legacy_migrator.py b/aredis_om/model/migrations/schema/legacy_migrator.py index f2de2a57..04f139fd 100644 --- a/aredis_om/model/migrations/schema/legacy_migrator.py +++ b/aredis_om/model/migrations/schema/legacy_migrator.py @@ -135,6 +135,86 @@ def __init__(self, module=None, conn=None): self.conn = conn self.migrations: List[IndexMigration] = [] + def _get_connection(self, cls): + """Return a Redis connection for a model, recovering closed loops.""" + try: + return self.conn or cls.db() + except RuntimeError as e: + if "Event loop is closed" not in str(e): + raise + + from ....connections import get_redis_connection + + return get_redis_connection() + + def _get_schema(self, name, cls): + try: + schema = cls.redisearch_schema() + except NotImplementedError: + log.info("Skipping migrations for %s", name) + return None + + current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec + return schema, current_hash + + async def _index_exists(self, conn, index_name): + try: + await conn.ft(index_name).info() + return conn, True + except RuntimeError as e: + if "Event loop is closed" not in str(e): + raise + + from ....connections import get_redis_connection + + conn = get_redis_connection() + try: + await conn.ft(index_name).info() + return conn, True + except redis.ResponseError: + return conn, False + except redis.ResponseError: + return conn, False + + def _append_create_migration(self, name, cls, schema, current_hash, conn): + self.migrations.append( + IndexMigration( + name, + cls.Meta.index_name, + schema, + current_hash, + MigrationAction.CREATE, + conn, + ) + ) + + def _append_recreate_migrations( + self, name, cls, schema, current_hash, conn, stored_hash + ): + # TODO: Switch out schema with an alias to avoid downtime -- separate migration? + self.migrations.append( + IndexMigration( + name, + cls.Meta.index_name, + schema, + current_hash, + MigrationAction.DROP, + conn, + stored_hash, + ) + ) + self.migrations.append( + IndexMigration( + name, + cls.Meta.index_name, + schema, + current_hash, + MigrationAction.CREATE, + conn, + stored_hash, + ) + ) + async def detect_migrations(self): """Detect schema changes between models and Redis indexes.""" if self.module: @@ -147,92 +227,25 @@ async def detect_migrations(self): for name, cls in model_registry.items(): hash_key = schema_hash_key(cls.Meta.index_name) - # Try to get a connection, but handle event loop issues gracefully - try: - conn = self.conn or cls.db() - except RuntimeError as e: - if "Event loop is closed" in str(e): - # Model connection is bound to closed event loop, create fresh one - from ....connections import get_redis_connection - - conn = get_redis_connection() - else: - raise - - try: - schema = cls.redisearch_schema() - except NotImplementedError: - log.info("Skipping migrations for %s", name) + schema_data = self._get_schema(name, cls) + if schema_data is None: continue - current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec - try: - await conn.ft(cls.Meta.index_name).info() - except RuntimeError as e: - if "Event loop is closed" in str(e): - # Connection had event loop issues, try with a fresh connection - from ....connections import get_redis_connection - - conn = get_redis_connection() - try: - await conn.ft(cls.Meta.index_name).info() - except redis.ResponseError: - # Index doesn't exist, proceed to create it - self.migrations.append( - IndexMigration( - name, - cls.Meta.index_name, - schema, - current_hash, - MigrationAction.CREATE, - conn, - ) - ) - continue - else: - raise - except redis.ResponseError: - self.migrations.append( - IndexMigration( - name, - cls.Meta.index_name, - schema, - current_hash, - MigrationAction.CREATE, - conn, - ) - ) + schema, current_hash = schema_data + conn = self._get_connection(cls) + + conn, exists = await self._index_exists(conn, cls.Meta.index_name) + if not exists: + self._append_create_migration(name, cls, schema, current_hash, conn) continue stored_hash = await conn.get(hash_key) if isinstance(stored_hash, bytes): stored_hash = stored_hash.decode("utf-8") - schema_out_of_date = current_hash != stored_hash - - if schema_out_of_date: - # TODO: Switch out schema with an alias to avoid downtime -- separate migration? - self.migrations.append( - IndexMigration( - name, - cls.Meta.index_name, - schema, - current_hash, - MigrationAction.DROP, - conn, - stored_hash, - ) - ) - self.migrations.append( - IndexMigration( - name, - cls.Meta.index_name, - schema, - current_hash, - MigrationAction.CREATE, - conn, - stored_hash, - ) + if current_hash != stored_hash: + self._append_recreate_migrations( + name, cls, schema, current_hash, conn, stored_hash ) async def run(self): diff --git a/aredis_om/model/migrations/schema/migrator.py b/aredis_om/model/migrations/schema/migrator.py index 121646c1..453f86cb 100644 --- a/aredis_om/model/migrations/schema/migrator.py +++ b/aredis_om/model/migrations/schema/migrator.py @@ -163,6 +163,27 @@ async def rollback( # Don't mark as unapplied if rollback failed for other reasons return False + async def _rollback_migration( + self, + migration_id: str, + migration: BaseSchemaMigration, + verbose: bool = False, + ) -> bool: + try: + if verbose: + print(f"Rolling back: {migration_id}") + await migration.down() + await self.mark_unapplied(migration_id) + return True + except NotImplementedError: + if verbose: + print(f"Migration {migration_id} does not support rollback, stopping") + return False + except Exception as e: + if verbose: + print(f"Rollback failed for {migration_id}: {e}, stopping") + return False + async def downgrade( self, steps: int = 1, dry_run: bool = False, verbose: bool = False ) -> int: @@ -202,21 +223,9 @@ async def downgrade( if verbose: print(f"Warning: Migration {mid} not found on disk, skipping") continue - mig = discovered[mid] - try: - if verbose: - print(f"Rolling back: {mid}") - await mig.down() - await self.mark_unapplied(mid) - count += 1 - except NotImplementedError: - if verbose: - print(f"Migration {mid} does not support rollback, stopping") - break - except Exception as e: - if verbose: - print(f"Rollback failed for {mid}: {e}, stopping") + if not await self._rollback_migration(mid, discovered[mid], verbose): break + count += 1 if verbose: print(f"Rolled back {count} migration(s).") diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 3afc6835..92e2d15d 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1127,73 +1127,70 @@ def _validate_deep_field_path(self, field_path: str): """Validate that a deep field path like 'address__city' exists in the model.""" parts = field_path.split("__") current_model = self.model - current_field_name = parts[0] + current_model = self._validate_deep_field_segment( + current_model, parts[0], field_path, is_root=True + ) + + # Dict fields can accept any nested access, so validation stops there. + if current_model is None: + return + + # Walk through intermediate nested segments; the final segment only needs to exist. + for field_name in parts[1:-1]: + current_model = self._validate_deep_field_segment( + current_model, field_name, field_path + ) + if current_model is None: + return - # Check the first part exists in the model - if current_field_name not in current_model.model_fields: + final_field_name = parts[-1] + if not hasattr(current_model, "model_fields") or final_field_name not in ( + current_model.model_fields + ): raise QueryNotSupportedError( - f"You tried to return the field {field_path}, but the root field " - f"{current_field_name} does not exist on the model {current_model}" + f"You tried to return the field {field_path}, but the nested field " + f"{final_field_name} does not exist on the embedded model {current_model}" ) - # Walk through the nested field path - for i, field_name in enumerate(parts): - if i == 0: - # First part - get the field info - field_info = current_model.model_fields[field_name] - field_type = getattr(field_info, "annotation", None) + def _validate_deep_field_segment( + self, + current_model, + field_name: str, + field_path: str, + is_root: bool = False, + ): + if not hasattr(current_model, "model_fields") or field_name not in ( + current_model.model_fields + ): + if is_root: + raise QueryNotSupportedError( + f"You tried to return the field {field_path}, but the root field " + f"{field_name} does not exist on the model {current_model}" + ) + raise QueryNotSupportedError( + f"You tried to return the field {field_path}, but the nested field " + f"{field_name} does not exist on the embedded model {current_model}" + ) - # Check if it's an embedded model - try: - if isinstance(field_type, type) and issubclass( - field_type, RedisModel - ): - current_model = field_type - elif field_type is dict: - # Dict fields - we can't validate nested paths, just accept them - return - else: - raise QueryNotSupportedError( - f"Deep field path {field_path} requires {field_name} to be an " - f"embedded model or dict, but it is {field_type}" - ) - except TypeError: - raise QueryNotSupportedError( - f"Deep field path {field_path} requires {field_name} to be an " - f"embedded model or dict, but it is {field_type}" - ) - else: - # Nested parts - check they exist in the embedded model - if ( - not hasattr(current_model, "model_fields") - or field_name not in current_model.model_fields - ): - raise QueryNotSupportedError( - f"You tried to return the field {field_path}, but the nested field " - f"{field_name} does not exist on the embedded model {current_model}" - ) + field_info = current_model.model_fields[field_name] + field_type = _get_model_field_type(field_info) - # Update current_model for further nesting if needed - if i < len(parts) - 1: # Not the last part - field_info = current_model.model_fields[field_name] - field_type = getattr(field_info, "annotation", None) - try: - if isinstance(field_type, type) and issubclass( - field_type, RedisModel - ): - current_model = field_type - elif field_type is dict: - return # Can't validate further into dict - else: - raise QueryNotSupportedError( - f"Deep field path {field_path} requires {field_name} to be an " - f"embedded model or dict for further nesting" - ) - except TypeError: - raise QueryNotSupportedError( - f"Deep field path {field_path} requires {field_name} to be an " - f"embedded model or dict for further nesting" - ) + if _is_redis_model_type(field_type): + return field_type + + if field_type is dict: + return None + + if is_root: + raise QueryNotSupportedError( + f"Deep field path {field_path} requires {field_name} to be an " + f"embedded model or dict, but it is {field_type}" + ) + + raise QueryNotSupportedError( + f"Deep field path {field_path} requires {field_name} to be an " + f"embedded model or dict for further nesting" + ) def _parse_projected_results(self, res: Any) -> List[Dict[str, Any]]: """Parse results when using RETURN clause with specific fields.""" @@ -1358,69 +1355,84 @@ async def _parse_json_path_projection_as_dict( self, res: Any ) -> List[Dict[str, Any]]: """Use JSON.GET with JSONPath to efficiently extract deep fields.""" - # Extract document keys from search results - doc_keys = [] - step = 2 # Because the result has content - - for i in range(1, len(res), step): - if i < len(res): - doc_key = res[i] # Document key - if isinstance(doc_key, bytes): - doc_key = doc_key.decode("utf-8") - doc_keys.append(doc_key) + doc_keys = self._extract_document_keys_from_search_results(res) if not doc_keys: return [] - # Convert field names to JSONPath expressions - json_paths = [] - for field_name in self.projected_fields: - if "__" in field_name: - # Deep field: address__city -> $.address.city - json_path = "$." + field_name.replace("__", ".") - else: - # Regular field: name -> $.name - json_path = f"$.{field_name}" - json_paths.append(json_path) - - # Batch get all projected fields for all documents + json_paths = self._build_json_paths(self.projected_fields) projected_results = [] db = self.model.db() for doc_key in doc_keys: - try: - # Get multiple JSONPath expressions in one call - result = await db.json().get(doc_key, *json_paths) + projected_data = await self._fetch_projected_document_data( + db, doc_key, json_paths + ) + if projected_data: + projected_results.append(projected_data) - if result is None: - continue + return projected_results - # Convert JSONPath results back to field names - projected_data = {} - if isinstance(result, dict): - # Multiple paths returned as dict - for json_path, values in result.items(): - # Convert $.address.city back to address__city - field_name = json_path[2:].replace( - ".", "__" - ) # Remove "$." and convert dots to __ - # JSON.GET returns arrays, take first value - if values and len(values) > 0: - projected_data[field_name] = values[0] - else: - # Single path - shouldn't happen with multiple paths, but handle it - if len(json_paths) == 1: - field_name = json_paths[0][2:].replace(".", "__") - if isinstance(result, list) and result: - projected_data[field_name] = result[0] + async def _fetch_projected_document_data( + self, db, doc_key: str, json_paths: List[str] + ) -> Dict[str, Any]: + """Fetch and normalize projected data for a single document.""" + try: + result = await db.json().get(doc_key, *json_paths) + except Exception: # nosec B112 + return {} - projected_results.append(projected_data) + if result is None: + return {} - except Exception: # nosec B112 - # If JSON.GET fails (connection, parsing, etc.), skip this document - continue + if isinstance(result, dict): + return self._normalize_json_dict_result(result) - return projected_results + return self._normalize_json_list_result(result, json_paths) + + def _extract_document_keys_from_search_results(self, res: Any) -> List[str]: + """Extract document keys from FT.SEARCH results.""" + doc_keys = [] + for i in range(1, len(res), 2): + if i < len(res): + doc_key = res[i] + if isinstance(doc_key, bytes): + doc_key = doc_key.decode("utf-8") + doc_keys.append(doc_key) + return doc_keys + + def _build_json_paths(self, field_names: List[str]) -> List[str]: + """Convert projected field names to JSONPath expressions.""" + return [self._field_name_to_json_path(field_name) for field_name in field_names] + + def _field_name_to_json_path(self, field_name: str) -> str: + """Convert a Django-style field path into JSONPath syntax.""" + if "__" in field_name: + return "$." + field_name.replace("__", ".") + return f"$.{field_name}" + + def _normalize_json_dict_result(self, result: Dict[str, Any]) -> Dict[str, Any]: + """Normalize a multi-path JSON.GET result into a flat dict.""" + projected_data = {} + for json_path, values in result.items(): + field_name = self._json_path_to_field_name(json_path) + if values and len(values) > 0: + projected_data[field_name] = values[0] + return projected_data + + def _normalize_json_list_result( + self, result: Any, json_paths: List[str] + ) -> Dict[str, Any]: + """Normalize a single-path JSON.GET result into a flat dict.""" + if len(json_paths) != 1 or not isinstance(result, list) or not result: + return {} + + field_name = self._json_path_to_field_name(json_paths[0]) + return {field_name: result[0]} + + def _json_path_to_field_name(self, json_path: str) -> str: + """Convert a JSONPath expression back to a Django-style field path.""" + return json_path[2:].replace(".", "__") async def _parse_fallback_projection_as_dict( self, res: Any @@ -1777,6 +1789,15 @@ def _resolve_geo_value(field_name: str, op: Operators, value: Any) -> str: return f"@{field_name}:[{value}]" return "" + @staticmethod + def _resolve_prefixed_field_name( + field_name: str, parents: List[Tuple[str, "RedisModel"]] + ) -> str: + if not parents: + return field_name + prefix = "_".join([p[0] for p in parents]) + return f"{prefix}_{field_name}" + @classmethod def resolve_value( cls, @@ -1788,23 +1809,29 @@ def resolve_value( parents: List[Tuple[str, "RedisModel"]], model_class: Optional[Type["RedisModel"]] = None, ) -> str: - # The 'field_name' should already include the correct prefix - if parents: - prefix = "_".join([p[0] for p in parents]) - field_name = f"{prefix}_{field_name}" - if field_type is RediSearchFieldTypes.TEXT: - return cls._resolve_text_value(field_name, op, value) - elif field_type is RediSearchFieldTypes.NUMERIC: - return cls._resolve_numeric_value(field_name, op, value) - # TODO: How will we know the difference between a multi-value use of a TAG - # field and our hidden use of TAG for exact-match queries? - elif field_type is RediSearchFieldTypes.TAG: - return cls._resolve_tag_value(field_name, field_info, op, value, model_class) - - elif field_type is RediSearchFieldTypes.GEO: - return cls._resolve_geo_value(field_name, op, value) + field_name = cls._resolve_prefixed_field_name(field_name, parents) - return "" + resolvers = { + RediSearchFieldTypes.TEXT: lambda: cls._resolve_text_value( + field_name, op, value + ), + RediSearchFieldTypes.NUMERIC: lambda: cls._resolve_numeric_value( + field_name, op, value + ), + # TODO: How will we know the difference between a multi-value use of a + # TAG field and our hidden use of TAG for exact-match queries? + RediSearchFieldTypes.TAG: lambda: cls._resolve_tag_value( + field_name, field_info, op, value, model_class + ), + RediSearchFieldTypes.GEO: lambda: cls._resolve_geo_value( + field_name, op, value + ), + } + + resolver = resolvers.get(field_type) + if resolver is None: + return "" + return resolver() def resolve_redisearch_pagination(self): """Resolve pagination options for a query.""" @@ -1821,6 +1848,96 @@ def resolve_redisearch_sort_fields(self): if self.sort_fields: return ["SORTBY", *fields] + @staticmethod + def _unwrap_negated_expression( + expression: ExpressionOrNegated, + ) -> tuple[Expression, bool]: + if isinstance(expression, NegatedExpression): + return expression.expression, True + return expression, False + + def _resolve_left_query_fragment( + self, expression: Expression + ) -> tuple[ + str, + Optional[str], + Optional[RediSearchFieldTypes], + Optional[PydanticFieldInfo], + ]: + if isinstance(expression.left, (Expression, NegatedExpression)): + return ( + f"({self._resolve_redisearch_query(expression.left)})", + None, + None, + None, + ) + if isinstance(expression.left, FieldInfo): + field_type = self.__class__.resolve_field_type( + expression.left, expression.op + ) + field_name = expression.left.name + field_info = expression.left + if not field_info or not getattr(field_info, "index", None): + raise QueryNotSupportedError( + f"You tried to query by a field ({field_name}) " + f"that isn't indexed. Docs: {ERRORS_URL}#E6" + ) + return "", field_name, field_type, field_info + raise QueryNotSupportedError( + "A query expression should start with either a field " + f"or an expression enclosed in parentheses. Docs: {ERRORS_URL}#E7" + ) + + def _resolve_right_query_fragment( + self, + expression: Expression, + result: str, + field_name: Optional[str], + field_type: Optional[RediSearchFieldTypes], + field_info: Optional[PydanticFieldInfo], + ) -> str: + right = expression.right + + if isinstance(right, (Expression, NegatedExpression)): + if expression.op == Operators.AND: + result += " " + elif expression.op == Operators.OR: + result += "| " + else: + raise QueryNotSupportedError( + "You can only combine two query expressions with" + f"AND (&) or OR (|). Docs: {ERRORS_URL}#E8" + ) + + if isinstance(right, NegatedExpression): + result += "-" + right = right.expression + + return f"{result}({self._resolve_redisearch_query(right)})" + + if not field_name: + raise QuerySyntaxError( + f"Could not resolve field name. Docs: {ERRORS_URL}#E9" + ) + if not field_type: + raise QuerySyntaxError( + f"Could not resolve field type. Docs: {ERRORS_URL}#E10" + ) + if not field_info: + raise QuerySyntaxError( + f"Could not resolve field info. Docs: {ERRORS_URL}#E11" + ) + + return result + self.__class__.resolve_value( + field_name, + field_type, + field_info, + expression.op, + right, + expression.parents, + self.model, + ) + def _resolve_redisearch_query(self, expression: ExpressionOrNegated) -> str: """ Resolve an arbitrarily deep expression into a single RediSearch query string. @@ -1845,15 +1962,9 @@ def _resolve_redisearch_query(self, expression: ExpressionOrNegated) -> str: TODO: When the operator is not IN or NOT_IN, detect a sequence type (other than strings, which are allowed) and raise an exception. """ - field_type = None - field_name = None - field_info = None - encompassing_expression_is_negated = False - result = "" - - if isinstance(expression, NegatedExpression): - encompassing_expression_is_negated = True - expression = expression.expression + expression, encompassing_expression_is_negated = self._unwrap_negated_expression( + expression + ) if expression.op is Operators.ALL: if encompassing_expression_is_negated: @@ -1862,68 +1973,12 @@ def _resolve_redisearch_query(self, expression: ExpressionOrNegated) -> str: ) return "*" - if isinstance(expression.left, Expression) or isinstance( - expression.left, NegatedExpression - ): - result += f"({self._resolve_redisearch_query(expression.left)})" - elif isinstance(expression.left, FieldInfo): - field_type = self.__class__.resolve_field_type( - expression.left, expression.op - ) - field_name = expression.left.name - field_info = expression.left - if not field_info or not getattr(field_info, "index", None): - raise QueryNotSupportedError( - f"You tried to query by a field ({field_name}) " - f"that isn't indexed. Docs: {ERRORS_URL}#E6" - ) - else: - raise QueryNotSupportedError( - "A query expression should start with either a field " - f"or an expression enclosed in parentheses. Docs: {ERRORS_URL}#E7" - ) - - right = expression.right - - if isinstance(right, Expression) or isinstance(right, NegatedExpression): - if expression.op == Operators.AND: - result += " " - elif expression.op == Operators.OR: - result += "| " - else: - raise QueryNotSupportedError( - "You can only combine two query expressions with" - f"AND (&) or OR (|). Docs: {ERRORS_URL}#E8" - ) - - if isinstance(right, NegatedExpression): - result += "-" - right = right.expression - - result += f"({self._resolve_redisearch_query(right)})" - else: - if not field_name: - raise QuerySyntaxError( - f"Could not resolve field name. Docs: {ERRORS_URL}#E9" - ) - elif not field_type: - raise QuerySyntaxError( - f"Could not resolve field type. Docs: {ERRORS_URL}#E10" - ) - elif not field_info: - raise QuerySyntaxError( - f"Could not resolve field info. Docs: {ERRORS_URL}#E11" - ) - else: - result += self.__class__.resolve_value( - field_name, - field_type, - field_info, - expression.op, - right, - expression.parents, - self.model, - ) + result, field_name, field_type, field_info = self._resolve_left_query_fragment( + expression + ) + result = self._resolve_right_query_fragment( + expression, result, field_name, field_type, field_info + ) if encompassing_expression_is_negated: result = f"-({result})" @@ -1987,6 +2042,41 @@ async def _execute_command(self, args: List[Union[str, bytes]]): # Re-raise the original exception raise + def _reset_model_cache_if_needed(self): + if self.offset == 0: + self._cache.model_cache.clear() + + async def _collect_remaining_results(self): + # Keep paginating until Redis stops returning rows. + query = self + while True: + query = query.copy(offset=query.offset + query.page_size) + results = await query.execute(exhaust_results=False) + if not results: + break + self._cache.model_cache += results + return self._cache.model_cache + + async def _handle_execute_results( + self, + raw_result: Any, + use_full_document_fallback: bool, + exhaust_results: bool, + ): + count = raw_result[0] + results = await self._parse_execute_results( + raw_result, use_full_document_fallback + ) + self._cache.model_cache += results + + if not exhaust_results: + return self._cache.model_cache + + if count <= len(results): + return self._cache.model_cache + + return await self._collect_remaining_results() + async def _parse_execute_results( self, raw_result: Any, use_full_document_fallback: bool ): @@ -2020,42 +2110,18 @@ async def execute( if return_query_args: return self.model.Meta.index_name, args - # Reset the cache if we're executing from offset 0. - if self.offset == 0: - self._cache.model_cache.clear() + self._reset_model_cache_if_needed() - # If the offset is greater than 0, we're paginating through a result set, - # so append the new results to results already in the cache. raw_result = await self._execute_command(args) if return_raw_result: return raw_result - count = raw_result[0] - results = await self._parse_execute_results( - raw_result, use_full_document_fallback + return await self._handle_execute_results( + raw_result, + use_full_document_fallback, + exhaust_results, ) - self._cache.model_cache += results - - if not exhaust_results: - return self._cache.model_cache - - # The query returned all results, so we have no more work to do. - if count <= len(results): - return self._cache.model_cache - - # Transparently (to the user) make subsequent requests to paginate - # through the results and finally return them all. - query = self - while True: - # Make a query for each pass of the loop, with a new offset equal to the - # current offset plus `page_size`, until we stop getting results back. - query = query.copy(offset=query.offset + query.page_size) - _results = await query.execute(exhaust_results=False) - if not _results: - break - self._cache.model_cache += _results - return self._cache.model_cache async def get_query(self): query = self.copy() @@ -2905,6 +2971,42 @@ def _register_indexed_model(new_class, bases): key = f"{new_class.__module__}.{new_class.__qualname__}" model_registry[key] = new_class + @classmethod + def _filter_config_kwargs(cls, kwargs): + allowed_config_kwargs = cls._allowed_config_kwargs() + return {key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs} + + @classmethod + def _finalize_new_class( + cls, + new_class, + bases, + base_meta, + original_field_infos, + is_indexed, + ): + # Create proxies for each model field so that we can use the field + # in queries, like Model.get(Model.field_name == 1) + # Only set if the model is has index=True + model_fields = cls._model_fields(new_class) + custom_pk_count = cls._apply_field_proxies( + new_class, + model_fields, + original_field_infos, + is_indexed, + ) + cls._setup_primary_key_accessor(new_class, model_fields, custom_pk_count) + cls._apply_embedded_model_rules(new_class) + cls._apply_meta_defaults(new_class, base_meta) + cls._register_indexed_model(new_class, bases) + + @staticmethod + def _validate_indexed_inheritance(new_class, is_indexed): + if is_indexed and new_class.model_config.get("index", None) is True: + raise RedisModelError( + f"{new_class.__name__} cannot be indexed, only one model can be indexed in an inheritance tree" + ) + def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 meta = attrs.pop("Meta", None) @@ -2916,9 +3018,7 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 # Duplicate logic from Pydantic to filter config kwargs because if they are # passed directly including the registry Pydantic will pass them over to the # superclass causing an error - allowed_config_kwargs = cls._allowed_config_kwargs() - - config_kwargs = {key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs} + config_kwargs = cls._filter_config_kwargs(kwargs) new_class: RedisModel = super().__new__( cls, name, bases, attrs, **config_kwargs @@ -2929,27 +3029,16 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 is_indexed = kwargs.get("index", None) is True - if is_indexed and new_class.model_config.get("index", None) is True: - raise RedisModelError( - f"{new_class.__name__} cannot be indexed, only one model can be indexed in an inheritance tree" - ) + cls._validate_indexed_inheritance(new_class, is_indexed) cls._set_index_config(new_class, is_indexed) - - # Create proxies for each model field so that we can use the field - # in queries, like Model.get(Model.field_name == 1) - # Only set if the model is has index=True - model_fields = cls._model_fields(new_class) - custom_pk_count = cls._apply_field_proxies( + cls._finalize_new_class( new_class, - model_fields, + bases, + base_meta, original_field_infos, is_indexed, ) - cls._setup_primary_key_accessor(new_class, model_fields, custom_pk_count) - cls._apply_embedded_model_rules(new_class) - cls._apply_meta_defaults(new_class, base_meta) - cls._register_indexed_model(new_class, bases) return new_class @@ -3319,71 +3408,63 @@ def check(self): class HashModel(RedisModel, abc.ABC): + @classmethod + def _has_vector_options_for_field(cls, field_name: str) -> bool: + """Check whether a field is configured as a vector field.""" + # Check cls.__dict__ first because model_fields may not exist yet during + # class creation. + if field_name in cls.__dict__: + field = cls.__dict__[field_name] + if getattr(field, "vector_options", None) is not None: + return True + + if hasattr(cls, "model_fields") and field_name in cls.model_fields: + field = cls.model_fields[field_name] + if getattr(field, "vector_options", None) is not None: + return True + + return False + + @staticmethod + def _validate_hash_field_type(field_name: str, field_type, allow_vector: bool): + origin = get_origin(field_type) + if origin: + for typ in (Set, Mapping, List): + if isinstance(origin, type) and issubclass(origin, typ): + if allow_vector: + return + raise RedisModelError( + f"HashModels cannot index set, list, or mapping fields. Field: {field_name}" + ) + + if isinstance(field_type, type) and issubclass(field_type, RedisModel): + raise RedisModelError( + f"HashModels cannot index embedded model fields. Field: {field_name}" + ) + + if isinstance(field_type, type) and dataclasses.is_dataclass(field_type): + raise RedisModelError( + f"HashModels cannot index dataclass fields. Field: {field_name}" + ) + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - # Helper to check if a field has vector_options (making it a vector field). - # We check cls.__dict__ because model_fields may not be populated yet - # when __init_subclass__ runs during class creation. - def _has_vector_options(field_name: str) -> bool: - """Check if a field has vector_options set, making it a vector field.""" - # First check cls.__dict__ for the original FieldInfo (before Pydantic processing) - if field_name in cls.__dict__: - field = cls.__dict__[field_name] - if getattr(field, "vector_options", None) is not None: - return True - # Also check model_fields in case it's populated - if hasattr(cls, "model_fields") and field_name in cls.model_fields: - field = cls.model_fields[field_name] - if getattr(field, "vector_options", None) is not None: - return True - return False - if hasattr(cls, "__annotations__"): for name, field_type in cls.__annotations__.items(): - origin = get_origin(field_type) - for typ in (Set, Mapping, List): - if isinstance(origin, type) and issubclass(origin, typ): - # Vector fields are allowed to be lists (list[float]) - if _has_vector_options(name): - continue - raise RedisModelError( - f"HashModels cannot index set, list, " - f"or mapping fields. Field: {name}" - ) - if isinstance(field_type, type) and issubclass(field_type, RedisModel): - raise RedisModelError( - f"HashModels cannot index embedded model fields. Field: {name}" - ) - elif isinstance(field_type, type) and dataclasses.is_dataclass( - field_type - ): - raise RedisModelError( - f"HashModels cannot index dataclass fields. Field: {name}" - ) + cls._validate_hash_field_type( + name, + field_type, + allow_vector=cls._has_vector_options_for_field(name), + ) for name, field in cls.model_fields.items(): outer_type = outer_type_or_annotation(field) - origin = get_origin(outer_type) - if origin: - for typ in (Set, Mapping, List): - if issubclass(origin, typ): - # Vector fields are allowed to be lists (list[float]) - if getattr(field, "vector_options", None) is not None: - continue - raise RedisModelError( - f"HashModels cannot index set, list, " - f"or mapping fields. Field: {name}" - ) - - if issubclass(outer_type, RedisModel): - raise RedisModelError( - f"HashModels cannot index embedded model fields. Field: {name}" - ) - elif dataclasses.is_dataclass(outer_type): - raise RedisModelError( - f"HashModels cannot index dataclass fields. Field: {name}" - ) + cls._validate_hash_field_type( + name, + outer_type, + allow_vector=getattr(field, "vector_options", None) is not None, + ) def _get_field_expirations( self, field_expirations: Optional[Dict[str, int]] = None @@ -3421,6 +3502,62 @@ def _get_field_expirations( return expirations + def _validate_save_options(self, pipeline, nx: bool, xx: bool) -> None: + if nx and xx: + raise ValueError("Cannot specify both nx and xx") + if pipeline and (nx or xx): + raise ValueError( + "Cannot use nx or xx with pipeline for HashModel. " + "Use JsonModel if you need conditional saves with pipelines." + ) + + async def _preserve_field_ttls( + self, + conn, + key: str, + document: Dict[str, Any], + is_pipeline: bool, + ) -> Dict[str, int]: + preserved_ttls: Dict[str, int] = {} + if not supports_hash_field_expiration() or is_pipeline: + return preserved_ttls + + fields_to_check = [f for f in document.keys() if f != "pk"] + if not fields_to_check: + return preserved_ttls + + current_ttls = await conn.httl(key, *fields_to_check) + if not current_ttls: + return preserved_ttls + + for i, field_name in enumerate(fields_to_check): + if current_ttls[i] > 0: # Has a TTL + preserved_ttls[field_name] = current_ttls[i] + + return preserved_ttls + + async def _apply_field_expirations( + self, + conn, + key: str, + document: Dict[str, Any], + preserved_ttls: Dict[str, int], + expirations: Dict[str, int], + ) -> None: + if not supports_hash_field_expiration(): + return + + for field_name in document.keys(): + if field_name == "pk": + continue + # Priority: preserved TTL > explicit field_expirations > Field(expire=N) default + if field_name in preserved_ttls: + # Restore the TTL that was removed by HSET + await conn.hexpire(key, preserved_ttls[field_name], field_name) + elif field_name in expirations: + # Apply new expiration (from Field(expire=N) or field_expirations param) + await conn.hexpire(key, expirations[field_name], field_name) + async def save( self: "Model", pipeline: Optional[Pipeline] = None, @@ -3441,13 +3578,7 @@ async def save( Returns: The saved model, or None if nx/xx conditions weren't met. """ - if nx and xx: - raise ValueError("Cannot specify both nx and xx") - if pipeline and (nx or xx): - raise ValueError( - "Cannot use nx or xx with pipeline for HashModel. " - "Use JsonModel if you need conditional saves with pipelines." - ) + self._validate_save_options(pipeline, nx, xx) self.check() db = self._get_db(pipeline) @@ -3492,32 +3623,22 @@ async def _do_save(conn): # See issue #753: .save() conflicts with TTL on unrelated field # Note: TTL preservation is skipped when using pipelines because # pipeline commands return futures, not actual values - preserved_ttls: Dict[str, int] = {} - if supports_hash_field_expiration() and not is_pipeline: - fields_to_check = [f for f in document.keys() if f != "pk"] - if fields_to_check: - current_ttls = await conn.httl(key, *fields_to_check) - if current_ttls: - for i, field_name in enumerate(fields_to_check): - if current_ttls[i] > 0: # Has a TTL - preserved_ttls[field_name] = current_ttls[i] + preserved_ttls = await self._preserve_field_ttls( + conn, key, document, is_pipeline + ) await conn.hset(key, mapping=document) # Apply field expirations after HSET (requires Redis 7.4+) # When using pipelines, we can still apply default expirations but # can't preserve manually-set TTLs - if supports_hash_field_expiration(): - for field_name in document.keys(): - if field_name == "pk": - continue - # Priority: preserved TTL > explicit field_expirations > Field(expire=N) default - if field_name in preserved_ttls: - # Restore the TTL that was removed by HSET - await conn.hexpire(key, preserved_ttls[field_name], field_name) - elif field_name in expirations: - # Apply new expiration (from Field(expire=N) or field_expirations param) - await conn.hexpire(key, expirations[field_name], field_name) + await self._apply_field_expirations( + conn, + key, + document, + preserved_ttls, + expirations, + ) return self @@ -3666,53 +3787,35 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): # as sortable. # TODO: Abstract string-building logic for each type (TAG, etc.) into # classes that take a field name. + return cls._schema_for_type(name, typ, field_info) + + @staticmethod + def _issubclass_safe(typ: Any, class_or_tuple: Any) -> bool: + try: + return issubclass(typ, class_or_tuple) + except TypeError: + return False + + @classmethod + def _schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): sortable = getattr(field_info, "sortable", False) case_sensitive = getattr(field_info, "case_sensitive", False) if is_supported_container_type(typ): - embedded_cls = get_args(typ) - if not embedded_cls: - # TODO: Test if this can really happen. - log.warning( - "Model %s defined an empty list or tuple field: %s", cls, name - ) - return "" - embedded_cls = embedded_cls[0] - schema = cls.schema_for_type(name, embedded_cls, field_info) + schema = cls._schema_for_container_type(name, typ, field_info) elif typ is bool: schema = f"{name} TAG" elif typ in [CoordinateType, Coordinates]: schema = f"{name} GEO" elif is_numeric_type(typ): - vector_options: Optional[VectorFieldOptions] = getattr( - field_info, "vector_options", None - ) - if vector_options: - schema = f"{name} {vector_options.schema}" - else: - schema = f"{name} NUMERIC" - elif issubclass(typ, str): - separator = getattr( - field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR - ) - if getattr(field_info, "full_text_search", False) is True: - schema = f"{name} TAG SEPARATOR {separator} {name} AS {name}_fts TEXT" - else: - schema = f"{name} TAG SEPARATOR {separator}" - elif issubclass(typ, RedisModel): - sub_fields = [] - for embedded_name, field in typ.model_fields.items(): - sub_fields.append( - cls.schema_for_type( - f"{name}_{embedded_name}", field.outer_type_, field.field_info - ) - ) - schema = " ".join(sub_fields) + schema = cls._schema_for_numeric_type(name, field_info) + elif cls._issubclass_safe(typ, str): + schema = cls._schema_for_string_type(name, field_info) + elif cls._issubclass_safe(typ, RedisModel): + schema = cls._schema_for_redis_model_type(name, typ) else: - separator = getattr( - field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR - ) - schema = f"{name} TAG SEPARATOR {separator}" + schema = cls._schema_for_default_type(name, field_info) + if schema and sortable is True: schema += " SORTABLE" if schema and case_sensitive is True: @@ -3720,6 +3823,53 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): return schema + @classmethod + def _schema_for_container_type( + cls, name: str, typ: Any, field_info: PydanticFieldInfo + ) -> str: + embedded_cls = get_args(typ) + if not embedded_cls: + # TODO: Test if this can really happen. + log.warning("Model %s defined an empty list or tuple field: %s", cls, name) + return "" + return cls._schema_for_type(name, embedded_cls[0], field_info) + + @classmethod + def _schema_for_numeric_type( + cls, name: str, field_info: PydanticFieldInfo + ) -> str: + vector_options: Optional[VectorFieldOptions] = getattr( + field_info, "vector_options", None + ) + if vector_options: + return f"{name} {vector_options.schema}" + return f"{name} NUMERIC" + + @classmethod + def _schema_for_string_type(cls, name: str, field_info: PydanticFieldInfo) -> str: + separator = getattr(field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR) + if getattr(field_info, "full_text_search", False) is True: + return f"{name} TAG SEPARATOR {separator} {name} AS {name}_fts TEXT" + return f"{name} TAG SEPARATOR {separator}" + + @classmethod + def _schema_for_redis_model_type(cls, name: str, typ: Any) -> str: + sub_fields = [] + for embedded_name, field in typ.model_fields.items(): + sub_fields.append( + cls._schema_for_type( + f"{name}_{embedded_name}", field.outer_type_, field.field_info + ) + ) + return " ".join(sub_fields) + + @classmethod + def _schema_for_default_type( + cls, name: str, field_info: PydanticFieldInfo + ) -> str: + separator = getattr(field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR) + return f"{name} TAG SEPARATOR {separator}" + # ========================================================================= # Hash Field Expiration Methods (Redis 7.4+) # ========================================================================= @@ -3935,6 +4085,31 @@ def redisearch_schema(cls): def schema_for_fields(cls): schema_parts = [] json_path = "$" + + for name, field in cls._collect_schema_fields().items(): + _type = get_outer_type(field) + if _type is None: + continue + + if ( + not isinstance(field, FieldInfo) + and hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): + field = field.metadata[0] + + field_info = field + redisearch_field = cls._schema_for_json_field( + json_path, name, _type, field_info + ) + if redisearch_field: + schema_parts.append(redisearch_field) + + return schema_parts + + @classmethod + def _collect_schema_fields(cls): fields = dict() if PYDANTIC_V2: model_fields = cls.model_fields @@ -3959,37 +4134,25 @@ def schema_for_fields(cls): continue fields[name] = PydanticFieldInfo.from_annotation(field) - for name, field in fields.items(): - _type = get_outer_type(field) - if _type is None: - continue + return fields - if ( - not isinstance(field, FieldInfo) - and hasattr(field, "metadata") - and len(field.metadata) > 0 - and isinstance(field.metadata[0], FieldInfo) - ): - field = field.metadata[0] - - field_info = field + @classmethod + def _schema_for_json_field( + cls, + json_path: str, + name: str, + typ: Union[Type[RedisModel], Any], + field_info: PydanticFieldInfo, + ) -> str: + if getattr(field_info, "primary_key", None) is True: + if issubclass(typ, str): + separator = getattr( + field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR + ) + return f"$.{name} AS {name} TAG SEPARATOR {separator}" + return cls.schema_for_type(json_path, name, "", typ, field_info) - if getattr(field_info, "primary_key", None) is True: - if issubclass(_type, str): - separator = getattr( - field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR - ) - redisearch_field = f"$.{name} AS {name} TAG SEPARATOR {separator}" - else: - redisearch_field = cls.schema_for_type( - json_path, name, "", _type, field_info - ) - schema_parts.append(redisearch_field) - continue - schema_parts.append( - cls.schema_for_type(json_path, name, "", _type, field_info) - ) - return schema_parts + return cls.schema_for_type(json_path, name, "", typ, field_info) @classmethod def schema_for_type( @@ -4166,60 +4329,148 @@ def _schema_for_leaf_type( separator = getattr(field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR) if is_vector and vector_options: - schema = f"{path} AS {index_field_name} {vector_options.schema}" - elif parent_is_container_type or parent_is_model_in_container: - if typ is not str: - raise RedisModelError( - "List and tuple fields can only contain strings. " - f"Problem field: {name}. Docs: {ERRORS_URL}#E12" - ) - if full_text_search is True: - raise RedisModelError( - "List and tuple fields cannot be indexed for full-text " - f"search. Problem field: {name}. Docs: {ERRORS_URL}#E13" - ) - schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" - if sortable is True: - schema += " SORTABLE" - if case_sensitive is True: - schema += " CASESENSITIVE" - elif typ is bool: + return cls._schema_for_vector_leaf_type( + path, index_field_name, vector_options.schema + ) + + if parent_is_container_type or parent_is_model_in_container: + return cls._schema_for_container_member_leaf_type( + path, + index_field_name, + name, + typ, + separator, + sortable, + case_sensitive, + full_text_search, + ) + + if typ is bool: + return cls._schema_for_tag_leaf_type( + path, index_field_name, separator=None, sortable=sortable + ) + + if typ in [CoordinateType, Coordinates]: + return cls._schema_for_geo_leaf_type(path, index_field_name, sortable) + + if is_numeric_type(typ): + return cls._schema_for_numeric_leaf_type(path, index_field_name, sortable) + + if cls._issubclass_safe(typ, str): + return cls._schema_for_string_leaf_type( + path, + index_field_name, + separator, + sortable, + case_sensitive, + full_text_search, + ) + + return cls._schema_for_tag_leaf_type( + path, index_field_name, separator, sortable + ) + + @staticmethod + def _append_schema_modifiers( + schema: str, + sortable: bool, + case_sensitive: bool = False, + allow_case_sensitive: bool = True, + ) -> str: + if sortable is True: + schema += " SORTABLE" + if case_sensitive is True: + if not allow_case_sensitive: + raise RedisModelError("Text fields cannot be case-sensitive.") + schema += " CASESENSITIVE" + return schema + + @classmethod + def _schema_for_vector_leaf_type( + cls, path: str, index_field_name: str, schema: str + ) -> str: + return f"{path} AS {index_field_name} {schema}" + + @classmethod + def _schema_for_container_member_leaf_type( + cls, + path: str, + index_field_name: str, + name: str, + typ: Union[Type[RedisModel], Any], + separator: str, + sortable: bool, + case_sensitive: bool, + full_text_search: bool, + ) -> str: + if typ is not str: + raise RedisModelError( + "List and tuple fields can only contain strings. " + f"Problem field: {name}. Docs: {ERRORS_URL}#E12" + ) + if full_text_search is True: + raise RedisModelError( + "List and tuple fields cannot be indexed for full-text " + f"search. Problem field: {name}. Docs: {ERRORS_URL}#E13" + ) + + schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" + return cls._append_schema_modifiers(schema, sortable, case_sensitive) + + @classmethod + def _schema_for_tag_leaf_type( + cls, + path: str, + index_field_name: str, + separator: Optional[str], + sortable: bool, + ) -> str: + if separator is None: schema = f"{path} AS {index_field_name} TAG" - if sortable is True: - schema += " SORTABLE" - elif typ in [CoordinateType, Coordinates]: - schema = f"{path} AS {index_field_name} GEO" - if sortable is True: - schema += " SORTABLE" - elif is_numeric_type(typ): - schema = f"{path} AS {index_field_name} NUMERIC" - if sortable is True: - schema += " SORTABLE" - elif issubclass(typ, str): - if full_text_search is True: - schema = ( - f"{path} AS {index_field_name} TAG SEPARATOR {separator} " - f"{path} AS {index_field_name}_fts TEXT" - ) - if sortable is True: - # NOTE: With the current preview release, making a field - # full-text searchable and sortable only makes the TEXT - # field sortable. - schema += " SORTABLE" - if case_sensitive is True: - raise RedisModelError("Text fields cannot be case-sensitive.") - else: - schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" - if sortable is True: - schema += " SORTABLE" - if case_sensitive is True: - schema += " CASESENSITIVE" else: schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" - if sortable is True: - schema += " SORTABLE" + return cls._append_schema_modifiers(schema, sortable) - return schema + @classmethod + def _schema_for_geo_leaf_type( + cls, path: str, index_field_name: str, sortable: bool + ) -> str: + schema = f"{path} AS {index_field_name} GEO" + return cls._append_schema_modifiers(schema, sortable) + + @classmethod + def _schema_for_numeric_leaf_type( + cls, path: str, index_field_name: str, sortable: bool + ) -> str: + schema = f"{path} AS {index_field_name} NUMERIC" + return cls._append_schema_modifiers(schema, sortable) + + @classmethod + def _schema_for_string_leaf_type( + cls, + path: str, + index_field_name: str, + separator: str, + sortable: bool, + case_sensitive: bool, + full_text_search: bool, + ) -> str: + if full_text_search is True: + schema = ( + f"{path} AS {index_field_name} TAG SEPARATOR {separator} " + f"{path} AS {index_field_name}_fts TEXT" + ) + # NOTE: With the current preview release, making a field + # full-text searchable and sortable only makes the TEXT field sortable. + return cls._append_schema_modifiers( + schema, + sortable, + case_sensitive, + allow_case_sensitive=False, + ) + + schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" + return cls._append_schema_modifiers(schema, sortable, case_sensitive) class EmbeddedJsonModel(JsonModel, abc.ABC): diff --git a/aredis_om/model/render_tree.py b/aredis_om/model/render_tree.py index 0a1ce71a..34dbe617 100644 --- a/aredis_om/model/render_tree.py +++ b/aredis_om/model/render_tree.py @@ -7,6 +7,38 @@ from typing import Any, Optional +def _name_resolver(current_node: Any, nameattr: str): + if hasattr(current_node, nameattr): + return lambda node: getattr(node, nameattr) # noqa: E731 + return lambda node: str(node) # noqa: E731 + + +def _next_indent(indent: str, last: str, direction: str, name_width: int) -> str: + return "{0}{1}{2}".format( + indent, " " if direction in last else "|", " " * name_width + ) + + +def _start_shape(last: str) -> str: + if last == "up": + return "┌" + if last == "down": + return "└" + if last == "updown": + return " " + return "├" + + +def _end_shape(up: Any, down: Any) -> str: + if up is not None and down is not None: + return "┤" + if up: + return "┘" + if down: + return "┐" + return "" + + def render_tree( current_node: Any, nameattr: str = "name", @@ -25,40 +57,20 @@ def render_tree( """ if buffer is None: buffer = io.StringIO() - if hasattr(current_node, nameattr): - name = lambda node: getattr(node, nameattr) # noqa: E731 - else: - name = lambda node: str(node) # noqa: E731 + name = _name_resolver(current_node, nameattr) up = getattr(current_node, left_child, None) down = getattr(current_node, right_child, None) if up is not None: next_last = "up" - next_indent = "{0}{1}{2}".format( - indent, " " if "up" in last else "|", " " * len(str(name(current_node))) - ) + next_indent = _next_indent(indent, last, "up", len(str(name(current_node)))) render_tree( up, nameattr, left_child, right_child, next_indent, next_last, buffer ) - if last == "up": - start_shape = "┌" - elif last == "down": - start_shape = "└" - elif last == "updown": - start_shape = " " - else: - start_shape = "├" - - if up is not None and down is not None: - end_shape = "┤" - elif up: - end_shape = "┘" - elif down: - end_shape = "┐" - else: - end_shape = "" + start_shape = _start_shape(last) + end_shape = _end_shape(up, down) print( "{0}{1}{2}{3}".format(indent, start_shape, name(current_node), end_shape), @@ -67,8 +79,8 @@ def render_tree( if down is not None: next_last = "down" - next_indent = "{0}{1}{2}".format( - indent, " " if "down" in last else "|", " " * len(str(name(current_node))) + next_indent = _next_indent( + indent, last, "down", len(str(name(current_node))) ) render_tree( down, nameattr, left_child, right_child, next_indent, next_last, buffer diff --git a/tests/test_find_query.py b/tests/test_find_query.py index 22da377b..b5b1b861 100644 --- a/tests/test_find_query.py +++ b/tests/test_find_query.py @@ -360,6 +360,31 @@ async def test_find_query_text_search_not_or_and(m, members): ] +@py_test_mark_asyncio +async def test_validate_deep_field_path_accepts_nested_embedded_fields(m): + query = FindQuery(expressions=[], model=m.Member) + + assert query.validate_projected_fields( + ["address__city", "address__note__description"] + ) + + +@py_test_mark_asyncio +async def test_validate_deep_field_path_rejects_invalid_nested_field(m): + query = FindQuery(expressions=[], model=m.Member) + + with pytest.raises(QueryNotSupportedError, match="nested field missing_field"): + query.validate_projected_fields(["address__missing_field"]) + + +@py_test_mark_asyncio +async def test_validate_deep_field_path_rejects_invalid_root_field(m): + query = FindQuery(expressions=[], model=m.Member) + + with pytest.raises(QueryNotSupportedError, match="root field missing_root"): + query.validate_projected_fields(["missing_root__city"]) + + # text search operators; contains, startswith, endswith, fuzzy @py_test_mark_asyncio async def test_find_query_text_contains(m): diff --git a/tests/test_json_path_projection.py b/tests/test_json_path_projection.py new file mode 100644 index 00000000..fc81ce51 --- /dev/null +++ b/tests/test_json_path_projection.py @@ -0,0 +1,59 @@ +import pytest + +from aredis_om.model.model import FindQuery, _FindQueryState + + +class DummyJsonClient: + def __init__(self, responses): + self.responses = responses + + def json(self): + return self + + async def get(self, doc_key, *json_paths): + return self.responses.get(doc_key) + + +class DummyModel: + def __init__(self, client): + self._client = client + + def db(self): + return self._client + + +def build_query(projected_fields, client): + query = FindQuery.__new__(FindQuery) + query._state = _FindQueryState(expressions=[], model=DummyModel(client)) + query.projected_fields = projected_fields + return query + + +@pytest.mark.asyncio +async def test_parse_json_path_projection_as_dict_handles_dict_results(): + client = DummyJsonClient( + { + "doc:1": { + "$.name": ["Alice"], + "$.address.city": ["Lisbon"], + }, + "doc:2": None, + } + ) + query = build_query(["name", "address__city"], client) + + result = await query._parse_json_path_projection_as_dict( + [2, b"doc:1", None, "doc:2", None] + ) + + assert result == [{"name": "Alice", "address__city": "Lisbon"}] + + +@pytest.mark.asyncio +async def test_parse_json_path_projection_as_dict_handles_single_path_results(): + client = DummyJsonClient({"doc:1": ["Bob"]}) + query = build_query(["name"], client) + + result = await query._parse_json_path_projection_as_dict([1, "doc:1", None]) + + assert result == [{"name": "Bob"}] diff --git a/tests/test_render_tree.py b/tests/test_render_tree.py new file mode 100644 index 00000000..99112ec6 --- /dev/null +++ b/tests/test_render_tree.py @@ -0,0 +1,35 @@ +from aredis_om.model.render_tree import render_tree + + +class Node: + def __init__(self, name, left=None, right=None): + self.name = name + self.left = left + self.right = right + + +def test_render_tree_matches_documented_layout(): + root = Node( + "AND", + Node("NOT EQ", Node("first_name"), Node("Andrew")), + Node( + "OR", + Node("EQ", Node("last_name"), Node("Brookins")), + Node("EQ", Node("last_name"), Node("Smith")), + ), + ) + + assert render_tree(root) == ( + "\n" + " ┌first_name\n" + " ┌NOT EQ┤\n" + " | └Andrew\n" + " AND┤\n" + " | ┌last_name\n" + " | ┌EQ┤\n" + " | | └Brookins\n" + " └OR┤\n" + " | ┌last_name\n" + " └EQ┤\n" + " └Smith\n" + )