Skip to content

Conversation

@getrichthroughcode
Copy link

  • Fix dimension comparison to use spatial dims instead of total dims
  • Add validation for minimum input dimensions
  • Fix typo in error message (ken_spatial_sigma -> len_spatial_sigma)
  • Move spatial dimension validation before unsqueeze operations

The forward() method was incorrectly comparing self.len_spatial_sigma (number of spatial dimensions) with len(input_tensor.shape) (total dimensions including batch and channel), causing valid 3D inputs to be rejected.

Fixes #7444

Description

This PR fixes a validation bug in TrainableBilateralFilter that incorrectly rejected valid 3D inputs with shape (B, C, H, W, D).

Root Cause: The forward() method compared self.len_spatial_sigma (spatial dimensions = 3) with len(input_tensor.shape) (total dimensions = 5), causing a dimension mismatch error for valid inputs.

Solution: Calculate spatial_dims = len(input_tensor.shape) - 2 to exclude batch and channel dimensions, then compare against self.len_spatial_sigma.

Example of fixed behavior:

# Previously failed, now works
bf = TrainableBilateralFilter([1.0, 1.0, 1.0], 1.0)
x = torch.randn(1, 1, 10, 10, 10)  # (B, C, H, W, D)
out = bf(x)  # Success!

This fix also improves error messages and adds validation for inputs with insufficient dimensions.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Notes on Testing

The existing unit tests for TrainableBilateralFilter (24 tests) require the C++ extension and were skipped locally (expected behavior with @skip_if_no_cpp_extension decorator). These tests will run automatically in CI.

I verified the fix logic with custom local tests for 1D, 2D, and 3D cases (see examples in description above).

Linting and code formatting checks passed:

./runtests.sh --autofix     # Passed
./runtests.sh --codeformat  # Passed

No new tests were added as the existing 24 unit tests already cover the behavior. No docstring or documentation changes were needed as this is purely a bug fix in validation logic.

- Fix dimension comparison to use spatial dims instead of total dims
- Add validation for minimum input dimensions
- Fix typo in error message (ken_spatial_sigma -> len_spatial_sigma)
- Move spatial dimension validation before unsqueeze operations

The forward() method was incorrectly comparing self.len_spatial_sigma
(number of spatial dimensions) with len(input_tensor.shape) (total
dimensions including batch and channel), causing valid 3D inputs to
be rejected.

Fixes Project-MONAI#7444

Signed-off-by: Abdoulaye Diallo <abdoulayediallo338@gmail.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 1, 2026

📝 Walkthrough

Walkthrough

This change fixes dimension handling in TrainableBilateralFilter and TrainableJointBilateralFilter. The fix corrects a variable reference error in error messages (ken_spatial_sigma → len_spatial_sigma), adds validation requiring minimum 3 input dimensions, and refactors dimension-checking logic to properly distinguish between batch, channel, and spatial dimensions. The changes enable proper support for 3D images by aligning dimension comparisons with spatial dimensionality rather than total input dimensions.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly describes the main fix: correcting 3D input validation in TrainableBilateralFilter by addressing dimension comparison logic.
Description check ✅ Passed Description includes root cause analysis, solution explanation, code examples, and testing notes; follows template structure with objectives clearly stated.
Linked Issues check ✅ Passed PR directly addresses issue #7444: fixes spatial-dimension comparison logic, adds input validation, corrects typo, and enables valid 3D inputs (B,C,H,W,D) to work correctly.
Out of Scope Changes check ✅ Passed All changes are scoped to TrainableBilateralFilter and TrainableJointBilateralFilter validation logic; no extraneous modifications detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/networks/layers/filtering.py (1)

406-430: ⚠️ Potential issue | 🟠 Major

TrainableJointBilateralFilter.forward() not updated with the same fix.

This method still uses len_input directly instead of computing spatial_dims = len_input - 2. It will reject valid 3D inputs just like the original bug in TrainableBilateralFilter. Also missing the minimum dimension validation added to the other class.

Proposed fix
     def forward(self, input_tensor, guidance_tensor):
+        if len(input_tensor.shape) < 3:
+            raise ValueError(
+                f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}"
+            )
         if input_tensor.shape[1] != 1:
             raise ValueError(
                 f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. "
                 "Please use multiple parallel filter layers if you want "
                 "to filter multiple channels."
             )
         if input_tensor.shape != guidance_tensor.shape:
             raise ValueError(
                 "Shape of input image must equal shape of guidance image."
                 f"Got {input_tensor.shape} and {guidance_tensor.shape}."
             )

         len_input = len(input_tensor.shape)
