Skip to content

Add configurable anisotropic downsampling support to AutoencoderKL an…#8856

Open
shubham-61969 wants to merge 5 commits into
Project-MONAI:devfrom
shubham-61969:8447-autoencoderkl-kernel-stride
Open

Add configurable anisotropic downsampling support to AutoencoderKL an…#8856
shubham-61969 wants to merge 5 commits into
Project-MONAI:devfrom
shubham-61969:8447-autoencoderkl-kernel-stride

Conversation

@shubham-61969
Copy link
Copy Markdown
Contributor

…d relevant testcases

Fixes #8447.

Description

This PR adds configurable anisotropic downsampling support to AutoencoderKL.

Previously, AutoencoderKL hardcoded:

  • kernel_size=3
  • stride=2
  • isotropic downsampling assumptions
  • asymmetric padding logic coupled to the default configuration

This PR introduces configurable per-level and per-dimension downsampling parameters while preserving backward compatibility and encoder-decoder spatial consistency.

Key changes:

  • Added configurable downsampling parameters for AEKLDownsample

  • Added helper utilities for:

    • parameter normalization
    • validation
    • automatic padding computation
  • Added support for anisotropic configurations such as:

    • stride=(2,2,1)
    • kernel_size=(3,3,1)
  • Removed dependency on hardcoded asymmetric padding for configurable paths

  • Updated decoder upsampling to automatically mirror encoder downsampling configuration

  • Added validation for:

    • odd kernels only
    • valid tuple lengths
    • correct number of downsampling levels
  • Added comprehensive tests covering:

    • backward compatibility
    • anisotropic 2D/3D configurations
    • per-level configurations
    • reconstruction shape consistency
    • non-power-of-two spatial dimensions
    • invalid configuration handling

This is particularly useful for medical imaging workloads with anisotropic voxel spacing such as CT and MRI volumes.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

…d relevant testcases

Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 17, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds helpers to validate/normalize per-level kernel/stride/padding and deprecates AsymmetricPad. Refactors AEKLDownsample into a conv layer accepting per-level kernel/stride/padding and adds shape-recording/restoration hooks. Encoder/Decoder accept canonicalized downsample_parameters, the encoder records pre-downsample shapes, and the decoder restores exact sizes via _ShapeRestoringUpsample (with stride-derived fallbacks). AutoencoderKL exposes downsample_parameters and wires encoder-normalized params to the decoder. load_old_state_dict remaps legacy attention and upsample keys. Tests cover anisotropic configs, validation, and reconstruction edge cases.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Title check ❓ Inconclusive Title is truncated and incomplete, ending with 'an…' rather than stating the full feature. Complete the title to clearly state the primary change, e.g., 'Add configurable anisotropic downsampling support to AutoencoderKL'.
✅ Passed checks (4 passed)
Check name Status Explanation
Description check ✅ Passed Description is comprehensive, covers all required sections, objectives, and rationale with specific examples and test coverage details.
Linked Issues check ✅ Passed All key requirements from #8447 are met: configurable kernel/stride, per-level and per-dimension support, automatic padding, validation, and anisotropic support.
Out of Scope Changes check ✅ Passed Changes focus exclusively on downsampling configurability, padding logic, and related validation—all directly addressing #8447 requirements.
Docstring Coverage ✅ Passed Docstring coverage is 84.62% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/networks/nets/autoencoderkl.py (1)

636-667: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

ConvTranspose path ignores anisotropic stride.

When use_convtranspose=True, the Upsample call doesn't receive the per-level stride and defaults to stride=2. This breaks anisotropic configurations (e.g., stride=(2,2,1)). The upsampling_stride computed on line 638 is unused in this branch, while the nontrainable path correctly applies it as scale_factor.

Proposed fix
             if use_convtranspose:
                 blocks.append(
                     Upsample(
-                        spatial_dims=spatial_dims, mode="deconv", in_channels=block_in_ch, out_channels=block_in_ch
+                        spatial_dims=spatial_dims,
+                        mode="deconv",
+                        in_channels=block_in_ch,
+                        out_channels=block_in_ch,
+                        scale_factor=tuple(float(s) for s in upsampling_stride),
                     )
                 )

