diff --git a/aredis_om/model/migrations/data/builtin/datetime_migration.py b/aredis_om/model/migrations/data/builtin/datetime_migration.py index e4c0607c..1ff9f05f 100644 --- a/aredis_om/model/migrations/data/builtin/datetime_migration.py +++ b/aredis_om/model/migrations/data/builtin/datetime_migration.py @@ -363,6 +363,79 @@ async def _save_progress_if_needed(self, current_model: str, total_keys: int): stats=self.stats.get_summary(), ) + async def _collect_hash_keys(self, model_class) -> List[str]: + """Collect all Redis hash keys for a model.""" + key_pattern = model_class.make_key("*") + all_keys = [] + scan_iter = self.redis.scan_iter(match=key_pattern, _type="HASH") + + async for key in scan_iter: # type: ignore[misc] + if isinstance(key, bytes): + key = key.decode("utf-8") + all_keys.append(key) + + return all_keys + + @staticmethod + def _normalize_hash_data(hash_data: Dict[Any, Any]) -> Dict[str, Any]: + """Normalize hash payload to string keys and values.""" + if hash_data and isinstance(next(iter(hash_data.keys())), bytes): + return { + k.decode("utf-8"): v.decode("utf-8") + for k, v in hash_data.items() + } + + return hash_data + + async def _process_hash_key( + self, + key: str, + datetime_fields: List[str], + model_name: str, + total_keys: int, + ) -> bool: + """Process a single hash key and return whether it was handled.""" + if key in self.processed_keys_set: + return False + + try: + hash_data = await self.redis.hgetall(key) # type: ignore[misc] + except Exception as e: + log.warning(f"Failed to get hash data from {key}: {e}") + return False + + if not hash_data: + return False + + hash_data = self._normalize_hash_data(hash_data) + updates = {} + + for field_name in datetime_fields: + if field_name in hash_data: + value = hash_data[field_name] + converted, success = self._safe_convert_datetime_value( + key, field_name, value + ) + + if success and converted != value: + updates[field_name] = str(converted) + + if updates: + try: + await self.redis.hset(key, mapping=updates) # type: ignore[misc] + except Exception as e: + log.error(f"Failed to update hash {key}: {e}") + if self.failure_mode == ConversionFailureMode.FAIL: + raise DataMigrationError(f"Failed to update hash {key}: {e}") + + self.processed_keys_set.add(key) + self.stats.add_processed_key() + self._processed_keys += 1 + + self._check_error_threshold() + await self._save_progress_if_needed(model_name, total_keys) + return True + async def _clear_progress_on_completion(self): """Clear saved progress when migration completes successfully.""" if self.migration_state: @@ -504,16 +577,7 @@ async def _process_hash_model( self, model_class, datetime_fields: List[str] ) -> None: """Process HashModel instances to convert datetime fields with enhanced error handling.""" - # Get all keys for this model - key_pattern = model_class.make_key("*") - - # Collect all keys first for batch processing - all_keys = [] - scan_iter = self.redis.scan_iter(match=key_pattern, _type="HASH") - async for key in scan_iter: # type: ignore[misc] - if isinstance(key, bytes): - key = key.decode("utf-8") - all_keys.append(key) + all_keys = await self._collect_hash_keys(model_class) total_keys = len(all_keys) log.info( @@ -531,64 +595,13 @@ async def _process_hash_model( for key in batch_keys: try: - # Skip if already processed (resume capability) - if key in self.processed_keys_set: - continue - - # Get all fields from the hash - try: - hash_data = await self.redis.hgetall(key) # type: ignore[misc] - except Exception as e: - log.warning(f"Failed to get hash data from {key}: {e}") - continue - - if not hash_data: - continue - - # Convert byte keys/values to strings if needed - if hash_data and isinstance(next(iter(hash_data.keys())), bytes): - hash_data = { - k.decode("utf-8"): v.decode("utf-8") - for k, v in hash_data.items() - } - - updates = {} - - # Check each datetime field with safe conversion - for field_name in datetime_fields: - if field_name in hash_data: - value = hash_data[field_name] - converted, success = self._safe_convert_datetime_value( - key, field_name, value - ) - - if success and converted != value: - updates[field_name] = str(converted) - - # Update the hash if we have changes - if updates: - try: - await self.redis.hset(key, mapping=updates) # type: ignore[misc] - except Exception as e: - log.error(f"Failed to update hash {key}: {e}") - if self.failure_mode == ConversionFailureMode.FAIL: - raise DataMigrationError( - f"Failed to update hash {key}: {e}" - ) - - # Mark key as processed - self.processed_keys_set.add(key) - self.stats.add_processed_key() - self._processed_keys += 1 - processed_count += 1 - - # Error threshold checking - self._check_error_threshold() - - # Save progress periodically - await self._save_progress_if_needed( - model_class.__name__, total_keys - ) + if await self._process_hash_key( + key, + datetime_fields, + model_class.__name__, + total_keys, + ): + processed_count += 1 except DataMigrationError: # Re-raise migration errors diff --git a/aredis_om/model/migrations/data/migrator.py b/aredis_om/model/migrations/data/migrator.py index 23456775..6b9ba1f0 100644 --- a/aredis_om/model/migrations/data/migrator.py +++ b/aredis_om/model/migrations/data/migrator.py @@ -329,21 +329,18 @@ async def run_migrations_with_monitoring( "errors": [], } - if verbose: - print(f"Found {len(pending_migrations)} pending migration(s):") - for migration in pending_migrations: - print(f"- {migration.migration_id}: {migration.description}") + self._log_pending_migrations(pending_migrations, verbose) if dry_run: if verbose: print("Dry run mode - no changes will be applied.") - return { - "applied_count": len(pending_migrations), - "total_migrations": len(pending_migrations), - "performance_stats": monitor.get_stats(), - "errors": [], - "dry_run": True, - } + return self._build_monitoring_result( + applied_count=len(pending_migrations), + pending_migrations=pending_migrations, + monitor=monitor, + errors=[], + dry_run=True, + ) applied_count = 0 errors = [] @@ -400,6 +397,33 @@ async def run_migrations_with_monitoring( monitor.finish() + result = self._build_monitoring_result( + applied_count=applied_count, + pending_migrations=pending_migrations, + monitor=monitor, + errors=errors, + ) + + self._log_monitoring_summary(result, verbose) + + return result + + def _log_pending_migrations( + self, pending_migrations: List[BaseMigration], verbose: bool + ) -> None: + if verbose: + print(f"Found {len(pending_migrations)} pending migration(s):") + for migration in pending_migrations: + print(f"- {migration.migration_id}: {migration.description}") + + def _build_monitoring_result( + self, + applied_count: int, + pending_migrations: List[BaseMigration], + monitor: PerformanceMonitor, + errors: List[Dict[str, Any]], + dry_run: bool = False, + ) -> Dict[str, Any]: result = { "applied_count": applied_count, "total_migrations": len(pending_migrations), @@ -412,18 +436,28 @@ async def run_migrations_with_monitoring( ), } + if dry_run: + result["dry_run"] = True + + return result + + def _log_monitoring_summary( + self, result: Dict[str, Any], verbose: bool + ) -> None: if verbose: - print(f"Applied {applied_count}/{len(pending_migrations)} migration(s).") + total_migrations = result["total_migrations"] + applied_count = result["applied_count"] + print(f"Applied {applied_count}/{total_migrations} migration(s).") stats = result["performance_stats"] if stats: print(f"Total time: {stats.get('total_time_seconds', 0):.2f}s") if "items_per_second" in stats: # type: ignore - print(f"Performance: {stats['items_per_second']:.1f} items/second") # type: ignore + print( + f"Performance: {stats['items_per_second']:.1f} items/second" + ) # type: ignore if "peak_memory_mb" in stats: # type: ignore print(f"Peak memory: {stats['peak_memory_mb']:.1f} MB") # type: ignore - return result - async def rollback_migration( self, migration_id: str, dry_run: bool = False, verbose: bool = False ) -> bool: diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 07f6de14..4cf8b0a2 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1541,6 +1541,161 @@ def expand_tag_value(value): ) return escaper.escape(str(value)) + @staticmethod + def _convert_numeric_value(value): + if isinstance(value, Enum): + value = value.value + if isinstance(value, (datetime.datetime, datetime.date)): + if isinstance(value, datetime.date) and not isinstance( + value, datetime.datetime + ): + value = datetime.datetime.combine( + value, datetime.time.min, tzinfo=datetime.timezone.utc + ) + value = value.timestamp() + return value + + @classmethod + def _resolve_text_value(cls, field_name: str, op: Operators, value: Any) -> str: + result = f"@{field_name}_fts:" + if op is Operators.EQ: + return f'{result}"{value}"' + if op is Operators.NE: + return f'-({result}"{value}")' + if op is Operators.LIKE: + return f"{result}{value}" + raise QueryNotSupportedError( + "Only equals (=), not-equals (!=), and like() " + "comparisons are supported for TEXT fields. " + f"Docs: {ERRORS_URL}#E5" + ) + + @classmethod + def _resolve_numeric_value( + cls, field_name: str, op: Operators, value: Any + ) -> str: + if op is Operators.IN: + converted_values = [cls._convert_numeric_value(v) for v in value] + parts = [f"(@{field_name}:[{v} {v}])" for v in converted_values] + return "|".join(parts) + if op is Operators.NOT_IN: + converted_values = [cls._convert_numeric_value(v) for v in value] + parts = [f"(@{field_name}:[{v} {v}])" for v in converted_values] + return f"-({' | '.join(parts)})" + + value = cls._convert_numeric_value(value) + if op is Operators.EQ: + return f"@{field_name}:[{value} {value}]" + if op is Operators.NE: + return f"-(@{field_name}:[{value} {value}])" + if op is Operators.GT: + return f"@{field_name}:[({value} +inf]" + if op is Operators.LT: + return f"@{field_name}:[-inf ({value}]" + if op is Operators.GE: + return f"@{field_name}:[{value} +inf]" + if op is Operators.LE: + return f"@{field_name}:[-inf {value}]" + return "" + + @classmethod + def _resolve_tag_eq_value( + cls, + field_name: str, + field_info: PydanticFieldInfo, + value: Any, + model_class: Optional[Type["RedisModel"]], + ) -> str: + separator_char = getattr( + field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR + ) + if value == separator_char: + log.warning( + "Your query against the field %s is for a single character, %s, " + "that is used internally by Redis OM Python. We must ignore " + "this portion of the query. Please review your query to find " + "an alternative query that uses a string containing more than " + "just the character %s.", + field_name, + separator_char, + separator_char, + ) + return "" + if isinstance(value, bool): + if model_class: + is_hash_model = any( + base.__name__ == "HashModel" for base in model_class.__mro__ + ) + value = ("1" if value else "0") if is_hash_model else value + return "@{field_name}:{{{value}}}".format( + field_name=field_name, value=value + ) + if isinstance(value, int): + return f"@{field_name}:[{value} {value}]" + if separator_char in value: + values: filter = filter(None, value.split(separator_char)) + return "".join( + "@{field_name}:{{{value}}}".format( + field_name=field_name, value=escaper.escape(value) + ) + for value in values + ) + value = escaper.escape(value) + return "@{field_name}:{{{value}}}".format(field_name=field_name, value=value) + + @classmethod + def _resolve_tag_value( + cls, + field_name: str, + field_info: PydanticFieldInfo, + op: Operators, + value: Any, + model_class: Optional[Type["RedisModel"]], + ) -> str: + if op is Operators.EQ: + return cls._resolve_tag_eq_value(field_name, field_info, value, model_class) + if op is Operators.NE: + value = escaper.escape(value) + return "-(@{field_name}:{{{value}}})".format( + field_name=field_name, value=value + ) + if op is Operators.IN: + expanded_value = cls.expand_tag_value(value) + return "(@{field_name}:{{{expanded_value}}})".format( + field_name=field_name, expanded_value=expanded_value + ) + if op is Operators.NOT_IN: + expanded_value = cls.expand_tag_value(value) + return "-(@{field_name}:{{{expanded_value}}})".format( + field_name=field_name, expanded_value=expanded_value + ) + if op is Operators.STARTSWITH: + expanded_value = cls.expand_tag_value(value) + return "(@{field_name}:{{{expanded_value}*}})".format( + field_name=field_name, expanded_value=expanded_value + ) + if op is Operators.ENDSWITH: + expanded_value = cls.expand_tag_value(value) + return "(@{field_name}:{{*{expanded_value}}})".format( + field_name=field_name, expanded_value=expanded_value + ) + if op is Operators.CONTAINS: + expanded_value = cls.expand_tag_value(value) + return "(@{field_name}:{{*{expanded_value}*}})".format( + field_name=field_name, expanded_value=expanded_value + ) + return "" + + @staticmethod + def _resolve_geo_value(field_name: str, op: Operators, value: Any) -> str: + if not isinstance(value, GeoFilter): + raise QuerySyntaxError( + "You can only use a GeoFilter object with a GEO field." + ) + if op is Operators.EQ: + return f"@{field_name}:[{value}]" + return "" + @classmethod def resolve_value( cls, @@ -1553,163 +1708,22 @@ def resolve_value( model_class: Optional[Type["RedisModel"]] = None, ) -> str: # The 'field_name' should already include the correct prefix - result = "" if parents: prefix = "_".join([p[0] for p in parents]) field_name = f"{prefix}_{field_name}" if field_type is RediSearchFieldTypes.TEXT: - result = f"@{field_name}_fts:" - if op is Operators.EQ: - result += f'"{value}"' - elif op is Operators.NE: - result = f'-({result}"{value}")' - elif op is Operators.LIKE: - result += value - else: - raise QueryNotSupportedError( - "Only equals (=), not-equals (!=), and like() " - "comparisons are supported for TEXT fields. " - f"Docs: {ERRORS_URL}#E5" - ) + return cls._resolve_text_value(field_name, op, value) elif field_type is RediSearchFieldTypes.NUMERIC: - # Helper to convert a single value for NUMERIC queries - def convert_numeric_value(v): - # Convert Enum to its value (fixes #108) - if isinstance(v, Enum): - v = v.value - # Convert datetime objects to timestamps - if isinstance(v, (datetime.datetime, datetime.date)): - if isinstance(v, datetime.date) and not isinstance( - v, datetime.datetime - ): - # Use UTC midnight so query conversion matches storage conversion. - v = datetime.datetime.combine( - v, datetime.time.min, tzinfo=datetime.timezone.utc - ) - v = v.timestamp() - return v - - if op is Operators.IN: - # Handle IN operator for NUMERIC fields (fixes #499) - # Convert each value and create OR of range queries - converted_values = [convert_numeric_value(v) for v in value] - parts = [f"(@{field_name}:[{v} {v}])" for v in converted_values] - result += "|".join(parts) - elif op is Operators.NOT_IN: - # Handle NOT_IN operator for NUMERIC fields - converted_values = [convert_numeric_value(v) for v in value] - parts = [f"(@{field_name}:[{v} {v}])" for v in converted_values] - result += f"-({' | '.join(parts)})" - else: - value = convert_numeric_value(value) - - if op is Operators.EQ: - result += f"@{field_name}:[{value} {value}]" - elif op is Operators.NE: - result += f"-(@{field_name}:[{value} {value}])" - elif op is Operators.GT: - result += f"@{field_name}:[({value} +inf]" - elif op is Operators.LT: - result += f"@{field_name}:[-inf ({value}]" - elif op is Operators.GE: - result += f"@{field_name}:[{value} +inf]" - elif op is Operators.LE: - result += f"@{field_name}:[-inf {value}]" + return cls._resolve_numeric_value(field_name, op, value) # TODO: How will we know the difference between a multi-value use of a TAG # field and our hidden use of TAG for exact-match queries? elif field_type is RediSearchFieldTypes.TAG: - if op is Operators.EQ: - separator_char = getattr( - field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR - ) - if value == separator_char: - # The value is ONLY the TAG field separator character -- - # this is not going to work. - log.warning( - "Your query against the field %s is for a single character, %s, " - "that is used internally by Redis OM Python. We must ignore " - "this portion of the query. Please review your query to find " - "an alternative query that uses a string containing more than " - "just the character %s.", - field_name, - separator_char, - separator_char, - ) - return "" - if isinstance(value, bool): - # For HashModel, convert boolean to "1"/"0" to match storage format - # For JsonModel, keep as boolean since JSON supports native booleans - if model_class: - # Check if this is a HashModel by checking the class hierarchy - is_hash_model = any( - base.__name__ == "HashModel" for base in model_class.__mro__ - ) - bool_value = ("1" if value else "0") if is_hash_model else value - else: - bool_value = value - result = "@{field_name}:{{{value}}}".format( - field_name=field_name, value=bool_value - ) - elif isinstance(value, int): - # This if will hit only if the field is a primary key of type int - result = f"@{field_name}:[{value} {value}]" - elif separator_char in value: - # The value contains the TAG field separator. We can work - # around this by breaking apart the values and unioning them - # with multiple field:{} queries. - values: filter = filter(None, value.split(separator_char)) - for value in values: - value = escaper.escape(value) - result += "@{field_name}:{{{value}}}".format( - field_name=field_name, value=value - ) - else: - value = escaper.escape(value) - result += "@{field_name}:{{{value}}}".format( - field_name=field_name, value=value - ) - elif op is Operators.NE: - value = escaper.escape(value) - result += "-(@{field_name}:{{{value}}})".format( - field_name=field_name, value=value - ) - elif op is Operators.IN: - expanded_value = cls.expand_tag_value(value) - result += "(@{field_name}:{{{expanded_value}}})".format( - field_name=field_name, expanded_value=expanded_value - ) - elif op is Operators.NOT_IN: - # TODO: Implement NOT_IN, test this... - expanded_value = cls.expand_tag_value(value) - result += "-(@{field_name}:{{{expanded_value}}})".format( - field_name=field_name, expanded_value=expanded_value - ) - elif op is Operators.STARTSWITH: - expanded_value = cls.expand_tag_value(value) - result += "(@{field_name}:{{{expanded_value}*}})".format( - field_name=field_name, expanded_value=expanded_value - ) - elif op is Operators.ENDSWITH: - expanded_value = cls.expand_tag_value(value) - result += "(@{field_name}:{{*{expanded_value}}})".format( - field_name=field_name, expanded_value=expanded_value - ) - elif op is Operators.CONTAINS: - expanded_value = cls.expand_tag_value(value) - result += "(@{field_name}:{{*{expanded_value}*}})".format( - field_name=field_name, expanded_value=expanded_value - ) + return cls._resolve_tag_value(field_name, field_info, op, value, model_class) elif field_type is RediSearchFieldTypes.GEO: - if not isinstance(value, GeoFilter): - raise QuerySyntaxError( - "You can only use a GeoFilter object with a GEO field." - ) + return cls._resolve_geo_value(field_name, op, value) - if op is Operators.EQ: - result += f"@{field_name}:[{value}]" - - return result + return "" def resolve_redisearch_pagination(self): """Resolve pagination options for a query.""" @@ -1835,9 +1849,7 @@ def _resolve_redisearch_query(self, expression: ExpressionOrNegated) -> str: return result - async def execute( - self, exhaust_results=True, return_raw_result=False, return_query_args=False - ): + def _build_execute_args(self) -> tuple[List[Union[str, bytes]], bool]: args: List[Union[str, bytes]] = [ "FT.SEARCH", self.model.Meta.index_name, @@ -1862,29 +1874,22 @@ async def execute( if self.nocontent: args.append("NOCONTENT") - # Check if we have complex fields that RediSearch RETURN clause can't handle - use_full_document_fallback = False - if self.projected_fields: - use_full_document_fallback = self._has_complex_projected_fields() + use_full_document_fallback = bool( + self.projected_fields and self._has_complex_projected_fields() + ) - # Add RETURN clause to the args list, not to the query string - # Skip RETURN clause if we need full documents for complex field projection + # Add RETURN clause to the args list, not to the query string. + # Skip RETURN clause if we need full documents for complex field projection. if self.projected_fields and not use_full_document_fallback: args.extend( ["RETURN", str(len(self.projected_fields))] + self.projected_fields ) - if return_query_args: - return self.model.Meta.index_name, args - - # Reset the cache if we're executing from offset 0. - if self.offset == 0: - self._model_cache.clear() + return args, use_full_document_fallback - # If the offset is greater than 0, we're paginating through a result set, - # so append the new results to results already in the cache. + async def _execute_command(self, args: List[Union[str, bytes]]): try: - raw_result = await self.model.db().execute_command(*args) + return await self.model.db().execute_command(*args) except Exception as e: error_msg = str(e).lower() @@ -1900,32 +1905,55 @@ async def execute( # Re-raise the original exception raise - if return_raw_result: - return raw_result - count = raw_result[0] - # Handle different result processing based on what was requested + async def _parse_execute_results( + self, raw_result: Any, use_full_document_fallback: bool + ): if self.projected_fields and use_full_document_fallback: # Complex field projection - use full document fallback if self.return_as_dict: - results = await self._parse_full_document_projection_as_dict(raw_result) - else: - results = await self._parse_full_document_projection_as_models( - raw_result - ) - elif self.projected_fields and self.return_as_dict: + return await self._parse_full_document_projection_as_dict(raw_result) + return await self._parse_full_document_projection_as_models(raw_result) + + if self.projected_fields and self.return_as_dict: # .values('field1', 'field2') - specific fields as dicts - results = self._parse_projected_results(raw_result) - elif self.projected_fields and not self.return_as_dict: + return self._parse_projected_results(raw_result) + + if self.projected_fields and not self.return_as_dict: # .only('field1', 'field2') - partial model instances - results = self._parse_projected_models(raw_result) - elif self.return_as_dict and not self.projected_fields: + return self._parse_projected_models(raw_result) + + if self.return_as_dict and not self.projected_fields: # .values() - all fields as dicts model_results = self.model.from_redis(raw_result, self.knn) - results = [model.model_dump() for model in model_results] - else: - # Normal query - full model instances - results = self.model.from_redis(raw_result, self.knn) + return [model.model_dump() for model in model_results] + + # Normal query - full model instances + return self.model.from_redis(raw_result, self.knn) + + async def execute( + self, exhaust_results=True, return_raw_result=False, return_query_args=False + ): + args, use_full_document_fallback = self._build_execute_args() + + if return_query_args: + return self.model.Meta.index_name, args + + # Reset the cache if we're executing from offset 0. + if self.offset == 0: + self._model_cache.clear() + + # If the offset is greater than 0, we're paginating through a result set, + # so append the new results to results already in the cache. + raw_result = await self._execute_command(args) + + if return_raw_result: + return raw_result + count = raw_result[0] + + results = await self._parse_execute_results( + raw_result, use_full_document_fallback + ) self._model_cache += results if not exhaust_results: @@ -2409,44 +2437,30 @@ class ModelMeta(ModelMetaclass): model_config: RedisOmConfig model_fields: Dict[str, FieldInfo] # type: ignore[assignment] - def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 - meta = attrs.pop("Meta", None) - - # Capture original FieldInfo objects from attrs before Pydantic processes them. - # Pydantic 2.12+ may convert custom FieldInfo subclasses to plain PydanticFieldInfo - # for Annotated types, losing custom attributes like index, sortable, etc. + @staticmethod + def _capture_original_field_infos(attrs): original_field_infos: Dict[str, FieldInfo] = {} if PYDANTIC_V2: for attr_name, attr_value in attrs.items(): if isinstance(attr_value, FieldInfo): original_field_infos[attr_name] = attr_value + return original_field_infos - # Duplicate logic from Pydantic to filter config kwargs because if they are - # passed directly including the registry Pydantic will pass them over to the - # superclass causing an error - allowed_config_kwargs: Set[str] = { + @staticmethod + def _allowed_config_kwargs(): + return { key for key in dir(ConfigDict) - if not ( - key.startswith("__") and key.endswith("__") - ) # skip dunder methods and attributes + if not (key.startswith("__") and key.endswith("__")) } - config_kwargs = { - key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs - } - - new_class: RedisModel = super().__new__( - cls, name, bases, attrs, **config_kwargs - ) - + @staticmethod + def _build_meta(new_class, meta, base_meta): # The fact that there is a Meta field and _meta field is important: a # user may have given us a Meta object with their configuration, while # we might have inherited _meta from a parent class, and should # therefore use some of the inherited fields. meta = meta or getattr(new_class, "Meta", None) - base_meta = getattr(new_class, "_meta", None) - if meta and meta != DefaultMeta and meta != base_meta: new_class.Meta = meta new_class._meta = meta @@ -2465,13 +2479,8 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 ) new_class.Meta = new_class._meta - is_indexed = kwargs.get("index", None) is True - - if is_indexed and new_class.model_config.get("index", None) is True: - raise RedisModelError( - f"{new_class.__name__} cannot be indexed, only one model can be indexed in an inheritance tree" - ) - + @staticmethod + def _set_index_config(new_class, is_indexed): if PYDANTIC_V2: new_class.model_config["index"] = is_indexed else: @@ -2485,14 +2494,20 @@ class Config: new_class.Config = Config - # Create proxies for each model field so that we can use the field - # in queries, like Model.get(Model.field_name == 1) - # Only set if the model is has index=True + @staticmethod + def _model_fields(new_class): if PYDANTIC_V2: - model_fields = new_class.model_fields - else: - model_fields = new_class.__fields__ + return new_class.model_fields + return new_class.__fields__ + @staticmethod + def _apply_field_proxies( + new_class, + model_fields, + original_field_infos, + is_indexed, + ): + custom_pk_count = 0 for field_name, field in model_fields.items(): pydantic_field = field # Keep reference to Pydantic's processed field if type(field) is PydanticFieldInfo: @@ -2541,7 +2556,6 @@ class Config: # should set up the PrimaryKeyAccessor. We only do this when there's # exactly one custom primary key. Multiple custom primary keys will be # caught by validate_primary_key() later. - custom_pk_count = 0 for field_name, field in model_fields.items(): if field_name == "pk": continue @@ -2557,6 +2571,10 @@ class Config: if getattr(check_field, "primary_key", None) is True: custom_pk_count += 1 + return custom_pk_count + + @staticmethod + def _setup_primary_key_accessor(new_class, model_fields, custom_pk_count): # If there's exactly one custom primary key (not the default 'pk'), set up # a PrimaryKeyAccessor so that .pk always returns the correct value. # This fixes GitHub issue #570. @@ -2572,6 +2590,8 @@ class Config: # Set up PrimaryKeyAccessor descriptor for .pk access setattr(new_class, "pk", PrimaryKeyAccessor()) + @staticmethod + def _apply_embedded_model_rules(new_class): # For embedded models, clear the primary_key from meta since they don't # need primary keys - they're stored as part of their parent document, # not as separate Redis keys. This fixes GitHub issue #496. @@ -2580,6 +2600,8 @@ class Config: if getattr(new_class._meta, "embedded", False): new_class._meta.primary_key = None + @staticmethod + def _apply_meta_defaults(new_class, base_meta): if not getattr(new_class._meta, "global_key_prefix", None): new_class._meta.global_key_prefix = getattr( base_meta, "global_key_prefix", "" @@ -2610,6 +2632,8 @@ class Config: f"{new_class._meta.model_key_prefix}:index" ) + @staticmethod + def _register_indexed_model(new_class, bases): # Model is indexed and not an abstract model class or embedded model, so we should let the # Migrator create indexes for it. if ( @@ -2620,6 +2644,52 @@ class Config: key = f"{new_class.__module__}.{new_class.__qualname__}" model_registry[key] = new_class + def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 + meta = attrs.pop("Meta", None) + + # Capture original FieldInfo objects from attrs before Pydantic processes them. + # Pydantic 2.12+ may convert custom FieldInfo subclasses to plain PydanticFieldInfo + # for Annotated types, losing custom attributes like index, sortable, etc. + original_field_infos = cls._capture_original_field_infos(attrs) + + # Duplicate logic from Pydantic to filter config kwargs because if they are + # passed directly including the registry Pydantic will pass them over to the + # superclass causing an error + allowed_config_kwargs = cls._allowed_config_kwargs() + + config_kwargs = {key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs} + + new_class: RedisModel = super().__new__( + cls, name, bases, attrs, **config_kwargs + ) + + base_meta = getattr(new_class, "_meta", None) + cls._build_meta(new_class, meta, base_meta) + + is_indexed = kwargs.get("index", None) is True + + if is_indexed and new_class.model_config.get("index", None) is True: + raise RedisModelError( + f"{new_class.__name__} cannot be indexed, only one model can be indexed in an inheritance tree" + ) + + cls._set_index_config(new_class, is_indexed) + + # Create proxies for each model field so that we can use the field + # in queries, like Model.get(Model.field_name == 1) + # Only set if the model is has index=True + model_fields = cls._model_fields(new_class) + custom_pk_count = cls._apply_field_proxies( + new_class, + model_fields, + original_field_infos, + is_indexed, + ) + cls._setup_primary_key_accessor(new_class, model_fields, custom_pk_count) + cls._apply_embedded_model_rules(new_class) + cls._apply_meta_defaults(new_class, base_meta) + cls._register_indexed_model(new_class, bases) + return new_class @@ -3713,161 +3783,182 @@ def schema_for_type( # into the values of the list or the fields of the model to # find any values marked as indexed. if is_container_type and not is_vector: - field_type = get_origin(typ) - if field_type == Literal: - path = f"{json_path}.{name}" - return cls.schema_for_type( - path, - name, - name_prefix, - str, - field_info, - parent_type=field_type, - ) - else: - embedded_cls = get_args(typ) - if not embedded_cls: - log.warning( - "Model %s defined an empty list or tuple field: %s", cls, name - ) - return "" - path = f"{json_path}.{name}[*]" - embedded_cls = embedded_cls[0] - return cls.schema_for_type( - path, - name, - name_prefix, - embedded_cls, - field_info, - parent_type=field_type, - ) + return cls._schema_for_container_type( + json_path, name, name_prefix, typ, field_info + ) elif field_is_model: - name_prefix = f"{name_prefix}_{name}" if name_prefix else name - sub_fields = [] - for embedded_name, field in typ.model_fields.items(): - if ( - hasattr(field, "metadata") - and len(field.metadata) > 0 - and isinstance(field.metadata[0], FieldInfo) - ): - field_info = field.metadata[0] - else: - field_info = field - - if parent_is_container_type: - # We'll store this value either as a JavaScript array, so - # the correct JSONPath expression is to refer directly to - # attribute names after the container notation, e.g. - # orders[*].created_date. - path = json_path - else: - # All other fields should use dot notation with both the - # current field name and "embedded" field name, e.g., - # order.address.street_line_1. - path = f"{json_path}.{name}" - sub_fields.append( - cls.schema_for_type( - path, - embedded_name, - name_prefix, - # field.annotation, - get_outer_type(field), - field_info, - parent_type=typ, - ) - ) - return " ".join(filter(None, sub_fields)) + return cls._schema_for_embedded_model( + json_path, name, name_prefix, typ, parent_is_container_type + ) # NOTE: This is the termination point for recursion. We've descended # into models and lists until we found an actual value to index. elif should_index: - index_field_name = f"{name_prefix}_{name}" if name_prefix else name - if parent_is_container_type: - # If we're indexing the this field as a JavaScript array, then - # the currently built-up JSONPath expression will be - # "field_name[*]", which is what we want to use. - path = json_path + return cls._schema_for_leaf_type( + json_path, + name, + name_prefix, + typ, + field_info, + is_vector, + vector_options, + parent_is_container_type, + parent_is_model_in_container, + ) + return "" + + @classmethod + def _schema_for_container_type( + cls, + json_path: str, + name: str, + name_prefix: str, + typ: Any, + field_info: PydanticFieldInfo, + ) -> str: + field_type = get_origin(typ) + if field_type == Literal: + path = f"{json_path}.{name}" + return cls.schema_for_type( + path, + name, + name_prefix, + str, + field_info, + parent_type=field_type, + ) + + embedded_cls = get_args(typ) + if not embedded_cls: + log.warning("Model %s defined an empty list or tuple field: %s", cls, name) + return "" + + path = f"{json_path}.{name}[*]" + return cls.schema_for_type( + path, + name, + name_prefix, + embedded_cls[0], + field_info, + parent_type=field_type, + ) + + @classmethod + def _schema_for_embedded_model( + cls, + json_path: str, + name: str, + name_prefix: str, + typ: Union[Type[RedisModel], Any], + parent_is_container_type: bool, + ) -> str: + name_prefix = f"{name_prefix}_{name}" if name_prefix else name + sub_fields = [] + for embedded_name, field in typ.model_fields.items(): + if ( + hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): + field_info = field.metadata[0] else: - path = f"{json_path}.{name}" - sortable = getattr(field_info, "sortable", False) - case_sensitive = getattr(field_info, "case_sensitive", False) - full_text_search = getattr(field_info, "full_text_search", False) - - # For more complicated compound validators (e.g. PositiveInt), we might get a _GenericAlias rather than - # a proper type, we can pull the type information from the origin of the first argument. - if not isinstance(typ, type): - type_args = typing_get_args(field_info.annotation) - typ = ( - getattr(type_args[0], "__origin__", type_args[0]) - if type_args - else typ - ) + field_info = field - # Get separator from field_info, defaulting to pipe - separator = getattr( - field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR + # Lists store embedded values directly under the container path. + path = json_path if parent_is_container_type else f"{json_path}.{name}" + sub_fields.append( + cls.schema_for_type( + path, + embedded_name, + name_prefix, + get_outer_type(field), + field_info, + parent_type=typ, + ) ) + return " ".join(filter(None, sub_fields)) - if is_vector and vector_options: - schema = f"{path} AS {index_field_name} {vector_options.schema}" - elif parent_is_container_type or parent_is_model_in_container: - if typ is not str: - raise RedisModelError( - "List and tuple fields can only contain strings. " - f"Problem field: {name}. Docs: {ERRORS_URL}#E12" - ) - if full_text_search is True: - raise RedisModelError( - "List and tuple fields cannot be indexed for full-text " - f"search. Problem field: {name}. Docs: {ERRORS_URL}#E13" - ) - # List/tuple fields are indexed as TAG fields and can be sortable - schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" + @classmethod + def _schema_for_leaf_type( + cls, + json_path: str, + name: str, + name_prefix: str, + typ: Union[Type[RedisModel], Any], + field_info: PydanticFieldInfo, + is_vector: bool, + vector_options: Optional[VectorFieldOptions], + parent_is_container_type: bool, + parent_is_model_in_container: bool, + ) -> str: + index_field_name = f"{name_prefix}_{name}" if name_prefix else name + path = json_path if parent_is_container_type else f"{json_path}.{name}" + sortable = getattr(field_info, "sortable", False) + case_sensitive = getattr(field_info, "case_sensitive", False) + full_text_search = getattr(field_info, "full_text_search", False) + + # For more complicated compound validators (e.g. PositiveInt), we might + # get a _GenericAlias rather than a proper type. + if not isinstance(typ, type): + type_args = typing_get_args(field_info.annotation) + typ = getattr(type_args[0], "__origin__", type_args[0]) if type_args else typ + + separator = getattr(field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR) + + if is_vector and vector_options: + schema = f"{path} AS {index_field_name} {vector_options.schema}" + elif parent_is_container_type or parent_is_model_in_container: + if typ is not str: + raise RedisModelError( + "List and tuple fields can only contain strings. " + f"Problem field: {name}. Docs: {ERRORS_URL}#E12" + ) + if full_text_search is True: + raise RedisModelError( + "List and tuple fields cannot be indexed for full-text " + f"search. Problem field: {name}. Docs: {ERRORS_URL}#E13" + ) + schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" + if sortable is True: + schema += " SORTABLE" + if case_sensitive is True: + schema += " CASESENSITIVE" + elif typ is bool: + schema = f"{path} AS {index_field_name} TAG" + if sortable is True: + schema += " SORTABLE" + elif typ in [CoordinateType, Coordinates]: + schema = f"{path} AS {index_field_name} GEO" + if sortable is True: + schema += " SORTABLE" + elif is_numeric_type(typ): + schema = f"{path} AS {index_field_name} NUMERIC" + if sortable is True: + schema += " SORTABLE" + elif issubclass(typ, str): + if full_text_search is True: + schema = ( + f"{path} AS {index_field_name} TAG SEPARATOR {separator} " + f"{path} AS {index_field_name}_fts TEXT" + ) if sortable is True: + # NOTE: With the current preview release, making a field + # full-text searchable and sortable only makes the TEXT + # field sortable. schema += " SORTABLE" if case_sensitive is True: - schema += " CASESENSITIVE" - elif typ is bool: - schema = f"{path} AS {index_field_name} TAG" - if sortable is True: - schema += " SORTABLE" - elif typ in [CoordinateType, Coordinates]: - schema = f"{path} AS {index_field_name} GEO" - if sortable is True: - schema += " SORTABLE" - elif is_numeric_type(typ): - schema = f"{path} AS {index_field_name} NUMERIC" - if sortable is True: - schema += " SORTABLE" - elif issubclass(typ, str): - if full_text_search is True: - schema = ( - f"{path} AS {index_field_name} TAG SEPARATOR {separator} " - f"{path} AS {index_field_name}_fts TEXT" - ) - if sortable is True: - # NOTE: With the current preview release, making a field - # full-text searchable and sortable only makes the TEXT - # field sortable. This means that results for full-text - # search queries can be sorted, but not exact match - # queries. - schema += " SORTABLE" - if case_sensitive is True: - raise RedisModelError("Text fields cannot be case-sensitive.") - else: - # String fields are indexed as TAG fields and can be sortable - schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" - if sortable is True: - schema += " SORTABLE" - if case_sensitive is True: - schema += " CASESENSITIVE" + raise RedisModelError("Text fields cannot be case-sensitive.") else: - # Default to TAG field, which can be sortable schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" if sortable is True: schema += " SORTABLE" + if case_sensitive is True: + schema += " CASESENSITIVE" + else: + schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" + if sortable is True: + schema += " SORTABLE" - return schema - return "" + return schema class EmbeddedJsonModel(JsonModel, abc.ABC):