Generalize TestTimeAugmentation to non-spatial predictions#8715
Generalize TestTimeAugmentation to non-spatial predictions#8715ytl0623 wants to merge 5 commits intoProject-MONAI:devfrom
Conversation
Signed-off-by: ytl0623 <david89062388@gmail.com>
📝 WalkthroughWalkthroughAdds a boolean parameter Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
hi @ytl0623 thanks for this change, I think it's fine in principle. The def _check_transforms(self):
"""Should be at least 1 random transform, and all random transforms should be invertible."""
transforms = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
warns=[]
randoms=[]
for idx, t in transforms:
if isinstance(t, Randomizable):
randoms.append(t)
if self.apply_inverse_to_pred and not isinstance(t, InvertibleTransform):
warns.append(f"Transform {idx} (type {type(t).__name__}) not invertible.")
if len(randoms)==0:
warns.append("TTA usually requires at least one `Randomizable` transform in the given transform sequence.")
if len(warns)>0:
warnings.warn("TTA has encountered issues with the given transforms:"+"\n ".join(warns))Please check this logic, it might be that we need to check all transforms for invertibility whether they're random or not, but what I have here is equivalent to the original. |
Signed-off-by: ytl0623 <david89062388@gmail.com>
for more information, see https://pre-commit.ci
|
Hi @ericspod, thanks for the suggestion! |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@monai/data/test_time_augmentation.py`:
- Around line 67-71: The docstring for test-time augmentation (function/class
using parameter names transform, batch_size, and apply_inverse_to_pred)
incorrectly states "All random transforms must be of type InvertibleTransform";
update the docstring (and the transform type hint if present) to reflect that
non-invertible random transforms are allowed when apply_inverse_to_pred=False
and only need to be invertible when apply_inverse_to_pred=True; change the
wording in both occurrences (the block around the transform description and the
later paragraph at lines ~115-118) to describe this conditional requirement and,
if applicable, broaden the transform type hint to accept non-invertible
Randomizable types when apply_inverse_to_pred is False.
- Around line 174-175: The warning message built from the local variable warns
is missing a newline after the colon and does not set a stacklevel, so update
the warnings.warn call to prepend a newline (e.g., "TTA has encountered issues
with the given transforms:\n " + "\n ".join(warns)) and pass an appropriate
stacklevel (e.g., stacklevel=2) so user stack traces point to the caller; locate
and modify the warnings.warn(...) invocation that uses the warns list in
test_time_augmentation.py.
- Around line 208-213: The non-inverse branch currently returns raw predictions
and skips all Invertd post-processing (to_tensor, output_device, post_func);
update the branch so decollated items still go through the same inverter
pipeline (or a post-processing-only path) before extracting self._pred_key.
Concretely, in the else branch replace outs.extend([i[self._pred_key] for i in
decollate_batch(b)]) with code that calls self.inverter on each
PadListDataCollate.inverse(i) (or calls an Invertd method/flag that runs only
to_tensor/output_device/post_func but not spatial inverse) and then extracts
[self._pred_key]; ensure the call honors to_tensor, output_device and post_func
parameters so behavior matches the apply_inverse_to_pred=True path.
Signed-off-by: ytl0623 <david89062388@gmail.com>
e1552eb to
bf83ab8
Compare
Fixes #8276
Description
apply_inverse_to_pred. Defaults toTrueto preserve backward compatibility. When set toFalse, it skips the inverse transformation step and aggregates the model predictions directly.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.