Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 81 additions & 68 deletions aredis_om/model/migrations/data/builtin/datetime_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,79 @@ async def _save_progress_if_needed(self, current_model: str, total_keys: int):
stats=self.stats.get_summary(),
)

async def _collect_hash_keys(self, model_class) -> List[str]:
"""Collect all Redis hash keys for a model."""
key_pattern = model_class.make_key("*")
all_keys = []
scan_iter = self.redis.scan_iter(match=key_pattern, _type="HASH")

async for key in scan_iter: # type: ignore[misc]
if isinstance(key, bytes):
key = key.decode("utf-8")
all_keys.append(key)

return all_keys

@staticmethod
def _normalize_hash_data(hash_data: Dict[Any, Any]) -> Dict[str, Any]:
"""Normalize hash payload to string keys and values."""
if hash_data and isinstance(next(iter(hash_data.keys())), bytes):
return {
k.decode("utf-8"): v.decode("utf-8")
for k, v in hash_data.items()
}

return hash_data

async def _process_hash_key(
self,
key: str,
datetime_fields: List[str],
model_name: str,
total_keys: int,
) -> bool:
"""Process a single hash key and return whether it was handled."""
if key in self.processed_keys_set:
return False

try:
hash_data = await self.redis.hgetall(key) # type: ignore[misc]
except Exception as e:
log.warning(f"Failed to get hash data from {key}: {e}")
return False

if not hash_data:
return False

hash_data = self._normalize_hash_data(hash_data)
updates = {}

for field_name in datetime_fields:
if field_name in hash_data:
value = hash_data[field_name]
converted, success = self._safe_convert_datetime_value(
key, field_name, value
)

if success and converted != value:
updates[field_name] = str(converted)

if updates:
try:
await self.redis.hset(key, mapping=updates) # type: ignore[misc]
except Exception as e:
log.error(f"Failed to update hash {key}: {e}")
if self.failure_mode == ConversionFailureMode.FAIL:
raise DataMigrationError(f"Failed to update hash {key}: {e}")

self.processed_keys_set.add(key)
self.stats.add_processed_key()
self._processed_keys += 1

self._check_error_threshold()
await self._save_progress_if_needed(model_name, total_keys)
return True

async def _clear_progress_on_completion(self):
"""Clear saved progress when migration completes successfully."""
if self.migration_state:
Expand Down Expand Up @@ -504,16 +577,7 @@ async def _process_hash_model(
self, model_class, datetime_fields: List[str]
) -> None:
"""Process HashModel instances to convert datetime fields with enhanced error handling."""
# Get all keys for this model
key_pattern = model_class.make_key("*")

# Collect all keys first for batch processing
all_keys = []
scan_iter = self.redis.scan_iter(match=key_pattern, _type="HASH")
async for key in scan_iter: # type: ignore[misc]
if isinstance(key, bytes):
key = key.decode("utf-8")
all_keys.append(key)
all_keys = await self._collect_hash_keys(model_class)

