Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 76 additions & 30 deletions fmpose3d/inference_api/fmpose3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def predict(
images=paths,
max_individuals=cfg.max_individuals,
out_folder=tmpdir,
progress_bar=False
)
# predictions: {image_path: {"bodyparts": (N_ind, K, 3), ...}}
# Iterate in input order to keep frame alignment stable.
Expand Down Expand Up @@ -937,23 +938,11 @@ def predict(

# 3D pose lifting
result_3d = self.pose_3d(
result_2d.keypoints,
result_2d.image_size,
result_2d,
camera_rotation=camera_rotation,
seed=seed,
progress=progress,
)

# Propagate 2D result status and validity mask to 3D pose result
result_3d.status_hint = f"2D pose status is {status.value}: {status_msg}"
result_3d.valid_frames_mask = result_2d.valid_frames_mask

# Apply result masking for partial results (set NaN for invalid frames)
if status == ResultStatus.PARTIAL:
invalid = ~result_3d.valid_frames_mask
if np.any(invalid):
result_3d.poses_3d[invalid] = np.nan
result_3d.poses_3d_world[invalid] = np.nan
return result_3d

@torch.no_grad()
Expand Down Expand Up @@ -1006,8 +995,8 @@ def prepare_2d(
@torch.no_grad()
def pose_3d(
self,
keypoints_2d: np.ndarray,
image_size: tuple[int, int],
keypoints_2d: Pose2DResult | np.ndarray,
image_size: tuple[int, int] | None = None,
*,
camera_rotation: np.ndarray | None = _DEFAULT_CAM_ROTATION,
seed: int | None = None,
Expand All @@ -1027,13 +1016,17 @@ def pose_3d(

Parameters
----------
keypoints_2d : ndarray
2D keypoints returned by :meth:`prepare_2d`. Accepted shapes:
keypoints_2d : Pose2DResult or ndarray
2D keypoints returned by :meth:`prepare_2d`, either as a full
:class:`Pose2DResult` or as a raw ndarray. Accepted ndarray shapes:

* ``(num_persons, num_frames, J, 2)`` -- first person is used.
* ``(num_frames, J, 2)`` -- treated as a single person.
image_size : tuple of (int, int)
image_size : tuple of (int, int) or None
``(height, width)`` of the source image / video frames.
Required when ``keypoints_2d`` is an ndarray. Optional when
``keypoints_2d`` is a :class:`Pose2DResult`; if provided, it must
match ``Pose2DResult.image_size``.
camera_rotation : ndarray or None
Length-4 quaternion for the camera-to-world rotation applied
to produce ``poses_3d_world``. Defaults to the rotation used
Expand All @@ -1053,25 +1046,25 @@ def pose_3d(
Pose3DResult
Root-relative and post-processed 3D poses.
"""
result_2d: Pose2DResult = self._normalize_3d_input(
keypoints_2d,
image_size=image_size
)
status, status_msg = result_2d.get_status_info()
if status in {ResultStatus.EMPTY, ResultStatus.INVALID}:
raise ValueError(f"2D pose estimation is not usable for 3D lifting: {status.value}. {status_msg}")
# Just use the first person's keypoints for now.
kpts = result_2d.keypoints[0]
h, w = result_2d.image_size

self.setup_runtime()
model = self._model_3d
h, w = image_size
steps = self.inference_cfg.sample_steps

# Optional deterministic seeding
if seed is not None:
torch.manual_seed(seed)

# Normalise input shape to (num_frames, J, 2)
if keypoints_2d.ndim == 4:
kpts = keypoints_2d[0] # first person
elif keypoints_2d.ndim == 3:
kpts = keypoints_2d
else:
raise ValueError(
f"Expected keypoints_2d with 3 or 4 dims, got {keypoints_2d.ndim}"
)

num_frames = kpts.shape[0]
all_poses_3d: list[np.ndarray] = []
all_poses_world: list[np.ndarray] = []
Expand All @@ -1091,11 +1084,64 @@ def pose_3d(
if progress:
progress(i + 1, num_frames)

return Pose3DResult(
result_3d = Pose3DResult(
poses_3d=np.stack(all_poses_3d, axis=0),
poses_3d_world=np.stack(all_poses_world, axis=0),
)

# Mask invalid frames in 3D output for partial 2D predictions.
result_3d.status_hint = f"2D pose status is {status.value}: {status_msg}"
result_3d.valid_frames_mask = result_2d.valid_frames_mask
if status == ResultStatus.PARTIAL and result_3d.valid_frames_mask is not None:
invalid = ~result_3d.valid_frames_mask
if np.any(invalid):
result_3d.poses_3d[invalid] = np.nan
result_3d.poses_3d_world[invalid] = np.nan
return result_3d

def _normalize_3d_input(
self,
keypoints_2d: Pose2DResult | np.ndarray,
*,
image_size: tuple[int, int] | None,
) -> Pose2DResult:
"""Normalise pose_3d inputs into a Pose2DResult instance."""
if isinstance(keypoints_2d, Pose2DResult):
if image_size is not None and image_size != keypoints_2d.image_size:
raise ValueError(
f"Image size mismatch: Pose2DResult.image_size={keypoints_2d.image_size}, "
f"image_size={image_size}. Please provide either a Pose2DResult (containing "
f"image_size), or keypoints_2d as a numpy ndarray together with "
f"image_size={image_size}."
)
return keypoints_2d

if not isinstance(keypoints_2d, np.ndarray):
raise ValueError("keypoints_2d must be a Pose2DResult or a numpy ndarray.")
if image_size is None:
raise ValueError(
"image_size is required when keypoints_2d is provided as an ndarray."
)

if keypoints_2d.ndim == 4:
keypoints = keypoints_2d
elif keypoints_2d.ndim == 3:
# Treat 3D input as a single-person sequence for consistency.
keypoints = keypoints_2d[np.newaxis]
else:
raise ValueError(
f"Expected keypoints_2d with 3 or 4 dims, got {keypoints_2d.ndim}"
)

scores = np.full(keypoints.shape[:-1], np.nan, dtype=np.float32)
return Pose2DResult(
keypoints=keypoints,
scores=scores,
image_size=image_size,
valid_frames_mask=None,
)


# ------------------------------------------------------------------
# Private helpers – sampling & post-processing
# ------------------------------------------------------------------
Expand Down
Loading