Note: Anisotropic stride tests exist but don't exercise the convtranspose path, leaving this bug untested.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/nets/autoencoderkl.py` around lines 636 - 667, The
convtranspose branch for upsampling ignores the computed per-level
upsampling_stride (variable upsampling_stride) and always uses the default
stride, breaking anisotropic cases; modify the use_convtranspose branch in the
loop that builds blocks so the Upsample(...) call for mode="deconv" receives the
per-level scale/stride (e.g., pass scale_factor=tuple(float(s) for s in
upsampling_stride) or the appropriate strides argument accepted by Upsample) so
it uses the anisotropic upsampling_stride instead of the hardcoded default;
update the Upsample(...) invocation in the use_convtranspose True branch (the
block creating Upsample with mode="deconv" and in_channels=block_in_ch) to
include that scale/stride parameter.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/networks/nets/test_autoencoderkl.py`:
- Around line 560-576: The test test_validation_even_kernel_raises_error
currently fails for the wrong reason because the supplied downsample_parameters
list length doesn't match the expected number of downsampling levels for the
provided channels, so the level-count validation triggers before kernel-size
validation; update the test to supply a downsample_parameters list whose length
matches the required number of levels for AutoencoderKL (e.g., for
channels=(4,4,4) provide two dicts) and ensure at least one dict uses an even
"kernel_size" (e.g., 4) so that instantiating AutoencoderKL(...) raises the
intended ValueError about even kernel sizes rather than the level-count
mismatch.
- Around line 578-595: The test
test_validation_invalid_tuple_length_raises_error is failing because the
level-count mismatch validation runs before the tuple-length check; to reach the
tuple-length validation you must provide two downsample parameter dicts in
downsample_parameters so the number of levels matches attention_levels and
channels length, then still include invalid tuple lengths (e.g., kernel_size and
stride with only 2 elements) to trigger ValueError from AutoencoderKL; update
the downsample_params used in the test (referenced variable downsample_params
and class AutoencoderKL) to contain two dicts with the bad tuples so the
tuple-length validation is exercised.

---

Outside diff comments:
In `@monai/networks/nets/autoencoderkl.py`:
- Around line 636-667: The convtranspose branch for upsampling ignores the
computed per-level upsampling_stride (variable upsampling_stride) and always
uses the default stride, breaking anisotropic cases; modify the
use_convtranspose branch in the loop that builds blocks so the Upsample(...)
call for mode="deconv" receives the per-level scale/stride (e.g., pass
scale_factor=tuple(float(s) for s in upsampling_stride) or the appropriate
strides argument accepted by Upsample) so it uses the anisotropic
upsampling_stride instead of the hardcoded default; update the Upsample(...)
invocation in the use_convtranspose True branch (the block creating Upsample
with mode="deconv" and in_channels=block_in_ch) to include that scale/stride
parameter.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 8b8b88a6-f7ce-47d8-b866-6f242adcbc65

📥 Commits

Reviewing files that changed from the base of the PR and between 2a98f63 and 0a90773.

📒 Files selected for processing (2)
  • monai/networks/nets/autoencoderkl.py
  • tests/networks/nets/test_autoencoderkl.py

Comment thread tests/networks/nets/test_autoencoderkl.py
Comment thread tests/networks/nets/test_autoencoderkl.py
Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/networks/nets/autoencoderkl.py (1)

668-692: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Pass anisotropic stride to deconv branch.

The deconv upsampling ignores upsampling_stride (line 667) and defaults to isotropic ×2, while the nontrainable branch correctly passes scale_factor=tuple(float(s) for s in upsampling_stride). For anisotropic configs like (2, 2, 1), deconv will upscale incorrectly.

Add scale_factor=tuple(float(s) for s in upsampling_stride) to the deconv Upsample call (line 668-673).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/nets/autoencoderkl.py` around lines 668 - 692, The deconv
branch inside the upsampling construction (when use_convtranspose is True)
currently creates an Upsample(mode="deconv", ...) but omits the anisotropic
upsampling factor; pass the same computed scale factor used by the nontrainable
branch by adding scale_factor=tuple(float(s) for s in upsampling_stride) to that
Upsample(...) call so Upsample(mode="deconv", ...) uses the correct anisotropic
upsampling_stride value (refer to the Upsample instantiation, use_convtranspose
flag, and the upsampling_stride variable).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@monai/networks/nets/autoencoderkl.py`:
- Around line 130-139: The current None-handling for downsample_parameters
normalizes to symmetric padding via
_validate_kernel_stride_parameters/_compute_padding, changing legacy behavior;
revert to special-casing the legacy default when downsample_parameters is None
by returning per-level entries that match the original AsymmetricPad + Conv
semantics (kernel_size=3, stride=2, padding=0) instead of computed symmetric
padding—use the symbols downsample_parameters,
default_kernel_size/default_stride, spatial_dims and num_levels to locate the
branch and ensure each returned dict keeps padding=0 (caller is expected to
apply the AsymmetricPad((0,1)*spatial_dims) behavior externally) so existing
checkpoints keep the same behavior.
- Around line 85-99: The current _compute_padding that returns padding = tuple(k
// 2 for k in kernel_size) produces symmetric padding only and does not preserve
spatial sizes for non-divisible inputs; update the encoder/decoder to record
per-stage spatial outputs (target sizes) during encoding and use those targets
in reconstruct() to compute and apply either per-stage output_padding for
ConvTranspose (based on stride and recorded encoder sizes) or explicit cropping
after upsampling, rather than relying on fixed symmetric padding. Specifically,
modify the code paths around _compute_padding and the encoder forward that
produces ceil(n/stride) to store each intermediate spatial shape, and update the
decoder/ConvTranspose reconstruction logic (where output_padding or cropping is
applied) to use those stored sizes to guarantee exact recovery for
kernels/strides such as kernel=3,stride=2 (also fix the similar logic referenced
at lines 691-692).

---

Outside diff comments:
In `@monai/networks/nets/autoencoderkl.py`:
- Around line 668-692: The deconv branch inside the upsampling construction
(when use_convtranspose is True) currently creates an Upsample(mode="deconv",
...) but omits the anisotropic upsampling factor; pass the same computed scale
factor used by the nontrainable branch by adding scale_factor=tuple(float(s) for
s in upsampling_stride) to that Upsample(...) call so Upsample(mode="deconv",
...) uses the correct anisotropic upsampling_stride value (refer to the Upsample
instantiation, use_convtranspose flag, and the upsampling_stride variable).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 43f2a9ce-100e-4c3b-a9b2-baea9721e870

📥 Commits

Reviewing files that changed from the base of the PR and between 0a90773 and 6bc4ebc.

📒 Files selected for processing (2)
  • monai/networks/nets/autoencoderkl.py
  • tests/networks/nets/test_autoencoderkl.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/networks/nets/test_autoencoderkl.py

Comment thread monai/networks/nets/autoencoderkl.py
Comment thread monai/networks/nets/autoencoderkl.py
…ed some relavent testcases

Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
tests/networks/nets/test_autoencoderkl.py (2)

951-951: 💤 Low value

Unused unpacked variables z_mu, z_sigma.

Same as above - prefix with underscore.

Proposed fix
-            reconstruction, z_mu, z_sigma = net(x)
+            reconstruction, _z_mu, _z_sigma = net(x)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/networks/nets/test_autoencoderkl.py` at line 951, The test unpacks
three values from net(x) into reconstruction, z_mu, z_sigma but z_mu and z_sigma
are unused; update the unpacking in the test to mark those as intentionally
unused (e.g., use _z_mu and _z_sigma or _ , _ ) so linter warnings are
suppressed—locate the unpack in tests/networks/nets/test_autoencoderkl.py where
the call net(x) assigns to reconstruction, z_mu, z_sigma and change the unused
variable names accordingly.

