diff --git a/.gitignore b/.gitignore index 5104331c..90439acb 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,8 @@ htmlcov/ *.pkl *.h5 *.ckpt + +# Excluded directories +pre_trained_models/ +demo/predictions/ +demo/images/ \ No newline at end of file diff --git a/README.md b/README.md index 7c83438a..2b241f65 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,11 @@ ![Version](https://img.shields.io/badge/python_version-3.10-purple) [![PyPI version](https://badge.fury.io/py/fmpose3d.svg?icon=si%3Apython)](https://badge.fury.io/py/fmpose3d) -[![License: LApache 2.0](https://img.shields.io/badge/License-Apache2.0-blue.svg)](https://www.gnu.org/licenses/apach2.0) +[![License: Apache 2.0](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://www.apache.org/licenses/LICENSE-2.0) This is the official implementation of the approach described in the preprint: -[**FMPose3D: monocular 3D pose estimation via flow matching**](http://arxiv.org/abs/2602.05755) +[**FMPose3D: monocular 3D pose estimation via flow matching**](https://arxiv.org/abs/2602.05755) Ti Wang, Xiaohang Yu, Mackenzie Weygandt Mathis @@ -51,7 +51,7 @@ sh vis_in_the_wild.sh ``` The predictions will be saved to folder `demo/predictions`. -

+

