Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,24 @@ def sliding_window_inference(
for idx in slice_range
]
if sw_batch_size > 1:
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
win_data = torch.cat(
[inputs[tuple(win_slice) if isinstance(win_slice, list) else win_slice] for win_slice in unravel_slice]
).to(sw_device)
if condition is not None:
win_condition = torch.cat([condition[win_slice] for win_slice in unravel_slice]).to(sw_device)
win_condition = torch.cat(
[
condition[tuple(win_slice) if isinstance(win_slice, list) else win_slice]
for win_slice in unravel_slice
]
).to(sw_device)
kwargs["condition"] = win_condition
else:
win_data = inputs[unravel_slice[0]].to(sw_device)
s0 = unravel_slice[0]
s0_idx = tuple(s0) if isinstance(s0, list) else s0

win_data = inputs[s0_idx].to(sw_device)
if condition is not None:
win_condition = condition[unravel_slice[0]].to(sw_device)
win_condition = condition[s0_idx].to(sw_device)
kwargs["condition"] = win_condition

if with_coord:
Expand All @@ -257,7 +267,7 @@ def sliding_window_inference(
offset = s[buffer_dim + 2].start - c_start
s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim])
s[0] = slice(0, 1)
sw_device_buffer[0][s] += p * w_t
sw_device_buffer[0][tuple(s) if isinstance(s, list) else s] += p * w_t
b_i += len(unravel_slice)
if b_i < b_slices[b_s][0]:
continue
Expand Down Expand Up @@ -288,10 +298,11 @@ def sliding_window_inference(
o_slice[buffer_dim + 2] = slice(c_start, c_end)
img_b = b_s // n_per_batch # image batch index
o_slice[0] = slice(img_b, img_b + 1)
o_slice_idx = tuple(o_slice) if isinstance(o_slice, list) else o_slice
if non_blocking:
output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking)
output_image_list[0][o_slice_idx].copy_(sw_device_buffer[0], non_blocking=non_blocking)
else:
output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device)
output_image_list[0][o_slice_idx] += sw_device_buffer[0].to(device=device)
else:
sw_device_buffer[ss] *= w_t
sw_device_buffer[ss] = sw_device_buffer[ss].to(device)
Expand Down Expand Up @@ -367,7 +378,7 @@ def _compute_coords(coords, z_scale, out, patch):
idx_zm[axis] = slice(
int(original_idx[axis].start * z_scale[axis - 2]), int(original_idx[axis].stop * z_scale[axis - 2])
)
out[idx_zm] += p
out[tuple(idx_zm)] += p


def _get_scan_interval(
Expand Down
Loading