804-804: 💤 Low value

Unused unpacked variables z_mu, z_sigma.

Prefix with underscore to indicate intentionally unused.

Proposed fix
-            reconstruction, z_mu, z_sigma = net(x)
+            reconstruction, _z_mu, _z_sigma = net(x)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/networks/nets/test_autoencoderkl.py` at line 804, The test unpacks
three values from net(x) but never uses z_mu and z_sigma; change the unpacking
in the test from "reconstruction, z_mu, z_sigma = net(x)" to use
underscore-prefixed names (e.g., "reconstruction, _z_mu, _z_sigma = net(x)") so
the intent of intentionally unused values is clear and linter warnings are
suppressed; locate the unpack in tests/networks/nets/test_autoencoderkl.py where
net(x) is called and update the variable names accordingly.
monai/networks/nets/autoencoderkl.py (1)

230-247: 💤 Low value

Missing docstring for __init__ parameters.

Per coding guidelines, docstrings should describe each parameter. The __init__ method lacks an Args section.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/nets/autoencoderkl.py` around lines 230 - 247, Add a docstring
to the __init__ method of the AutoEncoderKL module that documents each parameter
and their types/meaning (spatial_dims, in_channels, out_channels, post_conv,
shape_index, downsample_shapes_ref, scale_factor) and any important notes (e.g.,
downsample_shapes_ref is a shared list reference and not a module); place it
immediately below def __init__ and use an Args section describing each parameter
and any default/optional behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@monai/networks/nets/autoencoderkl.py`:
- Around line 230-247: Add a docstring to the __init__ method of the
AutoEncoderKL module that documents each parameter and their types/meaning
(spatial_dims, in_channels, out_channels, post_conv, shape_index,
downsample_shapes_ref, scale_factor) and any important notes (e.g.,
downsample_shapes_ref is a shared list reference and not a module); place it
immediately below def __init__ and use an Args section describing each parameter
and any default/optional behavior.

In `@tests/networks/nets/test_autoencoderkl.py`:
- Line 951: The test unpacks three values from net(x) into reconstruction, z_mu,
z_sigma but z_mu and z_sigma are unused; update the unpacking in the test to
mark those as intentionally unused (e.g., use _z_mu and _z_sigma or _ , _ ) so
linter warnings are suppressed—locate the unpack in
tests/networks/nets/test_autoencoderkl.py where the call net(x) assigns to
reconstruction, z_mu, z_sigma and change the unused variable names accordingly.
- Line 804: The test unpacks three values from net(x) but never uses z_mu and
z_sigma; change the unpacking in the test from "reconstruction, z_mu, z_sigma =
net(x)" to use underscore-prefixed names (e.g., "reconstruction, _z_mu, _z_sigma
= net(x)") so the intent of intentionally unused values is clear and linter
warnings are suppressed; locate the unpack in
tests/networks/nets/test_autoencoderkl.py where net(x) is called and update the
variable names accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a0d65d18-a62e-4d0e-846d-b93270e70317

📥 Commits

Reviewing files that changed from the base of the PR and between 6bc4ebc and 9ca1550.

📒 Files selected for processing (2)
  • monai/networks/nets/autoencoderkl.py
  • tests/networks/nets/test_autoencoderkl.py

Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/networks/nets/autoencoderkl.py (1)

956-974: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

downsample_shapes list grows unboundedly across forward passes.

Each encode() call appends shapes to self.encoder.downsample_shapes. The list is never cleared. While the decoder's reverse-indexing (len - 1 - idx) correctly retrieves the most recent shapes, the list accumulates indefinitely.

Clear the list at the start of encoding to prevent memory growth during repeated inference.

Proposed fix
     def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+        # Clear recorded shapes from previous forward passes
+        self.encoder.downsample_shapes.clear()
         if self.use_checkpoint:
             h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/nets/autoencoderkl.py` around lines 956 - 974, The encoder
