From 9e00ce5b76ec04be37375310512a443605b95077 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 20 Nov 2025 14:42:46 -0800 Subject: [PATCH 1/5] Make Batch Images node add alpha channel when one of the inputs has it (#10816) * When one Batch Image input has alpha and one does not, add empty alpha channel * Use torch.nn.functional.pad --- nodes.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 75e820e663ba..0303716338aa 100644 --- a/nodes.py +++ b/nodes.py @@ -1853,9 +1853,10 @@ def INPUT_TYPES(s): def batch(self, image1, image2): if image1.shape[-1] != image2.shape[-1]: - channels = min(image1.shape[-1], image2.shape[-1]) - image1 = image1[..., :channels] - image2 = image2[..., :channels] + if image1.shape[-1] > image2.shape[-1]: + image2 = torch.nn.functional.pad(image2, (0,1), mode='constant', value=1.0) + else: + image1 = torch.nn.functional.pad(image1, (0,1), mode='constant', value=1.0) if image1.shape[1:] != image2.shape[1:]: image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1) s = torch.cat((image1, image2), dim=0) From 7b8389578e88dcd13b1cf6aea5404047298c9183 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 21 Nov 2025 02:17:47 +0200 Subject: [PATCH 2/5] feat(api-nodes): add Nano Banana Pro (#10814) * feat(api-nodes): add Nano Banana Pro * frontend bump to 1.28.9 --- comfy_api_nodes/apis/gemini_api.py | 5 +- comfy_api_nodes/nodes_gemini.py | 205 ++++++++++++++++++++++++++++- comfy_api_nodes/util/client.py | 13 +- requirements.txt | 2 +- 4 files changed, 215 insertions(+), 10 deletions(-) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index f63e026935b6..710f173f1204 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -68,7 +68,7 @@ class GeminiTextPart(BaseModel): class GeminiContent(BaseModel): - parts: list[GeminiPart] = Field(...) + parts: list[GeminiPart] = Field([]) role: GeminiRole = Field(..., examples=["user"]) @@ -120,7 +120,7 @@ class GeminiGenerationConfig(BaseModel): class GeminiImageConfig(BaseModel): aspectRatio: str | None = Field(None) - resolution: str | None = Field(None) + imageSize: str | None = Field(None) class GeminiImageGenerationConfig(GeminiGenerationConfig): @@ -227,3 +227,4 @@ class GeminiGenerateContentResponse(BaseModel): candidates: list[GeminiCandidate] | None = Field(None) promptFeedback: GeminiPromptFeedback | None = Field(None) usageMetadata: GeminiUsageMetadata | None = Field(None) + modelVersion: str | None = Field(None) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 6e746eebd1cd..be752c885dd2 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -29,11 +29,13 @@ GeminiMimeType, GeminiPart, GeminiRole, + Modality, ) from comfy_api_nodes.util import ( ApiEndpoint, audio_to_base64_string, bytesio_to_image_tensor, + get_number_of_images, sync_op, tensor_to_base64_string, validate_string, @@ -147,6 +149,49 @@ def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Te return torch.cat(image_tensors, dim=0) +def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None: + if not response.modelVersion: + return None + # Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing + if response.modelVersion in ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"): + input_tokens_price = 1.25 + output_text_tokens_price = 10.0 + output_image_tokens_price = 0.0 + elif response.modelVersion in ( + "gemini-2.5-flash-preview-04-17", + "gemini-2.5-flash", + ): + input_tokens_price = 0.30 + output_text_tokens_price = 2.50 + output_image_tokens_price = 0.0 + elif response.modelVersion in ( + "gemini-2.5-flash-image-preview", + "gemini-2.5-flash-image", + ): + input_tokens_price = 0.30 + output_text_tokens_price = 2.50 + output_image_tokens_price = 30.0 + elif response.modelVersion == "gemini-3-pro-preview": + input_tokens_price = 2 + output_text_tokens_price = 12.0 + output_image_tokens_price = 0.0 + elif response.modelVersion == "gemini-3-pro-image-preview": + input_tokens_price = 2 + output_text_tokens_price = 12.0 + output_image_tokens_price = 120.0 + else: + return None + final_price = response.usageMetadata.promptTokenCount * input_tokens_price + for i in response.usageMetadata.candidatesTokensDetails: + if i.modality == Modality.IMAGE: + final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models + else: + final_price += output_text_tokens_price * i.tokenCount + if response.usageMetadata.thoughtsTokenCount: + final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount + return final_price / 1_000_000.0 + + class GeminiNode(IO.ComfyNode): """ Node to generate text responses from a Gemini model. @@ -314,6 +359,7 @@ async def execute( ] ), response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, ) output_text = get_text_from_response(response) @@ -476,6 +522,13 @@ def define_schema(cls): "or otherwise generates 1:1 squares.", optional=True, ), + IO.Combo.Input( + "response_modalities", + options=["IMAGE+TEXT", "IMAGE"], + tooltip="Choose 'IMAGE' for image-only output, or " + "'IMAGE+TEXT' to return both the generated image and a text response.", + optional=True, + ), ], outputs=[ IO.Image.Output(), @@ -498,6 +551,7 @@ async def execute( images: torch.Tensor | None = None, files: list[GeminiPart] | None = None, aspect_ratio: str = "auto", + response_modalities: str = "IMAGE+TEXT", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) parts: list[GeminiPart] = [GeminiPart(text=prompt)] @@ -520,17 +574,16 @@ async def execute( GeminiContent(role=GeminiRole.user, parts=parts), ], generationConfig=GeminiImageGenerationConfig( - responseModalities=["TEXT", "IMAGE"], + responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), imageConfig=None if aspect_ratio == "auto" else image_config, ), ), response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, ) - output_image = get_image_from_response(response) output_text = get_text_from_response(response) if output_text: - # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. render_spec = { "node_id": cls.hidden.unique_id, "component": "ChatHistoryWidget", @@ -551,9 +604,150 @@ async def execute( "display_component", render_spec, ) + return IO.NodeOutput(get_image_from_response(response), output_text) + + +class GeminiImage2(IO.ComfyNode): - output_text = output_text or "Empty response from Gemini model..." - return IO.NodeOutput(output_image, output_text) + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiImage2Node", + display_name="Nano Banana Pro (Google Gemini Image)", + category="api node/image/Gemini", + description="Generate or edit images synchronously via Google Vertex API.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text prompt describing the image to generate or the edits to apply. " + "Include any constraints, styles, or details the model should follow.", + default="", + ), + IO.Combo.Input( + "model", + options=["gemini-3-pro-image-preview"], + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], + default="auto", + tooltip="If set to 'auto', matches your input image's aspect ratio; " + "if no image is provided, generates a 1:1 square.", + ), + IO.Combo.Input( + "resolution", + options=["1K", "2K", "4K"], + tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.", + ), + IO.Combo.Input( + "response_modalities", + options=["IMAGE+TEXT", "IMAGE"], + tooltip="Choose 'IMAGE' for image-only output, or " + "'IMAGE+TEXT' to return both the generated image and a text response.", + ), + IO.Image.Input( + "images", + optional=True, + tooltip="Optional reference image(s). " + "To include multiple images, use the Batch Images node (up to 14).", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: str, + seed: int, + aspect_ratio: str, + resolution: str, + response_modalities: str, + images: torch.Tensor | None = None, + files: list[GeminiPart] | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + + parts: list[GeminiPart] = [GeminiPart(text=prompt)] + if images is not None: + if get_number_of_images(images) > 14: + raise ValueError("The current maximum number of supported images is 14.") + parts.extend(create_image_parts(images)) + if files is not None: + parts.extend(files) + + image_config = GeminiImageConfig(imageSize=resolution) + if aspect_ratio != "auto": + image_config.aspectRatio = aspect_ratio + + response = await sync_op( + cls, + ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + data=GeminiImageGenerateContentRequest( + contents=[ + GeminiContent(role=GeminiRole.user, parts=parts), + ], + generationConfig=GeminiImageGenerationConfig( + responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), + imageConfig=image_config, + ), + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + + output_text = get_text_from_response(response) + if output_text: + render_spec = { + "node_id": cls.hidden.unique_id, + "component": "ChatHistoryWidget", + "props": { + "history": json.dumps( + [ + { + "prompt": prompt, + "response": output_text, + "response_id": str(uuid.uuid4()), + "timestamp": time.time(), + } + ] + ), + }, + } + PromptServer.instance.send_sync( + "display_component", + render_spec, + ) + return IO.NodeOutput(get_image_from_response(response), output_text) class GeminiExtension(ComfyExtension): @@ -562,6 +756,7 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ GeminiNode, GeminiImage, + GeminiImage2, GeminiInputFiles, ] diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index ad6e3c0d03e3..bf01d7d36a14 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -63,6 +63,7 @@ class _RequestConfig: estimated_total: Optional[int] = None final_label_on_success: Optional[str] = "Completed" progress_origin_ts: Optional[float] = None + price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None @dataclass @@ -87,6 +88,7 @@ async def sync_op( endpoint: ApiEndpoint, *, response_model: Type[M], + price_extractor: Optional[Callable[[M], Optional[float]]] = None, data: Optional[BaseModel] = None, files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, content_type: str = "application/json", @@ -104,6 +106,7 @@ async def sync_op( raw = await sync_op_raw( cls, endpoint, + price_extractor=_wrap_model_extractor(response_model, price_extractor), data=data, files=files, content_type=content_type, @@ -175,6 +178,7 @@ async def sync_op_raw( cls: type[IO.ComfyNode], endpoint: ApiEndpoint, *, + price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, data: Optional[Union[dict[str, Any], BaseModel]] = None, files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, content_type: str = "application/json", @@ -216,6 +220,7 @@ async def sync_op_raw( estimated_total=estimated_duration, final_label_on_success=final_label_on_success, progress_origin_ts=progress_origin_ts, + price_extractor=price_extractor, ) return await _request_base(cfg, expect_binary=as_binary) @@ -425,7 +430,8 @@ def _display_text( display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") if price is not None: p = f"{float(price):,.4f}".rstrip("0").rstrip(".") - display_lines.append(f"Price: ${p}") + if p != "0": + display_lines.append(f"Price: ${p}") if text is not None: display_lines.append(text) if display_lines: @@ -581,6 +587,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): delay = cfg.retry_delay operation_succeeded: bool = False final_elapsed_seconds: Optional[int] = None + extracted_price: Optional[float] = None while True: attempt += 1 stop_event = asyncio.Event() @@ -768,6 +775,8 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): except json.JSONDecodeError: payload = {"_raw": text} response_content_to_log = payload if isinstance(payload, dict) else text + with contextlib.suppress(Exception): + extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None operation_succeeded = True final_elapsed_seconds = int(time.monotonic() - start_time) try: @@ -872,7 +881,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): else int(time.monotonic() - start_time) ), estimated_total=cfg.estimated_total, - price=None, + price=extracted_price, is_queued=False, processing_elapsed_seconds=final_elapsed_seconds, ) diff --git a/requirements.txt b/requirements.txt index 36c39f338e41..8c1946f3d80d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.28.8 +comfyui-frontend-package==1.28.9 comfyui-workflow-templates==0.3.1 comfyui-embedded-docs==0.3.1 torch From b75d349f25ccb702895c6f1b8af7aded63a7f7e2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 21 Nov 2025 02:33:54 +0200 Subject: [PATCH 3/5] fix(KlingLipSyncAudioToVideoNode): convert audio to mp3 format (#10811) --- comfy_api_nodes/nodes_kling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 7b23e9cf95c6..36852038b12f 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -518,7 +518,9 @@ async def execute_lipsync( # Upload the audio file to Comfy API and get download URL if audio: - audio_url = await upload_audio_to_comfyapi(cls, audio) + audio_url = await upload_audio_to_comfyapi( + cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3" + ) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) else: audio_url = None From 10e90a5757906ecdb71b84d41173813d7f62c140 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Thu, 20 Nov 2025 18:20:52 -0800 Subject: [PATCH 4/5] bump comfyui-workflow-templates for nano banana 2 (#10818) * bump templates * bump templates --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8c1946f3d80d..624aa73622a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.28.9 -comfyui-workflow-templates==0.3.1 +comfyui-workflow-templates==0.6.0 comfyui-embedded-docs==0.3.1 torch torchsde From 943b3b615d40542ea19bc8ff8ad2950c0a094605 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 20 Nov 2025 19:44:43 -0800 Subject: [PATCH 5/5] HunyuanVideo 1.5 (#10819) * init * update * Update model.py * Update model.py * remove print * Fix text encoding * Prevent empty negative prompt Really doesn't work otherwise * fp16 works * I2V * Update model_base.py * Update nodes_hunyuan.py * Better latent rgb factors * Use the correct sigclip output... * Support HunyuanVideo1.5 SR model * whitespaces... * Proper latent channel count * SR model fixes This also still needs timesteps scheduling based on the noise scale, can be used with two samplers too already * vae_refiner: roll the convolution through temporal Work in progress. Roll the convolution through time using 2-latent-frame chunks and a FIFO queue for the convolution seams. * Support HunyuanVideo15 latent resampler * fix * Some cleanup Co-Authored-By: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> * Proper hyvid15 I2V channels Co-Authored-By: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> * Fix TokenRefiner for fp16 Otherwise x.sum has infs, just in case only casting if input is fp16, I don't know if necessary. * Bugfix for the HunyuanVideo15 SR model * vae_refiner: roll the convolution through temporal II Roll the convolution through time using 2-latent-frame chunks and a FIFO queue for the convolution seams. Added support for encoder, lowered to 1 latent frame to save more VRAM, made work for Hunyuan Image 3.0 (as code shared). Fixed names, cleaned up code. * Allow any number of input frames in VAE. * Better VAE encode mem estimation. * Lowvram fix. * Fix hunyuan image 2.1 refiner. * Fix mistake. * Name changes. * Rename. * Whitespace. * Fix. * Fix. --------- Co-authored-by: kijai <40791699+kijai@users.noreply.github.com> Co-authored-by: Rattus --- comfy/latent_formats.py | 60 ++++ comfy/ldm/hunyuan_video/model.py | 54 +++- comfy/ldm/hunyuan_video/upsampler.py | 120 ++++++++ comfy/ldm/hunyuan_video/vae_refiner.py | 288 +++++++++++------- comfy/model_base.py | 91 ++++++ comfy/model_detection.py | 10 + comfy/sd.py | 12 +- comfy/supported_models.py | 50 ++- comfy/text_encoders/hunyuan_video.py | 9 + comfy/text_encoders/qwen_image.py | 4 +- comfy_api/latest/_io.py | 4 + comfy_extras/nodes_hunyuan.py | 201 +++++++++++- folder_paths.py | 2 + .../put_latent_upscale_models_here | 0 nodes.py | 2 +- 15 files changed, 779 insertions(+), 128 deletions(-) create mode 100644 comfy/ldm/hunyuan_video/upsampler.py create mode 100644 models/latent_upscale_models/put_latent_upscale_models_here diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 77e642a94476..204fc048d466 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -611,6 +611,66 @@ class HunyuanImage21Refiner(LatentFormat): latent_dimensions = 3 scale_factor = 1.03682 + def process_in(self, latent): + out = latent * self.scale_factor + out = torch.cat((out[:, :, :1], out), dim=2) + out = out.permute(0, 2, 1, 3, 4) + b, f_times_2, c, h, w = out.shape + out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) + out = out.permute(0, 2, 1, 3, 4).contiguous() + return out + + def process_out(self, latent): + z = latent / self.scale_factor + z = z.permute(0, 2, 1, 3, 4) + b, f, c, h, w = z.shape + z = z.reshape(b, f, 2, c // 2, h, w) + z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) + z = z.permute(0, 2, 1, 3, 4) + z = z[:, :, 1:] + return z + +class HunyuanVideo15(LatentFormat): + latent_rgb_factors = [ + [ 0.0568, -0.0521, -0.0131], + [ 0.0014, 0.0735, 0.0326], + [ 0.0186, 0.0531, -0.0138], + [-0.0031, 0.0051, 0.0288], + [ 0.0110, 0.0556, 0.0432], + [-0.0041, -0.0023, -0.0485], + [ 0.0530, 0.0413, 0.0253], + [ 0.0283, 0.0251, 0.0339], + [ 0.0277, -0.0372, -0.0093], + [ 0.0393, 0.0944, 0.1131], + [ 0.0020, 0.0251, 0.0037], + [-0.0017, 0.0012, 0.0234], + [ 0.0468, 0.0436, 0.0203], + [ 0.0354, 0.0439, -0.0233], + [ 0.0090, 0.0123, 0.0346], + [ 0.0382, 0.0029, 0.0217], + [ 0.0261, -0.0300, 0.0030], + [-0.0088, -0.0220, -0.0283], + [-0.0272, -0.0121, -0.0363], + [-0.0664, -0.0622, 0.0144], + [ 0.0414, 0.0479, 0.0529], + [ 0.0355, 0.0612, -0.0247], + [ 0.0147, 0.0264, 0.0174], + [ 0.0438, 0.0038, 0.0542], + [ 0.0431, -0.0573, -0.0033], + [-0.0162, -0.0211, -0.0406], + [-0.0487, -0.0295, -0.0393], + [ 0.0005, -0.0109, 0.0253], + [ 0.0296, 0.0591, 0.0353], + [ 0.0119, 0.0181, -0.0306], + [-0.0085, -0.0362, 0.0229], + [ 0.0005, -0.0106, 0.0242] + ] + + latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644] + latent_channels = 32 + latent_dimensions = 3 + scale_factor = 1.03682 + class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 5132e6c07b75..f75c6e0e1362 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -6,7 +6,6 @@ import comfy.ldm.modules.diffusionmodules.mmdit from comfy.ldm.modules.attention import optimized_attention - from dataclasses import dataclass from einops import repeat @@ -42,6 +41,8 @@ class HunyuanVideoParams: guidance_embed: bool byt5: bool meanflow: bool + use_cond_type_embedding: bool + vision_in_dim: int class SelfAttentionRef(nn.Module): @@ -157,7 +158,10 @@ def forward( t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype)) # m = mask.float().unsqueeze(-1) # c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise - c = x.sum(dim=1) / x.shape[1] + if x.dtype == torch.float16: + c = x.float().sum(dim=1) / x.shape[1] + else: + c = x.sum(dim=1) / x.shape[1] c = t + self.c_embedder(c.to(x.dtype)) x = self.input_embedder(x) @@ -196,11 +200,15 @@ class HunyuanVideo(nn.Module): def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): super().__init__() self.dtype = dtype + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + params = HunyuanVideoParams(**kwargs) self.params = params self.patch_size = params.patch_size self.in_channels = params.in_channels self.out_channels = params.out_channels + self.use_cond_type_embedding = params.use_cond_type_embedding + self.vision_in_dim = params.vision_in_dim if params.hidden_size % params.num_heads != 0: raise ValueError( f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" @@ -266,6 +274,18 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, if final_layer: self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations) + # HunyuanVideo 1.5 specific modules + if self.vision_in_dim is not None: + from comfy.ldm.wan.model import MLPProj + self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings) + else: + self.vision_in = None + if self.use_cond_type_embedding: + # 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature + self.cond_type_embedding = nn.Embedding(3, self.hidden_size) + else: + self.cond_type_embedding = None + def forward_orig( self, img: Tensor, @@ -276,6 +296,7 @@ def forward_orig( timesteps: Tensor, y: Tensor = None, txt_byt5=None, + clip_fea=None, guidance: Tensor = None, guiding_frame_index=None, ref_latent=None, @@ -331,12 +352,31 @@ def forward_orig( txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options) + if self.cond_type_embedding is not None: + self.cond_type_embedding.to(txt.device) + cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long)) + txt = txt + cond_emb.to(txt.dtype) + if self.byt5_in is not None and txt_byt5 is not None: txt_byt5 = self.byt5_in(txt_byt5) + if self.cond_type_embedding is not None: + cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long)) + txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype) + txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5 + else: + txt = torch.cat((txt, txt_byt5), dim=1) txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) - txt = torch.cat((txt, txt_byt5), dim=1) txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1) + if clip_fea is not None: + txt_vision_states = self.vision_in(clip_fea) + if self.cond_type_embedding is not None: + cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device)) + txt_vision_states = txt_vision_states + cond_emb + txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1) + extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) + txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1) + ids = torch.cat((img_ids, txt_ids), dim=1) pe = self.pe_embedder(ids) @@ -430,14 +470,14 @@ def img_ids_2d(self, x): img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) return repeat(img_ids, "h w c -> b (h w) c", b=bs) - def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) + ).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) - def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): + def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): bs = x.shape[0] if len(self.patch_size) == 3: img_ids = self.img_ids(x) @@ -445,5 +485,5 @@ def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, a else: img_ids = self.img_ids_2d(x) txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) return out diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py new file mode 100644 index 000000000000..9f5e91a59c6a --- /dev/null +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d +import model_management, model_patcher + +class SRResidualCausalBlock3D(nn.Module): + def __init__(self, channels: int): + super().__init__() + self.block = nn.Sequential( + VideoConv3d(channels, channels, kernel_size=3), + nn.SiLU(inplace=True), + VideoConv3d(channels, channels, kernel_size=3), + nn.SiLU(inplace=True), + VideoConv3d(channels, channels, kernel_size=3), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.block(x) + +class SRModel3DV2(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int = 64, + num_blocks: int = 6, + global_residual: bool = False, + ): + super().__init__() + self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3) + self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)]) + self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3) + self.global_residual = bool(global_residual) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + y = self.in_conv(x) + for blk in self.blocks: + y = blk(y) + y = self.out_conv(y) + if self.global_residual and (y.shape == residual.shape): + y = y + residual + return y + + +class Upsampler(nn.Module): + def __init__( + self, + z_channels: int, + out_channels: int, + block_out_channels: tuple[int, ...], + num_res_blocks: int = 2, + ): + super().__init__() + self.num_res_blocks = num_res_blocks + self.block_out_channels = block_out_channels + self.z_channels = z_channels + + ch = block_out_channels[0] + self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3) + + self.up = nn.ModuleList() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_shortcut=False, + conv_op=VideoConv3d, norm_op=RMS_norm) + for j in range(num_res_blocks + 1)]) + ch = tgt + self.up.append(stage) + + self.norm_out = RMS_norm(ch) + self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3) + + def forward(self, z): + """ + Args: + z: (B, C, T, H, W) + target_shape: (H, W) + """ + # z to block_in + repeats = self.block_out_channels[0] // (self.z_channels) + x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1) + + # upsampling + for stage in self.up: + for blk in stage.block: + x = blk(x) + + out = self.conv_out(F.silu(self.norm_out(x))) + return out + +UPSAMPLERS = { + "720p": SRModel3DV2, + "1080p": Upsampler, +} + +class HunyuanVideo15SRModel(): + def __init__(self, model_type, config): + self.load_device = model_management.vae_device() + offload_device = model_management.vae_offload_device() + self.dtype = model_management.vae_dtype(self.load_device) + self.model_class = UPSAMPLERS.get(model_type) + self.model = self.model_class(**config).eval() + + self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + + def load_sd(self, sd): + return self.model.load_state_dict(sd, strict=True) + + def get_sd(self): + return self.model.state_dict() + + def resample_latent(self, latent): + model_management.load_model_gpu(self.patcher) + return self.model(latent.to(self.load_device)) diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index c2a0b507d4bf..9f750dcc4776 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -4,8 +4,40 @@ from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize import comfy.ops import comfy.ldm.models.autoencoder +import comfy.model_management ops = comfy.ops.disable_weight_init +class NoPadConv3d(nn.Module): + def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs): + super().__init__() + self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + return self.conv(x) + + +def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None): + + x = xl[0] + xl.clear() + + if conv_carry_out is not None: + to_push = x[:, :, -2:, :, :].clone() + conv_carry_out.append(to_push) + + if isinstance(op, NoPadConv3d): + if conv_carry_in is None: + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate') + else: + carry_len = conv_carry_in[0].shape[2] + x = torch.cat([conv_carry_in.pop(0), x], dim=2) + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate') + + out = op(x) + + return out + + class RMS_norm(nn.Module): def __init__(self, dim): super().__init__() @@ -14,7 +46,7 @@ def __init__(self, dim): self.gamma = nn.Parameter(torch.empty(shape)) def forward(self, x): - return F.normalize(x, dim=1) * self.scale * self.gamma + return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) class DnSmpl(nn.Module): def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): @@ -27,11 +59,12 @@ def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): self.tds = tds self.gs = fct * ic // oc - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): r1 = 2 if self.tds else 1 - h = self.conv(x) + h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) + + if self.tds and self.refiner_vae and conv_carry_in is None: - if self.tds and self.refiner_vae: hf = h[:, :, :1, :, :] b, c, f, ht, wd = hf.shape hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) @@ -39,14 +72,7 @@ def forward(self, x): hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2) hf = torch.cat([hf, hf], dim=1) - hn = h[:, :, 1:, :, :] - b, c, frms, ht, wd = hn.shape - nf = frms // r1 - hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) - hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6) - hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) - - h = torch.cat([hf, hn], dim=2) + h = h[:, :, 1:, :, :] xf = x[:, :, :1, :, :] b, ci, f, ht, wd = xf.shape @@ -54,34 +80,32 @@ def forward(self, x): xf = xf.permute(0, 4, 6, 1, 2, 3, 5) xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2) B, C, T, H, W = xf.shape - xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2) - - xn = x[:, :, 1:, :, :] - b, ci, frms, ht, wd = xn.shape - nf = frms // r1 - xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) - xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6) - xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) - B, C, T, H, W = xn.shape - xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) - sc = torch.cat([xf, xn], dim=2) - else: - b, c, frms, ht, wd = h.shape + xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2) + + x = x[:, :, 1:, :, :] - nf = frms // r1 - h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) - h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) - h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) + if h.shape[2] == 0: + return hf + xf - b, ci, frms, ht, wd = x.shape - nf = frms // r1 - sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) - sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6) - sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) - B, C, T, H, W = sc.shape - sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) + b, c, frms, ht, wd = h.shape + nf = frms // r1 + h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) + h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) + h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) - return h + sc + b, ci, frms, ht, wd = x.shape + nf = frms // r1 + x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) + x = x.permute(0, 3, 5, 7, 1, 2, 4, 6) + x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) + B, C, T, H, W = x.shape + x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) + + if self.tds and self.refiner_vae and conv_carry_in is None: + h = torch.cat([hf, h], dim=2) + x = torch.cat([xf, x], dim=2) + + return h + x class UpSmpl(nn.Module): @@ -94,11 +118,11 @@ def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): self.tus = tus self.rp = fct * oc // ic - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): r1 = 2 if self.tus else 1 - h = self.conv(x) + h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) - if self.tus and self.refiner_vae: + if self.tus and self.refiner_vae and conv_carry_in is None: hf = h[:, :, :1, :, :] b, c, f, ht, wd = hf.shape nc = c // (2 * 2) @@ -107,14 +131,7 @@ def forward(self, x): hf = hf.reshape(b, nc, f, ht * 2, wd * 2) hf = hf[:, : hf.shape[1] // 2] - hn = h[:, :, 1:, :, :] - b, c, frms, ht, wd = hn.shape - nc = c // (r1 * 2 * 2) - hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd) - hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3) - hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2) - - h = torch.cat([hf, hn], dim=2) + h = h[:, :, 1:, :, :] xf = x[:, :, :1, :, :] b, ci, f, ht, wd = xf.shape @@ -125,29 +142,43 @@ def forward(self, x): xf = xf.permute(0, 3, 4, 5, 1, 6, 2) xf = xf.reshape(b, nc, f, ht * 2, wd * 2) - xn = x[:, :, 1:, :, :] - xn = xn.repeat_interleave(repeats=self.rp, dim=1) - b, c, frms, ht, wd = xn.shape - nc = c // (r1 * 2 * 2) - xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd) - xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3) - xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2) - sc = torch.cat([xf, xn], dim=2) - else: - b, c, frms, ht, wd = h.shape - nc = c // (r1 * 2 * 2) - h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd) - h = h.permute(0, 4, 5, 1, 6, 2, 7, 3) - h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2) + x = x[:, :, 1:, :, :] - sc = x.repeat_interleave(repeats=self.rp, dim=1) - b, c, frms, ht, wd = sc.shape - nc = c // (r1 * 2 * 2) - sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd) - sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3) - sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2) + b, c, frms, ht, wd = h.shape + nc = c // (r1 * 2 * 2) + h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd) + h = h.permute(0, 4, 5, 1, 6, 2, 7, 3) + h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2) - return h + sc + x = x.repeat_interleave(repeats=self.rp, dim=1) + b, c, frms, ht, wd = x.shape + nc = c // (r1 * 2 * 2) + x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd) + x = x.permute(0, 4, 5, 1, 6, 2, 7, 3) + x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + if self.tus and self.refiner_vae and conv_carry_in is None: + h = torch.cat([hf, h], dim=2) + x = torch.cat([xf, x], dim=2) + + return h + x + +class HunyuanRefinerResnetBlock(ResnetBlock): + def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm): + super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + + def forward(self, x, conv_carry_in=None, conv_carry_out=None): + h = x + h = [ self.swish(self.norm1(x)) ] + h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) + + h = [ self.dropout(self.swish(self.norm2(h))) ] + h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x+h class Encoder(nn.Module): def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, @@ -160,7 +191,7 @@ def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = VideoConv3d + conv_op = NoPadConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -175,10 +206,9 @@ def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, - out_channels=tgt, - temb_channels=0, - conv_op=conv_op, norm_op=norm_op) + stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks)]) ch = tgt if i < depth: @@ -188,9 +218,9 @@ def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, self.down.append(stage) self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.norm_out = norm_op(ch) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) @@ -201,31 +231,50 @@ def forward(self, x): if not self.refiner_vae and x.shape[2] == 1: x = x.expand(-1, -1, self.ffactor_temporal, -1, -1) - x = self.conv_in(x) + if self.refiner_vae: + xl = [x[:, :, :1, :, :]] + if x.shape[2] > self.ffactor_temporal: + xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2) + x = xl + else: + x = [x] + out = [] - for stage in self.down: - for blk in stage.block: - x = blk(x) - if hasattr(stage, 'downsample'): - x = stage.downsample(x) + conv_carry_in = None - x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + for i, x1 in enumerate(x): + conv_carry_out = [] + if i == len(x) - 1: + conv_carry_out = None + x1 = [ x1 ] + x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) + + for stage in self.down: + for blk in stage.block: + x1 = blk(x1, conv_carry_in, conv_carry_out) + if hasattr(stage, 'downsample'): + x1 = stage.downsample(x1, conv_carry_in, conv_carry_out) + + out.append(x1) + conv_carry_in = conv_carry_out + + if len(out) > 1: + out = torch.cat(out, dim=2) + else: + out = out[0] + + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out))) + del out b, c, t, h, w = x.shape grp = c // (self.z_channels << 1) skip = x.view(b, c // grp, grp, t, h, w).mean(2) - out = self.conv_out(F.silu(self.norm_out(x))) + skip + out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip if self.refiner_vae: out = self.regul(out)[0] - out = torch.cat((out[:, :, :1], out), dim=2) - out = out.permute(0, 2, 1, 3, 4) - b, f_times_2, c, h, w = out.shape - out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) - out = out.permute(0, 2, 1, 3, 4).contiguous() - return out class Decoder(nn.Module): @@ -239,7 +288,7 @@ def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = VideoConv3d + conv_op = NoPadConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -249,9 +298,9 @@ def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.up = nn.ModuleList() depth = (ffactor_spatial >> 1).bit_length() @@ -259,10 +308,9 @@ def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, - out_channels=tgt, - temb_channels=0, - conv_op=conv_op, norm_op=norm_op) + stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks + 1)]) ch = tgt if i < depth: @@ -275,27 +323,41 @@ def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1) def forward(self, z): - if self.refiner_vae: - z = z.permute(0, 2, 1, 3, 4) - b, f, c, h, w = z.shape - z = z.reshape(b, f, 2, c // 2, h, w) - z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) - z = z.permute(0, 2, 1, 3, 4) - z = z[:, :, 1:] - - x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) + x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) - for stage in self.up: - for blk in stage.block: - x = blk(x) - if hasattr(stage, 'upsample'): - x = stage.upsample(x) - - out = self.conv_out(F.silu(self.norm_out(x))) + if self.refiner_vae: + x = torch.split(x, 2, dim=2) + else: + x = [ x ] + out = [] + + conv_carry_in = None + + for i, x1 in enumerate(x): + conv_carry_out = [] + if i == len(x) - 1: + conv_carry_out = None + for stage in self.up: + for blk in stage.block: + x1 = blk(x1, conv_carry_in, conv_carry_out) + if hasattr(stage, 'upsample'): + x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) + + x1 = [ F.silu(self.norm_out(x1)) ] + x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out) + out.append(x1) + conv_carry_in = conv_carry_out + del x + + if len(out) > 1: + out = torch.cat(out, dim=2) + else: + out = out[0] if not self.refiner_vae: if z.shape[-3] == 1: out = out[:, :, -1:] return out + diff --git a/comfy/model_base.py b/comfy/model_base.py index 7c788d085482..e14b552c5053 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1536,3 +1536,94 @@ def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) out['disable_time_r'] = comfy.conds.CONDConstant(True) return out + +class HunyuanVideo15(HunyuanVideo): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 #noise 32 img cond 32 + mask 1 + if extra_channels == 0: + return None + + image = kwargs.get("concat_latent_image", None) + device = kwargs["device"] + + if image is None: + shape_image = list(noise.shape) + shape_image[1] = extra_channels + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + latent_dim = self.latent_format.latent_channels + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], latent_dim): + image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim]) + image = utils.resize_to_batch_size(image, noise.shape[0]) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :1] + else: + mask = 1.0 - mask + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + + return torch.cat((image, mask), dim=1) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + if torch.numel(attention_mask) != attention_mask.sum(): + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + conditioning_byt5small = kwargs.get("conditioning_byt5small", None) + if conditioning_byt5small is not None: + out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small) + + guidance = kwargs.get("guidance", 6.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + + clip_vision_output = kwargs.get("clip_vision_output", None) + if clip_vision_output is not None: + out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.last_hidden_state) + + return out + +class HunyuanVideo15_SR_Distilled(HunyuanVideo15): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + image = kwargs.get("concat_latent_image", None) + noise_augmentation = kwargs.get("noise_augmentation", 0.0) + device = kwargs["device"] + + if image is None: + image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=comfy.model_management.intermediate_device()) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + #image = self.process_latent_in(image) # scaling wasn't applied in reference code + image = utils.resize_to_batch_size(image, noise.shape[0]) + lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1) + if noise_augmentation > 0: + generator = torch.Generator(device="cpu") + generator.manual_seed(kwargs.get("seed", 0) - 10) + noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device) + image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice] + else: + image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice] + return image + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + out['disable_time_r'] = comfy.conds.CONDConstant(False) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 3142a7fc388c..0131ca25a51e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -186,6 +186,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys)) dit_config["guidance_embed"] = len(guidance_keys) > 0 + + # HunyuanVideo 1.5 + if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys: + dit_config["use_cond_type_embedding"] = True + else: + dit_config["use_cond_type_embedding"] = False + if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys: + dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0] + else: + dit_config["vision_in_dim"] = None return dit_config if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) diff --git a/comfy/sd.py b/comfy/sd.py index 9e5ebbf15ba0..dc0905ada4af 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -441,20 +441,20 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None) elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] - self.latent_channels = 64 + self.latent_channels = 32 self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) self.upscale_index_formula = (4, 16, 16) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) self.downscale_index_formula = (4, 16, 16) self.latent_dim = 3 - self.not_video = True + self.not_video = False self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"}, encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig}, decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) - self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True @@ -911,6 +911,7 @@ class CLIPType(Enum): OMNIGEN2 = 17 QWEN_IMAGE = 18 HUNYUAN_IMAGE = 19 + HUNYUAN_VIDEO_15 = 20 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1126,6 +1127,9 @@ class EmptyClass: elif clip_type == CLIPType.HUNYUAN_IMAGE: clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer + elif clip_type == CLIPType.HUNYUAN_VIDEO_15: + clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 4064bdae1052..2e64b85e88fb 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1374,6 +1374,54 @@ def get_model(self, state_dict, prefix="", device=None): out = model_base.HunyuanImage21Refiner(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] +class HunyuanVideo15(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "vision_in_dim": 1152, + } + + sampling_settings = { + "shift": 7.0, + } + memory_usage_factor = 4.0 #TODO + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + latent_format = latent_formats.HunyuanVideo15 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanVideo15(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) + + +class HunyuanVideo15_SR_Distilled(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "vision_in_dim": 1152, + "in_channels": 98, + } + + sampling_settings = { + "shift": 2.0, + } + memory_usage_factor = 4.0 #TODO + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + latent_format = latent_formats.HunyuanVideo15 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanVideo15_SR_Distilled(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index b02148b3346d..557094f496b6 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -1,6 +1,7 @@ from comfy import sd1_clip import comfy.model_management import comfy.text_encoders.llama +from .hunyuan_image import HunyuanImageTokenizer from transformers import LlamaTokenizerFast import torch import os @@ -73,6 +74,14 @@ def state_dict(self): return {} +class HunyuanVideo15Tokenizer(HunyuanImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a helpful assistant. Describe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + return super().tokenize_with_weights(text, return_word_ids, prevent_empty_text=True, **kwargs) + class HunyuanVideoClipModel(torch.nn.Module): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): super().__init__() diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index 40fa6793739c..c0d32a6ef4ec 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -17,12 +17,14 @@ def __init__(self, embedding_directory=None, tokenizer_data={}): self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" - def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs): skip_template = False if text.startswith('<|im_start|>'): skip_template = True if text.startswith('<|start_header_id|>'): skip_template = True + if prevent_empty_text and text == '': + text = ' ' if skip_template: llama_text = text diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 863254ce7fdc..79c0722a9e46 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -629,6 +629,10 @@ class UpscaleModel(ComfyTypeIO): if TYPE_CHECKING: Type = ImageModelDescriptor +@comfytype(io_type="LATENT_UPSCALE_MODEL") +class LatentUpscaleModel(ComfyTypeIO): + Type = Any + @comfytype(io_type="AUDIO") class Audio(ComfyTypeIO): class AudioDict(TypedDict): diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index f7c34d0590e9..5a2e8cc61365 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -4,7 +4,8 @@ import comfy.model_management from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel +import folder_paths class CLIPTextEncodeHunyuanDiT(io.ComfyNode): @classmethod @@ -57,6 +58,199 @@ def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: generate = execute # TODO: remove +class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo): + @classmethod + def define_schema(cls): + schema = super().define_schema() + schema.node_id = "EmptyHunyuanVideo15Latent" + return schema + + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: + # Using scale factor of 16 instead of 8 + latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) + + generate = execute # TODO: remove + + +class HunyuanVideo15ImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15ImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=33, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + encoded = vae.encode(start_image[:, :, :, :3]) + concat_latent_image = torch.zeros((latent.shape[0], 32, latent.shape[2], latent.shape[3], latent.shape[4]), device=comfy.model_management.intermediate_device()) + concat_latent_image[:, :, :encoded.shape[2], :, :] = encoded + + mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + +class HunyuanVideo15SuperResolution(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15SuperResolution", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae", optional=True), + io.Image.Input("start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Latent.Input("latent"), + io.Float.Input("noise_augmentation", default=0.70, min=0.0, max=1.0, step=0.01), + + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, latent, noise_augmentation, vae=None, start_image=None, clip_vision_output=None) -> io.NodeOutput: + in_latent = latent["samples"] + in_channels = in_latent.shape[1] + cond_latent = torch.zeros([in_latent.shape[0], in_channels * 2 + 2, in_latent.shape[-3], in_latent.shape[-2], in_latent.shape[-1]], device=comfy.model_management.intermediate_device()) + cond_latent[:, in_channels + 1 : 2 * in_channels + 1] = in_latent + cond_latent[:, 2 * in_channels + 1] = 1 + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image.movedim(-1, 1), in_latent.shape[-1] * 16, in_latent.shape[-2] * 16, "bilinear", "center").movedim(1, -1) + encoded = vae.encode(start_image[:, :, :, :3]) + cond_latent[:, :in_channels, :encoded.shape[2], :, :] = encoded + cond_latent[:, in_channels + 1, 0] = 1 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation}) + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + return io.NodeOutput(positive, negative, latent) + + +class LatentUpscaleModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LatentUpscaleModelLoader", + display_name="Load Latent Upscale Model", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("latent_upscale_models")), + ], + outputs=[ + io.LatentUpscaleModel.Output(), + ], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + + if "blocks.0.block.0.conv.weight" in sd: + config = { + "in_channels": sd["in_conv.conv.weight"].shape[1], + "out_channels": sd["out_conv.conv.weight"].shape[0], + "hidden_channels": sd["in_conv.conv.weight"].shape[0], + "num_blocks": len([k for k in sd.keys() if k.startswith("blocks.") and k.endswith(".block.0.conv.weight")]), + "global_residual": False, + } + model_type = "720p" + elif "up.0.block.0.conv1.conv.weight" in sd: + sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()} + config = { + "z_channels": sd["conv_in.conv.weight"].shape[1], + "out_channels": sd["conv_out.conv.weight"].shape[0], + "block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))), + } + model_type = "1080p" + + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) + + return io.NodeOutput(model) + + +class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15LatentUpscaleWithModel", + display_name="Hunyuan Video 15 Latent Upscale With Model", + category="latent", + inputs=[ + io.LatentUpscaleModel.Input("model"), + io.Latent.Input("samples"), + io.Combo.Input("upscale_method", options=["nearest-exact", "bilinear", "area", "bicubic", "bislerp"], default="bilinear"), + io.Int.Input("width", default=1280, min=0, max=16384, step=8), + io.Int.Input("height", default=720, min=0, max=16384, step=8), + io.Combo.Input("crop", options=["disabled", "center"]), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, model, samples, upscale_method, width, height, crop) -> io.NodeOutput: + if width == 0 and height == 0: + return io.NodeOutput(samples) + else: + if width == 0: + height = max(64, height) + width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2])) + elif height == 0: + width = max(64, width) + height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1])) + else: + width = max(64, width) + height = max(64, height) + s = comfy.utils.common_upscale(samples["samples"], width // 16, height // 16, upscale_method, crop) + s = model.resample_latent(s) + return io.NodeOutput({"samples": s.cpu().float()}) + + PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " "1. The main content and theme of the video." @@ -210,6 +404,11 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: CLIPTextEncodeHunyuanDiT, TextEncodeHunyuanVideo_ImageToVideo, EmptyHunyuanLatentVideo, + EmptyHunyuanVideo15Latent, + HunyuanVideo15ImageToVideo, + HunyuanVideo15SuperResolution, + HunyuanVideo15LatentUpscaleWithModel, + LatentUpscaleModelLoader, HunyuanImageToVideo, EmptyHunyuanImageLatent, HunyuanRefinerLatent, diff --git a/folder_paths.py b/folder_paths.py index f110d832bb23..ffdc4d020662 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -38,6 +38,8 @@ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) +folder_names_and_paths["latent_upscale_models"] = ([os.path.join(models_dir, "latent_upscale_models")], supported_pt_extensions) + folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], set()) folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) diff --git a/models/latent_upscale_models/put_latent_upscale_models_here b/models/latent_upscale_models/put_latent_upscale_models_here new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nodes.py b/nodes.py index 0303716338aa..f023ae3b687c 100644 --- a/nodes.py +++ b/nodes.py @@ -957,7 +957,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}),