Skip to content
Open
2 changes: 1 addition & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jobs:
shell: bash

- name: Run Model Benchmark Test
run: uv run dlc-live-test --nodisplay
run: uv run dlc-live-test

- name: Run DLC Live Unit Tests
run: uv run pytest
Expand Down
21 changes: 15 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,27 @@ repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-docstring-first
- id: check-added-large-files
- id: check-yaml
- id: check-toml
- id: end-of-file-fixer
- id: name-tests-test
args: [--pytest-test-first]
- id: trailing-whitespace
- repo: https://github.com/asottile/setup-cfg-fmt
rev: v3.2.0
- id: check-merge-conflict
- repo: https://github.com/tox-dev/pyproject-fmt
rev: v2.15.2
hooks:
- id: setup-cfg-fmt
- id: pyproject-fmt
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.25
hooks:
- id: validate-pyproject
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.10
rev: v0.15.0
hooks:
# Run the formatter.
- id: ruff-format
# Run the linter.
- id: ruff-check
args: [--fix,--unsafe-fixes]
args: [--fix,--unsafe-fixes]
105 changes: 59 additions & 46 deletions dlclive/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import sys
import time
import warnings
from typing import TYPE_CHECKING
from pathlib import Path

import argparse
import os
import colorcet as cc
import cv2
import numpy as np
Expand All @@ -23,10 +25,12 @@

from dlclive import DLCLive
from dlclive import VERSION
from dlclive import __file__ as dlcfile
from dlclive.engine import Engine
from dlclive.utils import decode_fourcc

if TYPE_CHECKING:
import tensorflow # type: ignore