+        spatial_dims = len_input - 2

         # C++ extension so far only supports 5-dim inputs.
-        if len_input == 3:
+        if spatial_dims == 1:
             input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)
             guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4)
-        elif len_input == 4:
+        elif spatial_dims == 2:
             input_tensor = input_tensor.unsqueeze(4)
             guidance_tensor = guidance_tensor.unsqueeze(4)

-        if self.len_spatial_sigma != len_input:
-            raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).")
+        if self.len_spatial_sigma != spatial_dims:
+            raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).")

         prediction = TrainableJointBilateralFilterFunction.apply(
             input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color
         )

         # Make sure to return tensor of the same shape as the input.
-        if len_input == 3:
+        if spatial_dims == 1:
             prediction = prediction.squeeze(4).squeeze(3)
-        elif len_input == 4:
+        elif spatial_dims == 2:
             prediction = prediction.squeeze(4)

         return prediction
🤖 Fix all issues with AI agents
In `@monai/networks/layers/filtering.py`:
- Around line 223-225: The error message uses self.len_spatial_sigma which is
not assigned in the branch; fix by referencing the actual expected spatial
dimension attribute or ensuring self.len_spatial_sigma is initialized before
this check: either assign self.len_spatial_sigma = self.spatial_ndim (or the
class's existing spatial-dimension attribute) earlier in the initializer, or
change the ValueError message to use the computed expected dimension (e.g.,
self.spatial_ndim or len(self.spatial_shape)) instead of self.len_spatial_sigma
so the attribute is defined when raising the error in the spatial_sigma
validation.
- Around line 395-398: The else branch references an undefined attribute
self.len_spatial_sigma; fix it by using a defined value (e.g., compute
len_spatial = len(self.spatial_sigma) or use self.spatial_ndim) when building
the error message in the failing branch of the initializer (same place as
TrainableBilateralFilter.__init__). Replace self.len_spatial_sigma with the
actual computed length (len(self.spatial_sigma) or self.spatial_ndim) so the
ValueError message prints a valid expected-dimension value.

Comment on lines 223 to 225
raise ValueError(
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}."
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: self.len_spatial_sigma is undefined when this branch executes.

The else branch runs when len(spatial_sigma) is not 1, 2, or 3. At that point self.len_spatial_sigma has never been assigned, so this raises AttributeError before the intended ValueError.

Proposed fix
         else:
             raise ValueError(
-                f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}."
+                f"len(spatial_sigma) must be 1, 2, or 3, got {len(spatial_sigma)}."
             )
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 223-225: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@monai/networks/layers/filtering.py` around lines 223 - 225, The error message
uses self.len_spatial_sigma which is not assigned in the branch; fix by
referencing the actual expected spatial dimension attribute or ensuring
self.len_spatial_sigma is initialized before this check: either assign
self.len_spatial_sigma = self.spatial_ndim (or the class's existing
spatial-dimension attribute) earlier in the initializer, or change the
ValueError message to use the computed expected dimension (e.g.,
self.spatial_ndim or len(self.spatial_shape)) instead of self.len_spatial_sigma
so the attribute is defined when raising the error in the spatial_sigma
validation.

Comment on lines 395 to 398
else:
raise ValueError(
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}."
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Same bug: self.len_spatial_sigma undefined in else branch.

Identical issue as TrainableBilateralFilter.__init__.

Proposed fix
         else:
             raise ValueError(
-                f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}."
+                f"len(spatial_sigma) must be 1, 2, or 3, got {len(spatial_sigma)}."
             )
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 396-398: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@monai/networks/layers/filtering.py` around lines 395 - 398, The else branch
references an undefined attribute self.len_spatial_sigma; fix it by using a
defined value (e.g., compute len_spatial = len(self.spatial_sigma) or use
self.spatial_ndim) when building the error message in the failing branch of the
initializer (same place as TrainableBilateralFilter.__init__). Replace
self.len_spatial_sigma with the actual computed length (len(self.spatial_sigma)
or self.spatial_ndim) so the ValueError message prints a valid
expected-dimension value.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

can't use TrainableBilateralFilter for 3d image

1 participant