Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/ect/dect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .results import ECTResult
from typing import Optional, Union
import numpy as np
from numba import njit
from numba import njit # type: ignore[attr-defined]


class DECT(ECT):
Expand Down Expand Up @@ -86,18 +86,19 @@ def _compute_directional_transform(
def calculate(
self,
graph: Union[EmbeddedGraph, EmbeddedCW],
scale: Optional[float] = None,
theta: Optional[float] = None,
override_bound_radius: Optional[float] = None,
*,
scale: Optional[float] = None,
) -> ECTResult:
"""
Calculate the Differentiable Euler Characteristic Transform (DECT) for a given embedded complex.
Args:
graph (EmbeddedGraph or EmbeddedCW): The embedded complex to analyze.
scale (Optional[float]): Slope parameter for the sigmoid function. If None, uses the instance's scale.
theta (Optional[float]): Specific direction angle to use. If None, uses all directions.
override_bound_radius (Optional[float]): Override for bounding radius in threshold generation.
scale (Optional[float]): Slope parameter for the sigmoid function. If None, uses the instance's scale.
Returns:
ECTResult: Result object containing the DECT matrix, directions, and thresholds.
Expand Down
9 changes: 5 additions & 4 deletions src/ect/ect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from numba import prange, njit
from numba import prange, njit # type: ignore[attr-defined]
from numba.typed import List
from typing import Optional

Expand All @@ -16,7 +16,7 @@ class ECT:
The result is a matrix where entry ``M[i,j]`` is :math:`\chi(K_{a_i})` for the direction :math:`\omega_j`
where :math:`a_i` is the ith entry in ``self.thresholds``, and :math:`\omega_j` is the jth entry in ``self.directions``.
Example:
>>> from ect import ECT, EmbeddedComplex
Expand Down Expand Up @@ -106,14 +106,15 @@ def _ensure_thresholds(self, graph, override_bound_radius=None):
def calculate(
self,
graph: EmbeddedComplex,
theta: float = None,
override_bound_radius: float = None,
theta: Optional[float] = None,
override_bound_radius: Optional[float] = None,
):
self._ensure_directions(graph.dim, theta)
self._ensure_thresholds(graph, override_bound_radius)
directions = (
self.directions if theta is None else Directions.from_angles([theta])
)
assert self.thresholds is not None
ect_matrix = self._compute_ect(graph, directions, self.thresholds, self.dtype)

return ECTResult(ect_matrix, directions, self.thresholds)
Expand Down
16 changes: 7 additions & 9 deletions src/ect/embed_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,9 +655,8 @@ def pca_projection(self, target_dim=2):
pca = PCA(n_components=target_dim)
self._coord_matrix = pca.fit_transform(self._coord_matrix)


@staticmethod
def validate_plot_parameters(func):
# decorator to check plotting requirements
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
bounding_center_type = kwargs.get("bounding_center_type", "bounding_box")
Expand Down Expand Up @@ -706,7 +705,6 @@ def plot_faces(self, ax=None, **kwargs):

return ax


@validate_plot_parameters
def plot(
self,
Expand All @@ -722,7 +720,7 @@ def plot(
face_color: str = "lightblue",
face_alpha: float = 0.3,
**kwargs,
) -> plt.Axes:
) -> plt.Axes:
"""
Visualize the embedded complex in 2D or 3D
Expand All @@ -739,7 +737,7 @@ def plot(
face_color (str): Color for faces (2-cells)
face_alpha (float): Transparency for faces (2-cells)
**kwargs: Additional keyword arguments for plotting functions
Returns:
matplotlib.axes.Axes: The axes object with the plot.
"""
Expand Down Expand Up @@ -991,20 +989,20 @@ def _build_incidence_csr(self) -> tuple:

cell_vertex_pointers = np.empty(n_cells + 1, dtype=np.int64)
cell_euler_signs = np.empty(n_cells, dtype=np.int32)
cell_vertex_indices_flat = []
list_flat: List[int] = []

cell_vertex_pointers[0] = 0
cell_index = 0
for dim in dimensions:
cells_in_dim = cells_by_dimension[dim]
euler_sign = 1 if (dim % 2 == 0) else -1
for cell_vertices in cells_in_dim:
cell_vertex_indices_flat.extend(cell_vertices)
list_flat.extend(cell_vertices)
cell_euler_signs[cell_index] = euler_sign
cell_index += 1
cell_vertex_pointers[cell_index] = len(cell_vertex_indices_flat)
cell_vertex_pointers[cell_index] = len(list_flat)

cell_vertex_indices_flat = np.asarray(cell_vertex_indices_flat, dtype=np.int32)
cell_vertex_indices_flat = np.asarray(list_flat, dtype=np.int32)
return (
cell_vertex_pointers,
cell_vertex_indices_flat,
Expand Down
35 changes: 25 additions & 10 deletions src/ect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from ect.directions import Sampling
from scipy.spatial.distance import cdist, pdist, squareform
from typing import Union, List, Callable
from typing import Union, List, Callable, cast


# ---------- CSR <-> Dense helpers (prefix-difference over thresholds) ----------
Expand Down Expand Up @@ -352,14 +352,15 @@ def dist(
>>> # Batch distances with custom function
>>> dists = ect1.dist([ect2, ect3, ect4], metric=my_distance)
"""
# normalize input to list
single = isinstance(other, ECTResult)
others = [other] if single else other
others_list: List["ECTResult"] = cast(
List["ECTResult"], [other] if single else other
)

if not others:
if not others_list:
return np.array([])

for i, ect in enumerate(others):
for i, ect in enumerate(others_list):
if ect.shape != self.shape:
raise ValueError(
f"Shape mismatch at index {i}: {self.shape} vs {ect.shape}"
Expand All @@ -370,13 +371,15 @@ def dist(
if single:
b = np.asarray(other, dtype=np.float64)
return float(np.sqrt(np.sum((a - b) ** 2)))
b = np.stack([np.asarray(ect, dtype=np.float64) for ect in others], axis=0)
b = np.stack(
[np.asarray(ect, dtype=np.float64) for ect in others_list], axis=0
)
diff = b - a
return np.sqrt(np.sum(diff * diff, axis=(1, 2)))

distances = cdist(
self.ravel()[np.newaxis, :],
np.vstack([ect.ravel() for ect in others]),
np.vstack([ect.ravel() for ect in others_list]),
metric=metric,
**kwargs,
)[0]
Expand All @@ -399,13 +402,25 @@ def dist_matrix(
raise ValueError(f"Shape mismatch at index {i}: {shape0} vs {r.shape}")

if isinstance(metric, str) and metric.lower() in ("frobenius", "fro"):
return np.vstack([results[i].dist(results, metric="frobenius") for i in range(len(results))])
return np.vstack(
[
results[i].dist(results, metric="frobenius")
for i in range(len(results))
]
)

if isinstance(metric, str):
X = np.stack([np.asarray(r, dtype=np.float64).ravel() for r in results], axis=0)
X = np.stack(
[np.asarray(r, dtype=np.float64).ravel() for r in results], axis=0
)
try:
return squareform(pdist(X, metric=metric, **kwargs))
except TypeError:
return cdist(X, X, metric=metric, **kwargs)

return np.vstack([results[i].dist(results, metric=metric, **kwargs) for i in range(len(results))])
return np.vstack(
[
results[i].dist(results, metric=metric, **kwargs)
for i in range(len(results))
]
)
Loading
Loading