Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 80 additions & 75 deletions aredis_om/model/migrations/data/builtin/datetime_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,18 +247,13 @@ def __init__(
self.failure_mode = failure_mode
self.batch_size = batch_size
self.max_errors = max_errors
self.enable_resume = enable_resume
self.progress_save_interval = progress_save_interval
self.stats = MigrationStats()
self.migration_state = (
MigrationState(self.redis, self.migration_id) if enable_resume else None
)
self.processed_keys_set: Set[str] = set()

# Legacy compatibility
self._processed_keys = 0
self._converted_fields = 0

def _safe_convert_datetime_value(
self, key: str, field_name: str, value: Any
) -> Tuple[Any, bool]:
Expand Down Expand Up @@ -332,7 +327,6 @@ async def _load_previous_progress(self) -> bool:

if progress["processed_keys"]:
self.processed_keys_set = set(progress["processed_keys"])
self._processed_keys = len(self.processed_keys_set)

# Restore stats if available
if progress.get("stats"):
Expand Down Expand Up @@ -363,6 +357,78 @@ async def _save_progress_if_needed(self, current_model: str, total_keys: int):
stats=self.stats.get_summary(),
)

async def _collect_hash_keys(self, model_class) -> List[str]:
"""Collect all Redis hash keys for a model."""
key_pattern = model_class.make_key("*")
all_keys = []
scan_iter = self.redis.scan_iter(match=key_pattern, _type="HASH")

async for key in scan_iter: # type: ignore[misc]
if isinstance(key, bytes):
key = key.decode("utf-8")
all_keys.append(key)

return all_keys

@staticmethod
def _normalize_hash_data(hash_data: Dict[Any, Any]) -> Dict[str, Any]:
"""Normalize hash payload to string keys and values."""
if hash_data and isinstance(next(iter(hash_data.keys())), bytes):
return {
k.decode("utf-8"): v.decode("utf-8")
for k, v in hash_data.items()
}

return hash_data

async def _process_hash_key(
self,
key: str,
datetime_fields: List[str],
model_name: str,
total_keys: int,
) -> bool:
"""Process a single hash key and return whether it was handled."""
if key in self.processed_keys_set:
return False

try:
hash_data = await self.redis.hgetall(key) # type: ignore[misc]
except Exception as e:
log.warning(f"Failed to get hash data from {key}: {e}")
return False

if not hash_data:
return False

hash_data = self._normalize_hash_data(hash_data)
updates = {}

for field_name in datetime_fields:
if field_name in hash_data:
value = hash_data[field_name]
converted, success = self._safe_convert_datetime_value(
key, field_name, value
)

if success and converted != value:
updates[field_name] = str(converted)

if updates:
try:
await self.redis.hset(key, mapping=updates) # type: ignore[misc]
except Exception as e:
log.error(f"Failed to update hash {key}: {e}")
if self.failure_mode == ConversionFailureMode.FAIL:
raise DataMigrationError(f"Failed to update hash {key}: {e}")

self.processed_keys_set.add(key)
self.stats.add_processed_key()

self._check_error_threshold()
await self._save_progress_if_needed(model_name, total_keys)
return True

async def _clear_progress_on_completion(self):
"""Clear saved progress when migration completes successfully."""
if self.migration_state:
Expand Down Expand Up @@ -504,16 +570,7 @@ async def _process_hash_model(
self, model_class, datetime_fields: List[str]
) -> None:
"""Process HashModel instances to convert datetime fields with enhanced error handling."""
# Get all keys for this model
key_pattern = model_class.make_key("*")

# Collect all keys first for batch processing
all_keys = []
scan_iter = self.redis.scan_iter(match=key_pattern, _type="HASH")
async for key in scan_iter: # type: ignore[misc]
if isinstance(key, bytes):
key = key.decode("utf-8")
all_keys.append(key)
all_keys = await self._collect_hash_keys(model_class)