appends to self.encoder.downsample_shapes on each call causing unbounded growth;
modify encode (method encode) to clear or reset self.encoder.downsample_shapes
at the start (e.g., self.encoder.downsample_shapes.clear() or assign [] ) so
repeated calls don't accumulate shapes while keeping the decoder's
reverse-indexing (len(self.encoder.downsample_shapes) - 1 - idx) logic
unchanged.
🧹 Nitpick comments (1)
monai/networks/nets/autoencoderkl.py (1)

210-220: 💤 Low value

Missing docstrings for __init__ and forward methods.

Per guidelines, definitions should have Google-style docstrings. However, this is a private helper class with straightforward semantics.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/nets/autoencoderkl.py` around lines 210 - 220, Add
Google-style docstrings for the private helper class _RecordShapeHook's methods:
document __init__ to describe the shape_list parameter (a list of spatial-shape
tuples collected during encoding) and side effects, and document forward to
describe the input tensor x, the fact it appends tuple(x.shape[2:]) to
self.shape_list, and that it returns the unchanged tensor; keep the docstrings
concise and private-method style since the class is internal.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@monai/networks/nets/autoencoderkl.py`:
- Around line 956-974: The encoder appends to self.encoder.downsample_shapes on
each call causing unbounded growth; modify encode (method encode) to clear or
reset self.encoder.downsample_shapes at the start (e.g.,
self.encoder.downsample_shapes.clear() or assign [] ) so repeated calls don't
accumulate shapes while keeping the decoder's reverse-indexing
(len(self.encoder.downsample_shapes) - 1 - idx) logic unchanged.

---

Nitpick comments:
In `@monai/networks/nets/autoencoderkl.py`:
- Around line 210-220: Add Google-style docstrings for the private helper class
_RecordShapeHook's methods: document __init__ to describe the shape_list
parameter (a list of spatial-shape tuples collected during encoding) and side
effects, and document forward to describe the input tensor x, the fact it
appends tuple(x.shape[2:]) to self.shape_list, and that it returns the unchanged
tensor; keep the docstrings concise and private-method style since the class is
internal.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 151f4733-5707-471b-8bab-9ec8502849a5

