From 8d0df403ca247bf483e28c34d64df13dda158cc8 Mon Sep 17 00:00:00 2001 From: Kared <101872541+huarzone@users.noreply.github.com> Date: Tue, 27 Jan 2026 03:55:36 +0000 Subject: [PATCH] fix wan i2v train bug --- diffsynth/diffusion/loss.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py index ae44bb683..14fdfd3be 100644 --- a/diffsynth/diffusion/loss.py +++ b/diffsynth/diffusion/loss.py @@ -13,9 +13,16 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) + if "first_frame_latents" in inputs: + inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"] + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep) + if "first_frame_latents" in inputs: + noise_pred = noise_pred[:, :, 1:] + training_target = training_target[:, :, 1:] + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) loss = loss * pipe.scheduler.training_weight(timestep) return loss