total_keys = len(all_keys)
log.info(
Expand All @@ -531,64 +588,13 @@ async def _process_hash_model(

for key in batch_keys:
try:
# Skip if already processed (resume capability)
if key in self.processed_keys_set:
continue

# Get all fields from the hash
try:
hash_data = await self.redis.hgetall(key) # type: ignore[misc]
except Exception as e:
log.warning(f"Failed to get hash data from {key}: {e}")
continue

if not hash_data:
continue

# Convert byte keys/values to strings if needed
if hash_data and isinstance(next(iter(hash_data.keys())), bytes):
hash_data = {
k.decode("utf-8"): v.decode("utf-8")
for k, v in hash_data.items()
}

updates = {}

# Check each datetime field with safe conversion
for field_name in datetime_fields:
if field_name in hash_data:
value = hash_data[field_name]
converted, success = self._safe_convert_datetime_value(
key, field_name, value
)

if success and converted != value:
updates[field_name] = str(converted)

# Update the hash if we have changes
if updates:
try:
await self.redis.hset(key, mapping=updates) # type: ignore[misc]
except Exception as e:
log.error(f"Failed to update hash {key}: {e}")
if self.failure_mode == ConversionFailureMode.FAIL:
raise DataMigrationError(
f"Failed to update hash {key}: {e}"
)

# Mark key as processed
self.processed_keys_set.add(key)
self.stats.add_processed_key()
self._processed_keys += 1
processed_count += 1

# Error threshold checking
self._check_error_threshold()

# Save progress periodically
await self._save_progress_if_needed(
model_class.__name__, total_keys
)
if await self._process_hash_key(
key,
datetime_fields,
model_class.__name__,
total_keys,
):
processed_count += 1

except DataMigrationError:
# Re-raise migration errors
Expand Down Expand Up @@ -677,7 +683,6 @@ async def _process_json_model(
# Mark key as processed
self.processed_keys_set.add(key)
self.stats.add_processed_key()
self._processed_keys += 1
processed_count += 1

# Error threshold checking
Expand Down
64 changes: 49 additions & 15 deletions aredis_om/model/migrations/data/migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,21 +329,18 @@ async def run_migrations_with_monitoring(
"errors": [],
}

if verbose:
print(f"Found {len(pending_migrations)} pending migration(s):")
for migration in pending_migrations:
print(f"- {migration.migration_id}: {migration.description}")
self._log_pending_migrations(pending_migrations, verbose)

if dry_run:
if verbose:
print("Dry run mode - no changes will be applied.")
return {
"applied_count": len(pending_migrations),
"total_migrations": len(pending_migrations),
"performance_stats": monitor.get_stats(),
"errors": [],
"dry_run": True,
}
return self._build_monitoring_result(
applied_count=len(pending_migrations),
pending_migrations=pending_migrations,
monitor=monitor,
errors=[],
dry_run=True,
)

applied_count = 0
errors = []
Expand Down Expand Up @@ -400,6 +397,33 @@ async def run_migrations_with_monitoring(

monitor.finish()

result = self._build_monitoring_result(
applied_count=applied_count,
pending_migrations=pending_migrations,
monitor=monitor,
errors=errors,
)

self._log_monitoring_summary(result, verbose)

return result

def _log_pending_migrations(
self, pending_migrations: List[BaseMigration], verbose: bool
) -> None:
if verbose:
print(f"Found {len(pending_migrations)} pending migration(s):")
for migration in pending_migrations:
print(f"- {migration.migration_id}: {migration.description}")

def _build_monitoring_result(
self,
applied_count: int,
pending_migrations: List[BaseMigration],
monitor: PerformanceMonitor,
errors: List[Dict[str, Any]],
dry_run: bool = False,
) -> Dict[str, Any]:
result = {
"applied_count": applied_count,
"total_migrations": len(pending_migrations),
Expand All @@ -412,18 +436,28 @@ async def run_migrations_with_monitoring(
),
}

if dry_run:
result["dry_run"] = True

return result

def _log_monitoring_summary(
self, result: Dict[str, Any], verbose: bool
) -> None:
if verbose:
print(f"Applied {applied_count}/{len(pending_migrations)} migration(s).")
total_migrations = result["total_migrations"]
applied_count = result["applied_count"]
print(f"Applied {applied_count}/{total_migrations} migration(s).")
stats = result["performance_stats"]
if stats:
print(f"Total time: {stats.get('total_time_seconds', 0):.2f}s")
if "items_per_second" in stats: # type: ignore
print(f"Performance: {stats['items_per_second']:.1f} items/second") # type: ignore
print(
f"Performance: {stats['items_per_second']:.1f} items/second"
) # type: ignore
if "peak_memory_mb" in stats: # type: ignore
print(f"Peak memory: {stats['peak_memory_mb']:.1f} MB") # type: ignore

return result

async def rollback_migration(
self, migration_id: str, dry_run: bool = False, verbose: bool = False
) -> bool:
Expand Down
Loading