Skip to content

Commit 603c8a8

Browse files
committed
feat(invdes): added symmetry functions
1 parent 66d3db3 commit 603c8a8

File tree

6 files changed

+380
-0
lines changed

6 files changed

+380
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [Unreleased]
99

1010
### Added
11+
- Added `symmetrize_mirror`, `symmetrize_rotation`, `symmetrize_diagonal` functions to the autograd plugin. They can be used for enforcing symmetries in topology optimization.
1112

1213
### Changed
1314
- Removed validator that would warn if `PerturbationMedium` values could become numerically unstable, since an error will anyway be raised if this actually happens when the medium is converted using actual perturbation data.

docs/api/plugins/autograd.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,7 @@ Inverse Design
8585
tidy3d.plugins.autograd.invdes.ramp_projection
8686
tidy3d.plugins.autograd.invdes.tanh_projection
8787
tidy3d.plugins.autograd.invdes.smoothed_projection
88+
tidy3d.plugins.autograd.invdes.symmetrize_mirror
89+
tidy3d.plugins.autograd.invdes.symmetrize_rotation
90+
tidy3d.plugins.autograd.invdes.symmetrize_diagonal
8891

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from __future__ import annotations
2+
3+
import autograd.numpy as np
4+
import numpy as onp
5+
import pytest
6+
from autograd.test_util import check_grads
7+
8+
from tidy3d.plugins.autograd.invdes.symmetries import (
9+
symmetrize_diagonal,
10+
symmetrize_mirror,
11+
symmetrize_rotation,
12+
)
13+
14+
# --- Helper Fixtures ---
15+
16+
17+
@pytest.fixture
18+
def square_array():
19+
"""Returns a random 5x5 array for square tests."""
20+
return np.random.randn(5, 5)
21+
22+
23+
@pytest.fixture
24+
def rect_array():
25+
"""Returns a random 4x6 array for non-square tests."""
26+
return np.random.randn(4, 6)
27+
28+
29+
# --- Symmetrize Mirror Tests ---
30+
31+
32+
@pytest.mark.parametrize("axis", [0, 1, (0, 1)])
33+
def test_mirror_gradients(axis):
34+
"""
35+
Verifies that the gradient calculation through symmetrize_mirror is correct
36+
using finite difference checks provided by autograd.
37+
"""
38+
# Create a random array. Size doesn't need to be square.
39+
x = np.random.randn(4, 5)
40+
41+
# We wrap the function to treat 'axis' as a fixed constant,
42+
# testing the gradient only with respect to 'x'.
43+
def fun(x):
44+
return symmetrize_mirror(x, axis=axis)
45+
46+
# check_grads verifies analytical grad vs finite difference
47+
check_grads(fun, modes=["rev"], order=1)(x)
48+
49+
50+
@pytest.mark.parametrize("axis", [0, 1, (0, 1)])
51+
def test_mirror_values(axis):
52+
"""Verifies numerical correctness of mirror symmetry."""
53+
# Simple 2x2 case
54+
# [[1, 2],
55+
# [3, 4]]
56+
arr = np.array([[1.0, 2.0], [3.0, 4.0]])
57+
58+
res = symmetrize_mirror(arr, axis=axis)
59+
60+
if axis == 0:
61+
# Average with vertical flip [[3, 4], [1, 2]]
62+
# ([[1, 2], [3, 4]] + [[3, 4], [1, 2]]) / 2 = [[2, 3], [2, 3]]
63+
expected = np.array([[2.0, 3.0], [2.0, 3.0]])
64+
elif axis == 1:
65+
# Average with horizontal flip [[2, 1], [4, 3]]
66+
# ([[1, 2], [3, 4]] + [[2, 1], [4, 3]]) / 2 = [[1.5, 1.5], [3.5, 3.5]]
67+
expected = np.array([[1.5, 1.5], [3.5, 3.5]])
68+
else: # (0, 1)
69+
# Average of all 4 mirror types implied (linear combination reduces to avg of 4 corners)
70+
# Result should be constant value 2.5 everywhere for this specific linear gradient input
71+
expected = np.full((2, 2), 2.5)
72+
73+
onp.testing.assert_allclose(res, expected)
74+
75+
76+
def test_mirror_shapes_and_errors(rect_array):
77+
"""Test shape constraints and error handling."""
78+
# Should work on rectangular arrays
79+
res = symmetrize_mirror(rect_array, axis=0)
80+
assert res.shape == rect_array.shape
81+
82+
# Error: 3D array
83+
with pytest.raises(ValueError, match="Need 2d array"):
84+
symmetrize_mirror(np.random.randn(2, 2, 2), axis=0)
85+
86+
# Error: Invalid axis
87+
with pytest.raises(ValueError, match="Invalid axis"):
88+
symmetrize_mirror(rect_array, axis=2)
89+
90+
# Error: Invalid tuple
91+
with pytest.raises(ValueError, match="Invalid axis"):
92+
symmetrize_mirror(rect_array, axis=(0, 0))
93+
94+
95+
# --- Symmetrize Rotation Tests ---
96+
97+
98+
def test_rotation_gradients(square_array):
99+
"""Verifies gradients for rotation symmetry."""
100+
check_grads(symmetrize_rotation, modes=["rev"], order=1)(square_array)
101+
102+
103+
def test_rotation_values():
104+
"""Verifies numerical correctness of rotation symmetry."""
105+
# Input with a single 1 in top-left, 0 elsewhere
106+
# [[1, 0],
107+
# [0, 0]]
108+
arr = np.zeros((2, 2))
109+
arr[0, 0] = 1.0
110+
111+
res = symmetrize_rotation(arr)
112+
113+
# The 1 should be distributed to all 4 corners equally
114+
expected = np.full((2, 2), 0.25)
115+
onp.testing.assert_allclose(res, expected)
116+
117+
118+
def test_rotation_invariance(square_array):
119+
"""The output of symmetrize_rotation should be invariant to further 90deg rotations."""
120+
sym = symmetrize_rotation(square_array)
121+
rot = np.rot90(sym)
122+
onp.testing.assert_allclose(sym, rot, err_msg="Output is not rotationally symmetric")
123+
124+
125+
def test_rotation_errors(rect_array):
126+
"""Test shape constraints for rotation."""
127+
# Error: Rectangular array
128+
with pytest.raises(ValueError, match="must be square"):
129+
symmetrize_rotation(rect_array)
130+
131+
132+
# --- Symmetrize Diagonal Tests ---
133+
134+
135+
@pytest.mark.parametrize("anti", [False, True])
136+
def test_diagonal_gradients(square_array, anti):
137+
"""Verifies gradients for diagonal symmetry."""
138+
139+
def fun(x):
140+
return symmetrize_diagonal(x, anti=anti)
141+
142+
check_grads(fun, modes=["rev"], order=1)(square_array)
143+
144+
145+
def test_diagonal_values():
146+
"""Verifies numerical correctness of diagonal symmetry."""
147+
# [[1, 2],
148+
# [3, 4]]
149+
arr = np.array([[1.0, 2.0], [3.0, 4.0]])
150+
151+
# Main diagonal
152+
res_main = symmetrize_diagonal(arr, anti=False)
153+
# Transpose is [[1, 3], [2, 4]]
154+
# Avg: [[1, 2.5], [2.5, 4]]
155+
expected_main = np.array([[1.0, 2.5], [2.5, 4.0]])
156+
onp.testing.assert_allclose(res_main, expected_main)
157+
158+
# Anti diagonal
159+
res_anti = symmetrize_diagonal(arr, anti=True)
160+
# Anti-transpose logic check:
161+
#
162+
# Input:
163+
# 1 2
164+
# 3 4
165+
#
166+
# Anti-Transpose:
167+
# 4 2
168+
# 3 1
169+
#
170+
# Average:
171+
# 2.5 2
172+
# 3 2.5
173+
expected_anti = np.array([[2.5, 2.0], [3.0, 2.5]])
174+
onp.testing.assert_allclose(res_anti, expected_anti)
175+
176+
177+
def test_diagonal_errors(rect_array):
178+
"""Test shape constraints for diagonal."""
179+
# Error: Rectangular array
180+
with pytest.raises(ValueError, match="must be square"):
181+
symmetrize_diagonal(rect_array)

