Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
679 changes: 328 additions & 351 deletions doc_source/notebooks/Matisse/example_matisse.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ect"
version = "1.2.3"
version = "1.2.4"
authors = [
{ name="Liz Munch", email="muncheli@msu.edu" },
]
Expand Down
2 changes: 2 additions & 0 deletions src/ect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .ect import ECT
from .embed_complex import EmbeddedComplex, EmbeddedGraph, EmbeddedCW
from .directions import Directions
from .results import ECTResult
from .sect import SECT
from .dect import DECT
from .utils import examples
Expand All @@ -25,5 +26,6 @@
"EmbeddedGraph",
"EmbeddedCW",
"Directions",
"ECTResult",
"examples",
]
41 changes: 38 additions & 3 deletions src/ect/results.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import matplotlib.pyplot as plt
import numpy as np
from ect.directions import Sampling
from scipy.spatial.distance import cdist
from scipy.spatial.distance import cdist, pdist, squareform
from typing import Union, List, Callable


Expand Down Expand Up @@ -319,7 +319,7 @@ def _plot_ecc(self, theta):
def dist(
self,
other: Union["ECTResult", List["ECTResult"]],
metric: Union[str, Callable] = "cityblock",
metric: Union[str, Callable] = "frobenius",
**kwargs,
):
"""
Expand Down Expand Up @@ -365,7 +365,15 @@ def dist(
f"Shape mismatch at index {i}: {self.shape} vs {ect.shape}"
)

# use ravel to avoid copying the data and compute distances
if isinstance(metric, str) and metric.lower() in ("frobenius", "fro"):
a = np.asarray(self, dtype=np.float64)
if single:
b = np.asarray(other, dtype=np.float64)
return float(np.sqrt(np.sum((a - b) ** 2)))
b = np.stack([np.asarray(ect, dtype=np.float64) for ect in others], axis=0)
diff = b - a
return np.sqrt(np.sum(diff * diff, axis=(1, 2)))

distances = cdist(
self.ravel()[np.newaxis, :],
np.vstack([ect.ravel() for ect in others]),
Expand All @@ -374,3 +382,30 @@ def dist(
)[0]

return distances[0] if single else distances

@classmethod
def dist_matrix(
cls,
results: List["ECTResult"],
metric: Union[str, Callable] = "frobenius",
**kwargs,
) -> np.ndarray:
if not results:
return np.empty((0, 0), dtype=np.float64)

shape0 = results[0].shape
for i, r in enumerate(results):
if r.shape != shape0:
raise ValueError(f"Shape mismatch at index {i}: {shape0} vs {r.shape}")

if isinstance(metric, str) and metric.lower() in ("frobenius", "fro"):
return np.vstack([results[i].dist(results, metric="frobenius") for i in range(len(results))])

if isinstance(metric, str):
X = np.stack([np.asarray(r, dtype=np.float64).ravel() for r in results], axis=0)
try:
return squareform(pdist(X, metric=metric, **kwargs))
except TypeError:
return cdist(X, X, metric=metric, **kwargs)

return np.vstack([results[i].dist(results, metric=metric, **kwargs) for i in range(len(results))])
20 changes: 14 additions & 6 deletions tests/test_ect_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,19 @@ def test_dist_single_ectresult(self):
result2_modified.directions = result2.directions
result2_modified.thresholds = result2.thresholds

# Test L1 distance (default)
dist_l1 = self.result.dist(result2_modified)
expected_l1 = np.abs(self.result - result2_modified).sum()
self.assertAlmostEqual(dist_l1, expected_l1)
self.assertIsInstance(dist_l1, (float, np.floating))
# Test frobenius distance (default)
dist_frobenius = self.result.dist(result2_modified)
expected_frobenius = np.sqrt(
np.sum(
(
np.asarray(self.result, dtype=np.float64)
- np.asarray(result2_modified, dtype=np.float64)
).ravel()
** 2
)
)
self.assertAlmostEqual(dist_frobenius, expected_frobenius)
self.assertIsInstance(dist_frobenius, (float, np.floating))

# Test L2 distance
dist_l2 = self.result.dist(result2_modified, metric="euclidean")
Expand All @@ -119,7 +127,7 @@ def test_dist_list_of_ectresults(self):
r.thresholds = self.result.thresholds

# Test batch distances
distances = self.result.dist([result2, result3, result4])
distances = self.result.dist([result2, result3, result4], metric="cityblock")

# Check return type is array
self.assertIsInstance(distances, np.ndarray)
Expand Down
Loading