diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py index ae44bb683..14fdfd3be 100644 --- a/diffsynth/diffusion/loss.py +++ b/diffsynth/diffusion/loss.py @@ -13,9 +13,16 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) + if "first_frame_latents" in inputs: + inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"] + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep) + if "first_frame_latents" in inputs: + noise_pred = noise_pred[:, :, 1:] + training_target = training_target[:, :, 1:] + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) loss = loss * pipe.scheduler.training_weight(timestep) return loss