Skip to content

Commit 3374e90

Browse files
Faster workflow cancelling. (Comfy-Org#10301)
1 parent 51696e3 commit 3374e90

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

comfy/ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import comfy.rmsnorm
2525
import contextlib
2626

27+
def run_every_op():
28+
comfy.model_management.throw_exception_if_processing_interrupted()
2729

2830
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
2931
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
@@ -109,6 +111,7 @@ def forward_comfy_cast_weights(self, input):
109111
return torch.nn.functional.linear(input, weight, bias)
110112

111113
def forward(self, *args, **kwargs):
114+
run_every_op()
112115
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
113116
return self.forward_comfy_cast_weights(*args, **kwargs)
114117
else:
@@ -123,6 +126,7 @@ def forward_comfy_cast_weights(self, input):
123126
return self._conv_forward(input, weight, bias)
124127

125128
def forward(self, *args, **kwargs):
129+
run_every_op()
126130
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
127131
return self.forward_comfy_cast_weights(*args, **kwargs)
128132
else:
@@ -137,6 +141,7 @@ def forward_comfy_cast_weights(self, input):
137141
return self._conv_forward(input, weight, bias)
138142

139143
def forward(self, *args, **kwargs):
144+
run_every_op()
140145
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
141146
return self.forward_comfy_cast_weights(*args, **kwargs)
142147
else:
@@ -151,6 +156,7 @@ def forward_comfy_cast_weights(self, input):
151156
return self._conv_forward(input, weight, bias)
152157

153158
def forward(self, *args, **kwargs):
159+
run_every_op()
154160
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
155161
return self.forward_comfy_cast_weights(*args, **kwargs)
156162
else:
@@ -165,6 +171,7 @@ def forward_comfy_cast_weights(self, input):
165171
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
166172

167173
def forward(self, *args, **kwargs):
174+
run_every_op()
168175
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
169176
return self.forward_comfy_cast_weights(*args, **kwargs)
170177
else:
@@ -183,6 +190,7 @@ def forward_comfy_cast_weights(self, input):
183190
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
184191

185192
def forward(self, *args, **kwargs):
193+
run_every_op()
186194
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
187195
return self.forward_comfy_cast_weights(*args, **kwargs)
188196
else:
@@ -202,6 +210,7 @@ def forward_comfy_cast_weights(self, input):
202210
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
203211

204212
def forward(self, *args, **kwargs):
213+
run_every_op()
205214
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
206215
return self.forward_comfy_cast_weights(*args, **kwargs)
207216
else:
@@ -223,6 +232,7 @@ def forward_comfy_cast_weights(self, input, output_size=None):
223232
output_padding, self.groups, self.dilation)
224233

225234
def forward(self, *args, **kwargs):
235+
run_every_op()
226236
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
227237
return self.forward_comfy_cast_weights(*args, **kwargs)
228238
else:
@@ -244,6 +254,7 @@ def forward_comfy_cast_weights(self, input, output_size=None):
244254
output_padding, self.groups, self.dilation)
245255

246256
def forward(self, *args, **kwargs):
257+
run_every_op()
247258
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
248259
return self.forward_comfy_cast_weights(*args, **kwargs)
249260
else:
@@ -262,6 +273,7 @@ def forward_comfy_cast_weights(self, input, out_dtype=None):
262273
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
263274

264275
def forward(self, *args, **kwargs):
276+
run_every_op()
265277
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
266278
return self.forward_comfy_cast_weights(*args, **kwargs)
267279
else:

0 commit comments

Comments
 (0)