total_keys = len(all_keys)
log.info(
Expand All @@ -531,64 +595,13 @@ async def _process_hash_model(

for key in batch_keys:
try:
# Skip if already processed (resume capability)
if key in self.processed_keys_set:
continue

# Get all fields from the hash
try:
hash_data = await self.redis.hgetall(key) # type: ignore[misc]
except Exception as e:
log.warning(f"Failed to get hash data from {key}: {e}")
continue

if not hash_data:
continue

# Convert byte keys/values to strings if needed
if hash_data and isinstance(next(iter(hash_data.keys())), bytes):
hash_data = {
k.decode("utf-8"): v.decode("utf-8")
for k, v in hash_data.items()
}

updates = {}

# Check each datetime field with safe conversion
for field_name in datetime_fields:
if field_name in hash_data:
value = hash_data[field_name]
converted, success = self._safe_convert_datetime_value(
key, field_name, value
)

if success and converted != value:
updates[field_name] = str(converted)

# Update the hash if we have changes
if updates:
try:
await self.redis.hset(key, mapping=updates) # type: ignore[misc]
except Exception as e:
log.error(f"Failed to update hash {key}: {e}")
if self.failure_mode == ConversionFailureMode.FAIL:
raise DataMigrationError(
f"Failed to update hash {key}: {e}"
)

# Mark key as processed
self.processed_keys_set.add(key)
self.stats.add_processed_key()
self._processed_keys += 1
processed_count += 1

# Error threshold checking
self._check_error_threshold()

# Save progress periodically
await self._save_progress_if_needed(
model_class.__name__, total_keys
)
if await self._process_hash_key(
key,
datetime_fields,
model_class.__name__,
total_keys,
):
processed_count += 1

except DataMigrationError:
# Re-raise migration errors
Expand Down
64 changes: 49 additions & 15 deletions aredis_om/model/migrations/data/migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,21 +329,18 @@ async def run_migrations_with_monitoring(
"errors": [],
}

if verbose:
print(f"Found {len(pending_migrations)} pending migration(s):")
for migration in pending_migrations:
print(f"- {migration.migration_id}: {migration.description}")
self._log_pending_migrations(pending_migrations, verbose)

if dry_run:
if verbose:
print("Dry run mode - no changes will be applied.")
return {
"applied_count": len(pending_migrations),
"total_migrations": len(pending_migrations),
"performance_stats": monitor.get_stats(),
"errors": [],
"dry_run": True,
}
return self._build_monitoring_result(
applied_count=len(pending_migrations),
pending_migrations=pending_migrations,
monitor=monitor,
errors=[],
dry_run=True,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dry-run result now includes extra success_rate field

Low Severity

The dry-run path now returns a success_rate key in its result dictionary that wasn't present before the refactoring. The original dry-run result contained only applied_count, total_migrations, performance_stats, errors, and dry_run. By routing both the dry-run and normal paths through _build_monitoring_result, the dry-run result now also includes success_rate, changing the API contract.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit b80c1a6. Configure here.


applied_count = 0
errors = []
Expand Down Expand Up @@ -400,6 +397,33 @@ async def run_migrations_with_monitoring(

monitor.finish()

result = self._build_monitoring_result(
applied_count=applied_count,
pending_migrations=pending_migrations,
monitor=monitor,
errors=errors,
)

self._log_monitoring_summary(result, verbose)

return result

def _log_pending_migrations(
self, pending_migrations: List[BaseMigration], verbose: bool
) -> None:
if verbose:
print(f"Found {len(pending_migrations)} pending migration(s):")
for migration in pending_migrations:
print(f"- {migration.migration_id}: {migration.description}")

def _build_monitoring_result(
self,
applied_count: int,
pending_migrations: List[BaseMigration],
monitor: PerformanceMonitor,
errors: List[Dict[str, Any]],
dry_run: bool = False,
) -> Dict[str, Any]:
result = {
"applied_count": applied_count,
"total_migrations": len(pending_migrations),
Expand All @@ -412,18 +436,28 @@ async def run_migrations_with_monitoring(
),
}

if dry_run:
result["dry_run"] = True

return result

def _log_monitoring_summary(
self, result: Dict[str, Any], verbose: bool
) -> None:
if verbose:
print(f"Applied {applied_count}/{len(pending_migrations)} migration(s).")
total_migrations = result["total_migrations"]
applied_count = result["applied_count"]
print(f"Applied {applied_count}/{total_migrations} migration(s).")
stats = result["performance_stats"]
if stats:
print(f"Total time: {stats.get('total_time_seconds', 0):.2f}s")
if "items_per_second" in stats: # type: ignore
print(f"Performance: {stats['items_per_second']:.1f} items/second") # type: ignore
print(
f"Performance: {stats['items_per_second']:.1f} items/second"
) # type: ignore
if "peak_memory_mb" in stats: # type: ignore
print(f"Peak memory: {stats['peak_memory_mb']:.1f} MB") # type: ignore

return result

async def rollback_migration(
self, migration_id: str, dry_run: bool = False, verbose: bool = False
) -> bool:
Expand Down
Loading