Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions app/node_replace_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from __future__ import annotations

from aiohttp import web

from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from comfy_api.latest._io_public import NodeReplace

from comfy_execution.graph_utils import is_link
import nodes

class NodeStruct(TypedDict):
inputs: dict[str, str | int | float | bool | tuple[str, int]]
class_type: str
_meta: dict[str, str]

def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
new_node_struct = node_struct.copy()
if empty_inputs:
new_node_struct["inputs"] = {}
else:
new_node_struct["inputs"] = node_struct["inputs"].copy()
new_node_struct["_meta"] = node_struct["_meta"].copy()
return new_node_struct


class NodeReplaceManager:
"""Manages node replacement registrations."""

def __init__(self):
self._replacements: dict[str, list[NodeReplace]] = {}

def register(self, node_replace: NodeReplace):
"""Register a node replacement mapping."""
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)

def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
"""Get replacements for an old node ID."""
return self._replacements.get(old_node_id)

def has_replacement(self, old_node_id: str) -> bool:
"""Check if a replacement exists for an old node ID."""
return old_node_id in self._replacements

def apply_replacements(self, prompt: dict[str, NodeStruct]):
connections: dict[str, list[tuple[str, str, int]]] = {}
need_replacement: set[str] = set()
for node_number, node_struct in prompt.items():
class_type = node_struct["class_type"]
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
need_replacement.add(node_number)
# keep track of connections
for input_id, input_value in node_struct["inputs"].items():
if is_link(input_value):
conn_number = input_value[0]
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
for node_number in need_replacement:
node_struct = prompt[node_number]
class_type = node_struct["class_type"]
replacements = self.get_replacement(class_type)
if replacements is None:
continue
# just use the first replacement
replacement = replacements[0]
new_node_id = replacement.new_node_id
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
continue
# first, replace node id (class_type)
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
new_node_struct["class_type"] = new_node_id
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
# second, replace inputs
if replacement.input_mapping is not None:
for input_map in replacement.input_mapping:
if "set_value" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
elif "old_id" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
# finalize input replacement
prompt[node_number] = new_node_struct
# third, replace outputs
if replacement.output_mapping is not None:
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
if node_number in connections:
for conns in connections[node_number]:
conn_node_number, conn_input_id, old_output_idx = conns
for output_map in replacement.output_mapping:
if output_map["old_idx"] == old_output_idx:
new_output_idx = output_map["new_idx"]
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
previous_input[1] = new_output_idx

def as_dict(self):
"""Serialize all replacements to dict."""
return {
k: [v.as_dict() for v in v_list]
for k, v_list in self._replacements.items()
}

def add_routes(self, routes):
@routes.get("/node_replacements")
async def get_node_replacements(request):
return web.json_response(self.as_dict())
1 change: 1 addition & 0 deletions comfy_api/feature_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,
}


Expand Down
13 changes: 11 additions & 2 deletions comfy_api/latest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False

def __init__(self):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()

class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
"""Register a node replacement mapping."""
from server import PromptServer
PromptServer.instance.node_replace_manager.register(node_replace)