## Training and Inference @@ -79,7 +79,7 @@ The training logs, checkpoints, and related files of each training time will be For training on Human3.6M: ```bash -sh /scripts/FMPose3D_train.sh +sh ./scripts/FMPose3D_train.sh ``` ### Inference diff --git a/animals/demo/vis_animals.py b/animals/demo/vis_animals.py index 357cfe80..c9fe4384 100644 --- a/animals/demo/vis_animals.py +++ b/animals/demo/vis_animals.py @@ -8,7 +8,6 @@ """ # SuperAnimal Demo: https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/COLAB/COLAB_YOURDATA_SuperAnimal.ipynb -import sys import os import numpy as np import glob @@ -25,8 +24,6 @@ from fmpose3d.animals.common.arguments import opts as parse_args from fmpose3d.common.camera import normalize_screen_coordinates, camera_to_world -sys.path.append(os.getcwd()) - args = parse_args().parse() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu @@ -334,13 +331,15 @@ def get_pose3D(path, output_dir, type='image'): print(f"args.n_joints: {args.n_joints}, args.out_joints: {args.out_joints}") ## Reload model + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = {} - model['CFM'] = CFM(args).cuda() + model['CFM'] = CFM(args).to(device) model_dict = model['CFM'].state_dict() model_path = args.saved_model_path print(f"Loading model from: {model_path}") - pre_dict = torch.load(model_path) + pre_dict = torch.load(model_path, map_location=device, weights_only=True) for name, key in model_dict.items(): model_dict[name] = pre_dict[name] model['CFM'].load_state_dict(model_dict) @@ -400,7 +399,8 @@ def get_3D_pose_from_image(args, keypoints, i, img, model, output_dir): input_2D = np.expand_dims(input_2D, axis=0) # (1, J, 2) # Convert to tensor format matching visualize_animal_poses.py - input_2D = torch.from_numpy(input_2D.astype('float32')).cuda() # (1, J, 2) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + input_2D = torch.from_numpy(input_2D.astype('float32')).to(device) # (1, J, 2) input_2D = input_2D.unsqueeze(0) # (1, 1, J, 2) # Euler sampler for CFM @@ -418,7 +418,7 @@ def euler_sample(c_2d, y_local, steps, model_3d): # Single inference without flip augmentation # Create 3D random noise with shape (1, 1, J, 3) - y = torch.randn(input_2D.size(0), input_2D.size(1), input_2D.size(2), 3).cuda() + y = torch.randn(input_2D.size(0), input_2D.size(1), input_2D.size(2), 3, device=device) output_3D = euler_sample(input_2D, y, steps=args.sample_steps, model_3d=model) output_3D = output_3D[0:, args.pad].unsqueeze(1) diff --git a/animals/scripts/main_animal3d.py b/animals/scripts/main_animal3d.py index f93e04cf..04362047 100644 --- a/animals/scripts/main_animal3d.py +++ b/animals/scripts/main_animal3d.py @@ -75,7 +75,7 @@ def step(split, args, actions, dataLoader, model, optimizer=None, epoch=None, st # gt_3D shape: torch.Size([B, J, 4]) (x,y,z + homogeneous coordinate) gt_3D = gt_3D[:,:,:3] # only use x,y,z for 3D ground truth - # [input_2D, gt_3D, batch_cam, vis_3D] = get_varialbe(split, [input_2D, gt_3D, batch_cam, vis_3D]) + # [input_2D, gt_3D, batch_cam, vis_3D] = get_variable(split, [input_2D, gt_3D, batch_cam, vis_3D]) # unsqueeze frame dimension input_2D = input_2D.unsqueeze(1) # (B,F,J,C) @@ -264,15 +264,17 @@ def get_parameter_number(net): test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers), pin_memory=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = {} - model['CFM'] = CFM(args).cuda() + model['CFM'] = CFM(args).to(device) if args.reload: model_dict = model['CFM'].state_dict() # Prefer explicit saved_model_path; otherwise fallback to previous_dir glob model_path = args.saved_model_path print(model_path) - pre_dict = torch.load(model_path) + pre_dict = torch.load(model_path, weights_only=True, map_location=device) for name, key in model_dict.items(): model_dict[name] = pre_dict[name] model['CFM'].load_state_dict(model_dict) diff --git a/demo/vis_in_the_wild.py b/demo/vis_in_the_wild.py index 9ca6f1ee..90b2953c 100755 --- a/demo/vis_in_the_wild.py +++ b/demo/vis_in_the_wild.py @@ -7,7 +7,6 @@ Licensed under Apache 2.0 """ -import sys import cv2 import os import numpy as np @@ -16,8 +15,6 @@ from tqdm import tqdm import copy -sys.path.append(os.getcwd()) - # Auto-download checkpoint files if missing from fmpose3d.lib.checkpoint.download_checkpoints import ensure_checkpoints ensure_checkpoints() @@ -28,17 +25,10 @@ args = parse_args().parse() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu -if getattr(args, 'model_path', ''): - import importlib.util - import pathlib - model_abspath = os.path.abspath(args.model_path) - module_name = pathlib.Path(model_abspath).stem - spec = importlib.util.spec_from_file_location(module_name, model_abspath) - module = importlib.util.module_from_spec(spec) - assert spec.loader is not None - spec.loader.exec_module(module) - CFM = getattr(module, 'Model') - + +from fmpose3d.models import get_model +CFM = get_model(args.model_type) + from fmpose3d.common.camera import * import matplotlib @@ -50,15 +40,27 @@ matplotlib.rcParams['pdf.fonttype'] = 42 matplotlib.rcParams['ps.fonttype'] = 42 -def show2Dpose(kps, img): - connections = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], - [5, 6], [0, 7], [7, 8], [8, 9], [9, 10], - [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]] +# Shared skeleton definition so 2D/3D segment colors match +SKELETON_CONNECTIONS = [ + [0, 1], [1, 2], [2, 3], [0, 4], [4, 5], + [5, 6], [0, 7], [7, 8], [8, 9], [9, 10], + [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16] +] +# LR mask for skeleton segments: True -> left color, False -> right color +SKELETON_LR = np.array( + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + dtype=bool, +) - LR = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=bool) +def show2Dpose(kps, img): + connections = SKELETON_CONNECTIONS + LR = SKELETON_LR lcolor = (255, 0, 0) rcolor = (0, 0, 255) + # lcolor = (240, 176, 0) + # rcolor = (240, 176, 0) + thickness = 3 for j,c in enumerate(connections): @@ -67,8 +69,8 @@ def show2Dpose(kps, img): start = list(start) end = list(end) cv2.line(img, (start[0], start[1]), (end[0], end[1]), lcolor if LR[j] else rcolor, thickness) - cv2.circle(img, (start[0], start[1]), thickness=-1, color=(0, 255, 0), radius=3) - cv2.circle(img, (end[0], end[1]), thickness=-1, color=(0, 255, 0), radius=3) + # cv2.circle(img, (start[0], start[1]), thickness=-1, color=(0, 255, 0), radius=3) + # cv2.circle(img, (end[0], end[1]), thickness=-1, color=(0, 255, 0), radius=3) return img @@ -77,11 +79,13 @@ def show3Dpose(vals, ax): lcolor=(0,0,1) rcolor=(1,0,0) - - I = np.array( [0, 0, 1, 4, 2, 5, 0, 7, 8, 8, 14, 15, 11, 12, 8, 9]) - J = np.array( [1, 4, 2, 5, 3, 6, 7, 8, 14, 11, 15, 16, 12, 13, 9, 10]) - - LR = np.array([0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0], dtype=bool) + # lcolor=(0/255, 176/255, 240/255) + # rcolor=(0/255, 176/255, 240/255) + + + I = np.array([c[0] for c in SKELETON_CONNECTIONS]) + J = np.array([c[1] for c in SKELETON_CONNECTIONS]) + LR = SKELETON_LR for i in np.arange( len(I) ): x, y, z = [np.array( [vals[I[i], j], vals[J[i], j]] ) for j in range(3)] @@ -199,7 +203,8 @@ def get_3D_pose_from_image(args, keypoints, i, img, model, output_dir): input_2D = input_2D[np.newaxis, :, :, :, :] - input_2D = torch.from_numpy(input_2D.astype('float32')).cuda() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + input_2D = torch.from_numpy(input_2D.astype('float32')).to(device) N = input_2D.size(0) @@ -215,10 +220,10 @@ def euler_sample(c_2d, y_local, steps, model_3d): ## estimation - y = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3).cuda() + y = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3, device=device) output_3D_non_flip = euler_sample(input_2D[:, 0], y, steps=args.sample_steps, model_3d=model) - y_flip = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3).cuda() + y_flip = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3, device=device) output_3D_flip = euler_sample(input_2D[:, 1], y_flip, steps=args.sample_steps, model_3d=model) output_3D_flip[:, :, :, 0] *= -1 @@ -266,14 +271,16 @@ def get_pose3D(path, output_dir, type='image'): # args.type = type ## Reload + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = {} - model['CFM'] = CFM(args).cuda() + model['CFM'] = CFM(args).to(device) # if args.reload: model_dict = model['CFM'].state_dict() - model_path = args.saved_model_path + model_path = args.model_weights_path print(model_path) - pre_dict = torch.load(model_path) + pre_dict = torch.load(model_path, map_location=device, weights_only=True) for name, key in model_dict.items(): model_dict[name] = pre_dict[name] model['CFM'].load_state_dict(model_dict) @@ -336,7 +343,7 @@ def get_pose3D(path, output_dir, type='image'): ## save output_dir_pose = output_dir +'pose/' os.makedirs(output_dir_pose, exist_ok=True) - plt.savefig(output_dir_pose + str(('%04d'% i)) + '_pose.jpg', dpi=200, bbox_inches = 'tight') + plt.savefig(output_dir_pose + str(('%04d'% i)) + '_pose.png', dpi=200, bbox_inches = 'tight') if __name__ == "__main__": diff --git a/demo/vis_in_the_wild.sh b/demo/vis_in_the_wild.sh index 3909bc3d..a6f3a4d2 100755 --- a/demo/vis_in_the_wild.sh +++ b/demo/vis_in_the_wild.sh @@ -1,21 +1,22 @@ #Test layers=5 -gpu_id=1 +gpu_id=0 sample_steps=3 batch_size=1 sh_file='vis_in_the_wild.sh' -model_path='../pre_trained_models/fmpose_detected2d/model_GAMLP.py' -saved_model_path='../pre_trained_models/fmpose_detected2d/FMpose_36_4972_best.pth' +model_type='fmpose3d' +model_weights_path='../pre_trained_models/fmpose3d_h36m/FMpose3D_pretrained_weights.pth' -# path='./images/image_00068.jpg' # single image -input_images_folder='./images/' # folder containing multiple images +target_path='./images/' # folder containing multiple images +# target_path='./images/xx.png' # single image +# target_path='./videos/xxx.mp4' # video path python3 vis_in_the_wild.py \ --type 'image' \ - --path ${input_images_folder} \ - --saved_model_path "${saved_model_path}" \ - --model_path "${model_path}" \ + --path ${target_path} \ + --model_weights_path "${model_weights_path}" \ + --model_type "${model_type}" \ --sample_steps ${sample_steps} \ --batch_size ${batch_size} \ --layers ${layers} \ diff --git a/fmpose3d/__init__.py b/fmpose3d/__init__.py index 8a9a4716..563a1402 100644 --- a/fmpose3d/__init__.py +++ b/fmpose3d/__init__.py @@ -18,17 +18,49 @@ aggregation_RPEA_joint_level, ) +# Configuration dataclasses +from .common.config import ( + FMPose3DConfig, + HRNetConfig, + InferenceConfig, + ModelConfig, + PipelineConfig, +) + +# High-level inference API +from .fmpose3d import ( + FMPose3DInference, + HRNetEstimator, + Pose2DResult, + Pose3DResult, + Source, +) + # Import 2D pose detection utilities from .lib.hrnet.gen_kpts import gen_video_kpts +from .lib.hrnet.hrnet import HRNetPose2d from .lib.preprocess import h36m_coco_format, revise_kpts # Make commonly used classes/functions available at package level __all__ = [ + # Inference API + "FMPose3DInference", + "HRNetEstimator", + "Pose2DResult", + "Pose3DResult", + "Source", + # Configuration + "FMPose3DConfig", + "HRNetConfig", + "InferenceConfig", + "ModelConfig", + "PipelineConfig", # Aggregation methods "average_aggregation", "aggregation_select_single_best_hypothesis_by_2D_error", "aggregation_RPEA_joint_level", # 2D pose detection + "HRNetPose2d", "gen_video_kpts", "h36m_coco_format", "revise_kpts", diff --git a/fmpose3d/aggregation_methods.py b/fmpose3d/aggregation_methods.py index 38d022f9..4898a03b 100644 --- a/fmpose3d/aggregation_methods.py +++ b/fmpose3d/aggregation_methods.py @@ -166,17 +166,13 @@ def aggregation_RPEA_joint_level( dist[:, :, 0] = 0.0 # Convert 2D losses to weights using softmax over top-k hypotheses per joint - tau = float(getattr(args, "weight_softmax_tau", 1.0)) H = dist.size(1) k = int(getattr(args, "topk", None)) - # print("k:", k) - # k = int(H//2)+1 k = max(1, min(k, H)) # top-k smallest distances along hypothesis dim topk_vals, topk_idx = torch.topk(dist, k=k, dim=1, largest=False) # (B,k,J) - # Weight calculation method ; weight_method = 'exp' temp = args.exp_temp max_safe_val = temp * 20 topk_vals_clipped = torch.clamp(topk_vals, max=max_safe_val) diff --git a/fmpose3d/animals/common/arber_dataset.py b/fmpose3d/animals/common/arber_dataset.py index c70bb838..27dba171 100644 --- a/fmpose3d/animals/common/arber_dataset.py +++ b/fmpose3d/animals/common/arber_dataset.py @@ -12,7 +12,6 @@ import glob import os import random -import sys import cv2 import matplotlib.pyplot as plt @@ -23,10 +22,8 @@ from torch.utils.data import Dataset from tqdm import tqdm -sys.path.append(os.path.dirname(sys.path[0])) - -from common.camera import normalize_screen_coordinates -from common.lifter3d import load_camera_params, load_h5_keypoints +from fmpose3d.common.camera import normalize_screen_coordinates +from fmpose3d.animals.common.lifter3d import load_camera_params, load_h5_keypoints class ArberDataset(Dataset): diff --git a/fmpose3d/animals/common/utils.py b/fmpose3d/animals/common/utils.py index d4496625..cdafd8c1 100755 --- a/fmpose3d/animals/common/utils.py +++ b/fmpose3d/animals/common/utils.py @@ -15,7 +15,6 @@ import numpy as np import torch -from torch.autograd import Variable def mpjpe_cal(predicted, target): @@ -220,18 +219,17 @@ def update(self, val, n=1): self.avg = self.sum / self.count -def get_varialbe(split, target): +def get_variable(split, target): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num = len(target) var = [] if split == "train": for i in range(num): - temp = ( - Variable(target[i], requires_grad=False).contiguous().type(torch.cuda.FloatTensor) - ) + temp = target[i].requires_grad_(False).contiguous().float().to(device) var.append(temp) else: for i in range(num): - temp = Variable(target[i]).contiguous().cuda().type(torch.cuda.FloatTensor) + temp = target[i].contiguous().float().to(device) var.append(temp) return var diff --git a/fmpose3d/common/__init__.py b/fmpose3d/common/__init__.py index 44082829..c7863030 100644 --- a/fmpose3d/common/__init__.py +++ b/fmpose3d/common/__init__.py @@ -12,6 +12,22 @@ """ from .arguments import opts +from .config import ( + PipelineConfig, + ModelConfig, + FMPose3DConfig, + HRNetConfig, + Pose2DConfig, + DatasetConfig, + TrainingConfig, + InferenceConfig, + AggregationConfig, + CheckpointConfig, + RefinementConfig, + OutputConfig, + DemoConfig, + RuntimeConfig, +) from .h36m_dataset import Human36mDataset from .load_data_hm36 import Fusion from .utils import ( @@ -22,11 +38,25 @@ save_top_N_models, test_calculation, print_error, - get_varialbe, + get_variable, ) __all__ = [ "opts", + "PipelineConfig", + "FMPose3DConfig", + "HRNetConfig", + "Pose2DConfig", + "ModelConfig", + "DatasetConfig", + "TrainingConfig", + "InferenceConfig", + "AggregationConfig", + "CheckpointConfig", + "RefinementConfig", + "OutputConfig", + "DemoConfig", + "RuntimeConfig", "Human36mDataset", "Fusion", "mpjpe_cal", @@ -36,6 +66,6 @@ "save_top_N_models", "test_calculation", "print_error", - "get_varialbe", + "get_variable", ] diff --git a/fmpose3d/common/arguments.py b/fmpose3d/common/arguments.py index 3777a1fd..b94db985 100755 --- a/fmpose3d/common/arguments.py +++ b/fmpose3d/common/arguments.py @@ -53,7 +53,6 @@ def init(self): self.parser.add_argument("-s", "--stride", default=1, type=int) self.parser.add_argument("--gpu", default="0", type=str, help="") self.parser.add_argument("--train", action="store_true") - # self.parser.add_argument('--test', action='store_true') self.parser.add_argument("--test", type=int, default=1) # self.parser.add_argument("--nepoch", type=int, default=41) # self.parser.add_argument( @@ -75,13 +74,15 @@ def init(self): self.parser.add_argument("--model_dir", type=str, default="") # Optional: load model class from a specific file path self.parser.add_argument("--model_path", type=str, default="") + # Model registry name (e.g. "fmpose3d"); used instead of --model_path + self.parser.add_argument("--model_type", type=str, default="fmpose3d") + self.parser.add_argument("--model_weights_path", type=str, default="") self.parser.add_argument("--post_refine_reload", action="store_true") self.parser.add_argument("--checkpoint", type=str, default="") self.parser.add_argument( "--previous_dir", type=str, default="./pre_trained_model/pretrained" ) - self.parser.add_argument("--saved_model_path", type=str, default="") self.parser.add_argument("--n_joints", type=int, default=17) self.parser.add_argument("--out_joints", type=int, default=17) @@ -148,7 +149,6 @@ def init(self): # uncertainty-aware aggregation threshold factor self.parser.add_argument("--topk", type=int, default=3) - self.parser.add_argument("--weight_softmax_tau", type=float, default=1.0) self.parser.add_argument("--exp_temp", type=float, default=0.002) self.parser.add_argument("--mode", type=str, default="exp") diff --git a/fmpose3d/common/config.py b/fmpose3d/common/config.py new file mode 100644 index 00000000..b2980e1f --- /dev/null +++ b/fmpose3d/common/config.py @@ -0,0 +1,276 @@ +""" +FMPose3D: monocular 3D Pose Estimation via Flow Matching + +Official implementation of the paper: +"FMPose3D: monocular 3D Pose Estimation via Flow Matching" +by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis +Licensed under Apache 2.0 +""" + +import math +from dataclasses import dataclass, field, fields, asdict +from typing import List + + +# --------------------------------------------------------------------------- +# Dataclass configuration groups +# --------------------------------------------------------------------------- + + +@dataclass +class ModelConfig: + """Model architecture configuration.""" + model_type: str = "fmpose3d" + + +@dataclass +class FMPose3DConfig(ModelConfig): + model: str = "" + model_type: str = "fmpose3d" + layers: int = 3 + channel: int = 512 + d_hid: int = 1024 + token_dim: int = 256 + n_joints: int = 17 + out_joints: int = 17 + in_channels: int = 2 + out_channels: int = 3 + frames: int = 1 + """Optional: load model class from a specific file path.""" + + +@dataclass +class DatasetConfig: + """Dataset and data loading configuration.""" + + dataset: str = "h36m" + keypoints: str = "cpn_ft_h36m_dbb" + root_path: str = "dataset/" + actions: str = "*" + downsample: int = 1 + subset: float = 1.0 + stride: int = 1 + crop_uv: int = 0 + out_all: int = 1 + train_views: List[int] = field(default_factory=lambda: [0, 1, 2, 3]) + test_views: List[int] = field(default_factory=lambda: [0, 1, 2, 3]) + + # Derived / set during parse based on dataset choice + subjects_train: str = "S1,S5,S6,S7,S8" + subjects_test: str = "S9,S11" + root_joint: int = 0 + joints_left: List[int] = field(default_factory=list) + joints_right: List[int] = field(default_factory=list) + + +@dataclass +class TrainingConfig: + """Training hyperparameters and settings.""" + + train: bool = False + nepoch: int = 41 + batch_size: int = 128 + lr: float = 1e-3 + lr_decay: float = 0.95 + lr_decay_large: float = 0.5 + large_decay_epoch: int = 5 + workers: int = 8 + data_augmentation: bool = True + reverse_augmentation: bool = False + norm: float = 0.01 + + +@dataclass +class InferenceConfig: + """Evaluation and testing configuration.""" + + test: int = 1 + test_augmentation: bool = True + test_augmentation_flip_hypothesis: bool = False + test_augmentation_FlowAug: bool = False + sample_steps: int = 3 + eval_multi_steps: bool = False + eval_sample_steps: str = "1,3,5,7,9" + num_hypothesis_list: str = "1" + hypothesis_num: int = 1 + guidance_scale: float = 1.0 + + +@dataclass +class AggregationConfig: + """Hypothesis aggregation configuration.""" + + topk: int = 3 + exp_temp: float = 0.002 + mode: str = "exp" + opt_steps: int = 2 + + +@dataclass +class CheckpointConfig: + """Checkpoint loading and saving configuration.""" + + reload: bool = False + model_dir: str = "" + model_weights_path: str = "" + checkpoint: str = "" + previous_dir: str = "./pre_trained_model/pretrained" + num_saved_models: int = 3 + previous_best_threshold: float = math.inf + previous_name: str = "" + + +@dataclass +class RefinementConfig: + """Post-refinement model configuration.""" + + post_refine: bool = False + post_refine_reload: bool = False + previous_post_refine_name: str = "" + lr_refine: float = 1e-5 + refine: bool = False + reload_refine: bool = False + previous_refine_name: str = "" + + +@dataclass +class OutputConfig: + """Output, logging, and file management configuration.""" + + create_time: str = "" + filename: str = "" + create_file: int = 1 + debug: bool = False + folder_name: str = "" + sh_file: str = "" + + +@dataclass +class Pose2DConfig: + """2D pose estimator configuration.""" + pose2d_model: str = "hrnet" + + +@dataclass +class HRNetConfig(Pose2DConfig): + """HRNet 2D pose detector configuration. + + Attributes + ---------- + det_dim : int + YOLO input resolution for human detection (default 416). + num_persons : int + Maximum number of persons to estimate per frame (default 1). + thred_score : float + YOLO object-confidence threshold (default 0.30). + hrnet_cfg_file : str + Path to the HRNet YAML experiment config. When left empty the + bundled ``w48_384x288_adam_lr1e-3.yaml`` is used. + hrnet_weights_path : str + Path to the HRNet ``.pth`` checkpoint. When left empty the + auto-downloaded ``pose_hrnet_w48_384x288.pth`` is used. + """ + pose2d_model: str = "hrnet" + det_dim: int = 416 + num_persons: int = 1 + thred_score: float = 0.30 + hrnet_cfg_file: str = "" + hrnet_weights_path: str = "" + + +@dataclass +class DemoConfig: + """Demo / inference configuration.""" + + type: str = "image" + """Input type: ``'image'`` or ``'video'``.""" + path: str = "demo/images/running.png" + """Path to input file or directory.""" + + +@dataclass +class RuntimeConfig: + """Runtime environment configuration.""" + + gpu: str = "0" + pad: int = 0 # derived: (frames - 1) // 2 + single: bool = False + reload_3d: bool = False + + +# --------------------------------------------------------------------------- +# Composite configuration +# --------------------------------------------------------------------------- + +_SUB_CONFIG_CLASSES = { + "model_cfg": ModelConfig, + "dataset_cfg": DatasetConfig, + "training_cfg": TrainingConfig, + "inference_cfg": InferenceConfig, + "aggregation_cfg": AggregationConfig, + "checkpoint_cfg": CheckpointConfig, + "refinement_cfg": RefinementConfig, + "output_cfg": OutputConfig, + "pose2d_cfg": Pose2DConfig, + "demo_cfg": DemoConfig, + "runtime_cfg": RuntimeConfig, +} + + +@dataclass +class PipelineConfig: + """Top-level configuration for FMPose3D pipeline. + + Groups related settings into sub-configs:: + + config.model_cfg.layers + config.training_cfg.lr + """ + + model_cfg: ModelConfig = field(default_factory=FMPose3DConfig) + dataset_cfg: DatasetConfig = field(default_factory=DatasetConfig) + training_cfg: TrainingConfig = field(default_factory=TrainingConfig) + inference_cfg: InferenceConfig = field(default_factory=InferenceConfig) + aggregation_cfg: AggregationConfig = field(default_factory=AggregationConfig) + checkpoint_cfg: CheckpointConfig = field(default_factory=CheckpointConfig) + refinement_cfg: RefinementConfig = field(default_factory=RefinementConfig) + output_cfg: OutputConfig = field(default_factory=OutputConfig) + pose2d_cfg: Pose2DConfig = field(default_factory=HRNetConfig) + demo_cfg: DemoConfig = field(default_factory=DemoConfig) + runtime_cfg: RuntimeConfig = field(default_factory=RuntimeConfig) + + # -- construction from argparse namespace --------------------------------- + + @classmethod + def from_namespace(cls, ns) -> "PipelineConfig": + """Build a :class:`PipelineConfig` from an ``argparse.Namespace`` + + Example:: + + args = opts().parse() + cfg = PipelineConfig.from_namespace(args) + """ + raw = vars(ns) if hasattr(ns, "__dict__") else dict(ns) + + def _pick(dc_class, src: dict): + names = {f.name for f in fields(dc_class)} + return dc_class(**{k: v for k, v in src.items() if k in names}) + + kwargs = {} + for group_name, dc_class in _SUB_CONFIG_CLASSES.items(): + if group_name == "model_cfg" and raw.get("model_type", "fmpose3d") == "fmpose3d": + dc_class = FMPose3DConfig + elif group_name == "pose2d_cfg" and raw.get("pose2d_model", "hrnet") == "hrnet": + dc_class = HRNetConfig + kwargs[group_name] = _pick(dc_class, raw) + return cls(**kwargs) + + # -- utilities ------------------------------------------------------------ + + def to_dict(self) -> dict: + """Return a flat dictionary of all configuration values.""" + result = {} + for group_name in _SUB_CONFIG_CLASSES: + result.update(asdict(getattr(self, group_name))) + return result + diff --git a/fmpose3d/common/utils.py b/fmpose3d/common/utils.py index 11cf3747..549ef2bd 100755 --- a/fmpose3d/common/utils.py +++ b/fmpose3d/common/utils.py @@ -15,7 +15,44 @@ import numpy as np import torch -from torch.autograd import Variable + + +def euler_sample( + c_2d: torch.Tensor, + y: torch.Tensor, + steps: int, + model: torch.nn.Module, +) -> torch.Tensor: + """Euler ODE sampler for Conditional Flow Matching at test time. + + Integrates the learned velocity field from *t = 0* to *t = 1* using + ``steps`` uniform Euler steps. + + Parameters + ---------- + c_2d : Tensor + 2-D conditioning input, shape ``(B, F, J, 2)``. + y : Tensor + Initial noise sample (same spatial dims as ``c_2d`` but with 3 + output channels), shape ``(B, F, J, 3)``. + steps : int + Number of Euler integration steps. + model : nn.Module + The velocity-prediction network ``v(c_2d, y, t)``. + + Returns + ------- + Tensor + The denoised 3-D prediction, same shape as *y*. + """ + dt = 1.0 / steps + for s in range(steps): + t_s = torch.full( + (c_2d.size(0), 1, 1, 1), s * dt, device=c_2d.device, dtype=c_2d.dtype + ) + v_s = model(c_2d, y, t_s) + y = y + dt * v_s + return y def deterministic_random(min_value, max_value, data): digest = hashlib.sha256(data.encode()).digest() @@ -186,20 +223,17 @@ def update(self, val, n=1): self.avg = self.sum / self.count -def get_varialbe(split, target): +def get_variable(split, target): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num = len(target) var = [] if split == "train": for i in range(num): - temp = ( - Variable(target[i], requires_grad=False) - .contiguous() - .type(torch.cuda.FloatTensor) - ) + temp = target[i].requires_grad_(False).contiguous().float().to(device) var.append(temp) else: for i in range(num): - temp = Variable(target[i]).contiguous().cuda().type(torch.cuda.FloatTensor) + temp = target[i].contiguous().float().to(device) var.append(temp) return var diff --git a/fmpose3d/fmpose3d.py b/fmpose3d/fmpose3d.py new file mode 100644 index 00000000..e1992d8e --- /dev/null +++ b/fmpose3d/fmpose3d.py @@ -0,0 +1,658 @@ +""" +FMPose3D: monocular 3D Pose Estimation via Flow Matching + +Official implementation of the paper: +"FMPose3D: monocular 3D Pose Estimation via Flow Matching" +by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis +Licensed under Apache 2.0 +""" + + +from __future__ import annotations + +import copy +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Sequence, Tuple, Union + +import numpy as np +import torch + +from fmpose3d.common.camera import camera_to_world, normalize_screen_coordinates +from fmpose3d.common.utils import euler_sample +from fmpose3d.common.config import ( + FMPose3DConfig, + HRNetConfig, + InferenceConfig, + ModelConfig, +) +from fmpose3d.models import get_model + +#: Progress callback signature: ``(current_step, total_steps) -> None``. +ProgressCallback = Callable[[int, int], None] + + +# Default camera-to-world rotation quaternion (from the demo script). +_DEFAULT_CAM_ROTATION = np.array( + [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088], + dtype="float32", +) + + +# --------------------------------------------------------------------------- +# 2D pose estimator +# --------------------------------------------------------------------------- + + +class HRNetEstimator: + """Default 2D pose estimator: HRNet + YOLO, with COCO→H36M conversion. + + Thin wrapper around :class:`~fmpose3d.lib.hrnet.api.HRNetPose2d` that + adds the COCO → H36M keypoint conversion expected by the 3D lifter. + + Parameters + ---------- + cfg : HRNetConfig + Estimator settings (``det_dim``, ``num_persons``, …). + """ + + def __init__(self, cfg: HRNetConfig | None = None) -> None: + self.cfg = cfg or HRNetConfig() + self._model = None + + def setup_runtime(self) -> None: + """Load YOLO + HRNet models (safe to call more than once).""" + if self._model is not None: + return + + from fmpose3d.lib.hrnet.hrnet import HRNetPose2d + + self._model = HRNetPose2d( + det_dim=self.cfg.det_dim, + num_persons=self.cfg.num_persons, + thred_score=self.cfg.thred_score, + hrnet_cfg_file=self.cfg.hrnet_cfg_file, + hrnet_weights_path=self.cfg.hrnet_weights_path, + ) + self._model.setup() + + def predict( + self, frames: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + """Estimate 2D keypoints from image frames and return in H36M format. + + Parameters + ---------- + frames : ndarray + BGR image frames, shape ``(N, H, W, C)``. + + Returns + ------- + keypoints : ndarray + H36M-format 2D keypoints, shape ``(num_persons, N, 17, 2)``. + scores : ndarray + Per-joint confidence scores, shape ``(num_persons, N, 17)``. + """ + from fmpose3d.lib.preprocess import h36m_coco_format, revise_kpts + + self.setup_runtime() + + keypoints, scores = self._model.predict(frames) + + keypoints, scores, valid_frames = h36m_coco_format(keypoints, scores) + # NOTE: revise_kpts is computed for consistency but is NOT applied + # to the returned keypoints, matching the demo script behaviour. + _revised = revise_kpts(keypoints, scores, valid_frames) # noqa: F841 + + return keypoints, scores + + +# --------------------------------------------------------------------------- +# Result containers +# --------------------------------------------------------------------------- + + +@dataclass +class Pose2DResult: + """Container returned by :meth:`FMPose3DInference.prepare_2d`.""" + + keypoints: np.ndarray + """H36M-format 2D keypoints, shape ``(num_persons, num_frames, 17, 2)``.""" + scores: np.ndarray + """Per-joint confidence scores, shape ``(num_persons, num_frames, 17)``.""" + image_size: tuple[int, int] = (0, 0) + """``(height, width)`` of the source frames.""" + + +@dataclass +class Pose3DResult: + """Container returned by :meth:`FMPose3DInference.pose_3d`.""" + + poses_3d: np.ndarray + """Root-relative 3D poses, shape ``(num_frames, 17, 3)``.""" + poses_3d_world: np.ndarray + """World-coordinate 3D poses, shape ``(num_frames, 17, 3)``.""" + + +#: Accepted source types for :meth:`FMPose3DInference.predict`. +#: +#: * ``str`` or ``Path`` – path to an image file or directory of images. +#: * ``np.ndarray`` – a single frame ``(H, W, C)`` or batch ``(N, H, W, C)``. +#: * ``list`` – a list of file paths or a list of ``(H, W, C)`` arrays. +Source = Union[str, Path, np.ndarray, Sequence[Union[str, Path, np.ndarray]]] + + +@dataclass +class _IngestedInput: + """Normalised result of :meth:`FMPose3DInference._ingest_input`. + + Always contains a batch of BGR frames as a numpy array, regardless + of the original source type. + """ + + frames: np.ndarray + """BGR image frames, shape ``(N, H, W, C)``.""" + image_size: tuple[int, int] + """``(height, width)`` of the source frames.""" + + +# --------------------------------------------------------------------------- +# Main inference class +# --------------------------------------------------------------------------- + + +class FMPose3DInference: + """High-level, two-step inference API for FMPose3D. + + Typical workflow:: + + api = FMPose3DInference(model_weights_path="weights.pth") + result_2d = api.prepare_2d("photo.jpg") + result_3d = api.pose_3d(result_2d.keypoints, image_size=(H, W)) + + Parameters + ---------- + model_cfg : ModelConfig, optional + Model architecture settings (layers, channels, …). + Defaults to :class:`~fmpose3d.common.config.FMPose3DConfig` defaults. + inference_cfg : InferenceConfig, optional + Inference settings (sample_steps, test_augmentation, …). + Defaults to :class:`~fmpose3d.common.config.InferenceConfig` defaults. + model_weights_path : str + Path to a ``.pth`` checkpoint for the 3D lifting model. + If empty the model is created but **not** loaded with weights. + device : str or torch.device, optional + Compute device. ``None`` (default) picks CUDA when available. + """ + + # H36M joint indices for left / right flip augmentation + _JOINTS_LEFT: list[int] = [4, 5, 6, 11, 12, 13] + _JOINTS_RIGHT: list[int] = [1, 2, 3, 14, 15, 16] + + _IMAGE_EXTENSIONS: set[str] = { + ".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp", + } + + # ------------------------------------------------------------------ + # Initialisation + # ------------------------------------------------------------------ + + def __init__( + self, + model_cfg: ModelConfig | None = None, + inference_cfg: InferenceConfig | None = None, + model_weights_path: str = "", + device: str | torch.device | None = None, + ) -> None: + self.model_cfg = model_cfg or FMPose3DConfig() + self.inference_cfg = inference_cfg or InferenceConfig() + self.model_weights_path = model_weights_path + + # Resolve device and padding configuration + self._device: torch.device | None = self._resolve_device(device) + self._pad: int = self._resolve_pad() + + # Lazy-loaded models (populated by setup_runtime) + self._model_3d: torch.nn.Module | None = None + self._estimator_2d: HRNetEstimator | None = None + + def setup_runtime(self) -> None: + """Initialise all runtime components on first use. + + Called automatically when the API is used for the first time. + Loads the 2D estimator, the 3D lifting model, and the model + weights in sequence. + """ + self._setup_estimator_2d() + self._setup_model() + self._load_weights() + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + @torch.no_grad() + def predict( + self, + source: Source, + *, + camera_rotation: np.ndarray | None = _DEFAULT_CAM_ROTATION, + seed: int | None = None, + progress: ProgressCallback | None = None, + ) -> Pose3DResult: + """End-to-end prediction: 2D pose estimation → 3D lifting. + + Convenience wrapper that calls :meth:`prepare_2d` then + :meth:`pose_3d`. + + Parameters + ---------- + source : Source + Input to process. Accepts a file path (``str`` / ``Path``), + a directory of images, a numpy array ``(H, W, C)`` for a + single frame, ``(N, H, W, C)`` for a batch, or a list of + paths / arrays. See :data:`Source` for the full type. + Video files are **not** supported and will raise + :class:`NotImplementedError`. + camera_rotation : ndarray or None + Length-4 quaternion for the camera-to-world rotation. + See :meth:`pose_3d` for details. + seed : int or None + Deterministic seed for the 3D sampling step. + See :meth:`pose_3d` for details. + progress : ProgressCallback or None + Optional ``(current_step, total_steps)`` callback. Forwarded + to the :meth:`pose_3d` step (per-frame reporting). + + Returns + ------- + Pose3DResult + Root-relative and world-coordinate 3D poses. + """ + result_2d = self.prepare_2d(source) + return self.pose_3d( + result_2d.keypoints, + result_2d.image_size, + camera_rotation=camera_rotation, + seed=seed, + progress=progress, + ) + + @torch.no_grad() + def prepare_2d( + self, + source: Source, + progress: ProgressCallback | None = None, + ) -> Pose2DResult: + """Estimate 2D poses using HRNet + YOLO. + + The estimator is set up lazily by :meth:`setup_runtime` on first + call. + + Parameters + ---------- + source : Source + Input to process. Accepts a file path (``str`` / ``Path``), + a directory of images, a numpy array ``(H, W, C)`` for a + single frame, ``(N, H, W, C)`` for a batch, or a list of + paths / arrays. See :data:`Source` for the full type. + progress : ProgressCallback or None + Optional ``(current_step, total_steps)`` callback invoked + before and after the 2D estimation step. + + Returns + ------- + Pose2DResult + H36M-format 2D keypoints and per-joint scores. The result + also carries ``image_size`` so it can be forwarded directly + to :meth:`pose_3d`. + """ + ingested = self._ingest_input(source) + self.setup_runtime() + if progress: + progress(0, 1) + keypoints, scores = self._estimator_2d.predict(ingested.frames) + if progress: + progress(1, 1) + return Pose2DResult( + keypoints=keypoints, + scores=scores, + image_size=ingested.image_size, + ) + + @torch.no_grad() + def pose_3d( + self, + keypoints_2d: np.ndarray, + image_size: tuple[int, int], + *, + camera_rotation: np.ndarray | None = _DEFAULT_CAM_ROTATION, + seed: int | None = None, + progress: ProgressCallback | None = None, + ) -> Pose3DResult: + """Lift 2D keypoints to 3D using the flow-matching model. + + The pipeline exactly mirrors ``demo/vis_in_the_wild.py``'s + ``get_3D_pose_from_image``: normalise screen coordinates, build a + flip-augmented conditioning pair, run two independent Euler ODE + integrations (each with its own noise sample), un-flip and average, + zero the root joint, then convert to world coordinates. + + Parameters + ---------- + keypoints_2d : ndarray + 2D keypoints returned by :meth:`prepare_2d`. Accepted shapes: + + * ``(num_persons, num_frames, 17, 2)`` – first person is used. + * ``(num_frames, 17, 2)`` – treated as a single person. + image_size : tuple of (int, int) + ``(height, width)`` of the source image / video frames. + camera_rotation : ndarray or None + Length-4 quaternion for the camera-to-world rotation applied + to produce ``poses_3d_world``. Defaults to the rotation used + in the official demo. Pass ``None`` to skip the transform + (``poses_3d_world`` will equal ``poses_3d``). + seed : int or None + If given, ``torch.manual_seed(seed)`` is called before + sampling so that results are fully reproducible. Use the + same seed in the demo script (by inserting + ``torch.manual_seed(seed)`` before the ``torch.randn`` calls) + to obtain bit-identical results. + progress : ProgressCallback or None + Optional ``(current_step, total_steps)`` callback invoked + after each frame is lifted to 3D. + + Returns + ------- + Pose3DResult + Root-relative and world-coordinate 3D poses. + """ + self.setup_runtime() + model = self._model_3d + h, w = image_size + steps = self.inference_cfg.sample_steps + use_flip = self.inference_cfg.test_augmentation + jl = self._JOINTS_LEFT + jr = self._JOINTS_RIGHT + + # Optional deterministic seeding + if seed is not None: + torch.manual_seed(seed) + + # Normalise input shape to (num_frames, 17, 2) + if keypoints_2d.ndim == 4: + kpts = keypoints_2d[0] # first person + elif keypoints_2d.ndim == 3: + kpts = keypoints_2d + else: + raise ValueError( + f"Expected keypoints_2d with 3 or 4 dims, got {keypoints_2d.ndim}" + ) + + num_frames = kpts.shape[0] + all_poses_3d: list[np.ndarray] = [] + all_poses_world: list[np.ndarray] = [] + + if progress: + progress(0, num_frames) + + for i in range(num_frames): + frame_kpts = kpts[i : i + 1] # (1, 17, 2) + + # Normalise to [-1, 1] range (same as demo) + normed = normalize_screen_coordinates(frame_kpts, w=w, h=h) + + if use_flip: + # -- build flip-augmented conditioning (matches demo exactly) -- + normed_flip = copy.deepcopy(normed) + normed_flip[:, :, 0] *= -1 + normed_flip[:, jl + jr] = normed_flip[:, jr + jl] + input_2d = np.concatenate( + (np.expand_dims(normed, axis=0), np.expand_dims(normed_flip, axis=0)), + 0, + ) # (2, F, J, 2) + input_2d = input_2d[np.newaxis, :, :, :, :] # (1, 2, F, J, 2) + input_t = torch.from_numpy(input_2d.astype("float32")).to(self.device) + + # -- two independent Euler ODE runs (matches demo exactly) -- + y = torch.randn( + input_t.size(0), input_t.size(2), input_t.size(3), 3, + device=self.device, + ) + output_3d_non_flip = euler_sample( + input_t[:, 0], y, steps, model, + ) + + y_flip = torch.randn( + input_t.size(0), input_t.size(2), input_t.size(3), 3, + device=self.device, + ) + output_3d_flip = euler_sample( + input_t[:, 1], y_flip, steps, model, + ) + + # -- un-flip & average (matches demo exactly) -- + output_3d_flip[:, :, :, 0] *= -1 + output_3d_flip[:, :, jl + jr, :] = output_3d_flip[ + :, :, jr + jl, : + ] + + output = (output_3d_non_flip + output_3d_flip) / 2 + else: + input_2d = normed[np.newaxis] # (1, F, J, 2) + input_t = torch.from_numpy(input_2d.astype("float32")).to(self.device) + y = torch.randn( + input_t.size(0), input_t.size(1), input_t.size(2), 3, + device=self.device, + ) + output = euler_sample(input_t, y, steps, model) + + # Extract single-frame result → (17, 3) (matches demo exactly) + output = output[0:, self._pad].unsqueeze(1) + output[:, :, 0, :] = 0 # root-relative + pose_np = output[0, 0].cpu().detach().numpy() + all_poses_3d.append(pose_np) + + # Camera-to-world transform (matches demo exactly) + if camera_rotation is not None: + pose_world = camera_to_world(pose_np, R=camera_rotation, t=0) + pose_world[:, 2] -= np.min(pose_world[:, 2]) + else: + pose_world = pose_np.copy() + all_poses_world.append(pose_world) + + if progress: + progress(i + 1, num_frames) + + poses_3d = np.stack(all_poses_3d, axis=0) # (num_frames, 17, 3) + poses_world = np.stack(all_poses_world, axis=0) # (num_frames, 17, 3) + + return Pose3DResult(poses_3d=poses_3d, poses_3d_world=poses_world) + + # ------------------------------------------------------------------ + # Private helpers – device & padding + # ------------------------------------------------------------------ + + def _resolve_device(self, device) -> None: + """Set ``self.device`` from the constructor argument.""" + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = torch.device(device) + + def _resolve_pad(self) -> int: + """Derived from frames setting (single-frame models ⇒ pad=0).""" + return (self.model_cfg.frames - 1) // 2 + + # ------------------------------------------------------------------ + # Private helpers – model loading + # ------------------------------------------------------------------ + + def _setup_estimator_2d(self) -> HRNetEstimator: + """Initialise the HRNet 2D pose estimator on first use.""" + if self._estimator_2d is None: + self._estimator_2d = HRNetEstimator() + return self._estimator_2d + + def _setup_model(self) -> torch.nn.Module: + """Initialise the 3D lifting model on first use.""" + if self._model_3d is None: + ModelClass = get_model(self.model_cfg.model_type) + self._model_3d = ModelClass(self.model_cfg).to(self.device) + self._model_3d.eval() + return self._model_3d + + def _load_weights(self) -> None: + """Load checkpoint weights into ``self._model_3d``. + + Mirrors the demo's loading strategy: iterate over the model's own + state-dict keys and pull matching entries from the checkpoint so that + extra keys in the checkpoint are silently ignored. + """ + if not self.model_weights_path: + raise ValueError( + "No model weights path provided. Pass 'model_weights_path' " + "to the FMPose3DInference constructor." + ) + weights = Path(self.model_weights_path) + if not weights.exists(): + raise ValueError( + f"Model weights file not found: {weights}. " + "Please provide a valid path to a .pth checkpoint file in the " + "FMPose3DInference constructor." + ) + if self._model_3d is None: + raise ValueError("Model not initialised. Call setup_runtime() first.") + pre_dict = torch.load( + self.model_weights_path, + weights_only=True, + map_location=self.device, + ) + model_dict = self._model_3d.state_dict() + for name in model_dict: + if name in pre_dict: + model_dict[name] = pre_dict[name] + self._model_3d.load_state_dict(model_dict) + + # ------------------------------------------------------------------ + # Private helpers – input resolution + # ------------------------------------------------------------------ + + def _ingest_input(self, source: Source) -> _IngestedInput: + """Normalise *source* into a ``(N, H, W, C)`` frames array. + + Accepted *source* values: + + * **str / Path** – path to a single image or a directory of images. + * **ndarray (H, W, C)** – a single BGR frame. + * **ndarray (N, H, W, C)** – a batch of BGR frames. + * **list of str/Path** – multiple image file paths. + * **list of ndarray** – multiple ``(H, W, C)`` BGR frames. + + Video files are not yet supported and will raise + :class:`NotImplementedError`. + + Parameters + ---------- + source : Source + The input to resolve. + + Returns + ------- + _IngestedInput + Contains ``frames`` as ``(N, H, W, C)`` and ``image_size`` + as ``(height, width)``. + """ + import cv2 + + # -- numpy array (single frame or batch) ---------------------------- + if isinstance(source, np.ndarray): + if source.ndim == 3: + frames = source[np.newaxis] # (1, H, W, C) + elif source.ndim == 4: + frames = source + else: + raise ValueError( + f"Expected ndarray with 3 (H,W,C) or 4 (N,H,W,C) dims, " + f"got {source.ndim}" + ) + h, w = frames.shape[1], frames.shape[2] + return _IngestedInput(frames=frames, image_size=(h, w)) + + # -- list / sequence ------------------------------------------------ + if isinstance(source, (list, tuple)): + if len(source) == 0: + raise ValueError("Empty source list.") + + first = source[0] + + # List of arrays + if isinstance(first, np.ndarray): + frames = np.stack(list(source), axis=0) + h, w = frames.shape[1], frames.shape[2] + return _IngestedInput(frames=frames, image_size=(h, w)) + + # List of paths + if isinstance(first, (str, Path)): + loaded = [] + for p in source: + p = Path(p) + self._check_not_video(p) + img = cv2.imread(str(p)) + if img is None: + raise FileNotFoundError( + f"Could not read image: {p}" + ) + loaded.append(img) + frames = np.stack(loaded, axis=0) + h, w = frames.shape[1], frames.shape[2] + return _IngestedInput(frames=frames, image_size=(h, w)) + + raise TypeError( + f"Unsupported element type in source list: {type(first)}" + ) + + # -- str / Path (file or directory) --------------------------------- + p = Path(source) + if not p.exists(): + raise FileNotFoundError(f"Source path does not exist: {p}") + + self._check_not_video(p) + + if p.is_dir(): + images = sorted( + f for f in p.iterdir() + if f.suffix.lower() in self._IMAGE_EXTENSIONS + ) + if not images: + raise FileNotFoundError( + f"No image files found in directory: {p}" + ) + loaded = [] + for img_path in images: + img = cv2.imread(str(img_path)) + if img is None: + raise FileNotFoundError( + f"Could not read image: {img_path}" + ) + loaded.append(img) + frames = np.stack(loaded, axis=0) + h, w = frames.shape[1], frames.shape[2] + return _IngestedInput(frames=frames, image_size=(h, w)) + + # Single image file + img = cv2.imread(str(p)) + if img is None: + raise FileNotFoundError(f"Could not read image: {p}") + frames = img[np.newaxis] # (1, H, W, C) + h, w = frames.shape[1], frames.shape[2] + return _IngestedInput(frames=frames, image_size=(h, w)) + + def _check_not_video(self, p: Path) -> None: + """Raise :class:`NotImplementedError` if *p* looks like a video.""" + _VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv"} + if p.is_file() and p.suffix.lower() in _VIDEO_EXTS: + raise NotImplementedError( + f"Video input is not yet supported (got {p}). " + "Please extract frames and pass them as image paths or arrays." + ) diff --git a/fmpose3d/lib/hrnet/hrnet.py b/fmpose3d/lib/hrnet/hrnet.py new file mode 100644 index 00000000..1368b30b --- /dev/null +++ b/fmpose3d/lib/hrnet/hrnet.py @@ -0,0 +1,282 @@ +""" +FMPose3D: monocular 3D Pose Estimation via Flow Matching + +Official implementation of the paper: +"FMPose3D: monocular 3D Pose Estimation via Flow Matching" +by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis +Licensed under Apache 2.0 +""" + +""" +FMPose3D – clean HRNet 2D pose estimation API. + +Provides :class:`HRNetPose2d`, a self-contained wrapper around the +HRNet + YOLO detection pipeline that accepts numpy arrays directly +(no file I/O, no argparse, no global yacs config leaking out). + +Usage:: + + api = HRNetPose2d(det_dim=416, num_persons=1) + api.setup() # loads YOLO + HRNet weights + keypoints, scores = api.predict(frames) # (M, N, 17, 2), (M, N, 17) +""" + +from __future__ import annotations + +import copy +import os.path as osp +from collections import OrderedDict +from typing import Tuple + +import numpy as np +import torch +import torch.backends.cudnn as cudnn + +from fmpose3d.lib.checkpoint.download_checkpoints import ( + ensure_checkpoints, + get_checkpoint_path, +) + + +class HRNetPose2d: + """Self-contained 2D pose estimator (YOLO detector + HRNet). + + A self-contained HRNet 2D pose estimator that accepts numpy arrays directly. + It serves as alternative to the gen_video_kpts function in fmpose3d/lib/hrnet/gen_kpts.py, + which generates 2D keypoints from a video file. + + Parameters + ---------- + det_dim : int + YOLO input resolution (default 416). + num_persons : int + Maximum number of persons to track per frame (default 1). + thred_score : float + YOLO object-confidence threshold (default 0.30). + hrnet_cfg_file : str + Path to the HRNet YAML experiment config. Empty string (default) + uses the bundled ``w48_384x288_adam_lr1e-3.yaml``. + hrnet_weights_path : str + Path to the HRNet ``.pth`` checkpoint. Empty string (default) + uses the auto-downloaded ``pose_hrnet_w48_384x288.pth``. + """ + + def __init__( + self, + det_dim: int = 416, + num_persons: int = 1, + thred_score: float = 0.30, + hrnet_cfg_file: str = "", + hrnet_weights_path: str = "", + ) -> None: + self.det_dim = det_dim + self.num_persons = num_persons + self.thred_score = thred_score + self.hrnet_cfg_file = hrnet_cfg_file + self.hrnet_weights_path = hrnet_weights_path + + # Populated by setup() + self._human_model = None + self._pose_model = None + self._people_sort = None + self._hrnet_cfg = None # frozen yacs CfgNode used by PreProcess / get_final_preds + + # ------------------------------------------------------------------ + # Setup + # ------------------------------------------------------------------ + + @property + def is_ready(self) -> bool: + """``True`` once :meth:`setup` has been called.""" + return self._human_model is not None + + def setup(self) -> "HRNetPose2d": + """Load YOLO detector and HRNet pose model. + + Can safely be called more than once (subsequent calls are no-ops). + + Returns ``self`` so you can write ``api = HRNetPose2d().setup()``. + """ + if self.is_ready: + return self + + ensure_checkpoints() + + # --- resolve paths --------------------------------------------------- + hrnet_cfg_file = self.hrnet_cfg_file + if not hrnet_cfg_file: + hrnet_cfg_file = osp.join( + osp.dirname(osp.abspath(__file__)), + "experiments", + "w48_384x288_adam_lr1e-3.yaml", + ) + + hrnet_weights = self.hrnet_weights_path + if not hrnet_weights: + hrnet_weights = get_checkpoint_path("pose_hrnet_w48_384x288.pth") + + # --- build internal yacs config (kept private) ----------------------- + from fmpose3d.lib.hrnet.lib.config import cfg as _global_cfg + from fmpose3d.lib.hrnet.lib.config import update_config as _update_cfg + from types import SimpleNamespace + + _global_cfg.defrost() + _update_cfg( + _global_cfg, + SimpleNamespace(cfg=hrnet_cfg_file, opts=[], modelDir=hrnet_weights), + ) + # Snapshot the frozen cfg so we can pass it to PreProcess / get_final_preds. + self._hrnet_cfg = _global_cfg + + # cudnn tuning + cudnn.benchmark = self._hrnet_cfg.CUDNN.BENCHMARK + cudnn.deterministic = self._hrnet_cfg.CUDNN.DETERMINISTIC + cudnn.enabled = self._hrnet_cfg.CUDNN.ENABLED + + # --- load models ----------------------------------------------------- + from fmpose3d.lib.yolov3.human_detector import load_model as _yolo_load + from fmpose3d.lib.sort.sort import Sort + + self._human_model = _yolo_load(inp_dim=self.det_dim) + self._pose_model = self._load_hrnet(self._hrnet_cfg) + self._people_sort = Sort(min_hits=0) + + return self + + # ------------------------------------------------------------------ + # Prediction + # ------------------------------------------------------------------ + + def predict( + self, frames: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + """Estimate 2D keypoints for a batch of BGR frames. + + Parameters + ---------- + frames : ndarray, shape ``(N, H, W, C)`` + BGR images. A single frame ``(H, W, C)`` is also accepted + and will be treated as a batch of one. + + Returns + ------- + keypoints : ndarray, shape ``(num_persons, N, 17, 2)`` + COCO-format 2D keypoints in pixel coordinates. + scores : ndarray, shape ``(num_persons, N, 17)`` + Per-joint confidence scores. + """ + if not self.is_ready: + self.setup() + + if frames.ndim == 3: + frames = frames[np.newaxis] + + kpts_result = [] + scores_result = [] + + for i in range(frames.shape[0]): + kpts, sc = self._estimate_frame(frames[i]) + kpts_result.append(kpts) + scores_result.append(sc) + + keypoints = np.array(kpts_result) # (N, M, 17, 2) + scores = np.array(scores_result) # (N, M, 17) + + # (N, M, 17, 2) → (M, N, 17, 2) + keypoints = keypoints.transpose(1, 0, 2, 3) + # (N, M, 17) → (M, N, 17) + scores = scores.transpose(1, 0, 2) + + return keypoints, scores + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _load_hrnet(config): + """Instantiate HRNet and load checkpoint weights.""" + from fmpose3d.lib.hrnet.lib.models import pose_hrnet + + model = pose_hrnet.get_pose_net(config, is_train=False) + if torch.cuda.is_available(): + model = model.cuda() + + state_dict = torch.load(config.OUTPUT_DIR, weights_only=True) + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + new_state_dict[k] = v + model.load_state_dict(new_state_dict) + model.eval() + return model + + def _estimate_frame( + self, frame: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + """Run detection + pose estimation on a single BGR frame. + + Returns + ------- + kpts : ndarray, shape ``(num_persons, 17, 2)`` + scores : ndarray, shape ``(num_persons, 17)`` + """ + from fmpose3d.lib.yolov3.human_detector import yolo_human_det + from fmpose3d.lib.hrnet.lib.utils.utilitys import PreProcess + from fmpose3d.lib.hrnet.lib.utils.inference import get_final_preds + + num_persons = self.num_persons + + bboxs, det_scores = yolo_human_det( + frame, self._human_model, reso=self.det_dim, confidence=self.thred_score, + ) + + if bboxs is None or not bboxs.any(): + # No detection – return zeros + kpts = np.zeros((num_persons, 17, 2), dtype=np.float32) + scores = np.zeros((num_persons, 17), dtype=np.float32) + return kpts, scores + + # Track + people_track = self._people_sort.update(bboxs) + + if people_track.shape[0] == 1: + people_track_ = people_track[-1, :-1].reshape(1, 4) + elif people_track.shape[0] >= 2: + people_track_ = people_track[-num_persons:, :-1].reshape(num_persons, 4) + people_track_ = people_track_[::-1] + else: + kpts = np.zeros((num_persons, 17, 2), dtype=np.float32) + scores = np.zeros((num_persons, 17), dtype=np.float32) + return kpts, scores + + track_bboxs = [] + for bbox in people_track_: + bbox = [round(i, 2) for i in list(bbox)] + track_bboxs.append(bbox) + + with torch.no_grad(): + inputs, origin_img, center, scale = PreProcess( + frame, track_bboxs, self._hrnet_cfg, num_persons, + ) + inputs = inputs[:, [2, 1, 0]] + + if torch.cuda.is_available(): + inputs = inputs.cuda() + output = self._pose_model(inputs) + + preds, maxvals = get_final_preds( + self._hrnet_cfg, + output.clone().cpu().numpy(), + np.asarray(center), + np.asarray(scale), + ) + + kpts = np.zeros((num_persons, 17, 2), dtype=np.float32) + scores = np.zeros((num_persons, 17), dtype=np.float32) + for i, kpt in enumerate(preds): + kpts[i] = kpt + for i, score in enumerate(maxvals): + scores[i] = score.squeeze() + + return kpts, scores + diff --git a/fmpose3d/models/__init__.py b/fmpose3d/models/__init__.py index 5b94df4d..b9dc64a4 100644 --- a/fmpose3d/models/__init__.py +++ b/fmpose3d/models/__init__.py @@ -11,10 +11,16 @@ FMPose3D models. """ -from .graph_frames import Graph -from .model_GAMLP import Model +from .base_model import BaseModel, register_model, get_model, list_models + +# Import model subpackages so their @register_model decorators execute. +from .fmpose3d import Graph, Model __all__ = [ + "BaseModel", + "register_model", + "get_model", + "list_models", "Graph", "Model", ] diff --git a/fmpose3d/models/base_model.py b/fmpose3d/models/base_model.py new file mode 100644 index 00000000..ddf06737 --- /dev/null +++ b/fmpose3d/models/base_model.py @@ -0,0 +1,114 @@ +""" +FMPose3D: monocular 3D Pose Estimation via Flow Matching + +Official implementation of the paper: +"FMPose3D: monocular 3D Pose Estimation via Flow Matching" +by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis +Licensed under Apache 2.0 +""" + +from abc import ABC, abstractmethod +import warnings + +import torch +import torch.nn as nn + + +# --------------------------------------------------------------------------- +# Model registry +# --------------------------------------------------------------------------- + +_MODEL_REGISTRY: dict[str, type["BaseModel"]] = {} + + +def register_model(name: str): + """Class decorator that registers a model under *name*. + + Usage:: + + @register_model("my_model") + class MyModel(BaseModel): + ... + + The model can then be retrieved with :func:`get_model`. + """ + + def decorator(cls: type["BaseModel"]) -> type["BaseModel"]: + if name in _MODEL_REGISTRY: + warnings.warn( + f"Model '{name}' is already registered " + f"(existing: {_MODEL_REGISTRY[name].__qualname__}, " + f"new: {cls.__qualname__})" + ) + # raise ValueError( + # f"Model '{name}' is already registered " + # f"(existing: {_MODEL_REGISTRY[name].__qualname__}, " + # f"new: {cls.__qualname__})" + # ) + _MODEL_REGISTRY[name] = cls + return cls + + return decorator + + +def get_model(name: str) -> type["BaseModel"]: + """Return the model class registered under *name*. + + Raises :class:`KeyError` with a helpful message when the name is unknown. + """ + if name not in _MODEL_REGISTRY: + available = ", ".join(sorted(_MODEL_REGISTRY)) or "(none)" + raise KeyError( + f"Unknown model '{name}'. Available models: {available}" + ) + return _MODEL_REGISTRY[name] + + +def list_models() -> list[str]: + """Return a sorted list of all registered model names.""" + return sorted(_MODEL_REGISTRY) + + +# --------------------------------------------------------------------------- +# Base model +# --------------------------------------------------------------------------- + + +class BaseModel(ABC, nn.Module): + """Abstract base class for all FMPose3D lifting models. + + Every model must accept a single *args* namespace / object in its + constructor and implement :meth:`forward` with the signature below. + + Parameters expected on *args* (at minimum): + - ``channel`` – embedding dimension + - ``layers`` – number of transformer / GCN blocks + - ``d_hid`` – hidden MLP dimension + - ``token_dim`` – token dimension + - ``n_joints`` – number of body joints + """ + + @abstractmethod + def __init__(self, args) -> None: + super().__init__() + + @abstractmethod + def forward( + self, + pose_2d: torch.Tensor, + y_t: torch.Tensor, + t: torch.Tensor, + ) -> torch.Tensor: + """Predict the velocity field for flow matching. + + Args: + pose_2d: 2D keypoints, shape ``(B, F, J, 2)``. + y_t: Noisy 3D hypothesis at time *t*, shape ``(B, F, J, 3)``. + t: Diffusion / flow time, shape ``(B, F, 1, 1)`` with values + in ``[0, 1]``. + + Returns: + Predicted velocity ``v``, shape ``(B, F, J, 3)``. + """ + ... + diff --git a/fmpose3d/models/fmpose3d/__init__.py b/fmpose3d/models/fmpose3d/__init__.py new file mode 100644 index 00000000..9cd972b9 --- /dev/null +++ b/fmpose3d/models/fmpose3d/__init__.py @@ -0,0 +1,21 @@ +""" +FMPose3D: monocular 3D Pose Estimation via Flow Matching + +Official implementation of the paper: +"FMPose3D: monocular 3D Pose Estimation via Flow Matching" +by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis +Licensed under Apache 2.0 +""" + +""" +FMPose3D model subpackage. +""" + +from .graph_frames import Graph +from .model_GAMLP import Model + +__all__ = [ + "Graph", + "Model", +] + diff --git a/fmpose3d/models/graph_frames.py b/fmpose3d/models/fmpose3d/graph_frames.py old mode 100755 new mode 100644 similarity index 99% rename from fmpose3d/models/graph_frames.py rename to fmpose3d/models/fmpose3d/graph_frames.py index 7aa2c391..0a26b140 --- a/fmpose3d/models/graph_frames.py +++ b/fmpose3d/models/fmpose3d/graph_frames.py @@ -207,4 +207,5 @@ def normalize_undigraph(A): if __name__=="__main__": graph = Graph('hm36_gt', 'spatial', 1) print(graph.A.shape) - # print(graph) \ No newline at end of file + # print(graph) + diff --git a/fmpose3d/models/model_GAMLP.py b/fmpose3d/models/fmpose3d/model_GAMLP.py similarity index 88% rename from fmpose3d/models/model_GAMLP.py rename to fmpose3d/models/fmpose3d/model_GAMLP.py index c0c6b46e..66579001 100644 --- a/fmpose3d/models/model_GAMLP.py +++ b/fmpose3d/models/fmpose3d/model_GAMLP.py @@ -7,15 +7,14 @@ Licensed under Apache 2.0 """ -import sys -sys.path.append("..") import torch import torch.nn as nn import math from einops import rearrange -from fmpose3d.models.graph_frames import Graph +from fmpose3d.models.fmpose3d.graph_frames import Graph +from fmpose3d.models.base_model import BaseModel, register_model from functools import partial -from einops import rearrange, repeat +from einops import rearrange from timm.models.layers import DropPath class TimeEmbedding(nn.Module): @@ -36,8 +35,6 @@ def forward(self, t: torch.Tensor) -> torch.Tensor: b, f = t.shape[0], t.shape[1] half_dim = self.dim // 2 - # Gaussian Fourier features: sin(2π B t), cos(2π B t) - # t: (B,F,1,1) -> (B,F,1,1,1) -> broadcast with (1,1,1,1,half_dim) angles = (2 * math.pi) * t.to(torch.float32).unsqueeze(-1) * self.B.view(1, 1, 1, 1, half_dim) sin = torch.sin(angles) cos = torch.cos(angles) @@ -206,27 +203,28 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay self.fc1 = nn.Linear(dim_in, dim_hid) self.act = act_layer() self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity() - self.fc5 = nn.Linear(dim_hid, dim_out) + self.fc2= nn.Linear(dim_hid, dim_out) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) - x = self.fc5(x) + x = self.fc2(x) return x -class Model(nn.Module): +@register_model("fmpose3d") +class Model(BaseModel): def __init__(self, args): - super().__init__() + super().__init__(args) ## GCN self.graph = Graph('hm36_gt', 'spatial', pad=1) # Register as buffer (not parameter) to follow module device automatically self.register_buffer('A', torch.tensor(self.graph.A, dtype=torch.float32)) - self.t_embed_dim = 16 + self.t_embed_dim = 32 self.time_embed = TimeEmbedding(self.t_embed_dim, hidden_dim=64) - self.encoder_pose_2d = encoder(2, args.channel//2, args.channel//2-self.t_embed_dim//2) - self.encoder_y_t = encoder(3, args.channel//2, args.channel//2-self.t_embed_dim//2) + + self.encoder = encoder(2 + 3 + self.t_embed_dim, args.channel//2, args.channel) self.FMPose3D = FMPose3D(args.layers, args.channel, args.d_hid, args.token_dim, self.A, length=args.n_joints) # 256 self.pred_mu = decoder(args.channel, args.channel//2, 3) @@ -235,24 +233,16 @@ def forward(self, pose_2d, y_t, t): # pose_2d: (B,F,J,2) y_t: (B,F,J,3) t: (B,F,1,1) b, f, j, _ = pose_2d.shape - # Ensure t has the correct shape (B,F,1,1) - if t.shape[1] == 1 and f > 1: - t = t.expand(b, f, 1, 1).contiguous() - # build time embedding t_emb = self.time_embed(t) # (B,F,t_dim) t_emb = t_emb.unsqueeze(2).expand(b, f, j, self.t_embed_dim).contiguous() # (B,F,J,t_dim) - pose_2d_emb = self.encoder_pose_2d(pose_2d) - y_t_emb = self.encoder_y_t(y_t) - - in_emb = torch.cat([pose_2d_emb, y_t_emb, t_emb], dim=-1) # (B,F,J,dim) - in_emb = rearrange(in_emb, 'b f j c -> (b f) j c').contiguous() # (B*F,J,in) - - # encoder -> model -> regression head - h = self.FMPose3D(in_emb) - v = self.pred_mu(h) # (B*F,J,3) + x_in = torch.cat([pose_2d, y_t, t_emb], dim=-1) # (B,F,J,2+3+t_dim) + x_in = rearrange(x_in, 'b f j c -> (b f) j c').contiguous() # (B*F,J,in) + in_emb = self.encoder(x_in) + features = self.FMPose3D(in_emb) + v = self.pred_mu(features) # (B*F,J,3) v = rearrange(v, '(b f) j c -> b f j c', b=b, f=f).contiguous() # (B,F,J,3) return v @@ -276,4 +266,5 @@ class Args: y_t = torch.randn(1, 17, 17, 3, device=device) t = torch.randn(1, 1, 1, 1, device=device) v = model(x, y_t, t) - print(v.shape) \ No newline at end of file + print(v.shape) + diff --git a/images/demo.gif b/images/demo.gif new file mode 100644 index 00000000..fa1595aa Binary files /dev/null and b/images/demo.gif differ diff --git a/images/demo.jpg b/images/demo.jpg index 328847b0..9dee98d2 100644 Binary files a/images/demo.jpg and b/images/demo.jpg differ diff --git a/scripts/FMPose3D_main.py b/scripts/FMPose3D_main.py index 6cd97b1d..e9adf3a7 100644 --- a/scripts/FMPose3D_main.py +++ b/scripts/FMPose3D_main.py @@ -78,7 +78,7 @@ def test_multi_hypothesis( for i, data in enumerate(tqdm(dataLoader, 0)): batch_cam, gt_3D, input_2D, action, subject, scale, bb_box, cam_ind = data - [input_2D, gt_3D, batch_cam, scale, bb_box] = get_varialbe( + [input_2D, gt_3D, batch_cam, scale, bb_box] = get_variable( split, [input_2D, gt_3D, batch_cam, scale, bb_box] ) @@ -165,7 +165,7 @@ def train(opt, train_loader, model, optimizer): for i, data in enumerate(tqdm(train_loader, 0)): batch_cam, gt_3D, input_2D, action, subject, scale, bb_box, cam_ind = data - [input_2D, gt_3D, batch_cam, scale, bb_box] = get_varialbe( + [input_2D, gt_3D, batch_cam, scale, bb_box] = get_variable( split, [input_2D, gt_3D, batch_cam, scale, bb_box] ) @@ -267,36 +267,29 @@ def print_error_action(action_error_sum, is_train): args.checkpoint = "./checkpoint/" + folder_name elif args.train == False: # create a new folder for the test results - args.previous_dir = os.path.dirname(args.saved_model_path) + args.previous_dir = os.path.dirname(args.model_weights_path) args.checkpoint = os.path.join(args.previous_dir, folder_name) if not os.path.exists(args.checkpoint): os.makedirs(args.checkpoint) # backup files - # import shutil - # file_name = os.path.basename(__file__) - # shutil.copyfile( - # src=file_name, - # dst=os.path.join(args.checkpoint, args.create_time + "_" + file_name), - # ) - # shutil.copyfile( - # src="common/arguments.py", - # dst=os.path.join(args.checkpoint, args.create_time + "_arguments.py"), - # ) - # if getattr(args, "model_path", ""): - # model_src_path = os.path.abspath(args.model_path) - # model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path) - # shutil.copyfile( - # src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name) - # ) - # shutil.copyfile( - # src="common/utils.py", - # dst=os.path.join(args.checkpoint, args.create_time + "_utils.py"), - # ) - # sh_base = os.path.basename(args.sh_file) - # dst_name = f"{args.create_time}_" + sh_base - # shutil.copyfile(src=args.sh_file, dst=os.path.join(args.checkpoint, dst_name)) + import shutil + script_path = os.path.abspath(__file__) + script_name = os.path.basename(script_path) + shutil.copyfile( + src=script_path, + dst=os.path.join(args.checkpoint, args.create_time + "_" + script_name), + ) + if getattr(args, "model_path", ""): + model_src_path = os.path.abspath(args.model_path) + model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path) + shutil.copyfile( + src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name) + ) + sh_base = os.path.basename(args.sh_file) + dst_name = f"{args.create_time}_" + sh_base + shutil.copyfile(src=args.sh_file, dst=os.path.join(args.checkpoint, dst_name)) logging.basicConfig( format="%(asctime)s %(message)s", @@ -342,14 +335,16 @@ def print_error_action(action_error_sum, is_train): pin_memory=True, ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = {} - model["CFM"] = CFM(args).cuda() + model["CFM"] = CFM(args).to(device) if args.reload: model_dict = model["CFM"].state_dict() - model_path = args.saved_model_path + model_path = args.model_weights_path print(model_path) - pre_dict = torch.load(model_path) + pre_dict = torch.load(model_path, map_location=device, weights_only=True) for name, key in model_dict.items(): model_dict[name] = pre_dict[name] model["CFM"].load_state_dict(model_dict) diff --git a/scripts/FMPose3D_test.sh b/scripts/FMPose3D_test.sh index de7869ea..3d2a6152 100755 --- a/scripts/FMPose3D_test.sh +++ b/scripts/FMPose3D_test.sh @@ -2,7 +2,6 @@ layers=5 batch_size=1024 sh_file='scripts/FMPose3D_test.sh' -weight_softmax_tau=1.0 num_hypothesis_list=1 eval_multi_steps=3 topk=8 @@ -11,17 +10,16 @@ mode='exp' exp_temp=0.005 folder_name=test_s${eval_multi_steps}_${mode}_h${num_hypothesis_list}_$(date +%Y%m%d_%H%M%S) -model_path='pre_trained_models/fmpose_detected2d/model_GAMLP.py' -saved_model_path='pre_trained_models/fmpose_detected2d/FMpose_36_4972_best.pth' +model_path='./pre_trained_models/fmpose3d_h36m/model_GAMLP.py' +model_weights_path='./pre_trained_models/fmpose3d_h36m/FMpose3D_pretrained_weights.pth' -#Test CFM +#Test python3 scripts/FMPose3D_main.py \ --reload \ --topk ${topk} \ --exp_temp ${exp_temp} \ ---weight_softmax_tau ${weight_softmax_tau} \ --folder_name ${folder_name} \ ---saved_model_path "${saved_model_path}" \ +--model_weights_path "${model_weights_path}" \ --model_path "${model_path}" \ --eval_sample_steps ${eval_multi_steps} \ --test_augmentation True \ diff --git a/scripts/FMPose3D_train.sh b/scripts/FMPose3D_train.sh index cb3e0289..939658c6 100755 --- a/scripts/FMPose3D_train.sh +++ b/scripts/FMPose3D_train.sh @@ -11,10 +11,10 @@ epochs=80 num_saved_models=3 frames=1 channel_dim=512 -model_path="" # when the path is empty, the model will be loaded from the installed fmpose package -# model_path='./fmpose/models/model_GAMLP.py' # when the path is not empty, the model will be loaded from the local file path +model_path="" # when the path is empty, the model will be loaded from the installed fmpose3d package +# model_path='./models/model_GAMLP.py' # when the path is not empty, the model will be loaded from the local file path sh_file='scripts/FMPose3D_train.sh' -folder_name=FMPose3D_Publish_layers${layers}_$(date +%Y%m%d_%H%M%S) +folder_name=FMPose3D_layers${layers}_$(date +%Y%m%d_%H%M%S) python3 scripts/FMPose3D_main.py \ --train \ diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 00000000..0919b4e5 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,403 @@ +""" +FMPose3D: monocular 3D Pose Estimation via Flow Matching + +Official implementation of the paper: +"FMPose3D: monocular 3D Pose Estimation via Flow Matching" +by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis +Licensed under Apache 2.0 +""" + +import argparse +import math + +import pytest + +from fmpose3d.common.config import ( + PipelineConfig, + FMPose3DConfig, + DatasetConfig, + TrainingConfig, + InferenceConfig, + AggregationConfig, + CheckpointConfig, + RefinementConfig, + OutputConfig, + DemoConfig, + RuntimeConfig, + _SUB_CONFIG_CLASSES, +) + + +# --------------------------------------------------------------------------- +# Sub-config defaults +# --------------------------------------------------------------------------- + + +class TestFMPose3DConfig: + def test_defaults(self): + cfg = FMPose3DConfig() + assert cfg.layers == 3 + assert cfg.channel == 512 + assert cfg.d_hid == 1024 + assert cfg.n_joints == 17 + assert cfg.out_joints == 17 + assert cfg.frames == 1 + + def test_custom_values(self): + cfg = FMPose3DConfig(layers=5, channel=256, n_joints=26) + assert cfg.layers == 5 + assert cfg.channel == 256 + assert cfg.n_joints == 26 + + +class TestDatasetConfig: + def test_defaults(self): + cfg = DatasetConfig() + assert cfg.dataset == "h36m" + assert cfg.keypoints == "cpn_ft_h36m_dbb" + assert cfg.root_path == "dataset/" + assert cfg.train_views == [0, 1, 2, 3] + assert cfg.joints_left == [] + assert cfg.joints_right == [] + + def test_list_defaults_are_independent(self): + """Each instance should get its own list, not a shared reference.""" + a = DatasetConfig() + b = DatasetConfig() + a.joints_left.append(99) + assert 99 not in b.joints_left + + def test_custom_values(self): + cfg = DatasetConfig( + dataset="rat7m", + root_path="Rat7M_data/", + joints_left=[8, 10, 11], + joints_right=[9, 14, 15], + ) + assert cfg.dataset == "rat7m" + assert cfg.root_path == "Rat7M_data/" + assert cfg.joints_left == [8, 10, 11] + + +class TestTrainingConfig: + def test_defaults(self): + cfg = TrainingConfig() + assert cfg.train is False + assert cfg.nepoch == 41 + assert cfg.batch_size == 128 + assert cfg.lr == pytest.approx(1e-3) + assert cfg.lr_decay == pytest.approx(0.95) + assert cfg.data_augmentation is True + + def test_custom_values(self): + cfg = TrainingConfig(lr=5e-4, nepoch=100) + assert cfg.lr == pytest.approx(5e-4) + assert cfg.nepoch == 100 + + +class TestInferenceConfig: + def test_defaults(self): + cfg = InferenceConfig() + assert cfg.test == 1 + assert cfg.test_augmentation is True + assert cfg.sample_steps == 3 + assert cfg.eval_sample_steps == "1,3,5,7,9" + assert cfg.hypothesis_num == 1 + assert cfg.guidance_scale == pytest.approx(1.0) + + +class TestAggregationConfig: + def test_defaults(self): + cfg = AggregationConfig() + assert cfg.topk == 3 + assert cfg.exp_temp == pytest.approx(0.002) + assert cfg.mode == "exp" + assert cfg.opt_steps == 2 + + +class TestCheckpointConfig: + def test_defaults(self): + cfg = CheckpointConfig() + assert cfg.reload is False + assert cfg.model_weights_path == "" + assert cfg.previous_dir == "./pre_trained_model/pretrained" + assert cfg.num_saved_models == 3 + assert cfg.previous_best_threshold == math.inf + + def test_mutability(self): + cfg = CheckpointConfig() + cfg.previous_best_threshold = 42.5 + cfg.previous_name = "best_model.pth" + assert cfg.previous_best_threshold == pytest.approx(42.5) + assert cfg.previous_name == "best_model.pth" + + +class TestRefinementConfig: + def test_defaults(self): + cfg = RefinementConfig() + assert cfg.post_refine is False + assert cfg.lr_refine == pytest.approx(1e-5) + assert cfg.refine is False + + +class TestOutputConfig: + def test_defaults(self): + cfg = OutputConfig() + assert cfg.create_time == "" + assert cfg.create_file == 1 + assert cfg.debug is False + assert cfg.folder_name == "" + + +class TestDemoConfig: + def test_defaults(self): + cfg = DemoConfig() + assert cfg.type == "image" + assert cfg.path == "demo/images/running.png" + + +class TestRuntimeConfig: + def test_defaults(self): + cfg = RuntimeConfig() + assert cfg.gpu == "0" + assert cfg.pad == 0 + assert cfg.single is False + assert cfg.reload_3d is False + + +# --------------------------------------------------------------------------- +# PipelineConfig +# --------------------------------------------------------------------------- + + +class TestPipelineConfig: + def test_default_construction(self): + """All sub-configs are initialised with their defaults.""" + cfg = PipelineConfig() + assert isinstance(cfg.model_cfg, FMPose3DConfig) + assert isinstance(cfg.dataset_cfg, DatasetConfig) + assert isinstance(cfg.training_cfg, TrainingConfig) + assert isinstance(cfg.inference_cfg, InferenceConfig) + assert isinstance(cfg.aggregation_cfg, AggregationConfig) + assert isinstance(cfg.checkpoint_cfg, CheckpointConfig) + assert isinstance(cfg.refinement_cfg, RefinementConfig) + assert isinstance(cfg.output_cfg, OutputConfig) + assert isinstance(cfg.demo_cfg, DemoConfig) + assert isinstance(cfg.runtime_cfg, RuntimeConfig) + + def test_partial_construction(self): + """Supplying only some sub-configs leaves the rest at defaults.""" + cfg = PipelineConfig( + model_cfg=FMPose3DConfig(layers=5), + training_cfg=TrainingConfig(lr=2e-4), + ) + assert cfg.model_cfg.layers == 5 + assert cfg.training_cfg.lr == pytest.approx(2e-4) + # Others keep defaults + assert cfg.dataset_cfg.dataset == "h36m" + assert cfg.runtime_cfg.gpu == "0" + + def test_sub_config_mutation(self): + """Mutating a sub-config field is reflected on the config.""" + cfg = PipelineConfig() + cfg.training_cfg.lr = 0.01 + assert cfg.training_cfg.lr == pytest.approx(0.01) + + def test_sub_config_replacement(self): + """Replacing an entire sub-config works.""" + cfg = PipelineConfig() + cfg.model_cfg = FMPose3DConfig(layers=10, channel=1024) + assert cfg.model_cfg.layers == 10 + assert cfg.model_cfg.channel == 1024 + + # -- to_dict -------------------------------------------------------------- + + def test_to_dict_returns_flat_dict(self): + cfg = PipelineConfig() + d = cfg.to_dict() + assert isinstance(d, dict) + # Spot-check keys from different groups + assert "layers" in d + assert "dataset" in d + assert "lr" in d + assert "topk" in d + assert "gpu" in d + + def test_to_dict_reflects_custom_values(self): + cfg = PipelineConfig( + model_cfg=FMPose3DConfig(layers=7), + aggregation_cfg=AggregationConfig(topk=5), + ) + d = cfg.to_dict() + assert d["layers"] == 7 + assert d["topk"] == 5 + + def test_to_dict_no_duplicate_keys(self): + """Every field name should be unique across all sub-configs.""" + cfg = PipelineConfig() + d = cfg.to_dict() + all_field_names = [] + for dc_class in _SUB_CONFIG_CLASSES.values(): + from dataclasses import fields as dc_fields + all_field_names.extend(f.name for f in dc_fields(dc_class)) + assert len(all_field_names) == len(set(all_field_names)), ( + "Duplicate field names across sub-configs" + ) + + # -- from_namespace ------------------------------------------------------- + + def test_from_namespace_basic(self): + ns = argparse.Namespace( + # FMPose3DConfig + model="test_model", + model_type="fmpose3d", + layers=5, + channel=256, + d_hid=512, + token_dim=128, + n_joints=20, + out_joints=20, + in_channels=2, + out_channels=3, + frames=3, + # DatasetConfig + dataset="rat7m", + keypoints="cpn", + root_path="Rat7M_data/", + actions="*", + downsample=1, + subset=1.0, + stride=1, + crop_uv=0, + out_all=1, + train_views=[0, 1], + test_views=[2, 3], + subjects_train="S1", + subjects_test="S2", + root_joint=4, + joints_left=[8, 10], + joints_right=[9, 14], + # TrainingConfig + train=True, + nepoch=100, + batch_size=64, + lr=5e-4, + lr_decay=0.99, + lr_decay_large=0.5, + large_decay_epoch=10, + workers=4, + data_augmentation=False, + reverse_augmentation=False, + norm=0.01, + # InferenceConfig + test=1, + test_augmentation=False, + test_augmentation_flip_hypothesis=False, + test_augmentation_FlowAug=False, + sample_steps=5, + eval_multi_steps=True, + eval_sample_steps="1,3,5", + num_hypothesis_list="1,3", + hypothesis_num=3, + guidance_scale=1.5, + # AggregationConfig + topk=5, + exp_temp=0.001, + mode="softmax", + opt_steps=3, + # CheckpointConfig + reload=True, + model_dir="/tmp", + model_weights_path="/tmp/weights.pth", + checkpoint="/tmp/ckpt", + previous_dir="./pre_trained", + num_saved_models=5, + previous_best_threshold=50.0, + previous_name="best.pth", + # RefinementConfig + post_refine=False, + post_refine_reload=False, + previous_post_refine_name="", + lr_refine=1e-5, + refine=False, + reload_refine=False, + previous_refine_name="", + # OutputConfig + create_time="250101", + filename="run1", + create_file=1, + debug=True, + folder_name="exp1", + sh_file="train.sh", + # DemoConfig + type="video", + path="/tmp/video.mp4", + # RuntimeConfig + gpu="1", + pad=1, + single=True, + reload_3d=False, + ) + cfg = PipelineConfig.from_namespace(ns) + + # Verify a sample from each group + assert cfg.model_cfg.layers == 5 + assert cfg.model_cfg.channel == 256 + assert cfg.dataset_cfg.dataset == "rat7m" + assert cfg.dataset_cfg.joints_left == [8, 10] + assert cfg.training_cfg.train is True + assert cfg.training_cfg.nepoch == 100 + assert cfg.inference_cfg.sample_steps == 5 + assert cfg.inference_cfg.guidance_scale == pytest.approx(1.5) + assert cfg.aggregation_cfg.topk == 5 + assert cfg.checkpoint_cfg.reload is True + assert cfg.checkpoint_cfg.previous_best_threshold == pytest.approx(50.0) + assert cfg.refinement_cfg.lr_refine == pytest.approx(1e-5) + assert cfg.output_cfg.debug is True + assert cfg.demo_cfg.type == "video" + assert cfg.runtime_cfg.gpu == "1" + + def test_from_namespace_ignores_unknown_fields(self): + """Extra attributes in the namespace that don't match any field are ignored.""" + ns = argparse.Namespace( + layers=3, channel=512, unknown_field="should_be_ignored", + ) + cfg = PipelineConfig.from_namespace(ns) + assert cfg.model_cfg.layers == 3 + assert cfg.model_cfg.channel == 512 + assert not hasattr(cfg, "unknown_field") + + def test_from_namespace_partial_namespace(self): + """A namespace missing some fields uses dataclass defaults for those.""" + ns = argparse.Namespace(layers=10, gpu="2") + cfg = PipelineConfig.from_namespace(ns) + assert cfg.model_cfg.layers == 10 + assert cfg.runtime_cfg.gpu == "2" + # Unset fields keep defaults + assert cfg.model_cfg.channel == 512 + assert cfg.training_cfg.lr == pytest.approx(1e-3) + + # -- round-trip: from_namespace ↔ to_dict --------------------------------- + + def test_roundtrip_from_namespace_to_dict(self): + """Values fed via from_namespace appear identically in to_dict.""" + ns = argparse.Namespace( + layers=8, channel=1024, dataset="animal3d", lr=2e-4, topk=7, gpu="3", + ) + cfg = PipelineConfig.from_namespace(ns) + d = cfg.to_dict() + assert d["layers"] == 8 + assert d["channel"] == 1024 + assert d["dataset"] == "animal3d" + assert d["lr"] == pytest.approx(2e-4) + assert d["topk"] == 7 + assert d["gpu"] == "3" + + def test_to_dict_after_mutation(self): + """to_dict reflects in-place mutations on sub-configs.""" + cfg = PipelineConfig() + cfg.training_cfg.lr = 0.123 + cfg.model_cfg.layers = 99 + d = cfg.to_dict() + assert d["lr"] == pytest.approx(0.123) + assert d["layers"] == 99