Skip to content

Commit 6ea8c12

Browse files
authored
Assets Part 2 - add more endpoints (Comfy-Org#12125)
1 parent 6e469a3 commit 6ea8c12

17 files changed

+4347
-25
lines changed

app/assets/api/routes.py

Lines changed: 413 additions & 1 deletion
Large diffs are not rendered by default.

app/assets/api/schemas_in.py

Lines changed: 183 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json
2-
import uuid
32
from typing import Any, Literal
43

54
from pydantic import (
@@ -8,9 +7,9 @@
87
Field,
98
conint,
109
field_validator,
10+
model_validator,
1111
)
1212

13-
1413
class ListAssetsQuery(BaseModel):
1514
include_tags: list[str] = Field(default_factory=list)
1615
exclude_tags: list[str] = Field(default_factory=list)
@@ -57,6 +56,57 @@ def _parse_metadata_json(cls, v):
5756
return None
5857

5958

59+
class UpdateAssetBody(BaseModel):
60+
name: str | None = None
61+
user_metadata: dict[str, Any] | None = None
62+
63+
@model_validator(mode="after")
64+
def _at_least_one(self):
65+
if self.name is None and self.user_metadata is None:
66+
raise ValueError("Provide at least one of: name, user_metadata.")
67+
return self
68+
69+
70+
class CreateFromHashBody(BaseModel):
71+
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
72+
73+
hash: str
74+
name: str
75+
tags: list[str] = Field(default_factory=list)
76+
user_metadata: dict[str, Any] = Field(default_factory=dict)
77+
78+
@field_validator("hash")
79+
@classmethod
80+
def _require_blake3(cls, v):
81+
s = (v or "").strip().lower()
82+
if ":" not in s:
83+
raise ValueError("hash must be 'blake3:<hex>'")
84+
algo, digest = s.split(":", 1)
85+
if algo != "blake3":
86+
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
87+
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
88+
raise ValueError("hash digest must be lowercase hex")
89+
return s
90+
91+
@field_validator("tags", mode="before")
92+
@classmethod
93+
def _tags_norm(cls, v):
94+
if v is None:
95+
return []
96+
if isinstance(v, list):
97+
out = [str(t).strip().lower() for t in v if str(t).strip()]
98+
seen = set()
99+
dedup = []
100+
for t in out:
101+
if t not in seen:
102+
seen.add(t)
103+
dedup.append(t)
104+
return dedup
105+
if isinstance(v, str):
106+
return [t.strip().lower() for t in v.split(",") if t.strip()]
107+
return []
108+
109+
60110
class TagsListQuery(BaseModel):
61111
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
62112

@@ -75,20 +125,140 @@ def normalize_prefix(cls, v: str | None) -> str | None:
75125
return v.lower() or None
76126

77127

78-
class SetPreviewBody(BaseModel):
79-
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
80-
preview_id: str | None = None
128+
class TagsAdd(BaseModel):
129+
model_config = ConfigDict(extra="ignore")
130+
tags: list[str] = Field(..., min_length=1)
81131

82-
@field_validator("preview_id", mode="before")
132+
@field_validator("tags")
83133
@classmethod
84-
def _norm_uuid(cls, v):
134+
def normalize_tags(cls, v: list[str]) -> list[str]:
135+
out = []
136+
for t in v:
137+
if not isinstance(t, str):
138+
raise TypeError("tags must be strings")
139+
tnorm = t.strip().lower()
140+
if tnorm:
141+
out.append(tnorm)
142+
seen = set()
143+
deduplicated = []
144+
for x in out:
145+
if x not in seen:
146+
seen.add(x)
147+
deduplicated.append(x)
148+
return deduplicated
149+
150+
151+
class TagsRemove(TagsAdd):
152+
pass
153+
154+
155+
class UploadAssetSpec(BaseModel):
156+
"""Upload Asset operation.
157+
- tags: ordered; first is root ('models'|'input'|'output');
158+
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
159+
- name: display name
160+
- user_metadata: arbitrary JSON object (optional)
161+
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
162+
163+
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
164+
and the original extension is preserved when available.
165+
"""
166+
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
167+
168+
tags: list[str] = Field(..., min_length=1)
169+
name: str | None = Field(default=None, max_length=512, description="Display Name")
170+
user_metadata: dict[str, Any] = Field(default_factory=dict)
171+
hash: str | None = Field(default=None)
172+
173+
@field_validator("hash", mode="before")
174+
@classmethod
175+
def _parse_hash(cls, v):
85176
if v is None:
86177
return None
87-
s = str(v).strip()
178+
s = str(v).strip().lower()
88179
if not s:
89180
return None
90-
try:
91-
uuid.UUID(s)
92-
except Exception:
93-
raise ValueError("preview_id must be a UUID")
94-
return s
181+
if ":" not in s:
182+
raise ValueError("hash must be 'blake3:<hex>'")
183+
algo, digest = s.split(":", 1)
184+
if algo != "blake3":
185+
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
186+
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
187+
raise ValueError("hash digest must be lowercase hex")
188+
return f"{algo}:{digest}"
189+
190+
@field_validator("tags", mode="before")
191+
@classmethod
192+
def _parse_tags(cls, v):
193+
"""
194+
Accepts a list of strings (possibly multiple form fields),
195+
where each string can be:
196+
- JSON array (e.g., '["models","loras","foo"]')
197+
- comma-separated ('models, loras, foo')
198+
- single token ('models')
199+
Returns a normalized, deduplicated, ordered list.
200+
"""
201+
items: list[str] = []
202+
if v is None:
203+
return []
204+
if isinstance(v, str):
205+
v = [v]
206+
207+
if isinstance(v, list):
208+
for item in v:
209+
if item is None:
210+
continue
211+
s = str(item).strip()
212+
if not s:
213+
continue
214+
if s.startswith("["):
215+
try:
216+
arr = json.loads(s)
217+
if isinstance(arr, list):
218+
items.extend(str(x) for x in arr)
219+
continue
220+
except Exception:
221+
pass # fallback to CSV parse below
222+
items.extend([p for p in s.split(",") if p.strip()])
223+
else:
224+
return []
225+
226+
# normalize + dedupe
227+
norm = []
228+
seen = set()
229+
for t in items:
230+
tnorm = str(t).strip().lower()
231+
if tnorm and tnorm not in seen:
232+
seen.add(tnorm)
233+
norm.append(tnorm)
234+
return norm
235+
236+
@field_validator("user_metadata", mode="before")
237+
@classmethod
238+
def _parse_metadata_json(cls, v):
239+
if v is None or isinstance(v, dict):
240+
return v or {}
241+
if isinstance(v, str):
242+
s = v.strip()
243+
if not s:
244+
return {}
245+
try:
246+
parsed = json.loads(s)
247+
except Exception as e:
248+
raise ValueError(f"user_metadata must be JSON: {e}") from e
249+
if not isinstance(parsed, dict):
250+
raise ValueError("user_metadata must be a JSON object")
251+
return parsed
252+
return {}
253+
254+
@model_validator(mode="after")
255+
def _validate_order(self):
256+
if not self.tags:
257+
raise ValueError("tags must be provided and non-empty")
258+
root = self.tags[0]
259+
if root not in {"models", "input", "output"}:
260+
raise ValueError("first tag must be one of: models, input, output")
261+
if root == "models":
262+
if len(self.tags) < 2:
263+
raise ValueError("models uploads require a category tag as the second tag")
264+
return self

app/assets/api/schemas_out.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@ class AssetsList(BaseModel):
2929
has_more: bool
3030

3131

32+
class AssetUpdated(BaseModel):
33+
id: str
34+
name: str
35+
asset_hash: str | None = None
36+
tags: list[str] = Field(default_factory=list)
37+
user_metadata: dict[str, Any] = Field(default_factory=dict)
38+
updated_at: datetime | None = None
39+
40+
model_config = ConfigDict(from_attributes=True)
41+
42+
@field_serializer("updated_at")
43+
def _ser_updated(self, v: datetime | None, _info):
44+
return v.isoformat() if v else None
45+
46+
3247
class AssetDetail(BaseModel):
3348
id: str
3449
name: str
@@ -48,6 +63,10 @@ def _ser_dt(self, v: datetime | None, _info):
4863
return v.isoformat() if v else None
4964

5065

66+
class AssetCreated(AssetDetail):
67+
created_new: bool
68+
69+
5170
class TagUsage(BaseModel):
5271
name: str
5372
count: int
@@ -58,3 +77,17 @@ class TagsList(BaseModel):
5877
tags: list[TagUsage] = Field(default_factory=list)
5978
total: int
6079
has_more: bool
80+
81+
82+
class TagsAdd(BaseModel):
83+
model_config = ConfigDict(str_strip_whitespace=True)
84+
added: list[str] = Field(default_factory=list)
85+
already_present: list[str] = Field(default_factory=list)
86+
total_tags: list[str] = Field(default_factory=list)
87+
88+
89+
class TagsRemove(BaseModel):
90+
model_config = ConfigDict(str_strip_whitespace=True)
91+
removed: list[str] = Field(default_factory=list)
92+
not_present: list[str] = Field(default_factory=list)
93+
total_tags: list[str] = Field(default_factory=list)

0 commit comments

Comments
 (0)