def download_benchmarking_data(
target_dir=".",
Expand All @@ -49,17 +53,20 @@ def download_benchmarking_data(
if os.path.exists(zip_path):
print(f"{zip_path} already exists. Skipping download.")
else:

def show_progress(count, block_size, total_size):
pbar.update(block_size)

print(f"Downloading the benchmarking data from {url} ...")
pbar = tqdm(unit="B", total=0, position=0, desc="Downloading")

filename, _ = urllib.request.urlretrieve(url, filename=zip_path, reporthook=show_progress)
filename, _ = urllib.request.urlretrieve(
url, filename=zip_path, reporthook=show_progress
)
pbar.close()

print(f"Extracting {zip_path} to {target_dir} ...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(target_dir)


Expand All @@ -81,6 +88,7 @@ def benchmark_videos(
cmap="bmy",
save_poses=False,
save_video=False,
single_animal=False,
):
"""Analyze videos using DeepLabCut-live exported models.
Analyze multiple videos and/or multiple options for the size of the video
Expand Down Expand Up @@ -168,7 +176,7 @@ def benchmark_videos(
im_size_out = []

for i in range(len(resize)):
print(f"\nRun {i+1} / {len(resize)}\n")
print(f"\nRun {i + 1} / {len(resize)}\n")

this_inf_times, this_im_size, meta = benchmark(
model_path=model_path,
Expand All @@ -188,6 +196,7 @@ def benchmark_videos(
save_poses=save_poses,
save_video=save_video,
save_dir=output,
single_animal=single_animal,
)

inf_times.append(this_inf_times)
Expand Down Expand Up @@ -275,9 +284,7 @@ def get_system_info() -> dict:
}


def save_inf_times(
sys_info, inf_times, im_size, model=None, meta=None, output=None
):
def save_inf_times(sys_info, inf_times, im_size, model=None, meta=None, output=None):
"""Save inference time data collected using :function:`benchmark` with system information to a pickle file.
This is primarily used through :function:`benchmark_videos`

Expand Down Expand Up @@ -346,6 +353,7 @@ def save_inf_times(

return True


def benchmark(
model_path: str,
model_type: str,
Expand All @@ -357,8 +365,8 @@ def benchmark(
single_animal: bool = True,
cropping: list[int] | None = None,
dynamic: tuple[bool, float, int] = (False, 0.5, 10),
n_frames: int =1000,
print_rate: bool=False,
n_frames: int = 1000,
print_rate: bool = False,
precision: str = "FP32",
display: bool = True,
pcutoff: float = 0.5,
Expand Down Expand Up @@ -434,7 +442,10 @@ def benchmark(
if not cap.isOpened():
print(f"Error: Could not open video file {video_path}")
return
im_size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
im_size = (
int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
)

if pixels is not None:
resize = np.sqrt(pixels / (im_size[0] * im_size[1]))
Expand Down Expand Up @@ -492,9 +503,7 @@ def benchmark(

total_n_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
n_frames = int(
n_frames
if (n_frames > 0) and n_frames < total_n_frames
else total_n_frames
n_frames if (n_frames > 0) and n_frames < total_n_frames else total_n_frames
)
iterator = range(n_frames) if print_rate or display else tqdm(range(n_frames))
for _ in iterator:
Expand All @@ -510,7 +519,7 @@ def benchmark(

start_time = time.perf_counter()
if frame_index == 0:
pose = dlc_live.init_inference(frame) # Loads model
pose = dlc_live.init_inference(frame) # Loads model
else:
pose = dlc_live.get_pose(frame)

Expand All @@ -519,7 +528,9 @@ def benchmark(
times.append(inf_time)

if print_rate:
print("Inference rate = {:.3f} FPS".format(1 / inf_time), end="\r", flush=True)
print(
"Inference rate = {:.3f} FPS".format(1 / inf_time), end="\r", flush=True
)

if save_video:
draw_pose_and_write(
Expand All @@ -531,19 +542,17 @@ def benchmark(
pcutoff=pcutoff,
display_radius=display_radius,
draw_keypoint_names=draw_keypoint_names,
vwriter=vwriter
vwriter=vwriter,
)

frame_index += 1

if print_rate:
print("Mean inference rate: {:.3f} FPS".format(np.mean(1 / np.array(times)[1:])))
print(
"Mean inference rate: {:.3f} FPS".format(np.mean(1 / np.array(times)[1:]))
)

metadata = _get_metadata(
video_path=video_path,
cap=cap,
dlc_live=dlc_live
)
metadata = _get_metadata(video_path=video_path, cap=cap, dlc_live=dlc_live)

cap.release()

Expand All @@ -558,19 +567,21 @@ def benchmark(
else:
individuals = []
n_individuals = len(individuals) or 1
save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, timestamp=timestamp)
save_poses_to_files(
video_path, save_dir, n_individuals, bodyparts, poses, timestamp=timestamp
)

return times, im_size, metadata


def setup_video_writer(
video_path:str,
save_dir:str,
timestamp:str,
num_keypoints:int,
cmap:str,
fps:float,
frame_size:tuple[int, int],
video_path: str,
save_dir: str,
timestamp: str,
num_keypoints: int,
cmap: str,
fps: float,
frame_size: tuple[int, int],
):
# Set colors and convert to RGB
cmap_colors = getattr(cc, cmap)
Expand All @@ -582,7 +593,9 @@ def setup_video_writer(
# Define output video path
video_path = Path(video_path)
video_name = video_path.stem # filename without extension
output_video_path = Path(save_dir) / f"{video_name}_DLCLIVE_LABELLED_{timestamp}.mp4"
output_video_path = (
Path(save_dir) / f"{video_name}_DLCLIVE_LABELLED_{timestamp}.mp4"
)

# Get video writer setup
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
Expand All @@ -595,6 +608,7 @@ def setup_video_writer(

return colors, vwriter


def draw_pose_and_write(
frame: np.ndarray,
pose: np.ndarray,
Expand All @@ -611,7 +625,9 @@ def draw_pose_and_write(

if resize is not None and resize != 1.0:
# Resize the frame
frame = cv2.resize(frame, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
frame = cv2.resize(
frame, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR
)

# Scale pose coordinates
pose = pose.copy()
Expand Down Expand Up @@ -642,15 +658,10 @@ def draw_pose_and_write(
lineType=cv2.LINE_AA,
)


vwriter.write(image=frame)


def _get_metadata(
video_path: str,
cap: cv2.VideoCapture,
dlc_live: DLCLive
):
def _get_metadata(video_path: str, cap: cv2.VideoCapture, dlc_live: DLCLive):
try:
fourcc = decode_fourcc(cap.get(cv2.CAP_PROP_FOURCC))
except Exception:
Expand Down Expand Up @@ -687,7 +698,9 @@ def _get_metadata(
return meta


def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, timestamp):
def save_poses_to_files(
video_path, save_dir, n_individuals, bodyparts, poses, timestamp
):
"""
Saves the detected keypoint poses from the video to CSV and HDF5 files.

Expand Down Expand Up @@ -725,14 +738,16 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t
else:
individuals = [f"individual_{i}" for i in range(n_individuals)]
pdindex = pd.MultiIndex.from_product(
[individuals, bodyparts, ["x", "y", "likelihood"]], names=["individuals", "bodyparts", "coords"]
[individuals, bodyparts, ["x", "y", "likelihood"]],
names=["individuals", "bodyparts", "coords"],
)

pose_df = pd.DataFrame(flattened_poses, columns=pdindex)

pose_df.to_hdf(h5_save_path, key="df_with_missing", mode="w")
pose_df.to_csv(csv_save_path, index=False)


def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
# Create numpy array with poses:
max_frame = max(p["frame"] for p in poses)
Expand All @@ -745,17 +760,15 @@ def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
if pose.ndim == 2:
pose = pose[np.newaxis, :, :]
padded_pose = np.full(pose_target_shape, np.nan)
slices = tuple(slice(0, min(pose.shape[i], pose_target_shape[i])) for i in range(3))
slices = tuple(
slice(0, min(pose.shape[i], pose_target_shape[i])) for i in range(3)
)
padded_pose[slices] = pose[slices]
poses_array[frame] = padded_pose

return poses_array


import argparse
import os


def main():
"""Provides a command line interface to benchmark_videos function."""
parser = argparse.ArgumentParser(
Expand Down
Loading