Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
# FMPose: 3D Pose Estimation via Flow Matching
# FMPose3D: monocular 3D Pose Estimation via Flow Matching

This is the official implementation of the approach described in the paper:

> [**FMPose: 3D Pose Estimation via Flow Matching**](xxx)
> [**FMPose3D: monocular 3D Pose Estimation via Flow Matching**](xxx)
> Ti Wang, Xiaohang Yu, Mackenzie Weygandt Mathis

<!-- <p align="center"><img src="./images/Frame 4.jpg" width="50%" alt="" /></p> -->

<p align="center"><img src="./images/predictions.jpg" width="95%" alt="" /></p>

## Set up a environment

Make sure you have Python 3.10. You can set this up with:
## News!

- [X] Feb 2026: FMPose3D is code and arXiv paper is released - check out the demos here or on our [project page](https://xiu-cs.github.io/FMPose3D/)
- [ ] Planned: This method will be integrated into [DeepLabCut](https://www.mackenziemathislab.org/deeplabcut)

## Installation

### Set up an environment

Make sure you have Python 3.10+. You can set this up with:
```bash
conda create -n fmpose_3d python=3.10
conda activate fmpose_3d
Expand All @@ -20,9 +28,9 @@ conda activate fmpose_3d
```bash
git clone xxxx.git # clone this repo
# TestPyPI (pre-release/testing build)
pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ fmpose==0.0.5
pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ fmpose3d==0.0.7
# Future Official PyPI release
# pip install fmpose
# pip install fmpose3d
```

## Demo
Expand Down Expand Up @@ -68,7 +76,7 @@ The training logs, checkpoints, and related files of each training time will be

For training on Human3.6M:
```bash
sh /scripts/FMPose_train.sh
sh /scripts/FMPose3D_train.sh
```

### Inference
Expand All @@ -78,16 +86,18 @@ First, download the folder with pre-trained model from [here](https://drive.goog
To run inference on Human3.6M:

```bash
sh ./scripts/FMPose_test.sh
sh ./scripts/FMPose3D_test.sh
```

## Experiments Animals

For animal training/testing and demo scripts, see [animals/README.md](animals/README.md).

## Acknowledgement
## Acknowledgements

We thank the Swiss National Science Foundation (SNSF Project # 320030-227871) and the Kavli Foundation for providing financial support for this project.

Our code is extended from the following repositories. We thank the authors for releasing the codes.
Our code is extended from the following repositories. We thank the authors for releasing the code.

- [MHFormer](https://github.com/Vegetebird/MHFormer)
- [StridedTransformer-Pose3D](https://github.com/Vegetebird/StridedTransformer-Pose3D)
Expand Down
4 changes: 2 additions & 2 deletions animals/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Animals

In this part, the FMPose model is trained on [Animal3D](https://xujiacong.github.io/Animal3D/) dataset and [Control_Animal3D](https://luoxue-star.github.io/AniMer_project_page/) dataset.
In this part, the FMPose3D model is trained on [Animal3D](https://xujiacong.github.io/Animal3D/) dataset and [Control_Animal3D](https://luoxue-star.github.io/AniMer_project_page/) dataset.
## Demo

### Testing on in-the-wild images (animals)
Expand Down Expand Up @@ -62,4 +62,4 @@ Download the pretrained model from [here](https://drive.google.com/drive/folders
```bash
cd animals # the current path is: ./animals
bash ./scripts/test_animal3d.sh
```
```
6 changes: 3 additions & 3 deletions animals/demo/vis_animals.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from PIL import Image
import matplotlib.gridspec as gridspec
import imageio
from fmpose.animals.common.arguments import opts as parse_args
from fmpose.common.camera import normalize_screen_coordinates, camera_to_world
from fmpose3d.animals.common.arguments import opts as parse_args
from fmpose3d.common.camera import normalize_screen_coordinates, camera_to_world

sys.path.append(os.getcwd())

Expand All @@ -46,7 +46,7 @@
CFM = getattr(module, "Model")
else:
# Load model from installed fmpose package
from fmpose.models import Model as CFM
from fmpose3d.models import Model as CFM

from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images

Expand Down
2 changes: 1 addition & 1 deletion animals/demo/vis_animals.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ sh_file='vis_animals.sh'
# n_joints=26
# out_joints=26

model_path='../../fmpose/animals/models/model_animal3d.py'
model_path='../pre_trained_models/animal3d_pretrained_weights/model_animal3d.py'
saved_model_path='../pre_trained_models/animal3d_pretrained_weights/CFM_154_4403_best.pth'

# path='./images/image_00068.jpg' # single image
Expand Down
8 changes: 4 additions & 4 deletions animals/models/model_animals.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from einops import rearrange
from timm.models.layers import DropPath

from fmpose.animals.models.graph_frames import Graph
from fmpose3d.animals.models.graph_frames import Graph


class TimeEmbedding(nn.Module):
Expand Down Expand Up @@ -148,7 +148,7 @@ def forward(self, x):
x = res2 + self.drop_path(x)
return x

class FMPose(nn.Module):
class FMPose3D(nn.Module):
def __init__(self, depth, embed_dim, channels_dim, tokens_dim, adj, drop_rate=0.10, length=27):
super().__init__()
drop_path_rate = 0.2
Expand Down Expand Up @@ -220,7 +220,7 @@ def __init__(self, args):
self.encoder_pose_2d = encoder(2, args.channel//2, args.channel//2-self.t_embed_dim//2)
self.encoder_y_t = encoder(3, args.channel//2, args.channel//2-self.t_embed_dim//2)

self.FMPose = FMPose(args.layers, args.channel, args.d_hid, args.token_dim, self.A, length=args.n_joints) # 256
self.FMPose3D = FMPose3D(args.layers, args.channel, args.d_hid, args.token_dim, self.A, length=args.n_joints) # 256
self.pred_mu = decoder(args.channel, args.channel//2, 3)

def forward(self, pose_2d, y_t, t):
Expand All @@ -239,7 +239,7 @@ def forward(self, pose_2d, y_t, t):
in_emb = rearrange(in_emb, 'b f j c -> (b f) j c').contiguous() # (B*F,J,in)

# encoder -> model -> regression head
h = self.FMPose(in_emb)
h = self.FMPose3D(in_emb)
v = self.pred_mu(h) # (B*F,J,3)

v = rearrange(v, '(b f) j c -> b f j c', b=b, f=f).contiguous() # (B,F,J,3)
Expand Down
8 changes: 4 additions & 4 deletions animals/scripts/main_animal3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import numpy as np
from tqdm import tqdm
import torch.optim as optim
from fmpose.animals.common.arguments import opts as parse_args
from fmpose.animals.common.utils import *
from fmpose.animals.common.animal3d_dataset import TrainDataset
from fmpose3d.animals.common.arguments import opts as parse_args
from fmpose3d.animals.common.utils import *
from fmpose3d.animals.common.animal3d_dataset import TrainDataset
import time

args = parse_args().parse()
Expand All @@ -39,7 +39,7 @@
CFM = getattr(module, "Model")
else:
# Load model from installed fmpose package
from fmpose.animals.models import Model as CFM
from fmpose3d.animals.models import Model as CFM

def train(opt, actions, train_loader, model, optimizer, epoch):
return step('train', opt, actions, train_loader, model, optimizer, epoch)
Expand Down
2 changes: 1 addition & 1 deletion animals/scripts/test_animal3d.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ n_joints=26
out_joints=26
epochs=300
# model_path='models/model_animals.py'
model_path="" # when the path is empty, the model will be loaded from the installed fmpose package
model_path='./pre_trained_models/animal3d_pretrained_weights/model_animal3d.py' # when the path is empty, the model will be loaded from the installed fmpose package
saved_model_path='./pre_trained_models/animal3d_pretrained_weights/CFM_154_4403_best.pth'
# root path denotes the path to the original dataset
root_path="./dataset/"
Expand Down
1 change: 1 addition & 0 deletions dataset/readme.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Download the preprocessed datasets from [here](https://drive.google.com/drive/folders/112GPdRC9IEcwcJRyrLJeYw9_YV4wLdKC?usp=sharing) and place them in this folder.
10 changes: 5 additions & 5 deletions demo/vis_in_the_wild.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
sys.path.append(os.getcwd())

# Auto-download checkpoint files if missing
from fmpose.lib.checkpoint.download_checkpoints import ensure_checkpoints
from fmpose3d.lib.checkpoint.download_checkpoints import ensure_checkpoints
ensure_checkpoints()

from fmpose.lib.preprocess import h36m_coco_format, revise_kpts
from fmpose.lib.hrnet.gen_kpts import gen_video_kpts as hrnet_pose
from fmpose.common.arguments import opts as parse_args
from fmpose3d.lib.preprocess import h36m_coco_format, revise_kpts
from fmpose3d.lib.hrnet.gen_kpts import gen_video_kpts as hrnet_pose
from fmpose3d.common.arguments import opts as parse_args

args = parse_args().parse()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
Expand All @@ -39,7 +39,7 @@
spec.loader.exec_module(module)
CFM = getattr(module, 'Model')

from fmpose.common.camera import *
from fmpose3d.common.camera import *

import matplotlib
import matplotlib.pyplot as plt
Expand Down
17 changes: 0 additions & 17 deletions fmpose/animals/__init__.py

This file was deleted.

8 changes: 4 additions & 4 deletions fmpose/__init__.py → fmpose3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
Licensed under Apache 2.0
"""

__version__ = "0.0.5"
__version__ = "0.0.7"
__author__ = "Ti Wang, Xiaohang Yu, Mackenzie Weygandt Mathis"
__license__ = "MIT"
__license__ = "Apache 2.0"

# Import key components for easy access
from .aggregation_methods import (
average_aggregation,
aggregation_select_single_best_hypothesis_by_2D_error,
aggregation_RPEA_weighted_by_2D_error,
aggregation_RPEA_joint_level,
)

# Import 2D pose detection utilities
Expand All @@ -27,7 +27,7 @@
# Aggregation methods
"average_aggregation",
"aggregation_select_single_best_hypothesis_by_2D_error",
"aggregation_RPEA_weighted_by_2D_error",
"aggregation_RPEA_joint_level",
# 2D pose detection
"gen_video_kpts",
"h36m_coco_format",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import torch
from fmpose.common.utils import project_to_2d
from fmpose3d.common.utils import project_to_2d

def average_aggregation(list_hypothesis):
return torch.mean(torch.stack(list_hypothesis), dim=0)
Expand Down Expand Up @@ -96,7 +96,7 @@ def aggregation_select_single_best_hypothesis_by_2D_error(args,
return agg


def aggregation_RPEA_weighted_by_2D_error(
def aggregation_RPEA_joint_level(
args, list_hypothesis, batch_cam, input_2D, gt_3D, topk=3
):
"""
Expand Down
8 changes: 8 additions & 0 deletions fmpose3d/animals/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
Animal-specific components for FMPose3D.
"""

__all__ = [
"common",
]

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from einops import rearrange
from timm.models.layers import DropPath

from fmpose.animals.models.graph_frames import Graph
from fmpose3d.animals.models.graph_frames import Graph

class TimeEmbedding(nn.Module):
def __init__(self, dim: int, hidden_dim: int = 64):
Expand Down Expand Up @@ -148,7 +148,7 @@ def forward(self, x):
x = res2 + self.drop_path(x)
return x

class FMPose(nn.Module):
class FMPose3D(nn.Module):
def __init__(self, depth, embed_dim, channels_dim, tokens_dim, adj, drop_rate=0.10, length=27):
super().__init__()
drop_path_rate = 0.2
Expand Down Expand Up @@ -219,7 +219,7 @@ def __init__(self, args):
self.encoder_pose_2d = encoder(2, args.channel//2, args.channel//2-self.t_embed_dim//2)
self.encoder_y_t = encoder(3, args.channel//2, args.channel//2-self.t_embed_dim//2)

self.FMPose = FMPose(args.layers, args.channel, args.d_hid, args.token_dim, self.A, length=args.n_joints) # 256
self.FMPose3D = FMPose3D(args.layers, args.channel, args.d_hid, args.token_dim, self.A, length=args.n_joints) # 256
self.pred_mu = decoder(args.channel, args.channel//2, 3)

def forward(self, pose_2d, y_t, t):
Expand All @@ -238,7 +238,7 @@ def forward(self, pose_2d, y_t, t):
in_emb = rearrange(in_emb, 'b f j c -> (b f) j c').contiguous() # (B*F,J,in)

# encoder -> model -> regression head
h = self.FMPose(in_emb)
h = self.FMPose3D(in_emb)
v = self.pred_mu(h) # (B*F,J,3)

v = rearrange(v, '(b f) j c -> b f j c', b=b, f=f).contiguous() # (B,F,J,3)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

import numpy as np

from fmpose.common.camera import normalize_screen_coordinates
from fmpose.common.mocap_dataset import MocapDataset
from fmpose.common.skeleton import Skeleton
from fmpose3d.common.camera import normalize_screen_coordinates
from fmpose3d.common.mocap_dataset import MocapDataset
from fmpose3d.common.skeleton import Skeleton

h36m_skeleton = Skeleton(
parents=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import numpy as np
import torch.utils.data as data

from fmpose.common.camera import normalize_screen_coordinates, world_to_camera
from fmpose.common.generator import ChunkedGenerator
from fmpose.common.utils import deterministic_random
from fmpose3d.common.camera import normalize_screen_coordinates, world_to_camera
from fmpose3d.common.generator import ChunkedGenerator
from fmpose3d.common.utils import deterministic_random


class Fusion(data.Dataset):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
18 changes: 9 additions & 9 deletions fmpose/lib/hrnet/gen_kpts.py → fmpose3d/lib/hrnet/gen_kpts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,22 @@
import cv2
import copy

from fmpose.lib.hrnet.lib.utils.utilitys import plot_keypoint, PreProcess, write, load_json
from fmpose.lib.hrnet.lib.config import cfg, update_config
from fmpose.lib.hrnet.lib.utils.transforms import *
from fmpose.lib.hrnet.lib.utils.inference import get_final_preds
from fmpose.lib.hrnet.lib.models import pose_hrnet
from fmpose3d.lib.hrnet.lib.utils.utilitys import plot_keypoint, PreProcess, write, load_json
from fmpose3d.lib.hrnet.lib.config import cfg, update_config
from fmpose3d.lib.hrnet.lib.utils.transforms import *
from fmpose3d.lib.hrnet.lib.utils.inference import get_final_preds
from fmpose3d.lib.hrnet.lib.models import pose_hrnet

cfg_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'experiments') + '/'

# Auto-download checkpoints if missing and get checkpoint paths
from fmpose.lib.checkpoint.download_checkpoints import ensure_checkpoints, get_checkpoint_path
from fmpose3d.lib.checkpoint.download_checkpoints import ensure_checkpoints, get_checkpoint_path
ensure_checkpoints()

# Loading human detector model
from fmpose.lib.yolov3.human_detector import load_model as yolo_model
from fmpose.lib.yolov3.human_detector import yolo_human_det as yolo_det
from fmpose.lib.sort.sort import Sort
from fmpose3d.lib.yolov3.human_detector import load_model as yolo_model
from fmpose3d.lib.yolov3.human_detector import yolo_human_det as yolo_det
from fmpose3d.lib.sort.sort import Sort

def parse_args():
parser = argparse.ArgumentParser(description='Train keypoints network')
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import torch
import json
import torchvision.transforms as transforms
from fmpose.lib.hrnet.lib.utils.transforms import *
from fmpose3d.lib.hrnet.lib.utils.transforms import *

from fmpose.lib.hrnet.lib.utils.coco_h36m import coco_h36m
from fmpose3d.lib.hrnet.lib.utils.coco_h36m import coco_h36m
import numpy as np

joint_pairs = [[0, 1], [1, 3], [0, 2], [2, 4],
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading
Loading