-
Notifications
You must be signed in to change notification settings - Fork 7
Refactor model configuration with a model registry #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
851b01d
Merge pull request #13 from deruyter92/jaap/add_config_and_registry
xiu-cs 799be4d
Refactor FMPose3D test script to use model_type instead of model_path
xiu-cs 7acba1f
Refactor FMPose3D_main.py to load model using get_model from registry…
xiu-cs f173935
Update FMPose3D_train.sh to use model_type for model selection instea…
xiu-cs 4f07856
Merge branch 'feat/add_api' into ti_video_demo
xiu-cs f878e7f
Add model_type argument to opts for model registry selection
xiu-cs 321b28a
Remove unnecessary comment block in HRNet implementation file
xiu-cs ea1d3f7
Import animal models to ensure their registration in the model registry.
xiu-cs dd5be5d
Register Model class for animal3D in the model registry and update in…
xiu-cs a7e25e9
Refactor main_animal3d.py to load model from the registered model reg…
xiu-cs f1edf44
Update test_animal3d.sh to modify eval_sample_steps, change model_typ…
xiu-cs 51b14ab
Update train_animal3d.sh to modify eval_sample_steps, change model_ty…
xiu-cs 4702481
Refactor vis_animals.py to load model using get_model from the regist…
xiu-cs b02deda
Update vis_animals.sh to set model_type for FMPose3D and adjust saved…
xiu-cs 16f340c
Update README.md with new download link for pre-trained model
xiu-cs 0512c27
Refactor backup file handling in main_animal3d.py to enable file copy…
xiu-cs a440a08
Update test_animal3d.sh to change test dataset path and adjust script…
xiu-cs 1f7bfa6
Add default FMPose3DConfig per model_type
deruyter92 6055786
Update FMPose3DConfig and add SuperAnimalConfig
deruyter92 f90ce60
expose model registry in main package
deruyter92 80b2420
update .gitignore: ignore predictions
deruyter92 8f278ec
Update model_type to "fmpose3d_humans" across configuration and model…
xiu-cs 0138da6
Update model_type to "fmpose3d_humans" in demo and script files for c…
xiu-cs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,4 +49,5 @@ htmlcov/ | |
| # Excluded directories | ||
| pre_trained_models/ | ||
| demo/predictions/ | ||
| demo/images/ | ||
| demo/images/ | ||
| **/predictions/ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,8 +9,7 @@ | |
|
|
||
| import math | ||
| from dataclasses import dataclass, field, fields, asdict | ||
| from typing import List | ||
|
|
||
| from typing import Dict, List | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Dataclass configuration groups | ||
|
|
@@ -20,24 +19,63 @@ | |
| @dataclass | ||
| class ModelConfig: | ||
| """Model architecture configuration.""" | ||
| model_type: str = "fmpose3d" | ||
| model_type: str = "fmpose3d_humans" | ||
|
|
||
|
|
||
| # Per-model-type defaults for fields marked with INFER_FROM_MODEL_TYPE. | ||
| # Also consumed by PipelineConfig.for_model_type to set cross-config | ||
| # values (dataset, sample_steps, etc.). | ||
| _FMPOSE3D_DEFAULTS: Dict[str, Dict] = { | ||
| "fmpose3d_humans": { | ||
| "n_joints": 17, | ||
| "out_joints": 17, | ||
| "dataset": "h36m", | ||
| "sample_steps": 3, | ||
| "joints_left": [4, 5, 6, 11, 12, 13], | ||
| "joints_right": [1, 2, 3, 14, 15, 16], | ||
| "root_joint": 0, | ||
| }, | ||
| "fmpose3d_animals": { | ||
| "n_joints": 26, | ||
| "out_joints": 26, | ||
| "dataset": "animal3d", | ||
| "sample_steps": 5, | ||
| "joints_left": [0, 3, 5, 8, 10, 12, 14, 16, 20, 22], | ||
| "joints_right": [1, 4, 6, 9, 11, 13, 15, 17, 21, 23], | ||
| "root_joint": 7, | ||
| }, | ||
| } | ||
|
|
||
| # Sentinel object for defaults that are inferred from the model type. | ||
| INFER_FROM_MODEL_TYPE = object() | ||
|
|
||
| @dataclass | ||
| class FMPose3DConfig(ModelConfig): | ||
| model_type: str = "fmpose3d_humans" | ||
| model: str = "" | ||
| model_type: str = "fmpose3d" | ||
| layers: int = 3 | ||
| layers: int = 5 | ||
| channel: int = 512 | ||
| d_hid: int = 1024 | ||
| token_dim: int = 256 | ||
| n_joints: int = 17 | ||
| out_joints: int = 17 | ||
| n_joints: int = INFER_FROM_MODEL_TYPE # type: ignore[assignment] | ||
| out_joints: int = INFER_FROM_MODEL_TYPE # type: ignore[assignment] | ||
| joints_left: List[int] = INFER_FROM_MODEL_TYPE # type: ignore[assignment] | ||
| joints_right: List[int] = INFER_FROM_MODEL_TYPE # type: ignore[assignment] | ||
| root_joint: int = INFER_FROM_MODEL_TYPE # type: ignore[assignment] | ||
|
Comment on lines
+60
to
+64
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is added now, so you can load an FMPose3DConfig from model_type with the appropriate number of joints etc: |
||
| in_channels: int = 2 | ||
| out_channels: int = 3 | ||
| frames: int = 1 | ||
| """Optional: load model class from a specific file path.""" | ||
|
|
||
| def __post_init__(self): | ||
| defaults = _FMPOSE3D_DEFAULTS.get(self.model_type) | ||
| if defaults is None: | ||
| supported = ", ".join(sorted(_FMPOSE3D_DEFAULTS)) | ||
| raise ValueError( | ||
| f"Unknown model_type {self.model_type!r}; supported: {supported}" | ||
| ) | ||
| for f in fields(self): | ||
| if getattr(self, f.name) is INFER_FROM_MODEL_TYPE: | ||
| setattr(self, f.name, defaults[f.name]) | ||
|
|
||
| @dataclass | ||
| class DatasetConfig: | ||
|
|
@@ -178,6 +216,33 @@ class HRNetConfig(Pose2DConfig): | |
| hrnet_weights_path: str = "" | ||
|
|
||
|
|
||
| @dataclass | ||
| class SuperAnimalConfig(Pose2DConfig): | ||
| """DeepLabCut SuperAnimal 2D pose detector configuration. | ||
|
|
||
| Uses the DeepLabCut ``superanimal_analyze_images`` API to detect | ||
| animal keypoints in the quadruped80K format, then maps them to the | ||
| Animal3D 26-keypoint layout expected by the ``fmpose3d_animals`` | ||
| 3D lifter. | ||
|
|
||
| Attributes | ||
| ---------- | ||
| superanimal_name : str | ||
| Name of the SuperAnimal model (default ``"superanimal_quadruped"``). | ||
| sa_model_name : str | ||
| Backbone architecture (default ``"hrnet_w32"``). | ||
| detector_name : str | ||
| Object detector used for animal bounding boxes. | ||
| max_individuals : int | ||
| Maximum number of individuals to detect per image (default 1). | ||
| """ | ||
| pose2d_model: str = "superanimal" | ||
| superanimal_name: str = "superanimal_quadruped" | ||
| sa_model_name: str = "hrnet_w32" | ||
| detector_name: str = "fasterrcnn_resnet50_fpn_v2" | ||
| max_individuals: int = 1 | ||
|
|
||
|
|
||
| @dataclass | ||
| class DemoConfig: | ||
| """Demo / inference configuration.""" | ||
|
|
@@ -239,8 +304,6 @@ class PipelineConfig: | |
| demo_cfg: DemoConfig = field(default_factory=DemoConfig) | ||
| runtime_cfg: RuntimeConfig = field(default_factory=RuntimeConfig) | ||
|
|
||
| # -- construction from argparse namespace --------------------------------- | ||
|
|
||
| @classmethod | ||
| def from_namespace(cls, ns) -> "PipelineConfig": | ||
| """Build a :class:`PipelineConfig` from an ``argparse.Namespace`` | ||
|
|
@@ -258,10 +321,14 @@ def _pick(dc_class, src: dict): | |
|
|
||
| kwargs = {} | ||
| for group_name, dc_class in _SUB_CONFIG_CLASSES.items(): | ||
| if group_name == "model_cfg" and raw.get("model_type", "fmpose3d") == "fmpose3d": | ||
| if group_name == "model_cfg" and raw.get("model_type", 'fmpose3d_humans') in _FMPOSE3D_DEFAULTS: | ||
| dc_class = FMPose3DConfig | ||
| elif group_name == "pose2d_cfg" and raw.get("pose2d_model", "hrnet") == "hrnet": | ||
| dc_class = HRNetConfig | ||
| elif group_name == "pose2d_cfg": | ||
| p2d = raw.get("pose2d_model", "hrnet") | ||
| if p2d == "superanimal": | ||
| dc_class = SuperAnimalConfig | ||
| elif p2d == "hrnet": | ||
| dc_class = HRNetConfig | ||
| kwargs[group_name] = _pick(dc_class, raw) | ||
| return cls(**kwargs) | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you happy with the names "fmpose3d" for humans and "fmpose3d_animals" for animals? Or do you want to change "fmpose3d" -> "fmpose3d_humans" or something similar?
This would be a good moment to choose the final names, before these names will be used in other code as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I will change the name of the human model to 'fmpose3d_humans'