diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 242ec58..28d8b6c 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -43,7 +43,7 @@ jobs: df -h - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v4 # uses: actions/checkout@v6 - name: Install uv uses: astral-sh/setup-uv@v6 @@ -53,9 +53,9 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install the project - run: uv sync --no-cache --all-extras --dev + run: uv sync --all-extras --dev shell: bash - + - name: Install ffmpeg run: | if [ "$RUNNER_OS" == "Linux" ]; then @@ -67,9 +67,27 @@ jobs: choco install ffmpeg fi shell: bash - - - name: Run DLC Live Tests + + - name: Run Model Benchmark Test run: uv run dlc-live-test --nodisplay - - name: Run Functional Benchmark Test + - name: Run DLC Live Unit Tests run: uv run pytest + # - name: Run DLC Live Unit Tests + # run: uv run pytest --cov=dlclive --cov-report=xml --cov-report=term-missing + + # - name: Coverage Report + # uses: codecov/codecov-action@v5 + # with: + # files: ./coverage.xml + # flags: ${{ matrix.os }}-py${{ matrix.python-version }} + # name: codecov-${{ matrix.os }}-py${{ matrix.python-version }} + # - name: Add coverage to job summary + # if: always() + # shell: bash + # run: | + # uv run python -m coverage report -m > coverage.txt + # echo "## Coverage (dlclive)" >> "$GITHUB_STEP_SUMMARY" + # echo '```' >> "$GITHUB_STEP_SUMMARY" + # cat coverage.txt >> "$GITHUB_STEP_SUMMARY" + # echo '```' >> "$GITHUB_STEP_SUMMARY" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..2e5ed7b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,19 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + - id: check-docstring-first + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/asottile/setup-cfg-fmt + rev: v3.2.0 + hooks: + - id: setup-cfg-fmt + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.10 + hooks: + # Run the formatter. + - id: ruff-format + # Run the linter. + - id: ruff-check + args: [--fix,--unsafe-fixes] diff --git a/dlclive/core/inferenceutils.py b/dlclive/core/inferenceutils.py index 81d9d43..7b76b27 100644 --- a/dlclive/core/inferenceutils.py +++ b/dlclive/core/inferenceutils.py @@ -8,6 +8,10 @@ # # Licensed under GNU Lesser General Public License v3.0 # + + +# NOTE - DUPLICATED @C-Achard 2026-01-26: Copied from the original DeepLabCut codebase +# from deeplabcut/core/inferenceutils.py from __future__ import annotations import heapq @@ -17,9 +21,10 @@ import pickle import warnings from collections import defaultdict +from collections.abc import Iterable from dataclasses import dataclass from math import erf, sqrt -from typing import Any, Iterable, Tuple +from typing import Any import networkx as nx import numpy as np @@ -41,7 +46,7 @@ def _conv_square_to_condensed_indices(ind_row, ind_col, n): return n * ind_col - ind_col * (ind_col + 1) // 2 + ind_row - 1 - ind_col -Position = Tuple[float, float] +Position = tuple[float, float] @dataclass(frozen=True) @@ -155,7 +160,7 @@ def soft_identity(self): unq, idx, cnt = np.unique(data[:, 3], return_inverse=True, return_counts=True) avg = np.bincount(idx, weights=data[:, 2]) / cnt soft = softmax(avg) - return dict(zip(unq.astype(int), soft)) + return dict(zip(unq.astype(int), soft, strict=False)) @property def affinity(self): @@ -262,7 +267,8 @@ def __init__( self._has_identity = "identity" in self[0] if identity_only and not self._has_identity: warnings.warn( - "The network was not trained with identity; setting `identity_only` to False." + "The network was not trained with identity; setting `identity_only` to False.", + stacklevel=2, ) self.identity_only = identity_only & self._has_identity self.nan_policy = nan_policy @@ -344,7 +350,9 @@ def calibrate(self, train_data_file): pass n_bpts = len(df.columns.get_level_values("bodyparts").unique()) if n_bpts == 1: - warnings.warn("There is only one keypoint; skipping calibration...") + warnings.warn( + "There is only one keypoint; skipping calibration...", stacklevel=2 + ) return xy = df.to_numpy().reshape((-1, n_bpts, 2)) @@ -352,7 +360,9 @@ def calibrate(self, train_data_file): # Only keeps skeletons that are more than 90% complete xy = xy[frac_valid >= 0.9] if not xy.size: - warnings.warn("No complete poses were found. Skipping calibration...") + warnings.warn( + "No complete poses were found. Skipping calibration...", stacklevel=2 + ) return # TODO Normalize dists by longest length? @@ -369,7 +379,8 @@ def calibrate(self, train_data_file): except np.linalg.LinAlgError: # Covariance matrix estimation fails due to numerical singularities warnings.warn( - "The assembler could not be robustly calibrated. Continuing without it..." + "The assembler could not be robustly calibrated. Continuing without it...", + stacklevel=2, ) def calc_assembly_mahalanobis_dist( @@ -428,10 +439,12 @@ def _flatten_detections(data_dict): ids = [np.ones(len(arr), dtype=int) * -1 for arr in confidence] else: ids = [arr.argmax(axis=1) for arr in ids] - for i, (coords, conf, id_) in enumerate(zip(coordinates, confidence, ids)): + for i, (coords, conf, id_) in enumerate( + zip(coordinates, confidence, ids, strict=False) + ): if not np.any(coords): continue - for xy, p, g in zip(coords, conf, id_): + for xy, p, g in zip(coords, conf, id_, strict=False): joint = Joint(tuple(xy), p.item(), i, ind, g) ind += 1 yield joint @@ -474,13 +487,13 @@ def extract_best_links(self, joints_dict, costs, trees=None): (conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity) ) candidates = sorted( - zip(rows, cols, aff[rows, cols], lengths[rows, cols]), + zip(rows, cols, aff[rows, cols], lengths[rows, cols], strict=False), key=lambda x: x[2], reverse=True, ) i_seen = set() j_seen = set() - for i, j, w, l in candidates: + for i, j, w, _l in candidates: if i not in i_seen and j not in j_seen: i_seen.add(i) j_seen.add(j) @@ -502,7 +515,7 @@ def extract_best_links(self, joints_dict, costs, trees=None): ] aff = aff[np.ix_(keep_s, keep_t)] rows, cols = linear_sum_assignment(aff, maximize=True) - for row, col in zip(rows, cols): + for row, col in zip(rows, cols, strict=False): w = aff[row, col] if w >= self.min_affinity: links.append(Link(dets_s[keep_s[row]], dets_t[keep_t[col]], w)) @@ -548,9 +561,9 @@ def push_to_stack(i): d = self.calc_assembly_mahalanobis_dist(assembly, nan_policy=nan_policy) if d < d_old: push_to_stack(new_ind) - if tabu: - _, _, link = heapq.heappop(tabu) - heapq.heappush(stack, (-link.affinity, next(counter), link)) + if tabu: + _, _, link = heapq.heappop(tabu) + heapq.heappush(stack, (-link.affinity, next(counter), link)) else: heapq.heappush(tabu, (d - d_old, next(counter), best)) assembly.__dict__.update(assembly._dict) @@ -665,7 +678,7 @@ def build_assemblies(self, links): for idx in store[j]._idx: store[idx] = store[i] except KeyError: - # Some links may reference indices that were never added to `store`; + # Some links may reference indices that were never added to `store`; # in that case we intentionally skip merging for this link pass @@ -791,7 +804,7 @@ def _assemble(self, data_dict, ind_frame): ] else: scores = [ass._affinity for ass in assemblies] - lst = list(zip(scores, assemblies)) + lst = list(zip(scores, assemblies, strict=False)) assemblies = [] while lst: temp = max(lst, key=lambda x: x[0]) @@ -1074,7 +1087,7 @@ def match_assemblies( if ~np.isnan(oks): mat[i, j] = oks rows, cols = linear_sum_assignment(mat, maximize=True) - for row, col in zip(rows, cols): + for row, col in zip(rows, cols, strict=False): matched[row].ground_truth = ground_truth[col] matched[row].oks = mat[row, col] _ = inds_true.remove(col) @@ -1087,7 +1100,7 @@ def parse_ground_truth_data_file(h5_file): try: df.drop("single", axis=1, level="individuals", inplace=True) except KeyError: - # Ignore if the "single" individual column is absent + # Ignore if the "single" individual column is absent pass # Cast columns of dtype 'object' to float to avoid TypeError # further down in _parse_ground_truth_data. @@ -1128,7 +1141,7 @@ def find_outlier_assemblies(dict_of_assemblies, criterion="area", qs=(5, 95)): for frame_ind, assemblies in dict_of_assemblies.items(): for assembly in assemblies: tuples.append((frame_ind, getattr(assembly, criterion))) - frame_inds, vals = zip(*tuples) + frame_inds, vals = zip(*tuples, strict=False) vals = np.asarray(vals) lo, up = np.percentile(vals, qs, interpolation="nearest") inds = np.flatnonzero((vals < lo) | (vals > up)).tolist() @@ -1246,12 +1259,14 @@ def evaluate_assembly( ass_pred_dict, ass_true_dict, oks_sigma=0.072, - oks_thresholds=np.linspace(0.5, 0.95, 10), + oks_thresholds=None, margin=0, symmetric_kpts=None, greedy_matching=False, with_tqdm: bool = True, ): + if oks_thresholds is None: + oks_thresholds = np.linspace(0.5, 0.95, 10) if greedy_matching: return evaluate_assembly_greedy( ass_true_dict, diff --git a/dlclive/display.py b/dlclive/display.py index 0d1c924..42abab4 100644 --- a/dlclive/display.py +++ b/dlclive/display.py @@ -7,7 +7,9 @@ try: from tkinter import Label, Tk + from PIL import ImageTk + _TKINTER_AVAILABLE = True except ImportError: _TKINTER_AVAILABLE = False @@ -59,7 +61,9 @@ def set_display(self, im_size, bodyparts): self.lab.pack() all_colors = getattr(cc, self.cmap) - self.colors = all_colors[:: int(len(all_colors) / bodyparts)] + # Avoid 0 step + step = max(1, int(len(all_colors) / bodyparts)) + self.colors = all_colors[::step] def display_frame(self, frame, pose=None): """ @@ -75,10 +79,10 @@ def display_frame(self, frame, pose=None): """ if not _TKINTER_AVAILABLE: raise ImportError("tkinter is not available. Cannot display frames.") - + im_size = (frame.shape[1], frame.shape[0]) + img = Image.fromarray(frame) # avoid undefined image if pose is None if pose is not None: - img = Image.fromarray(frame) draw = ImageDraw.Draw(img) if len(pose.shape) == 2: @@ -91,33 +95,16 @@ def display_frame(self, frame, pose=None): for j in range(pose.shape[1]): if pose[i, j, 2] > self.pcutoff: try: - x0 = ( - pose[i, j, 0] - self.radius - if pose[i, j, 0] - self.radius > 0 - else 0 - ) - x1 = ( - pose[i, j, 0] + self.radius - if pose[i, j, 0] + self.radius < im_size[0] - else im_size[1] - ) - y0 = ( - pose[i, j, 1] - self.radius - if pose[i, j, 1] - self.radius > 0 - else 0 - ) - y1 = ( - pose[i, j, 1] + self.radius - if pose[i, j, 1] + self.radius < im_size[1] - else im_size[0] - ) + x0 = max(0, pose[i, j, 0] - self.radius) + x1 = min(im_size[0], pose[i, j, 0] + self.radius) + y0 = max(0, pose[i, j, 1] - self.radius) + y1 = min(im_size[1], pose[i, j, 1] + self.radius) coords = [x0, y0, x1, y1] draw.ellipse( coords, fill=self.colors[j], outline=self.colors[j] ) except Exception as e: print(e) - img_tk = ImageTk.PhotoImage(image=img, master=self.window) self.lab.configure(image=img_tk) self.window.update() diff --git a/dlclive/dlclive.py b/dlclive/dlclive.py index 1dcb88f..965aab6 100644 --- a/dlclive/dlclive.py +++ b/dlclive/dlclive.py @@ -4,6 +4,7 @@ Licensed under GNU Lesser General Public License v3.0 """ + from __future__ import annotations from pathlib import Path @@ -197,12 +198,12 @@ def __init__( self.processor = processor self.convert2rgb = convert2rgb + self.pose: np.ndarray | None = None + if isinstance(display, Display): self.display = display elif display: - self.display = Display( - pcutoff=pcutoff, radius=display_radius, cmap=display_cmap - ) + self.display = Display(pcutoff=pcutoff, radius=display_radius, cmap=display_cmap) else: self.display = None @@ -250,9 +251,7 @@ def process_frame(self, frame: np.ndarray) -> np.ndarray: processed frame: convert type, crop, convert color """ if self.cropping: - frame = frame[ - self.cropping[2] : self.cropping[3], self.cropping[0] : self.cropping[1] - ] + frame = frame[self.cropping[2] : self.cropping[3], self.cropping[0] : self.cropping[1]] if self.dynamic[0]: if self.pose is not None: @@ -263,9 +262,7 @@ def process_frame(self, frame: np.ndarray) -> np.ndarray: elif len(self.pose) == 1: pose = self.pose[0] else: - raise ValueError( - "Cannot use Dynamic Cropping - more than 1 individual found" - ) + raise ValueError("Cannot use Dynamic Cropping - more than 1 individual found") else: pose = self.pose diff --git a/dlclive/modelzoo/resolve_config.py b/dlclive/modelzoo/resolve_config.py index bea25f5..cf11f3b 100644 --- a/dlclive/modelzoo/resolve_config.py +++ b/dlclive/modelzoo/resolve_config.py @@ -1,9 +1,10 @@ """ -Helper function to deal with default values in the model configuration. +Helper function to deal with default values in the model configuration. For instance, "num_bodyparts x 2" is replaced with the number of bodyparts multiplied by 2. """ -# NOTE JR 2026-23-01: This is duplicate code, copied from the original DeepLabCut-Live codebase. +# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase +# from deeplabcut/pose_estimation_pytorch/modelzoo/utils.py import copy @@ -78,8 +79,7 @@ def get_updated_value(variable: str) -> int | list[int]: var_name = var_parts[0] if updated_values[var_name] is None: raise ValueError( - f"Found {variable} in the configuration file, but there is no default " - f"value for this variable." + f"Found {variable} in the configuration file, but there is no default value for this variable." ) if len(var_parts) == 1: @@ -133,4 +133,4 @@ def get_updated_value(variable: str) -> int | list[int]: ): config[k] = get_updated_value(config[k]) - return config \ No newline at end of file + return config diff --git a/dlclive/modelzoo/utils.py b/dlclive/modelzoo/utils.py index 1aab80b..3857d14 100644 --- a/dlclive/modelzoo/utils.py +++ b/dlclive/modelzoo/utils.py @@ -5,55 +5,53 @@ # This should be removed once a solution is found to address duplicate code. import copy -from pathlib import Path import logging +from pathlib import Path +from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model from ruamel.yaml import YAML -from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model from dlclive.modelzoo.resolve_config import update_config -_MODELZOO_PATH = Path(__file__).parent +_MODELZOO_PATH = Path(__file__).parent def get_super_animal_model_config_path(model_name: str) -> Path: """Get the path to the model configuration file for a model and validate choice of model""" - cfg_path = _MODELZOO_PATH / 'model_configs' / f"{model_name}.yaml" + cfg_path = _MODELZOO_PATH / "model_configs" / f"{model_name}.yaml" if not cfg_path.exists(): raise FileNotFoundError( - f"Modelzoo model configuration file not found: {cfg_path} " - f"Available models: {list_available_models()}" + f"Modelzoo model configuration file not found: {cfg_path} Available models: {list_available_models()}" ) return cfg_path def get_super_animal_project_config_path(super_animal: str) -> Path: """Get the path to the project configuration file for a project and validate choice of project""" - cfg_path = _MODELZOO_PATH / 'project_configs' / f"{super_animal}.yaml" + cfg_path = _MODELZOO_PATH / "project_configs" / f"{super_animal}.yaml" if not cfg_path.exists(): raise FileNotFoundError( - f"Modelzoo project configuration file not found: {cfg_path}" - f"Available projects: {list_available_projects()}" + f"Modelzoo project configuration file not found: {cfg_path} Available projects: {list_available_projects()}" ) return cfg_path def get_snapshot_folder_path() -> Path: - return _MODELZOO_PATH / 'snapshots' + return _MODELZOO_PATH / "snapshots" def list_available_models() -> list[str]: - return [p.stem for p in _MODELZOO_PATH.glob('model_configs/*.yaml')] + return [p.stem for p in _MODELZOO_PATH.glob("model_configs/*.yaml")] def list_available_projects() -> list[str]: - return [p.stem for p in _MODELZOO_PATH.glob('project_configs/*.yaml')] + return [p.stem for p in _MODELZOO_PATH.glob("project_configs/*.yaml")] def list_available_combinations() -> list[str]: models = list_available_models() projects = list_available_projects() - combinations = ['_'.join([p, m]) for p in projects for m in models] + combinations = ["_".join([p, m]) for p in projects for m in models] return combinations @@ -65,14 +63,18 @@ def read_config_as_dict(config_path: str | Path) -> dict: Returns: The configuration file with pure Python classes """ - with open(config_path, "r") as f: - cfg = YAML(typ='safe', pure=True).load(f) + with open(config_path) as f: + cfg = YAML(typ="safe", pure=True).load(f) return cfg -# NOTE JR 2026-23-01: This is duplicate code, copied from the original DeepLabCut-Live codebase. -def add_metadata(project_config: dict, config: dict,) -> dict: +# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase +# from deeplabcut/pose_estimation_pytorch/config/make_pose_config.py +def add_metadata( + project_config: dict, + config: dict, +) -> dict: """Adds metadata to a pytorch pose configuration Args: @@ -87,7 +89,8 @@ def add_metadata(project_config: dict, config: dict,) -> dict: config["metadata"] = { "project_path": project_config["project_path"], "pose_config_path": "", - "bodyparts": project_config.get("multianimalbodyparts") or project_config["bodyparts"], + "bodyparts": project_config.get("multianimalbodyparts") + or project_config["bodyparts"], "unique_bodyparts": project_config.get("uniquebodyparts", []), "individuals": project_config.get("individuals", ["animal"]), "with_identity": project_config.get("identity", False), @@ -95,7 +98,8 @@ def add_metadata(project_config: dict, config: dict,) -> dict: return config -# NOTE JR 2026-23-01: This is duplicate code, copied from the original DeepLabCut-Live codebase. +# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase +# from deeplabcut/pose_estimation_pytorch/modelzoo/utils.py def load_super_animal_config( super_animal: str, model_name: str, @@ -167,9 +171,9 @@ def download_super_animal_snapshot(dataset: str, model_name: str) -> Path: raise RuntimeError(f"Failed to download {model_name} to {model_path}") except Exception as e: - logging.error(f"Failed to download superanimal snapshot {model_name} to {model_path}: {e}") + logging.error( + f"Failed to download superanimal snapshot {model_name} to {model_path}: {e}" + ) raise e return model_path - - diff --git a/dlclive/pose_estimation_pytorch/dynamic_cropping.py b/dlclive/pose_estimation_pytorch/dynamic_cropping.py index ae5991f..4572634 100644 --- a/dlclive/pose_estimation_pytorch/dynamic_cropping.py +++ b/dlclive/pose_estimation_pytorch/dynamic_cropping.py @@ -7,13 +7,16 @@ # https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS # # Licensed under GNU Lesser General Public License v3.0 -# + +# NOTE DUPLICATED @C-Achard 2026-01-26: Duplication between this file +# and deeplabcut/pose_estimation_pytorch/runners/dynamic_cropping.py +# NOTE Testing already exists at deeplabcut/tests/pose_estimation_pytorch/runners/test_dynamic_cropper.py """Modules to dynamically crop individuals out of videos to improve video analysis""" + from __future__ import annotations import math from dataclasses import dataclass, field -from typing import Optional import torch import torchvision.transforms.functional as F @@ -80,8 +83,7 @@ def crop(self, image: torch.Tensor) -> torch.Tensor: """ if len(image) != 1: raise RuntimeError( - "DynamicCropper can only be used with batch size 1 (found image " - f"shape: {image.shape})" + f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})" ) if self._shape is None: @@ -114,7 +116,7 @@ def update(self, pose: torch.Tensor) -> torch.Tensor: The pose, with coordinates updated to the full image space. """ if self._shape is None: - raise RuntimeError(f"You must call `crop` before calling `update`.") + raise RuntimeError("You must call `crop` before calling `update`.") # offset the pose to the original image space offset_x, offset_y = 0, 0 @@ -153,9 +155,7 @@ def reset(self) -> None: self._crop = None @staticmethod - def build( - dynamic: bool, threshold: float, margin: int - ) -> Optional["DynamicCropper"]: + def build(dynamic: bool, threshold: float, margin: int) -> DynamicCropper | None: """Builds the DynamicCropper based on the given parameters Args: @@ -310,8 +310,7 @@ def crop(self, image: torch.Tensor) -> torch.Tensor: """ if len(image) != 1: raise RuntimeError( - "DynamicCropper can only be used with batch size 1 (found image " - f"shape: {image.shape})" + f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})" ) if self._shape is None: @@ -349,7 +348,7 @@ def update(self, pose: torch.Tensor) -> torch.Tensor: The pose, with coordinates updated to the full image space. """ if self._shape is None: - raise RuntimeError(f"You must call `crop` before calling `update`.") + raise RuntimeError("You must call `crop` before calling `update`.") # check whether this was a patched crop batch_size = pose.shape[0] @@ -534,7 +533,7 @@ def split_array(size: int, n: int, overlap: int) -> list[tuple[int, int]]: segment_size = (padded_size // n) + (padded_size % n > 0) segments = [] end = overlap - for i in range(n): + for _i in range(n): start = end - overlap end = start + segment_size if end > size: diff --git a/pyproject.toml b/pyproject.toml index 951646b..8ce8c13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,11 +61,13 @@ tf = [ [dependency-groups] dev = [ "pytest", + "pytest-cov", + "hypothesis", "black", "ruff", ] -# Keep only for backward compatibility with Poetry +# Keep only for backward compatibility with Poetry # (without this section, Poetry assumes the wrong root directory of the project) [tool.poetry] packages = [ @@ -87,4 +89,14 @@ include-package-data = true include = ["dlclive*"] [tool.setuptools.package-data] -dlclive = ["check_install/*"] \ No newline at end of file +dlclive = ["check_install/*"] + +# [tool.ruff] +# lint.select = ["E", "F", "B", "I", "UP"] +# lint.ignore = ["E741"] +# target-version = "py310" +# fix = true +# line-length = 120 + +# [tool.ruff.lint.pydocstyle] +# convention = "google" diff --git a/pytest.ini b/pytest.ini index c878400..701bc93 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,9 +1,10 @@ [pytest] markers = - functional: functional tests + functional: Functional tests (high-level) + slow: Slow tests filterwarnings = # Suppress NumPy deprecation warning from Keras/TensorFlow about np.object ignore::FutureWarning:keras.* ignore::FutureWarning:tensorflow.* - ignore:In the future `np.object` will be defined:FutureWarning \ No newline at end of file + ignore:In the future `np.object` will be defined:FutureWarning diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6c4866a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,314 @@ +from __future__ import annotations + +import copy +from collections.abc import Callable +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from dlclive.core.inferenceutils import Assembler + + +# -------------------------------------------------------------------------------------- +# Headless display fixture +# -------------------------------------------------------------------------------------- +@pytest.fixture +def headless_display_env(monkeypatch): + """ + Patch dlclive.display so tkinter + ImageTk are replaced by MagicMocks. + + Returns an object with: + - mod: the imported dlclive.display module + - tk_ctor: MagicMock constructor for Tk + - tk: MagicMock instance for the window + - label_ctor: MagicMock constructor for Label + - label: MagicMock instance for the label widget + - photo_ctor: MagicMock function for ImageTk.PhotoImage + - photo: MagicMock instance representing created image + """ + import dlclive.display as display_mod + + # Ensure display path is enabled + monkeypatch.setattr(display_mod, "_TKINTER_AVAILABLE", True, raising=False) + + # Tk / Label mocks + tk = MagicMock(name="TkInstance") + tk_ctor = MagicMock(name="Tk", return_value=tk) + + label = MagicMock(name="LabelInstance") + label_ctor = MagicMock(name="Label", return_value=label) + + # ImageTk.PhotoImage mock + photo = MagicMock(name="PhotoImageInstance") + photo_ctor = MagicMock(name="PhotoImage", return_value=photo) + + class FakeImageTkModule: + PhotoImage = photo_ctor + + monkeypatch.setattr(display_mod, "Tk", tk_ctor, raising=False) + monkeypatch.setattr(display_mod, "Label", label_ctor, raising=False) + monkeypatch.setattr(display_mod, "ImageTk", FakeImageTkModule, raising=False) + + return SimpleNamespace( + mod=display_mod, + tk_ctor=tk_ctor, + tk=tk, + label_ctor=label_ctor, + label=label, + photo_ctor=photo_ctor, + photo=photo, + ) + + +# -------------------------------------------------------------------------------------- +# Assembler/assembly test fixtures +# -------------------------------------------------------------------------------------- +@pytest.fixture +def assembler_graph_and_pafs() -> SimpleNamespace: + """Standard 2‑joint graph used throughout the test suite.""" + graph = [(0, 1)] + paf_inds = [0] + return SimpleNamespace(graph=graph, paf_inds=paf_inds) + + +@pytest.fixture +def make_assembler_metadata() -> Callable[..., dict[str, Any]]: + """Return a factory that builds minimal Assembler metadata dictionaries.""" + + def _factory(graph, paf_inds, n_bodyparts, frame_keys): + return { + "metadata": { + "all_joints_names": [f"b{i}" for i in range(n_bodyparts)], + "PAFgraph": graph, + "PAFinds": paf_inds, + }, + **{k: {} for k in frame_keys}, + } + + return _factory + + +@pytest.fixture +def make_assembler_frame() -> Callable[..., dict[str, Any]]: + """Return a factory that builds a frame dict compatible with _flatten_detections.""" + + def _factory( + coordinates_per_label, + confidence_per_label, + identity_per_label=None, + costs=None, + ): + frame = { + "coordinates": [coordinates_per_label], + "confidence": confidence_per_label, + "costs": costs or {}, + } + if identity_per_label is not None: + frame["identity"] = identity_per_label + return frame + + return _factory + + +@pytest.fixture +def simple_two_label_scene(make_assembler_frame) -> dict[str, Any]: + """Deterministic scene with predictable affinities for testing.""" + coords0 = np.array([[0.0, 0.0], [100.0, 100.0]]) + coords1 = np.array([[5.0, 0.0], [110.0, 100.0]]) + conf0 = np.array([0.9, 0.6]) + conf1 = np.array([0.8, 0.7]) + + aff = np.array([[0.95, 0.1], [0.05, 0.9]]) + + lens = np.array( + [ + [np.hypot(*(coords1[0] - coords0[0])), np.hypot(*(coords1[1] - coords0[0]))], + [np.hypot(*(coords1[0] - coords0[1])), np.hypot(*(coords1[1] - coords0[1]))], + ] + ) + + return make_assembler_frame( + coordinates_per_label=[coords0, coords1], + confidence_per_label=[conf0, conf1], + identity_per_label=None, + costs={0: {"distance": lens, "m1": aff}}, + ) + + +@pytest.fixture +def scene_copy(simple_two_label_scene) -> dict[str, Any]: + """Return a deep copy of the simple_two_label_scene fixture.""" + return copy.deepcopy(simple_two_label_scene) + + +@pytest.fixture +def assembler_data( + assembler_graph_and_pafs, + make_assembler_metadata, + simple_two_label_scene, +) -> SimpleNamespace: + """Full metadata + two identical frames ('0', '1').""" + paf = assembler_graph_and_pafs + data = make_assembler_metadata(paf.graph, paf.paf_inds, n_bodyparts=2, frame_keys=["0", "1"]) + data["0"] = simple_two_label_scene + data["1"] = simple_two_label_scene + return SimpleNamespace(data=data, graph=paf.graph, paf_inds=paf.paf_inds) + + +@pytest.fixture +def assembler_data_single_frame( + assembler_graph_and_pafs, + make_assembler_metadata, + simple_two_label_scene, +) -> SimpleNamespace: + """Metadata + a single frame ('0'). Used by most tests.""" + paf = assembler_graph_and_pafs + data = make_assembler_metadata(paf.graph, paf.paf_inds, n_bodyparts=2, frame_keys=["0"]) + data["0"] = simple_two_label_scene + return SimpleNamespace(data=data, graph=paf.graph, paf_inds=paf.paf_inds) + + +@pytest.fixture +def assembler_data_two_frames_nudged( + assembler_graph_and_pafs, + make_assembler_metadata, + simple_two_label_scene, +) -> SimpleNamespace: + """Two frames where frame '1' is a nudged copy of frame '0'.""" + paf = assembler_graph_and_pafs + data = make_assembler_metadata(paf.graph, paf.paf_inds, n_bodyparts=2, frame_keys=["0", "1"]) + + frame0 = simple_two_label_scene + frame1 = copy.deepcopy(simple_two_label_scene) + frame1["coordinates"][0][0] += np.array([[1.0, 0.0], [1.0, 0.0]]) + frame1["coordinates"][0][1] += np.array([[1.0, 0.0], [1.0, 0.0]]) + + data["0"] = frame0 + data["1"] = frame1 + return SimpleNamespace(data=data, graph=paf.graph, paf_inds=paf.paf_inds) + + +@pytest.fixture +def assembler_data_no_detections( + assembler_graph_and_pafs, + make_assembler_metadata, + make_assembler_frame, +) -> SimpleNamespace: + """Metadata + a single frame ('0') with zero detections for both labels.""" + paf = assembler_graph_and_pafs + data = make_assembler_metadata(paf.graph, paf.paf_inds, n_bodyparts=2, frame_keys=["0"]) + + frame = make_assembler_frame( + coordinates_per_label=[np.zeros((0, 2)), np.zeros((0, 2))], + confidence_per_label=[np.zeros((0,)), np.zeros((0,))], + identity_per_label=None, + costs={}, + ) + data["0"] = frame + # return data, graph, paf_inds + return SimpleNamespace(data=data, graph=paf.graph, paf_inds=paf.paf_inds) + + +@pytest.fixture +def make_assembler() -> Callable[..., Assembler]: + """ + Factory to create an Assembler with sensible defaults for this test suite. + Override any parameter per-test via kwargs. + """ + + def _factory(data: dict[str, Any], **overrides) -> Assembler: + defaults = dict( + max_n_individuals=2, + n_multibodyparts=2, + min_n_links=1, + pcutoff=0.1, + min_affinity=0.05, + ) + defaults.update(overrides) + return Assembler(data, **defaults) + + return _factory + + +# -------------------------------------------------------------------------------------- +# Assembly / Joint / Link test fixtures +# -------------------------------------------------------------------------------------- +from dlclive.core.inferenceutils import Assembly, Joint, Link # noqa: E402 + + +@pytest.fixture +def make_assembly() -> Callable[..., Assembly]: + """Factory to create an Assembly with the given size.""" + + def _factory(size: int) -> Assembly: + return Assembly(size=size) + + return _factory + + +@pytest.fixture +def make_joint() -> Callable[..., Joint]: + """Factory to create a Joint with sensible defaults.""" + + def _factory( + pos=(0.0, 0.0), + confidence: float = 1.0, + label: int = 0, + idx: int = 0, + group: int = -1, + ) -> Joint: + return Joint(pos=pos, confidence=confidence, label=label, idx=idx, group=group) + + return _factory + + +@pytest.fixture +def make_link() -> Callable[..., Link]: + """Factory to create a Link between two joints.""" + + def _factory(j1: Joint, j2: Joint, affinity: float = 1.0) -> Link: + return Link(j1, j2, affinity=affinity) + + return _factory + + +@pytest.fixture +def two_overlap_assemblies(make_assembly) -> tuple[Assembly, Assembly]: + """Two assemblies with partial overlap used by intersection tests.""" + assemb1 = make_assembly(2) + assemb1.data[0, :2] = [0, 0] + assemb1.data[1, :2] = [10, 10] + assemb1._visible.update({0, 1}) + + assemb2 = make_assembly(2) + assemb2.data[0, :2] = [5, 5] + assemb2.data[1, :2] = [15, 15] + assemb2._visible.update({0, 1}) + return assemb1, assemb2 + + +@pytest.fixture +def soft_identity_assembly(make_assembly) -> Assembly: + """Assembly configured for soft_identity tests.""" + assemb = make_assembly(3) + assemb.data[:] = np.nan + assemb.data[0] = [0, 0, 1.0, 0] + assemb.data[1] = [5, 5, 0.5, 0] + assemb.data[2] = [10, 10, 1.0, 1] + assemb._visible = {0, 1, 2} + return assemb + + +@pytest.fixture +def four_joint_chain(make_joint, make_link) -> SimpleNamespace: + """Four joints and two links: (0-1) and (2-3).""" + j0 = make_joint((0, 0), 1.0, label=0, idx=10) + j1 = make_joint((1, 0), 1.0, label=1, idx=11) + j2 = make_joint((2, 0), 1.0, label=2, idx=12) + j3 = make_joint((3, 0), 1.0, label=3, idx=13) + l01 = make_link(j0, j1, affinity=0.5) + l23 = make_link(j2, j3, affinity=0.8) + return SimpleNamespace(j0=j0, j1=j1, j2=j2, j3=j3, l01=l01, l23=l23) diff --git a/tests/test_benchmark_script.py b/tests/test_benchmark_script.py index 58c1533..875e63d 100644 --- a/tests/test_benchmark_script.py +++ b/tests/test_benchmark_script.py @@ -1,5 +1,7 @@ import glob + import pytest + from dlclive import benchmark_videos, download_benchmarking_data from dlclive.engine import Engine @@ -10,7 +12,9 @@ def datafolder(tmp_path): download_benchmarking_data(str(datafolder)) return datafolder + @pytest.mark.functional +@pytest.mark.slow def test_benchmark_script_runs_tf_backend(tmp_path, datafolder): dog_models = glob.glob(str(datafolder / "dog" / "*[!avi]")) dog_video = glob.glob(str(datafolder / "dog" / "*.avi"))[0] @@ -27,11 +31,7 @@ def test_benchmark_script_runs_tf_backend(tmp_path, datafolder): print(f"Running dog model: {model_path}") benchmark_videos( model_path=model_path, - model_type=( - "base" - if Engine.from_model_path(model_path) == Engine.TENSORFLOW - else "pytorch" - ), + model_type=("base" if Engine.from_model_path(model_path) == Engine.TENSORFLOW else "pytorch"), video_path=dog_video, output=str(out_dir), n_frames=n_frames, @@ -42,11 +42,7 @@ def test_benchmark_script_runs_tf_backend(tmp_path, datafolder): print(f"Running mouse model: {model_path}") benchmark_videos( model_path=model_path, - model_type=( - "base" - if Engine.from_model_path(model_path) == Engine.TENSORFLOW - else "pytorch" - ), + model_type=("base" if Engine.from_model_path(model_path) == Engine.TENSORFLOW else "pytorch"), video_path=mouse_video, output=str(out_dir), n_frames=n_frames, @@ -58,6 +54,7 @@ def test_benchmark_script_runs_tf_backend(tmp_path, datafolder): @pytest.mark.parametrize("model_name", ["hrnet_w32", "resnet_50"]) @pytest.mark.functional +@pytest.mark.slow def test_benchmark_script_with_torch_modelzoo(tmp_path, datafolder, model_name): from dlclive import modelzoo @@ -107,4 +104,4 @@ def test_benchmark_script_with_torch_modelzoo(tmp_path, datafolder, model_name): # Assertions: verify output files were created output_files = list(out_dir.iterdir()) assert len(output_files) > 0, "No output files were created by benchmark_videos" - assert any(f.suffix == ".pickle" for f in output_files), "No pickle files found in output directory" \ No newline at end of file + assert any(f.suffix == ".pickle" for f in output_files), "No pickle files found in output directory" diff --git a/tests/test_display.py b/tests/test_display.py new file mode 100644 index 0000000..aa6821a --- /dev/null +++ b/tests/test_display.py @@ -0,0 +1,126 @@ +from unittest.mock import ANY, MagicMock + +import numpy as np +import pytest + + +def test_display_init_raises_when_tk_unavailable(monkeypatch): + import dlclive.display as display_mod + + monkeypatch.setattr(display_mod, "_TKINTER_AVAILABLE", False, raising=False) + + with pytest.raises(ImportError): + display_mod.Display() + + +def test_display_frame_creates_window_and_updates(headless_display_env): + env = headless_display_env + display_mod = env.mod + disp = display_mod.Display(radius=3, pcutoff=0.5) + + frame = np.zeros((100, 120, 3), dtype=np.uint8) + pose = np.array([[[10, 10, 0.9], [50, 50, 0.2]]]) # 1 animal, 2 bodyparts + + disp.display_frame(frame, pose) + + # Window created and initialized + env.tk_ctor.assert_called_once_with() + env.tk.title.assert_called_once_with("DLC Live") + + # Label created and packed + env.label_ctor.assert_called_once_with(env.tk) + env.label.pack.assert_called_once() + + # PhotoImage created with correct master + image passed + env.photo_ctor.assert_called_once_with(image=ANY, master=env.tk) + + # Image configured on label and window updated + env.label.configure.assert_called_once_with(image=env.photo) + env.tk.update.assert_called_once_with() + + +def test_display_draws_only_points_above_cutoff_with_clamping( + headless_display_env, monkeypatch +): + env = headless_display_env + display_mod = env.mod + disp = display_mod.Display(radius=3, pcutoff=0.5) + r = disp.radius + + # Fake colors + class FakeCC: + bmy = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] + + monkeypatch.setattr(display_mod, "cc", FakeCC) + + frame = np.zeros((50, 50, 3), dtype=np.uint8) + h, w = frame.shape[:2] + + pose = np.array( + [ + [ + [-1, -1, 0.9], # top-left offscreen + [48, 48, 0.9], # bottom-right edge + [25, 25, 0.4], # below cutoff + ] + ], + dtype=float, + ) + + draw = MagicMock() + monkeypatch.setattr(display_mod.ImageDraw, "Draw", MagicMock(return_value=draw)) + + disp.display_frame(frame, pose) + + assert draw.ellipse.call_count == 2 + calls = draw.ellipse.call_args_list + + def expected_coords(x, y): + return [ + max(0, x - r), + max(0, y - r), + min(w, x + r), + min(h, y + r), + ] + + # First point + assert calls[0].args[0] == expected_coords(-1, -1) + + # Second point + assert calls[1].args[0] == expected_coords(48, 48) + + +def test_destroy_calls_window_destroy(headless_display_env): + env = headless_display_env + display_mod = env.mod + disp = display_mod.Display() + + frame = np.zeros((10, 10, 3), dtype=np.uint8) + pose = np.array([[[5, 5, 0.9]]]) + + disp.display_frame(frame, pose) + disp.destroy() + + env.tk.destroy.assert_called_once_with() + + +def test_set_display_color_sampling_safe(headless_display_env, monkeypatch): + env = headless_display_env + display_mod = env.mod + + class FakeCC: + bmy = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 0), (0, 1, 1), (1, 0, 1)] + + monkeypatch.setattr(display_mod, "cc", FakeCC) + + disp = display_mod.Display(cmap="bmy") + disp.set_display(im_size=(100, 100), bodyparts=3) + + assert disp.colors is not None + assert len(disp.colors) >= 3 + + # Also verify window setup calls happened + env.tk_ctor.assert_called_once_with() + env.tk.title.assert_called_once_with("DLC Live") + env.label_ctor.assert_called_once_with(env.tk) + env.label.pack.assert_called_once() diff --git a/tests/test_modelzoo.py b/tests/test_modelzoo.py index 1a7e6db..c2a0d70 100644 --- a/tests/test_modelzoo.py +++ b/tests/test_modelzoo.py @@ -1,9 +1,9 @@ -# NOTE JR 2026-23-01: This is duplicate code, copied from the original DeepLabCut-Live codebase. - +# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase +# from deeplabcut/tests/pose_estimation_pytorch/modelzoo/test_modelzoo_utils.py import os -import pytest import dlclibrary +import pytest from dlclibrary.dlcmodelzoo.modelzoo_download import MODELOPTIONS from dlclive import modelzoo @@ -48,4 +48,4 @@ def test_download_huggingface_wrong_model(): @pytest.mark.skip(reason="slow") @pytest.mark.parametrize("model", MODELOPTIONS) def test_download_all_models(tmp_path_factory, model): - test_download_huggingface_model(tmp_path_factory, model) \ No newline at end of file + test_download_huggingface_model(tmp_path_factory, model) diff --git a/tests/tests_core/test_assembler.py b/tests/tests_core/test_assembler.py new file mode 100644 index 0000000..fe96cac --- /dev/null +++ b/tests/tests_core/test_assembler.py @@ -0,0 +1,453 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np +import pandas as pd +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st +from hypothesis.extra.numpy import arrays + +from dlclive.core.inferenceutils import Assembler, Assembly, Joint, Link, _conv_square_to_condensed_indices + +HYPOTHESIS_SETTINGS = settings(max_examples=300, deadline=None) + + +def _bag_from_frame(frame: dict) -> dict[int, list]: + """Build joints bag {label: [Joint, ...]} from a single frame.""" + bag: dict[int, list] = {} + for j in Assembler._flatten_detections(frame): + bag.setdefault(j.label, []).append(j) + return bag + + +# _conv_square_to_condensed_indices +@HYPOTHESIS_SETTINGS +@given( + n=st.integers(min_value=2, max_value=50), + i=st.integers(min_value=0, max_value=49), + j=st.integers(min_value=0, max_value=49), +) +def test_condensed_index_properties(n, i, j): + i = i % n + j = j % n + + if i == j: + with pytest.raises(ValueError): + _conv_square_to_condensed_indices(i, j, n) + return + + k1 = _conv_square_to_condensed_indices(i, j, n) + k2 = _conv_square_to_condensed_indices(j, i, n) + + assert k1 == k2 + assert 0 <= k1 < (n * (n - 1)) // 2 + + +# -------------------------------------------------------------------------------------- +# Basic metadata and __getitem__ +# -------------------------------------------------------------------------------------- + + +def test_parse_metadata_and_getitem(assembler_data, make_assembler): + adat = assembler_data + # Parsing + asm = make_assembler( + adat.data, + max_n_individuals=2, + n_multibodyparts=2, + ) + + assert asm.metadata["num_joints"] == 2 + assert asm.metadata["paf_graph"] == adat.graph + assert list(asm.metadata["paf"]) == adat.paf_inds + assert set(asm.metadata["imnames"]) == {"0", "1"} + # __getitem__ + assert "coordinates" in asm[0] + assert "confidence" in asm[0] + assert "costs" in asm[0] + + +def test_empty_classmethod(assembler_graph_and_pafs): + paf = assembler_graph_and_pafs + empty = Assembler.empty( + max_n_individuals=1, + n_multibodyparts=1, + n_uniquebodyparts=0, + graph=paf.graph, + paf_inds=paf.paf_inds, + ) + assert isinstance(empty, Assembler) + assert empty.n_keypoints == 1 + + +# -------------------------------------------------------------------------------------- +# _flatten_detections +# -------------------------------------------------------------------------------------- +def test_flatten_detections_no_identity(simple_two_label_scene): + frame = simple_two_label_scene + joints = list(Assembler._flatten_detections(frame)) + + assert len(joints) == 4 + assert sorted(j.label for j in joints) == [0, 0, 1, 1] + assert set(j.group for j in joints) == {-1} + + +def test_flatten_detections_with_identity(scene_copy): + frame = scene_copy + id0 = np.array([[10.0, 0.0], [0.0, 10.0]]) + id1 = np.array([[10.0, 0.0], [0.0, 10.0]]) + frame["identity"] = [id0, id1] + + joints = list(Assembler._flatten_detections(frame)) + groups = [j.group for j in joints] + + assert set(groups) == {0, 1} + assert groups.count(0) == 2 + assert groups.count(1) == 2 + + +@st.composite +def coords_and_conf(draw, max_n=5): + n = draw(st.integers(1, max_n)) + coords = draw( + arrays( + dtype=np.float64, + shape=(n, 2), + elements=st.floats(min_value=0.1, max_value=1000, allow_nan=False, allow_infinity=False), + ) + ) + conf = draw( + arrays( + dtype=np.float64, + shape=(n,), + elements=st.floats(min_value=0.0, max_value=1.0, allow_nan=False, allow_infinity=False), + ) + ) + return coords, conf + + +@HYPOTHESIS_SETTINGS +@given( + c0=coords_and_conf(), + c1=coords_and_conf(), +) +def test_flatten_detections_counts(c0, c1): + coords0, conf0 = c0 + coords1, conf1 = c1 + + frame = { + "coordinates": [[coords0, coords1]], + "confidence": [conf0, conf1], + "costs": {}, + } + + joints = list(Assembler._flatten_detections(frame)) + + # Should yield exactly one Joint per detection + assert len(joints) == (len(coords0) + len(coords1)) + assert sum(j.label == 0 for j in joints) == len(coords0) + assert sum(j.label == 1 for j in joints) == len(coords1) + + +# -------------------------------------------------------------------------------------- +# extract_best_links +# -------------------------------------------------------------------------------------- +def test_extract_best_links_optimal_assignment(assembler_data_single_frame, make_assembler): + sframe_data = assembler_data_single_frame + asm = make_assembler( + sframe_data.data, + greedy=False, # use Hungarian (maximize) + min_n_links=1, + ) + + frame0 = sframe_data.data["0"] + bag = _bag_from_frame(frame0) + + links = asm.extract_best_links(bag, frame0["costs"], trees=None) + assert len(links) == 2 + + endpoints = [{tuple(l.j1.pos), tuple(l.j2.pos)} for l in links] + assert {(0.0, 0.0), (5.0, 0.0)} in endpoints + assert {(100.0, 100.0), (110.0, 100.0)} in endpoints + + vals = sorted((l.affinity for l in links), reverse=True) + assert vals[0] == pytest.approx(0.95, rel=1e-6) + assert vals[1] == pytest.approx(0.90, rel=1e-6) + + +def test_extract_best_links_greedy_with_thresholds(assembler_data_single_frame, make_assembler): + sframe_data = assembler_data_single_frame + asm = make_assembler( + sframe_data.data, + max_n_individuals=1, # greedy will stop after 1 disjoint pair chosen + greedy=True, + pcutoff=0.5, # conf product must exceed 0.25 + min_affinity=0.5, # low-affinity pairs excluded + min_n_links=1, + ) + + frame0 = sframe_data.data["0"] + bag = _bag_from_frame(frame0) + + links = asm.extract_best_links(bag, frame0["costs"], trees=None) + assert len(links) == 1 + + s = {tuple(links[0].j1.pos), tuple(links[0].j2.pos)} + assert s in ( + {(0.0, 0.0), (5.0, 0.0)}, + {(100.0, 100.0), (110.0, 100.0)}, + ) + + +@HYPOTHESIS_SETTINGS +@given( + n=st.integers(min_value=1, max_value=4), + pcutoff=st.floats(min_value=0.0, max_value=1.0, allow_nan=False, allow_infinity=False), + min_aff=st.floats(min_value=0.0, max_value=1.0, allow_nan=False, allow_infinity=False), + conf0=st.lists(st.floats(0.0, 1.0, allow_nan=False, allow_infinity=False), min_size=1, max_size=4), + conf1=st.lists(st.floats(0.0, 1.0, allow_nan=False, allow_infinity=False), min_size=1, max_size=4), +) +def test_extract_best_links_greedy_invariants_with_threshold_gates(n, pcutoff, min_aff, conf0, conf1): + # Normalize confidences to exactly n items + conf0 = (conf0 + [0.0] * n)[:n] + conf1 = (conf1 + [0.0] * n)[:n] + conf0 = np.array(conf0, dtype=float) + conf1 = np.array(conf1, dtype=float) + + # Random-ish affinity matrix (still stable), in [0,1] + rng = np.random.default_rng(0) # deterministic noise + aff = rng.random((n, n)) # uniform [0,1) + # Ensure at least one "good" candidate sometimes; otherwise test is vacuously true. + # We'll only assert gated properties on returned links anyway. + # But for better coverage, bias the diagonal upward a bit: + np.fill_diagonal(aff, np.maximum(np.diag(aff), 0.8)) + dist = np.ones((n, n), dtype=float) + + graph = [(0, 1)] + paf_inds = [0] + data = { + "metadata": {"all_joints_names": ["b0", "b1"], "PAFgraph": graph, "PAFinds": paf_inds}, + "0": {}, + } + + asm = Assembler( + data, + max_n_individuals=n, + n_multibodyparts=2, + greedy=True, + pcutoff=pcutoff, + min_affinity=min_aff, + min_n_links=1, + method="m1", + ) + + dets0 = [Joint((float(i), 0.0), float(conf0[i]), label=0, idx=i) for i in range(n)] + dets1 = [Joint((float(i), 1.0), float(conf1[i]), label=1, idx=100 + i) for i in range(n)] + joints_dict = {0: dets0, 1: dets1} + costs = {0: {"distance": dist, "m1": aff}} + + links = asm.extract_best_links(joints_dict, costs, trees=None) + + assert len(links) <= n + + used_src = set() + used_tgt = set() + + for link in links: + # Invariant 1: affinity gate + assert link.affinity >= min_aff + + # Invariant 2: pcutoff gate (confidence product) + assert link.j1.confidence * link.j2.confidence >= pcutoff * pcutoff + + # Invariant 3: disjointness in greedy selection + assert link.j1.idx not in used_src + assert link.j2.idx not in used_tgt + used_src.add(link.j1.idx) + used_tgt.add(link.j2.idx) + + +# -------------------------------------------------------------------------------------- +# build_assemblies +# -------------------------------------------------------------------------------------- + + +def test_build_assemblies_from_links(assembler_data_single_frame, make_assembler): + sframe_data = assembler_data_single_frame + asm = make_assembler(sframe_data.data, greedy=False, min_n_links=1) + + frame0 = sframe_data.data["0"] + bag = _bag_from_frame(frame0) + + links = asm.extract_best_links(bag, frame0["costs"]) + assemblies, _ = asm.build_assemblies(links) + + assert len(assemblies) == 2 + for a in assemblies: + assert a.n_links == 1 + assert len(a) == 2 + assert a.affinity == pytest.approx(a._affinity / a.n_links) + + +# -------------------------------------------------------------------------------------- +# _assemble (per-frame) – main path +# -------------------------------------------------------------------------------------- + + +def test__assemble_main_no_calibration_returns_two_assemblies(assembler_data_single_frame, make_assembler): + sframe_data = assembler_data_single_frame + asm = make_assembler( + sframe_data.data, + greedy=False, + min_n_links=1, + max_overlap=0.99, + window_size=0, + ) + + assemblies, unique = asm._assemble(sframe_data.data["0"], 0) + assert unique is None + assert len(assemblies) == 2 + assert all(len(a) == 2 for a in assemblies) + + +def test__assemble_returns_none_when_no_detections(assembler_data_no_detections, make_assembler): + nodet_data = assembler_data_no_detections + asm = make_assembler(nodet_data.data, max_n_individuals=2, n_multibodyparts=2) + + assemblies, unique = asm._assemble(nodet_data.data["0"], 0) + assert assemblies is None and unique is None + + +# -------------------------------------------------------------------------------------- +# assemble() over multiple frames + KD-tree caching +# -------------------------------------------------------------------------------------- + + +def test_assemble_across_frames_updates_temporal_trees(assembler_data_two_frames_nudged, make_assembler): + twofr_data = assembler_data_two_frames_nudged + asm = make_assembler( + twofr_data.data, + window_size=1, # enable temporal memory + min_n_links=1, + ) + + asm.assemble(chunk_size=0) + + assert 0 in asm._trees or 1 in asm._trees + assert isinstance(asm.assemblies, dict) + assert set(asm.assemblies.keys()).issubset({0, 1}) + + +# -------------------------------------------------------------------------------------- +# identity_only=True branch +# -------------------------------------------------------------------------------------- + + +def test_identity_only_branch_groups_by_identity(assembler_data_single_frame, scene_copy, make_assembler): + sframe_data = assembler_data_single_frame + + base = scene_copy + id0 = np.array([[4.0, 1.0], [1.0, 4.0]]) + id1 = np.array([[4.0, 1.0], [1.0, 4.0]]) + base["identity"] = [id0, id1] + sframe_data.data["0"] = base + + asm = make_assembler( + sframe_data.data, + max_n_individuals=3, + identity_only=True, + pcutoff=0.1, + ) + + assemblies, _ = asm._assemble(sframe_data.data["0"], 0) + assert assemblies is not None + assert all(len(a) >= 1 for a in assemblies) + + +# -------------------------------------------------------------------------------------- +# Mahalanobis & link probability +# -------------------------------------------------------------------------------------- + + +@dataclass +class _FakeKDE: + mean: np.ndarray + inv_cov: np.ndarray + covariance: np.ndarray + d: int + + +def test_calc_assembly_mahalanobis_and_link_probability_with_fake_kde(assembler_data_single_frame, make_assembler): + sframe_data = assembler_data_single_frame + asm = make_assembler(sframe_data.data, min_n_links=1) + + j0 = Joint((0.0, 0.0), 1.0, 0, 0) + j1 = Joint((3.0, 4.0), 1.0, 1, 1) + link = Link(j0, j1, 1.0) + + a = Assembly(size=2) + a.add_link(link) + + fake = _FakeKDE( + mean=np.array([25.0]), + inv_cov=np.array([[1.0]]), + covariance=np.array([[1.0]]), + d=1, + ) + asm._kde = fake + asm.safe_edge = True + + d = asm.calc_assembly_mahalanobis_dist(a) + assert d == pytest.approx(0.0, abs=1e-6) + + p = asm.calc_link_probability(link) + assert p == pytest.approx(1.0, rel=1e-6) + + +# -------------------------------------------------------------------------------------- +# I/O: pickle / h5 +# -------------------------------------------------------------------------------------- + + +def test_to_pickle_and_from_pickle(tmp_path, assembler_data_single_frame, make_assembler, assembler_graph_and_pafs): + sframe_data = assembler_data_single_frame + asm = make_assembler(sframe_data.data, min_n_links=1) + assemblies, _ = asm._assemble(sframe_data.data["0"], 0) + asm.assemblies = {0: assemblies} + + pkl = tmp_path / "assemb.pkl" + asm.to_pickle(str(pkl)) + + new_asm = Assembler.empty( + max_n_individuals=2, + n_multibodyparts=2, + n_uniquebodyparts=0, + graph=sframe_data.graph, + paf_inds=sframe_data.paf_inds, + ) + new_asm.from_pickle(str(pkl)) + + assert 0 in new_asm.assemblies + assert isinstance(new_asm.assemblies[0], list) + + +@pytest.mark.skipif( + pytest.importorskip("tables", reason="PyTables required for HDF5") is None, + reason="requires PyTables", +) +def test_to_h5_roundtrip(tmp_path, assembler_data_single_frame, make_assembler): + sframe_data = assembler_data_single_frame + + asm = make_assembler(sframe_data.data, min_n_links=1) + assemblies, _ = asm._assemble(sframe_data.data["0"], 0) + asm.assemblies = {0: assemblies} + + h5 = tmp_path / "assemb.h5" + asm.to_h5(str(h5)) + + df = pd.read_hdf(str(h5), key="ass") + assert df.shape[0] == 1 + assert df.columns.nlevels == 4 + assert set(df.columns.get_level_values("coords")) == {"x", "y", "likelihood"} diff --git a/tests/tests_core/test_assembly.py b/tests/tests_core/test_assembly.py new file mode 100644 index 0000000..55ffa08 --- /dev/null +++ b/tests/tests_core/test_assembly.py @@ -0,0 +1,268 @@ +import numpy as np +import pytest +from hypothesis import assume, given, settings +from hypothesis import strategies as st +from hypothesis.extra.numpy import arrays + +from dlclive.core.inferenceutils import Assembly + +HYPOTHESIS_SETTINGS = settings(max_examples=200, deadline=None) + + +# --------------------------- +# Basic construction +# --------------------------- +def test_assembly_init(make_assembly): + assemb = make_assembly(size=5) + assert assemb.data.shape == (5, 4) + + # col 0,1,3 are NaN, col 2 is confidence=0 + # this is due to confidence being initialized to 0 + # by default in Assembly.__init__ + assert np.isnan(assemb.data[:, 0]).all() + assert np.isnan(assemb.data[:, 1]).all() + assert (assemb.data[:, 2] == 0).all() + assert np.isnan(assemb.data[:, 3]).all() + + assert assemb._affinity == 0 + assert assemb._links == [] + assert assemb._visible == set() + assert assemb._idx == set() + + +# --------------------------- +# from_array +# --------------------------- +@HYPOTHESIS_SETTINGS +@given( + st.integers(min_value=1, max_value=30).flatmap( + lambda n_rows: st.sampled_from([2, 3]).flatmap( + lambda n_cols: arrays( + dtype=np.float64, + shape=(n_rows, n_cols), + elements=st.floats(allow_infinity=False, allow_nan=True, width=32), + ) + ) + ) +) +def test_from_array_invariants(arr): + n_rows, n_cols = arr.shape + + assemb = Assembly.from_array(arr.copy()) + assert assemb.data.shape == (n_rows, 4) + + # Row is "valid" iff it has no NaN among the provided columns + row_valid = ~np.isnan(arr).any(axis=1) + visible = set(np.flatnonzero(row_valid).tolist()) + assert assemb._visible == visible + + out = assemb.data + + # For invalid rows: x/y must be NaN + assert np.all(np.isnan(out[~row_valid, 0])) + assert np.all(np.isnan(out[~row_valid, 1])) + + # Confidence behavior differs depending on number of columns + if n_cols == 2: + # XY-only input: confidence starts at 0 and is set to 1 for visible rows only + assert np.all(out[row_valid, 2] == pytest.approx(1.0)) + assert np.all(out[~row_valid, 2] == pytest.approx(0.0)) + else: + # XY+confidence input: confidence is preserved for visible rows + assert np.allclose(out[row_valid, 2], arr[row_valid, 2], equal_nan=False) + # Invalid rows become NaN in all provided columns, including confidence + assert np.all(np.isnan(out[~row_valid, 2])) + + # Visible rows preserve xy + assert np.allclose(out[row_valid, :2], arr[row_valid, :2], equal_nan=False) + + +def test_assembly_from_array_with_nans(): + arr = np.array( + [ + [10.0, 20.0, 0.9], + [np.nan, 5.0, 0.8], # one NaN → entire row becomes NaN + ] + ) + assemb = Assembly.from_array(arr.copy()) + + assert np.allclose(assemb.data[0], [10.0, 20.0, 0.9, np.nan], equal_nan=True) + assert np.isnan(assemb.data[1]).all() + + # visible only includes fully non-NaN rows + assert assemb._visible == {0} + + +# --------------------------- +# extent, area, xy +# --------------------------- +@HYPOTHESIS_SETTINGS +@given( + coords=arrays( + dtype=np.float64, + shape=st.tuples(st.integers(1, 30), st.just(2)), + elements=st.floats(allow_nan=True, allow_infinity=False, width=32), + ) +) +def test_extent_matches_visible_points(coords): + xy = coords.copy() + # Ensure rows with any NaN are fully NaN, + # matching Assembly's from_array behavior + xy[np.isnan(xy).any(axis=1)] = np.nan + a = Assembly(size=xy.shape[0]) + a.data[:] = np.nan + a.data[:, :2] = xy + a._visible = set(np.flatnonzero(~np.isnan(xy).any(axis=1)).tolist()) + + visible_mask = ~np.isnan(coords).any(axis=1) + assume(visible_mask.any()) + + expected = np.array( + [ + coords[visible_mask, 0].min(), + coords[visible_mask, 1].min(), + coords[visible_mask, 0].max(), + coords[visible_mask, 1].max(), + ] + ) + assert np.allclose(a.extent, expected) + assert a.area >= 0 + + +# --------------------------- +# add_joint / remove_joint +# --------------------------- +def test_add_joint_and_remove_joint(make_assembly, make_joint): + assemb = make_assembly(size=3) + j0 = make_joint(pos=(1.0, 2.0), confidence=0.5, label=0, idx=10) + j1 = make_joint(pos=(3.0, 4.0), confidence=0.8, label=1, idx=11) + + # adding first joint + assert assemb.add_joint(j0) is True + assert assemb._visible == {0} + assert assemb._idx == {10} + assert np.allclose(assemb.data[0], [1.0, 2.0, 0.5, j0.group]) + + # adding second joint + assert assemb.add_joint(j1) is True + assert assemb._visible == {0, 1} + assert assemb._idx == {10, 11} + + # adding same joint label again → ignored + assert assemb.add_joint(j0) is False + + # removing joint + assert assemb.remove_joint(j1) is True + assert assemb._visible == {0} + assert assemb._idx == {10} + assert np.isnan(assemb.data[1]).all() + + # remove nonexistent → False + assert assemb.remove_joint(j1) is False + + +# --------------------------- +# add_link (simple) +# --------------------------- +def test_add_link_adds_joints_and_affinity(make_assembly, make_joint, make_link): + assemb = make_assembly(size=3) + + j0 = make_joint(pos=(0.0, 0.0), confidence=1.0, label=0, idx=100) + j1 = make_joint(pos=(1.0, 0.0), confidence=1.0, label=1, idx=101) + link = make_link(j0, j1, affinity=0.7) + + # New link → adds both joints + result = assemb.add_link(link) + assert result is True + assert assemb.n_links == 1 + assert assemb._affinity == pytest.approx(0.7) + assert assemb._visible == {0, 1} + assert assemb._idx == {100, 101} + + # Add same link again → both idx already present → only increases affinity, no new joints + result = assemb.add_link(link) + assert result is False # as per code path + assert assemb.n_links == 2 # link appended again + assert assemb._affinity == pytest.approx(1.4) # 0.7 + 0.7 + + +# --------------------------- +# intersection_with +# --------------------------- +def test_intersection_with_partial_overlap(two_overlap_assemblies): + ass1, ass2 = two_overlap_assemblies + assert ass1.intersection_with(ass2) == pytest.approx(0.5) + + +# --------------------------- +# confidence property +# --------------------------- +def test_confidence_property(make_assembly): + assemb = make_assembly(size=3) + assemb.data[:] = np.nan + assemb.data[:, 2] = [0.2, 0.4, np.nan] # mean of finite = (0.2+0.4)/2 = 0.3 + assert assemb.confidence == pytest.approx(0.3) + + assemb.confidence = 0.9 + assert np.allclose(assemb.data[:, 2], [0.9, 0.9, 0.9], equal_nan=True) + + +# --------------------------- +# soft_identity +# --------------------------- +def test_soft_identity_simple(soft_identity_assembly): + assemb = soft_identity_assembly + soft = assemb.soft_identity + assert set(soft.keys()) == {0, 1} + s0, s1 = soft[0], soft[1] + assert pytest.approx(s0 + s1) == 1.0 + assert s1 > s0 + + +# --------------------------- +# intersection operator: __contains__ +# --------------------------- +def test_contains_checks_shared_idx(make_assembly, make_joint): + ass1 = make_assembly(size=3) + ass2 = make_assembly(size=3) + + j0 = make_joint((0, 0), confidence=1.0, label=0, idx=10) + j1 = make_joint((1, 1), confidence=1.0, label=1, idx=99) + + ass1.add_joint(j0) + ass2.add_joint(j1) + + # different idx sets → no intersection + assert (ass2 in ass1) is False + + ass2.add_joint(j0) + # now share idx=10 + assert (ass2 in ass1) is True + + +# --------------------------- +# assembly addition (__add__) +# --------------------------- +def test_assembly_addition_combines_links(make_assembly, four_joint_chain): + a1 = make_assembly(size=4) + a2 = make_assembly(size=4) + + chain = four_joint_chain + + a1.add_link(chain.l01) + a2.add_link(chain.l23) + + # now they share NO joints → addition should succeed + result = a1 + a2 + + assert result.n_links == 2 + assert result._affinity == pytest.approx(1.3) + + # original assemblies unchanged + assert a1.n_links == 1 + assert a2.n_links == 1 + + # now purposely make them share a joint → should raise + a2.add_joint(chain.j0) + with pytest.raises(ArithmeticError): + _ = a1 + a2