tidy3d/plugins/autograd/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
make_filter_and_project,
3737
make_gaussian_filter,
3838
ramp_projection,
39+
smoothed_projection,
40+
symmetrize_diagonal,
41+
symmetrize_mirror,
42+
symmetrize_rotation,
3943
tanh_projection,
4044
)
4145
from .primitives import gaussian_filter, interpolate_spline
@@ -79,6 +83,10 @@
7983
"scalar_objective",
8084
"smooth_max",
8185
"smooth_min",
86+
"smoothed_projection",
87+
"symmetrize_diagonal",
88+
"symmetrize_mirror",
89+
"symmetrize_rotation",
8290
"tanh_projection",
8391
"threshold",
8492
"trapz",

tidy3d/plugins/autograd/invdes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from .penalties import ErosionDilationPenalty, make_curvature_penalty, make_erosion_dilation_penalty
1919
from .projections import ramp_projection, smoothed_projection, tanh_projection
20+
from .symmetries import symmetrize_diagonal, symmetrize_mirror, symmetrize_rotation
2021

2122
__all__ = [
2223
"CircularFilter",
@@ -35,5 +36,8 @@
3536
"make_gaussian_filter",
3637
"ramp_projection",
3738
"smoothed_projection",
39+
"symmetrize_diagonal",
40+
"symmetrize_mirror",
41+
"symmetrize_rotation",
3842
"tanh_projection",
3943
]

0 commit comments

Comments
 (0)