From 81f9c9e7a1d151246345a46277cc19cc04bdb9fb Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 20 Jan 2026 14:34:12 +0100 Subject: [PATCH 01/20] Add pre-commit configuration file Introduces .pre-commit-config.yaml to automate code formatting, linting, and basic checks using pre-commit hooks for improved code quality and consistency. --- .pre-commit-config.yaml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..2e5ed7b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,19 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + - id: check-docstring-first + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/asottile/setup-cfg-fmt + rev: v3.2.0 + hooks: + - id: setup-cfg-fmt + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.10 + hooks: + # Run the formatter. + - id: ruff-format + # Run the linter. + - id: ruff-check + args: [--fix,--unsafe-fixes] From cb7316c523fd339dec6ac908549fa200c85bcc83 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 20 Jan 2026 14:35:08 +0100 Subject: [PATCH 02/20] Update CI workflow and add pytest-cov to dev dependencies Upgraded GitHub Actions to use newer versions and improved test steps by separating model benchmark and unit tests, adding coverage reporting with codecov. Added pytest-cov to dev dependencies and configured Ruff linter settings in pyproject.toml. --- .github/workflows/testing.yml | 17 +++++++++++------ pyproject.toml | 12 ++++++++++-- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 242ec58..17ee47d 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -43,7 +43,7 @@ jobs: df -h - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install uv uses: astral-sh/setup-uv@v6 @@ -55,7 +55,7 @@ jobs: - name: Install the project run: uv sync --no-cache --all-extras --dev shell: bash - + - name: Install ffmpeg run: | if [ "$RUNNER_OS" == "Linux" ]; then @@ -67,9 +67,14 @@ jobs: choco install ffmpeg fi shell: bash - - - name: Run DLC Live Tests + + - name: Run Model Benchmark Test run: uv run dlc-live-test --nodisplay - - name: Run Functional Benchmark Test - run: uv run pytest + - name: Run DLC Live Unit Tests + run: uv run pytest --cov=dlclive --cov-report=xml + + - name: Coverage Report + uses: codecov/codecov-action@v5 + with: + files: ./coverage.xml diff --git a/pyproject.toml b/pyproject.toml index 951646b..a4fdd80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,11 +61,12 @@ tf = [ [dependency-groups] dev = [ "pytest", + "pytest-cov", "black", "ruff", ] -# Keep only for backward compatibility with Poetry +# Keep only for backward compatibility with Poetry # (without this section, Poetry assumes the wrong root directory of the project) [tool.poetry] packages = [ @@ -87,4 +88,11 @@ include-package-data = true include = ["dlclive*"] [tool.setuptools.package-data] -dlclive = ["check_install/*"] \ No newline at end of file +dlclive = ["check_install/*"] + +[tool.ruff] +lint.select = ["E", "F", "B", "I", "UP"] +lint.ignore = ["E741"] +target-version = "py310" +fix = true +line-length = 120 From eaaac972a6da87caf2ff002bd645e96b99da148e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 20 Jan 2026 14:56:34 +0100 Subject: [PATCH 03/20] Update pyproject.toml --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a4fdd80..9217c8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,8 +91,11 @@ include = ["dlclive*"] dlclive = ["check_install/*"] [tool.ruff] -lint.select = ["E", "F", "B", "I", "UP"] +lint.select = ["D", "E", "F", "B", "I", "UP"] lint.ignore = ["E741"] target-version = "py310" fix = true line-length = 120 + +[tool.ruff.pydocstyle] +convention = "google" From 7b8113e5a4672b74b097a077f7ee779bbbbc1c05 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Wed, 21 Jan 2026 09:53:25 +0100 Subject: [PATCH 04/20] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9217c8f..b448150 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,7 @@ include = ["dlclive*"] dlclive = ["check_install/*"] [tool.ruff] -lint.select = ["D", "E", "F", "B", "I", "UP"] +lint.select = ["E", "F", "B", "I", "UP"] lint.ignore = ["E741"] target-version = "py310" fix = true From 7ec008f7909d65b8e2ba98256f3cc4e928f64da2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 22 Jan 2026 09:01:21 +0100 Subject: [PATCH 05/20] Add comprehensive tests for Assembler and Assembly classes Introduces detailed unit tests for the Assembler, Assembly, Joint, and Link classes in dlclive.core.inferenceutils. The new tests cover metadata parsing, detection flattening, link extraction, assembly building, Mahalanobis distance calculation, I/O helpers, and various Assembly operations, improving test coverage and reliability. --- dlclive/core/inferenceutils.py | 149 +++------ tests/tests_core/test_assembler.py | 516 +++++++++++++++++++++++++++++ tests/tests_core/test_assembly.py | 271 +++++++++++++++ 3 files changed, 838 insertions(+), 98 deletions(-) create mode 100644 tests/tests_core/test_assembler.py create mode 100644 tests/tests_core/test_assembly.py diff --git a/dlclive/core/inferenceutils.py b/dlclive/core/inferenceutils.py index 81d9d43..dc2f1f5 100644 --- a/dlclive/core/inferenceutils.py +++ b/dlclive/core/inferenceutils.py @@ -8,6 +8,9 @@ # # Licensed under GNU Lesser General Public License v3.0 # + + +# NOTE DUPLICATED from deeplabcut/core/inferenceutils.py from __future__ import annotations import heapq @@ -17,9 +20,10 @@ import pickle import warnings from collections import defaultdict +from collections.abc import Iterable from dataclasses import dataclass from math import erf, sqrt -from typing import Any, Iterable, Tuple +from typing import Any import networkx as nx import numpy as np @@ -41,7 +45,7 @@ def _conv_square_to_condensed_indices(ind_row, ind_col, n): return n * ind_col - ind_col * (ind_col + 1) // 2 + ind_row - 1 - ind_col -Position = Tuple[float, float] +Position = tuple[float, float] @dataclass(frozen=True) @@ -61,9 +65,7 @@ def __init__(self, j1, j2, affinity=1): self._length = sqrt((j1.pos[0] - j2.pos[0]) ** 2 + (j1.pos[1] - j2.pos[1]) ** 2) def __repr__(self): - return ( - f"Link {self.idx}, affinity={self.affinity:.2f}, length={self.length:.2f}" - ) + return f"Link {self.idx}, affinity={self.affinity:.2f}, length={self.length:.2f}" @property def confidence(self): @@ -155,7 +157,7 @@ def soft_identity(self): unq, idx, cnt = np.unique(data[:, 3], return_inverse=True, return_counts=True) avg = np.bincount(idx, weights=data[:, 2]) / cnt soft = softmax(avg) - return dict(zip(unq.astype(int), soft)) + return dict(zip(unq.astype(int), soft, strict=False)) @property def affinity(self): @@ -261,9 +263,7 @@ def __init__( self.max_overlap = max_overlap self._has_identity = "identity" in self[0] if identity_only and not self._has_identity: - warnings.warn( - "The network was not trained with identity; setting `identity_only` to False." - ) + warnings.warn("The network was not trained with identity; setting `identity_only` to False.", stacklevel=2) self.identity_only = identity_only & self._has_identity self.nan_policy = nan_policy self.force_fusion = force_fusion @@ -344,7 +344,7 @@ def calibrate(self, train_data_file): pass n_bpts = len(df.columns.get_level_values("bodyparts").unique()) if n_bpts == 1: - warnings.warn("There is only one keypoint; skipping calibration...") + warnings.warn("There is only one keypoint; skipping calibration...", stacklevel=2) return xy = df.to_numpy().reshape((-1, n_bpts, 2)) @@ -352,7 +352,7 @@ def calibrate(self, train_data_file): # Only keeps skeletons that are more than 90% complete xy = xy[frac_valid >= 0.9] if not xy.size: - warnings.warn("No complete poses were found. Skipping calibration...") + warnings.warn("No complete poses were found. Skipping calibration...", stacklevel=2) return # TODO Normalize dists by longest length? @@ -368,13 +368,9 @@ def calibrate(self, train_data_file): self.safe_edge = True except np.linalg.LinAlgError: # Covariance matrix estimation fails due to numerical singularities - warnings.warn( - "The assembler could not be robustly calibrated. Continuing without it..." - ) + warnings.warn("The assembler could not be robustly calibrated. Continuing without it...", stacklevel=2) - def calc_assembly_mahalanobis_dist( - self, assembly, return_proba=False, nan_policy="little" - ): + def calc_assembly_mahalanobis_dist(self, assembly, return_proba=False, nan_policy="little"): if self._kde is None: raise ValueError("Assembler should be calibrated first with training data.") @@ -428,10 +424,10 @@ def _flatten_detections(data_dict): ids = [np.ones(len(arr), dtype=int) * -1 for arr in confidence] else: ids = [arr.argmax(axis=1) for arr in ids] - for i, (coords, conf, id_) in enumerate(zip(coordinates, confidence, ids)): + for i, (coords, conf, id_) in enumerate(zip(coordinates, confidence, ids, strict=False)): if not np.any(coords): continue - for xy, p, g in zip(coords, conf, id_): + for xy, p, g in zip(coords, conf, id_, strict=False): joint = Joint(tuple(xy), p.item(), i, ind, g) ind += 1 yield joint @@ -453,9 +449,7 @@ def extract_best_links(self, joints_dict, costs, trees=None): aff[np.isnan(aff)] = 0 if trees: - vecs = np.vstack( - [[*det_s.pos, *det_t.pos] for det_s in dets_s for det_t in dets_t] - ) + vecs = np.vstack([[*det_s.pos, *det_t.pos] for det_s in dets_s for det_t in dets_t]) dists = [] for n, tree in enumerate(trees, start=1): d, _ = tree.query(vecs) @@ -464,23 +458,16 @@ def extract_best_links(self, joints_dict, costs, trees=None): aff *= w.reshape(aff.shape) if self.greedy: - conf = np.asarray( - [ - [det_s.confidence * det_t.confidence for det_t in dets_t] - for det_s in dets_s - ] - ) - rows, cols = np.where( - (conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity) - ) + conf = np.asarray([[det_s.confidence * det_t.confidence for det_t in dets_t] for det_s in dets_s]) + rows, cols = np.where((conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity)) candidates = sorted( - zip(rows, cols, aff[rows, cols], lengths[rows, cols]), + zip(rows, cols, aff[rows, cols], lengths[rows, cols], strict=False), key=lambda x: x[2], reverse=True, ) i_seen = set() j_seen = set() - for i, j, w, l in candidates: + for i, j, w, _l in candidates: if i not in i_seen and j not in j_seen: i_seen.add(i) j_seen.add(j) @@ -488,21 +475,17 @@ def extract_best_links(self, joints_dict, costs, trees=None): if len(i_seen) == self.max_n_individuals: break else: # Optimal keypoint pairing - inds_s = sorted( - range(len(dets_s)), key=lambda x: dets_s[x].confidence, reverse=True - )[: self.max_n_individuals] - inds_t = sorted( - range(len(dets_t)), key=lambda x: dets_t[x].confidence, reverse=True - )[: self.max_n_individuals] - keep_s = [ - ind for ind in inds_s if dets_s[ind].confidence >= self.pcutoff + inds_s = sorted(range(len(dets_s)), key=lambda x: dets_s[x].confidence, reverse=True)[ + : self.max_n_individuals ] - keep_t = [ - ind for ind in inds_t if dets_t[ind].confidence >= self.pcutoff + inds_t = sorted(range(len(dets_t)), key=lambda x: dets_t[x].confidence, reverse=True)[ + : self.max_n_individuals ] + keep_s = [ind for ind in inds_s if dets_s[ind].confidence >= self.pcutoff] + keep_t = [ind for ind in inds_t if dets_t[ind].confidence >= self.pcutoff] aff = aff[np.ix_(keep_s, keep_t)] rows, cols = linear_sum_assignment(aff, maximize=True) - for row, col in zip(rows, cols): + for row, col in zip(rows, cols, strict=False): w = aff[row, col] if w >= self.min_affinity: links.append(Link(dets_s[keep_s[row]], dets_t[keep_t[col]], w)) @@ -538,9 +521,7 @@ def push_to_stack(i): if new_ind in assembled: continue if safe_edge: - d_old = self.calc_assembly_mahalanobis_dist( - assembly, nan_policy=nan_policy - ) + d_old = self.calc_assembly_mahalanobis_dist(assembly, nan_policy=nan_policy) success = assembly.add_link(best, store_dict=True) if not success: assembly._dict = dict() @@ -548,9 +529,9 @@ def push_to_stack(i): d = self.calc_assembly_mahalanobis_dist(assembly, nan_policy=nan_policy) if d < d_old: push_to_stack(new_ind) - if tabu: - _, _, link = heapq.heappop(tabu) - heapq.heappush(stack, (-link.affinity, next(counter), link)) + if tabu: + _, _, link = heapq.heappop(tabu) + heapq.heappush(stack, (-link.affinity, next(counter), link)) else: heapq.heappush(tabu, (d - d_old, next(counter), best)) assembly.__dict__.update(assembly._dict) @@ -593,9 +574,7 @@ def build_assemblies(self, links): continue assembly = Assembly(self.n_multibodyparts) assembly.add_link(link) - self._fill_assembly( - assembly, lookup, assembled, self.safe_edge, self.nan_policy - ) + self._fill_assembly(assembly, lookup, assembled, self.safe_edge, self.nan_policy) for assembly_link in assembly._links: i, j = assembly_link.idx lookup[i].pop(j) @@ -607,10 +586,7 @@ def build_assemblies(self, links): n_extra = len(assemblies) - self.max_n_individuals if n_extra > 0: if self.safe_edge: - ds_old = [ - self.calc_assembly_mahalanobis_dist(assembly) - for assembly in assemblies - ] + ds_old = [self.calc_assembly_mahalanobis_dist(assembly) for assembly in assemblies] while len(assemblies) > self.max_n_individuals: ds = [] for i, j in itertools.combinations(range(len(assemblies)), 2): @@ -665,7 +641,7 @@ def build_assemblies(self, links): for idx in store[j]._idx: store[idx] = store[i] except KeyError: - # Some links may reference indices that were never added to `store`; + # Some links may reference indices that were never added to `store`; # in that case we intentionally skip merging for this link pass @@ -742,10 +718,7 @@ def _assemble(self, data_dict, ind_frame): for _, group in groups: ass = Assembly(self.n_multibodyparts) for joint in sorted(group, key=lambda x: x.confidence, reverse=True): - if ( - joint.confidence >= self.pcutoff - and joint.label < self.n_multibodyparts - ): + if joint.confidence >= self.pcutoff and joint.label < self.n_multibodyparts: ass.add_joint(joint) if len(ass): assemblies.append(ass) @@ -774,11 +747,7 @@ def _assemble(self, data_dict, ind_frame): assembled.update(assembled_) # Remove invalid assemblies - discarded = set( - joint - for joint in joints - if joint.idx not in assembled and np.isfinite(joint.confidence) - ) + discarded = set(joint for joint in joints if joint.idx not in assembled and np.isfinite(joint.confidence)) for assembly in assemblies[::-1]: if 0 < assembly.n_links < self.min_n_links or not len(assembly): for link in assembly._links: @@ -786,12 +755,10 @@ def _assemble(self, data_dict, ind_frame): assemblies.remove(assembly) if 0 < self.max_overlap < 1: # Non-maximum pose suppression if self._kde is not None: - scores = [ - -self.calc_assembly_mahalanobis_dist(ass) for ass in assemblies - ] + scores = [-self.calc_assembly_mahalanobis_dist(ass) for ass in assemblies] else: scores = [ass._affinity for ass in assemblies] - lst = list(zip(scores, assemblies)) + lst = list(zip(scores, assemblies, strict=False)) assemblies = [] while lst: temp = max(lst, key=lambda x: x[0]) @@ -857,9 +824,7 @@ def wrapped(i): n_frames = len(self.metadata["imnames"]) with multiprocessing.Pool(n_processes) as p: with tqdm(total=n_frames) as pbar: - for i, (assemblies, unique) in p.imap_unordered( - wrapped, range(n_frames), chunksize=chunk_size - ): + for i, (assemblies, unique) in p.imap_unordered(wrapped, range(n_frames), chunksize=chunk_size): if assemblies: self.assemblies[i] = assemblies if unique is not None: @@ -878,9 +843,7 @@ def parse_metadata(data): params["joint_names"] = data["metadata"]["all_joints_names"] params["num_joints"] = len(params["joint_names"]) params["paf_graph"] = data["metadata"]["PAFgraph"] - params["paf"] = data["metadata"].get( - "PAFinds", np.arange(len(params["joint_names"])) - ) + params["paf"] = data["metadata"].get("PAFinds", np.arange(len(params["joint_names"]))) params["bpts"] = params["ibpts"] = range(params["num_joints"]) params["imnames"] = [fn for fn in list(data) if fn != "metadata"] return params @@ -970,11 +933,7 @@ def calc_object_keypoint_similarity( else: oks = [] xy_preds = [xy_pred] - combos = ( - pair - for l in range(len(symmetric_kpts)) - for pair in itertools.combinations(symmetric_kpts, l + 1) - ) + combos = (pair for l in range(len(symmetric_kpts)) for pair in itertools.combinations(symmetric_kpts, l + 1)) for pairs in combos: # Swap corresponding keypoints tmp = xy_pred.copy() @@ -1011,9 +970,7 @@ def match_assemblies( num_ground_truth = len(ground_truth) # Sort predictions by score - inds_pred = np.argsort( - [ins.affinity if ins.n_links else ins.confidence for ins in predictions] - )[::-1] + inds_pred = np.argsort([ins.affinity if ins.n_links else ins.confidence for ins in predictions])[::-1] predictions = np.asarray(predictions)[inds_pred] # indices of unmatched ground truth assemblies @@ -1074,7 +1031,7 @@ def match_assemblies( if ~np.isnan(oks): mat[i, j] = oks rows, cols = linear_sum_assignment(mat, maximize=True) - for row, col in zip(rows, cols): + for row, col in zip(rows, cols, strict=False): matched[row].ground_truth = ground_truth[col] matched[row].oks = mat[row, col] _ = inds_true.remove(col) @@ -1087,7 +1044,7 @@ def parse_ground_truth_data_file(h5_file): try: df.drop("single", axis=1, level="individuals", inplace=True) except KeyError: - # Ignore if the "single" individual column is absent + # Ignore if the "single" individual column is absent pass # Cast columns of dtype 'object' to float to avoid TypeError # further down in _parse_ground_truth_data. @@ -1120,15 +1077,13 @@ def find_outlier_assemblies(dict_of_assemblies, criterion="area", qs=(5, 95)): raise ValueError(f"Invalid criterion {criterion}.") if len(qs) != 2: - raise ValueError( - "Two percentiles (for lower and upper bounds) should be given." - ) + raise ValueError("Two percentiles (for lower and upper bounds) should be given.") tuples = [] for frame_ind, assemblies in dict_of_assemblies.items(): for assembly in assemblies: tuples.append((frame_ind, getattr(assembly, criterion))) - frame_inds, vals = zip(*tuples) + frame_inds, vals = zip(*tuples, strict=False) vals = np.asarray(vals) lo, up = np.percentile(vals, qs, interpolation="nearest") inds = np.flatnonzero((vals < lo) | (vals > up)).tolist() @@ -1226,9 +1181,7 @@ def evaluate_assembly_greedy( oks = np.asarray([match.oks for match in all_matched])[sorted_pred_indices] # Compute prediction and recall - p, r = _compute_precision_and_recall( - total_gt_assemblies, oks, oks_t, recall_thresholds - ) + p, r = _compute_precision_and_recall(total_gt_assemblies, oks, oks_t, recall_thresholds) precisions.append(p) recalls.append(r) @@ -1246,12 +1199,14 @@ def evaluate_assembly( ass_pred_dict, ass_true_dict, oks_sigma=0.072, - oks_thresholds=np.linspace(0.5, 0.95, 10), + oks_thresholds=None, margin=0, symmetric_kpts=None, greedy_matching=False, with_tqdm: bool = True, ): + if oks_thresholds is None: + oks_thresholds = np.linspace(0.5, 0.95, 10) if greedy_matching: return evaluate_assembly_greedy( ass_true_dict, @@ -1299,9 +1254,7 @@ def evaluate_assembly( precisions = [] recalls = [] for t in oks_thresholds: - p, r = _compute_precision_and_recall( - total_gt_assemblies, oks, t, recall_thresholds - ) + p, r = _compute_precision_and_recall(total_gt_assemblies, oks, t, recall_thresholds) precisions.append(p) recalls.append(r) diff --git a/tests/tests_core/test_assembler.py b/tests/tests_core/test_assembler.py new file mode 100644 index 0000000..3b9620d --- /dev/null +++ b/tests/tests_core/test_assembler.py @@ -0,0 +1,516 @@ +from dataclasses import dataclass + +import numpy as np +import pandas as pd +import pytest + +from dlclive.core.inferenceutils import Assembler, Assembly, Joint, Link + +# -------------------------------------------------------------------------------------- +# Helpers +# -------------------------------------------------------------------------------------- + + +def make_metadata(graph, paf_inds, n_bodyparts, frame_keys): + """Create a minimal DLC-like metadata structure for Assembler.""" + return { + "metadata": { + "all_joints_names": [f"b{i}" for i in range(n_bodyparts)], + "PAFgraph": graph, + "PAFinds": paf_inds, + }, + **{k: {} for k in frame_keys}, + } + + +def make_frame(coordinates_per_label, confidence_per_label, identity_per_label=None, costs=None): + """ + Build a single frame dict with the structure Assembler._flatten_detections expects. + + coordinates_per_label: list of np.ndarray[(n_dets, 2)] + confidence_per_label: list of np.ndarray[(n_dets, )] + identity_per_label: list of np.ndarray[(n_dets, n_groups)] or None + costs: dict or None. Example: + { + 0: { + "distance": np.array([[...]]), # rows: label s detections, cols: label t + "m1": np.array([[...]]), + } + } + """ + frame = { + # NOTE: Assembler expects coordinates under key "coordinates"[0] + "coordinates": [coordinates_per_label], + "confidence": confidence_per_label, + "costs": costs or {}, + } + if identity_per_label is not None: + frame["identity"] = identity_per_label + return frame + + +def simple_two_label_scene(): + """ + Build a very small, deterministic scene with 2 bodyparts (0 ↔ 1), each with 2 detections. + We design affinities so that the intended pairs are (A↔C) and (B↔D). + """ + # Label 0 detections: A (near origin), B (far) + coords0 = np.array([[0.0, 0.0], [100.0, 100.0]]) + conf0 = np.array([0.9, 0.6]) + + # Label 1 detections: C near A, D near B + coords1 = np.array([[5.0, 0.0], [110.0, 100.0]]) + conf1 = np.array([0.8, 0.7]) + + # Affinities: strong on diagonal pairs AC and BD, weak elsewhere + aff = np.array([[0.95, 0.1], [0.05, 0.9]]) + # Lengths: finite distances (not used for assignment, but must be finite) + lens = np.array( + [ + [np.hypot(*(coords1[0] - coords0[0])), np.hypot(*(coords1[1] - coords0[0]))], + [np.hypot(*(coords1[0] - coords0[1])), np.hypot(*(coords1[1] - coords0[1]))], + ] + ) + + # Build frame expected by Assembler + frame0 = make_frame( + coordinates_per_label=[coords0, coords1], + confidence_per_label=[conf0, conf1], + identity_per_label=None, + costs={0: {"distance": lens, "m1": aff}}, + ) + return frame0 + + +# -------------------------------------------------------------------------------------- +# Basic metadata and __getitem__ +# -------------------------------------------------------------------------------------- + + +def test_parse_metadata_and_getitem_and_empty_classmethod(): + graph = [(0, 1)] + paf_inds = [0] + data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0", "1"]) + # Fill frames so __getitem__ returns something non-empty later + data["0"] = simple_two_label_scene() + data["1"] = simple_two_label_scene() + + asm = Assembler( + data, + max_n_individuals=2, + n_multibodyparts=2, + ) + + # parse_metadata applied in ctor: + assert asm.metadata["num_joints"] == 2 + assert asm.metadata["paf_graph"] == graph + assert list(asm.metadata["paf"]) == paf_inds + assert set(asm.metadata["imnames"]) == {"0", "1"} + + # __getitem__ returns the per-frame dict + assert "coordinates" in asm[0] + assert "confidence" in asm[0] + assert "costs" in asm[0] + + # empty() convenience + empty = Assembler.empty( + max_n_individuals=1, + n_multibodyparts=1, + n_uniquebodyparts=0, + graph=graph, + paf_inds=paf_inds, + ) + assert isinstance(empty, Assembler) + assert empty.n_keypoints == 1 + + +# -------------------------------------------------------------------------------------- +# _flatten_detections +# -------------------------------------------------------------------------------------- + + +def test_flatten_detections_no_identity(): + frame = simple_two_label_scene() + joints = list(Assembler._flatten_detections(frame)) + # 2 labels * 2 detections + assert len(joints) == 4 + + # label IDs and groups + labels = sorted([j.label for j in joints]) + assert labels == [0, 0, 1, 1] + # identity absent → group = -1 + assert set(j.group for j in joints) == {-1} + + +def test_flatten_detections_with_identity(): + frame = simple_two_label_scene() + + # Add identity logits so that argmax → [0, 1] for both labels + id0 = np.array([[10.0, 0.0], [0.0, 10.0]]) # for label 0 + id1 = np.array([[10.0, 0.0], [0.0, 10.0]]) # for label 1 + frame["identity"] = [id0, id1] + + joints = list(Assembler._flatten_detections(frame)) + groups = [j.group for j in joints] + # we expect groups [0,1,0,1] in some order (2 per label) + assert set(groups) == {0, 1} + assert groups.count(0) == 2 + assert groups.count(1) == 2 + + +# -------------------------------------------------------------------------------------- +# extract_best_links +# -------------------------------------------------------------------------------------- + + +def test_extract_best_links_optimal_assignment(): + graph = [(0, 1)] + paf_inds = [0] + data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) + data["0"] = simple_two_label_scene() + + asm = Assembler( + data, + max_n_individuals=2, + n_multibodyparts=2, + greedy=False, # use Hungarian (maximize) + pcutoff=0.1, + min_affinity=0.05, + min_n_links=1, # avoid pruning 1-link assemblies in later steps + ) + + # Build joints_dict like _assemble does + joints = list(Assembler._flatten_detections(data["0"])) + bag = {} + for j in joints: + bag.setdefault(j.label, []).append(j) + + links = asm.extract_best_links(bag, data["0"]["costs"], trees=None) + # Expect 2 high-quality links: (coords0[0] ↔ coords1[0]) and (coords0[1] ↔ coords1[1]) + assert len(links) == 2 + + # Check that each link connects matching pairs (by position) + endpoints = [{tuple(l.j1.pos), tuple(l.j2.pos)} for l in links] + assert {(0.0, 0.0), (5.0, 0.0)} in endpoints + assert {(100.0, 100.0), (110.0, 100.0)} in endpoints + + # Affinity should be the matrix diagonal values ~0.95 and ~0.9 + vals = sorted([l.affinity for l in links], reverse=True) + assert vals[0] == pytest.approx(0.95, rel=1e-6) + assert vals[1] == pytest.approx(0.90, rel=1e-6) + + +def test_extract_best_links_greedy_with_thresholds(): + graph = [(0, 1)] + paf_inds = [0] + data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) + data["0"] = simple_two_label_scene() + + asm = Assembler( + data, + max_n_individuals=1, # greedy will stop after 1 disjoint pair chosen + n_multibodyparts=2, + greedy=True, + pcutoff=0.5, # conf product must exceed 0.25 + min_affinity=0.5, # low-affinity pairs excluded + min_n_links=1, + ) + + joints = list(Assembler._flatten_detections(data["0"])) + bag = {} + for j in joints: + bag.setdefault(j.label, []).append(j) + + links = asm.extract_best_links(bag, data["0"]["costs"], trees=None) + # Expect exactly 1 link due to max_n_individuals=1 in greedy picking + assert len(links) == 1 + s = {tuple(links[0].j1.pos), tuple(links[0].j2.pos)} + assert s == {(0.0, 0.0), (5.0, 0.0)} or s == {(100.0, 100.0), (110.0, 100.0)} + + +# -------------------------------------------------------------------------------------- +# build_assemblies +# -------------------------------------------------------------------------------------- + + +def test_build_assemblies_from_links(): + graph = [(0, 1)] + paf_inds = [0] + data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) + data["0"] = simple_two_label_scene() + + asm = Assembler( + data, + max_n_individuals=2, + n_multibodyparts=2, + greedy=False, + pcutoff=0.1, + min_affinity=0.05, + min_n_links=1, + ) + + joints = list(Assembler._flatten_detections(data["0"])) + bag = {} + for j in joints: + bag.setdefault(j.label, []).append(j) + + links = asm.extract_best_links(bag, data["0"]["costs"]) + assemblies, _ = asm.build_assemblies(links) + + # We expect two disjoint 2-joint assemblies + assert len(assemblies) == 2 + for a in assemblies: + assert a.n_links == 1 + assert len(a) == 2 + # affinity is the sum of link affinities for the assembly + assert a.affinity == pytest.approx(a._affinity / a.n_links) + + +# -------------------------------------------------------------------------------------- +# _assemble (per-frame) – main path without calibration +# -------------------------------------------------------------------------------------- + + +def test__assemble_main_no_calibration_returns_two_assemblies(): + graph = [(0, 1)] + paf_inds = [0] + data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) + data["0"] = simple_two_label_scene() + + asm = Assembler( + data, + max_n_individuals=2, + n_multibodyparts=2, + greedy=False, + pcutoff=0.1, + min_affinity=0.05, + min_n_links=1, + max_overlap=0.99, + window_size=0, + ) + + assemblies, unique = asm._assemble(data["0"], ind_frame=0) + assert unique is None # no unique bodyparts in this setting + assert len(assemblies) == 2 + assert all(len(a) == 2 for a in assemblies) + + +def test__assemble_returns_none_when_no_detections(): + graph = [(0, 1)] + paf_inds = [0] + data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) + + # Frame with zero coords (→ skipped by _flatten_detections) + coords0 = np.zeros((0, 2)) + conf0 = np.zeros((0,)) + coords1 = np.zeros((0, 2)) + conf1 = np.zeros((0,)) + frame = make_frame([coords0, coords1], [conf0, conf1], identity_per_label=None, costs={}) + data["0"] = frame + + asm = Assembler( + data, + max_n_individuals=2, + n_multibodyparts=2, + ) + assemblies, unique = asm._assemble(data["0"], ind_frame=0) + assert assemblies is None and unique is None + + +# -------------------------------------------------------------------------------------- +# assemble() over multiple frames + window_size KD-tree caching +# -------------------------------------------------------------------------------------- + + +def test_assemble_across_frames_updates_temporal_trees(): + graph = [(0, 1)] + paf_inds = [0] + data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0", "1"]) + + # Frame 0: baseline + frame0 = simple_two_label_scene() + + # Frame 1: nudge coordinates slightly, keep affinities similar + f1 = simple_two_label_scene() + f1["coordinates"][0][0] = f1["coordinates"][0][0] + np.array([[1.0, 0.0], [1.0, 0.0]]) + f1["coordinates"][0][1] = f1["coordinates"][0][1] + np.array([[1.0, 0.0], [1.0, 0.0]]) + + data["0"] = frame0 + data["1"] = f1 + + asm = Assembler( + data, + max_n_individuals=2, + n_multibodyparts=2, + window_size=1, # enable temporal memory + min_n_links=1, + ) + + # Use serial path to avoid multiprocessing in tests + asm.assemble(chunk_size=0) + + # KD-trees should be recorded for frames that had links + # Presence of keys 0 and 1 depends on links creation + assert 0 in asm._trees or 1 in asm._trees + assert isinstance(asm.assemblies, dict) + assert set(asm.assemblies.keys()).issubset({0, 1}) + + +# -------------------------------------------------------------------------------------- +# identity_only=True branch +# -------------------------------------------------------------------------------------- + + +def test_identity_only_branch_groups_by_identity(): + graph = [(0, 1)] + paf_inds = [0] + + # Build a frame with identities. Two groups (0 and 1). + base = simple_two_label_scene() + # identity logits such that each label has two detections belonging to groups 0 and 1 + id0 = np.array([[4.0, 1.0], [1.0, 4.0]]) # label 0 → group 0 then 1 + id1 = np.array([[4.0, 1.0], [1.0, 4.0]]) # label 1 → group 0 then 1 + base["identity"] = [id0, id1] + + data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) + data["0"] = base + + # identity_only=True should be enabled since "identity" is present in frame 0 + asm = Assembler( + data, + max_n_individuals=3, + n_multibodyparts=2, + identity_only=True, + pcutoff=0.1, + ) + + assemblies, unique = asm._assemble(data["0"], ind_frame=0) + # We expect at least one assembly created by grouping (one per identity that has both labels observed) + assert assemblies is not None + assert all(len(a) >= 1 for a in assemblies) + + +# -------------------------------------------------------------------------------------- +# Mahalanobis distance and link probability with a mocked KDE +# -------------------------------------------------------------------------------------- + + +@dataclass +class _FakeKDE: + mean: np.ndarray + inv_cov: np.ndarray + covariance: np.ndarray + d: int # dimension + + +def test_calc_assembly_mahalanobis_and_link_probability_with_fake_kde(): + # 2 multibody parts → pdist length = 1 + graph = [(0, 1)] + paf_inds = [0] + data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) + data["0"] = simple_two_label_scene() + + asm = Assembler( + data, + max_n_individuals=2, + n_multibodyparts=2, + min_n_links=1, + ) + + # Build a simple assembly with two joints + j0 = Joint((0.0, 0.0), 1.0, label=0, idx=0) + j1 = Joint((3.0, 4.0), 1.0, label=1, idx=1) + link = Link(j0, j1, affinity=1.0) + + a = Assembly(size=2) + a.add_link(link) + + # Fake KDE: one-dimensional (pairwise sq distance only) + # distance^2 = 5^2 = 25; set mean=25, identity covariance + fake = _FakeKDE( + mean=np.array([25.0]), + inv_cov=np.array([[1.0]]), + covariance=np.array([[1.0]]), + d=1, + ) + asm._kde = fake + asm.safe_edge = True + + # Mahalanobis should be finite and small because dist == mean + d = asm.calc_assembly_mahalanobis_dist(a) + assert np.isfinite(d) + assert d == pytest.approx(0.0, abs=1e-6) + + # Link probability depends on squared length vs mean; here z=0 → high prob + p = asm.calc_link_probability(link) + assert 0.0 <= p <= 1.0 + assert p == pytest.approx(1.0, rel=1e-6) + + +# -------------------------------------------------------------------------------------- +# I/O helpers: to_pickle / from_pickle / to_h5 (optional) +# -------------------------------------------------------------------------------------- + + +def test_to_pickle_and_from_pickle(tmp_path): + graph = [(0, 1)] + paf_inds = [0] + data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) + data["0"] = simple_two_label_scene() + + asm = Assembler( + data, + max_n_individuals=2, + n_multibodyparts=2, + min_n_links=1, + ) + + # build assemblies for frame 0 + assemblies, _ = asm._assemble(data["0"], 0) + asm.assemblies = {0: assemblies} + + pkl = tmp_path / "ass.pkl" + asm.to_pickle(str(pkl)) + + # Load into a new Assembler (empty schema is sufficient) + new_asm = Assembler.empty( + max_n_individuals=2, + n_multibodyparts=2, + n_uniquebodyparts=0, + graph=graph, + paf_inds=paf_inds, + ) + new_asm.from_pickle(str(pkl)) + assert 0 in new_asm.assemblies + assert isinstance(new_asm.assemblies[0], list) + assert new_asm.assemblies[0][0].shape == (2, 4) or True # presence is enough + + +@pytest.mark.skipif( + pytest.importorskip("tables", reason="PyTables required for HDF5") is None, reason="requires PyTables" +) +def test_to_h5_roundtrip(tmp_path): + graph = [(0, 1)] + paf_inds = [0] + data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) + data["0"] = simple_two_label_scene() + + asm = Assembler( + data, + max_n_individuals=2, + n_multibodyparts=2, + min_n_links=1, + ) + assemblies, _ = asm._assemble(data["0"], 0) + asm.assemblies = {0: assemblies} + + h5 = tmp_path / "ass.h5" + asm.to_h5(str(h5)) + + # Read back and perform basic structural assertions + df = pd.read_hdf(str(h5), key="ass") + # one frame, 2 individuals, 2 bodyparts, coords {x,y,likelihood} + # df shape will be (frames, 2*2*3) + assert df.shape[0] == 1 + assert df.columns.nlevels == 4 + assert set(df.columns.get_level_values("coords")) == {"x", "y", "likelihood"} diff --git a/tests/tests_core/test_assembly.py b/tests/tests_core/test_assembly.py new file mode 100644 index 0000000..9471f81 --- /dev/null +++ b/tests/tests_core/test_assembly.py @@ -0,0 +1,271 @@ +import numpy as np +import pytest + +from dlclive.core.inferenceutils import Assembly, Joint, Link + +# --------------------------- +# Basic construction +# --------------------------- + + +def test_assembly_init(): + assemb = Assembly(size=5) + assert assemb.data.shape == (5, 4) + + # col 0,1,3 are NaN, col 2 is confidence=0 + # this is due to confidence being initialized to 0 + # by default in Assembly.__init__ + assert np.isnan(assemb.data[:, 0]).all() + assert np.isnan(assemb.data[:, 1]).all() + assert (assemb.data[:, 2] == 0).all() + assert np.isnan(assemb.data[:, 3]).all() + + assert assemb._affinity == 0 + assert assemb._links == [] + assert assemb._visible == set() + assert assemb._idx == set() + + +# --------------------------- +# from_array +# --------------------------- + + +def test_assembly_from_array_basic_xy_only(): + arr = np.array( + [ + [10.0, 20.0], + [30.0, 40.0], + ] + ) + assemb = Assembly.from_array(arr.copy()) + + # full shape (n_bodyparts, 4) + assert assemb.data.shape == (2, 4) + + # xy preserved + assert np.allclose(assemb.data[:, :2], arr) + + # confidence auto-set to 1 where xy is present + assert np.allclose(assemb.data[:, 2], np.array([1.0, 1.0])) + + # labels visible + assert assemb._visible == {0, 1} + + +def test_assembly_from_array_with_nans(): + arr = np.array( + [ + [10.0, 20.0, 0.9], + [np.nan, 5.0, 0.8], # one NaN → entire row becomes NaN + ] + ) + assemb = Assembly.from_array(arr.copy()) + + assert np.allclose(assemb.data[0], [10.0, 20.0, 0.9, np.nan], equal_nan=True) + assert np.isnan(assemb.data[1]).all() + + # visible only includes fully non-NaN rows + assert assemb._visible == {0} + + +# --------------------------- +# add_joint / remove_joint +# --------------------------- + + +def test_add_joint_and_remove_joint(): + assemb = Assembly(size=3) + j0 = Joint(pos=(1.0, 2.0), confidence=0.5, label=0, idx=10) + j1 = Joint(pos=(3.0, 4.0), confidence=0.8, label=1, idx=11) + + # adding first joint + assert assemb.add_joint(j0) is True + assert assemb._visible == {0} + assert assemb._idx == {10} + assert np.allclose(assemb.data[0], [1.0, 2.0, 0.5, j0.group]) + + # adding second joint + assert assemb.add_joint(j1) is True + assert assemb._visible == {0, 1} + assert assemb._idx == {10, 11} + + # adding same joint label again → ignored + assert assemb.add_joint(j0) is False + + # removing joint + assert assemb.remove_joint(j1) is True + assert assemb._visible == {0} + assert assemb._idx == {10} + assert np.isnan(assemb.data[1]).all() + + # remove nonexistent → False + assert assemb.remove_joint(j1) is False + + +# --------------------------- +# add_link (simple) +# --------------------------- + + +def test_add_link_adds_joints_and_affinity(): + assemb = Assembly(size=3) + + j0 = Joint(pos=(0.0, 0.0), confidence=1.0, label=0, idx=100) + j1 = Joint(pos=(1.0, 0.0), confidence=1.0, label=1, idx=101) + link = Link(j0, j1, affinity=0.7) + + # New link → adds both joints + result = assemb.add_link(link) + assert result is True + assert assemb.n_links == 1 + assert assemb._affinity == pytest.approx(0.7) + assert assemb._visible == {0, 1} + assert assemb._idx == {100, 101} + + # Add same link again → both idx already present → only increases affinity, no new joints + result = assemb.add_link(link) + assert result is False # as per code path + assert assemb.n_links == 2 # link appended again + assert assemb._affinity == pytest.approx(1.4) # 0.7 + 0.7 + + +# --------------------------- +# extent, area, xy +# --------------------------- + + +def test_extent_and_area(): + assemb = Assembly(size=3) + # manually set data: [x, y, conf, group] + assemb.data[:] = np.nan + assemb.data[0, :2] = [10, 10] + assemb.data[1, :2] = [20, 40] + assemb._visible.update({0, 1}) + + # extent = (min_x, min_y, max_x, max_y) + assert np.allclose(assemb.extent, [10, 10, 20, 40]) + + # area = dx * dy = (20-10) * (40-10) = 10 * 30 + assert assemb.area == pytest.approx(300) + + +# --------------------------- +# intersection_with +# --------------------------- + + +def test_intersection_with_partial_overlap(): + ass1 = Assembly(size=2) + ass1.data[0, :2] = [0, 0] + ass1.data[1, :2] = [10, 10] + ass1._visible.update({0, 1}) + + ass2 = Assembly(size=2) + ass2.data[0, :2] = [5, 5] + ass2.data[1, :2] = [15, 15] + ass2._visible.update({0, 1}) + + # They overlap in a square of area 5x5 around (5,5)-(10,10). + # Each assembly has 2 points. Points inside overlap: + # ass1: both (0,0) no, (10,10) yes → 1 / 2 = 0.5 + # ass2: (5,5) yes, (15,15) no → 1 / 2 = 0.5 + assert ass1.intersection_with(ass2) == pytest.approx(0.5) + + +# --------------------------- +# confidence property +# --------------------------- + + +def test_confidence_property(): + assemb = Assembly(size=3) + assemb.data[:] = np.nan + assemb.data[:, 2] = [0.2, 0.4, np.nan] # mean of finite = (0.2+0.4)/2 = 0.3 + assert assemb.confidence == pytest.approx(0.3) + + assemb.confidence = 0.9 + assert np.allclose(assemb.data[:, 2], [0.9, 0.9, 0.9], equal_nan=True) + + +# --------------------------- +# soft_identity +# --------------------------- + + +def test_soft_identity_simple(): + # data format: x, y, conf, group + assemb = Assembly(size=3) + assemb.data[:] = np.nan + assemb.data[0] = [0, 0, 1.0, 0] + assemb.data[1] = [5, 5, 0.5, 0] + assemb.data[2] = [10, 10, 1.0, 1] + assemb._visible = {0, 1, 2} + + # groups: 0 → weights 1.0 and 0.5 (avg=0.75) + # 1 → weight 1.0 + # softmax([0.75, 1.0]) ≈ [...] + soft = assemb.soft_identity + assert set(soft.keys()) == {0, 1} + s0, s1 = soft[0], soft[1] + assert pytest.approx(s0 + s1) == 1.0 + assert s1 > s0 + + +# --------------------------- +# intersection operator: __contains__ +# --------------------------- + + +def test_contains_checks_shared_idx(): + ass1 = Assembly(size=3) + ass2 = Assembly(size=3) + + j0 = Joint((0, 0), confidence=1.0, label=0, idx=10) + j1 = Joint((1, 1), confidence=1.0, label=1, idx=99) + + ass1.add_joint(j0) + ass2.add_joint(j1) + + # different idx sets → no intersection + assert (ass2 in ass1) is False + + ass2.add_joint(j0) + # now share idx=10 + assert (ass2 in ass1) is True + + +# --------------------------- +# assembly addition (__add__) +# --------------------------- + + +def test_assembly_addition_combines_links(): + a1 = Assembly(size=4) + a2 = Assembly(size=4) + + j0 = Joint((0, 0), 1.0, label=0, idx=10) + j1 = Joint((1, 0), 1.0, label=1, idx=11) + j2 = Joint((2, 0), 1.0, label=2, idx=12) + j3 = Joint((3, 0), 1.0, label=3, idx=13) + + l01 = Link(j0, j1, affinity=0.5) + l23 = Link(j2, j3, affinity=0.8) + + a1.add_link(l01) + a2.add_link(l23) + + # now they share NO joints → addition should succeed + result = a1 + a2 + + assert result.n_links == 2 + assert result._affinity == pytest.approx(1.3) + + # original assemblies unchanged + assert a1.n_links == 1 + assert a2.n_links == 1 + + # now purposely make them share a joint → should raise + a2.add_joint(j0) + with pytest.raises(ArithmeticError): + _ = a1 + a2 From 92fc40624bc4ed5183304ddf6b10660ba2373e36 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 26 Jan 2026 11:41:53 +0100 Subject: [PATCH 06/20] Update code duplication marker --- dlclive/core/inferenceutils.py | 3 +- dlclive/modelzoo/resolve_config.py | 19 +++++------- dlclive/modelzoo/utils.py | 49 ++++++++++++++---------------- pyproject.toml | 2 +- tests/test_modelzoo.py | 12 +++----- 5 files changed, 38 insertions(+), 47 deletions(-) diff --git a/dlclive/core/inferenceutils.py b/dlclive/core/inferenceutils.py index dc2f1f5..e69ff5c 100644 --- a/dlclive/core/inferenceutils.py +++ b/dlclive/core/inferenceutils.py @@ -10,7 +10,8 @@ # -# NOTE DUPLICATED from deeplabcut/core/inferenceutils.py +# NOTE - DUPLICATED @C-Achard 2026-26-01: Copied from the original DeepLabCut codebase +# from deeplabcut/core/inferenceutils.py from __future__ import annotations import heapq diff --git a/dlclive/modelzoo/resolve_config.py b/dlclive/modelzoo/resolve_config.py index bea25f5..4508eb0 100644 --- a/dlclive/modelzoo/resolve_config.py +++ b/dlclive/modelzoo/resolve_config.py @@ -1,9 +1,10 @@ """ -Helper function to deal with default values in the model configuration. +Helper function to deal with default values in the model configuration. For instance, "num_bodyparts x 2" is replaced with the number of bodyparts multiplied by 2. """ -# NOTE JR 2026-23-01: This is duplicate code, copied from the original DeepLabCut-Live codebase. +# NOTE - DUPLICATED @deruyter92 2026-23-01: Copied from the original DeepLabCut codebase +# from deeplabcut/pose_estimation_pytorch/modelzoo/utils.py import copy @@ -78,8 +79,7 @@ def get_updated_value(variable: str) -> int | list[int]: var_name = var_parts[0] if updated_values[var_name] is None: raise ValueError( - f"Found {variable} in the configuration file, but there is no default " - f"value for this variable." + f"Found {variable} in the configuration file, but there is no default value for this variable." ) if len(var_parts) == 1: @@ -99,9 +99,7 @@ def get_updated_value(variable: str) -> int | list[int]: else: raise ValueError(f"Unknown operator for variable: {variable}") - raise ValueError( - f"Found {variable} in the configuration file, but cannot parse it." - ) + raise ValueError(f"Found {variable} in the configuration file, but cannot parse it.") updated_values = { "num_bodyparts": num_bodyparts, @@ -127,10 +125,7 @@ def get_updated_value(variable: str) -> int | list[int]: backbone_output_channels, **kwargs, ) - elif ( - isinstance(config[k], str) - and config[k].strip().split(" ")[0] in updated_values.keys() - ): + elif isinstance(config[k], str) and config[k].strip().split(" ")[0] in updated_values.keys(): config[k] = get_updated_value(config[k]) - return config \ No newline at end of file + return config diff --git a/dlclive/modelzoo/utils.py b/dlclive/modelzoo/utils.py index 1aab80b..341bd5c 100644 --- a/dlclive/modelzoo/utils.py +++ b/dlclive/modelzoo/utils.py @@ -5,55 +5,53 @@ # This should be removed once a solution is found to address duplicate code. import copy -from pathlib import Path import logging +from pathlib import Path +from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model from ruamel.yaml import YAML -from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model from dlclive.modelzoo.resolve_config import update_config -_MODELZOO_PATH = Path(__file__).parent +_MODELZOO_PATH = Path(__file__).parent def get_super_animal_model_config_path(model_name: str) -> Path: """Get the path to the model configuration file for a model and validate choice of model""" - cfg_path = _MODELZOO_PATH / 'model_configs' / f"{model_name}.yaml" + cfg_path = _MODELZOO_PATH / "model_configs" / f"{model_name}.yaml" if not cfg_path.exists(): raise FileNotFoundError( - f"Modelzoo model configuration file not found: {cfg_path} " - f"Available models: {list_available_models()}" + f"Modelzoo model configuration file not found: {cfg_path} Available models: {list_available_models()}" ) return cfg_path def get_super_animal_project_config_path(super_animal: str) -> Path: """Get the path to the project configuration file for a project and validate choice of project""" - cfg_path = _MODELZOO_PATH / 'project_configs' / f"{super_animal}.yaml" + cfg_path = _MODELZOO_PATH / "project_configs" / f"{super_animal}.yaml" if not cfg_path.exists(): raise FileNotFoundError( - f"Modelzoo project configuration file not found: {cfg_path}" - f"Available projects: {list_available_projects()}" + f"Modelzoo project configuration file not found: {cfg_path}Available projects: {list_available_projects()}" ) return cfg_path def get_snapshot_folder_path() -> Path: - return _MODELZOO_PATH / 'snapshots' + return _MODELZOO_PATH / "snapshots" def list_available_models() -> list[str]: - return [p.stem for p in _MODELZOO_PATH.glob('model_configs/*.yaml')] + return [p.stem for p in _MODELZOO_PATH.glob("model_configs/*.yaml")] def list_available_projects() -> list[str]: - return [p.stem for p in _MODELZOO_PATH.glob('project_configs/*.yaml')] + return [p.stem for p in _MODELZOO_PATH.glob("project_configs/*.yaml")] def list_available_combinations() -> list[str]: models = list_available_models() projects = list_available_projects() - combinations = ['_'.join([p, m]) for p in projects for m in models] + combinations = ["_".join([p, m]) for p in projects for m in models] return combinations @@ -65,14 +63,18 @@ def read_config_as_dict(config_path: str | Path) -> dict: Returns: The configuration file with pure Python classes """ - with open(config_path, "r") as f: - cfg = YAML(typ='safe', pure=True).load(f) + with open(config_path) as f: + cfg = YAML(typ="safe", pure=True).load(f) return cfg -# NOTE JR 2026-23-01: This is duplicate code, copied from the original DeepLabCut-Live codebase. -def add_metadata(project_config: dict, config: dict,) -> dict: +# NOTE - DUPLICATED @deruyter92 2026-23-01: Copied from the original DeepLabCut codebase +# from deeplabcut/pose_estimation_pytorch/config/make_pose_config.py +def add_metadata( + project_config: dict, + config: dict, +) -> dict: """Adds metadata to a pytorch pose configuration Args: @@ -95,7 +97,8 @@ def add_metadata(project_config: dict, config: dict,) -> dict: return config -# NOTE JR 2026-23-01: This is duplicate code, copied from the original DeepLabCut-Live codebase. +# NOTE - DUPLICATED @deruyter92 2026-23-01: Copied from the original DeepLabCut codebase +# from deeplabcut/pose_estimation_pytorch/modelzoo/utils.py def load_super_animal_config( super_animal: str, model_name: str, @@ -128,9 +131,7 @@ def load_super_animal_config( else: model_config["method"] = "TD" if super_animal != "superanimal_humanbody": - detector_cfg_path = get_super_animal_model_config_path( - model_name=detector_name - ) + detector_cfg_path = get_super_animal_model_config_path(model_name=detector_name) detector_cfg = read_config_as_dict(detector_cfg_path) model_config["detector"] = detector_cfg return model_config @@ -159,9 +160,7 @@ def download_super_animal_snapshot(dataset: str, model_name: str) -> Path: return model_path try: - download_huggingface_model( - model_name, target_dir=str(snapshot_dir), rename_mapping=model_filename - ) + download_huggingface_model(model_name, target_dir=str(snapshot_dir), rename_mapping=model_filename) if not model_path.exists(): raise RuntimeError(f"Failed to download {model_name} to {model_path}") @@ -171,5 +170,3 @@ def download_super_animal_snapshot(dataset: str, model_name: str) -> Path: raise e return model_path - - diff --git a/pyproject.toml b/pyproject.toml index b448150..2456749 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,5 +97,5 @@ target-version = "py310" fix = true line-length = 120 -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" diff --git a/tests/test_modelzoo.py b/tests/test_modelzoo.py index 1a7e6db..997b0ba 100644 --- a/tests/test_modelzoo.py +++ b/tests/test_modelzoo.py @@ -1,17 +1,15 @@ -# NOTE JR 2026-23-01: This is duplicate code, copied from the original DeepLabCut-Live codebase. - +# NOTE - DUPLICATED @deruyter92 2026-23-01: Copied from the original DeepLabCut codebase +# from deeplabcut/tests/pose_estimation_pytorch/modelzoo/test_modelzoo_utils.py import os -import pytest import dlclibrary +import pytest from dlclibrary.dlcmodelzoo.modelzoo_download import MODELOPTIONS from dlclive import modelzoo -@pytest.mark.parametrize( - "super_animal", ["superanimal_quadruped", "superanimal_topviewmouse"] -) +@pytest.mark.parametrize("super_animal", ["superanimal_quadruped", "superanimal_topviewmouse"]) @pytest.mark.parametrize("model_name", ["hrnet_w32"]) @pytest.mark.parametrize("detector_name", [None, "fasterrcnn_resnet50_fpn_v2"]) def test_get_config_model_paths(super_animal, model_name, detector_name): @@ -48,4 +46,4 @@ def test_download_huggingface_wrong_model(): @pytest.mark.skip(reason="slow") @pytest.mark.parametrize("model", MODELOPTIONS) def test_download_all_models(tmp_path_factory, model): - test_download_huggingface_model(tmp_path_factory, model) \ No newline at end of file + test_download_huggingface_model(tmp_path_factory, model) From 131fdef026ab8b023987e799e448ea8bf14373ce Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 26 Jan 2026 13:51:21 +0100 Subject: [PATCH 07/20] Display testing & fixes Corrects color sampling in Display to avoid zero step and ensures image is always defined in display_frame. Adds comprehensive tests for Display, including headless environment setup, frame display, cutoff logic, window destruction, and color sampling safety. --- dlclive/display.py | 43 +++++---------- tests/conftest.py | 22 ++++++++ tests/test_display.py | 120 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 155 insertions(+), 30 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_display.py diff --git a/dlclive/display.py b/dlclive/display.py index 0d1c924..0304552 100644 --- a/dlclive/display.py +++ b/dlclive/display.py @@ -7,7 +7,9 @@ try: from tkinter import Label, Tk + from PIL import ImageTk + _TKINTER_AVAILABLE = True except ImportError: _TKINTER_AVAILABLE = False @@ -33,9 +35,7 @@ class Display: def __init__(self, cmap="bmy", radius=3, pcutoff=0.5): if not _TKINTER_AVAILABLE: - raise ImportError( - "tkinter is not available. Display functionality requires tkinter. " - ) + raise ImportError("tkinter is not available. Display functionality requires tkinter. ") self.cmap = cmap self.colors = None self.radius = radius @@ -59,7 +59,9 @@ def set_display(self, im_size, bodyparts): self.lab.pack() all_colors = getattr(cc, self.cmap) - self.colors = all_colors[:: int(len(all_colors) / bodyparts)] + # Avoid 0 step + step = max(1, int(len(all_colors) / bodyparts)) + self.colors = all_colors[::step] def display_frame(self, frame, pose=None): """ @@ -75,10 +77,10 @@ def display_frame(self, frame, pose=None): """ if not _TKINTER_AVAILABLE: raise ImportError("tkinter is not available. Cannot display frames.") - + im_size = (frame.shape[1], frame.shape[0]) + img = Image.fromarray(frame) # avoid undefined image if pose is None if pose is not None: - img = Image.fromarray(frame) draw = ImageDraw.Draw(img) if len(pose.shape) == 2: @@ -91,33 +93,14 @@ def display_frame(self, frame, pose=None): for j in range(pose.shape[1]): if pose[i, j, 2] > self.pcutoff: try: - x0 = ( - pose[i, j, 0] - self.radius - if pose[i, j, 0] - self.radius > 0 - else 0 - ) - x1 = ( - pose[i, j, 0] + self.radius - if pose[i, j, 0] + self.radius < im_size[0] - else im_size[1] - ) - y0 = ( - pose[i, j, 1] - self.radius - if pose[i, j, 1] - self.radius > 0 - else 0 - ) - y1 = ( - pose[i, j, 1] + self.radius - if pose[i, j, 1] + self.radius < im_size[1] - else im_size[0] - ) + x0 = pose[i, j, 0] - self.radius if pose[i, j, 0] - self.radius > 0 else 0 + x1 = pose[i, j, 0] + self.radius if pose[i, j, 0] + self.radius < im_size[0] else im_size[1] + y0 = pose[i, j, 1] - self.radius if pose[i, j, 1] - self.radius > 0 else 0 + y1 = pose[i, j, 1] + self.radius if pose[i, j, 1] + self.radius < im_size[1] else im_size[0] coords = [x0, y0, x1, y1] - draw.ellipse( - coords, fill=self.colors[j], outline=self.colors[j] - ) + draw.ellipse(coords, fill=self.colors[j], outline=self.colors[j]) except Exception as e: print(e) - img_tk = ImageTk.PhotoImage(image=img, master=self.window) self.lab.configure(image=img_tk) self.window.update() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ebee7e3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,22 @@ +import pytest + + +@pytest.fixture +def headless_display_env(monkeypatch): + # Import module under test + from test_display import FakeLabel, FakePhotoImage, FakeTk + + import dlclive.display as display_mod + + # Force tkinter availability and patch UI components + monkeypatch.setattr(display_mod, "_TKINTER_AVAILABLE", True, raising=False) + monkeypatch.setattr(display_mod, "Tk", FakeTk, raising=False) + monkeypatch.setattr(display_mod, "Label", FakeLabel, raising=False) + + # Patch ImageTk.PhotoImage + class FakeImageTkModule: + PhotoImage = FakePhotoImage + + monkeypatch.setattr(display_mod, "ImageTk", FakeImageTkModule, raising=False) + + return display_mod diff --git a/tests/test_display.py b/tests/test_display.py new file mode 100644 index 0000000..0f6b555 --- /dev/null +++ b/tests/test_display.py @@ -0,0 +1,120 @@ +import numpy as np +import pytest + + +class FakeTk: + def __init__(self): + self.titles = [] + self.updated = 0 + self.destroyed = False + + def title(self, text): + self.titles.append(text) + + def update(self): + self.updated += 1 + + def destroy(self): + self.destroyed = True + + +class FakeLabel: + def __init__(self, window): + self.window = window + self.packed = False + self.configured = {} + + def pack(self): + self.packed = True + + def configure(self, **kwargs): + self.configured.update(kwargs) + + +class FakePhotoImage: + def __init__(self, image=None, master=None): + self.image = image + self.master = master + + +def test_display_init_raises_when_tk_unavailable(monkeypatch): + import dlclive.display as display_mod + + monkeypatch.setattr(display_mod, "_TKINTER_AVAILABLE", False, raising=False) + + with pytest.raises(ImportError): + display_mod.Display() + + +def test_display_frame_creates_window_and_updates(headless_display_env): + display_mod = headless_display_env + disp = display_mod.Display(radius=3, pcutoff=0.5) + + frame = np.zeros((100, 120, 3), dtype=np.uint8) + pose = np.array([[[10, 10, 0.9], [50, 50, 0.2]]]) # 1 animal, 2 bodyparts + + disp.display_frame(frame, pose) + + assert disp.window is not None + assert disp.lab is not None + assert disp.lab.packed is True + assert disp.window.updated == 1 + assert "image" in disp.lab.configured # configured with PhotoImage + + +def test_display_draws_only_points_above_cutoff(headless_display_env, monkeypatch): + display_mod = headless_display_env + disp = display_mod.Display(radius=3, pcutoff=0.5) + + frame = np.zeros((100, 100, 3), dtype=np.uint8) + pose = np.array( + [ + [ + [10, 10, 0.9], # draw + [20, 20, 0.49], # don't draw + [30, 30, 0.5001], # draw (>=) + ] + ], + dtype=float, + ) + + ellipses = [] + + class DrawRecorder: + def ellipse(self, coords, fill=None, outline=None): + ellipses.append((coords, fill, outline)) + + monkeypatch.setattr(display_mod.ImageDraw, "Draw", lambda img: DrawRecorder()) + + disp.display_frame(frame, pose) + + assert len(ellipses) == 2 + + +def test_destroy_calls_window_destroy(headless_display_env): + display_mod = headless_display_env + disp = display_mod.Display() + + frame = np.zeros((10, 10, 3), dtype=np.uint8) + pose = np.array([[[5, 5, 0.9]]]) + + disp.display_frame(frame, pose) + disp.destroy() + + assert disp.window.destroyed is True + + +def test_set_display_color_sampling_safe(headless_display_env, monkeypatch): + display_mod = headless_display_env + + # Provide a fixed colormap list + class FakeCC: + bmy = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 0), (0, 1, 1), (1, 0, 1)] + + monkeypatch.setattr(display_mod, "cc", FakeCC) + + disp = display_mod.Display(cmap="bmy") + disp.set_display(im_size=(100, 100), bodyparts=3) + + assert disp.colors is not None + assert len(disp.colors) >= 3 From aa53622fff2a86b0947a7521a8f77bfb99aec92d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 26 Jan 2026 14:09:12 +0100 Subject: [PATCH 08/20] Update CI workflow and add duplication note to dynamic_cropping.py Removed the --no-cache flag from 'uv sync' in the testing workflow, enhanced pytest coverage reporting, and added a step to summarize coverage in the GitHub Actions job summary. Added a note in dynamic_cropping.py about duplication with another file and referenced existing tests. --- .github/workflows/testing.yml | 15 ++++++- .../dynamic_cropping.py | 39 +++++++------------ 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 17ee47d..dc3b765 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -53,7 +53,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install the project - run: uv sync --no-cache --all-extras --dev + run: uv sync --all-extras --dev shell: bash - name: Install ffmpeg @@ -72,9 +72,20 @@ jobs: run: uv run dlc-live-test --nodisplay - name: Run DLC Live Unit Tests - run: uv run pytest --cov=dlclive --cov-report=xml + run: uv run pytest --cov=dlclive --cov-report=xml --cov-report=term-missing - name: Coverage Report uses: codecov/codecov-action@v5 with: files: ./coverage.xml + flags: ${{ matrix.os }}-py${{ matrix.python-version }} + name: codecov-${{ matrix.os }}-py${{ matrix.python-version }} + - name: Add coverage to job summary + if: always() + shell: bash + run: | + uv run python -m coverage report -m > coverage.txt + echo "## Coverage (dlclive)" >> "$GITHUB_STEP_SUMMARY" + echo '```' >> "$GITHUB_STEP_SUMMARY" + cat coverage.txt >> "$GITHUB_STEP_SUMMARY" + echo '```' >> "$GITHUB_STEP_SUMMARY" diff --git a/dlclive/pose_estimation_pytorch/dynamic_cropping.py b/dlclive/pose_estimation_pytorch/dynamic_cropping.py index ae5991f..27a1348 100644 --- a/dlclive/pose_estimation_pytorch/dynamic_cropping.py +++ b/dlclive/pose_estimation_pytorch/dynamic_cropping.py @@ -7,13 +7,16 @@ # https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS # # Licensed under GNU Lesser General Public License v3.0 -# + +# NOTE DUPLICATED @C-Achard 2026-26-01: Duplication between this file +# and deeplabcut/pose_estimation_pytorch/runners/dynamic_cropping.py +# NOTE Testing already exists at deeplabcut/tests/pose_estimation_pytorch/runners/test_dynamic_cropper.py """Modules to dynamically crop individuals out of videos to improve video analysis""" + from __future__ import annotations import math from dataclasses import dataclass, field -from typing import Optional import torch import torchvision.transforms.functional as F @@ -79,10 +82,7 @@ def crop(self, image: torch.Tensor) -> torch.Tensor: height. """ if len(image) != 1: - raise RuntimeError( - "DynamicCropper can only be used with batch size 1 (found image " - f"shape: {image.shape})" - ) + raise RuntimeError(f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})") if self._shape is None: self._shape = image.shape[3], image.shape[2] @@ -114,7 +114,7 @@ def update(self, pose: torch.Tensor) -> torch.Tensor: The pose, with coordinates updated to the full image space. """ if self._shape is None: - raise RuntimeError(f"You must call `crop` before calling `update`.") + raise RuntimeError("You must call `crop` before calling `update`.") # offset the pose to the original image space offset_x, offset_y = 0, 0 @@ -153,9 +153,7 @@ def reset(self) -> None: self._crop = None @staticmethod - def build( - dynamic: bool, threshold: float, margin: int - ) -> Optional["DynamicCropper"]: + def build(dynamic: bool, threshold: float, margin: int) -> DynamicCropper | None: """Builds the DynamicCropper based on the given parameters Args: @@ -309,10 +307,7 @@ def crop(self, image: torch.Tensor) -> torch.Tensor: `crop` was previously called with an image of a different W or H. """ if len(image) != 1: - raise RuntimeError( - "DynamicCropper can only be used with batch size 1 (found image " - f"shape: {image.shape})" - ) + raise RuntimeError(f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})") if self._shape is None: self._shape = image.shape[3], image.shape[2] @@ -349,7 +344,7 @@ def update(self, pose: torch.Tensor) -> torch.Tensor: The pose, with coordinates updated to the full image space. """ if self._shape is None: - raise RuntimeError(f"You must call `crop` before calling `update`.") + raise RuntimeError("You must call `crop` before calling `update`.") # check whether this was a patched crop batch_size = pose.shape[0] @@ -399,9 +394,7 @@ def update(self, pose: torch.Tensor) -> torch.Tensor: return pose - def _prepare_bounding_box( - self, x1: int, y1: int, x2: int, y2: int - ) -> tuple[int, int, int, int]: + def _prepare_bounding_box(self, x1: int, y1: int, x2: int, y2: int) -> tuple[int, int, int, int]: """Prepares the bounding box for cropping. Adds a margin around the bounding box, then transforms it into the target aspect @@ -498,12 +491,8 @@ def generate_patches(self) -> list[tuple[int, int, int, int]]: Returns: A list of patch coordinates as tuples (x0, y0, x1, y1). """ - patch_xs = self.split_array( - self._shape[0], self._patch_counts[0], self._patch_overlap - ) - patch_ys = self.split_array( - self._shape[1], self._patch_counts[1], self._patch_overlap - ) + patch_xs = self.split_array(self._shape[0], self._patch_counts[0], self._patch_overlap) + patch_ys = self.split_array(self._shape[1], self._patch_counts[1], self._patch_overlap) patches = [] for y0, y1 in patch_ys: @@ -534,7 +523,7 @@ def split_array(size: int, n: int, overlap: int) -> list[tuple[int, int]]: segment_size = (padded_size // n) + (padded_size % n > 0) segments = [] end = overlap - for i in range(n): + for _i in range(n): start = end - overlap end = start + segment_size if end > size: From fd05cefb220041097d8098b8daf4d340e219b2a6 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 26 Jan 2026 14:54:16 +0100 Subject: [PATCH 09/20] Add 'slow' marker to pytest config and tests --- pytest.ini | 5 +++-- tests/test_benchmark_script.py | 19 ++++++++----------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/pytest.ini b/pytest.ini index c878400..701bc93 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,9 +1,10 @@ [pytest] markers = - functional: functional tests + functional: Functional tests (high-level) + slow: Slow tests filterwarnings = # Suppress NumPy deprecation warning from Keras/TensorFlow about np.object ignore::FutureWarning:keras.* ignore::FutureWarning:tensorflow.* - ignore:In the future `np.object` will be defined:FutureWarning \ No newline at end of file + ignore:In the future `np.object` will be defined:FutureWarning diff --git a/tests/test_benchmark_script.py b/tests/test_benchmark_script.py index 58c1533..875e63d 100644 --- a/tests/test_benchmark_script.py +++ b/tests/test_benchmark_script.py @@ -1,5 +1,7 @@ import glob + import pytest + from dlclive import benchmark_videos, download_benchmarking_data from dlclive.engine import Engine @@ -10,7 +12,9 @@ def datafolder(tmp_path): download_benchmarking_data(str(datafolder)) return datafolder + @pytest.mark.functional +@pytest.mark.slow def test_benchmark_script_runs_tf_backend(tmp_path, datafolder): dog_models = glob.glob(str(datafolder / "dog" / "*[!avi]")) dog_video = glob.glob(str(datafolder / "dog" / "*.avi"))[0] @@ -27,11 +31,7 @@ def test_benchmark_script_runs_tf_backend(tmp_path, datafolder): print(f"Running dog model: {model_path}") benchmark_videos( model_path=model_path, - model_type=( - "base" - if Engine.from_model_path(model_path) == Engine.TENSORFLOW - else "pytorch" - ), + model_type=("base" if Engine.from_model_path(model_path) == Engine.TENSORFLOW else "pytorch"), video_path=dog_video, output=str(out_dir), n_frames=n_frames, @@ -42,11 +42,7 @@ def test_benchmark_script_runs_tf_backend(tmp_path, datafolder): print(f"Running mouse model: {model_path}") benchmark_videos( model_path=model_path, - model_type=( - "base" - if Engine.from_model_path(model_path) == Engine.TENSORFLOW - else "pytorch" - ), + model_type=("base" if Engine.from_model_path(model_path) == Engine.TENSORFLOW else "pytorch"), video_path=mouse_video, output=str(out_dir), n_frames=n_frames, @@ -58,6 +54,7 @@ def test_benchmark_script_runs_tf_backend(tmp_path, datafolder): @pytest.mark.parametrize("model_name", ["hrnet_w32", "resnet_50"]) @pytest.mark.functional +@pytest.mark.slow def test_benchmark_script_with_torch_modelzoo(tmp_path, datafolder, model_name): from dlclive import modelzoo @@ -107,4 +104,4 @@ def test_benchmark_script_with_torch_modelzoo(tmp_path, datafolder, model_name): # Assertions: verify output files were created output_files = list(out_dir.iterdir()) assert len(output_files) > 0, "No output files were created by benchmark_videos" - assert any(f.suffix == ".pickle" for f in output_files), "No pickle files found in output directory" \ No newline at end of file + assert any(f.suffix == ".pickle" for f in output_files), "No pickle files found in output directory" From 7f6773412380bb988ae4631917c5091474d1fd01 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 27 Jan 2026 09:01:16 +0100 Subject: [PATCH 10/20] Add pose attribute to DLCLive class Introduces a new 'pose' attribute to the DLCLive class, initialized as None or a numpy ndarray. This prepares the class for storing pose data. --- dlclive/dlclive.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/dlclive/dlclive.py b/dlclive/dlclive.py index 1dcb88f..965aab6 100644 --- a/dlclive/dlclive.py +++ b/dlclive/dlclive.py @@ -4,6 +4,7 @@ Licensed under GNU Lesser General Public License v3.0 """ + from __future__ import annotations from pathlib import Path @@ -197,12 +198,12 @@ def __init__( self.processor = processor self.convert2rgb = convert2rgb + self.pose: np.ndarray | None = None + if isinstance(display, Display): self.display = display elif display: - self.display = Display( - pcutoff=pcutoff, radius=display_radius, cmap=display_cmap - ) + self.display = Display(pcutoff=pcutoff, radius=display_radius, cmap=display_cmap) else: self.display = None @@ -250,9 +251,7 @@ def process_frame(self, frame: np.ndarray) -> np.ndarray: processed frame: convert type, crop, convert color """ if self.cropping: - frame = frame[ - self.cropping[2] : self.cropping[3], self.cropping[0] : self.cropping[1] - ] + frame = frame[self.cropping[2] : self.cropping[3], self.cropping[0] : self.cropping[1]] if self.dynamic[0]: if self.pose is not None: @@ -263,9 +262,7 @@ def process_frame(self, frame: np.ndarray) -> np.ndarray: elif len(self.pose) == 1: pose = self.pose[0] else: - raise ValueError( - "Cannot use Dynamic Cropping - more than 1 individual found" - ) + raise ValueError("Cannot use Dynamic Cropping - more than 1 individual found") else: pose = self.pose From 287cf2f242684d74b899e75ad38adc3fc9282b4c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 2 Feb 2026 10:10:46 +0100 Subject: [PATCH 11/20] Refactor assembler tests to use fixtures Move test helpers into tests/conftest.py and introduce a suite of reusable pytest fixtures for assembler testing (headless_display_env, assembler graph/paf fixtures, scene factories, assembler/assembly/joint/link factories, and various canned dataset variants). Refactor tests/tests_core/test_assembler.py and tests/tests_core/test_assembly.py to consume the new fixtures, remove duplicated setup code, and simplify assertions. Also adjust serialization test filenames and tidy up identity/affinity-related test logic. --- tests/conftest.py | 273 ++++++++++++++++++++- tests/tests_core/test_assembler.py | 380 +++++++---------------------- tests/tests_core/test_assembly.py | 86 +++---- 3 files changed, 387 insertions(+), 352 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ebee7e3..ff19a2b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,22 +1,279 @@ +from __future__ import annotations + +import copy +from collections.abc import Callable +from typing import Any + +import numpy as np import pytest +from dlclive.core.inferenceutils import Assembler + +# -------------------------------------------------------------------------------------- +# Headless display fixture +# -------------------------------------------------------------------------------------- @pytest.fixture def headless_display_env(monkeypatch): - # Import module under test + """Patch dlclive.display so tkinter is replaced with fake, non-GUI-safe objects.""" from test_display import FakeLabel, FakePhotoImage, FakeTk import dlclive.display as display_mod - # Force tkinter availability and patch UI components - monkeypatch.setattr(display_mod, "_TKINTER_AVAILABLE", True, raising=False) - monkeypatch.setattr(display_mod, "Tk", FakeTk, raising=False) - monkeypatch.setattr(display_mod, "Label", FakeLabel, raising=False) + monkeypatch.setattr(display_mod, "_TKINTER_AVAILABLE", True) + monkeypatch.setattr(display_mod, "Tk", FakeTk) + monkeypatch.setattr(display_mod, "Label", FakeLabel) - # Patch ImageTk.PhotoImage class FakeImageTkModule: PhotoImage = FakePhotoImage - monkeypatch.setattr(display_mod, "ImageTk", FakeImageTkModule, raising=False) - + monkeypatch.setattr(display_mod, "ImageTk", FakeImageTkModule) return display_mod + + +# -------------------------------------------------------------------------------------- +# Assembler/assembly test fixtures +# -------------------------------------------------------------------------------------- +@pytest.fixture +def assembler_graph_and_pafs() -> tuple[list[tuple[int, int]], list[int]]: + """Standard 2‑joint graph used throughout the test suite.""" + return ([(0, 1)], [0]) + + +@pytest.fixture +def make_assembler_metadata() -> Callable[..., dict[str, Any]]: + """Return a factory that builds minimal Assembler metadata dictionaries.""" + + def _factory(graph, paf_inds, n_bodyparts, frame_keys): + return { + "metadata": { + "all_joints_names": [f"b{i}" for i in range(n_bodyparts)], + "PAFgraph": graph, + "PAFinds": paf_inds, + }, + **{k: {} for k in frame_keys}, + } + + return _factory + + +@pytest.fixture +def make_assembler_frame() -> Callable[..., dict[str, Any]]: + """Return a factory that builds a frame dict compatible with _flatten_detections.""" + + def _factory( + coordinates_per_label, + confidence_per_label, + identity_per_label=None, + costs=None, + ): + frame = { + "coordinates": [coordinates_per_label], + "confidence": confidence_per_label, + "costs": costs or {}, + } + if identity_per_label is not None: + frame["identity"] = identity_per_label + return frame + + return _factory + + +@pytest.fixture +def simple_two_label_scene(make_assembler_frame) -> dict[str, Any]: + """Deterministic scene with predictable affinities for testing.""" + coords0 = np.array([[0.0, 0.0], [100.0, 100.0]]) + coords1 = np.array([[5.0, 0.0], [110.0, 100.0]]) + conf0 = np.array([0.9, 0.6]) + conf1 = np.array([0.8, 0.7]) + + aff = np.array([[0.95, 0.1], [0.05, 0.9]]) + + lens = np.array( + [ + [np.hypot(*(coords1[0] - coords0[0])), np.hypot(*(coords1[1] - coords0[0]))], + [np.hypot(*(coords1[0] - coords0[1])), np.hypot(*(coords1[1] - coords0[1]))], + ] + ) + + return make_assembler_frame( + coordinates_per_label=[coords0, coords1], + confidence_per_label=[conf0, conf1], + identity_per_label=None, + costs={0: {"distance": lens, "m1": aff}}, + ) + + +@pytest.fixture +def scene_copy(simple_two_label_scene) -> dict[str, Any]: + """Return a deep copy of the simple_two_label_scene fixture.""" + return copy.deepcopy(simple_two_label_scene) + + +@pytest.fixture +def assembler_data( + assembler_graph_and_pafs, + make_assembler_metadata, + simple_two_label_scene, +) -> tuple[dict[str, Any], list[tuple[int, int]], list[int]]: + """Full metadata + two identical frames ('0', '1').""" + graph, paf_inds = assembler_graph_and_pafs + data = make_assembler_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0", "1"]) + data["0"] = simple_two_label_scene + data["1"] = simple_two_label_scene + return data, graph, paf_inds + + +@pytest.fixture +def assembler_data_single_frame( + assembler_graph_and_pafs, + make_assembler_metadata, + simple_two_label_scene, +) -> tuple[dict[str, Any], list[tuple[int, int]], list[int]]: + """Metadata + a single frame ('0'). Used by most tests.""" + graph, paf_inds = assembler_graph_and_pafs + data = make_assembler_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) + data["0"] = simple_two_label_scene + return data, graph, paf_inds + + +@pytest.fixture +def assembler_data_two_frames_nudged( + assembler_graph_and_pafs, + make_assembler_metadata, + simple_two_label_scene, +) -> tuple[dict[str, Any], list[tuple[int, int]], list[int]]: + """Two frames where frame '1' is a nudged copy of frame '0'.""" + graph, paf_inds = assembler_graph_and_pafs + data = make_assembler_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0", "1"]) + + frame0 = simple_two_label_scene + frame1 = copy.deepcopy(simple_two_label_scene) + frame1["coordinates"][0][0] += np.array([[1.0, 0.0], [1.0, 0.0]]) + frame1["coordinates"][0][1] += np.array([[1.0, 0.0], [1.0, 0.0]]) + + data["0"] = frame0 + data["1"] = frame1 + return data, graph, paf_inds + + +@pytest.fixture +def assembler_data_no_detections( + assembler_graph_and_pafs, + make_assembler_metadata, + make_assembler_frame, +) -> tuple[dict[str, Any], list[tuple[int, int]], list[int]]: + """Metadata + a single frame ('0') with zero detections for both labels.""" + graph, paf_inds = assembler_graph_and_pafs + data = make_assembler_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) + + frame = make_assembler_frame( + coordinates_per_label=[np.zeros((0, 2)), np.zeros((0, 2))], + confidence_per_label=[np.zeros((0,)), np.zeros((0,))], + identity_per_label=None, + costs={}, + ) + data["0"] = frame + return data, graph, paf_inds + + +@pytest.fixture +def make_assembler() -> Callable[..., Assembler]: + """ + Factory to create an Assembler with sensible defaults for this test suite. + Override any parameter per-test via kwargs. + """ + + def _factory(data: dict[str, Any], **overrides) -> Assembler: + defaults = dict( + max_n_individuals=2, + n_multibodyparts=2, + min_n_links=1, + pcutoff=0.1, + min_affinity=0.05, + ) + defaults.update(overrides) + return Assembler(data, **defaults) + + return _factory + + +# -------------------------------------------------------------------------------------- +# Assembly / Joint / Link test fixtures +# -------------------------------------------------------------------------------------- +from dlclive.core.inferenceutils import Assembly, Joint, Link # noqa: E402 + + +@pytest.fixture +def make_assembly() -> Callable[..., Assembly]: + """Factory to create an Assembly with the given size.""" + + def _factory(size: int) -> Assembly: + return Assembly(size=size) + + return _factory + + +@pytest.fixture +def make_joint() -> Callable[..., Joint]: + """Factory to create a Joint with sensible defaults.""" + + def _factory( + pos=(0.0, 0.0), + confidence: float = 1.0, + label: int = 0, + idx: int = 0, + group: int = -1, + ) -> Joint: + return Joint(pos=pos, confidence=confidence, label=label, idx=idx, group=group) + + return _factory + + +@pytest.fixture +def make_link() -> Callable[..., Link]: + """Factory to create a Link between two joints.""" + + def _factory(j1: Joint, j2: Joint, affinity: float = 1.0) -> Link: + return Link(j1, j2, affinity=affinity) + + return _factory + + +@pytest.fixture +def two_overlap_assemblies(make_assembly) -> tuple[Assembly, Assembly]: + """Two assemblies with partial overlap used by intersection tests.""" + ass1 = make_assembly(2) + ass1.data[0, :2] = [0, 0] + ass1.data[1, :2] = [10, 10] + ass1._visible.update({0, 1}) + + ass2 = make_assembly(2) + ass2.data[0, :2] = [5, 5] + ass2.data[1, :2] = [15, 15] + ass2._visible.update({0, 1}) + return ass1, ass2 + + +@pytest.fixture +def soft_identity_assembly(make_assembly) -> Assembly: + """Assembly configured for soft_identity tests.""" + assemb = make_assembly(3) + assemb.data[:] = np.nan + assemb.data[0] = [0, 0, 1.0, 0] + assemb.data[1] = [5, 5, 0.5, 0] + assemb.data[2] = [10, 10, 1.0, 1] + assemb._visible = {0, 1, 2} + return assemb + + +@pytest.fixture +def four_joint_chain(make_joint, make_link) -> tuple[Joint, Joint, Joint, Joint, Link, Link]: + """Four joints and two links: (0-1) and (2-3).""" + j0 = make_joint((0, 0), 1.0, label=0, idx=10) + j1 = make_joint((1, 0), 1.0, label=1, idx=11) + j2 = make_joint((2, 0), 1.0, label=2, idx=12) + j3 = make_joint((3, 0), 1.0, label=3, idx=13) + l01 = make_link(j0, j1, affinity=0.5) + l23 = make_link(j2, j3, affinity=0.8) + return j0, j1, j2, j3, l01, l23 diff --git a/tests/tests_core/test_assembler.py b/tests/tests_core/test_assembler.py index 3b9620d..6feef72 100644 --- a/tests/tests_core/test_assembler.py +++ b/tests/tests_core/test_assembler.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass import numpy as np @@ -6,80 +8,13 @@ from dlclive.core.inferenceutils import Assembler, Assembly, Joint, Link -# -------------------------------------------------------------------------------------- -# Helpers -# -------------------------------------------------------------------------------------- - -def make_metadata(graph, paf_inds, n_bodyparts, frame_keys): - """Create a minimal DLC-like metadata structure for Assembler.""" - return { - "metadata": { - "all_joints_names": [f"b{i}" for i in range(n_bodyparts)], - "PAFgraph": graph, - "PAFinds": paf_inds, - }, - **{k: {} for k in frame_keys}, - } - - -def make_frame(coordinates_per_label, confidence_per_label, identity_per_label=None, costs=None): - """ - Build a single frame dict with the structure Assembler._flatten_detections expects. - - coordinates_per_label: list of np.ndarray[(n_dets, 2)] - confidence_per_label: list of np.ndarray[(n_dets, )] - identity_per_label: list of np.ndarray[(n_dets, n_groups)] or None - costs: dict or None. Example: - { - 0: { - "distance": np.array([[...]]), # rows: label s detections, cols: label t - "m1": np.array([[...]]), - } - } - """ - frame = { - # NOTE: Assembler expects coordinates under key "coordinates"[0] - "coordinates": [coordinates_per_label], - "confidence": confidence_per_label, - "costs": costs or {}, - } - if identity_per_label is not None: - frame["identity"] = identity_per_label - return frame - - -def simple_two_label_scene(): - """ - Build a very small, deterministic scene with 2 bodyparts (0 ↔ 1), each with 2 detections. - We design affinities so that the intended pairs are (A↔C) and (B↔D). - """ - # Label 0 detections: A (near origin), B (far) - coords0 = np.array([[0.0, 0.0], [100.0, 100.0]]) - conf0 = np.array([0.9, 0.6]) - - # Label 1 detections: C near A, D near B - coords1 = np.array([[5.0, 0.0], [110.0, 100.0]]) - conf1 = np.array([0.8, 0.7]) - - # Affinities: strong on diagonal pairs AC and BD, weak elsewhere - aff = np.array([[0.95, 0.1], [0.05, 0.9]]) - # Lengths: finite distances (not used for assignment, but must be finite) - lens = np.array( - [ - [np.hypot(*(coords1[0] - coords0[0])), np.hypot(*(coords1[1] - coords0[0]))], - [np.hypot(*(coords1[0] - coords0[1])), np.hypot(*(coords1[1] - coords0[1]))], - ] - ) - - # Build frame expected by Assembler - frame0 = make_frame( - coordinates_per_label=[coords0, coords1], - confidence_per_label=[conf0, conf1], - identity_per_label=None, - costs={0: {"distance": lens, "m1": aff}}, - ) - return frame0 +def _bag_from_frame(frame: dict) -> dict[int, list]: + """Build joints bag {label: [Joint, ...]} from a single frame.""" + bag: dict[int, list] = {} + for j in Assembler._flatten_detections(frame): + bag.setdefault(j.label, []).append(j) + return bag # -------------------------------------------------------------------------------------- @@ -87,32 +22,27 @@ def simple_two_label_scene(): # -------------------------------------------------------------------------------------- -def test_parse_metadata_and_getitem_and_empty_classmethod(): - graph = [(0, 1)] - paf_inds = [0] - data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0", "1"]) - # Fill frames so __getitem__ returns something non-empty later - data["0"] = simple_two_label_scene() - data["1"] = simple_two_label_scene() - - asm = Assembler( +def test_parse_metadata_and_getitem(assembler_data, make_assembler): + data, graph, paf_inds = assembler_data + # Parsing + asm = make_assembler( data, max_n_individuals=2, n_multibodyparts=2, ) - # parse_metadata applied in ctor: assert asm.metadata["num_joints"] == 2 assert asm.metadata["paf_graph"] == graph assert list(asm.metadata["paf"]) == paf_inds assert set(asm.metadata["imnames"]) == {"0", "1"} - - # __getitem__ returns the per-frame dict + # __getitem__ assert "coordinates" in asm[0] assert "confidence" in asm[0] assert "costs" in asm[0] - # empty() convenience + +def test_empty_classmethod(assembler_graph_and_pafs): + graph, paf_inds = assembler_graph_and_pafs empty = Assembler.empty( max_n_individuals=1, n_multibodyparts=1, @@ -129,30 +59,24 @@ def test_parse_metadata_and_getitem_and_empty_classmethod(): # -------------------------------------------------------------------------------------- -def test_flatten_detections_no_identity(): - frame = simple_two_label_scene() +def test_flatten_detections_no_identity(simple_two_label_scene): + frame = simple_two_label_scene joints = list(Assembler._flatten_detections(frame)) - # 2 labels * 2 detections - assert len(joints) == 4 - # label IDs and groups - labels = sorted([j.label for j in joints]) - assert labels == [0, 0, 1, 1] - # identity absent → group = -1 + assert len(joints) == 4 + assert sorted(j.label for j in joints) == [0, 0, 1, 1] assert set(j.group for j in joints) == {-1} -def test_flatten_detections_with_identity(): - frame = simple_two_label_scene() - - # Add identity logits so that argmax → [0, 1] for both labels - id0 = np.array([[10.0, 0.0], [0.0, 10.0]]) # for label 0 - id1 = np.array([[10.0, 0.0], [0.0, 10.0]]) # for label 1 +def test_flatten_detections_with_identity(scene_copy): + frame = scene_copy + id0 = np.array([[10.0, 0.0], [0.0, 10.0]]) + id1 = np.array([[10.0, 0.0], [0.0, 10.0]]) frame["identity"] = [id0, id1] joints = list(Assembler._flatten_detections(frame)) groups = [j.group for j in joints] - # we expect groups [0,1,0,1] in some order (2 per label) + assert set(groups) == {0, 1} assert groups.count(0) == 2 assert groups.count(1) == 2 @@ -163,69 +87,51 @@ def test_flatten_detections_with_identity(): # -------------------------------------------------------------------------------------- -def test_extract_best_links_optimal_assignment(): - graph = [(0, 1)] - paf_inds = [0] - data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) - data["0"] = simple_two_label_scene() - - asm = Assembler( +def test_extract_best_links_optimal_assignment(assembler_data_single_frame, make_assembler): + data, _, _ = assembler_data_single_frame + asm = make_assembler( data, - max_n_individuals=2, - n_multibodyparts=2, greedy=False, # use Hungarian (maximize) - pcutoff=0.1, - min_affinity=0.05, - min_n_links=1, # avoid pruning 1-link assemblies in later steps + min_n_links=1, ) - # Build joints_dict like _assemble does - joints = list(Assembler._flatten_detections(data["0"])) - bag = {} - for j in joints: - bag.setdefault(j.label, []).append(j) + frame0 = data["0"] + bag = _bag_from_frame(frame0) - links = asm.extract_best_links(bag, data["0"]["costs"], trees=None) - # Expect 2 high-quality links: (coords0[0] ↔ coords1[0]) and (coords0[1] ↔ coords1[1]) + links = asm.extract_best_links(bag, frame0["costs"], trees=None) assert len(links) == 2 - # Check that each link connects matching pairs (by position) endpoints = [{tuple(l.j1.pos), tuple(l.j2.pos)} for l in links] assert {(0.0, 0.0), (5.0, 0.0)} in endpoints assert {(100.0, 100.0), (110.0, 100.0)} in endpoints - # Affinity should be the matrix diagonal values ~0.95 and ~0.9 - vals = sorted([l.affinity for l in links], reverse=True) + vals = sorted((l.affinity for l in links), reverse=True) assert vals[0] == pytest.approx(0.95, rel=1e-6) assert vals[1] == pytest.approx(0.90, rel=1e-6) -def test_extract_best_links_greedy_with_thresholds(): - graph = [(0, 1)] - paf_inds = [0] - data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) - data["0"] = simple_two_label_scene() - - asm = Assembler( +def test_extract_best_links_greedy_with_thresholds(assembler_data_single_frame, make_assembler): + data, _, _ = assembler_data_single_frame + asm = make_assembler( data, max_n_individuals=1, # greedy will stop after 1 disjoint pair chosen - n_multibodyparts=2, greedy=True, pcutoff=0.5, # conf product must exceed 0.25 min_affinity=0.5, # low-affinity pairs excluded min_n_links=1, ) - joints = list(Assembler._flatten_detections(data["0"])) - bag = {} - for j in joints: - bag.setdefault(j.label, []).append(j) + frame0 = data["0"] + bag = _bag_from_frame(frame0) - links = asm.extract_best_links(bag, data["0"]["costs"], trees=None) - # Expect exactly 1 link due to max_n_individuals=1 in greedy picking + links = asm.extract_best_links(bag, frame0["costs"], trees=None) assert len(links) == 1 + s = {tuple(links[0].j1.pos), tuple(links[0].j2.pos)} - assert s == {(0.0, 0.0), (5.0, 0.0)} or s == {(100.0, 100.0), (110.0, 100.0)} + assert s in ( + {(0.0, 0.0), (5.0, 0.0)}, + {(100.0, 100.0), (110.0, 100.0)}, + ) # -------------------------------------------------------------------------------------- @@ -233,124 +139,67 @@ def test_extract_best_links_greedy_with_thresholds(): # -------------------------------------------------------------------------------------- -def test_build_assemblies_from_links(): - graph = [(0, 1)] - paf_inds = [0] - data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) - data["0"] = simple_two_label_scene() - - asm = Assembler( - data, - max_n_individuals=2, - n_multibodyparts=2, - greedy=False, - pcutoff=0.1, - min_affinity=0.05, - min_n_links=1, - ) +def test_build_assemblies_from_links(assembler_data_single_frame, make_assembler): + data, _, _ = assembler_data_single_frame + asm = make_assembler(data, greedy=False, min_n_links=1) - joints = list(Assembler._flatten_detections(data["0"])) - bag = {} - for j in joints: - bag.setdefault(j.label, []).append(j) + frame0 = data["0"] + bag = _bag_from_frame(frame0) - links = asm.extract_best_links(bag, data["0"]["costs"]) + links = asm.extract_best_links(bag, frame0["costs"]) assemblies, _ = asm.build_assemblies(links) - # We expect two disjoint 2-joint assemblies assert len(assemblies) == 2 for a in assemblies: assert a.n_links == 1 assert len(a) == 2 - # affinity is the sum of link affinities for the assembly assert a.affinity == pytest.approx(a._affinity / a.n_links) # -------------------------------------------------------------------------------------- -# _assemble (per-frame) – main path without calibration +# _assemble (per-frame) – main path # -------------------------------------------------------------------------------------- -def test__assemble_main_no_calibration_returns_two_assemblies(): - graph = [(0, 1)] - paf_inds = [0] - data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) - data["0"] = simple_two_label_scene() - - asm = Assembler( +def test__assemble_main_no_calibration_returns_two_assemblies(assembler_data_single_frame, make_assembler): + data, _, _ = assembler_data_single_frame + asm = make_assembler( data, - max_n_individuals=2, - n_multibodyparts=2, greedy=False, - pcutoff=0.1, - min_affinity=0.05, min_n_links=1, max_overlap=0.99, window_size=0, ) - assemblies, unique = asm._assemble(data["0"], ind_frame=0) - assert unique is None # no unique bodyparts in this setting + assemblies, unique = asm._assemble(data["0"], 0) + assert unique is None assert len(assemblies) == 2 assert all(len(a) == 2 for a in assemblies) -def test__assemble_returns_none_when_no_detections(): - graph = [(0, 1)] - paf_inds = [0] - data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) +def test__assemble_returns_none_when_no_detections(assembler_data_no_detections, make_assembler): + data, _, _ = assembler_data_no_detections + asm = make_assembler(data, max_n_individuals=2, n_multibodyparts=2) - # Frame with zero coords (→ skipped by _flatten_detections) - coords0 = np.zeros((0, 2)) - conf0 = np.zeros((0,)) - coords1 = np.zeros((0, 2)) - conf1 = np.zeros((0,)) - frame = make_frame([coords0, coords1], [conf0, conf1], identity_per_label=None, costs={}) - data["0"] = frame - - asm = Assembler( - data, - max_n_individuals=2, - n_multibodyparts=2, - ) - assemblies, unique = asm._assemble(data["0"], ind_frame=0) + assemblies, unique = asm._assemble(data["0"], 0) assert assemblies is None and unique is None # -------------------------------------------------------------------------------------- -# assemble() over multiple frames + window_size KD-tree caching +# assemble() over multiple frames + KD-tree caching # -------------------------------------------------------------------------------------- -def test_assemble_across_frames_updates_temporal_trees(): - graph = [(0, 1)] - paf_inds = [0] - data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0", "1"]) - - # Frame 0: baseline - frame0 = simple_two_label_scene() - - # Frame 1: nudge coordinates slightly, keep affinities similar - f1 = simple_two_label_scene() - f1["coordinates"][0][0] = f1["coordinates"][0][0] + np.array([[1.0, 0.0], [1.0, 0.0]]) - f1["coordinates"][0][1] = f1["coordinates"][0][1] + np.array([[1.0, 0.0], [1.0, 0.0]]) - - data["0"] = frame0 - data["1"] = f1 - - asm = Assembler( +def test_assemble_across_frames_updates_temporal_trees(assembler_data_two_frames_nudged, make_assembler): + data, _, _ = assembler_data_two_frames_nudged + asm = make_assembler( data, - max_n_individuals=2, - n_multibodyparts=2, window_size=1, # enable temporal memory min_n_links=1, ) - # Use serial path to avoid multiprocessing in tests asm.assemble(chunk_size=0) - # KD-trees should be recorded for frames that had links - # Presence of keys 0 and 1 depends on links creation assert 0 in asm._trees or 1 in asm._trees assert isinstance(asm.assemblies, dict) assert set(asm.assemblies.keys()).issubset({0, 1}) @@ -361,37 +210,29 @@ def test_assemble_across_frames_updates_temporal_trees(): # -------------------------------------------------------------------------------------- -def test_identity_only_branch_groups_by_identity(): - graph = [(0, 1)] - paf_inds = [0] +def test_identity_only_branch_groups_by_identity(assembler_data_single_frame, scene_copy, make_assembler): + data, _, _ = assembler_data_single_frame - # Build a frame with identities. Two groups (0 and 1). - base = simple_two_label_scene() - # identity logits such that each label has two detections belonging to groups 0 and 1 - id0 = np.array([[4.0, 1.0], [1.0, 4.0]]) # label 0 → group 0 then 1 - id1 = np.array([[4.0, 1.0], [1.0, 4.0]]) # label 1 → group 0 then 1 + base = scene_copy + id0 = np.array([[4.0, 1.0], [1.0, 4.0]]) + id1 = np.array([[4.0, 1.0], [1.0, 4.0]]) base["identity"] = [id0, id1] - - data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) data["0"] = base - # identity_only=True should be enabled since "identity" is present in frame 0 - asm = Assembler( + asm = make_assembler( data, max_n_individuals=3, - n_multibodyparts=2, identity_only=True, pcutoff=0.1, ) - assemblies, unique = asm._assemble(data["0"], ind_frame=0) - # We expect at least one assembly created by grouping (one per identity that has both labels observed) + assemblies, _ = asm._assemble(data["0"], 0) assert assemblies is not None assert all(len(a) >= 1 for a in assemblies) # -------------------------------------------------------------------------------------- -# Mahalanobis distance and link probability with a mocked KDE +# Mahalanobis & link probability # -------------------------------------------------------------------------------------- @@ -400,33 +241,20 @@ class _FakeKDE: mean: np.ndarray inv_cov: np.ndarray covariance: np.ndarray - d: int # dimension + d: int -def test_calc_assembly_mahalanobis_and_link_probability_with_fake_kde(): - # 2 multibody parts → pdist length = 1 - graph = [(0, 1)] - paf_inds = [0] - data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) - data["0"] = simple_two_label_scene() +def test_calc_assembly_mahalanobis_and_link_probability_with_fake_kde(assembler_data_single_frame, make_assembler): + data, _, _ = assembler_data_single_frame + asm = make_assembler(data, min_n_links=1) - asm = Assembler( - data, - max_n_individuals=2, - n_multibodyparts=2, - min_n_links=1, - ) - - # Build a simple assembly with two joints - j0 = Joint((0.0, 0.0), 1.0, label=0, idx=0) - j1 = Joint((3.0, 4.0), 1.0, label=1, idx=1) - link = Link(j0, j1, affinity=1.0) + j0 = Joint((0.0, 0.0), 1.0, 0, 0) + j1 = Joint((3.0, 4.0), 1.0, 1, 1) + link = Link(j0, j1, 1.0) a = Assembly(size=2) a.add_link(link) - # Fake KDE: one-dimensional (pairwise sq distance only) - # distance^2 = 5^2 = 25; set mean=25, identity covariance fake = _FakeKDE( mean=np.array([25.0]), inv_cov=np.array([[1.0]]), @@ -436,43 +264,29 @@ def test_calc_assembly_mahalanobis_and_link_probability_with_fake_kde(): asm._kde = fake asm.safe_edge = True - # Mahalanobis should be finite and small because dist == mean d = asm.calc_assembly_mahalanobis_dist(a) - assert np.isfinite(d) assert d == pytest.approx(0.0, abs=1e-6) - # Link probability depends on squared length vs mean; here z=0 → high prob p = asm.calc_link_probability(link) - assert 0.0 <= p <= 1.0 assert p == pytest.approx(1.0, rel=1e-6) # -------------------------------------------------------------------------------------- -# I/O helpers: to_pickle / from_pickle / to_h5 (optional) +# I/O: pickle / h5 # -------------------------------------------------------------------------------------- -def test_to_pickle_and_from_pickle(tmp_path): - graph = [(0, 1)] - paf_inds = [0] - data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) - data["0"] = simple_two_label_scene() +def test_to_pickle_and_from_pickle(tmp_path, assembler_data_single_frame, make_assembler, assembler_graph_and_pafs): + data, _, _ = assembler_data_single_frame + graph, paf_inds = assembler_graph_and_pafs - asm = Assembler( - data, - max_n_individuals=2, - n_multibodyparts=2, - min_n_links=1, - ) - - # build assemblies for frame 0 + asm = make_assembler(data, min_n_links=1) assemblies, _ = asm._assemble(data["0"], 0) asm.assemblies = {0: assemblies} - pkl = tmp_path / "ass.pkl" + pkl = tmp_path / "assemb.pkl" asm.to_pickle(str(pkl)) - # Load into a new Assembler (empty schema is sufficient) new_asm = Assembler.empty( max_n_individuals=2, n_multibodyparts=2, @@ -481,36 +295,26 @@ def test_to_pickle_and_from_pickle(tmp_path): paf_inds=paf_inds, ) new_asm.from_pickle(str(pkl)) + assert 0 in new_asm.assemblies assert isinstance(new_asm.assemblies[0], list) - assert new_asm.assemblies[0][0].shape == (2, 4) or True # presence is enough @pytest.mark.skipif( - pytest.importorskip("tables", reason="PyTables required for HDF5") is None, reason="requires PyTables" + pytest.importorskip("tables", reason="PyTables required for HDF5") is None, + reason="requires PyTables", ) -def test_to_h5_roundtrip(tmp_path): - graph = [(0, 1)] - paf_inds = [0] - data = make_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) - data["0"] = simple_two_label_scene() +def test_to_h5_roundtrip(tmp_path, assembler_data_single_frame, make_assembler): + data, _, _ = assembler_data_single_frame - asm = Assembler( - data, - max_n_individuals=2, - n_multibodyparts=2, - min_n_links=1, - ) + asm = make_assembler(data, min_n_links=1) assemblies, _ = asm._assemble(data["0"], 0) asm.assemblies = {0: assemblies} - h5 = tmp_path / "ass.h5" + h5 = tmp_path / "assemb.h5" asm.to_h5(str(h5)) - # Read back and perform basic structural assertions df = pd.read_hdf(str(h5), key="ass") - # one frame, 2 individuals, 2 bodyparts, coords {x,y,likelihood} - # df shape will be (frames, 2*2*3) assert df.shape[0] == 1 assert df.columns.nlevels == 4 assert set(df.columns.get_level_values("coords")) == {"x", "y", "likelihood"} diff --git a/tests/tests_core/test_assembly.py b/tests/tests_core/test_assembly.py index 9471f81..3cc3c84 100644 --- a/tests/tests_core/test_assembly.py +++ b/tests/tests_core/test_assembly.py @@ -1,15 +1,14 @@ import numpy as np import pytest -from dlclive.core.inferenceutils import Assembly, Joint, Link +from dlclive.core.inferenceutils import Assembly + # --------------------------- # Basic construction # --------------------------- - - -def test_assembly_init(): - assemb = Assembly(size=5) +def test_assembly_init(make_assembly): + assemb = make_assembly(size=5) assert assemb.data.shape == (5, 4) # col 0,1,3 are NaN, col 2 is confidence=0 @@ -29,8 +28,6 @@ def test_assembly_init(): # --------------------------- # from_array # --------------------------- - - def test_assembly_from_array_basic_xy_only(): arr = np.array( [ @@ -72,12 +69,10 @@ def test_assembly_from_array_with_nans(): # --------------------------- # add_joint / remove_joint # --------------------------- - - -def test_add_joint_and_remove_joint(): - assemb = Assembly(size=3) - j0 = Joint(pos=(1.0, 2.0), confidence=0.5, label=0, idx=10) - j1 = Joint(pos=(3.0, 4.0), confidence=0.8, label=1, idx=11) +def test_add_joint_and_remove_joint(make_assembly, make_joint): + assemb = make_assembly(size=3) + j0 = make_joint(pos=(1.0, 2.0), confidence=0.5, label=0, idx=10) + j1 = make_joint(pos=(3.0, 4.0), confidence=0.8, label=1, idx=11) # adding first joint assert assemb.add_joint(j0) is True @@ -106,14 +101,12 @@ def test_add_joint_and_remove_joint(): # --------------------------- # add_link (simple) # --------------------------- +def test_add_link_adds_joints_and_affinity(make_assembly, make_joint, make_link): + assemb = make_assembly(size=3) - -def test_add_link_adds_joints_and_affinity(): - assemb = Assembly(size=3) - - j0 = Joint(pos=(0.0, 0.0), confidence=1.0, label=0, idx=100) - j1 = Joint(pos=(1.0, 0.0), confidence=1.0, label=1, idx=101) - link = Link(j0, j1, affinity=0.7) + j0 = make_joint(pos=(0.0, 0.0), confidence=1.0, label=0, idx=100) + j1 = make_joint(pos=(1.0, 0.0), confidence=1.0, label=1, idx=101) + link = make_link(j0, j1, affinity=0.7) # New link → adds both joints result = assemb.add_link(link) @@ -135,8 +128,8 @@ def test_add_link_adds_joints_and_affinity(): # --------------------------- -def test_extent_and_area(): - assemb = Assembly(size=3) +def test_extent_and_area(make_assembly): + assemb = make_assembly(size=3) # manually set data: [x, y, conf, group] assemb.data[:] = np.nan assemb.data[0, :2] = [10, 10] @@ -155,16 +148,8 @@ def test_extent_and_area(): # --------------------------- -def test_intersection_with_partial_overlap(): - ass1 = Assembly(size=2) - ass1.data[0, :2] = [0, 0] - ass1.data[1, :2] = [10, 10] - ass1._visible.update({0, 1}) - - ass2 = Assembly(size=2) - ass2.data[0, :2] = [5, 5] - ass2.data[1, :2] = [15, 15] - ass2._visible.update({0, 1}) +def test_intersection_with_partial_overlap(two_overlap_assemblies): + ass1, ass2 = two_overlap_assemblies # They overlap in a square of area 5x5 around (5,5)-(10,10). # Each assembly has 2 points. Points inside overlap: @@ -178,8 +163,8 @@ def test_intersection_with_partial_overlap(): # --------------------------- -def test_confidence_property(): - assemb = Assembly(size=3) +def test_confidence_property(make_assembly): + assemb = make_assembly(size=3) assemb.data[:] = np.nan assemb.data[:, 2] = [0.2, 0.4, np.nan] # mean of finite = (0.2+0.4)/2 = 0.3 assert assemb.confidence == pytest.approx(0.3) @@ -193,14 +178,9 @@ def test_confidence_property(): # --------------------------- -def test_soft_identity_simple(): +def test_soft_identity_simple(soft_identity_assembly): # data format: x, y, conf, group - assemb = Assembly(size=3) - assemb.data[:] = np.nan - assemb.data[0] = [0, 0, 1.0, 0] - assemb.data[1] = [5, 5, 0.5, 0] - assemb.data[2] = [10, 10, 1.0, 1] - assemb._visible = {0, 1, 2} + assemb = soft_identity_assembly # groups: 0 → weights 1.0 and 0.5 (avg=0.75) # 1 → weight 1.0 @@ -217,12 +197,12 @@ def test_soft_identity_simple(): # --------------------------- -def test_contains_checks_shared_idx(): - ass1 = Assembly(size=3) - ass2 = Assembly(size=3) +def test_contains_checks_shared_idx(make_assembly, make_joint): + ass1 = make_assembly(size=3) + ass2 = make_assembly(size=3) - j0 = Joint((0, 0), confidence=1.0, label=0, idx=10) - j1 = Joint((1, 1), confidence=1.0, label=1, idx=99) + j0 = make_joint((0, 0), confidence=1.0, label=0, idx=10) + j1 = make_joint((1, 1), confidence=1.0, label=1, idx=99) ass1.add_joint(j0) ass2.add_joint(j1) @@ -240,17 +220,11 @@ def test_contains_checks_shared_idx(): # --------------------------- -def test_assembly_addition_combines_links(): - a1 = Assembly(size=4) - a2 = Assembly(size=4) - - j0 = Joint((0, 0), 1.0, label=0, idx=10) - j1 = Joint((1, 0), 1.0, label=1, idx=11) - j2 = Joint((2, 0), 1.0, label=2, idx=12) - j3 = Joint((3, 0), 1.0, label=3, idx=13) +def test_assembly_addition_combines_links(make_assembly, four_joint_chain): + a1 = make_assembly(size=4) + a2 = make_assembly(size=4) - l01 = Link(j0, j1, affinity=0.5) - l23 = Link(j2, j3, affinity=0.8) + j0, _, _, _, l01, l23 = four_joint_chain a1.add_link(l01) a2.add_link(l23) From f5e6f6cd73c0e1359dbe1e0b4581011b434a7a05 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 2 Feb 2026 10:32:58 +0100 Subject: [PATCH 12/20] Use MagicMocks for headless display tests Replace handcrafted FakeTk/Label/PhotoImage classes with unittest.mock MagicMocks in the headless_display_env fixture. The fixture now returns a SimpleNamespace with mock constructors/instances (tk, label, photo) and sets display module attributes with monkeypatch (raising=False). Update tests to use the env mocks, assert on mock calls (title, pack, configure, update, destroy), mock ImageDraw.Draw with a MagicMock, and make colormap sampling deterministic for assertions. This simplifies test setup and enables precise call-based assertions instead of inspecting custom fake object state. --- tests/conftest.py | 52 ++++++++++++++++++----- tests/test_display.py | 97 +++++++++++++++++++------------------------ 2 files changed, 84 insertions(+), 65 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ff19a2b..17499e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,9 @@ import copy from collections.abc import Callable +from types import SimpleNamespace from typing import Any +from unittest.mock import MagicMock import numpy as np import pytest @@ -15,20 +17,50 @@ # -------------------------------------------------------------------------------------- @pytest.fixture def headless_display_env(monkeypatch): - """Patch dlclive.display so tkinter is replaced with fake, non-GUI-safe objects.""" - from test_display import FakeLabel, FakePhotoImage, FakeTk - + """ + Patch dlclive.display so tkinter + ImageTk are replaced by MagicMocks. + + Returns an object with: + - mod: the imported dlclive.display module + - tk_ctor: MagicMock constructor for Tk + - tk: MagicMock instance for the window + - label_ctor: MagicMock constructor for Label + - label: MagicMock instance for the label widget + - photo_ctor: MagicMock function for ImageTk.PhotoImage + - photo: MagicMock instance representing created image + """ import dlclive.display as display_mod - monkeypatch.setattr(display_mod, "_TKINTER_AVAILABLE", True) - monkeypatch.setattr(display_mod, "Tk", FakeTk) - monkeypatch.setattr(display_mod, "Label", FakeLabel) + # Ensure display path is enabled + monkeypatch.setattr(display_mod, "_TKINTER_AVAILABLE", True, raising=False) - class FakeImageTkModule: - PhotoImage = FakePhotoImage + # Tk / Label mocks + tk = MagicMock(name="TkInstance") + tk_ctor = MagicMock(name="Tk", return_value=tk) - monkeypatch.setattr(display_mod, "ImageTk", FakeImageTkModule) - return display_mod + label = MagicMock(name="LabelInstance") + label_ctor = MagicMock(name="Label", return_value=label) + + # ImageTk.PhotoImage mock + photo = MagicMock(name="PhotoImageInstance") + photo_ctor = MagicMock(name="PhotoImage", return_value=photo) + + class FakeImageTkModule: + PhotoImage = photo_ctor + + monkeypatch.setattr(display_mod, "Tk", tk_ctor, raising=False) + monkeypatch.setattr(display_mod, "Label", label_ctor, raising=False) + monkeypatch.setattr(display_mod, "ImageTk", FakeImageTkModule, raising=False) + + return SimpleNamespace( + mod=display_mod, + tk_ctor=tk_ctor, + tk=tk, + label_ctor=label_ctor, + label=label, + photo_ctor=photo_ctor, + photo=photo, + ) # -------------------------------------------------------------------------------------- diff --git a/tests/test_display.py b/tests/test_display.py index 0f6b555..e554555 100644 --- a/tests/test_display.py +++ b/tests/test_display.py @@ -1,42 +1,9 @@ +from unittest.mock import ANY, MagicMock + import numpy as np import pytest -class FakeTk: - def __init__(self): - self.titles = [] - self.updated = 0 - self.destroyed = False - - def title(self, text): - self.titles.append(text) - - def update(self): - self.updated += 1 - - def destroy(self): - self.destroyed = True - - -class FakeLabel: - def __init__(self, window): - self.window = window - self.packed = False - self.configured = {} - - def pack(self): - self.packed = True - - def configure(self, **kwargs): - self.configured.update(kwargs) - - -class FakePhotoImage: - def __init__(self, image=None, master=None): - self.image = image - self.master = master - - def test_display_init_raises_when_tk_unavailable(monkeypatch): import dlclive.display as display_mod @@ -47,7 +14,8 @@ def test_display_init_raises_when_tk_unavailable(monkeypatch): def test_display_frame_creates_window_and_updates(headless_display_env): - display_mod = headless_display_env + env = headless_display_env + display_mod = env.mod disp = display_mod.Display(radius=3, pcutoff=0.5) frame = np.zeros((100, 120, 3), dtype=np.uint8) @@ -55,44 +23,57 @@ def test_display_frame_creates_window_and_updates(headless_display_env): disp.display_frame(frame, pose) - assert disp.window is not None - assert disp.lab is not None - assert disp.lab.packed is True - assert disp.window.updated == 1 - assert "image" in disp.lab.configured # configured with PhotoImage + # Window created and initialized + env.tk_ctor.assert_called_once_with() + env.tk.title.assert_called_once_with("DLC Live") + + # Label created and packed + env.label_ctor.assert_called_once_with(env.tk) + env.label.pack.assert_called_once() + + # PhotoImage created with correct master + image passed + env.photo_ctor.assert_called_once_with(image=ANY, master=env.tk) + + # Image configured on label and window updated + env.label.configure.assert_called_once_with(image=env.photo) + env.tk.update.assert_called_once_with() def test_display_draws_only_points_above_cutoff(headless_display_env, monkeypatch): - display_mod = headless_display_env + env = headless_display_env + display_mod = env.mod disp = display_mod.Display(radius=3, pcutoff=0.5) + # Patch colormap so color sampling is deterministic and always long enough + class FakeCC: + bmy = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 0)] + + monkeypatch.setattr(display_mod, "cc", FakeCC) + frame = np.zeros((100, 100, 3), dtype=np.uint8) pose = np.array( [ [ [10, 10, 0.9], # draw [20, 20, 0.49], # don't draw - [30, 30, 0.5001], # draw (>=) + [30, 30, 0.5001], # draw (> pcutoff) ] ], dtype=float, ) - ellipses = [] - - class DrawRecorder: - def ellipse(self, coords, fill=None, outline=None): - ellipses.append((coords, fill, outline)) - - monkeypatch.setattr(display_mod.ImageDraw, "Draw", lambda img: DrawRecorder()) + draw = MagicMock(name="DrawInstance") + monkeypatch.setattr(display_mod.ImageDraw, "Draw", MagicMock(return_value=draw)) disp.display_frame(frame, pose) - assert len(ellipses) == 2 + # Two points above cutoff => two ellipse calls + assert draw.ellipse.call_count == 2 def test_destroy_calls_window_destroy(headless_display_env): - display_mod = headless_display_env + env = headless_display_env + display_mod = env.mod disp = display_mod.Display() frame = np.zeros((10, 10, 3), dtype=np.uint8) @@ -101,13 +82,13 @@ def test_destroy_calls_window_destroy(headless_display_env): disp.display_frame(frame, pose) disp.destroy() - assert disp.window.destroyed is True + env.tk.destroy.assert_called_once_with() def test_set_display_color_sampling_safe(headless_display_env, monkeypatch): - display_mod = headless_display_env + env = headless_display_env + display_mod = env.mod - # Provide a fixed colormap list class FakeCC: bmy = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 0), (0, 1, 1), (1, 0, 1)] @@ -118,3 +99,9 @@ class FakeCC: assert disp.colors is not None assert len(disp.colors) >= 3 + + # Also verify window setup calls happened + env.tk_ctor.assert_called_once_with() + env.tk.title.assert_called_once_with("DLC Live") + env.label_ctor.assert_called_once_with(env.tk) + env.label.pack.assert_called_once() From 6905cf3ee5456837f96d554f1fa83c8898eff84f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 2 Feb 2026 12:03:27 +0100 Subject: [PATCH 13/20] Use SimpleNamespace in tests; add Hypothesis Replace tuple-based assembler fixtures with SimpleNamespace objects (providing .data, .graph, .paf_inds) and update all tests to access those attributes. Add Hypothesis to dev dependencies and introduce property-based tests for Assembly (from_array invariants and extent/area checks) --- pyproject.toml | 1 + tests/conftest.py | 61 +++++++------ tests/tests_core/test_assembler.py | 78 ++++++++--------- tests/tests_core/test_assembly.py | 136 +++++++++++++++++------------ 4 files changed, 149 insertions(+), 127 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2456749..54efe23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ tf = [ dev = [ "pytest", "pytest-cov", + "hypothesis", "black", "ruff", ] diff --git a/tests/conftest.py b/tests/conftest.py index 17499e9..6c4866a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -67,9 +67,11 @@ class FakeImageTkModule: # Assembler/assembly test fixtures # -------------------------------------------------------------------------------------- @pytest.fixture -def assembler_graph_and_pafs() -> tuple[list[tuple[int, int]], list[int]]: +def assembler_graph_and_pafs() -> SimpleNamespace: """Standard 2‑joint graph used throughout the test suite.""" - return ([(0, 1)], [0]) + graph = [(0, 1)] + paf_inds = [0] + return SimpleNamespace(graph=graph, paf_inds=paf_inds) @pytest.fixture @@ -147,13 +149,13 @@ def assembler_data( assembler_graph_and_pafs, make_assembler_metadata, simple_two_label_scene, -) -> tuple[dict[str, Any], list[tuple[int, int]], list[int]]: +) -> SimpleNamespace: """Full metadata + two identical frames ('0', '1').""" - graph, paf_inds = assembler_graph_and_pafs - data = make_assembler_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0", "1"]) + paf = assembler_graph_and_pafs + data = make_assembler_metadata(paf.graph, paf.paf_inds, n_bodyparts=2, frame_keys=["0", "1"]) data["0"] = simple_two_label_scene data["1"] = simple_two_label_scene - return data, graph, paf_inds + return SimpleNamespace(data=data, graph=paf.graph, paf_inds=paf.paf_inds) @pytest.fixture @@ -161,12 +163,12 @@ def assembler_data_single_frame( assembler_graph_and_pafs, make_assembler_metadata, simple_two_label_scene, -) -> tuple[dict[str, Any], list[tuple[int, int]], list[int]]: +) -> SimpleNamespace: """Metadata + a single frame ('0'). Used by most tests.""" - graph, paf_inds = assembler_graph_and_pafs - data = make_assembler_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) + paf = assembler_graph_and_pafs + data = make_assembler_metadata(paf.graph, paf.paf_inds, n_bodyparts=2, frame_keys=["0"]) data["0"] = simple_two_label_scene - return data, graph, paf_inds + return SimpleNamespace(data=data, graph=paf.graph, paf_inds=paf.paf_inds) @pytest.fixture @@ -174,10 +176,10 @@ def assembler_data_two_frames_nudged( assembler_graph_and_pafs, make_assembler_metadata, simple_two_label_scene, -) -> tuple[dict[str, Any], list[tuple[int, int]], list[int]]: +) -> SimpleNamespace: """Two frames where frame '1' is a nudged copy of frame '0'.""" - graph, paf_inds = assembler_graph_and_pafs - data = make_assembler_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0", "1"]) + paf = assembler_graph_and_pafs + data = make_assembler_metadata(paf.graph, paf.paf_inds, n_bodyparts=2, frame_keys=["0", "1"]) frame0 = simple_two_label_scene frame1 = copy.deepcopy(simple_two_label_scene) @@ -186,7 +188,7 @@ def assembler_data_two_frames_nudged( data["0"] = frame0 data["1"] = frame1 - return data, graph, paf_inds + return SimpleNamespace(data=data, graph=paf.graph, paf_inds=paf.paf_inds) @pytest.fixture @@ -194,10 +196,10 @@ def assembler_data_no_detections( assembler_graph_and_pafs, make_assembler_metadata, make_assembler_frame, -) -> tuple[dict[str, Any], list[tuple[int, int]], list[int]]: +) -> SimpleNamespace: """Metadata + a single frame ('0') with zero detections for both labels.""" - graph, paf_inds = assembler_graph_and_pafs - data = make_assembler_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"]) + paf = assembler_graph_and_pafs + data = make_assembler_metadata(paf.graph, paf.paf_inds, n_bodyparts=2, frame_keys=["0"]) frame = make_assembler_frame( coordinates_per_label=[np.zeros((0, 2)), np.zeros((0, 2))], @@ -206,7 +208,8 @@ def assembler_data_no_detections( costs={}, ) data["0"] = frame - return data, graph, paf_inds + # return data, graph, paf_inds + return SimpleNamespace(data=data, graph=paf.graph, paf_inds=paf.paf_inds) @pytest.fixture @@ -275,16 +278,16 @@ def _factory(j1: Joint, j2: Joint, affinity: float = 1.0) -> Link: @pytest.fixture def two_overlap_assemblies(make_assembly) -> tuple[Assembly, Assembly]: """Two assemblies with partial overlap used by intersection tests.""" - ass1 = make_assembly(2) - ass1.data[0, :2] = [0, 0] - ass1.data[1, :2] = [10, 10] - ass1._visible.update({0, 1}) + assemb1 = make_assembly(2) + assemb1.data[0, :2] = [0, 0] + assemb1.data[1, :2] = [10, 10] + assemb1._visible.update({0, 1}) - ass2 = make_assembly(2) - ass2.data[0, :2] = [5, 5] - ass2.data[1, :2] = [15, 15] - ass2._visible.update({0, 1}) - return ass1, ass2 + assemb2 = make_assembly(2) + assemb2.data[0, :2] = [5, 5] + assemb2.data[1, :2] = [15, 15] + assemb2._visible.update({0, 1}) + return assemb1, assemb2 @pytest.fixture @@ -300,7 +303,7 @@ def soft_identity_assembly(make_assembly) -> Assembly: @pytest.fixture -def four_joint_chain(make_joint, make_link) -> tuple[Joint, Joint, Joint, Joint, Link, Link]: +def four_joint_chain(make_joint, make_link) -> SimpleNamespace: """Four joints and two links: (0-1) and (2-3).""" j0 = make_joint((0, 0), 1.0, label=0, idx=10) j1 = make_joint((1, 0), 1.0, label=1, idx=11) @@ -308,4 +311,4 @@ def four_joint_chain(make_joint, make_link) -> tuple[Joint, Joint, Joint, Joint, j3 = make_joint((3, 0), 1.0, label=3, idx=13) l01 = make_link(j0, j1, affinity=0.5) l23 = make_link(j2, j3, affinity=0.8) - return j0, j1, j2, j3, l01, l23 + return SimpleNamespace(j0=j0, j1=j1, j2=j2, j3=j3, l01=l01, l23=l23) diff --git a/tests/tests_core/test_assembler.py b/tests/tests_core/test_assembler.py index 6feef72..1604cbc 100644 --- a/tests/tests_core/test_assembler.py +++ b/tests/tests_core/test_assembler.py @@ -23,17 +23,17 @@ def _bag_from_frame(frame: dict) -> dict[int, list]: def test_parse_metadata_and_getitem(assembler_data, make_assembler): - data, graph, paf_inds = assembler_data + adat = assembler_data # Parsing asm = make_assembler( - data, + adat.data, max_n_individuals=2, n_multibodyparts=2, ) assert asm.metadata["num_joints"] == 2 - assert asm.metadata["paf_graph"] == graph - assert list(asm.metadata["paf"]) == paf_inds + assert asm.metadata["paf_graph"] == adat.graph + assert list(asm.metadata["paf"]) == adat.paf_inds assert set(asm.metadata["imnames"]) == {"0", "1"} # __getitem__ assert "coordinates" in asm[0] @@ -42,13 +42,13 @@ def test_parse_metadata_and_getitem(assembler_data, make_assembler): def test_empty_classmethod(assembler_graph_and_pafs): - graph, paf_inds = assembler_graph_and_pafs + paf = assembler_graph_and_pafs empty = Assembler.empty( max_n_individuals=1, n_multibodyparts=1, n_uniquebodyparts=0, - graph=graph, - paf_inds=paf_inds, + graph=paf.graph, + paf_inds=paf.paf_inds, ) assert isinstance(empty, Assembler) assert empty.n_keypoints == 1 @@ -88,14 +88,14 @@ def test_flatten_detections_with_identity(scene_copy): def test_extract_best_links_optimal_assignment(assembler_data_single_frame, make_assembler): - data, _, _ = assembler_data_single_frame + sframe_data = assembler_data_single_frame asm = make_assembler( - data, + sframe_data.data, greedy=False, # use Hungarian (maximize) min_n_links=1, ) - frame0 = data["0"] + frame0 = sframe_data.data["0"] bag = _bag_from_frame(frame0) links = asm.extract_best_links(bag, frame0["costs"], trees=None) @@ -111,9 +111,9 @@ def test_extract_best_links_optimal_assignment(assembler_data_single_frame, make def test_extract_best_links_greedy_with_thresholds(assembler_data_single_frame, make_assembler): - data, _, _ = assembler_data_single_frame + sframe_data = assembler_data_single_frame asm = make_assembler( - data, + sframe_data.data, max_n_individuals=1, # greedy will stop after 1 disjoint pair chosen greedy=True, pcutoff=0.5, # conf product must exceed 0.25 @@ -121,7 +121,7 @@ def test_extract_best_links_greedy_with_thresholds(assembler_data_single_frame, min_n_links=1, ) - frame0 = data["0"] + frame0 = sframe_data.data["0"] bag = _bag_from_frame(frame0) links = asm.extract_best_links(bag, frame0["costs"], trees=None) @@ -140,10 +140,10 @@ def test_extract_best_links_greedy_with_thresholds(assembler_data_single_frame, def test_build_assemblies_from_links(assembler_data_single_frame, make_assembler): - data, _, _ = assembler_data_single_frame - asm = make_assembler(data, greedy=False, min_n_links=1) + sframe_data = assembler_data_single_frame + asm = make_assembler(sframe_data.data, greedy=False, min_n_links=1) - frame0 = data["0"] + frame0 = sframe_data.data["0"] bag = _bag_from_frame(frame0) links = asm.extract_best_links(bag, frame0["costs"]) @@ -162,26 +162,26 @@ def test_build_assemblies_from_links(assembler_data_single_frame, make_assembler def test__assemble_main_no_calibration_returns_two_assemblies(assembler_data_single_frame, make_assembler): - data, _, _ = assembler_data_single_frame + sframe_data = assembler_data_single_frame asm = make_assembler( - data, + sframe_data.data, greedy=False, min_n_links=1, max_overlap=0.99, window_size=0, ) - assemblies, unique = asm._assemble(data["0"], 0) + assemblies, unique = asm._assemble(sframe_data.data["0"], 0) assert unique is None assert len(assemblies) == 2 assert all(len(a) == 2 for a in assemblies) def test__assemble_returns_none_when_no_detections(assembler_data_no_detections, make_assembler): - data, _, _ = assembler_data_no_detections - asm = make_assembler(data, max_n_individuals=2, n_multibodyparts=2) + nodet_data = assembler_data_no_detections + asm = make_assembler(nodet_data.data, max_n_individuals=2, n_multibodyparts=2) - assemblies, unique = asm._assemble(data["0"], 0) + assemblies, unique = asm._assemble(nodet_data.data["0"], 0) assert assemblies is None and unique is None @@ -191,9 +191,9 @@ def test__assemble_returns_none_when_no_detections(assembler_data_no_detections, def test_assemble_across_frames_updates_temporal_trees(assembler_data_two_frames_nudged, make_assembler): - data, _, _ = assembler_data_two_frames_nudged + twofr_data = assembler_data_two_frames_nudged asm = make_assembler( - data, + twofr_data.data, window_size=1, # enable temporal memory min_n_links=1, ) @@ -211,22 +211,22 @@ def test_assemble_across_frames_updates_temporal_trees(assembler_data_two_frames def test_identity_only_branch_groups_by_identity(assembler_data_single_frame, scene_copy, make_assembler): - data, _, _ = assembler_data_single_frame + sframe_data = assembler_data_single_frame base = scene_copy id0 = np.array([[4.0, 1.0], [1.0, 4.0]]) id1 = np.array([[4.0, 1.0], [1.0, 4.0]]) base["identity"] = [id0, id1] - data["0"] = base + sframe_data.data["0"] = base asm = make_assembler( - data, + sframe_data.data, max_n_individuals=3, identity_only=True, pcutoff=0.1, ) - assemblies, _ = asm._assemble(data["0"], 0) + assemblies, _ = asm._assemble(sframe_data.data["0"], 0) assert assemblies is not None assert all(len(a) >= 1 for a in assemblies) @@ -245,8 +245,8 @@ class _FakeKDE: def test_calc_assembly_mahalanobis_and_link_probability_with_fake_kde(assembler_data_single_frame, make_assembler): - data, _, _ = assembler_data_single_frame - asm = make_assembler(data, min_n_links=1) + sframe_data = assembler_data_single_frame + asm = make_assembler(sframe_data.data, min_n_links=1) j0 = Joint((0.0, 0.0), 1.0, 0, 0) j1 = Joint((3.0, 4.0), 1.0, 1, 1) @@ -277,11 +277,9 @@ def test_calc_assembly_mahalanobis_and_link_probability_with_fake_kde(assembler_ def test_to_pickle_and_from_pickle(tmp_path, assembler_data_single_frame, make_assembler, assembler_graph_and_pafs): - data, _, _ = assembler_data_single_frame - graph, paf_inds = assembler_graph_and_pafs - - asm = make_assembler(data, min_n_links=1) - assemblies, _ = asm._assemble(data["0"], 0) + sframe_data = assembler_data_single_frame + asm = make_assembler(sframe_data.data, min_n_links=1) + assemblies, _ = asm._assemble(sframe_data.data["0"], 0) asm.assemblies = {0: assemblies} pkl = tmp_path / "assemb.pkl" @@ -291,8 +289,8 @@ def test_to_pickle_and_from_pickle(tmp_path, assembler_data_single_frame, make_a max_n_individuals=2, n_multibodyparts=2, n_uniquebodyparts=0, - graph=graph, - paf_inds=paf_inds, + graph=sframe_data.graph, + paf_inds=sframe_data.paf_inds, ) new_asm.from_pickle(str(pkl)) @@ -305,10 +303,10 @@ def test_to_pickle_and_from_pickle(tmp_path, assembler_data_single_frame, make_a reason="requires PyTables", ) def test_to_h5_roundtrip(tmp_path, assembler_data_single_frame, make_assembler): - data, _, _ = assembler_data_single_frame + sframe_data = assembler_data_single_frame - asm = make_assembler(data, min_n_links=1) - assemblies, _ = asm._assemble(data["0"], 0) + asm = make_assembler(sframe_data.data, min_n_links=1) + assemblies, _ = asm._assemble(sframe_data.data["0"], 0) asm.assemblies = {0: assemblies} h5 = tmp_path / "assemb.h5" diff --git a/tests/tests_core/test_assembly.py b/tests/tests_core/test_assembly.py index 3cc3c84..6e7aec6 100644 --- a/tests/tests_core/test_assembly.py +++ b/tests/tests_core/test_assembly.py @@ -1,8 +1,13 @@ import numpy as np import pytest +from hypothesis import assume, given, settings +from hypothesis import strategies as st +from hypothesis.extra.numpy import arrays from dlclive.core.inferenceutils import Assembly +HYPOTHESIS_SETTINGS = settings(max_examples=200, deadline=None) + # --------------------------- # Basic construction @@ -28,26 +33,48 @@ def test_assembly_init(make_assembly): # --------------------------- # from_array # --------------------------- -def test_assembly_from_array_basic_xy_only(): - arr = np.array( - [ - [10.0, 20.0], - [30.0, 40.0], - ] +@HYPOTHESIS_SETTINGS +@given( + st.integers(min_value=1, max_value=30).flatmap( + lambda n_rows: st.sampled_from([2, 3]).flatmap( + lambda n_cols: arrays( + dtype=np.float64, + shape=(n_rows, n_cols), + elements=st.floats(allow_infinity=False, allow_nan=True, width=32), + ) + ) ) +) +def test_from_array_invariants(arr): + n_rows, n_cols = arr.shape + assemb = Assembly.from_array(arr.copy()) + assert assemb.data.shape == (n_rows, 4) - # full shape (n_bodyparts, 4) - assert assemb.data.shape == (2, 4) + # Row is "valid" iff it has no NaN among the provided columns + row_valid = ~np.isnan(arr).any(axis=1) + visible = set(np.flatnonzero(row_valid).tolist()) + assert assemb._visible == visible - # xy preserved - assert np.allclose(assemb.data[:, :2], arr) + out = assemb.data - # confidence auto-set to 1 where xy is present - assert np.allclose(assemb.data[:, 2], np.array([1.0, 1.0])) + # For invalid rows: x/y must be NaN + assert np.all(np.isnan(out[~row_valid, 0])) + assert np.all(np.isnan(out[~row_valid, 1])) - # labels visible - assert assemb._visible == {0, 1} + # Confidence behavior differs depending on number of columns + if n_cols == 2: + # XY-only input: confidence starts at 0 and is set to 1 for visible rows only + assert np.all(out[row_valid, 2] == pytest.approx(1.0)) + assert np.all(out[~row_valid, 2] == pytest.approx(0.0)) + else: + # XY+confidence input: confidence is preserved for visible rows + assert np.allclose(out[row_valid, 2], arr[row_valid, 2], equal_nan=False) + # Invalid rows become NaN in all provided columns, including confidence + assert np.all(np.isnan(out[~row_valid, 2])) + + # Visible rows preserve xy + assert np.allclose(out[row_valid, :2], arr[row_valid, :2], equal_nan=False) def test_assembly_from_array_with_nans(): @@ -66,6 +93,39 @@ def test_assembly_from_array_with_nans(): assert assemb._visible == {0} +# --------------------------- +# extent, area, xy +# --------------------------- +@HYPOTHESIS_SETTINGS +@given( + coords=arrays( + dtype=np.float64, + shape=st.tuples(st.integers(1, 30), st.just(2)), + elements=st.floats(allow_nan=True, allow_infinity=False, width=32), + ) +) +def test_extent_matches_visible_points(coords): + xy = coords.copy() + a = Assembly(size=xy.shape[0]) + a.data[:] = np.nan + a.data[:, :2] = xy + a._visible = set(np.flatnonzero(~np.isnan(xy).any(axis=1)).tolist()) + + visible_mask = ~np.isnan(coords).any(axis=1) + assume(visible_mask.any()) + + expected = np.array( + [ + coords[visible_mask, 0].min(), + coords[visible_mask, 1].min(), + coords[visible_mask, 0].max(), + coords[visible_mask, 1].max(), + ] + ) + assert np.allclose(a.extent, expected) + assert a.area >= 0 + + # --------------------------- # add_joint / remove_joint # --------------------------- @@ -123,46 +183,17 @@ def test_add_link_adds_joints_and_affinity(make_assembly, make_joint, make_link) assert assemb._affinity == pytest.approx(1.4) # 0.7 + 0.7 -# --------------------------- -# extent, area, xy -# --------------------------- - - -def test_extent_and_area(make_assembly): - assemb = make_assembly(size=3) - # manually set data: [x, y, conf, group] - assemb.data[:] = np.nan - assemb.data[0, :2] = [10, 10] - assemb.data[1, :2] = [20, 40] - assemb._visible.update({0, 1}) - - # extent = (min_x, min_y, max_x, max_y) - assert np.allclose(assemb.extent, [10, 10, 20, 40]) - - # area = dx * dy = (20-10) * (40-10) = 10 * 30 - assert assemb.area == pytest.approx(300) - - # --------------------------- # intersection_with # --------------------------- - - def test_intersection_with_partial_overlap(two_overlap_assemblies): ass1, ass2 = two_overlap_assemblies - - # They overlap in a square of area 5x5 around (5,5)-(10,10). - # Each assembly has 2 points. Points inside overlap: - # ass1: both (0,0) no, (10,10) yes → 1 / 2 = 0.5 - # ass2: (5,5) yes, (15,15) no → 1 / 2 = 0.5 assert ass1.intersection_with(ass2) == pytest.approx(0.5) # --------------------------- # confidence property # --------------------------- - - def test_confidence_property(make_assembly): assemb = make_assembly(size=3) assemb.data[:] = np.nan @@ -176,15 +207,8 @@ def test_confidence_property(make_assembly): # --------------------------- # soft_identity # --------------------------- - - def test_soft_identity_simple(soft_identity_assembly): - # data format: x, y, conf, group assemb = soft_identity_assembly - - # groups: 0 → weights 1.0 and 0.5 (avg=0.75) - # 1 → weight 1.0 - # softmax([0.75, 1.0]) ≈ [...] soft = assemb.soft_identity assert set(soft.keys()) == {0, 1} s0, s1 = soft[0], soft[1] @@ -195,8 +219,6 @@ def test_soft_identity_simple(soft_identity_assembly): # --------------------------- # intersection operator: __contains__ # --------------------------- - - def test_contains_checks_shared_idx(make_assembly, make_joint): ass1 = make_assembly(size=3) ass2 = make_assembly(size=3) @@ -218,16 +240,14 @@ def test_contains_checks_shared_idx(make_assembly, make_joint): # --------------------------- # assembly addition (__add__) # --------------------------- - - def test_assembly_addition_combines_links(make_assembly, four_joint_chain): a1 = make_assembly(size=4) a2 = make_assembly(size=4) - j0, _, _, _, l01, l23 = four_joint_chain + chain = four_joint_chain - a1.add_link(l01) - a2.add_link(l23) + a1.add_link(chain.l01) + a2.add_link(chain.l23) # now they share NO joints → addition should succeed result = a1 + a2 @@ -240,6 +260,6 @@ def test_assembly_addition_combines_links(make_assembly, four_joint_chain): assert a2.n_links == 1 # now purposely make them share a joint → should raise - a2.add_joint(j0) + a2.add_joint(chain.j0) with pytest.raises(ArithmeticError): _ = a1 + a2 From 8bcf58bdff52f89505d2209171b5f667e1b7d170 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 2 Feb 2026 12:30:07 +0100 Subject: [PATCH 14/20] Add Hypothesis tests for assembler Introduce property-based tests for assembler utilities using Hypothesis: add test_condensed_index_properties for _conv_square_to_condensed_indices, a composite coords_and_conf strategy, test_flatten_detections_counts to validate _flatten_detections output counts, and a property test for extract_best_links (greedy) that asserts affinity, confidence-product (pcutoff) and disjointness invariants. Add Hypothesis imports, settings and numpy-array strategies. Also adjust test_assembly to ensure rows containing any NaN are set fully to NaN so the test matches Assembly.from_array behavior. --- tests/tests_core/test_assembler.py | 145 ++++++++++++++++++++++++++++- tests/tests_core/test_assembly.py | 3 + 2 files changed, 143 insertions(+), 5 deletions(-) diff --git a/tests/tests_core/test_assembler.py b/tests/tests_core/test_assembler.py index 1604cbc..fe96cac 100644 --- a/tests/tests_core/test_assembler.py +++ b/tests/tests_core/test_assembler.py @@ -5,8 +5,13 @@ import numpy as np import pandas as pd import pytest +from hypothesis import given, settings +from hypothesis import strategies as st +from hypothesis.extra.numpy import arrays -from dlclive.core.inferenceutils import Assembler, Assembly, Joint, Link +from dlclive.core.inferenceutils import Assembler, Assembly, Joint, Link, _conv_square_to_condensed_indices + +HYPOTHESIS_SETTINGS = settings(max_examples=300, deadline=None) def _bag_from_frame(frame: dict) -> dict[int, list]: @@ -17,6 +22,29 @@ def _bag_from_frame(frame: dict) -> dict[int, list]: return bag +# _conv_square_to_condensed_indices +@HYPOTHESIS_SETTINGS +@given( + n=st.integers(min_value=2, max_value=50), + i=st.integers(min_value=0, max_value=49), + j=st.integers(min_value=0, max_value=49), +) +def test_condensed_index_properties(n, i, j): + i = i % n + j = j % n + + if i == j: + with pytest.raises(ValueError): + _conv_square_to_condensed_indices(i, j, n) + return + + k1 = _conv_square_to_condensed_indices(i, j, n) + k2 = _conv_square_to_condensed_indices(j, i, n) + + assert k1 == k2 + assert 0 <= k1 < (n * (n - 1)) // 2 + + # -------------------------------------------------------------------------------------- # Basic metadata and __getitem__ # -------------------------------------------------------------------------------------- @@ -57,8 +85,6 @@ def test_empty_classmethod(assembler_graph_and_pafs): # -------------------------------------------------------------------------------------- # _flatten_detections # -------------------------------------------------------------------------------------- - - def test_flatten_detections_no_identity(simple_two_label_scene): frame = simple_two_label_scene joints = list(Assembler._flatten_detections(frame)) @@ -82,11 +108,52 @@ def test_flatten_detections_with_identity(scene_copy): assert groups.count(1) == 2 +@st.composite +def coords_and_conf(draw, max_n=5): + n = draw(st.integers(1, max_n)) + coords = draw( + arrays( + dtype=np.float64, + shape=(n, 2), + elements=st.floats(min_value=0.1, max_value=1000, allow_nan=False, allow_infinity=False), + ) + ) + conf = draw( + arrays( + dtype=np.float64, + shape=(n,), + elements=st.floats(min_value=0.0, max_value=1.0, allow_nan=False, allow_infinity=False), + ) + ) + return coords, conf + + +@HYPOTHESIS_SETTINGS +@given( + c0=coords_and_conf(), + c1=coords_and_conf(), +) +def test_flatten_detections_counts(c0, c1): + coords0, conf0 = c0 + coords1, conf1 = c1 + + frame = { + "coordinates": [[coords0, coords1]], + "confidence": [conf0, conf1], + "costs": {}, + } + + joints = list(Assembler._flatten_detections(frame)) + + # Should yield exactly one Joint per detection + assert len(joints) == (len(coords0) + len(coords1)) + assert sum(j.label == 0 for j in joints) == len(coords0) + assert sum(j.label == 1 for j in joints) == len(coords1) + + # -------------------------------------------------------------------------------------- # extract_best_links # -------------------------------------------------------------------------------------- - - def test_extract_best_links_optimal_assignment(assembler_data_single_frame, make_assembler): sframe_data = assembler_data_single_frame asm = make_assembler( @@ -134,6 +201,74 @@ def test_extract_best_links_greedy_with_thresholds(assembler_data_single_frame, ) +@HYPOTHESIS_SETTINGS +@given( + n=st.integers(min_value=1, max_value=4), + pcutoff=st.floats(min_value=0.0, max_value=1.0, allow_nan=False, allow_infinity=False), + min_aff=st.floats(min_value=0.0, max_value=1.0, allow_nan=False, allow_infinity=False), + conf0=st.lists(st.floats(0.0, 1.0, allow_nan=False, allow_infinity=False), min_size=1, max_size=4), + conf1=st.lists(st.floats(0.0, 1.0, allow_nan=False, allow_infinity=False), min_size=1, max_size=4), +) +def test_extract_best_links_greedy_invariants_with_threshold_gates(n, pcutoff, min_aff, conf0, conf1): + # Normalize confidences to exactly n items + conf0 = (conf0 + [0.0] * n)[:n] + conf1 = (conf1 + [0.0] * n)[:n] + conf0 = np.array(conf0, dtype=float) + conf1 = np.array(conf1, dtype=float) + + # Random-ish affinity matrix (still stable), in [0,1] + rng = np.random.default_rng(0) # deterministic noise + aff = rng.random((n, n)) # uniform [0,1) + # Ensure at least one "good" candidate sometimes; otherwise test is vacuously true. + # We'll only assert gated properties on returned links anyway. + # But for better coverage, bias the diagonal upward a bit: + np.fill_diagonal(aff, np.maximum(np.diag(aff), 0.8)) + dist = np.ones((n, n), dtype=float) + + graph = [(0, 1)] + paf_inds = [0] + data = { + "metadata": {"all_joints_names": ["b0", "b1"], "PAFgraph": graph, "PAFinds": paf_inds}, + "0": {}, + } + + asm = Assembler( + data, + max_n_individuals=n, + n_multibodyparts=2, + greedy=True, + pcutoff=pcutoff, + min_affinity=min_aff, + min_n_links=1, + method="m1", + ) + + dets0 = [Joint((float(i), 0.0), float(conf0[i]), label=0, idx=i) for i in range(n)] + dets1 = [Joint((float(i), 1.0), float(conf1[i]), label=1, idx=100 + i) for i in range(n)] + joints_dict = {0: dets0, 1: dets1} + costs = {0: {"distance": dist, "m1": aff}} + + links = asm.extract_best_links(joints_dict, costs, trees=None) + + assert len(links) <= n + + used_src = set() + used_tgt = set() + + for link in links: + # Invariant 1: affinity gate + assert link.affinity >= min_aff + + # Invariant 2: pcutoff gate (confidence product) + assert link.j1.confidence * link.j2.confidence >= pcutoff * pcutoff + + # Invariant 3: disjointness in greedy selection + assert link.j1.idx not in used_src + assert link.j2.idx not in used_tgt + used_src.add(link.j1.idx) + used_tgt.add(link.j2.idx) + + # -------------------------------------------------------------------------------------- # build_assemblies # -------------------------------------------------------------------------------------- diff --git a/tests/tests_core/test_assembly.py b/tests/tests_core/test_assembly.py index 6e7aec6..55ffa08 100644 --- a/tests/tests_core/test_assembly.py +++ b/tests/tests_core/test_assembly.py @@ -106,6 +106,9 @@ def test_assembly_from_array_with_nans(): ) def test_extent_matches_visible_points(coords): xy = coords.copy() + # Ensure rows with any NaN are fully NaN, + # matching Assembly's from_array behavior + xy[np.isnan(xy).any(axis=1)] = np.nan a = Assembly(size=xy.shape[0]) a.data[:] = np.nan a.data[:, :2] = xy From 80027cdb2039d09ce82c48a10478fb097a29ee31 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 2 Feb 2026 15:43:16 +0100 Subject: [PATCH 15/20] Revert CI changes for now (commented for later) --- .github/workflows/testing.yml | 36 ++++++++++++++++++----------------- pyproject.toml | 16 ++++++++-------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index dc3b765..28d8b6c 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -43,7 +43,7 @@ jobs: df -h - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v4 # uses: actions/checkout@v6 - name: Install uv uses: astral-sh/setup-uv@v6 @@ -72,20 +72,22 @@ jobs: run: uv run dlc-live-test --nodisplay - name: Run DLC Live Unit Tests - run: uv run pytest --cov=dlclive --cov-report=xml --cov-report=term-missing + run: uv run pytest + # - name: Run DLC Live Unit Tests + # run: uv run pytest --cov=dlclive --cov-report=xml --cov-report=term-missing - - name: Coverage Report - uses: codecov/codecov-action@v5 - with: - files: ./coverage.xml - flags: ${{ matrix.os }}-py${{ matrix.python-version }} - name: codecov-${{ matrix.os }}-py${{ matrix.python-version }} - - name: Add coverage to job summary - if: always() - shell: bash - run: | - uv run python -m coverage report -m > coverage.txt - echo "## Coverage (dlclive)" >> "$GITHUB_STEP_SUMMARY" - echo '```' >> "$GITHUB_STEP_SUMMARY" - cat coverage.txt >> "$GITHUB_STEP_SUMMARY" - echo '```' >> "$GITHUB_STEP_SUMMARY" + # - name: Coverage Report + # uses: codecov/codecov-action@v5 + # with: + # files: ./coverage.xml + # flags: ${{ matrix.os }}-py${{ matrix.python-version }} + # name: codecov-${{ matrix.os }}-py${{ matrix.python-version }} + # - name: Add coverage to job summary + # if: always() + # shell: bash + # run: | + # uv run python -m coverage report -m > coverage.txt + # echo "## Coverage (dlclive)" >> "$GITHUB_STEP_SUMMARY" + # echo '```' >> "$GITHUB_STEP_SUMMARY" + # cat coverage.txt >> "$GITHUB_STEP_SUMMARY" + # echo '```' >> "$GITHUB_STEP_SUMMARY" diff --git a/pyproject.toml b/pyproject.toml index 54efe23..8ce8c13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,12 +91,12 @@ include = ["dlclive*"] [tool.setuptools.package-data] dlclive = ["check_install/*"] -[tool.ruff] -lint.select = ["E", "F", "B", "I", "UP"] -lint.ignore = ["E741"] -target-version = "py310" -fix = true -line-length = 120 +# [tool.ruff] +# lint.select = ["E", "F", "B", "I", "UP"] +# lint.ignore = ["E741"] +# target-version = "py310" +# fix = true +# line-length = 120 -[tool.ruff.lint.pydocstyle] -convention = "google" +# [tool.ruff.lint.pydocstyle] +# convention = "google" From cfba3264cd5e99a342185860cd4c9926bfbd885c Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 3 Feb 2026 09:19:35 +0100 Subject: [PATCH 16/20] Update utils.py --- dlclive/modelzoo/utils.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/dlclive/modelzoo/utils.py b/dlclive/modelzoo/utils.py index 341bd5c..e7224f8 100644 --- a/dlclive/modelzoo/utils.py +++ b/dlclive/modelzoo/utils.py @@ -31,7 +31,7 @@ def get_super_animal_project_config_path(super_animal: str) -> Path: cfg_path = _MODELZOO_PATH / "project_configs" / f"{super_animal}.yaml" if not cfg_path.exists(): raise FileNotFoundError( - f"Modelzoo project configuration file not found: {cfg_path}Available projects: {list_available_projects()}" + f"Modelzoo project configuration file not found: {cfg_path} Available projects: {list_available_projects()}" ) return cfg_path @@ -89,7 +89,8 @@ def add_metadata( config["metadata"] = { "project_path": project_config["project_path"], "pose_config_path": "", - "bodyparts": project_config.get("multianimalbodyparts") or project_config["bodyparts"], + "bodyparts": project_config.get("multianimalbodyparts") + or project_config["bodyparts"], "unique_bodyparts": project_config.get("uniquebodyparts", []), "individuals": project_config.get("individuals", ["animal"]), "with_identity": project_config.get("identity", False), @@ -131,7 +132,9 @@ def load_super_animal_config( else: model_config["method"] = "TD" if super_animal != "superanimal_humanbody": - detector_cfg_path = get_super_animal_model_config_path(model_name=detector_name) + detector_cfg_path = get_super_animal_model_config_path( + model_name=detector_name + ) detector_cfg = read_config_as_dict(detector_cfg_path) model_config["detector"] = detector_cfg return model_config @@ -160,13 +163,17 @@ def download_super_animal_snapshot(dataset: str, model_name: str) -> Path: return model_path try: - download_huggingface_model(model_name, target_dir=str(snapshot_dir), rename_mapping=model_filename) + download_huggingface_model( + model_name, target_dir=str(snapshot_dir), rename_mapping=model_filename + ) if not model_path.exists(): raise RuntimeError(f"Failed to download {model_name} to {model_path}") except Exception as e: - logging.error(f"Failed to download superanimal snapshot {model_name} to {model_path}: {e}") + logging.error( + f"Failed to download superanimal snapshot {model_name} to {model_path}: {e}" + ) raise e return model_path From 9a30d985e2aba9c73cf104b37acf178a7244764d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 3 Feb 2026 12:07:30 +0100 Subject: [PATCH 17/20] Fix image size indexing when drawing keypoints Corrects swapped im_size indexing when clamping ellipse coordinates in Display.draw. The previous code used im_size[1] for x1 and im_size[0] for y1, which reversed width/height fallbacks and could produce incorrect bounds; x1 now uses im_size[0] (width) and y1 uses im_size[1] (height). This prevents drawing outside the image or using wrong coordinates. --- dlclive/display.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/dlclive/display.py b/dlclive/display.py index 0304552..9d4de43 100644 --- a/dlclive/display.py +++ b/dlclive/display.py @@ -35,7 +35,9 @@ class Display: def __init__(self, cmap="bmy", radius=3, pcutoff=0.5): if not _TKINTER_AVAILABLE: - raise ImportError("tkinter is not available. Display functionality requires tkinter. ") + raise ImportError( + "tkinter is not available. Display functionality requires tkinter. " + ) self.cmap = cmap self.colors = None self.radius = radius @@ -93,12 +95,30 @@ def display_frame(self, frame, pose=None): for j in range(pose.shape[1]): if pose[i, j, 2] > self.pcutoff: try: - x0 = pose[i, j, 0] - self.radius if pose[i, j, 0] - self.radius > 0 else 0 - x1 = pose[i, j, 0] + self.radius if pose[i, j, 0] + self.radius < im_size[0] else im_size[1] - y0 = pose[i, j, 1] - self.radius if pose[i, j, 1] - self.radius > 0 else 0 - y1 = pose[i, j, 1] + self.radius if pose[i, j, 1] + self.radius < im_size[1] else im_size[0] + x0 = ( + pose[i, j, 0] - self.radius + if pose[i, j, 0] - self.radius > 0 + else 0 + ) + x1 = ( + pose[i, j, 0] + self.radius + if pose[i, j, 0] + self.radius < im_size[0] + else im_size[0] + ) + y0 = ( + pose[i, j, 1] - self.radius + if pose[i, j, 1] - self.radius > 0 + else 0 + ) + y1 = ( + pose[i, j, 1] + self.radius + if pose[i, j, 1] + self.radius < im_size[1] + else im_size[1] + ) coords = [x0, y0, x1, y1] - draw.ellipse(coords, fill=self.colors[j], outline=self.colors[j]) + draw.ellipse( + coords, fill=self.colors[j], outline=self.colors[j] + ) except Exception as e: print(e) img_tk = ImageTk.PhotoImage(image=img, master=self.window) From 679919d546b5b7c9a55dcd40c9c24e88e8fa90c1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 3 Feb 2026 12:07:45 +0100 Subject: [PATCH 18/20] Update test_display.py --- tests/test_display.py | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/tests/test_display.py b/tests/test_display.py index e554555..aa6821a 100644 --- a/tests/test_display.py +++ b/tests/test_display.py @@ -39,36 +39,55 @@ def test_display_frame_creates_window_and_updates(headless_display_env): env.tk.update.assert_called_once_with() -def test_display_draws_only_points_above_cutoff(headless_display_env, monkeypatch): +def test_display_draws_only_points_above_cutoff_with_clamping( + headless_display_env, monkeypatch +): env = headless_display_env display_mod = env.mod disp = display_mod.Display(radius=3, pcutoff=0.5) + r = disp.radius - # Patch colormap so color sampling is deterministic and always long enough + # Fake colors class FakeCC: - bmy = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 0)] + bmy = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] monkeypatch.setattr(display_mod, "cc", FakeCC) - frame = np.zeros((100, 100, 3), dtype=np.uint8) + frame = np.zeros((50, 50, 3), dtype=np.uint8) + h, w = frame.shape[:2] + pose = np.array( [ [ - [10, 10, 0.9], # draw - [20, 20, 0.49], # don't draw - [30, 30, 0.5001], # draw (> pcutoff) + [-1, -1, 0.9], # top-left offscreen + [48, 48, 0.9], # bottom-right edge + [25, 25, 0.4], # below cutoff ] ], dtype=float, ) - draw = MagicMock(name="DrawInstance") + draw = MagicMock() monkeypatch.setattr(display_mod.ImageDraw, "Draw", MagicMock(return_value=draw)) disp.display_frame(frame, pose) - # Two points above cutoff => two ellipse calls assert draw.ellipse.call_count == 2 + calls = draw.ellipse.call_args_list + + def expected_coords(x, y): + return [ + max(0, x - r), + max(0, y - r), + min(w, x + r), + min(h, y + r), + ] + + # First point + assert calls[0].args[0] == expected_coords(-1, -1) + + # Second point + assert calls[1].args[0] == expected_coords(48, 48) def test_destroy_calls_window_destroy(headless_display_env): From 03fe25ac5c09c048b69b6106f034ebf120b62a20 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 3 Feb 2026 12:56:24 +0100 Subject: [PATCH 19/20] Fix dates in comments --- dlclive/core/inferenceutils.py | 121 +++++++++++++----- dlclive/modelzoo/resolve_config.py | 11 +- dlclive/modelzoo/utils.py | 4 +- .../dynamic_cropping.py | 22 +++- tests/test_modelzoo.py | 6 +- 5 files changed, 121 insertions(+), 43 deletions(-) diff --git a/dlclive/core/inferenceutils.py b/dlclive/core/inferenceutils.py index e69ff5c..7b76b27 100644 --- a/dlclive/core/inferenceutils.py +++ b/dlclive/core/inferenceutils.py @@ -10,7 +10,7 @@ # -# NOTE - DUPLICATED @C-Achard 2026-26-01: Copied from the original DeepLabCut codebase +# NOTE - DUPLICATED @C-Achard 2026-01-26: Copied from the original DeepLabCut codebase # from deeplabcut/core/inferenceutils.py from __future__ import annotations @@ -66,7 +66,9 @@ def __init__(self, j1, j2, affinity=1): self._length = sqrt((j1.pos[0] - j2.pos[0]) ** 2 + (j1.pos[1] - j2.pos[1]) ** 2) def __repr__(self): - return f"Link {self.idx}, affinity={self.affinity:.2f}, length={self.length:.2f}" + return ( + f"Link {self.idx}, affinity={self.affinity:.2f}, length={self.length:.2f}" + ) @property def confidence(self): @@ -264,7 +266,10 @@ def __init__( self.max_overlap = max_overlap self._has_identity = "identity" in self[0] if identity_only and not self._has_identity: - warnings.warn("The network was not trained with identity; setting `identity_only` to False.", stacklevel=2) + warnings.warn( + "The network was not trained with identity; setting `identity_only` to False.", + stacklevel=2, + ) self.identity_only = identity_only & self._has_identity self.nan_policy = nan_policy self.force_fusion = force_fusion @@ -345,7 +350,9 @@ def calibrate(self, train_data_file): pass n_bpts = len(df.columns.get_level_values("bodyparts").unique()) if n_bpts == 1: - warnings.warn("There is only one keypoint; skipping calibration...", stacklevel=2) + warnings.warn( + "There is only one keypoint; skipping calibration...", stacklevel=2 + ) return xy = df.to_numpy().reshape((-1, n_bpts, 2)) @@ -353,7 +360,9 @@ def calibrate(self, train_data_file): # Only keeps skeletons that are more than 90% complete xy = xy[frac_valid >= 0.9] if not xy.size: - warnings.warn("No complete poses were found. Skipping calibration...", stacklevel=2) + warnings.warn( + "No complete poses were found. Skipping calibration...", stacklevel=2 + ) return # TODO Normalize dists by longest length? @@ -369,9 +378,14 @@ def calibrate(self, train_data_file): self.safe_edge = True except np.linalg.LinAlgError: # Covariance matrix estimation fails due to numerical singularities - warnings.warn("The assembler could not be robustly calibrated. Continuing without it...", stacklevel=2) + warnings.warn( + "The assembler could not be robustly calibrated. Continuing without it...", + stacklevel=2, + ) - def calc_assembly_mahalanobis_dist(self, assembly, return_proba=False, nan_policy="little"): + def calc_assembly_mahalanobis_dist( + self, assembly, return_proba=False, nan_policy="little" + ): if self._kde is None: raise ValueError("Assembler should be calibrated first with training data.") @@ -425,7 +439,9 @@ def _flatten_detections(data_dict): ids = [np.ones(len(arr), dtype=int) * -1 for arr in confidence] else: ids = [arr.argmax(axis=1) for arr in ids] - for i, (coords, conf, id_) in enumerate(zip(coordinates, confidence, ids, strict=False)): + for i, (coords, conf, id_) in enumerate( + zip(coordinates, confidence, ids, strict=False) + ): if not np.any(coords): continue for xy, p, g in zip(coords, conf, id_, strict=False): @@ -450,7 +466,9 @@ def extract_best_links(self, joints_dict, costs, trees=None): aff[np.isnan(aff)] = 0 if trees: - vecs = np.vstack([[*det_s.pos, *det_t.pos] for det_s in dets_s for det_t in dets_t]) + vecs = np.vstack( + [[*det_s.pos, *det_t.pos] for det_s in dets_s for det_t in dets_t] + ) dists = [] for n, tree in enumerate(trees, start=1): d, _ = tree.query(vecs) @@ -459,8 +477,15 @@ def extract_best_links(self, joints_dict, costs, trees=None): aff *= w.reshape(aff.shape) if self.greedy: - conf = np.asarray([[det_s.confidence * det_t.confidence for det_t in dets_t] for det_s in dets_s]) - rows, cols = np.where((conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity)) + conf = np.asarray( + [ + [det_s.confidence * det_t.confidence for det_t in dets_t] + for det_s in dets_s + ] + ) + rows, cols = np.where( + (conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity) + ) candidates = sorted( zip(rows, cols, aff[rows, cols], lengths[rows, cols], strict=False), key=lambda x: x[2], @@ -476,14 +501,18 @@ def extract_best_links(self, joints_dict, costs, trees=None): if len(i_seen) == self.max_n_individuals: break else: # Optimal keypoint pairing - inds_s = sorted(range(len(dets_s)), key=lambda x: dets_s[x].confidence, reverse=True)[ - : self.max_n_individuals + inds_s = sorted( + range(len(dets_s)), key=lambda x: dets_s[x].confidence, reverse=True + )[: self.max_n_individuals] + inds_t = sorted( + range(len(dets_t)), key=lambda x: dets_t[x].confidence, reverse=True + )[: self.max_n_individuals] + keep_s = [ + ind for ind in inds_s if dets_s[ind].confidence >= self.pcutoff ] - inds_t = sorted(range(len(dets_t)), key=lambda x: dets_t[x].confidence, reverse=True)[ - : self.max_n_individuals + keep_t = [ + ind for ind in inds_t if dets_t[ind].confidence >= self.pcutoff ] - keep_s = [ind for ind in inds_s if dets_s[ind].confidence >= self.pcutoff] - keep_t = [ind for ind in inds_t if dets_t[ind].confidence >= self.pcutoff] aff = aff[np.ix_(keep_s, keep_t)] rows, cols = linear_sum_assignment(aff, maximize=True) for row, col in zip(rows, cols, strict=False): @@ -522,7 +551,9 @@ def push_to_stack(i): if new_ind in assembled: continue if safe_edge: - d_old = self.calc_assembly_mahalanobis_dist(assembly, nan_policy=nan_policy) + d_old = self.calc_assembly_mahalanobis_dist( + assembly, nan_policy=nan_policy + ) success = assembly.add_link(best, store_dict=True) if not success: assembly._dict = dict() @@ -575,7 +606,9 @@ def build_assemblies(self, links): continue assembly = Assembly(self.n_multibodyparts) assembly.add_link(link) - self._fill_assembly(assembly, lookup, assembled, self.safe_edge, self.nan_policy) + self._fill_assembly( + assembly, lookup, assembled, self.safe_edge, self.nan_policy + ) for assembly_link in assembly._links: i, j = assembly_link.idx lookup[i].pop(j) @@ -587,7 +620,10 @@ def build_assemblies(self, links): n_extra = len(assemblies) - self.max_n_individuals if n_extra > 0: if self.safe_edge: - ds_old = [self.calc_assembly_mahalanobis_dist(assembly) for assembly in assemblies] + ds_old = [ + self.calc_assembly_mahalanobis_dist(assembly) + for assembly in assemblies + ] while len(assemblies) > self.max_n_individuals: ds = [] for i, j in itertools.combinations(range(len(assemblies)), 2): @@ -719,7 +755,10 @@ def _assemble(self, data_dict, ind_frame): for _, group in groups: ass = Assembly(self.n_multibodyparts) for joint in sorted(group, key=lambda x: x.confidence, reverse=True): - if joint.confidence >= self.pcutoff and joint.label < self.n_multibodyparts: + if ( + joint.confidence >= self.pcutoff + and joint.label < self.n_multibodyparts + ): ass.add_joint(joint) if len(ass): assemblies.append(ass) @@ -748,7 +787,11 @@ def _assemble(self, data_dict, ind_frame): assembled.update(assembled_) # Remove invalid assemblies - discarded = set(joint for joint in joints if joint.idx not in assembled and np.isfinite(joint.confidence)) + discarded = set( + joint + for joint in joints + if joint.idx not in assembled and np.isfinite(joint.confidence) + ) for assembly in assemblies[::-1]: if 0 < assembly.n_links < self.min_n_links or not len(assembly): for link in assembly._links: @@ -756,7 +799,9 @@ def _assemble(self, data_dict, ind_frame): assemblies.remove(assembly) if 0 < self.max_overlap < 1: # Non-maximum pose suppression if self._kde is not None: - scores = [-self.calc_assembly_mahalanobis_dist(ass) for ass in assemblies] + scores = [ + -self.calc_assembly_mahalanobis_dist(ass) for ass in assemblies + ] else: scores = [ass._affinity for ass in assemblies] lst = list(zip(scores, assemblies, strict=False)) @@ -825,7 +870,9 @@ def wrapped(i): n_frames = len(self.metadata["imnames"]) with multiprocessing.Pool(n_processes) as p: with tqdm(total=n_frames) as pbar: - for i, (assemblies, unique) in p.imap_unordered(wrapped, range(n_frames), chunksize=chunk_size): + for i, (assemblies, unique) in p.imap_unordered( + wrapped, range(n_frames), chunksize=chunk_size + ): if assemblies: self.assemblies[i] = assemblies if unique is not None: @@ -844,7 +891,9 @@ def parse_metadata(data): params["joint_names"] = data["metadata"]["all_joints_names"] params["num_joints"] = len(params["joint_names"]) params["paf_graph"] = data["metadata"]["PAFgraph"] - params["paf"] = data["metadata"].get("PAFinds", np.arange(len(params["joint_names"]))) + params["paf"] = data["metadata"].get( + "PAFinds", np.arange(len(params["joint_names"])) + ) params["bpts"] = params["ibpts"] = range(params["num_joints"]) params["imnames"] = [fn for fn in list(data) if fn != "metadata"] return params @@ -934,7 +983,11 @@ def calc_object_keypoint_similarity( else: oks = [] xy_preds = [xy_pred] - combos = (pair for l in range(len(symmetric_kpts)) for pair in itertools.combinations(symmetric_kpts, l + 1)) + combos = ( + pair + for l in range(len(symmetric_kpts)) + for pair in itertools.combinations(symmetric_kpts, l + 1) + ) for pairs in combos: # Swap corresponding keypoints tmp = xy_pred.copy() @@ -971,7 +1024,9 @@ def match_assemblies( num_ground_truth = len(ground_truth) # Sort predictions by score - inds_pred = np.argsort([ins.affinity if ins.n_links else ins.confidence for ins in predictions])[::-1] + inds_pred = np.argsort( + [ins.affinity if ins.n_links else ins.confidence for ins in predictions] + )[::-1] predictions = np.asarray(predictions)[inds_pred] # indices of unmatched ground truth assemblies @@ -1078,7 +1133,9 @@ def find_outlier_assemblies(dict_of_assemblies, criterion="area", qs=(5, 95)): raise ValueError(f"Invalid criterion {criterion}.") if len(qs) != 2: - raise ValueError("Two percentiles (for lower and upper bounds) should be given.") + raise ValueError( + "Two percentiles (for lower and upper bounds) should be given." + ) tuples = [] for frame_ind, assemblies in dict_of_assemblies.items(): @@ -1182,7 +1239,9 @@ def evaluate_assembly_greedy( oks = np.asarray([match.oks for match in all_matched])[sorted_pred_indices] # Compute prediction and recall - p, r = _compute_precision_and_recall(total_gt_assemblies, oks, oks_t, recall_thresholds) + p, r = _compute_precision_and_recall( + total_gt_assemblies, oks, oks_t, recall_thresholds + ) precisions.append(p) recalls.append(r) @@ -1255,7 +1314,9 @@ def evaluate_assembly( precisions = [] recalls = [] for t in oks_thresholds: - p, r = _compute_precision_and_recall(total_gt_assemblies, oks, t, recall_thresholds) + p, r = _compute_precision_and_recall( + total_gt_assemblies, oks, t, recall_thresholds + ) precisions.append(p) recalls.append(r) diff --git a/dlclive/modelzoo/resolve_config.py b/dlclive/modelzoo/resolve_config.py index 4508eb0..cf11f3b 100644 --- a/dlclive/modelzoo/resolve_config.py +++ b/dlclive/modelzoo/resolve_config.py @@ -3,7 +3,7 @@ For instance, "num_bodyparts x 2" is replaced with the number of bodyparts multiplied by 2. """ -# NOTE - DUPLICATED @deruyter92 2026-23-01: Copied from the original DeepLabCut codebase +# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase # from deeplabcut/pose_estimation_pytorch/modelzoo/utils.py import copy @@ -99,7 +99,9 @@ def get_updated_value(variable: str) -> int | list[int]: else: raise ValueError(f"Unknown operator for variable: {variable}") - raise ValueError(f"Found {variable} in the configuration file, but cannot parse it.") + raise ValueError( + f"Found {variable} in the configuration file, but cannot parse it." + ) updated_values = { "num_bodyparts": num_bodyparts, @@ -125,7 +127,10 @@ def get_updated_value(variable: str) -> int | list[int]: backbone_output_channels, **kwargs, ) - elif isinstance(config[k], str) and config[k].strip().split(" ")[0] in updated_values.keys(): + elif ( + isinstance(config[k], str) + and config[k].strip().split(" ")[0] in updated_values.keys() + ): config[k] = get_updated_value(config[k]) return config diff --git a/dlclive/modelzoo/utils.py b/dlclive/modelzoo/utils.py index e7224f8..3857d14 100644 --- a/dlclive/modelzoo/utils.py +++ b/dlclive/modelzoo/utils.py @@ -69,7 +69,7 @@ def read_config_as_dict(config_path: str | Path) -> dict: return cfg -# NOTE - DUPLICATED @deruyter92 2026-23-01: Copied from the original DeepLabCut codebase +# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase # from deeplabcut/pose_estimation_pytorch/config/make_pose_config.py def add_metadata( project_config: dict, @@ -98,7 +98,7 @@ def add_metadata( return config -# NOTE - DUPLICATED @deruyter92 2026-23-01: Copied from the original DeepLabCut codebase +# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase # from deeplabcut/pose_estimation_pytorch/modelzoo/utils.py def load_super_animal_config( super_animal: str, diff --git a/dlclive/pose_estimation_pytorch/dynamic_cropping.py b/dlclive/pose_estimation_pytorch/dynamic_cropping.py index 27a1348..4572634 100644 --- a/dlclive/pose_estimation_pytorch/dynamic_cropping.py +++ b/dlclive/pose_estimation_pytorch/dynamic_cropping.py @@ -8,7 +8,7 @@ # # Licensed under GNU Lesser General Public License v3.0 -# NOTE DUPLICATED @C-Achard 2026-26-01: Duplication between this file +# NOTE DUPLICATED @C-Achard 2026-01-26: Duplication between this file # and deeplabcut/pose_estimation_pytorch/runners/dynamic_cropping.py # NOTE Testing already exists at deeplabcut/tests/pose_estimation_pytorch/runners/test_dynamic_cropper.py """Modules to dynamically crop individuals out of videos to improve video analysis""" @@ -82,7 +82,9 @@ def crop(self, image: torch.Tensor) -> torch.Tensor: height. """ if len(image) != 1: - raise RuntimeError(f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})") + raise RuntimeError( + f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})" + ) if self._shape is None: self._shape = image.shape[3], image.shape[2] @@ -307,7 +309,9 @@ def crop(self, image: torch.Tensor) -> torch.Tensor: `crop` was previously called with an image of a different W or H. """ if len(image) != 1: - raise RuntimeError(f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})") + raise RuntimeError( + f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})" + ) if self._shape is None: self._shape = image.shape[3], image.shape[2] @@ -394,7 +398,9 @@ def update(self, pose: torch.Tensor) -> torch.Tensor: return pose - def _prepare_bounding_box(self, x1: int, y1: int, x2: int, y2: int) -> tuple[int, int, int, int]: + def _prepare_bounding_box( + self, x1: int, y1: int, x2: int, y2: int + ) -> tuple[int, int, int, int]: """Prepares the bounding box for cropping. Adds a margin around the bounding box, then transforms it into the target aspect @@ -491,8 +497,12 @@ def generate_patches(self) -> list[tuple[int, int, int, int]]: Returns: A list of patch coordinates as tuples (x0, y0, x1, y1). """ - patch_xs = self.split_array(self._shape[0], self._patch_counts[0], self._patch_overlap) - patch_ys = self.split_array(self._shape[1], self._patch_counts[1], self._patch_overlap) + patch_xs = self.split_array( + self._shape[0], self._patch_counts[0], self._patch_overlap + ) + patch_ys = self.split_array( + self._shape[1], self._patch_counts[1], self._patch_overlap + ) patches = [] for y0, y1 in patch_ys: diff --git a/tests/test_modelzoo.py b/tests/test_modelzoo.py index 997b0ba..c2a0d70 100644 --- a/tests/test_modelzoo.py +++ b/tests/test_modelzoo.py @@ -1,4 +1,4 @@ -# NOTE - DUPLICATED @deruyter92 2026-23-01: Copied from the original DeepLabCut codebase +# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase # from deeplabcut/tests/pose_estimation_pytorch/modelzoo/test_modelzoo_utils.py import os @@ -9,7 +9,9 @@ from dlclive import modelzoo -@pytest.mark.parametrize("super_animal", ["superanimal_quadruped", "superanimal_topviewmouse"]) +@pytest.mark.parametrize( + "super_animal", ["superanimal_quadruped", "superanimal_topviewmouse"] +) @pytest.mark.parametrize("model_name", ["hrnet_w32"]) @pytest.mark.parametrize("detector_name", [None, "fasterrcnn_resnet50_fpn_v2"]) def test_get_config_model_paths(super_animal, model_name, detector_name): From 822f5d4999818e88610f7148eff46c4d9916c30a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 3 Feb 2026 13:30:18 +0100 Subject: [PATCH 20/20] Use max/min for pose coordinate clamping Replace verbose conditional expressions with max/min to clamp pose-based ellipse coordinates (x0, x1, y0, y1) within image bounds. --- dlclive/display.py | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/dlclive/display.py b/dlclive/display.py index 9d4de43..42abab4 100644 --- a/dlclive/display.py +++ b/dlclive/display.py @@ -95,26 +95,10 @@ def display_frame(self, frame, pose=None): for j in range(pose.shape[1]): if pose[i, j, 2] > self.pcutoff: try: - x0 = ( - pose[i, j, 0] - self.radius - if pose[i, j, 0] - self.radius > 0 - else 0 - ) - x1 = ( - pose[i, j, 0] + self.radius - if pose[i, j, 0] + self.radius < im_size[0] - else im_size[0] - ) - y0 = ( - pose[i, j, 1] - self.radius - if pose[i, j, 1] - self.radius > 0 - else 0 - ) - y1 = ( - pose[i, j, 1] + self.radius - if pose[i, j, 1] + self.radius < im_size[1] - else im_size[1] - ) + x0 = max(0, pose[i, j, 0] - self.radius) + x1 = min(im_size[0], pose[i, j, 0] + self.radius) + y0 = max(0, pose[i, j, 1] - self.radius) + y1 = min(im_size[1], pose[i, j, 1] + self.radius) coords = [x0, y0, x1, y1] draw.ellipse( coords, fill=self.colors[j], outline=self.colors[j]