📥 Commits

Reviewing files that changed from the base of the PR and between 9ca1550 and 14b691e.

📒 Files selected for processing (2)
  • monai/networks/nets/autoencoderkl.py
  • tests/networks/nets/test_autoencoderkl.py

Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@monai/networks/nets/autoencoderkl.py`:
- Around line 976-977: The encoder's recorded downsample_shapes must be
single-use to prevent stale sizes influencing later standalone decode calls:
ensure that after encoding we still clear encoder.downsample_shapes (existing
clear is fine) and additionally consume/clear downsample_shapes_ref when used
during upsampling so decode() and decode_stage_2_outputs() do not retain prior
shapes; update _ShapeRestoringUpsample usage to pop/consume the stored shape (or
clear encoder.downsample_shapes) immediately after it's applied. Add a
regression unit test that runs forward(x1) then standalone decode(z2) where x1
and z2 have different spatial shapes and assert the decoded output uses z2's
expected spatial dimensions (i.e., decode respects scale_factor/fallback when no
stale shapes remain).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 9b2df7bd-2d86-446b-9cd2-3a0b1b3daaa7

📥 Commits

Reviewing files that changed from the base of the PR and between 14b691e and d20d5c4.

📒 Files selected for processing (1)
  • monai/networks/nets/autoencoderkl.py

Comment on lines +976 to +977
# Clear shape list before encoding to avoid unbounded growth across forward passes
self.encoder.downsample_shapes.clear()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Make recorded shapes single-use for decode.

Clearing downsample_shapes only here still leaves stale encoder sizes behind for later standalone decode() / decode_stage_2_outputs() calls. _ShapeRestoringUpsample prefers downsample_shapes_ref over scale_factor, so a latent decoded after any prior encode() can be resized to the previous input shape instead of using the fallback path.

Minimal fix sketch
 class AutoencoderKL(nn.Module):
     def __init__(self, ...):
         ...
+        self._has_fresh_downsample_shapes = False

     def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
         # Clear shape list before encoding to avoid unbounded growth across forward passes
         self.encoder.downsample_shapes.clear()
+        self._has_fresh_downsample_shapes = True
         ...

     def decode(self, z: torch.Tensor) -> torch.Tensor:
+        if not self._has_fresh_downsample_shapes:
+            self.encoder.downsample_shapes.clear()
         z = self.post_quant_conv(z)
-        if self.use_checkpoint:
-            dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False)
-        else:
-            dec = self.decoder(z)
+        try:
+            if self.use_checkpoint:
+                dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False)
+            else:
+                dec = self.decoder(z)
+        finally:
+            self._has_fresh_downsample_shapes = False
         return dec

Please add a regression test for forward(x1) followed by standalone decode(z2) from a different spatial shape.

As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Clear shape list before encoding to avoid unbounded growth across forward passes
self.encoder.downsample_shapes.clear()
class AutoencoderKL(nn.Module):
def __init__(self, ...):
...
self._has_fresh_downsample_shapes = False
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# Clear shape list before encoding to avoid unbounded growth across forward passes
self.encoder.downsample_shapes.clear()
self._has_fresh_downsample_shapes = True
...
def decode(self, z: torch.Tensor) -> torch.Tensor:
if not self._has_fresh_downsample_shapes:
self.encoder.downsample_shapes.clear()
z = self.post_quant_conv(z)
try:
if self.use_checkpoint:
dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False)
else:
dec = self.decoder(z)
finally:
self._has_fresh_downsample_shapes = False
return dec
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/nets/autoencoderkl.py` around lines 976 - 977, The encoder's
recorded downsample_shapes must be single-use to prevent stale sizes influencing
later standalone decode calls: ensure that after encoding we still clear
encoder.downsample_shapes (existing clear is fine) and additionally
consume/clear downsample_shapes_ref when used during upsampling so decode() and
decode_stage_2_outputs() do not retain prior shapes; update
_ShapeRestoringUpsample usage to pop/consume the stored shape (or clear
encoder.downsample_shapes) immediately after it's applied. Add a regression unit
test that runs forward(x1) then standalone decode(z2) where x1 and z2 have
different spatial shapes and assert the decoded output uses z2's expected
spatial dimensions (i.e., decode respects scale_factor/fallback when no stale
shapes remain).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

AutoencoderKL does not allow to modify the kernel size and stride

1 participant