class Execution(ProxiedSingleton):
async def set_progress(
self,
Expand Down Expand Up @@ -73,8 +84,6 @@ async def set_progress(
image=to_display,
)

execution: Execution

class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Expand Down
63 changes: 63 additions & 0 deletions comfy_api/latest/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2030,6 +2030,68 @@ def as_dict(self) -> dict:
...


class InputMapOldId(TypedDict):
"""Map an old node input to a new node input by ID."""
new_id: str
old_id: str

class InputMapSetValue(TypedDict):
"""Set a specific value for a new node input."""
new_id: str
set_value: Any

InputMap = InputMapOldId | InputMapSetValue
"""
Input mapping for node replacement. Type is inferred by dictionary keys:
- {"new_id": str, "old_id": str} - maps old input to new input
- {"new_id": str, "set_value": Any} - sets a specific value for new input
"""

class OutputMap(TypedDict):
"""Map outputs of node replacement via indexes."""
new_idx: int
old_idx: int

class NodeReplace:
"""
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.

Also supports assigning specific values to the input widgets of the new node.

Args:
new_node_id: The class name of the new replacement node.
old_node_id: The class name of the deprecated node.
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
connected. The workflow JSON stores widget values by their relative position index,
not by ID. This list maps those positional indexes to input IDs, enabling the
replacement system to correctly identify widget values during node migration.
input_mapping: List of input mappings from old node to new node.
output_mapping: List of output mappings from old node to new node.
"""
def __init__(self,
new_node_id: str,
old_node_id: str,
old_widget_ids: list[str] | None=None,
input_mapping: list[InputMap] | None=None,
output_mapping: list[OutputMap] | None=None,
):
self.new_node_id = new_node_id
self.old_node_id = old_node_id
self.old_widget_ids = old_widget_ids
self.input_mapping = input_mapping
self.output_mapping = output_mapping

def as_dict(self):
"""Create serializable representation of the node replacement."""
return {
"new_node_id": self.new_node_id,
"old_node_id": self.old_node_id,
"old_widget_ids": self.old_widget_ids,
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
}


__all__ = [
"FolderType",
"UploadType",
Expand Down Expand Up @@ -2121,4 +2183,5 @@ def as_dict(self) -> dict:
"ImageCompare",
"PriceBadgeDepends",
"PriceBadge",
"NodeReplace",
]
6 changes: 0 additions & 6 deletions comfy_api_nodes/nodes_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class SupportedOpenAIModel(str, Enum):
o1 = "o1"
o3 = "o3"
o1_pro = "o1-pro"
gpt_4o = "gpt-4o"
gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano"
Expand Down Expand Up @@ -649,11 +648,6 @@ def define_schema(cls):
"usd": [0.01, 0.04],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-4o") ? {
"type": "list_usd",
"usd": [0.0025, 0.01],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-4.1-nano") ? {
"type": "list_usd",
"usd": [0.0001, 0.0004],
Expand Down
1 change: 1 addition & 0 deletions comfy_extras/nodes_post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,7 @@ def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
batched = batch_masks(values)
return io.NodeOutput(batched)


class PostProcessingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
Expand Down
103 changes: 103 additions & 0 deletions comfy_extras/nodes_replacements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from comfy_api.latest import ComfyExtension, io, ComfyAPI

api = ComfyAPI()


async def register_replacements():
"""Register all built-in node replacements."""
await register_replacements_longeredge()
await register_replacements_batchimages()
await register_replacements_upscaleimage()
await register_replacements_controlnet()
await register_replacements_load3d()
await register_replacements_preview3d()
await register_replacements_svdimg2vid()
await register_replacements_conditioningavg()

async def register_replacements_longeredge():
# No dynamic inputs here
await api.node_replacement.register(io.NodeReplace(
new_node_id="ImageScaleToMaxDimension",
old_node_id="ResizeImagesByLongerEdge",
old_widget_ids=["longer_edge"],
input_mapping=[
{"new_id": "image", "old_id": "images"},
{"new_id": "largest_size", "old_id": "longer_edge"},
{"new_id": "upscale_method", "set_value": "lanczos"},
],
# just to test the frontend output_mapping code, does nothing really here
output_mapping=[{"new_idx": 0, "old_idx": 0}],
))

async def register_replacements_batchimages():
# BatchImages node uses Autogrow
await api.node_replacement.register(io.NodeReplace(
new_node_id="BatchImagesNode",
old_node_id="ImageBatch",
input_mapping=[
{"new_id": "images.image0", "old_id": "image1"},
{"new_id": "images.image1", "old_id": "image2"},
],
))

async def register_replacements_upscaleimage():
# ResizeImageMaskNode uses DynamicCombo
await api.node_replacement.register(io.NodeReplace(
new_node_id="ResizeImageMaskNode",
old_node_id="ImageScaleBy",
old_widget_ids=["upscale_method", "scale_by"],
input_mapping=[
{"new_id": "input", "old_id": "image"},
{"new_id": "resize_type", "set_value": "scale by multiplier"},
{"new_id": "resize_type.multiplier", "old_id": "scale_by"},
{"new_id": "scale_method", "old_id": "upscale_method"},
],
))

async def register_replacements_controlnet():
# T2IAdapterLoader → ControlNetLoader
await api.node_replacement.register(io.NodeReplace(
new_node_id="ControlNetLoader",
old_node_id="T2IAdapterLoader",
input_mapping=[
{"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
],
))

async def register_replacements_load3d():
# Load3DAnimation merged into Load3D
await api.node_replacement.register(io.NodeReplace(
new_node_id="Load3D",
old_node_id="Load3DAnimation",
))

async def register_replacements_preview3d():
# Preview3DAnimation merged into Preview3D
await api.node_replacement.register(io.NodeReplace(
new_node_id="Preview3D",
old_node_id="Preview3DAnimation",
))

async def register_replacements_svdimg2vid():
# Typo fix: SDV → SVD
await api.node_replacement.register(io.NodeReplace(
new_node_id="SVD_img2vid_Conditioning",
old_node_id="SDV_img2vid_Conditioning",
))

async def register_replacements_conditioningavg():
# Typo fix: trailing space in node name
await api.node_replacement.register(io.NodeReplace(
new_node_id="ConditioningAverage",
old_node_id="ConditioningAverage ",
))

class NodeReplacementsExtension(ComfyExtension):
async def on_load(self) -> None:
await register_replacements()

async def get_node_list(self) -> list[type[io.ComfyNode]]:
return []

async def comfy_entrypoint() -> NodeReplacementsExtension:
return NodeReplacementsExtension()
2 changes: 2 additions & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2264,6 +2264,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
if not isinstance(extension, ComfyExtension):
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
return False
await extension.on_load()
node_list = await extension.get_node_list()
if not isinstance(node_list, list):
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
Expand Down Expand Up @@ -2435,6 +2436,7 @@ async def init_builtin_extra_nodes():
"nodes_lora_debug.py",
"nodes_color.py",
"nodes_toolkit.py",
"nodes_replacements.py",
]

import_failed = []
Expand Down
5 changes: 5 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager
from app.node_replace_manager import NodeReplaceManager
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
Expand Down Expand Up @@ -204,6 +205,7 @@ def __init__(self, loop):
self.model_file_manager = ModelFileManager()
self.custom_node_manager = CustomNodeManager()
self.subgraph_manager = SubgraphManager()
self.node_replace_manager = NodeReplaceManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = execution.PromptQueue(self)
Expand Down Expand Up @@ -887,6 +889,8 @@ async def post_prompt(request):
if "partial_execution_targets" in json_data:
partial_execution_targets = json_data["partial_execution_targets"]

self.node_replace_manager.apply_replacements(prompt)

valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
extra_data = {}
if "extra_data" in json_data:
Expand Down Expand Up @@ -995,6 +999,7 @@ def add_routes(self):
self.model_file_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
self.node_replace_manager.add_routes(self.routes)
self.app.add_subapp('/internal', self.internal_routes.get_app())

# Prefix every route with /api for easier matching for delegation.
Expand Down
Loading