diff --git a/.gitignore b/.gitignore
index 5104331c..90439acb 100644
--- a/.gitignore
+++ b/.gitignore
@@ -45,3 +45,8 @@ htmlcov/
*.pkl
*.h5
*.ckpt
+
+# Excluded directories
+pre_trained_models/
+demo/predictions/
+demo/images/
\ No newline at end of file
diff --git a/README.md b/README.md
index 7c83438a..2b241f65 100644
--- a/README.md
+++ b/README.md
@@ -2,11 +2,11 @@

[](https://badge.fury.io/py/fmpose3d)
-[](https://www.gnu.org/licenses/apach2.0)
+[](https://www.apache.org/licenses/LICENSE-2.0)
This is the official implementation of the approach described in the preprint:
-[**FMPose3D: monocular 3D pose estimation via flow matching**](http://arxiv.org/abs/2602.05755)
+[**FMPose3D: monocular 3D pose estimation via flow matching**](https://arxiv.org/abs/2602.05755)
Ti Wang, Xiaohang Yu, Mackenzie Weygandt Mathis
@@ -51,7 +51,7 @@ sh vis_in_the_wild.sh
```
The predictions will be saved to folder `demo/predictions`.
-

+
## Training and Inference
@@ -79,7 +79,7 @@ The training logs, checkpoints, and related files of each training time will be
For training on Human3.6M:
```bash
-sh /scripts/FMPose3D_train.sh
+sh ./scripts/FMPose3D_train.sh
```
### Inference
diff --git a/animals/demo/vis_animals.py b/animals/demo/vis_animals.py
index 357cfe80..c9fe4384 100644
--- a/animals/demo/vis_animals.py
+++ b/animals/demo/vis_animals.py
@@ -8,7 +8,6 @@
"""
# SuperAnimal Demo: https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/COLAB/COLAB_YOURDATA_SuperAnimal.ipynb
-import sys
import os
import numpy as np
import glob
@@ -25,8 +24,6 @@
from fmpose3d.animals.common.arguments import opts as parse_args
from fmpose3d.common.camera import normalize_screen_coordinates, camera_to_world
-sys.path.append(os.getcwd())
-
args = parse_args().parse()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
@@ -334,13 +331,15 @@ def get_pose3D(path, output_dir, type='image'):
print(f"args.n_joints: {args.n_joints}, args.out_joints: {args.out_joints}")
## Reload model
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
model = {}
- model['CFM'] = CFM(args).cuda()
+ model['CFM'] = CFM(args).to(device)
model_dict = model['CFM'].state_dict()
model_path = args.saved_model_path
print(f"Loading model from: {model_path}")
- pre_dict = torch.load(model_path)
+ pre_dict = torch.load(model_path, map_location=device, weights_only=True)
for name, key in model_dict.items():
model_dict[name] = pre_dict[name]
model['CFM'].load_state_dict(model_dict)
@@ -400,7 +399,8 @@ def get_3D_pose_from_image(args, keypoints, i, img, model, output_dir):
input_2D = np.expand_dims(input_2D, axis=0) # (1, J, 2)
# Convert to tensor format matching visualize_animal_poses.py
- input_2D = torch.from_numpy(input_2D.astype('float32')).cuda() # (1, J, 2)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ input_2D = torch.from_numpy(input_2D.astype('float32')).to(device) # (1, J, 2)
input_2D = input_2D.unsqueeze(0) # (1, 1, J, 2)
# Euler sampler for CFM
@@ -418,7 +418,7 @@ def euler_sample(c_2d, y_local, steps, model_3d):
# Single inference without flip augmentation
# Create 3D random noise with shape (1, 1, J, 3)
- y = torch.randn(input_2D.size(0), input_2D.size(1), input_2D.size(2), 3).cuda()
+ y = torch.randn(input_2D.size(0), input_2D.size(1), input_2D.size(2), 3, device=device)
output_3D = euler_sample(input_2D, y, steps=args.sample_steps, model_3d=model)
output_3D = output_3D[0:, args.pad].unsqueeze(1)
diff --git a/animals/scripts/main_animal3d.py b/animals/scripts/main_animal3d.py
index f93e04cf..04362047 100644
--- a/animals/scripts/main_animal3d.py
+++ b/animals/scripts/main_animal3d.py
@@ -75,7 +75,7 @@ def step(split, args, actions, dataLoader, model, optimizer=None, epoch=None, st
# gt_3D shape: torch.Size([B, J, 4]) (x,y,z + homogeneous coordinate)
gt_3D = gt_3D[:,:,:3] # only use x,y,z for 3D ground truth
- # [input_2D, gt_3D, batch_cam, vis_3D] = get_varialbe(split, [input_2D, gt_3D, batch_cam, vis_3D])
+ # [input_2D, gt_3D, batch_cam, vis_3D] = get_variable(split, [input_2D, gt_3D, batch_cam, vis_3D])
# unsqueeze frame dimension
input_2D = input_2D.unsqueeze(1) # (B,F,J,C)
@@ -264,15 +264,17 @@ def get_parameter_number(net):
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
shuffle=False, num_workers=int(args.workers), pin_memory=True)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
model = {}
- model['CFM'] = CFM(args).cuda()
+ model['CFM'] = CFM(args).to(device)
if args.reload:
model_dict = model['CFM'].state_dict()
# Prefer explicit saved_model_path; otherwise fallback to previous_dir glob
model_path = args.saved_model_path
print(model_path)
- pre_dict = torch.load(model_path)
+ pre_dict = torch.load(model_path, weights_only=True, map_location=device)
for name, key in model_dict.items():
model_dict[name] = pre_dict[name]
model['CFM'].load_state_dict(model_dict)
diff --git a/demo/vis_in_the_wild.py b/demo/vis_in_the_wild.py
index 9ca6f1ee..90b2953c 100755
--- a/demo/vis_in_the_wild.py
+++ b/demo/vis_in_the_wild.py
@@ -7,7 +7,6 @@
Licensed under Apache 2.0
"""
-import sys
import cv2
import os
import numpy as np
@@ -16,8 +15,6 @@
from tqdm import tqdm
import copy
-sys.path.append(os.getcwd())
-
# Auto-download checkpoint files if missing
from fmpose3d.lib.checkpoint.download_checkpoints import ensure_checkpoints
ensure_checkpoints()
@@ -28,17 +25,10 @@
args = parse_args().parse()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
-if getattr(args, 'model_path', ''):
- import importlib.util
- import pathlib
- model_abspath = os.path.abspath(args.model_path)
- module_name = pathlib.Path(model_abspath).stem
- spec = importlib.util.spec_from_file_location(module_name, model_abspath)
- module = importlib.util.module_from_spec(spec)
- assert spec.loader is not None
- spec.loader.exec_module(module)
- CFM = getattr(module, 'Model')
-
+
+from fmpose3d.models import get_model
+CFM = get_model(args.model_type)
+
from fmpose3d.common.camera import *
import matplotlib
@@ -50,15 +40,27 @@
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
-def show2Dpose(kps, img):
- connections = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5],
- [5, 6], [0, 7], [7, 8], [8, 9], [9, 10],
- [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]]
+# Shared skeleton definition so 2D/3D segment colors match
+SKELETON_CONNECTIONS = [
+ [0, 1], [1, 2], [2, 3], [0, 4], [4, 5],
+ [5, 6], [0, 7], [7, 8], [8, 9], [9, 10],
+ [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]
+]
+# LR mask for skeleton segments: True -> left color, False -> right color
+SKELETON_LR = np.array(
+ [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
+ dtype=bool,
+)
- LR = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=bool)
+def show2Dpose(kps, img):
+ connections = SKELETON_CONNECTIONS
+ LR = SKELETON_LR
lcolor = (255, 0, 0)
rcolor = (0, 0, 255)
+ # lcolor = (240, 176, 0)
+ # rcolor = (240, 176, 0)
+
thickness = 3
for j,c in enumerate(connections):
@@ -67,8 +69,8 @@ def show2Dpose(kps, img):
start = list(start)
end = list(end)
cv2.line(img, (start[0], start[1]), (end[0], end[1]), lcolor if LR[j] else rcolor, thickness)
- cv2.circle(img, (start[0], start[1]), thickness=-1, color=(0, 255, 0), radius=3)
- cv2.circle(img, (end[0], end[1]), thickness=-1, color=(0, 255, 0), radius=3)
+ # cv2.circle(img, (start[0], start[1]), thickness=-1, color=(0, 255, 0), radius=3)
+ # cv2.circle(img, (end[0], end[1]), thickness=-1, color=(0, 255, 0), radius=3)
return img
@@ -77,11 +79,13 @@ def show3Dpose(vals, ax):
lcolor=(0,0,1)
rcolor=(1,0,0)
-
- I = np.array( [0, 0, 1, 4, 2, 5, 0, 7, 8, 8, 14, 15, 11, 12, 8, 9])
- J = np.array( [1, 4, 2, 5, 3, 6, 7, 8, 14, 11, 15, 16, 12, 13, 9, 10])
-
- LR = np.array([0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0], dtype=bool)
+ # lcolor=(0/255, 176/255, 240/255)
+ # rcolor=(0/255, 176/255, 240/255)
+
+
+ I = np.array([c[0] for c in SKELETON_CONNECTIONS])
+ J = np.array([c[1] for c in SKELETON_CONNECTIONS])
+ LR = SKELETON_LR
for i in np.arange( len(I) ):
x, y, z = [np.array( [vals[I[i], j], vals[J[i], j]] ) for j in range(3)]
@@ -199,7 +203,8 @@ def get_3D_pose_from_image(args, keypoints, i, img, model, output_dir):
input_2D = input_2D[np.newaxis, :, :, :, :]
- input_2D = torch.from_numpy(input_2D.astype('float32')).cuda()
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ input_2D = torch.from_numpy(input_2D.astype('float32')).to(device)
N = input_2D.size(0)
@@ -215,10 +220,10 @@ def euler_sample(c_2d, y_local, steps, model_3d):
## estimation
- y = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3).cuda()
+ y = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3, device=device)
output_3D_non_flip = euler_sample(input_2D[:, 0], y, steps=args.sample_steps, model_3d=model)
- y_flip = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3).cuda()
+ y_flip = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3, device=device)
output_3D_flip = euler_sample(input_2D[:, 1], y_flip, steps=args.sample_steps, model_3d=model)
output_3D_flip[:, :, :, 0] *= -1
@@ -266,14 +271,16 @@ def get_pose3D(path, output_dir, type='image'):
# args.type = type
## Reload
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
model = {}
- model['CFM'] = CFM(args).cuda()
+ model['CFM'] = CFM(args).to(device)
# if args.reload:
model_dict = model['CFM'].state_dict()
- model_path = args.saved_model_path
+ model_path = args.model_weights_path
print(model_path)
- pre_dict = torch.load(model_path)
+ pre_dict = torch.load(model_path, map_location=device, weights_only=True)
for name, key in model_dict.items():
model_dict[name] = pre_dict[name]
model['CFM'].load_state_dict(model_dict)
@@ -336,7 +343,7 @@ def get_pose3D(path, output_dir, type='image'):
## save
output_dir_pose = output_dir +'pose/'
os.makedirs(output_dir_pose, exist_ok=True)
- plt.savefig(output_dir_pose + str(('%04d'% i)) + '_pose.jpg', dpi=200, bbox_inches = 'tight')
+ plt.savefig(output_dir_pose + str(('%04d'% i)) + '_pose.png', dpi=200, bbox_inches = 'tight')
if __name__ == "__main__":
diff --git a/demo/vis_in_the_wild.sh b/demo/vis_in_the_wild.sh
index 3909bc3d..a6f3a4d2 100755
--- a/demo/vis_in_the_wild.sh
+++ b/demo/vis_in_the_wild.sh
@@ -1,21 +1,22 @@
#Test
layers=5
-gpu_id=1
+gpu_id=0
sample_steps=3
batch_size=1
sh_file='vis_in_the_wild.sh'
-model_path='../pre_trained_models/fmpose_detected2d/model_GAMLP.py'
-saved_model_path='../pre_trained_models/fmpose_detected2d/FMpose_36_4972_best.pth'
+model_type='fmpose3d'
+model_weights_path='../pre_trained_models/fmpose3d_h36m/FMpose3D_pretrained_weights.pth'
-# path='./images/image_00068.jpg' # single image
-input_images_folder='./images/' # folder containing multiple images
+target_path='./images/' # folder containing multiple images
+# target_path='./images/xx.png' # single image
+# target_path='./videos/xxx.mp4' # video path
python3 vis_in_the_wild.py \
--type 'image' \
- --path ${input_images_folder} \
- --saved_model_path "${saved_model_path}" \
- --model_path "${model_path}" \
+ --path ${target_path} \
+ --model_weights_path "${model_weights_path}" \
+ --model_type "${model_type}" \
--sample_steps ${sample_steps} \
--batch_size ${batch_size} \
--layers ${layers} \
diff --git a/fmpose3d/__init__.py b/fmpose3d/__init__.py
index 8a9a4716..563a1402 100644
--- a/fmpose3d/__init__.py
+++ b/fmpose3d/__init__.py
@@ -18,17 +18,49 @@
aggregation_RPEA_joint_level,
)
+# Configuration dataclasses
+from .common.config import (
+ FMPose3DConfig,
+ HRNetConfig,
+ InferenceConfig,
+ ModelConfig,
+ PipelineConfig,
+)
+
+# High-level inference API
+from .fmpose3d import (
+ FMPose3DInference,
+ HRNetEstimator,
+ Pose2DResult,
+ Pose3DResult,
+ Source,
+)
+
# Import 2D pose detection utilities
from .lib.hrnet.gen_kpts import gen_video_kpts
+from .lib.hrnet.hrnet import HRNetPose2d
from .lib.preprocess import h36m_coco_format, revise_kpts
# Make commonly used classes/functions available at package level
__all__ = [
+ # Inference API
+ "FMPose3DInference",
+ "HRNetEstimator",
+ "Pose2DResult",
+ "Pose3DResult",
+ "Source",
+ # Configuration
+ "FMPose3DConfig",
+ "HRNetConfig",
+ "InferenceConfig",
+ "ModelConfig",
+ "PipelineConfig",
# Aggregation methods
"average_aggregation",
"aggregation_select_single_best_hypothesis_by_2D_error",
"aggregation_RPEA_joint_level",
# 2D pose detection
+ "HRNetPose2d",
"gen_video_kpts",
"h36m_coco_format",
"revise_kpts",
diff --git a/fmpose3d/aggregation_methods.py b/fmpose3d/aggregation_methods.py
index 38d022f9..4898a03b 100644
--- a/fmpose3d/aggregation_methods.py
+++ b/fmpose3d/aggregation_methods.py
@@ -166,17 +166,13 @@ def aggregation_RPEA_joint_level(
dist[:, :, 0] = 0.0
# Convert 2D losses to weights using softmax over top-k hypotheses per joint
- tau = float(getattr(args, "weight_softmax_tau", 1.0))
H = dist.size(1)
k = int(getattr(args, "topk", None))
- # print("k:", k)
- # k = int(H//2)+1
k = max(1, min(k, H))
# top-k smallest distances along hypothesis dim
topk_vals, topk_idx = torch.topk(dist, k=k, dim=1, largest=False) # (B,k,J)
- # Weight calculation method ; weight_method = 'exp'
temp = args.exp_temp
max_safe_val = temp * 20
topk_vals_clipped = torch.clamp(topk_vals, max=max_safe_val)
diff --git a/fmpose3d/animals/common/arber_dataset.py b/fmpose3d/animals/common/arber_dataset.py
index c70bb838..27dba171 100644
--- a/fmpose3d/animals/common/arber_dataset.py
+++ b/fmpose3d/animals/common/arber_dataset.py
@@ -12,7 +12,6 @@
import glob
import os
import random
-import sys
import cv2
import matplotlib.pyplot as plt
@@ -23,10 +22,8 @@
from torch.utils.data import Dataset
from tqdm import tqdm
-sys.path.append(os.path.dirname(sys.path[0]))
-
-from common.camera import normalize_screen_coordinates
-from common.lifter3d import load_camera_params, load_h5_keypoints
+from fmpose3d.common.camera import normalize_screen_coordinates
+from fmpose3d.animals.common.lifter3d import load_camera_params, load_h5_keypoints
class ArberDataset(Dataset):
diff --git a/fmpose3d/animals/common/utils.py b/fmpose3d/animals/common/utils.py
index d4496625..cdafd8c1 100755
--- a/fmpose3d/animals/common/utils.py
+++ b/fmpose3d/animals/common/utils.py
@@ -15,7 +15,6 @@
import numpy as np
import torch
-from torch.autograd import Variable
def mpjpe_cal(predicted, target):
@@ -220,18 +219,17 @@ def update(self, val, n=1):
self.avg = self.sum / self.count
-def get_varialbe(split, target):
+def get_variable(split, target):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num = len(target)
var = []
if split == "train":
for i in range(num):
- temp = (
- Variable(target[i], requires_grad=False).contiguous().type(torch.cuda.FloatTensor)
- )
+ temp = target[i].requires_grad_(False).contiguous().float().to(device)
var.append(temp)
else:
for i in range(num):
- temp = Variable(target[i]).contiguous().cuda().type(torch.cuda.FloatTensor)
+ temp = target[i].contiguous().float().to(device)
var.append(temp)
return var
diff --git a/fmpose3d/common/__init__.py b/fmpose3d/common/__init__.py
index 44082829..c7863030 100644
--- a/fmpose3d/common/__init__.py
+++ b/fmpose3d/common/__init__.py
@@ -12,6 +12,22 @@
"""
from .arguments import opts
+from .config import (
+ PipelineConfig,
+ ModelConfig,
+ FMPose3DConfig,
+ HRNetConfig,
+ Pose2DConfig,
+ DatasetConfig,
+ TrainingConfig,
+ InferenceConfig,
+ AggregationConfig,
+ CheckpointConfig,
+ RefinementConfig,
+ OutputConfig,
+ DemoConfig,
+ RuntimeConfig,
+)
from .h36m_dataset import Human36mDataset
from .load_data_hm36 import Fusion
from .utils import (
@@ -22,11 +38,25 @@
save_top_N_models,
test_calculation,
print_error,
- get_varialbe,
+ get_variable,
)
__all__ = [
"opts",
+ "PipelineConfig",
+ "FMPose3DConfig",
+ "HRNetConfig",
+ "Pose2DConfig",
+ "ModelConfig",
+ "DatasetConfig",
+ "TrainingConfig",
+ "InferenceConfig",
+ "AggregationConfig",
+ "CheckpointConfig",
+ "RefinementConfig",
+ "OutputConfig",
+ "DemoConfig",
+ "RuntimeConfig",
"Human36mDataset",
"Fusion",
"mpjpe_cal",
@@ -36,6 +66,6 @@
"save_top_N_models",
"test_calculation",
"print_error",
- "get_varialbe",
+ "get_variable",
]
diff --git a/fmpose3d/common/arguments.py b/fmpose3d/common/arguments.py
index 3777a1fd..b94db985 100755
--- a/fmpose3d/common/arguments.py
+++ b/fmpose3d/common/arguments.py
@@ -53,7 +53,6 @@ def init(self):
self.parser.add_argument("-s", "--stride", default=1, type=int)
self.parser.add_argument("--gpu", default="0", type=str, help="")
self.parser.add_argument("--train", action="store_true")
- # self.parser.add_argument('--test', action='store_true')
self.parser.add_argument("--test", type=int, default=1) #
self.parser.add_argument("--nepoch", type=int, default=41) #
self.parser.add_argument(
@@ -75,13 +74,15 @@ def init(self):
self.parser.add_argument("--model_dir", type=str, default="")
# Optional: load model class from a specific file path
self.parser.add_argument("--model_path", type=str, default="")
+ # Model registry name (e.g. "fmpose3d"); used instead of --model_path
+ self.parser.add_argument("--model_type", type=str, default="fmpose3d")
+ self.parser.add_argument("--model_weights_path", type=str, default="")
self.parser.add_argument("--post_refine_reload", action="store_true")
self.parser.add_argument("--checkpoint", type=str, default="")
self.parser.add_argument(
"--previous_dir", type=str, default="./pre_trained_model/pretrained"
)
- self.parser.add_argument("--saved_model_path", type=str, default="")
self.parser.add_argument("--n_joints", type=int, default=17)
self.parser.add_argument("--out_joints", type=int, default=17)
@@ -148,7 +149,6 @@ def init(self):
# uncertainty-aware aggregation threshold factor
self.parser.add_argument("--topk", type=int, default=3)
- self.parser.add_argument("--weight_softmax_tau", type=float, default=1.0)
self.parser.add_argument("--exp_temp", type=float, default=0.002)
self.parser.add_argument("--mode", type=str, default="exp")
diff --git a/fmpose3d/common/config.py b/fmpose3d/common/config.py
new file mode 100644
index 00000000..b2980e1f
--- /dev/null
+++ b/fmpose3d/common/config.py
@@ -0,0 +1,276 @@
+"""
+FMPose3D: monocular 3D Pose Estimation via Flow Matching
+
+Official implementation of the paper:
+"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
+by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
+Licensed under Apache 2.0
+"""
+
+import math
+from dataclasses import dataclass, field, fields, asdict
+from typing import List
+
+
+# ---------------------------------------------------------------------------
+# Dataclass configuration groups
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class ModelConfig:
+ """Model architecture configuration."""
+ model_type: str = "fmpose3d"
+
+
+@dataclass
+class FMPose3DConfig(ModelConfig):
+ model: str = ""
+ model_type: str = "fmpose3d"
+ layers: int = 3
+ channel: int = 512
+ d_hid: int = 1024
+ token_dim: int = 256
+ n_joints: int = 17
+ out_joints: int = 17
+ in_channels: int = 2
+ out_channels: int = 3
+ frames: int = 1
+ """Optional: load model class from a specific file path."""
+
+
+@dataclass
+class DatasetConfig:
+ """Dataset and data loading configuration."""
+
+ dataset: str = "h36m"
+ keypoints: str = "cpn_ft_h36m_dbb"
+ root_path: str = "dataset/"
+ actions: str = "*"
+ downsample: int = 1
+ subset: float = 1.0
+ stride: int = 1
+ crop_uv: int = 0
+ out_all: int = 1
+ train_views: List[int] = field(default_factory=lambda: [0, 1, 2, 3])
+ test_views: List[int] = field(default_factory=lambda: [0, 1, 2, 3])
+
+ # Derived / set during parse based on dataset choice
+ subjects_train: str = "S1,S5,S6,S7,S8"
+ subjects_test: str = "S9,S11"
+ root_joint: int = 0
+ joints_left: List[int] = field(default_factory=list)
+ joints_right: List[int] = field(default_factory=list)
+
+
+@dataclass
+class TrainingConfig:
+ """Training hyperparameters and settings."""
+
+ train: bool = False
+ nepoch: int = 41
+ batch_size: int = 128
+ lr: float = 1e-3
+ lr_decay: float = 0.95
+ lr_decay_large: float = 0.5
+ large_decay_epoch: int = 5
+ workers: int = 8
+ data_augmentation: bool = True
+ reverse_augmentation: bool = False
+ norm: float = 0.01
+
+
+@dataclass
+class InferenceConfig:
+ """Evaluation and testing configuration."""
+
+ test: int = 1
+ test_augmentation: bool = True
+ test_augmentation_flip_hypothesis: bool = False
+ test_augmentation_FlowAug: bool = False
+ sample_steps: int = 3
+ eval_multi_steps: bool = False
+ eval_sample_steps: str = "1,3,5,7,9"
+ num_hypothesis_list: str = "1"
+ hypothesis_num: int = 1
+ guidance_scale: float = 1.0
+
+
+@dataclass
+class AggregationConfig:
+ """Hypothesis aggregation configuration."""
+
+ topk: int = 3
+ exp_temp: float = 0.002
+ mode: str = "exp"
+ opt_steps: int = 2
+
+
+@dataclass
+class CheckpointConfig:
+ """Checkpoint loading and saving configuration."""
+
+ reload: bool = False
+ model_dir: str = ""
+ model_weights_path: str = ""
+ checkpoint: str = ""
+ previous_dir: str = "./pre_trained_model/pretrained"
+ num_saved_models: int = 3
+ previous_best_threshold: float = math.inf
+ previous_name: str = ""
+
+
+@dataclass
+class RefinementConfig:
+ """Post-refinement model configuration."""
+
+ post_refine: bool = False
+ post_refine_reload: bool = False
+ previous_post_refine_name: str = ""
+ lr_refine: float = 1e-5
+ refine: bool = False
+ reload_refine: bool = False
+ previous_refine_name: str = ""
+
+
+@dataclass
+class OutputConfig:
+ """Output, logging, and file management configuration."""
+
+ create_time: str = ""
+ filename: str = ""
+ create_file: int = 1
+ debug: bool = False
+ folder_name: str = ""
+ sh_file: str = ""
+
+
+@dataclass
+class Pose2DConfig:
+ """2D pose estimator configuration."""
+ pose2d_model: str = "hrnet"
+
+
+@dataclass
+class HRNetConfig(Pose2DConfig):
+ """HRNet 2D pose detector configuration.
+
+ Attributes
+ ----------
+ det_dim : int
+ YOLO input resolution for human detection (default 416).
+ num_persons : int
+ Maximum number of persons to estimate per frame (default 1).
+ thred_score : float
+ YOLO object-confidence threshold (default 0.30).
+ hrnet_cfg_file : str
+ Path to the HRNet YAML experiment config. When left empty the
+ bundled ``w48_384x288_adam_lr1e-3.yaml`` is used.
+ hrnet_weights_path : str
+ Path to the HRNet ``.pth`` checkpoint. When left empty the
+ auto-downloaded ``pose_hrnet_w48_384x288.pth`` is used.
+ """
+ pose2d_model: str = "hrnet"
+ det_dim: int = 416
+ num_persons: int = 1
+ thred_score: float = 0.30
+ hrnet_cfg_file: str = ""
+ hrnet_weights_path: str = ""
+
+
+@dataclass
+class DemoConfig:
+ """Demo / inference configuration."""
+
+ type: str = "image"
+ """Input type: ``'image'`` or ``'video'``."""
+ path: str = "demo/images/running.png"
+ """Path to input file or directory."""
+
+
+@dataclass
+class RuntimeConfig:
+ """Runtime environment configuration."""
+
+ gpu: str = "0"
+ pad: int = 0 # derived: (frames - 1) // 2
+ single: bool = False
+ reload_3d: bool = False
+
+
+# ---------------------------------------------------------------------------
+# Composite configuration
+# ---------------------------------------------------------------------------
+
+_SUB_CONFIG_CLASSES = {
+ "model_cfg": ModelConfig,
+ "dataset_cfg": DatasetConfig,
+ "training_cfg": TrainingConfig,
+ "inference_cfg": InferenceConfig,
+ "aggregation_cfg": AggregationConfig,
+ "checkpoint_cfg": CheckpointConfig,
+ "refinement_cfg": RefinementConfig,
+ "output_cfg": OutputConfig,
+ "pose2d_cfg": Pose2DConfig,
+ "demo_cfg": DemoConfig,
+ "runtime_cfg": RuntimeConfig,
+}
+
+
+@dataclass
+class PipelineConfig:
+ """Top-level configuration for FMPose3D pipeline.
+
+ Groups related settings into sub-configs::
+
+ config.model_cfg.layers
+ config.training_cfg.lr
+ """
+
+ model_cfg: ModelConfig = field(default_factory=FMPose3DConfig)
+ dataset_cfg: DatasetConfig = field(default_factory=DatasetConfig)
+ training_cfg: TrainingConfig = field(default_factory=TrainingConfig)
+ inference_cfg: InferenceConfig = field(default_factory=InferenceConfig)
+ aggregation_cfg: AggregationConfig = field(default_factory=AggregationConfig)
+ checkpoint_cfg: CheckpointConfig = field(default_factory=CheckpointConfig)
+ refinement_cfg: RefinementConfig = field(default_factory=RefinementConfig)
+ output_cfg: OutputConfig = field(default_factory=OutputConfig)
+ pose2d_cfg: Pose2DConfig = field(default_factory=HRNetConfig)
+ demo_cfg: DemoConfig = field(default_factory=DemoConfig)
+ runtime_cfg: RuntimeConfig = field(default_factory=RuntimeConfig)
+
+ # -- construction from argparse namespace ---------------------------------
+
+ @classmethod
+ def from_namespace(cls, ns) -> "PipelineConfig":
+ """Build a :class:`PipelineConfig` from an ``argparse.Namespace``
+
+ Example::
+
+ args = opts().parse()
+ cfg = PipelineConfig.from_namespace(args)
+ """
+ raw = vars(ns) if hasattr(ns, "__dict__") else dict(ns)
+
+ def _pick(dc_class, src: dict):
+ names = {f.name for f in fields(dc_class)}
+ return dc_class(**{k: v for k, v in src.items() if k in names})
+
+ kwargs = {}
+ for group_name, dc_class in _SUB_CONFIG_CLASSES.items():
+ if group_name == "model_cfg" and raw.get("model_type", "fmpose3d") == "fmpose3d":
+ dc_class = FMPose3DConfig
+ elif group_name == "pose2d_cfg" and raw.get("pose2d_model", "hrnet") == "hrnet":
+ dc_class = HRNetConfig
+ kwargs[group_name] = _pick(dc_class, raw)
+ return cls(**kwargs)
+
+ # -- utilities ------------------------------------------------------------
+
+ def to_dict(self) -> dict:
+ """Return a flat dictionary of all configuration values."""
+ result = {}
+ for group_name in _SUB_CONFIG_CLASSES:
+ result.update(asdict(getattr(self, group_name)))
+ return result
+
diff --git a/fmpose3d/common/utils.py b/fmpose3d/common/utils.py
index 11cf3747..549ef2bd 100755
--- a/fmpose3d/common/utils.py
+++ b/fmpose3d/common/utils.py
@@ -15,7 +15,44 @@
import numpy as np
import torch
-from torch.autograd import Variable
+
+
+def euler_sample(
+ c_2d: torch.Tensor,
+ y: torch.Tensor,
+ steps: int,
+ model: torch.nn.Module,
+) -> torch.Tensor:
+ """Euler ODE sampler for Conditional Flow Matching at test time.
+
+ Integrates the learned velocity field from *t = 0* to *t = 1* using
+ ``steps`` uniform Euler steps.
+
+ Parameters
+ ----------
+ c_2d : Tensor
+ 2-D conditioning input, shape ``(B, F, J, 2)``.
+ y : Tensor
+ Initial noise sample (same spatial dims as ``c_2d`` but with 3
+ output channels), shape ``(B, F, J, 3)``.
+ steps : int
+ Number of Euler integration steps.
+ model : nn.Module
+ The velocity-prediction network ``v(c_2d, y, t)``.
+
+ Returns
+ -------
+ Tensor
+ The denoised 3-D prediction, same shape as *y*.
+ """
+ dt = 1.0 / steps
+ for s in range(steps):
+ t_s = torch.full(
+ (c_2d.size(0), 1, 1, 1), s * dt, device=c_2d.device, dtype=c_2d.dtype
+ )
+ v_s = model(c_2d, y, t_s)
+ y = y + dt * v_s
+ return y
def deterministic_random(min_value, max_value, data):
digest = hashlib.sha256(data.encode()).digest()
@@ -186,20 +223,17 @@ def update(self, val, n=1):
self.avg = self.sum / self.count
-def get_varialbe(split, target):
+def get_variable(split, target):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num = len(target)
var = []
if split == "train":
for i in range(num):
- temp = (
- Variable(target[i], requires_grad=False)
- .contiguous()
- .type(torch.cuda.FloatTensor)
- )
+ temp = target[i].requires_grad_(False).contiguous().float().to(device)
var.append(temp)
else:
for i in range(num):
- temp = Variable(target[i]).contiguous().cuda().type(torch.cuda.FloatTensor)
+ temp = target[i].contiguous().float().to(device)
var.append(temp)
return var
diff --git a/fmpose3d/fmpose3d.py b/fmpose3d/fmpose3d.py
new file mode 100644
index 00000000..e1992d8e
--- /dev/null
+++ b/fmpose3d/fmpose3d.py
@@ -0,0 +1,658 @@
+"""
+FMPose3D: monocular 3D Pose Estimation via Flow Matching
+
+Official implementation of the paper:
+"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
+by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
+Licensed under Apache 2.0
+"""
+
+
+from __future__ import annotations
+
+import copy
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Callable, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+
+from fmpose3d.common.camera import camera_to_world, normalize_screen_coordinates
+from fmpose3d.common.utils import euler_sample
+from fmpose3d.common.config import (
+ FMPose3DConfig,
+ HRNetConfig,
+ InferenceConfig,
+ ModelConfig,
+)
+from fmpose3d.models import get_model
+
+#: Progress callback signature: ``(current_step, total_steps) -> None``.
+ProgressCallback = Callable[[int, int], None]
+
+
+# Default camera-to-world rotation quaternion (from the demo script).
+_DEFAULT_CAM_ROTATION = np.array(
+ [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088],
+ dtype="float32",
+)
+
+
+# ---------------------------------------------------------------------------
+# 2D pose estimator
+# ---------------------------------------------------------------------------
+
+
+class HRNetEstimator:
+ """Default 2D pose estimator: HRNet + YOLO, with COCO→H36M conversion.
+
+ Thin wrapper around :class:`~fmpose3d.lib.hrnet.api.HRNetPose2d` that
+ adds the COCO → H36M keypoint conversion expected by the 3D lifter.
+
+ Parameters
+ ----------
+ cfg : HRNetConfig
+ Estimator settings (``det_dim``, ``num_persons``, …).
+ """
+
+ def __init__(self, cfg: HRNetConfig | None = None) -> None:
+ self.cfg = cfg or HRNetConfig()
+ self._model = None
+
+ def setup_runtime(self) -> None:
+ """Load YOLO + HRNet models (safe to call more than once)."""
+ if self._model is not None:
+ return
+
+ from fmpose3d.lib.hrnet.hrnet import HRNetPose2d
+
+ self._model = HRNetPose2d(
+ det_dim=self.cfg.det_dim,
+ num_persons=self.cfg.num_persons,
+ thred_score=self.cfg.thred_score,
+ hrnet_cfg_file=self.cfg.hrnet_cfg_file,
+ hrnet_weights_path=self.cfg.hrnet_weights_path,
+ )
+ self._model.setup()
+
+ def predict(
+ self, frames: np.ndarray
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Estimate 2D keypoints from image frames and return in H36M format.
+
+ Parameters
+ ----------
+ frames : ndarray
+ BGR image frames, shape ``(N, H, W, C)``.
+
+ Returns
+ -------
+ keypoints : ndarray
+ H36M-format 2D keypoints, shape ``(num_persons, N, 17, 2)``.
+ scores : ndarray
+ Per-joint confidence scores, shape ``(num_persons, N, 17)``.
+ """
+ from fmpose3d.lib.preprocess import h36m_coco_format, revise_kpts
+
+ self.setup_runtime()
+
+ keypoints, scores = self._model.predict(frames)
+
+ keypoints, scores, valid_frames = h36m_coco_format(keypoints, scores)
+ # NOTE: revise_kpts is computed for consistency but is NOT applied
+ # to the returned keypoints, matching the demo script behaviour.
+ _revised = revise_kpts(keypoints, scores, valid_frames) # noqa: F841
+
+ return keypoints, scores
+
+
+# ---------------------------------------------------------------------------
+# Result containers
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class Pose2DResult:
+ """Container returned by :meth:`FMPose3DInference.prepare_2d`."""
+
+ keypoints: np.ndarray
+ """H36M-format 2D keypoints, shape ``(num_persons, num_frames, 17, 2)``."""
+ scores: np.ndarray
+ """Per-joint confidence scores, shape ``(num_persons, num_frames, 17)``."""
+ image_size: tuple[int, int] = (0, 0)
+ """``(height, width)`` of the source frames."""
+
+
+@dataclass
+class Pose3DResult:
+ """Container returned by :meth:`FMPose3DInference.pose_3d`."""
+
+ poses_3d: np.ndarray
+ """Root-relative 3D poses, shape ``(num_frames, 17, 3)``."""
+ poses_3d_world: np.ndarray
+ """World-coordinate 3D poses, shape ``(num_frames, 17, 3)``."""
+
+
+#: Accepted source types for :meth:`FMPose3DInference.predict`.
+#:
+#: * ``str`` or ``Path`` – path to an image file or directory of images.
+#: * ``np.ndarray`` – a single frame ``(H, W, C)`` or batch ``(N, H, W, C)``.
+#: * ``list`` – a list of file paths or a list of ``(H, W, C)`` arrays.
+Source = Union[str, Path, np.ndarray, Sequence[Union[str, Path, np.ndarray]]]
+
+
+@dataclass
+class _IngestedInput:
+ """Normalised result of :meth:`FMPose3DInference._ingest_input`.
+
+ Always contains a batch of BGR frames as a numpy array, regardless
+ of the original source type.
+ """
+
+ frames: np.ndarray
+ """BGR image frames, shape ``(N, H, W, C)``."""
+ image_size: tuple[int, int]
+ """``(height, width)`` of the source frames."""
+
+
+# ---------------------------------------------------------------------------
+# Main inference class
+# ---------------------------------------------------------------------------
+
+
+class FMPose3DInference:
+ """High-level, two-step inference API for FMPose3D.
+
+ Typical workflow::
+
+ api = FMPose3DInference(model_weights_path="weights.pth")
+ result_2d = api.prepare_2d("photo.jpg")
+ result_3d = api.pose_3d(result_2d.keypoints, image_size=(H, W))
+
+ Parameters
+ ----------
+ model_cfg : ModelConfig, optional
+ Model architecture settings (layers, channels, …).
+ Defaults to :class:`~fmpose3d.common.config.FMPose3DConfig` defaults.
+ inference_cfg : InferenceConfig, optional
+ Inference settings (sample_steps, test_augmentation, …).
+ Defaults to :class:`~fmpose3d.common.config.InferenceConfig` defaults.
+ model_weights_path : str
+ Path to a ``.pth`` checkpoint for the 3D lifting model.
+ If empty the model is created but **not** loaded with weights.
+ device : str or torch.device, optional
+ Compute device. ``None`` (default) picks CUDA when available.
+ """
+
+ # H36M joint indices for left / right flip augmentation
+ _JOINTS_LEFT: list[int] = [4, 5, 6, 11, 12, 13]
+ _JOINTS_RIGHT: list[int] = [1, 2, 3, 14, 15, 16]
+
+ _IMAGE_EXTENSIONS: set[str] = {
+ ".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp",
+ }
+
+ # ------------------------------------------------------------------
+ # Initialisation
+ # ------------------------------------------------------------------
+
+ def __init__(
+ self,
+ model_cfg: ModelConfig | None = None,
+ inference_cfg: InferenceConfig | None = None,
+ model_weights_path: str = "",
+ device: str | torch.device | None = None,
+ ) -> None:
+ self.model_cfg = model_cfg or FMPose3DConfig()
+ self.inference_cfg = inference_cfg or InferenceConfig()
+ self.model_weights_path = model_weights_path
+
+ # Resolve device and padding configuration
+ self._device: torch.device | None = self._resolve_device(device)
+ self._pad: int = self._resolve_pad()
+
+ # Lazy-loaded models (populated by setup_runtime)
+ self._model_3d: torch.nn.Module | None = None
+ self._estimator_2d: HRNetEstimator | None = None
+
+ def setup_runtime(self) -> None:
+ """Initialise all runtime components on first use.
+
+ Called automatically when the API is used for the first time.
+ Loads the 2D estimator, the 3D lifting model, and the model
+ weights in sequence.
+ """
+ self._setup_estimator_2d()
+ self._setup_model()
+ self._load_weights()
+
+ # ------------------------------------------------------------------
+ # Public API
+ # ------------------------------------------------------------------
+
+ @torch.no_grad()
+ def predict(
+ self,
+ source: Source,
+ *,
+ camera_rotation: np.ndarray | None = _DEFAULT_CAM_ROTATION,
+ seed: int | None = None,
+ progress: ProgressCallback | None = None,
+ ) -> Pose3DResult:
+ """End-to-end prediction: 2D pose estimation → 3D lifting.
+
+ Convenience wrapper that calls :meth:`prepare_2d` then
+ :meth:`pose_3d`.
+
+ Parameters
+ ----------
+ source : Source
+ Input to process. Accepts a file path (``str`` / ``Path``),
+ a directory of images, a numpy array ``(H, W, C)`` for a
+ single frame, ``(N, H, W, C)`` for a batch, or a list of
+ paths / arrays. See :data:`Source` for the full type.
+ Video files are **not** supported and will raise
+ :class:`NotImplementedError`.
+ camera_rotation : ndarray or None
+ Length-4 quaternion for the camera-to-world rotation.
+ See :meth:`pose_3d` for details.
+ seed : int or None
+ Deterministic seed for the 3D sampling step.
+ See :meth:`pose_3d` for details.
+ progress : ProgressCallback or None
+ Optional ``(current_step, total_steps)`` callback. Forwarded
+ to the :meth:`pose_3d` step (per-frame reporting).
+
+ Returns
+ -------
+ Pose3DResult
+ Root-relative and world-coordinate 3D poses.
+ """
+ result_2d = self.prepare_2d(source)
+ return self.pose_3d(
+ result_2d.keypoints,
+ result_2d.image_size,
+ camera_rotation=camera_rotation,
+ seed=seed,
+ progress=progress,
+ )
+
+ @torch.no_grad()
+ def prepare_2d(
+ self,
+ source: Source,
+ progress: ProgressCallback | None = None,
+ ) -> Pose2DResult:
+ """Estimate 2D poses using HRNet + YOLO.
+
+ The estimator is set up lazily by :meth:`setup_runtime` on first
+ call.
+
+ Parameters
+ ----------
+ source : Source
+ Input to process. Accepts a file path (``str`` / ``Path``),
+ a directory of images, a numpy array ``(H, W, C)`` for a
+ single frame, ``(N, H, W, C)`` for a batch, or a list of
+ paths / arrays. See :data:`Source` for the full type.
+ progress : ProgressCallback or None
+ Optional ``(current_step, total_steps)`` callback invoked
+ before and after the 2D estimation step.
+
+ Returns
+ -------
+ Pose2DResult
+ H36M-format 2D keypoints and per-joint scores. The result
+ also carries ``image_size`` so it can be forwarded directly
+ to :meth:`pose_3d`.
+ """
+ ingested = self._ingest_input(source)
+ self.setup_runtime()
+ if progress:
+ progress(0, 1)
+ keypoints, scores = self._estimator_2d.predict(ingested.frames)
+ if progress:
+ progress(1, 1)
+ return Pose2DResult(
+ keypoints=keypoints,
+ scores=scores,
+ image_size=ingested.image_size,
+ )
+
+ @torch.no_grad()
+ def pose_3d(
+ self,
+ keypoints_2d: np.ndarray,
+ image_size: tuple[int, int],
+ *,
+ camera_rotation: np.ndarray | None = _DEFAULT_CAM_ROTATION,
+ seed: int | None = None,
+ progress: ProgressCallback | None = None,
+ ) -> Pose3DResult:
+ """Lift 2D keypoints to 3D using the flow-matching model.
+
+ The pipeline exactly mirrors ``demo/vis_in_the_wild.py``'s
+ ``get_3D_pose_from_image``: normalise screen coordinates, build a
+ flip-augmented conditioning pair, run two independent Euler ODE
+ integrations (each with its own noise sample), un-flip and average,
+ zero the root joint, then convert to world coordinates.
+
+ Parameters
+ ----------
+ keypoints_2d : ndarray
+ 2D keypoints returned by :meth:`prepare_2d`. Accepted shapes:
+
+ * ``(num_persons, num_frames, 17, 2)`` – first person is used.
+ * ``(num_frames, 17, 2)`` – treated as a single person.
+ image_size : tuple of (int, int)
+ ``(height, width)`` of the source image / video frames.
+ camera_rotation : ndarray or None
+ Length-4 quaternion for the camera-to-world rotation applied
+ to produce ``poses_3d_world``. Defaults to the rotation used
+ in the official demo. Pass ``None`` to skip the transform
+ (``poses_3d_world`` will equal ``poses_3d``).
+ seed : int or None
+ If given, ``torch.manual_seed(seed)`` is called before
+ sampling so that results are fully reproducible. Use the
+ same seed in the demo script (by inserting
+ ``torch.manual_seed(seed)`` before the ``torch.randn`` calls)
+ to obtain bit-identical results.
+ progress : ProgressCallback or None
+ Optional ``(current_step, total_steps)`` callback invoked
+ after each frame is lifted to 3D.
+
+ Returns
+ -------
+ Pose3DResult
+ Root-relative and world-coordinate 3D poses.
+ """
+ self.setup_runtime()
+ model = self._model_3d
+ h, w = image_size
+ steps = self.inference_cfg.sample_steps
+ use_flip = self.inference_cfg.test_augmentation
+ jl = self._JOINTS_LEFT
+ jr = self._JOINTS_RIGHT
+
+ # Optional deterministic seeding
+ if seed is not None:
+ torch.manual_seed(seed)
+
+ # Normalise input shape to (num_frames, 17, 2)
+ if keypoints_2d.ndim == 4:
+ kpts = keypoints_2d[0] # first person
+ elif keypoints_2d.ndim == 3:
+ kpts = keypoints_2d
+ else:
+ raise ValueError(
+ f"Expected keypoints_2d with 3 or 4 dims, got {keypoints_2d.ndim}"
+ )
+
+ num_frames = kpts.shape[0]
+ all_poses_3d: list[np.ndarray] = []
+ all_poses_world: list[np.ndarray] = []
+
+ if progress:
+ progress(0, num_frames)
+
+ for i in range(num_frames):
+ frame_kpts = kpts[i : i + 1] # (1, 17, 2)
+
+ # Normalise to [-1, 1] range (same as demo)
+ normed = normalize_screen_coordinates(frame_kpts, w=w, h=h)
+
+ if use_flip:
+ # -- build flip-augmented conditioning (matches demo exactly) --
+ normed_flip = copy.deepcopy(normed)
+ normed_flip[:, :, 0] *= -1
+ normed_flip[:, jl + jr] = normed_flip[:, jr + jl]
+ input_2d = np.concatenate(
+ (np.expand_dims(normed, axis=0), np.expand_dims(normed_flip, axis=0)),
+ 0,
+ ) # (2, F, J, 2)
+ input_2d = input_2d[np.newaxis, :, :, :, :] # (1, 2, F, J, 2)
+ input_t = torch.from_numpy(input_2d.astype("float32")).to(self.device)
+
+ # -- two independent Euler ODE runs (matches demo exactly) --
+ y = torch.randn(
+ input_t.size(0), input_t.size(2), input_t.size(3), 3,
+ device=self.device,
+ )
+ output_3d_non_flip = euler_sample(
+ input_t[:, 0], y, steps, model,
+ )
+
+ y_flip = torch.randn(
+ input_t.size(0), input_t.size(2), input_t.size(3), 3,
+ device=self.device,
+ )
+ output_3d_flip = euler_sample(
+ input_t[:, 1], y_flip, steps, model,
+ )
+
+ # -- un-flip & average (matches demo exactly) --
+ output_3d_flip[:, :, :, 0] *= -1
+ output_3d_flip[:, :, jl + jr, :] = output_3d_flip[
+ :, :, jr + jl, :
+ ]
+
+ output = (output_3d_non_flip + output_3d_flip) / 2
+ else:
+ input_2d = normed[np.newaxis] # (1, F, J, 2)
+ input_t = torch.from_numpy(input_2d.astype("float32")).to(self.device)
+ y = torch.randn(
+ input_t.size(0), input_t.size(1), input_t.size(2), 3,
+ device=self.device,
+ )
+ output = euler_sample(input_t, y, steps, model)
+
+ # Extract single-frame result → (17, 3) (matches demo exactly)
+ output = output[0:, self._pad].unsqueeze(1)
+ output[:, :, 0, :] = 0 # root-relative
+ pose_np = output[0, 0].cpu().detach().numpy()
+ all_poses_3d.append(pose_np)
+
+ # Camera-to-world transform (matches demo exactly)
+ if camera_rotation is not None:
+ pose_world = camera_to_world(pose_np, R=camera_rotation, t=0)
+ pose_world[:, 2] -= np.min(pose_world[:, 2])
+ else:
+ pose_world = pose_np.copy()
+ all_poses_world.append(pose_world)
+
+ if progress:
+ progress(i + 1, num_frames)
+
+ poses_3d = np.stack(all_poses_3d, axis=0) # (num_frames, 17, 3)
+ poses_world = np.stack(all_poses_world, axis=0) # (num_frames, 17, 3)
+
+ return Pose3DResult(poses_3d=poses_3d, poses_3d_world=poses_world)
+
+ # ------------------------------------------------------------------
+ # Private helpers – device & padding
+ # ------------------------------------------------------------------
+
+ def _resolve_device(self, device) -> None:
+ """Set ``self.device`` from the constructor argument."""
+ if device is None:
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ else:
+ self.device = torch.device(device)
+
+ def _resolve_pad(self) -> int:
+ """Derived from frames setting (single-frame models ⇒ pad=0)."""
+ return (self.model_cfg.frames - 1) // 2
+
+ # ------------------------------------------------------------------
+ # Private helpers – model loading
+ # ------------------------------------------------------------------
+
+ def _setup_estimator_2d(self) -> HRNetEstimator:
+ """Initialise the HRNet 2D pose estimator on first use."""
+ if self._estimator_2d is None:
+ self._estimator_2d = HRNetEstimator()
+ return self._estimator_2d
+
+ def _setup_model(self) -> torch.nn.Module:
+ """Initialise the 3D lifting model on first use."""
+ if self._model_3d is None:
+ ModelClass = get_model(self.model_cfg.model_type)
+ self._model_3d = ModelClass(self.model_cfg).to(self.device)
+ self._model_3d.eval()
+ return self._model_3d
+
+ def _load_weights(self) -> None:
+ """Load checkpoint weights into ``self._model_3d``.
+
+ Mirrors the demo's loading strategy: iterate over the model's own
+ state-dict keys and pull matching entries from the checkpoint so that
+ extra keys in the checkpoint are silently ignored.
+ """
+ if not self.model_weights_path:
+ raise ValueError(
+ "No model weights path provided. Pass 'model_weights_path' "
+ "to the FMPose3DInference constructor."
+ )
+ weights = Path(self.model_weights_path)
+ if not weights.exists():
+ raise ValueError(
+ f"Model weights file not found: {weights}. "
+ "Please provide a valid path to a .pth checkpoint file in the "
+ "FMPose3DInference constructor."
+ )
+ if self._model_3d is None:
+ raise ValueError("Model not initialised. Call setup_runtime() first.")
+ pre_dict = torch.load(
+ self.model_weights_path,
+ weights_only=True,
+ map_location=self.device,
+ )
+ model_dict = self._model_3d.state_dict()
+ for name in model_dict:
+ if name in pre_dict:
+ model_dict[name] = pre_dict[name]
+ self._model_3d.load_state_dict(model_dict)
+
+ # ------------------------------------------------------------------
+ # Private helpers – input resolution
+ # ------------------------------------------------------------------
+
+ def _ingest_input(self, source: Source) -> _IngestedInput:
+ """Normalise *source* into a ``(N, H, W, C)`` frames array.
+
+ Accepted *source* values:
+
+ * **str / Path** – path to a single image or a directory of images.
+ * **ndarray (H, W, C)** – a single BGR frame.
+ * **ndarray (N, H, W, C)** – a batch of BGR frames.
+ * **list of str/Path** – multiple image file paths.
+ * **list of ndarray** – multiple ``(H, W, C)`` BGR frames.
+
+ Video files are not yet supported and will raise
+ :class:`NotImplementedError`.
+
+ Parameters
+ ----------
+ source : Source
+ The input to resolve.
+
+ Returns
+ -------
+ _IngestedInput
+ Contains ``frames`` as ``(N, H, W, C)`` and ``image_size``
+ as ``(height, width)``.
+ """
+ import cv2
+
+ # -- numpy array (single frame or batch) ----------------------------
+ if isinstance(source, np.ndarray):
+ if source.ndim == 3:
+ frames = source[np.newaxis] # (1, H, W, C)
+ elif source.ndim == 4:
+ frames = source
+ else:
+ raise ValueError(
+ f"Expected ndarray with 3 (H,W,C) or 4 (N,H,W,C) dims, "
+ f"got {source.ndim}"
+ )
+ h, w = frames.shape[1], frames.shape[2]
+ return _IngestedInput(frames=frames, image_size=(h, w))
+
+ # -- list / sequence ------------------------------------------------
+ if isinstance(source, (list, tuple)):
+ if len(source) == 0:
+ raise ValueError("Empty source list.")
+
+ first = source[0]
+
+ # List of arrays
+ if isinstance(first, np.ndarray):
+ frames = np.stack(list(source), axis=0)
+ h, w = frames.shape[1], frames.shape[2]
+ return _IngestedInput(frames=frames, image_size=(h, w))
+
+ # List of paths
+ if isinstance(first, (str, Path)):
+ loaded = []
+ for p in source:
+ p = Path(p)
+ self._check_not_video(p)
+ img = cv2.imread(str(p))
+ if img is None:
+ raise FileNotFoundError(
+ f"Could not read image: {p}"
+ )
+ loaded.append(img)
+ frames = np.stack(loaded, axis=0)
+ h, w = frames.shape[1], frames.shape[2]
+ return _IngestedInput(frames=frames, image_size=(h, w))
+
+ raise TypeError(
+ f"Unsupported element type in source list: {type(first)}"
+ )
+
+ # -- str / Path (file or directory) ---------------------------------
+ p = Path(source)
+ if not p.exists():
+ raise FileNotFoundError(f"Source path does not exist: {p}")
+
+ self._check_not_video(p)
+
+ if p.is_dir():
+ images = sorted(
+ f for f in p.iterdir()
+ if f.suffix.lower() in self._IMAGE_EXTENSIONS
+ )
+ if not images:
+ raise FileNotFoundError(
+ f"No image files found in directory: {p}"
+ )
+ loaded = []
+ for img_path in images:
+ img = cv2.imread(str(img_path))
+ if img is None:
+ raise FileNotFoundError(
+ f"Could not read image: {img_path}"
+ )
+ loaded.append(img)
+ frames = np.stack(loaded, axis=0)
+ h, w = frames.shape[1], frames.shape[2]
+ return _IngestedInput(frames=frames, image_size=(h, w))
+
+ # Single image file
+ img = cv2.imread(str(p))
+ if img is None:
+ raise FileNotFoundError(f"Could not read image: {p}")
+ frames = img[np.newaxis] # (1, H, W, C)
+ h, w = frames.shape[1], frames.shape[2]
+ return _IngestedInput(frames=frames, image_size=(h, w))
+
+ def _check_not_video(self, p: Path) -> None:
+ """Raise :class:`NotImplementedError` if *p* looks like a video."""
+ _VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv"}
+ if p.is_file() and p.suffix.lower() in _VIDEO_EXTS:
+ raise NotImplementedError(
+ f"Video input is not yet supported (got {p}). "
+ "Please extract frames and pass them as image paths or arrays."
+ )
diff --git a/fmpose3d/lib/hrnet/hrnet.py b/fmpose3d/lib/hrnet/hrnet.py
new file mode 100644
index 00000000..1368b30b
--- /dev/null
+++ b/fmpose3d/lib/hrnet/hrnet.py
@@ -0,0 +1,282 @@
+"""
+FMPose3D: monocular 3D Pose Estimation via Flow Matching
+
+Official implementation of the paper:
+"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
+by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
+Licensed under Apache 2.0
+"""
+
+"""
+FMPose3D – clean HRNet 2D pose estimation API.
+
+Provides :class:`HRNetPose2d`, a self-contained wrapper around the
+HRNet + YOLO detection pipeline that accepts numpy arrays directly
+(no file I/O, no argparse, no global yacs config leaking out).
+
+Usage::
+
+ api = HRNetPose2d(det_dim=416, num_persons=1)
+ api.setup() # loads YOLO + HRNet weights
+ keypoints, scores = api.predict(frames) # (M, N, 17, 2), (M, N, 17)
+"""
+
+from __future__ import annotations
+
+import copy
+import os.path as osp
+from collections import OrderedDict
+from typing import Tuple
+
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+
+from fmpose3d.lib.checkpoint.download_checkpoints import (
+ ensure_checkpoints,
+ get_checkpoint_path,
+)
+
+
+class HRNetPose2d:
+ """Self-contained 2D pose estimator (YOLO detector + HRNet).
+
+ A self-contained HRNet 2D pose estimator that accepts numpy arrays directly.
+ It serves as alternative to the gen_video_kpts function in fmpose3d/lib/hrnet/gen_kpts.py,
+ which generates 2D keypoints from a video file.
+
+ Parameters
+ ----------
+ det_dim : int
+ YOLO input resolution (default 416).
+ num_persons : int
+ Maximum number of persons to track per frame (default 1).
+ thred_score : float
+ YOLO object-confidence threshold (default 0.30).
+ hrnet_cfg_file : str
+ Path to the HRNet YAML experiment config. Empty string (default)
+ uses the bundled ``w48_384x288_adam_lr1e-3.yaml``.
+ hrnet_weights_path : str
+ Path to the HRNet ``.pth`` checkpoint. Empty string (default)
+ uses the auto-downloaded ``pose_hrnet_w48_384x288.pth``.
+ """
+
+ def __init__(
+ self,
+ det_dim: int = 416,
+ num_persons: int = 1,
+ thred_score: float = 0.30,
+ hrnet_cfg_file: str = "",
+ hrnet_weights_path: str = "",
+ ) -> None:
+ self.det_dim = det_dim
+ self.num_persons = num_persons
+ self.thred_score = thred_score
+ self.hrnet_cfg_file = hrnet_cfg_file
+ self.hrnet_weights_path = hrnet_weights_path
+
+ # Populated by setup()
+ self._human_model = None
+ self._pose_model = None
+ self._people_sort = None
+ self._hrnet_cfg = None # frozen yacs CfgNode used by PreProcess / get_final_preds
+
+ # ------------------------------------------------------------------
+ # Setup
+ # ------------------------------------------------------------------
+
+ @property
+ def is_ready(self) -> bool:
+ """``True`` once :meth:`setup` has been called."""
+ return self._human_model is not None
+
+ def setup(self) -> "HRNetPose2d":
+ """Load YOLO detector and HRNet pose model.
+
+ Can safely be called more than once (subsequent calls are no-ops).
+
+ Returns ``self`` so you can write ``api = HRNetPose2d().setup()``.
+ """
+ if self.is_ready:
+ return self
+
+ ensure_checkpoints()
+
+ # --- resolve paths ---------------------------------------------------
+ hrnet_cfg_file = self.hrnet_cfg_file
+ if not hrnet_cfg_file:
+ hrnet_cfg_file = osp.join(
+ osp.dirname(osp.abspath(__file__)),
+ "experiments",
+ "w48_384x288_adam_lr1e-3.yaml",
+ )
+
+ hrnet_weights = self.hrnet_weights_path
+ if not hrnet_weights:
+ hrnet_weights = get_checkpoint_path("pose_hrnet_w48_384x288.pth")
+
+ # --- build internal yacs config (kept private) -----------------------
+ from fmpose3d.lib.hrnet.lib.config import cfg as _global_cfg
+ from fmpose3d.lib.hrnet.lib.config import update_config as _update_cfg
+ from types import SimpleNamespace
+
+ _global_cfg.defrost()
+ _update_cfg(
+ _global_cfg,
+ SimpleNamespace(cfg=hrnet_cfg_file, opts=[], modelDir=hrnet_weights),
+ )
+ # Snapshot the frozen cfg so we can pass it to PreProcess / get_final_preds.
+ self._hrnet_cfg = _global_cfg
+
+ # cudnn tuning
+ cudnn.benchmark = self._hrnet_cfg.CUDNN.BENCHMARK
+ cudnn.deterministic = self._hrnet_cfg.CUDNN.DETERMINISTIC
+ cudnn.enabled = self._hrnet_cfg.CUDNN.ENABLED
+
+ # --- load models -----------------------------------------------------
+ from fmpose3d.lib.yolov3.human_detector import load_model as _yolo_load
+ from fmpose3d.lib.sort.sort import Sort
+
+ self._human_model = _yolo_load(inp_dim=self.det_dim)
+ self._pose_model = self._load_hrnet(self._hrnet_cfg)
+ self._people_sort = Sort(min_hits=0)
+
+ return self
+
+ # ------------------------------------------------------------------
+ # Prediction
+ # ------------------------------------------------------------------
+
+ def predict(
+ self, frames: np.ndarray
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Estimate 2D keypoints for a batch of BGR frames.
+
+ Parameters
+ ----------
+ frames : ndarray, shape ``(N, H, W, C)``
+ BGR images. A single frame ``(H, W, C)`` is also accepted
+ and will be treated as a batch of one.
+
+ Returns
+ -------
+ keypoints : ndarray, shape ``(num_persons, N, 17, 2)``
+ COCO-format 2D keypoints in pixel coordinates.
+ scores : ndarray, shape ``(num_persons, N, 17)``
+ Per-joint confidence scores.
+ """
+ if not self.is_ready:
+ self.setup()
+
+ if frames.ndim == 3:
+ frames = frames[np.newaxis]
+
+ kpts_result = []
+ scores_result = []
+
+ for i in range(frames.shape[0]):
+ kpts, sc = self._estimate_frame(frames[i])
+ kpts_result.append(kpts)
+ scores_result.append(sc)
+
+ keypoints = np.array(kpts_result) # (N, M, 17, 2)
+ scores = np.array(scores_result) # (N, M, 17)
+
+ # (N, M, 17, 2) → (M, N, 17, 2)
+ keypoints = keypoints.transpose(1, 0, 2, 3)
+ # (N, M, 17) → (M, N, 17)
+ scores = scores.transpose(1, 0, 2)
+
+ return keypoints, scores
+
+ # ------------------------------------------------------------------
+ # Internal helpers
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def _load_hrnet(config):
+ """Instantiate HRNet and load checkpoint weights."""
+ from fmpose3d.lib.hrnet.lib.models import pose_hrnet
+
+ model = pose_hrnet.get_pose_net(config, is_train=False)
+ if torch.cuda.is_available():
+ model = model.cuda()
+
+ state_dict = torch.load(config.OUTPUT_DIR, weights_only=True)
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ new_state_dict[k] = v
+ model.load_state_dict(new_state_dict)
+ model.eval()
+ return model
+
+ def _estimate_frame(
+ self, frame: np.ndarray
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Run detection + pose estimation on a single BGR frame.
+
+ Returns
+ -------
+ kpts : ndarray, shape ``(num_persons, 17, 2)``
+ scores : ndarray, shape ``(num_persons, 17)``
+ """
+ from fmpose3d.lib.yolov3.human_detector import yolo_human_det
+ from fmpose3d.lib.hrnet.lib.utils.utilitys import PreProcess
+ from fmpose3d.lib.hrnet.lib.utils.inference import get_final_preds
+
+ num_persons = self.num_persons
+
+ bboxs, det_scores = yolo_human_det(
+ frame, self._human_model, reso=self.det_dim, confidence=self.thred_score,
+ )
+
+ if bboxs is None or not bboxs.any():
+ # No detection – return zeros
+ kpts = np.zeros((num_persons, 17, 2), dtype=np.float32)
+ scores = np.zeros((num_persons, 17), dtype=np.float32)
+ return kpts, scores
+
+ # Track
+ people_track = self._people_sort.update(bboxs)
+
+ if people_track.shape[0] == 1:
+ people_track_ = people_track[-1, :-1].reshape(1, 4)
+ elif people_track.shape[0] >= 2:
+ people_track_ = people_track[-num_persons:, :-1].reshape(num_persons, 4)
+ people_track_ = people_track_[::-1]
+ else:
+ kpts = np.zeros((num_persons, 17, 2), dtype=np.float32)
+ scores = np.zeros((num_persons, 17), dtype=np.float32)
+ return kpts, scores
+
+ track_bboxs = []
+ for bbox in people_track_:
+ bbox = [round(i, 2) for i in list(bbox)]
+ track_bboxs.append(bbox)
+
+ with torch.no_grad():
+ inputs, origin_img, center, scale = PreProcess(
+ frame, track_bboxs, self._hrnet_cfg, num_persons,
+ )
+ inputs = inputs[:, [2, 1, 0]]
+
+ if torch.cuda.is_available():
+ inputs = inputs.cuda()
+ output = self._pose_model(inputs)
+
+ preds, maxvals = get_final_preds(
+ self._hrnet_cfg,
+ output.clone().cpu().numpy(),
+ np.asarray(center),
+ np.asarray(scale),
+ )
+
+ kpts = np.zeros((num_persons, 17, 2), dtype=np.float32)
+ scores = np.zeros((num_persons, 17), dtype=np.float32)
+ for i, kpt in enumerate(preds):
+ kpts[i] = kpt
+ for i, score in enumerate(maxvals):
+ scores[i] = score.squeeze()
+
+ return kpts, scores
+
diff --git a/fmpose3d/models/__init__.py b/fmpose3d/models/__init__.py
index 5b94df4d..b9dc64a4 100644
--- a/fmpose3d/models/__init__.py
+++ b/fmpose3d/models/__init__.py
@@ -11,10 +11,16 @@
FMPose3D models.
"""
-from .graph_frames import Graph
-from .model_GAMLP import Model
+from .base_model import BaseModel, register_model, get_model, list_models
+
+# Import model subpackages so their @register_model decorators execute.
+from .fmpose3d import Graph, Model
__all__ = [
+ "BaseModel",
+ "register_model",
+ "get_model",
+ "list_models",
"Graph",
"Model",
]
diff --git a/fmpose3d/models/base_model.py b/fmpose3d/models/base_model.py
new file mode 100644
index 00000000..ddf06737
--- /dev/null
+++ b/fmpose3d/models/base_model.py
@@ -0,0 +1,114 @@
+"""
+FMPose3D: monocular 3D Pose Estimation via Flow Matching
+
+Official implementation of the paper:
+"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
+by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
+Licensed under Apache 2.0
+"""
+
+from abc import ABC, abstractmethod
+import warnings
+
+import torch
+import torch.nn as nn
+
+
+# ---------------------------------------------------------------------------
+# Model registry
+# ---------------------------------------------------------------------------
+
+_MODEL_REGISTRY: dict[str, type["BaseModel"]] = {}
+
+
+def register_model(name: str):
+ """Class decorator that registers a model under *name*.
+
+ Usage::
+
+ @register_model("my_model")
+ class MyModel(BaseModel):
+ ...
+
+ The model can then be retrieved with :func:`get_model`.
+ """
+
+ def decorator(cls: type["BaseModel"]) -> type["BaseModel"]:
+ if name in _MODEL_REGISTRY:
+ warnings.warn(
+ f"Model '{name}' is already registered "
+ f"(existing: {_MODEL_REGISTRY[name].__qualname__}, "
+ f"new: {cls.__qualname__})"
+ )
+ # raise ValueError(
+ # f"Model '{name}' is already registered "
+ # f"(existing: {_MODEL_REGISTRY[name].__qualname__}, "
+ # f"new: {cls.__qualname__})"
+ # )
+ _MODEL_REGISTRY[name] = cls
+ return cls
+
+ return decorator
+
+
+def get_model(name: str) -> type["BaseModel"]:
+ """Return the model class registered under *name*.
+
+ Raises :class:`KeyError` with a helpful message when the name is unknown.
+ """
+ if name not in _MODEL_REGISTRY:
+ available = ", ".join(sorted(_MODEL_REGISTRY)) or "(none)"
+ raise KeyError(
+ f"Unknown model '{name}'. Available models: {available}"
+ )
+ return _MODEL_REGISTRY[name]
+
+
+def list_models() -> list[str]:
+ """Return a sorted list of all registered model names."""
+ return sorted(_MODEL_REGISTRY)
+
+
+# ---------------------------------------------------------------------------
+# Base model
+# ---------------------------------------------------------------------------
+
+
+class BaseModel(ABC, nn.Module):
+ """Abstract base class for all FMPose3D lifting models.
+
+ Every model must accept a single *args* namespace / object in its
+ constructor and implement :meth:`forward` with the signature below.
+
+ Parameters expected on *args* (at minimum):
+ - ``channel`` – embedding dimension
+ - ``layers`` – number of transformer / GCN blocks
+ - ``d_hid`` – hidden MLP dimension
+ - ``token_dim`` – token dimension
+ - ``n_joints`` – number of body joints
+ """
+
+ @abstractmethod
+ def __init__(self, args) -> None:
+ super().__init__()
+
+ @abstractmethod
+ def forward(
+ self,
+ pose_2d: torch.Tensor,
+ y_t: torch.Tensor,
+ t: torch.Tensor,
+ ) -> torch.Tensor:
+ """Predict the velocity field for flow matching.
+
+ Args:
+ pose_2d: 2D keypoints, shape ``(B, F, J, 2)``.
+ y_t: Noisy 3D hypothesis at time *t*, shape ``(B, F, J, 3)``.
+ t: Diffusion / flow time, shape ``(B, F, 1, 1)`` with values
+ in ``[0, 1]``.
+
+ Returns:
+ Predicted velocity ``v``, shape ``(B, F, J, 3)``.
+ """
+ ...
+
diff --git a/fmpose3d/models/fmpose3d/__init__.py b/fmpose3d/models/fmpose3d/__init__.py
new file mode 100644
index 00000000..9cd972b9
--- /dev/null
+++ b/fmpose3d/models/fmpose3d/__init__.py
@@ -0,0 +1,21 @@
+"""
+FMPose3D: monocular 3D Pose Estimation via Flow Matching
+
+Official implementation of the paper:
+"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
+by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
+Licensed under Apache 2.0
+"""
+
+"""
+FMPose3D model subpackage.
+"""
+
+from .graph_frames import Graph
+from .model_GAMLP import Model
+
+__all__ = [
+ "Graph",
+ "Model",
+]
+
diff --git a/fmpose3d/models/graph_frames.py b/fmpose3d/models/fmpose3d/graph_frames.py
old mode 100755
new mode 100644
similarity index 99%
rename from fmpose3d/models/graph_frames.py
rename to fmpose3d/models/fmpose3d/graph_frames.py
index 7aa2c391..0a26b140
--- a/fmpose3d/models/graph_frames.py
+++ b/fmpose3d/models/fmpose3d/graph_frames.py
@@ -207,4 +207,5 @@ def normalize_undigraph(A):
if __name__=="__main__":
graph = Graph('hm36_gt', 'spatial', 1)
print(graph.A.shape)
- # print(graph)
\ No newline at end of file
+ # print(graph)
+
diff --git a/fmpose3d/models/model_GAMLP.py b/fmpose3d/models/fmpose3d/model_GAMLP.py
similarity index 88%
rename from fmpose3d/models/model_GAMLP.py
rename to fmpose3d/models/fmpose3d/model_GAMLP.py
index c0c6b46e..66579001 100644
--- a/fmpose3d/models/model_GAMLP.py
+++ b/fmpose3d/models/fmpose3d/model_GAMLP.py
@@ -7,15 +7,14 @@
Licensed under Apache 2.0
"""
-import sys
-sys.path.append("..")
import torch
import torch.nn as nn
import math
from einops import rearrange
-from fmpose3d.models.graph_frames import Graph
+from fmpose3d.models.fmpose3d.graph_frames import Graph
+from fmpose3d.models.base_model import BaseModel, register_model
from functools import partial
-from einops import rearrange, repeat
+from einops import rearrange
from timm.models.layers import DropPath
class TimeEmbedding(nn.Module):
@@ -36,8 +35,6 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
b, f = t.shape[0], t.shape[1]
half_dim = self.dim // 2
- # Gaussian Fourier features: sin(2π B t), cos(2π B t)
- # t: (B,F,1,1) -> (B,F,1,1,1) -> broadcast with (1,1,1,1,half_dim)
angles = (2 * math.pi) * t.to(torch.float32).unsqueeze(-1) * self.B.view(1, 1, 1, 1, half_dim)
sin = torch.sin(angles)
cos = torch.cos(angles)
@@ -206,27 +203,28 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay
self.fc1 = nn.Linear(dim_in, dim_hid)
self.act = act_layer()
self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity()
- self.fc5 = nn.Linear(dim_hid, dim_out)
+ self.fc2= nn.Linear(dim_hid, dim_out)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
- x = self.fc5(x)
+ x = self.fc2(x)
return x
-class Model(nn.Module):
+@register_model("fmpose3d")
+class Model(BaseModel):
def __init__(self, args):
- super().__init__()
+ super().__init__(args)
## GCN
self.graph = Graph('hm36_gt', 'spatial', pad=1)
# Register as buffer (not parameter) to follow module device automatically
self.register_buffer('A', torch.tensor(self.graph.A, dtype=torch.float32))
- self.t_embed_dim = 16
+ self.t_embed_dim = 32
self.time_embed = TimeEmbedding(self.t_embed_dim, hidden_dim=64)
- self.encoder_pose_2d = encoder(2, args.channel//2, args.channel//2-self.t_embed_dim//2)
- self.encoder_y_t = encoder(3, args.channel//2, args.channel//2-self.t_embed_dim//2)
+
+ self.encoder = encoder(2 + 3 + self.t_embed_dim, args.channel//2, args.channel)
self.FMPose3D = FMPose3D(args.layers, args.channel, args.d_hid, args.token_dim, self.A, length=args.n_joints) # 256
self.pred_mu = decoder(args.channel, args.channel//2, 3)
@@ -235,24 +233,16 @@ def forward(self, pose_2d, y_t, t):
# pose_2d: (B,F,J,2) y_t: (B,F,J,3) t: (B,F,1,1)
b, f, j, _ = pose_2d.shape
- # Ensure t has the correct shape (B,F,1,1)
- if t.shape[1] == 1 and f > 1:
- t = t.expand(b, f, 1, 1).contiguous()
-
# build time embedding
t_emb = self.time_embed(t) # (B,F,t_dim)
t_emb = t_emb.unsqueeze(2).expand(b, f, j, self.t_embed_dim).contiguous() # (B,F,J,t_dim)
- pose_2d_emb = self.encoder_pose_2d(pose_2d)
- y_t_emb = self.encoder_y_t(y_t)
-
- in_emb = torch.cat([pose_2d_emb, y_t_emb, t_emb], dim=-1) # (B,F,J,dim)
- in_emb = rearrange(in_emb, 'b f j c -> (b f) j c').contiguous() # (B*F,J,in)
-
- # encoder -> model -> regression head
- h = self.FMPose3D(in_emb)
- v = self.pred_mu(h) # (B*F,J,3)
+ x_in = torch.cat([pose_2d, y_t, t_emb], dim=-1) # (B,F,J,2+3+t_dim)
+ x_in = rearrange(x_in, 'b f j c -> (b f) j c').contiguous() # (B*F,J,in)
+ in_emb = self.encoder(x_in)
+ features = self.FMPose3D(in_emb)
+ v = self.pred_mu(features) # (B*F,J,3)
v = rearrange(v, '(b f) j c -> b f j c', b=b, f=f).contiguous() # (B,F,J,3)
return v
@@ -276,4 +266,5 @@ class Args:
y_t = torch.randn(1, 17, 17, 3, device=device)
t = torch.randn(1, 1, 1, 1, device=device)
v = model(x, y_t, t)
- print(v.shape)
\ No newline at end of file
+ print(v.shape)
+
diff --git a/images/demo.gif b/images/demo.gif
new file mode 100644
index 00000000..fa1595aa
Binary files /dev/null and b/images/demo.gif differ
diff --git a/images/demo.jpg b/images/demo.jpg
index 328847b0..9dee98d2 100644
Binary files a/images/demo.jpg and b/images/demo.jpg differ
diff --git a/scripts/FMPose3D_main.py b/scripts/FMPose3D_main.py
index 6cd97b1d..e9adf3a7 100644
--- a/scripts/FMPose3D_main.py
+++ b/scripts/FMPose3D_main.py
@@ -78,7 +78,7 @@ def test_multi_hypothesis(
for i, data in enumerate(tqdm(dataLoader, 0)):
batch_cam, gt_3D, input_2D, action, subject, scale, bb_box, cam_ind = data
- [input_2D, gt_3D, batch_cam, scale, bb_box] = get_varialbe(
+ [input_2D, gt_3D, batch_cam, scale, bb_box] = get_variable(
split, [input_2D, gt_3D, batch_cam, scale, bb_box]
)
@@ -165,7 +165,7 @@ def train(opt, train_loader, model, optimizer):
for i, data in enumerate(tqdm(train_loader, 0)):
batch_cam, gt_3D, input_2D, action, subject, scale, bb_box, cam_ind = data
- [input_2D, gt_3D, batch_cam, scale, bb_box] = get_varialbe(
+ [input_2D, gt_3D, batch_cam, scale, bb_box] = get_variable(
split, [input_2D, gt_3D, batch_cam, scale, bb_box]
)
@@ -267,36 +267,29 @@ def print_error_action(action_error_sum, is_train):
args.checkpoint = "./checkpoint/" + folder_name
elif args.train == False:
# create a new folder for the test results
- args.previous_dir = os.path.dirname(args.saved_model_path)
+ args.previous_dir = os.path.dirname(args.model_weights_path)
args.checkpoint = os.path.join(args.previous_dir, folder_name)
if not os.path.exists(args.checkpoint):
os.makedirs(args.checkpoint)
# backup files
- # import shutil
- # file_name = os.path.basename(__file__)
- # shutil.copyfile(
- # src=file_name,
- # dst=os.path.join(args.checkpoint, args.create_time + "_" + file_name),
- # )
- # shutil.copyfile(
- # src="common/arguments.py",
- # dst=os.path.join(args.checkpoint, args.create_time + "_arguments.py"),
- # )
- # if getattr(args, "model_path", ""):
- # model_src_path = os.path.abspath(args.model_path)
- # model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path)
- # shutil.copyfile(
- # src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name)
- # )
- # shutil.copyfile(
- # src="common/utils.py",
- # dst=os.path.join(args.checkpoint, args.create_time + "_utils.py"),
- # )
- # sh_base = os.path.basename(args.sh_file)
- # dst_name = f"{args.create_time}_" + sh_base
- # shutil.copyfile(src=args.sh_file, dst=os.path.join(args.checkpoint, dst_name))
+ import shutil
+ script_path = os.path.abspath(__file__)
+ script_name = os.path.basename(script_path)
+ shutil.copyfile(
+ src=script_path,
+ dst=os.path.join(args.checkpoint, args.create_time + "_" + script_name),
+ )
+ if getattr(args, "model_path", ""):
+ model_src_path = os.path.abspath(args.model_path)
+ model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path)
+ shutil.copyfile(
+ src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name)
+ )
+ sh_base = os.path.basename(args.sh_file)
+ dst_name = f"{args.create_time}_" + sh_base
+ shutil.copyfile(src=args.sh_file, dst=os.path.join(args.checkpoint, dst_name))
logging.basicConfig(
format="%(asctime)s %(message)s",
@@ -342,14 +335,16 @@ def print_error_action(action_error_sum, is_train):
pin_memory=True,
)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
model = {}
- model["CFM"] = CFM(args).cuda()
+ model["CFM"] = CFM(args).to(device)
if args.reload:
model_dict = model["CFM"].state_dict()
- model_path = args.saved_model_path
+ model_path = args.model_weights_path
print(model_path)
- pre_dict = torch.load(model_path)
+ pre_dict = torch.load(model_path, map_location=device, weights_only=True)
for name, key in model_dict.items():
model_dict[name] = pre_dict[name]
model["CFM"].load_state_dict(model_dict)
diff --git a/scripts/FMPose3D_test.sh b/scripts/FMPose3D_test.sh
index de7869ea..3d2a6152 100755
--- a/scripts/FMPose3D_test.sh
+++ b/scripts/FMPose3D_test.sh
@@ -2,7 +2,6 @@
layers=5
batch_size=1024
sh_file='scripts/FMPose3D_test.sh'
-weight_softmax_tau=1.0
num_hypothesis_list=1
eval_multi_steps=3
topk=8
@@ -11,17 +10,16 @@ mode='exp'
exp_temp=0.005
folder_name=test_s${eval_multi_steps}_${mode}_h${num_hypothesis_list}_$(date +%Y%m%d_%H%M%S)
-model_path='pre_trained_models/fmpose_detected2d/model_GAMLP.py'
-saved_model_path='pre_trained_models/fmpose_detected2d/FMpose_36_4972_best.pth'
+model_path='./pre_trained_models/fmpose3d_h36m/model_GAMLP.py'
+model_weights_path='./pre_trained_models/fmpose3d_h36m/FMpose3D_pretrained_weights.pth'
-#Test CFM
+#Test
python3 scripts/FMPose3D_main.py \
--reload \
--topk ${topk} \
--exp_temp ${exp_temp} \
---weight_softmax_tau ${weight_softmax_tau} \
--folder_name ${folder_name} \
---saved_model_path "${saved_model_path}" \
+--model_weights_path "${model_weights_path}" \
--model_path "${model_path}" \
--eval_sample_steps ${eval_multi_steps} \
--test_augmentation True \
diff --git a/scripts/FMPose3D_train.sh b/scripts/FMPose3D_train.sh
index cb3e0289..939658c6 100755
--- a/scripts/FMPose3D_train.sh
+++ b/scripts/FMPose3D_train.sh
@@ -11,10 +11,10 @@ epochs=80
num_saved_models=3
frames=1
channel_dim=512
-model_path="" # when the path is empty, the model will be loaded from the installed fmpose package
-# model_path='./fmpose/models/model_GAMLP.py' # when the path is not empty, the model will be loaded from the local file path
+model_path="" # when the path is empty, the model will be loaded from the installed fmpose3d package
+# model_path='./models/model_GAMLP.py' # when the path is not empty, the model will be loaded from the local file path
sh_file='scripts/FMPose3D_train.sh'
-folder_name=FMPose3D_Publish_layers${layers}_$(date +%Y%m%d_%H%M%S)
+folder_name=FMPose3D_layers${layers}_$(date +%Y%m%d_%H%M%S)
python3 scripts/FMPose3D_main.py \
--train \
diff --git a/tests/test_config.py b/tests/test_config.py
new file mode 100644
index 00000000..0919b4e5
--- /dev/null
+++ b/tests/test_config.py
@@ -0,0 +1,403 @@
+"""
+FMPose3D: monocular 3D Pose Estimation via Flow Matching
+
+Official implementation of the paper:
+"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
+by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
+Licensed under Apache 2.0
+"""
+
+import argparse
+import math
+
+import pytest
+
+from fmpose3d.common.config import (
+ PipelineConfig,
+ FMPose3DConfig,
+ DatasetConfig,
+ TrainingConfig,
+ InferenceConfig,
+ AggregationConfig,
+ CheckpointConfig,
+ RefinementConfig,
+ OutputConfig,
+ DemoConfig,
+ RuntimeConfig,
+ _SUB_CONFIG_CLASSES,
+)
+
+
+# ---------------------------------------------------------------------------
+# Sub-config defaults
+# ---------------------------------------------------------------------------
+
+
+class TestFMPose3DConfig:
+ def test_defaults(self):
+ cfg = FMPose3DConfig()
+ assert cfg.layers == 3
+ assert cfg.channel == 512
+ assert cfg.d_hid == 1024
+ assert cfg.n_joints == 17
+ assert cfg.out_joints == 17
+ assert cfg.frames == 1
+
+ def test_custom_values(self):
+ cfg = FMPose3DConfig(layers=5, channel=256, n_joints=26)
+ assert cfg.layers == 5
+ assert cfg.channel == 256
+ assert cfg.n_joints == 26
+
+
+class TestDatasetConfig:
+ def test_defaults(self):
+ cfg = DatasetConfig()
+ assert cfg.dataset == "h36m"
+ assert cfg.keypoints == "cpn_ft_h36m_dbb"
+ assert cfg.root_path == "dataset/"
+ assert cfg.train_views == [0, 1, 2, 3]
+ assert cfg.joints_left == []
+ assert cfg.joints_right == []
+
+ def test_list_defaults_are_independent(self):
+ """Each instance should get its own list, not a shared reference."""
+ a = DatasetConfig()
+ b = DatasetConfig()
+ a.joints_left.append(99)
+ assert 99 not in b.joints_left
+
+ def test_custom_values(self):
+ cfg = DatasetConfig(
+ dataset="rat7m",
+ root_path="Rat7M_data/",
+ joints_left=[8, 10, 11],
+ joints_right=[9, 14, 15],
+ )
+ assert cfg.dataset == "rat7m"
+ assert cfg.root_path == "Rat7M_data/"
+ assert cfg.joints_left == [8, 10, 11]
+
+
+class TestTrainingConfig:
+ def test_defaults(self):
+ cfg = TrainingConfig()
+ assert cfg.train is False
+ assert cfg.nepoch == 41
+ assert cfg.batch_size == 128
+ assert cfg.lr == pytest.approx(1e-3)
+ assert cfg.lr_decay == pytest.approx(0.95)
+ assert cfg.data_augmentation is True
+
+ def test_custom_values(self):
+ cfg = TrainingConfig(lr=5e-4, nepoch=100)
+ assert cfg.lr == pytest.approx(5e-4)
+ assert cfg.nepoch == 100
+
+
+class TestInferenceConfig:
+ def test_defaults(self):
+ cfg = InferenceConfig()
+ assert cfg.test == 1
+ assert cfg.test_augmentation is True
+ assert cfg.sample_steps == 3
+ assert cfg.eval_sample_steps == "1,3,5,7,9"
+ assert cfg.hypothesis_num == 1
+ assert cfg.guidance_scale == pytest.approx(1.0)
+
+
+class TestAggregationConfig:
+ def test_defaults(self):
+ cfg = AggregationConfig()
+ assert cfg.topk == 3
+ assert cfg.exp_temp == pytest.approx(0.002)
+ assert cfg.mode == "exp"
+ assert cfg.opt_steps == 2
+
+
+class TestCheckpointConfig:
+ def test_defaults(self):
+ cfg = CheckpointConfig()
+ assert cfg.reload is False
+ assert cfg.model_weights_path == ""
+ assert cfg.previous_dir == "./pre_trained_model/pretrained"
+ assert cfg.num_saved_models == 3
+ assert cfg.previous_best_threshold == math.inf
+
+ def test_mutability(self):
+ cfg = CheckpointConfig()
+ cfg.previous_best_threshold = 42.5
+ cfg.previous_name = "best_model.pth"
+ assert cfg.previous_best_threshold == pytest.approx(42.5)
+ assert cfg.previous_name == "best_model.pth"
+
+
+class TestRefinementConfig:
+ def test_defaults(self):
+ cfg = RefinementConfig()
+ assert cfg.post_refine is False
+ assert cfg.lr_refine == pytest.approx(1e-5)
+ assert cfg.refine is False
+
+
+class TestOutputConfig:
+ def test_defaults(self):
+ cfg = OutputConfig()
+ assert cfg.create_time == ""
+ assert cfg.create_file == 1
+ assert cfg.debug is False
+ assert cfg.folder_name == ""
+
+
+class TestDemoConfig:
+ def test_defaults(self):
+ cfg = DemoConfig()
+ assert cfg.type == "image"
+ assert cfg.path == "demo/images/running.png"
+
+
+class TestRuntimeConfig:
+ def test_defaults(self):
+ cfg = RuntimeConfig()
+ assert cfg.gpu == "0"
+ assert cfg.pad == 0
+ assert cfg.single is False
+ assert cfg.reload_3d is False
+
+
+# ---------------------------------------------------------------------------
+# PipelineConfig
+# ---------------------------------------------------------------------------
+
+
+class TestPipelineConfig:
+ def test_default_construction(self):
+ """All sub-configs are initialised with their defaults."""
+ cfg = PipelineConfig()
+ assert isinstance(cfg.model_cfg, FMPose3DConfig)
+ assert isinstance(cfg.dataset_cfg, DatasetConfig)
+ assert isinstance(cfg.training_cfg, TrainingConfig)
+ assert isinstance(cfg.inference_cfg, InferenceConfig)
+ assert isinstance(cfg.aggregation_cfg, AggregationConfig)
+ assert isinstance(cfg.checkpoint_cfg, CheckpointConfig)
+ assert isinstance(cfg.refinement_cfg, RefinementConfig)
+ assert isinstance(cfg.output_cfg, OutputConfig)
+ assert isinstance(cfg.demo_cfg, DemoConfig)
+ assert isinstance(cfg.runtime_cfg, RuntimeConfig)
+
+ def test_partial_construction(self):
+ """Supplying only some sub-configs leaves the rest at defaults."""
+ cfg = PipelineConfig(
+ model_cfg=FMPose3DConfig(layers=5),
+ training_cfg=TrainingConfig(lr=2e-4),
+ )
+ assert cfg.model_cfg.layers == 5
+ assert cfg.training_cfg.lr == pytest.approx(2e-4)
+ # Others keep defaults
+ assert cfg.dataset_cfg.dataset == "h36m"
+ assert cfg.runtime_cfg.gpu == "0"
+
+ def test_sub_config_mutation(self):
+ """Mutating a sub-config field is reflected on the config."""
+ cfg = PipelineConfig()
+ cfg.training_cfg.lr = 0.01
+ assert cfg.training_cfg.lr == pytest.approx(0.01)
+
+ def test_sub_config_replacement(self):
+ """Replacing an entire sub-config works."""
+ cfg = PipelineConfig()
+ cfg.model_cfg = FMPose3DConfig(layers=10, channel=1024)
+ assert cfg.model_cfg.layers == 10
+ assert cfg.model_cfg.channel == 1024
+
+ # -- to_dict --------------------------------------------------------------
+
+ def test_to_dict_returns_flat_dict(self):
+ cfg = PipelineConfig()
+ d = cfg.to_dict()
+ assert isinstance(d, dict)
+ # Spot-check keys from different groups
+ assert "layers" in d
+ assert "dataset" in d
+ assert "lr" in d
+ assert "topk" in d
+ assert "gpu" in d
+
+ def test_to_dict_reflects_custom_values(self):
+ cfg = PipelineConfig(
+ model_cfg=FMPose3DConfig(layers=7),
+ aggregation_cfg=AggregationConfig(topk=5),
+ )
+ d = cfg.to_dict()
+ assert d["layers"] == 7
+ assert d["topk"] == 5
+
+ def test_to_dict_no_duplicate_keys(self):
+ """Every field name should be unique across all sub-configs."""
+ cfg = PipelineConfig()
+ d = cfg.to_dict()
+ all_field_names = []
+ for dc_class in _SUB_CONFIG_CLASSES.values():
+ from dataclasses import fields as dc_fields
+ all_field_names.extend(f.name for f in dc_fields(dc_class))
+ assert len(all_field_names) == len(set(all_field_names)), (
+ "Duplicate field names across sub-configs"
+ )
+
+ # -- from_namespace -------------------------------------------------------
+
+ def test_from_namespace_basic(self):
+ ns = argparse.Namespace(
+ # FMPose3DConfig
+ model="test_model",
+ model_type="fmpose3d",
+ layers=5,
+ channel=256,
+ d_hid=512,
+ token_dim=128,
+ n_joints=20,
+ out_joints=20,
+ in_channels=2,
+ out_channels=3,
+ frames=3,
+ # DatasetConfig
+ dataset="rat7m",
+ keypoints="cpn",
+ root_path="Rat7M_data/",
+ actions="*",
+ downsample=1,
+ subset=1.0,
+ stride=1,
+ crop_uv=0,
+ out_all=1,
+ train_views=[0, 1],
+ test_views=[2, 3],
+ subjects_train="S1",
+ subjects_test="S2",
+ root_joint=4,
+ joints_left=[8, 10],
+ joints_right=[9, 14],
+ # TrainingConfig
+ train=True,
+ nepoch=100,
+ batch_size=64,
+ lr=5e-4,
+ lr_decay=0.99,
+ lr_decay_large=0.5,
+ large_decay_epoch=10,
+ workers=4,
+ data_augmentation=False,
+ reverse_augmentation=False,
+ norm=0.01,
+ # InferenceConfig
+ test=1,
+ test_augmentation=False,
+ test_augmentation_flip_hypothesis=False,
+ test_augmentation_FlowAug=False,
+ sample_steps=5,
+ eval_multi_steps=True,
+ eval_sample_steps="1,3,5",
+ num_hypothesis_list="1,3",
+ hypothesis_num=3,
+ guidance_scale=1.5,
+ # AggregationConfig
+ topk=5,
+ exp_temp=0.001,
+ mode="softmax",
+ opt_steps=3,
+ # CheckpointConfig
+ reload=True,
+ model_dir="/tmp",
+ model_weights_path="/tmp/weights.pth",
+ checkpoint="/tmp/ckpt",
+ previous_dir="./pre_trained",
+ num_saved_models=5,
+ previous_best_threshold=50.0,
+ previous_name="best.pth",
+ # RefinementConfig
+ post_refine=False,
+ post_refine_reload=False,
+ previous_post_refine_name="",
+ lr_refine=1e-5,
+ refine=False,
+ reload_refine=False,
+ previous_refine_name="",
+ # OutputConfig
+ create_time="250101",
+ filename="run1",
+ create_file=1,
+ debug=True,
+ folder_name="exp1",
+ sh_file="train.sh",
+ # DemoConfig
+ type="video",
+ path="/tmp/video.mp4",
+ # RuntimeConfig
+ gpu="1",
+ pad=1,
+ single=True,
+ reload_3d=False,
+ )
+ cfg = PipelineConfig.from_namespace(ns)
+
+ # Verify a sample from each group
+ assert cfg.model_cfg.layers == 5
+ assert cfg.model_cfg.channel == 256
+ assert cfg.dataset_cfg.dataset == "rat7m"
+ assert cfg.dataset_cfg.joints_left == [8, 10]
+ assert cfg.training_cfg.train is True
+ assert cfg.training_cfg.nepoch == 100
+ assert cfg.inference_cfg.sample_steps == 5
+ assert cfg.inference_cfg.guidance_scale == pytest.approx(1.5)
+ assert cfg.aggregation_cfg.topk == 5
+ assert cfg.checkpoint_cfg.reload is True
+ assert cfg.checkpoint_cfg.previous_best_threshold == pytest.approx(50.0)
+ assert cfg.refinement_cfg.lr_refine == pytest.approx(1e-5)
+ assert cfg.output_cfg.debug is True
+ assert cfg.demo_cfg.type == "video"
+ assert cfg.runtime_cfg.gpu == "1"
+
+ def test_from_namespace_ignores_unknown_fields(self):
+ """Extra attributes in the namespace that don't match any field are ignored."""
+ ns = argparse.Namespace(
+ layers=3, channel=512, unknown_field="should_be_ignored",
+ )
+ cfg = PipelineConfig.from_namespace(ns)
+ assert cfg.model_cfg.layers == 3
+ assert cfg.model_cfg.channel == 512
+ assert not hasattr(cfg, "unknown_field")
+
+ def test_from_namespace_partial_namespace(self):
+ """A namespace missing some fields uses dataclass defaults for those."""
+ ns = argparse.Namespace(layers=10, gpu="2")
+ cfg = PipelineConfig.from_namespace(ns)
+ assert cfg.model_cfg.layers == 10
+ assert cfg.runtime_cfg.gpu == "2"
+ # Unset fields keep defaults
+ assert cfg.model_cfg.channel == 512
+ assert cfg.training_cfg.lr == pytest.approx(1e-3)
+
+ # -- round-trip: from_namespace ↔ to_dict ---------------------------------
+
+ def test_roundtrip_from_namespace_to_dict(self):
+ """Values fed via from_namespace appear identically in to_dict."""
+ ns = argparse.Namespace(
+ layers=8, channel=1024, dataset="animal3d", lr=2e-4, topk=7, gpu="3",
+ )
+ cfg = PipelineConfig.from_namespace(ns)
+ d = cfg.to_dict()
+ assert d["layers"] == 8
+ assert d["channel"] == 1024
+ assert d["dataset"] == "animal3d"
+ assert d["lr"] == pytest.approx(2e-4)
+ assert d["topk"] == 7
+ assert d["gpu"] == "3"
+
+ def test_to_dict_after_mutation(self):
+ """to_dict reflects in-place mutations on sub-configs."""
+ cfg = PipelineConfig()
+ cfg.training_cfg.lr = 0.123
+ cfg.model_cfg.layers = 99
+ d = cfg.to_dict()
+ assert d["lr"] == pytest.approx(0.123)
+ assert d["layers"] == 99