diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 766486a807..2ef754ddac 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -223,14 +223,24 @@ def sliding_window_inference( for idx in slice_range ] if sw_batch_size > 1: - win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) + win_data = torch.cat( + [inputs[tuple(win_slice) if isinstance(win_slice, list) else win_slice] for win_slice in unravel_slice] + ).to(sw_device) if condition is not None: - win_condition = torch.cat([condition[win_slice] for win_slice in unravel_slice]).to(sw_device) + win_condition = torch.cat( + [ + condition[tuple(win_slice) if isinstance(win_slice, list) else win_slice] + for win_slice in unravel_slice + ] + ).to(sw_device) kwargs["condition"] = win_condition else: - win_data = inputs[unravel_slice[0]].to(sw_device) + s0 = unravel_slice[0] + s0_idx = tuple(s0) if isinstance(s0, list) else s0 + + win_data = inputs[s0_idx].to(sw_device) if condition is not None: - win_condition = condition[unravel_slice[0]].to(sw_device) + win_condition = condition[s0_idx].to(sw_device) kwargs["condition"] = win_condition if with_coord: @@ -257,7 +267,7 @@ def sliding_window_inference( offset = s[buffer_dim + 2].start - c_start s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim]) s[0] = slice(0, 1) - sw_device_buffer[0][s] += p * w_t + sw_device_buffer[0][tuple(s) if isinstance(s, list) else s] += p * w_t b_i += len(unravel_slice) if b_i < b_slices[b_s][0]: continue @@ -288,10 +298,11 @@ def sliding_window_inference( o_slice[buffer_dim + 2] = slice(c_start, c_end) img_b = b_s // n_per_batch # image batch index o_slice[0] = slice(img_b, img_b + 1) + o_slice_idx = tuple(o_slice) if isinstance(o_slice, list) else o_slice if non_blocking: - output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking) + output_image_list[0][o_slice_idx].copy_(sw_device_buffer[0], non_blocking=non_blocking) else: - output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device) + output_image_list[0][o_slice_idx] += sw_device_buffer[0].to(device=device) else: sw_device_buffer[ss] *= w_t sw_device_buffer[ss] = sw_device_buffer[ss].to(device) @@ -367,7 +378,7 @@ def _compute_coords(coords, z_scale, out, patch): idx_zm[axis] = slice( int(original_idx[axis].start * z_scale[axis - 2]), int(original_idx[axis].stop * z_scale[axis - 2]) ) - out[idx_zm] += p + out[tuple(idx_zm)] += p def _get_scan_interval(