Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
851b01d
Merge pull request #13 from deruyter92/jaap/add_config_and_registry
xiu-cs Feb 10, 2026
799be4d
Refactor FMPose3D test script to use model_type instead of model_path
xiu-cs Feb 10, 2026
7acba1f
Refactor FMPose3D_main.py to load model using get_model from registry…
xiu-cs Feb 10, 2026
f173935
Update FMPose3D_train.sh to use model_type for model selection instea…
xiu-cs Feb 10, 2026
4f07856
Merge branch 'feat/add_api' into ti_video_demo
xiu-cs Feb 10, 2026
f878e7f
Add model_type argument to opts for model registry selection
xiu-cs Feb 10, 2026
321b28a
Remove unnecessary comment block in HRNet implementation file
xiu-cs Feb 10, 2026
ea1d3f7
Import animal models to ensure their registration in the model registry.
xiu-cs Feb 10, 2026
dd5be5d
Register Model class for animal3D in the model registry and update in…
xiu-cs Feb 10, 2026
a7e25e9
Refactor main_animal3d.py to load model from the registered model reg…
xiu-cs Feb 10, 2026
f1edf44
Update test_animal3d.sh to modify eval_sample_steps, change model_typ…
xiu-cs Feb 10, 2026
51b14ab
Update train_animal3d.sh to modify eval_sample_steps, change model_ty…
xiu-cs Feb 10, 2026
4702481
Refactor vis_animals.py to load model using get_model from the regist…
xiu-cs Feb 10, 2026
b02deda
Update vis_animals.sh to set model_type for FMPose3D and adjust saved…
xiu-cs Feb 10, 2026
16f340c
Update README.md with new download link for pre-trained model
xiu-cs Feb 10, 2026
0512c27
Refactor backup file handling in main_animal3d.py to enable file copy…
xiu-cs Feb 11, 2026
a440a08
Update test_animal3d.sh to change test dataset path and adjust script…
xiu-cs Feb 11, 2026
1f7bfa6
Add default FMPose3DConfig per model_type
deruyter92 Feb 11, 2026
6055786
Update FMPose3DConfig and add SuperAnimalConfig
deruyter92 Feb 11, 2026
f90ce60
expose model registry in main package
deruyter92 Feb 11, 2026
80b2420
update .gitignore: ignore predictions
deruyter92 Feb 11, 2026
8f278ec
Update model_type to "fmpose3d_humans" across configuration and model…
xiu-cs Feb 11, 2026
0138da6
Update model_type to "fmpose3d_humans" in demo and script files for c…
xiu-cs Feb 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,5 @@ htmlcov/
# Excluded directories
pre_trained_models/
demo/predictions/
demo/images/
demo/images/
**/predictions/
2 changes: 1 addition & 1 deletion animals/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ In this part, the FMPose3D model is trained on [Animal3D](https://xujiacong.gith
This visualization script is designed for single-frame based model, allowing you to easily run 3D animal pose estimation on any single image.

Before testing, make sure you have the pre-trained model ready.
You may either use the model trained by your own or download ours from [here](https://drive.google.com/drive/folders/1fMKVaYziwFkAnFrtQZmoPOTfe7Hkl2at?usp=sharing) and place it in the `./pre_trained_models` directory.
You may either use the model trained by your own or download ours from [here](https://drive.google.com/drive/folders/1kL4aOyWNq0o9zB0rSTRM8KYgkySVmUTk?usp=drive_link) and place it in the `./pre_trained_models` directory.

Next, put your test images into folder `demo/images`. Then run the visualization script:
```bash
Expand Down
5 changes: 3 additions & 2 deletions animals/demo/vis_animals.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@
spec.loader.exec_module(module)
CFM = getattr(module, "Model")
else:
# Load model from installed fmpose package
from fmpose3d.models import Model as CFM
# Load model from registered model registry
from fmpose3d.models import get_model
CFM = get_model(args.model_type)

from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images

Expand Down
8 changes: 5 additions & 3 deletions animals/demo/vis_animals.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ sh_file='vis_animals.sh'
# n_joints=26
# out_joints=26

model_path='../pre_trained_models/animal3d_pretrained_weights/model_animal3d.py'
saved_model_path='../pre_trained_models/animal3d_pretrained_weights/CFM_154_4403_best.pth'
model_type='fmpose3d_animals'
# model_path='' # set to a local file path to override the registry
saved_model_path='../pre_trained_models/fmpose3d_animals/fmpose3d_animals_pretrained_weights.pth'

# path='./images/image_00068.jpg' # single image
input_images_folder='./images/' # folder containing multiple images
Expand All @@ -17,7 +18,8 @@ python3 vis_animals.py \
--type 'image' \
--path ${input_images_folder} \
--saved_model_path "${saved_model_path}" \
--model_path "${model_path}" \
${model_path:+--model_path "$model_path"} \
--model_type "${model_type}" \
--sample_steps ${sample_steps} \
--batch_size ${batch_size} \
--layers ${layers} \
Expand Down
33 changes: 15 additions & 18 deletions animals/scripts/main_animal3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@
spec.loader.exec_module(module)
CFM = getattr(module, "Model")
else:
# Load model from installed fmpose package
from fmpose3d.animals.models import Model as CFM
# Load model from registered model registry
from fmpose3d.models import get_model
CFM = get_model(args.model_type)

def train(opt, actions, train_loader, model, optimizer, epoch):
return step('train', opt, actions, train_loader, model, optimizer, epoch)
Expand Down Expand Up @@ -98,7 +99,6 @@ def step(split, args, actions, dataLoader, model, optimizer=None, epoch=None, st
gt_3D = gt_3D.clone()
gt_3D[:, :, args.root_joint] = 0


# Conditional Flow Matching training
# gt_3D, input_2D shape: (B,F,J,C)
# vis_3D shape: (B,F,J,1) - visibility mask
Expand Down Expand Up @@ -217,21 +217,18 @@ def get_parameter_number(net):
os.makedirs(args.checkpoint)

# backup files
# import shutil
# file_path = os.path.abspath(__file__)
# file_name = os.path.basename(file_path)
# shutil.copyfile(src=file_path, dst=os.path.join(args.checkpoint, args.create_time + "_" + file_name))
# shutil.copyfile(src=os.path.abspath("common/arguments.py"), dst=os.path.join(args.checkpoint, args.create_time + "_arguments.py"))
# # backup the selected model file (from --model_path if provided)
# if getattr(args, 'model_path', ''):
# model_src_path = os.path.abspath(args.model_path)
# model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path)
# shutil.copyfile(src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name))
# # shutil.copyfile(src="common/utils.py", dst = os.path.join(args.checkpoint, args.create_time + "_utils.py"))
# sh_base = os.path.basename(args.sh_file)
# dst_name = f"{args.create_time}_" + sh_base
# sh_src = os.path.abspath(args.sh_file)
# shutil.copyfile(src=sh_src, dst=os.path.join(args.checkpoint, dst_name))
import shutil
file_path = os.path.abspath(__file__)
file_name = os.path.basename(file_path)
shutil.copyfile(src=file_path, dst=os.path.join(args.checkpoint, args.create_time + "_" + file_name))
if getattr(args, 'model_path', ''):
model_src_path = os.path.abspath(args.model_path)
model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path)
shutil.copyfile(src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name))
sh_base = os.path.basename(args.sh_file)
dst_name = f"{args.create_time}_" + sh_base
sh_src = os.path.abspath(args.sh_file)
shutil.copyfile(src=sh_src, dst=os.path.join(args.checkpoint, dst_name))

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y/%m/%d %H:%M:%S', \
filename=os.path.join(args.checkpoint, 'train.log'), level=logging.INFO)
Expand Down
12 changes: 7 additions & 5 deletions animals/scripts/test_animal3d.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@ layers=5
batch_size=13
lr=1e-3
gpu_id=0
eval_sample_steps=3
eval_sample_steps=5
num_saved_models=3
frames=1
large_decay_epoch=15
lr_decay_large=0.75
n_joints=26
out_joints=26
epochs=300
# model_path='models/model_animals.py'
model_path='./pre_trained_models/animal3d_pretrained_weights/model_animal3d.py' # when the path is empty, the model will be loaded from the installed fmpose package
saved_model_path='./pre_trained_models/animal3d_pretrained_weights/CFM_154_4403_best.pth'
model_type='fmpose3d_animals'
# model_path='' # set to a local file path to override the registry
saved_model_path='./pre_trained_models/fmpose3d_animals/fmpose3d_animals_pretrained_weights.pth'

# root path denotes the path to the original dataset
root_path="./dataset/"
train_dataset_paths=(
Expand All @@ -24,7 +25,7 @@ test_dataset_paths=(
)

folder_name="TestCtrlAni3D_L${layers}_lr${lr}_B${batch_size}_$(date +%Y%m%d_%H%M%S)"
sh_file='scripts/animals/test_animal3d.sh'
sh_file='scripts/test_animal3d.sh'

python ./scripts/main_animal3d.py \
--root_path ${root_path} \
Expand All @@ -33,6 +34,7 @@ python ./scripts/main_animal3d.py \
--test 1 \
--batch_size ${batch_size} \
--lr ${lr} \
--model_type "${model_type}" \
${model_path:+--model_path "$model_path"} \
--folder_name ${folder_name} \
--layers ${layers} \
Expand Down
10 changes: 4 additions & 6 deletions animals/scripts/train_animal3d.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@ layers=5
batch_size=13
lr=1e-3
gpu_id=0
eval_sample_steps=3
eval_sample_steps=5
num_saved_models=3
frames=1
large_decay_epoch=15
lr_decay_large=0.75
n_joints=26
out_joints=26
epochs=300
# model_path='models/model_animals.py'
model_path="" # when the path is empty, the model will be loaded from the installed fmpose package
model_type='fmpose3d_animals'
# model_path="" # set to a local file path to override the registry
# root path denotes the path to the original dataset
root_path="./dataset/"
train_dataset_paths=(
Expand All @@ -32,7 +30,7 @@ python ./scripts/main_animal3d.py \
--test 1 \
--batch_size ${batch_size} \
--lr ${lr} \
${model_path:+--model_path "$model_path"} \
--model_type "${model_type}" \
--folder_name ${folder_name} \
--layers ${layers} \
--gpu ${gpu_id} \
Expand Down
2 changes: 1 addition & 1 deletion demo/vis_in_the_wild.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ sample_steps=3
batch_size=1
sh_file='vis_in_the_wild.sh'

model_type='fmpose3d'
model_type='fmpose3d_humans'
model_weights_path='../pre_trained_models/fmpose3d_h36m/FMpose3D_pretrained_weights.pth'

target_path='./images/' # folder containing multiple images
Expand Down
8 changes: 8 additions & 0 deletions fmpose3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
Source,
)

# Model registry
from .models import BaseModel, register_model, get_model, list_models

# Import 2D pose detection utilities
from .lib.hrnet.gen_kpts import gen_video_kpts
from .lib.hrnet.hrnet import HRNetPose2d
Expand All @@ -59,6 +62,11 @@
"average_aggregation",
"aggregation_select_single_best_hypothesis_by_2D_error",
"aggregation_RPEA_joint_level",
# Model registry
"BaseModel",
"register_model",
"get_model",
"list_models",
# 2D pose detection
"HRNetPose2d",
"gen_video_kpts",
Expand Down
2 changes: 2 additions & 0 deletions fmpose3d/animals/common/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def init(self):
self.parser.add_argument("--model_dir", type=str, default="")
# Optional: load model class from a specific file path
self.parser.add_argument("--model_path", type=str, default="")
# Model registry name (e.g. "fmpose3d_animals"); used instead of --model_path
self.parser.add_argument("--model_type", type=str, default="fmpose3d_animals")

self.parser.add_argument("--post_refine_reload", action="store_true")
self.parser.add_argument("--checkpoint", type=str, default="")
Expand Down
6 changes: 4 additions & 2 deletions fmpose3d/animals/models/model_animal3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from timm.models.layers import DropPath

from fmpose3d.animals.models.graph_frames import Graph
from fmpose3d.models.base_model import BaseModel, register_model

class TimeEmbedding(nn.Module):
def __init__(self, dim: int, hidden_dim: int = 64):
Expand Down Expand Up @@ -207,9 +208,10 @@ def forward(self, x):
x = self.fc5(x)
return x

class Model(nn.Module):
@register_model("fmpose3d_animals")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you happy with the names "fmpose3d" for humans and "fmpose3d_animals" for animals? Or do you want to change "fmpose3d" -> "fmpose3d_humans" or something similar?

This would be a good moment to choose the final names, before these names will be used in other code as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will change the name of the human model to 'fmpose3d_humans'

class Model(BaseModel):
def __init__(self, args):
super().__init__()
super().__init__(args)

self.graph = Graph('animal3d', 'spatial', pad=1)
self.register_buffer('A', torch.tensor(self.graph.A, dtype=torch.float32))
Expand Down
4 changes: 2 additions & 2 deletions fmpose3d/common/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def init(self):
self.parser.add_argument("--model_dir", type=str, default="")
# Optional: load model class from a specific file path
self.parser.add_argument("--model_path", type=str, default="")
# Model registry name (e.g. "fmpose3d"); used instead of --model_path
self.parser.add_argument("--model_type", type=str, default="fmpose3d")
# Model registry name (e.g. "fmpose3d_humans"); used instead of --model_path
self.parser.add_argument("--model_type", type=str, default="fmpose3d_humans")
self.parser.add_argument("--model_weights_path", type=str, default="")

self.parser.add_argument("--post_refine_reload", action="store_true")
Expand Down
93 changes: 80 additions & 13 deletions fmpose3d/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

import math
from dataclasses import dataclass, field, fields, asdict
from typing import List

from typing import Dict, List

# ---------------------------------------------------------------------------
# Dataclass configuration groups
Expand All @@ -20,24 +19,63 @@
@dataclass
class ModelConfig:
"""Model architecture configuration."""
model_type: str = "fmpose3d"
model_type: str = "fmpose3d_humans"


# Per-model-type defaults for fields marked with INFER_FROM_MODEL_TYPE.
# Also consumed by PipelineConfig.for_model_type to set cross-config
# values (dataset, sample_steps, etc.).
_FMPOSE3D_DEFAULTS: Dict[str, Dict] = {
"fmpose3d_humans": {
"n_joints": 17,
"out_joints": 17,
"dataset": "h36m",
"sample_steps": 3,
"joints_left": [4, 5, 6, 11, 12, 13],
"joints_right": [1, 2, 3, 14, 15, 16],
"root_joint": 0,
},
"fmpose3d_animals": {
"n_joints": 26,
"out_joints": 26,
"dataset": "animal3d",
"sample_steps": 5,
"joints_left": [0, 3, 5, 8, 10, 12, 14, 16, 20, 22],
"joints_right": [1, 4, 6, 9, 11, 13, 15, 17, 21, 23],
"root_joint": 7,
},
}

# Sentinel object for defaults that are inferred from the model type.
INFER_FROM_MODEL_TYPE = object()

@dataclass
class FMPose3DConfig(ModelConfig):
model_type: str = "fmpose3d_humans"
model: str = ""
model_type: str = "fmpose3d"
layers: int = 3
layers: int = 5
channel: int = 512
d_hid: int = 1024
token_dim: int = 256
n_joints: int = 17
out_joints: int = 17
n_joints: int = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
out_joints: int = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
joints_left: List[int] = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
joints_right: List[int] = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
root_joint: int = INFER_FROM_MODEL_TYPE # type: ignore[assignment]
Comment on lines +60 to +64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is added now, so you can load an FMPose3DConfig from model_type with the appropriate number of joints etc:

model_cfg = FMPose3DConfig(model_type="fmpose3d_animal')

in_channels: int = 2
out_channels: int = 3
frames: int = 1
"""Optional: load model class from a specific file path."""

def __post_init__(self):
defaults = _FMPOSE3D_DEFAULTS.get(self.model_type)
if defaults is None:
supported = ", ".join(sorted(_FMPOSE3D_DEFAULTS))
raise ValueError(
f"Unknown model_type {self.model_type!r}; supported: {supported}"
)
for f in fields(self):
if getattr(self, f.name) is INFER_FROM_MODEL_TYPE:
setattr(self, f.name, defaults[f.name])

@dataclass
class DatasetConfig:
Expand Down Expand Up @@ -178,6 +216,33 @@ class HRNetConfig(Pose2DConfig):
hrnet_weights_path: str = ""


@dataclass
class SuperAnimalConfig(Pose2DConfig):
"""DeepLabCut SuperAnimal 2D pose detector configuration.

Uses the DeepLabCut ``superanimal_analyze_images`` API to detect
animal keypoints in the quadruped80K format, then maps them to the
Animal3D 26-keypoint layout expected by the ``fmpose3d_animals``
3D lifter.

Attributes
----------
superanimal_name : str
Name of the SuperAnimal model (default ``"superanimal_quadruped"``).
sa_model_name : str
Backbone architecture (default ``"hrnet_w32"``).
detector_name : str
Object detector used for animal bounding boxes.
max_individuals : int
Maximum number of individuals to detect per image (default 1).
"""
pose2d_model: str = "superanimal"
superanimal_name: str = "superanimal_quadruped"
sa_model_name: str = "hrnet_w32"
detector_name: str = "fasterrcnn_resnet50_fpn_v2"
max_individuals: int = 1


@dataclass
class DemoConfig:
"""Demo / inference configuration."""
Expand Down Expand Up @@ -239,8 +304,6 @@ class PipelineConfig:
demo_cfg: DemoConfig = field(default_factory=DemoConfig)
runtime_cfg: RuntimeConfig = field(default_factory=RuntimeConfig)

# -- construction from argparse namespace ---------------------------------

@classmethod
def from_namespace(cls, ns) -> "PipelineConfig":
"""Build a :class:`PipelineConfig` from an ``argparse.Namespace``
Expand All @@ -258,10 +321,14 @@ def _pick(dc_class, src: dict):

kwargs = {}
for group_name, dc_class in _SUB_CONFIG_CLASSES.items():
if group_name == "model_cfg" and raw.get("model_type", "fmpose3d") == "fmpose3d":
if group_name == "model_cfg" and raw.get("model_type", 'fmpose3d_humans') in _FMPOSE3D_DEFAULTS:
dc_class = FMPose3DConfig
elif group_name == "pose2d_cfg" and raw.get("pose2d_model", "hrnet") == "hrnet":
dc_class = HRNetConfig
elif group_name == "pose2d_cfg":
p2d = raw.get("pose2d_model", "hrnet")
if p2d == "superanimal":
dc_class = SuperAnimalConfig
elif p2d == "hrnet":
dc_class = HRNetConfig
kwargs[group_name] = _pick(dc_class, raw)
return cls(**kwargs)

Expand Down
2 changes: 0 additions & 2 deletions fmpose3d/lib/hrnet/hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
Licensed under Apache 2.0
"""

"""
FMPose3D – clean HRNet 2D pose estimation API.

Provides :class:`HRNetPose2d`, a self-contained wrapper around the
Expand Down
Loading
Loading