diff --git a/CHANGELOG.md b/CHANGELOG.md index 6767dfbab1..708c432298 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for `nonlinear_spec` in `CustomMedium` and `CustomDispersiveMedium`. - `tidy3d.plugins.design.DesignSpace.run(..., fn_post=...)` now accepts a `priority` keyword to propagate vGPU queue priority to all automatically batched simulations. - Introduced `BroadbandPulse` for exciting simulations across a wide frequency spectrum. +- Added `user_vjp` and `numerical_structures` to new custom run functions that provide hooks into adjoint for user-defined gradient calculations. ### Breaking Changes - Edge singularity correction at PEC and lossy metal edges defaults to `True`. diff --git a/tests/test_components/autograd/numerical/test_autograd_cm_user_vjp_numerical_structures.py b/tests/test_components/autograd/numerical/test_autograd_cm_user_vjp_numerical_structures.py new file mode 100644 index 0000000000..9fcb2a61f1 --- /dev/null +++ b/tests/test_components/autograd/numerical/test_autograd_cm_user_vjp_numerical_structures.py @@ -0,0 +1,478 @@ +# tests user_vjp and numerical_structures autograd hooks for ComponentModeler and compares to numerically computed finite difference gradients +from __future__ import annotations + +import operator +import sys + +import autograd as ag +import matplotlib.pylab as plt +import numpy as np +import pytest +import trimesh +import xarray as xr + +import tidy3d as td +from tidy3d.plugins.smatrix import ComponentModeler, Port +from tidy3d.plugins.smatrix.run import _run_local +from tidy3d.web.api.autograd.types import NumericalStructureConfig, UserVJPConfig + +PLOT_FD_ADJ_COMPARISON = True +NUM_FINITE_DIFFERENCE = 10 +SAVE_FD_ADJ_DATA = True +SAVE_FD_LOC = 0 +SAVE_ADJ_LOC = 1 +LOCAL_GRADIENT = True +VERBOSE = False +NUMERICAL_RESULTS_DATA_DIR = "./numerical_cm_user_vjp_numerical_structures_test/" +SHOW_PRINT_STATEMENTS = True + +OVERLAP_ERROR_THRESHOLD_DEG = 10.0 + +ADJOINT_PERMITTIVITY = 1.5**2 + +if PLOT_FD_ADJ_COMPARISON: + pytestmark = pytest.mark.usefixtures("mpl_config_interactive") +else: + pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") + +if SHOW_PRINT_STATEMENTS: + sys.stdout = sys.stderr + + +SIMULATION_SIZE_MESH_WVL_FACTOR = 7 +SIMULATION_HEIGHT_WVL_FACTOR = 3 + +SPHERE_OFFSET_MAX_MESH_WVL_FACTOR = 0.25 +SPHERE_MIN_RADIUS_MESH_WVL_FACTOR = 0.3 +SPHERE_MAX_RADIUS_MESH_WVL_FACTOR = 0.4 + +FD_STEP_MESH_WVL_FACTOR = 1.0 / 75.0 + + +def get_sim_geometry(mesh_wvl_um): + return td.Box( + size=( + SIMULATION_SIZE_MESH_WVL_FACTOR * mesh_wvl_um, + SIMULATION_SIZE_MESH_WVL_FACTOR * mesh_wvl_um, + SIMULATION_HEIGHT_WVL_FACTOR * mesh_wvl_um, + ), + center=(0, 0, 0), + ) + + +def make_base_sim( + mesh_wvl_um, + adj_wvl_um, + monitor_bg_index=1.0, + run_time=2e-11, +): + sim_geometry = get_sim_geometry(mesh_wvl_um) + sim_size_um = sim_geometry.size + sim_center_um = sim_geometry.center + + input_waveguide = td.Structure( + geometry=td.Box( + center=(-0.35 * sim_size_um[0], sim_center_um[1], sim_center_um[2]), + size=(0.5 * sim_size_um[0], 0.35 * adj_wvl_um, 0.2 * adj_wvl_um), + ), + medium=td.Medium(permittivity=3.5**2), + ) + + output_waveguide = td.Structure( + geometry=td.Box( + center=(0.35 * sim_size_um[0], sim_center_um[1], sim_center_um[2]), + size=(0.5 * sim_size_um[0], 0.35 * adj_wvl_um, 0.2 * adj_wvl_um), + ), + medium=td.Medium(permittivity=3.5**2), + ) + + num_modes = 1 + + port_left = Port( + center=input_waveguide.geometry.center, + size=(0.0, adj_wvl_um, adj_wvl_um), + mode_spec=td.ModeSpec(num_modes=num_modes), + direction="+", + name="left", + ) + + port_right = Port( + center=output_waveguide.geometry.center, + size=(0.0, adj_wvl_um, adj_wvl_um), + mode_spec=td.ModeSpec(num_modes=num_modes), + direction="-", + name="right", + ) + + boundary_spec = td.BoundarySpec( + x=td.Boundary.pml(), + y=td.Boundary.pml(), + z=td.Boundary.pml(), + ) + + ports = [port_left, port_right] + + return ports, td.Simulation( + center=sim_center_um, + size=sim_size_um, + grid_spec=td.GridSpec.auto( + min_steps_per_wvl=30, + wavelength=1.5, + ), + boundary_spec=boundary_spec, + sources=[], + monitors=[], + structures=[input_waveguide, output_waveguide], + run_time=1e-11, + ) + + +def vjp_sphere(sphere, derivative_info): + max_frequency = np.max(derivative_info.frequencies) + min_wvl = td.C_0 / max_frequency + + step_size = min_wvl / 20.0 + + ps_paths = set() + ps_paths.update({("permittivity",)}) + + update_kwargs = { + "paths": list(ps_paths), + "deep": False, + } + + def finite_difference_gradient(perturb_up, perturb_down, derivative_info_): + eps_up = derivative_info.updated_epsilon(perturb_up) + eps_down = derivative_info.updated_epsilon(perturb_down) + eps_grad = (eps_up - eps_down) / (2 * step_size) + + derivative_info_custom_medium = derivative_info_.updated_copy(**update_kwargs) + + custom_medium = td.CustomMedium(permittivity=xr.ones_like(eps_grad.isel(f=0, drop=True))) + vjps_custom_medium = custom_medium._compute_derivatives(derivative_info_custom_medium) + + total_grad = np.real(np.sum(eps_grad.sum("f").data * vjps_custom_medium[("permittivity",)])) + + return total_grad + + vjps = {} + for path in derivative_info.paths: + if path[0:2] == ( + "geometry", + "radius", + ): + sphere_up = sphere.updated_copy(radius=sphere.radius + step_size) + sphere_down = sphere.updated_copy(radius=sphere.radius - step_size) + vjps[path] = finite_difference_gradient(sphere_up, sphere_down, derivative_info) + elif path[0:2] == ("geometry", "center"): + if len(path) == 2: + center_indices = (0, 1, 2) + else: + _, center_index = path[1:] + center_indices = [center_index] + + vjp_result = [] + for center_index in center_indices: + center_up = list(sphere.center) + center_down = list(sphere.center) + + center_up[center_index] += step_size + center_down[center_index] -= step_size + + sphere_up = sphere.updated_copy(center=center_up) + sphere_down = sphere.updated_copy(center=center_down) + + vjp_result.append( + finite_difference_gradient(sphere_up, sphere_down, derivative_info) + ) + + vjps[path] = vjp_result if len(path) == 2 else vjp_result[0] + + return vjps + + +def create_ring(params): + ring_mesh = trimesh.creation.annulus( + r_min=params[0], r_max=params[1], height=params[2], sections=100 + ) + + rotator = trimesh.transformations.rotation_matrix(np.radians(90), [0, 1, 0]) + ring_mesh.apply_transform(rotator) + + translate = trimesh.transformations.translation_matrix([-0.65, 0, 0]) + ring_mesh.apply_transform(translate) + + ring_geo = td.TriangleMesh.from_trimesh(ring_mesh) + + return td.Structure(geometry=ring_geo, medium=td.Medium(permittivity=ADJOINT_PERMITTIVITY)) + + +def vjp_ring(parameters, derivative_info): + max_frequency = np.max(derivative_info.frequencies) + min_wvl = td.C_0 / max_frequency + + step_size = min_wvl / 20.0 + + ps_paths = set() + ps_paths.update({("permittivity",)}) + + # pass interpolators to PolySlab if available to avoid redundant conversions + update_kwargs = { + "paths": list(ps_paths), + "deep": False, + } + derivative_info_custom_medium = derivative_info.updated_copy(**update_kwargs) + + params_np = np.array(parameters) + + vjps = {} + for path in derivative_info.paths: + param_idx = path[0] + + params_up = params_np.copy() + params_down = params_np.copy() + + params_up[param_idx] += step_size + params_down[param_idx] -= step_size + + ring_up = create_ring(params_up) + ring_down = create_ring(params_down) + + eps_up = derivative_info.updated_epsilon(ring_up.geometry) + eps_down = derivative_info.updated_epsilon(ring_down.geometry) + + eps_grad = (eps_up - eps_down) / (2 * step_size) + + custom_medium = td.CustomMedium(permittivity=xr.ones_like(eps_grad.isel(f=0, drop=True))) + vjps_custom_medium = custom_medium._compute_derivatives(derivative_info_custom_medium) + + total_grad = np.real(np.sum(eps_grad.sum("f").data * vjps_custom_medium[("permittivity",)])) + + vjps[path] = total_grad + + return vjps + + +def create_objective_function(geometry, create_sim_base, adj_wvl_um, sim_path_dir): + def objective(geom_parameters_lists): + ports, sim_base = create_sim_base() + + simulation_dict = {} + geom_dict = {} + for idx, geom_parameters in enumerate(geom_parameters_lists): + sphere_structure = td.Structure( + geometry=td.Sphere(center=geom_parameters[0:3], radius=geom_parameters[3]), + medium=td.Medium(permittivity=ADJOINT_PERMITTIVITY), + ) + + sim_with_sphere = sim_base.updated_copy( + structures=(*sim_base.structures, sphere_structure) + ) + + simulation_dict[f"numerical_user_vjp_testing_{idx}"] = sim_with_sphere.copy() + geom_dict[f"numerical_user_vjp_testing_{idx}"] = geom_parameters + + sim_data = {} + for key, sim_val in simulation_dict.items(): + modeler = ComponentModeler( + simulation=sim_val, + ports=ports, + freqs=[td.C_0 / adj_wvl_um], + ) + + ring_numerical_structure = NumericalStructureConfig( + create=create_ring, + compute_derivatives=vjp_ring, + parameters=geom_dict[key][4:], + structure_index=0, + ) + + user_vjp_single = UserVJPConfig( + structure_index=3, + compute_derivatives=vjp_sphere, + ) + + sim_data[key] = _run_local( + modeler, + local_gradient=LOCAL_GRADIENT, + verbose=VERBOSE, + user_vjp=user_vjp_single, + numerical_structures=ring_numerical_structure, + ) + + objective_vals = [] + for idx in range(len(geom_parameters_lists)): + smatrix = sim_data[f"numerical_user_vjp_testing_{idx}"] + objective_vals.append(np.sum(np.abs(smatrix.smatrix().values) ** 2)) + + if len(geom_parameters_lists) == 1: + return objective_vals[0] + + return objective_vals + + return objective + + +background_indices = [1.0] +mesh_wvls_um = [1.5] +adj_wvls_um = [1.5] + +test_parameters = [] + +test_number = 0 +for idx in range(len(mesh_wvls_um)): + mesh_wvl_um = mesh_wvls_um[idx] + adj_wvl_um = adj_wvls_um[idx] + + for monitor_bg_index in background_indices: + test_parameters.append( + { + "mesh_wvl_um": mesh_wvl_um, + "adj_wvl_um": adj_wvl_um, + "monitor_bg_index": monitor_bg_index, + "test_number": test_number, + } + ) + + test_number += 1 + + +@pytest.mark.numerical +@pytest.mark.parametrize( + "test_parameters, dir_name", + zip( + test_parameters, + ([NUMERICAL_RESULTS_DATA_DIR] if SAVE_FD_ADJ_DATA else [None]) * len(test_parameters), + ), + indirect=["dir_name"], +) +def test_finite_difference_user_vjp(test_parameters, rng, tmp_path, create_directory): + """Test a variety of autograd permittivity gradients for DiffractionData by""" + """comparing them to numerical finite difference.""" + + test_number = test_parameters["test_number"] + + ( + mesh_wvl_um, + adj_wvl_um, + monitor_bg_index, + test_number, + ) = operator.itemgetter( + "mesh_wvl_um", + "adj_wvl_um", + "monitor_bg_index", + "test_number", + )(test_parameters) + + sim_geometry = get_sim_geometry(mesh_wvl_um) + + dim_um = mesh_wvl_um + thickness_um = 0.5 * mesh_wvl_um + block = td.Box( + center=(sim_geometry.center[0], sim_geometry.center[1], 0), + size=(dim_um, dim_um, thickness_um), + ) + + sim_path_dir = tmp_path / f"test{test_number}" + sim_path_dir.mkdir() + + objective = create_objective_function( + block, + lambda mesh_wvl_um=mesh_wvl_um, + adj_wvl_um=adj_wvl_um, + monitor_bg_index=monitor_bg_index: make_base_sim( + mesh_wvl_um=mesh_wvl_um, + adj_wvl_um=adj_wvl_um, + monitor_bg_index=monitor_bg_index, + ), + adj_wvl_um, + sim_path_dir=str(sim_path_dir), + ) + + obj_val_and_grad = ag.value_and_grad(objective) + + sphere_init = [ + *rng.uniform( + low=-SPHERE_OFFSET_MAX_MESH_WVL_FACTOR * mesh_wvl_um, + high=SPHERE_OFFSET_MAX_MESH_WVL_FACTOR * mesh_wvl_um, + size=2, + ), + 0.0, + *rng.uniform( + low=SPHERE_MIN_RADIUS_MESH_WVL_FACTOR * mesh_wvl_um, + high=SPHERE_MAX_RADIUS_MESH_WVL_FACTOR * mesh_wvl_um, + size=1, + ), + ] + + ring_init_mesh_wvl_factor = [0.15, 0.30, 0.2] + ring_init = [r * mesh_wvl_um for r in ring_init_mesh_wvl_factor] + + geom_init = sphere_init + ring_init + + test_results = np.zeros((2, len(geom_init))) + + obj, adj_grad = obj_val_and_grad([geom_init]) + adj_grad = np.squeeze(np.array(adj_grad)) + + # empirical step size for finite difference calculation + fd_step = FD_STEP_MESH_WVL_FACTOR * mesh_wvl_um + + all_params = [] + + for fd_idx in range(len(geom_init)): + geom_up = geom_init.copy() + geom_down = geom_init.copy() + + geom_up[fd_idx] += fd_step + geom_down[fd_idx] -= fd_step + + all_params.append(geom_up) + all_params.append(geom_down) + + all_obj = objective(all_params) + + fd_grad = np.zeros(len(geom_init)) + for fd_idx in range(len(geom_init)): + obj_up_location = 2 * fd_idx + obj_down_location = 2 * fd_idx + 1 + + fd_grad[fd_idx] = (all_obj[obj_up_location] - all_obj[obj_down_location]) / (2 * fd_step) + + rms_error = np.linalg.norm(fd_grad - adj_grad) + fd_mag = np.linalg.norm(fd_grad) + adj_mag = np.linalg.norm(adj_grad) + + dot = np.sum((fd_grad / fd_mag) * (adj_grad / adj_mag)) + overlap_deg = np.arccos(dot) * 180.0 / np.pi + + print("\n" * 3) + print("-" * 20) + print(f"Numerical test #{test_number}") + print(f"Mesh and adjoint wavelengths: {mesh_wvl_um}, {adj_wvl_um}") + print(f"Background index for monitor: {monitor_bg_index}") + print(f"RMS Error: {rms_error}") + print(f"Gradient overlap (deg): {overlap_deg}") + print(f"FD, Adj magnitudes: {fd_mag}, {adj_mag}") + print("-" * 20) + print("\n" * 3) + + assert overlap_deg < OVERLAP_ERROR_THRESHOLD_DEG, ( + "Adjoint and finite difference gradients misaligned." + ) + + test_results[SAVE_FD_LOC, :] = fd_grad + test_results[SAVE_ADJ_LOC, :] = adj_grad + + test_number += 1 + + if PLOT_FD_ADJ_COMPARISON: + plt.plot(adj_grad, color="g", linewidth=2.0) + plt.plot(fd_grad, color="b", linewidth=1.5, linestyle="--") + plt.legend(["Adjoint", "Finite difference"]) + plt.xlabel("Sample number") + plt.ylabel("Gradient value") + plt.show() + + if SAVE_FD_ADJ_DATA: + np.save(f"{NUMERICAL_RESULTS_DATA_DIR}/results_{test_number}.npy", test_results) diff --git a/tests/test_components/autograd/numerical/test_autograd_numerical_structures.py b/tests/test_components/autograd/numerical/test_autograd_numerical_structures.py new file mode 100644 index 0000000000..2f2747748a --- /dev/null +++ b/tests/test_components/autograd/numerical/test_autograd_numerical_structures.py @@ -0,0 +1,450 @@ +# tests numerical_structures autograd hook for run_custom and run_async_custom and compares to numerically computed finite difference gradients +from __future__ import annotations + +import operator +import sys + +import autograd as ag +import matplotlib.pylab as plt +import numpy as np +import pytest +import trimesh +import xarray as xr + +import tidy3d as td +from tidy3d.web.api.autograd.autograd import run_async_custom, run_custom +from tidy3d.web.api.autograd.types import NumericalStructureConfig + +PLOT_FD_ADJ_COMPARISON = True +NUM_FINITE_DIFFERENCE = 10 +SAVE_FD_ADJ_DATA = True +SAVE_FD_LOC = 0 +SAVE_ADJ_LOC = 1 +LOCAL_GRADIENT = True +VERBOSE = False +NUMERICAL_RESULTS_DATA_DIR = "./numerical_numerical_structures_test/" +SHOW_PRINT_STATEMENTS = True + +OVERLAP_ERROR_THRESHOLD_DEG = 15.0 + +ADJOINT_SPHERE_PERMITTIVITY = 1.5**2 + +RMS_THRESHOLD = 0.25 + +if PLOT_FD_ADJ_COMPARISON: + pytestmark = pytest.mark.usefixtures("mpl_config_interactive") +else: + pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") + +if SHOW_PRINT_STATEMENTS: + sys.stdout = sys.stderr + + +SIMULATION_SIZE_MESH_WVL_FACTOR = 3.5 +SIMULATION_HEIGHT_WVL_FACTOR = 5 + +RING_OFFSET_MAX_MESH_WVL_FACTOR = 0.25 +RING_MIN_RADIUS_MESH_WVL_FACTOR = 0.3 +RING_MAX_RADIUS_MESH_WVL_FACOTOR = 0.4 + +FD_STEP_MESH_WVL_FACTOR = 1.0 / 75.0 + + +def get_sim_geometry(mesh_wvl_um): + return td.Box( + size=( + SIMULATION_SIZE_MESH_WVL_FACTOR * mesh_wvl_um, + SIMULATION_SIZE_MESH_WVL_FACTOR * mesh_wvl_um, + SIMULATION_HEIGHT_WVL_FACTOR * mesh_wvl_um, + ), + center=(0, 0, 0), + ) + + +def make_base_sim( + mesh_wvl_um, + adj_wvl_um, + pw_angle_deg, + monitor_bg_index=1.0, + run_time=2e-11, +): + sim_geometry = get_sim_geometry(mesh_wvl_um) + sim_size_um = sim_geometry.size + sim_center_um = sim_geometry.center + + src_size = sim_size_um[0:2] + (0,) + + wl_min_src_um = 0.9 * adj_wvl_um + wl_max_src_um = 1.1 * adj_wvl_um + + fwidth_src = td.C_0 * ((1.0 / wl_min_src_um) - (1.0 / wl_max_src_um)) + freq0 = td.C_0 / adj_wvl_um + + pulse = td.GaussianPulse(freq0=freq0, fwidth=fwidth_src) + + src = td.PlaneWave( + center=(sim_center_um[0], sim_center_um[1], -2.0), + size=[td.inf, td.inf, 0], + source_time=pulse, + direction="+", + angle_theta=(pw_angle_deg * np.pi / 180.0), + ) + + boundary_spec = td.BoundarySpec( + x=td.Boundary.pml(), + y=td.Boundary.pml(), + z=td.Boundary.pml(), + ) + + field_monitor = td.FieldMonitor( + center=( + sim_center_um[0], + sim_center_um[1], + mesh_wvl_um / 1.5, + ), + size=(mesh_wvl_um, mesh_wvl_um, 0), + name="monitor_fields", + freqs=[freq0], + ) + + monitor_index_block = td.Box( + center=(sim_center_um[0], sim_center_um[1], 0.25 * sim_size_um[2] + mesh_wvl_um), + size=(*tuple(2 * size for size in sim_size_um[0:2]), mesh_wvl_um + 0.5 * sim_size_um[2]), + ) + monitor_index_block_structure = td.Structure( + geometry=monitor_index_block, medium=td.Medium(permittivity=monitor_bg_index**2) + ) + + sim_base = td.Simulation( + center=sim_center_um, + size=sim_size_um, + grid_spec=td.GridSpec.auto( + min_steps_per_wvl=30, + wavelength=mesh_wvl_um, + ), + structures=[monitor_index_block_structure], + sources=[src], + monitors=[field_monitor], + run_time=run_time, + boundary_spec=boundary_spec, + subpixel=True, + ) + + return sim_base + + +def create_ring(params): + ring_mesh = trimesh.creation.annulus( + r_min=params[0], r_max=params[1], height=params[2], sections=100 + ) + + ring_geo = td.TriangleMesh.from_trimesh(ring_mesh) + + return td.Structure(geometry=ring_geo, medium=td.Medium(permittivity=1.5**2)) + + +def vjp_ring(parameters, derivative_info): + max_frequency = np.max(derivative_info.frequencies) + min_wvl = td.C_0 / max_frequency + + step_size = min_wvl / 20.0 + + ps_paths = set() + ps_paths.update({("permittivity",)}) + + # pass interpolators to PolySlab if available to avoid redundant conversions + update_kwargs = { + "paths": list(ps_paths), + "deep": False, + } + derivative_info_custom_medium = derivative_info.updated_copy(**update_kwargs) + + params_np = np.array(parameters) + + vjps = {} + for path in derivative_info.paths: + param_idx = path[0] + + params_up = params_np.copy() + params_down = params_np.copy() + + params_up[param_idx] += step_size + params_down[param_idx] -= step_size + + ring_up = create_ring(params_up) + ring_down = create_ring(params_down) + + eps_up = derivative_info.updated_epsilon(ring_up.geometry) + eps_down = derivative_info.updated_epsilon(ring_down.geometry) + + eps_grad = (eps_up - eps_down) / (2 * step_size) + + custom_medium = td.CustomMedium(permittivity=xr.ones_like(eps_grad.isel(f=0, drop=True))) + vjps_custom_medium = custom_medium._compute_derivatives(derivative_info_custom_medium) + + total_grad = np.real(np.sum(eps_grad.sum("f").data * vjps_custom_medium[("permittivity",)])) + + vjps[path] = total_grad + + return vjps + + +def create_objective_function(geometry, create_sim_base, eval_fn, run_fn, sim_path_dir): + def objective(ring_parameters_lists): + sim_base = create_sim_base() + + simulation_dict = {} + for idx in range(len(ring_parameters_lists)): + simulation_dict[f"numerical_numerical_structures_testing_{idx}"] = sim_base.copy() + + assert (run_fn == "run_custom") or (run_fn == "run_async_custom"), ( + "Unrecognized run function!" + ) + + if run_fn == "run_custom": + sim_data = {} + idx = 0 + for key, sim_val in simulation_dict.items(): + ring_numerical_structure = NumericalStructureConfig( + create=create_ring, + compute_derivatives=vjp_ring, + parameters=ring_parameters_lists[idx], + structure_index=0, + ) + sim_data[key] = run_custom( + sim_val, + local_gradient=LOCAL_GRADIENT, + verbose=VERBOSE, + numerical_structures=ring_numerical_structure, + ) + + idx += 1 + elif run_fn == "run_async_custom": + user_vjp_dict = {} + numerical_structures_dict = {} + + for idx, key in enumerate(simulation_dict): + ring_numerical_structure = NumericalStructureConfig( + create=create_ring, + compute_derivatives=vjp_ring, + parameters=ring_parameters_lists[idx], + structure_index=0, + ) + user_vjp_dict[key] = ((1, "radius", vjp_ring), (1, "center", vjp_ring)) + numerical_structures_dict[key] = ring_numerical_structure + + sim_data = run_async_custom( + simulation_dict, + path_dir=sim_path_dir, + local_gradient=LOCAL_GRADIENT, + verbose=VERBOSE, + numerical_structures=numerical_structures_dict, + ) + + objective_vals = [] + for idx in range(len(ring_parameters_lists)): + objective_vals.append( + eval_fn(sim_data[f"numerical_numerical_structures_testing_{idx}"]) + ) + + if len(ring_parameters_lists) == 1: + return objective_vals[0] + + return objective_vals + + return objective + + +def make_eval_fns(): + def transmission(sim_data): + total = 0.0 + + return np.sum(np.abs(sim_data["monitor_fields"].flux.data) ** 2) + + eval_fns = [transmission] + eval_fn_names = ["transmission"] + + return eval_fns, eval_fn_names + + +background_indices = [1.0] +mesh_wvls_um = [1.5] +adj_wvls_um = [1.5] + +orders_x = [(1,)] +orders_y = [(0,)] +polarizations = ["p"] + + +pw_angles_deg = [0.0] + +run_functions = ["run_custom", "run_async_custom"] + +test_parameters = [] + +test_number = 0 +for idx in range(len(mesh_wvls_um)): + mesh_wvl_um = mesh_wvls_um[idx] + adj_wvl_um = adj_wvls_um[idx] + + eval_fns, eval_fn_names = make_eval_fns() + + for pw_angle_deg in pw_angles_deg: + for monitor_bg_index in background_indices: + for eval_fn_idx, eval_fn in enumerate(eval_fns): + for run_fn in run_functions: + test_parameters.append( + { + "mesh_wvl_um": mesh_wvl_um, + "adj_wvl_um": adj_wvl_um, + "monitor_bg_index": monitor_bg_index, + "pw_angle_deg": pw_angle_deg, + "eval_fn": eval_fn, + "eval_fn_name": eval_fn_names[eval_fn_idx], + "run_fn": run_fn, + "test_number": test_number, + } + ) + + test_number += 1 + + +@pytest.mark.numerical +@pytest.mark.parametrize( + "test_parameters, dir_name", + zip( + test_parameters, + ([NUMERICAL_RESULTS_DATA_DIR] if SAVE_FD_ADJ_DATA else [None]) * len(test_parameters), + ), + indirect=["dir_name"], +) +def test_finite_difference_numerical_structures(test_parameters, rng, tmp_path, create_directory): + """Test a variety of autograd permittivity gradients for DiffractionData by""" + """comparing them to numerical finite difference.""" + + test_number = test_parameters["test_number"] + + ( + mesh_wvl_um, + adj_wvl_um, + monitor_bg_index, + pw_angle_deg, + eval_fn, + eval_fn_name, + run_fn, + test_number, + ) = operator.itemgetter( + "mesh_wvl_um", + "adj_wvl_um", + "monitor_bg_index", + "pw_angle_deg", + "eval_fn", + "eval_fn_name", + "run_fn", + "test_number", + )(test_parameters) + + sim_geometry = get_sim_geometry(mesh_wvl_um) + + dim_um = mesh_wvl_um + thickness_um = 0.5 * mesh_wvl_um + block = td.Box( + center=(sim_geometry.center[0], sim_geometry.center[1], 0), + size=(dim_um, dim_um, thickness_um), + ) + + eval_fns, eval_fn_names = make_eval_fns() + + sim_path_dir = tmp_path / f"test{test_number}" + sim_path_dir.mkdir() + + objective = create_objective_function( + block, + lambda mesh_wvl_um=mesh_wvl_um, + adj_wvl_um=adj_wvl_um, + pw_angle_deg=pw_angle_deg, + monitor_bg_index=monitor_bg_index: make_base_sim( + mesh_wvl_um=mesh_wvl_um, + adj_wvl_um=adj_wvl_um, + pw_angle_deg=pw_angle_deg, + monitor_bg_index=monitor_bg_index, + ), + eval_fn, + run_fn, + sim_path_dir=str(sim_path_dir), + ) + + obj_val_and_grad = ag.value_and_grad(objective) + + ring_init_mesh_wvl_factor = [0.15, 0.30, 0.2] + ring_init = [r * mesh_wvl_um for r in ring_init_mesh_wvl_factor] + + test_results = np.zeros((2, len(ring_init))) + + obj, adj_grad = obj_val_and_grad([ring_init]) + adj_grad = np.squeeze(np.array(adj_grad)) + + # empirical step size from running other finite difference tests for field + # cases with permittivity + fd_step = FD_STEP_MESH_WVL_FACTOR * mesh_wvl_um + + all_rings = [] + for fd_idx in range(len(ring_init)): + ring_up = ring_init.copy() + ring_down = ring_init.copy() + + ring_up[fd_idx] += fd_step + ring_down[fd_idx] -= fd_step + + all_rings.append(ring_up) + all_rings.append(ring_down) + + all_obj = objective(all_rings) + + fd_grad = np.zeros(len(ring_init)) + for fd_idx in range(len(ring_init)): + obj_up_location = 2 * fd_idx + obj_down_location = 2 * fd_idx + 1 + + fd_grad[fd_idx] = (all_obj[obj_up_location] - all_obj[obj_down_location]) / (2 * fd_step) + + rms_error = np.linalg.norm(fd_grad - adj_grad) + fd_mag = np.linalg.norm(fd_grad) + adj_mag = np.linalg.norm(adj_grad) + + dot = np.sum((fd_grad / fd_mag) * (adj_grad / adj_mag)) + overlap_deg = np.arccos(dot) * 180.0 / np.pi + + print("\n" * 3) + print("-" * 20) + print(f"Numerical test #{test_number}") + print(f"Mesh and adjoint wavelengths: {mesh_wvl_um}, {adj_wvl_um}") + print(f"Input plane wave angle (deg): {pw_angle_deg}") + print(f"Background index for monitor: {monitor_bg_index}") + print(f"Eval function: {eval_fn_name}") + print(f"RMS Error: {rms_error}") + print(f"Gradient overlap (deg): {overlap_deg}") + print(f"FD, Adj magnitudes: {fd_mag}, {adj_mag}") + print("-" * 20) + print("\n" * 3) + + assert overlap_deg < OVERLAP_ERROR_THRESHOLD_DEG, ( + "Adjoint and finite difference gradients misaligned." + ) + + test_results[SAVE_FD_LOC, :] = fd_grad + test_results[SAVE_ADJ_LOC, :] = adj_grad + + test_number += 1 + + if PLOT_FD_ADJ_COMPARISON: + plt.plot(adj_grad, color="g", linewidth=2.0) + plt.plot(fd_grad, color="b", linewidth=1.5, linestyle="--") + plt.title(f"Gradient for objective: {eval_fn_name}") + plt.legend(["Adjoint", "Finite difference"]) + plt.xlabel("Sample number") + plt.ylabel("Gradient value") + plt.show() + + if SAVE_FD_ADJ_DATA: + np.save(f"{NUMERICAL_RESULTS_DATA_DIR}/results_{test_number}.npy", test_results) diff --git a/tests/test_components/autograd/numerical/test_autograd_periodic_numerical.py b/tests/test_components/autograd/numerical/test_autograd_periodic_numerical.py index a71a34292a..31a8ea9c3d 100644 --- a/tests/test_components/autograd/numerical/test_autograd_periodic_numerical.py +++ b/tests/test_components/autograd/numerical/test_autograd_periodic_numerical.py @@ -122,7 +122,6 @@ def make_base_sim( ) else: diffraction_monitor = td.DiffractionMonitor( - # center=(0, 0, -0.35 * sim_size_um[2]), center=(sim_center_um[0], sim_center_um[1], -0.35 * sim_size_um[2]), size=(np.inf, np.inf, 0), name="monitor_diffraction", diff --git a/tests/test_components/autograd/numerical/test_autograd_user_vjp.py b/tests/test_components/autograd/numerical/test_autograd_user_vjp.py new file mode 100644 index 0000000000..ea4c22c2ce --- /dev/null +++ b/tests/test_components/autograd/numerical/test_autograd_user_vjp.py @@ -0,0 +1,452 @@ +# tests user_vjp autograd hook for run_custom and run_async_custom and compares to numerically computed finite difference gradients +from __future__ import annotations + +import operator +import sys + +import autograd as ag +import matplotlib.pylab as plt +import numpy as np +import pytest +import xarray as xr + +import tidy3d as td +from tidy3d.web.api.autograd.autograd import run_async_custom, run_custom +from tidy3d.web.api.autograd.types import UserVJPConfig + +PLOT_FD_ADJ_COMPARISON = True +NUM_FINITE_DIFFERENCE = 10 +SAVE_FD_ADJ_DATA = True +SAVE_FD_LOC = 0 +SAVE_ADJ_LOC = 1 +LOCAL_GRADIENT = True +VERBOSE = False +NUMERICAL_RESULTS_DATA_DIR = "./numerical_user_vjp_test/" +SHOW_PRINT_STATEMENTS = True + +OVERLAP_ERROR_THRESHOLD_DEG = 10.0 + +ADJOINT_SPHERE_PERMITTIVITY = 1.5**2 + +if PLOT_FD_ADJ_COMPARISON: + pytestmark = pytest.mark.usefixtures("mpl_config_interactive") +else: + pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") + +if SHOW_PRINT_STATEMENTS: + sys.stdout = sys.stderr + + +SIMULATION_SIZE_MESH_WVL_FACTOR = 3.5 +SIMULATION_HEIGHT_WVL_FACTOR = 5 + +SPHERE_OFFSET_MAX_MESH_WVL_FACTOR = 0.25 +SPHERE_MIN_RADIUS_MESH_WVL_FACTOR = 0.3 +SPHERE_MAX_RADIUS_MESH_WVL_FACTOR = 0.4 + +FD_STEP_MESH_WVL_FACTOR = 1.0 / 75.0 + + +def get_sim_geometry(mesh_wvl_um): + return td.Box( + size=( + SIMULATION_SIZE_MESH_WVL_FACTOR * mesh_wvl_um, + SIMULATION_SIZE_MESH_WVL_FACTOR * mesh_wvl_um, + SIMULATION_HEIGHT_WVL_FACTOR * mesh_wvl_um, + ), + center=(0, 0, 0), + ) + + +def make_base_sim( + mesh_wvl_um, + adj_wvl_um, + pw_angle_deg, + monitor_bg_index=1.0, + run_time=2e-11, +): + sim_geometry = get_sim_geometry(mesh_wvl_um) + sim_size_um = sim_geometry.size + sim_center_um = sim_geometry.center + + src_size = sim_size_um[0:2] + (0,) + + wl_min_src_um = 0.9 * adj_wvl_um + wl_max_src_um = 1.1 * adj_wvl_um + + fwidth_src = td.C_0 * ((1.0 / wl_min_src_um) - (1.0 / wl_max_src_um)) + freq0 = td.C_0 / adj_wvl_um + + pulse = td.GaussianPulse(freq0=freq0, fwidth=fwidth_src) + + src = td.PlaneWave( + center=(sim_center_um[0], sim_center_um[1], -2.0), + size=[td.inf, td.inf, 0], + source_time=pulse, + direction="+", + angle_theta=(pw_angle_deg * np.pi / 180.0), + ) + + boundary_spec = td.BoundarySpec( + x=td.Boundary.pml(), + y=td.Boundary.pml(), + z=td.Boundary.pml(), + ) + + field_monitor = td.FieldMonitor( + center=( + sim_center_um[0], + sim_center_um[1], + mesh_wvl_um / 1.5, + ), + size=(mesh_wvl_um, mesh_wvl_um, 0), + name="monitor_fields", + freqs=[freq0], + ) + + monitor_index_block = td.Box( + center=(sim_center_um[0], sim_center_um[1], 0.25 * sim_size_um[2] + mesh_wvl_um), + size=(*tuple(2 * size for size in sim_size_um[0:2]), mesh_wvl_um + 0.5 * sim_size_um[2]), + ) + monitor_index_block_structure = td.Structure( + geometry=monitor_index_block, medium=td.Medium(permittivity=monitor_bg_index**2) + ) + + sim_base = td.Simulation( + center=sim_center_um, + size=sim_size_um, + grid_spec=td.GridSpec.auto( + min_steps_per_wvl=30, + wavelength=mesh_wvl_um, + ), + structures=[monitor_index_block_structure], + sources=[src], + monitors=[field_monitor], + run_time=run_time, + boundary_spec=boundary_spec, + subpixel=True, + ) + + return sim_base + + +def vjp_sphere(sphere, derivative_info): + max_frequency = np.max(derivative_info.frequencies) + min_wvl = td.C_0 / max_frequency + + step_size = min_wvl / 20.0 + + ps_paths = set() + ps_paths.update({("permittivity",)}) + + update_kwargs = { + "paths": list(ps_paths), + "deep": False, + } + + def finite_difference_gradient(perturb_up, perturb_down, derivative_info_): + eps_up = derivative_info.updated_epsilon(perturb_up) + eps_down = derivative_info.updated_epsilon(perturb_down) + eps_grad = (eps_up - eps_down) / (2 * step_size) + + derivative_info_custom_medium = derivative_info_.updated_copy(**update_kwargs) + + custom_medium = td.CustomMedium(permittivity=xr.ones_like(eps_grad.isel(f=0, drop=True))) + vjps_custom_medium = custom_medium._compute_derivatives(derivative_info_custom_medium) + + total_grad = np.real(np.sum(eps_grad.sum("f").data * vjps_custom_medium[("permittivity",)])) + + return total_grad + + vjps = {} + for path in derivative_info.paths: + if path[0:2] == ( + "geometry", + "radius", + ): + sphere_up = sphere.updated_copy(radius=sphere.radius + step_size) + sphere_down = sphere.updated_copy(radius=sphere.radius - step_size) + vjps[path] = finite_difference_gradient(sphere_up, sphere_down, derivative_info) + elif path[0:2] == ("geometry", "center"): + if len(path) == 2: + center_indices = (0, 1, 2) + else: + _, center_index = path[1:] + center_indices = [center_index] + + vjp_result = [] + for center_index in center_indices: + center_up = list(sphere.center) + center_down = list(sphere.center) + + center_up[center_index] += step_size + center_down[center_index] -= step_size + + sphere_up = sphere.updated_copy(center=center_up) + sphere_down = sphere.updated_copy(center=center_down) + + vjp_result.append( + finite_difference_gradient(sphere_up, sphere_down, derivative_info) + ) + + vjps[path] = vjp_result if len(path) == 2 else vjp_result[0] + + return vjps + + +def create_objective_function(geometry, create_sim_base, eval_fn, run_fn, sim_path_dir): + def objective(sphere_parameters_lists): + sim_base = create_sim_base() + + simulation_dict = {} + for idx, sphere_parameters in enumerate(sphere_parameters_lists): + sphere_structure = td.Structure( + geometry=td.Sphere(center=sphere_parameters[0:3], radius=sphere_parameters[3]), + medium=td.Medium(permittivity=ADJOINT_SPHERE_PERMITTIVITY), + ) + + sim_with_sphere = sim_base.updated_copy( + structures=(*sim_base.structures, sphere_structure) + ) + + simulation_dict[f"numerical_user_vjp_testing_{idx}"] = sim_with_sphere.copy() + + user_vjp_single = UserVJPConfig( + structure_index=1, + compute_derivatives=vjp_sphere, + ) + + assert (run_fn == "run_custom") or (run_fn == "run_async_custom"), ( + "Unrecognized run function!" + ) + if run_fn == "run_custom": + sim_data = {} + for key, sim_val in simulation_dict.items(): + sim_data[key] = run_custom( + sim_val, + local_gradient=LOCAL_GRADIENT, + verbose=VERBOSE, + user_vjp=user_vjp_single, + ) + elif run_fn == "run_async_custom": + sim_data = run_async_custom( + simulation_dict, + path_dir=sim_path_dir, + local_gradient=LOCAL_GRADIENT, + verbose=VERBOSE, + user_vjp=user_vjp_single, + ) + + objective_vals = [] + for idx in range(len(sphere_parameters_lists)): + objective_vals.append(eval_fn(sim_data[f"numerical_user_vjp_testing_{idx}"])) + + if len(sphere_parameters_lists) == 1: + return objective_vals[0] + + return objective_vals + + return objective + + +def make_eval_fns(): + def transmission(sim_data): + total = 0.0 + + return np.sum(np.abs(sim_data["monitor_fields"].flux.data) ** 2) + + eval_fns = [transmission] + eval_fn_names = ["transmission"] + + return eval_fns, eval_fn_names + + +background_indices = [1.0] +mesh_wvls_um = [1.5] +adj_wvls_um = [1.5] + +orders_x = [(1,)] +orders_y = [(0,)] +polarizations = ["p"] + +pw_angles_deg = [0.0] + +run_functions = ["run_custom", "run_async_custom"] + +test_parameters = [] + +test_number = 0 +for idx in range(len(mesh_wvls_um)): + mesh_wvl_um = mesh_wvls_um[idx] + adj_wvl_um = adj_wvls_um[idx] + + eval_fns, eval_fn_names = make_eval_fns() + + for pw_angle_deg in pw_angles_deg: + for monitor_bg_index in background_indices: + for eval_fn_idx, eval_fn in enumerate(eval_fns): + for run_fn in run_functions: + test_parameters.append( + { + "mesh_wvl_um": mesh_wvl_um, + "adj_wvl_um": adj_wvl_um, + "monitor_bg_index": monitor_bg_index, + "pw_angle_deg": pw_angle_deg, + "eval_fn": eval_fn, + "eval_fn_name": eval_fn_names[eval_fn_idx], + "run_fn": run_fn, + "test_number": test_number, + } + ) + + test_number += 1 + + +@pytest.mark.numerical +@pytest.mark.parametrize( + "test_parameters, dir_name", + zip( + test_parameters, + ([NUMERICAL_RESULTS_DATA_DIR] if SAVE_FD_ADJ_DATA else [None]) * len(test_parameters), + ), + indirect=["dir_name"], +) +def test_finite_difference_user_vjp(test_parameters, rng, tmp_path, create_directory): + """Test a variety of autograd permittivity gradients for DiffractionData by""" + """comparing them to numerical finite difference.""" + + ( + mesh_wvl_um, + adj_wvl_um, + monitor_bg_index, + pw_angle_deg, + eval_fn, + eval_fn_name, + run_fn, + test_number, + ) = operator.itemgetter( + "mesh_wvl_um", + "adj_wvl_um", + "monitor_bg_index", + "pw_angle_deg", + "eval_fn", + "eval_fn_name", + "run_fn", + "test_number", + )(test_parameters) + + sim_geometry = get_sim_geometry(mesh_wvl_um) + + dim_um = mesh_wvl_um + thickness_um = 0.5 * mesh_wvl_um + block = td.Box( + center=(sim_geometry.center[0], sim_geometry.center[1], 0), + size=(dim_um, dim_um, thickness_um), + ) + + sim_path_dir = tmp_path / f"test{test_number}" + sim_path_dir.mkdir() + + objective = create_objective_function( + block, + lambda mesh_wvl_um=mesh_wvl_um, + adj_wvl_um=adj_wvl_um, + pw_angle_deg=pw_angle_deg, + monitor_bg_index=monitor_bg_index: make_base_sim( + mesh_wvl_um=mesh_wvl_um, + adj_wvl_um=adj_wvl_um, + pw_angle_deg=pw_angle_deg, + monitor_bg_index=monitor_bg_index, + ), + eval_fn, + run_fn, + sim_path_dir=str(sim_path_dir), + ) + + obj_val_and_grad = ag.value_and_grad(objective) + + sphere_init = [ + *rng.uniform( + low=-SPHERE_OFFSET_MAX_MESH_WVL_FACTOR * mesh_wvl_um, + high=SPHERE_OFFSET_MAX_MESH_WVL_FACTOR * mesh_wvl_um, + size=2, + ), + 0.0, + *rng.uniform( + low=SPHERE_MIN_RADIUS_MESH_WVL_FACTOR * mesh_wvl_um, + high=SPHERE_MAX_RADIUS_MESH_WVL_FACTOR * mesh_wvl_um, + size=1, + ), + ] + + test_results = np.zeros((2, len(sphere_init))) + + obj, adj_grad = obj_val_and_grad([sphere_init]) + adj_grad = np.squeeze(np.array(adj_grad)) + + # empirical step size from running other finite difference tests for field + # cases with permittivity + fd_step = FD_STEP_MESH_WVL_FACTOR * mesh_wvl_um + + all_spheres = [] + # pattern_dot_adj_gradient = np.zeros(len(sphere_init)) + + for fd_idx in range(len(sphere_init)): + sphere_up = sphere_init.copy() + sphere_down = sphere_init.copy() + + sphere_up[fd_idx] += fd_step + sphere_down[fd_idx] -= fd_step + + all_spheres.append(sphere_up) + all_spheres.append(sphere_down) + + all_obj = objective(all_spheres) + + fd_grad = np.zeros(len(sphere_init)) + for fd_idx in range(len(sphere_init)): + obj_up_location = 2 * fd_idx + obj_down_location = 2 * fd_idx + 1 + + fd_grad[fd_idx] = (all_obj[obj_up_location] - all_obj[obj_down_location]) / (2 * fd_step) + + rms_error = np.linalg.norm(fd_grad - adj_grad) + fd_mag = np.linalg.norm(fd_grad) + adj_mag = np.linalg.norm(adj_grad) + + dot = np.sum((fd_grad / fd_mag) * (adj_grad / adj_mag)) + overlap_deg = np.arccos(dot) * 180.0 / np.pi + + print("\n" * 3) + print("-" * 20) + print(f"Numerical test #{test_number}") + print(f"Mesh and adjoint wavelengths: {mesh_wvl_um}, {adj_wvl_um}") + print(f"Input plane wave angle (deg): {pw_angle_deg}") + print(f"Background index for monitor: {monitor_bg_index}") + print(f"Eval function: {eval_fn_name}") + print(f"RMS Error: {rms_error}") + print(f"Gradient overlap (deg): {overlap_deg}") + print(f"FD, Adj magnitudes: {fd_mag}, {adj_mag}") + print("-" * 20) + print("\n" * 3) + + assert overlap_deg < OVERLAP_ERROR_THRESHOLD_DEG, ( + "Adjoint and finite difference gradients misaligned." + ) + + test_results[SAVE_FD_LOC, :] = fd_grad + test_results[SAVE_ADJ_LOC, :] = adj_grad + + test_number += 1 + + if PLOT_FD_ADJ_COMPARISON: + plt.plot(adj_grad, color="g", linewidth=2.0) + plt.plot(fd_grad, color="b", linewidth=1.5, linestyle="--") + plt.title(f"Gradient for objective: {eval_fn_name}") + plt.legend(["Adjoint", "Finite difference"]) + plt.xlabel("Sample number") + plt.ylabel("Gradient value") + plt.show() + + if SAVE_FD_ADJ_DATA: + np.save(f"{NUMERICAL_RESULTS_DATA_DIR}/results_{test_number}.npy", test_results) diff --git a/tests/test_components/autograd/test_autograd.py b/tests/test_components/autograd/test_autograd.py index 8621eaded7..04dec5297f 100644 --- a/tests/test_components/autograd/test_autograd.py +++ b/tests/test_components/autograd/test_autograd.py @@ -20,6 +20,7 @@ import tidy3d as td import tidy3d.web as web +from tidy3d.components.autograd import get_static from tidy3d.components.autograd.derivative_utils import DerivativeInfo from tidy3d.components.autograd.field_map import FieldMap from tidy3d.components.autograd.utils import is_tidy_box @@ -28,8 +29,12 @@ from tidy3d.config import config from tidy3d.exceptions import AdjointError from tidy3d.plugins.polyslab import ComplexPolySlab +from tidy3d.plugins.smatrix import ComponentModeler, Port +from tidy3d.plugins.smatrix.run import _run_local from tidy3d.web import run, run_async from tidy3d.web.api.autograd import autograd as autograd_module +from tidy3d.web.api.autograd.autograd import run_async_custom, run_custom +from tidy3d.web.api.autograd.types import NumericalStructureConfig, UserVJPConfig from ...utils import SIM_FULL, AssertLogLevel, run_emulated, tracer_arr @@ -101,6 +106,7 @@ def _make_di(paths, freq): ), bounds_intersect=((-1, -1, -1), (1, 1, 1)), simulation_bounds=((-2, -2, -2), (2, 2, 2)), + updated_epsilon=None, ) @@ -116,6 +122,7 @@ def _make_di(paths, freq): IS_3D = False POLYSLAB_AXIS = 2 +POLYSLAB_SELECT_VERTICES = 0 # angle of the measurement waveguide ROT_ANGLE_WG = 0 * np.pi / 4 @@ -239,7 +246,6 @@ def emulated_run_fwd(simulation, task_name, **run_kwargs) -> td.SimulationData: def emulated_run_bwd(simulation, task_name, **run_kwargs) -> td.SimulationData: """What gets called instead of ``web/api/autograd/autograd.py::_run_tidy3d_bwd``.""" - task_name_fwd = "".join(task_name.partition("_adjoint")[:-2]) # run the adjoint sim @@ -259,6 +265,8 @@ def emulated_run_bwd(simulation, task_name, **run_kwargs) -> td.SimulationData: sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_keys=sim_fields_keys, + numerical_structures=None, + user_vjp=None, ) return traced_fields_vjp @@ -266,6 +274,7 @@ def emulated_run_bwd(simulation, task_name, **run_kwargs) -> td.SimulationData: def emulated_run_async_fwd(simulations, **run_kwargs) -> td.SimulationData: batch_data_orig, task_ids_fwd = {}, {} sim_fields_keys_dict = run_kwargs.pop("sim_fields_keys_dict", None) + for task_name, simulation in simulations.items(): if sim_fields_keys_dict is not None: run_kwargs["sim_fields_keys"] = sim_fields_keys_dict[task_name] @@ -306,7 +315,9 @@ def emulated_run_async_bwd(simulations, **run_kwargs) -> td.SimulationData: return emulated_run_fwd, emulated_run_bwd -def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: +def make_structures( + params: anp.ndarray, polyslab_axis: int = POLYSLAB_AXIS +) -> dict[str, td.Structure]: """Make a dictionary of the structures given the parameters.""" np.random.seed(0) @@ -406,7 +417,7 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: matrix = np.random.random((N_PARAMS,)) - 0.5 params_01 = 0.5 * (anp.tanh(matrix @ params / 3) + 1) - free_param = "vertices" if POLYSLAB_AXIS == 0 else "slab_bounds" + free_param = "vertices" if polyslab_axis == POLYSLAB_SELECT_VERTICES else "slab_bounds" if free_param == "vertices": radii = 0.5 + 0.5 * params_01 @@ -415,8 +426,6 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: radii = 1.0 shift = 0.1 * params_01 slab_bounds = (-0.5 + shift, 0.5 + shift) - # slab_bounds = (-0.5 + shift, 0.5) - # slab_bounds = (-0.5, 0.5 + shift) phis = 2 * anp.pi * anp.linspace(0, 1, NUM_VERTICES + 1)[:NUM_VERTICES] xs = radii * anp.cos(phis) @@ -427,7 +436,7 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: geometry=td.PolySlab( vertices=vertices, slab_bounds=slab_bounds, - axis=POLYSLAB_AXIS, + axis=polyslab_axis, sidewall_angle=0.00, dilation=0.00, ), @@ -438,7 +447,7 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: geometry=td.PolySlab( vertices=vertices, slab_bounds=slab_bounds, - axis=POLYSLAB_AXIS, + axis=polyslab_axis, sidewall_angle=0.00, dilation=0.00, ), @@ -658,9 +667,6 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = True) -> None: args = [("polyslab", "mode")] -# args = [("polyslab", "mode")] - - def get_functions(structure_key: str, monitor_key: str) -> dict[str, typing.Callable]: if structure_key == ALL_KEY: structure_keys = structure_keys_ @@ -681,10 +687,10 @@ def get_functions(structure_key: str, monitor_key: str) -> dict[str, typing.Call monitors.append(monitor_traced) monitor_pp_fns[monitor_key] = monitor_pp_fn - def make_sim(*args) -> td.Simulation: + def make_sim(*args, polyslab_axis=POLYSLAB_AXIS) -> td.Simulation: """Make the simulation with all of the fields.""" - structures_traced_dict = make_structures(*args) + structures_traced_dict = make_structures(*args, polyslab_axis=polyslab_axis) structures = list(SIM_BASE.structures) for structure_key in structure_keys: @@ -727,6 +733,581 @@ def test_polyslab_axis_ops(axis): basis_vecs = p.edge_basis_vectors(edges=edges) +def make_polyslab_user_vjp(user_vjp_val): + def polyslab_user_vjp(polyslab, derivative_info): + vjps = {} + + # should there only be one path here since that is how user_vjp is specified? + for path in derivative_info.paths: + # print(f'working on path = {path}') + if path[0:2] == ("geometry", "vertices"): + vjps[path] = user_vjp_val * np.ones(polyslab.vertices.shape) + elif path[0:2] == ("geometry", "slab_bounds"): + if len(path) == 3: + vjps[path] = (user_vjp_val, user_vjp_val)[path[2]] + else: + vjps[path] = (user_vjp_val, user_vjp_val) + + return vjps + + return polyslab_user_vjp + + +@pytest.mark.parametrize("structure_key, monitor_key", [("polyslab", "mode")]) +@pytest.mark.parametrize("polyslab_axis", [0]) # , 1, 2]) +@pytest.mark.parametrize("use_run_async", [False]) # [True, False]) +@pytest.mark.parametrize("use_task_names", [True, False]) +@pytest.mark.parametrize("use_single_user_vjp", [True, False]) +@pytest.mark.parametrize("local_gradient", [True]) # , False]) +def test_autograd_user_vjp( + use_emulated_run, + structure_key, + monitor_key, + polyslab_axis, + use_run_async, + use_task_names, + use_single_user_vjp, + local_gradient, +): + """Test that we can override a vjp with a user defined function.""" + + fn_dict = get_functions(structure_key, monitor_key) + make_sim = fn_dict["sim"] + postprocess = fn_dict["postprocess"] + + task_names = {"test_a", "adjoint", "_test"} + + def make_objective(user_vjp_val): + polyslab_user_vjp = make_polyslab_user_vjp(user_vjp_val) + + user_vjp_tuple = ( + UserVJPConfig( + structure_index=1, + compute_derivatives=polyslab_user_vjp, + path_key=( + ( + "geometry", + "vertices", + ) + ), + ), + UserVJPConfig( + structure_index=1, + compute_derivatives=polyslab_user_vjp, + path_key=( + ( + "geometry", + "slab_bounds", + ) + ), + ), + ) + + user_vjp_single = UserVJPConfig( + structure_index=1, + compute_derivatives=polyslab_user_vjp, + ) + + user_vjp_element = user_vjp_single if use_single_user_vjp else user_vjp_tuple + + def objective(*args): + if use_task_names: + sims = { + task_name: make_sim(*args, polyslab_axis=polyslab_axis) + for task_name in task_names + } + user_vjp = dict.fromkeys(sims.keys(), user_vjp_element) + else: + sims = [make_sim(*args, polyslab_axis=polyslab_axis)] * len(task_names) + user_vjp = [user_vjp_element] * len(task_names) + batch_data = {} + if use_run_async: + batch_data = run_async_custom( + sims, user_vjp=user_vjp, local_gradient=local_gradient + ) + else: + if use_task_names: + for task_name, sim in sims.items(): + batch_data[task_name] = run_custom( + sim, + task_name, + user_vjp=user_vjp[task_name], + local_gradient=local_gradient, + ) + else: + for idx, sim in enumerate(sims): + batch_data[idx] = run_custom( + sim, user_vjp=user_vjp[idx], local_gradient=local_gradient + ) + + value = 0.0 + + for _, sim_data in batch_data.items(): + value += postprocess(sim_data) + return value + + return objective + + user_vjp_val = 1.0 + user_vjp_val_scale = 10.0 * user_vjp_val + + if not local_gradient: + with pytest.raises( + td.exceptions.AdjointError, + match="user_vjp specified for a remote gradient not supported.", + ): + val, grad = ag.value_and_grad(make_objective(user_vjp_val))(params0) + else: + val, grad = ag.value_and_grad(make_objective(user_vjp_val))(params0) + val_scale, grad_scale = ag.value_and_grad(make_objective(user_vjp_val_scale))(params0) + + assert np.isclose( + np.sum(np.abs(grad * (user_vjp_val_scale / user_vjp_val) - grad_scale)), 0.0 + ), "Gradients were not set by the user vjp" + + +@pytest.mark.parametrize("structure_key, monitor_key", [("polyslab", "mode")]) +@pytest.mark.parametrize("polyslab_axis", [0, 1, 2]) +@pytest.mark.parametrize("use_run_async", [True, False]) +@pytest.mark.parametrize("use_task_names", [True, False]) +@pytest.mark.parametrize("use_single_user_vjp", [True, False]) +def test_autograd_user_vjp_selective( + use_emulated_run, + structure_key, + monitor_key, + polyslab_axis, + use_run_async, + use_task_names, + use_single_user_vjp, +): + """Test that we can selectively override a vjp with a user defined function that covers some of, but not all, gradient keys.""" + + fn_dict = get_functions(structure_key, monitor_key) + make_sim = fn_dict["sim"] + postprocess = fn_dict["postprocess"] + + task_names = {"test_a", "adjoint", "_test"} + + def make_objective(user_vjp_val): + polyslab_user_vjp = make_polyslab_user_vjp(user_vjp_val) + + user_vjp_tuple = ( + UserVJPConfig( + structure_index=1, + compute_derivatives=polyslab_user_vjp, + path_key=( + ( + "geometry", + "vertices", + ) + ), + ), + ) + + user_vjp_single = UserVJPConfig( + structure_index=1, + compute_derivatives=polyslab_user_vjp, + path_key=( + ( + "geometry", + "vertices", + ) + ), + ) + + user_vjp_element = user_vjp_single if use_single_user_vjp else user_vjp_tuple + + def objective(*args): + if use_task_names: + sims = { + task_name: make_sim(*args, polyslab_axis=polyslab_axis) + for task_name in task_names + } + user_vjp = dict.fromkeys(task_names, user_vjp_element) + else: + sims = [make_sim(*args, polyslab_axis=polyslab_axis)] * len(task_names) + user_vjp = [user_vjp_element] * len(task_names) + + batch_data = {} + if use_run_async: + batch_data = run_async_custom(sims, user_vjp=user_vjp, local_gradient=True) + else: + if use_task_names: + for task_name, sim in sims.items(): + batch_data[task_name] = run_custom( + sim, task_name, user_vjp=user_vjp[task_name], local_gradient=True + ) + else: + for idx, sim in enumerate(sims): + batch_data[idx] = run_custom( + sim, user_vjp=user_vjp[idx], local_gradient=True + ) + + value = 0.0 + + for _, sim_data in batch_data.items(): + value += postprocess(sim_data) + return value + + return objective + + user_vjp_val = 1.0 + user_vjp_val_scale = 10.0 * user_vjp_val + + val, grad = ag.value_and_grad(make_objective(user_vjp_val))(params0) + val_scale, grad_scale = ag.value_and_grad(make_objective(user_vjp_val_scale))(params0) + + if polyslab_axis == POLYSLAB_SELECT_VERTICES: + assert np.isclose( + np.sum(np.abs(grad * (user_vjp_val_scale / user_vjp_val) - grad_scale)), 0.0 + ), "Gradients were not set by the user vjp when they should have been" + else: + assert not np.isclose( + np.sum(np.abs(grad * (user_vjp_val_scale / user_vjp_val) - grad_scale)), 0.0 + ), "Gradients were set by the user vjp when they should not have been" + + +@pytest.mark.parametrize("structure_key, monitor_key", [("polyslab", "mode")]) +@pytest.mark.parametrize("polyslab_axis", [0, 1, 2]) +@pytest.mark.parametrize("use_single_user_vjp", [True, False]) +@pytest.mark.parametrize("run_function", [_run_local, run_custom]) +@pytest.mark.parametrize("local_gradient", [True, False]) +def test_autograd_cm_user_vjp( + use_emulated_run, + structure_key, + monitor_key, + polyslab_axis, + use_single_user_vjp, + run_function, + local_gradient, +): + """Test that we can override a vjp with a user defined function in component modeler simulations.""" + + fn_dict = get_functions(structure_key, monitor_key) + make_sim = fn_dict["sim"] + postprocess = fn_dict["postprocess"] + + def make_objective(user_vjp_val): + polyslab_user_vjp = make_polyslab_user_vjp(user_vjp_val) + + user_vjp_tuple = ( + UserVJPConfig( + structure_index=1, + compute_derivatives=polyslab_user_vjp, + path_key=( + ( + "geometry", + "vertices", + ) + ), + ), + UserVJPConfig( + structure_index=1, + compute_derivatives=polyslab_user_vjp, + path_key=( + ( + "geometry", + "slab_bounds", + ) + ), + ), + ) + + user_vjp_single = UserVJPConfig( + structure_index=1, + compute_derivatives=polyslab_user_vjp, + ) + + user_vjp_element = user_vjp_single if use_single_user_vjp else user_vjp_tuple + + def objective(*args): + base_sim = make_sim(*args, polyslab_axis=polyslab_axis) + find_mode_monitors = [ + monitor for monitor in base_sim.monitors if isinstance(monitor, td.ModeMonitor) + ] + + select_mode_monitor = find_mode_monitors[0] + + stripped_sim = base_sim.updated_copy(sources=[], monitors=[]) + + input_port = Port( + center=select_mode_monitor.center, + size=select_mode_monitor.size, + mode_spec=select_mode_monitor.mode_spec, + direction="-", + name="input_port", + ) + + modeler = ComponentModeler( + simulation=stripped_sim, + ports=[input_port], + freqs=select_mode_monitor.freqs, + ) + + smatrix = run_function( + modeler, + user_vjp=user_vjp_element, + local_gradient=local_gradient, + ) + return np.sum(np.abs(smatrix.smatrix().values) ** 2) + + return objective + + user_vjp_val = 1.0 + user_vjp_val_scale = 10.0 * user_vjp_val + + if not local_gradient: + with pytest.raises( + td.exceptions.AdjointError, + match="user_vjp specified for a remote gradient not supported.", + ): + val, grad = ag.value_and_grad(make_objective(user_vjp_val))(params0) + else: + val, grad = ag.value_and_grad(make_objective(user_vjp_val))(params0) + val_scale, grad_scale = ag.value_and_grad(make_objective(user_vjp_val_scale))(params0) + + assert np.isclose( + np.sum(np.abs(grad * (user_vjp_val_scale / user_vjp_val) - grad_scale)), 0.0 + ), "Gradients were not set by the user vjp" + + +@pytest.mark.parametrize("structure_key, monitor_key", [("polyslab", "mode")]) +@pytest.mark.parametrize("polyslab_axis", [0, 1, 2]) +@pytest.mark.parametrize("use_run_async", [True, False]) +@pytest.mark.parametrize("use_single_numerical_structure", [True, False]) +@pytest.mark.parametrize("use_task_names", [True, False]) +@pytest.mark.parametrize("specify_numerical_structure_index", [True, False]) +@pytest.mark.parametrize("local_gradient", [True, False]) +def test_autograd_numerical_structures( + use_emulated_run, + structure_key, + monitor_key, + polyslab_axis, + use_run_async, + use_single_numerical_structure, + use_task_names, + specify_numerical_structure_index, + local_gradient, +): + """Test that we can add numerical structures to autograd simulations.""" + + fn_dict = get_functions(structure_key, monitor_key) + make_sim = fn_dict["sim"] + postprocess = fn_dict["postprocess"] + + task_names = {"test_a", "adjoint", "_test"} + + def make_objective(user_vjp_val): + def objective(*args): + def make_first_polyslab(params): + return make_sim(*args, polyslab_axis=polyslab_axis).structures[1] + + def vjp(parameters, derivative_info): + vjps = {} + + for path in derivative_info.paths: + param_idx = path[0] + + vjps[path] = user_vjp_val + + return vjps + + # ensure the numerical_structures are the reason for the autograd run by stripping + # tracers for the simulation creation + static_args = [get_static(arg) for arg in args] + sim = make_sim(*static_args, polyslab_axis=polyslab_axis) + + structures = [s for idx, s in enumerate(sim.structures) if (not (idx == 1))] + sim_strip_structure = sim.updated_copy(structures=structures) + + if specify_numerical_structure_index: + numerical_structure = NumericalStructureConfig( + create=make_first_polyslab, + compute_derivatives=vjp, + parameters=np.array(args).flatten(), + structure_index=1, + ) + else: + numerical_structure = NumericalStructureConfig( + create=make_first_polyslab, + compute_derivatives=vjp, + parameters=np.array(args).flatten(), + ) + numerical_structures = ( + numerical_structure if use_single_numerical_structure else (numerical_structure,) + ) + + if use_task_names: + sims = dict.fromkeys(task_names, sim_strip_structure) + numerical_structures = dict.fromkeys(task_names, numerical_structures) + else: + sims = [sim_strip_structure] * len(task_names) + numerical_structures = [numerical_structures] * len(task_names) + + batch_data = {} + if use_run_async: + batch_data = run_async_custom( + sims, numerical_structures=numerical_structures, local_gradient=local_gradient + ) + else: + if use_task_names: + for task_name, sim in sims.items(): + batch_data[task_name] = run_custom( + sim, + task_name, + numerical_structures=numerical_structures[task_name], + local_gradient=local_gradient, + ) + else: + for idx, sim in enumerate(sims): + batch_data[idx] = run_custom( + sim, + numerical_structures=numerical_structures[idx], + local_gradient=local_gradient, + ) + + value = 0.0 + + for _, sim_data in batch_data.items(): + value += postprocess(sim_data) + return value + + return objective + + user_vjp_val = 1.0 + user_vjp_val_scale = 10.0 * user_vjp_val + + if not local_gradient: + with pytest.raises( + td.exceptions.AdjointError, + match="numerical_structures specified for a remote gradient not supported.", + ): + val, grad = ag.value_and_grad(make_objective(user_vjp_val))(params0) + else: + val, grad = ag.value_and_grad(make_objective(user_vjp_val))(params0) + val_scale, grad_scale = ag.value_and_grad(make_objective(user_vjp_val_scale))(params0) + + assert np.allclose(grad, len(task_names) * user_vjp_val), ( + "Gradients did not accumulate correctly." + ) + + assert np.isclose( + np.sum(np.abs(grad * (user_vjp_val_scale / user_vjp_val) - grad_scale)), 0.0 + ), "Gradients were not set by the user vjp" + + +@pytest.mark.parametrize("structure_key, monitor_key", [("polyslab", "mode")]) +@pytest.mark.parametrize("polyslab_axis", [0, 1, 2]) +@pytest.mark.parametrize("numerical_structures_specification", ["single", "tuple"]) +@pytest.mark.parametrize("run_function", [_run_local, run_custom]) +@pytest.mark.parametrize("specify_numerical_structure_index", [True, False]) +@pytest.mark.parametrize("local_gradient", [True, False]) +def test_autograd_cm_numerical_structures( + use_emulated_run, + structure_key, + monitor_key, + polyslab_axis, + numerical_structures_specification, + run_function, + specify_numerical_structure_index, + local_gradient, +): + """Test that we can add numerical structures to component modeler autograd simulations.""" + + fn_dict = get_functions(structure_key, monitor_key) + make_sim = fn_dict["sim"] + postprocess = fn_dict["postprocess"] + + def make_objective(user_vjp_val): + def objective(*args): + def make_first_polyslab(params): + return make_sim(*args, polyslab_axis=polyslab_axis).structures[1] + + def vjp(parameters, derivative_info): + vjps = {} + + for path in derivative_info.paths: + param_idx = path[0] + + vjps[path] = user_vjp_val + + return vjps + + if specify_numerical_structure_index: + numerical_structure = NumericalStructureConfig( + create=make_first_polyslab, + compute_derivatives=vjp, + parameters=np.array(args).flatten(), + structure_index=1, + ) + else: + numerical_structure = NumericalStructureConfig( + create=make_first_polyslab, + compute_derivatives=vjp, + parameters=np.array(args).flatten(), + ) + if numerical_structures_specification == "single": + numerical_structures = numerical_structure + elif numerical_structures_specification == "tuple": + numerical_structures = (numerical_structure,) + + # ensure the numerical_structures are the reason for the autograd run by stripping + # tracers for the simulation creation + static_args = [get_static(arg) for arg in args] + sim = make_sim(*static_args, polyslab_axis=polyslab_axis) + + structures = [s for idx, s in enumerate(sim.structures) if (not (idx == 1))] + sim_strip_structure = sim.updated_copy(structures=structures) + + find_mode_monitors = [ + monitor + for monitor in sim_strip_structure.monitors + if isinstance(monitor, td.ModeMonitor) + ] + + select_mode_monitor = find_mode_monitors[0] + + stripped_sim = sim_strip_structure.updated_copy(sources=[], monitors=[]) + + input_port = Port( + center=select_mode_monitor.center, + size=select_mode_monitor.size, + mode_spec=select_mode_monitor.mode_spec, + direction="-", + name="input_port", + ) + + modeler = ComponentModeler( + simulation=stripped_sim, + ports=[input_port], + freqs=select_mode_monitor.freqs, + ) + + smatrix = run_function( + modeler, numerical_structures=numerical_structures, local_gradient=local_gradient + ) + return np.sum(np.abs(smatrix.smatrix().values) ** 2) + + return objective + + user_vjp_val = 1.0 + user_vjp_val_scale = 10.0 * user_vjp_val + + if not local_gradient: + with pytest.raises( + td.exceptions.AdjointError, + match="ComponentModeler autograd with traced numerical structures requires local_gradient=True.", + ): + val, grad = ag.value_and_grad(make_objective(user_vjp_val))(params0) + else: + val, grad = ag.value_and_grad(make_objective(user_vjp_val))(params0) + val_scale, grad_scale = ag.value_and_grad(make_objective(user_vjp_val_scale))(params0) + + assert np.isclose( + np.sum(np.abs(grad * (user_vjp_val_scale / user_vjp_val) - grad_scale)), 0.0 + ), "Gradients were not set by the user vjp" + + @pytest.mark.skipif(not RUN_NUMERICAL, reason="Numerical gradient tests runs through web API.") @pytest.mark.parametrize("structure_key, monitor_key", (_NUMERICAL_COMBINATION,)) def test_autograd_numerical(structure_key, monitor_key): @@ -1847,6 +2428,7 @@ def J(eps): ), bounds_intersect=((-1, -1, -1), (1, 1, 1)), simulation_bounds=((-2, -2, -2), (2, 2, 2)), + updated_epsilon=None, ) grads_computed = pr._compute_derivatives(derivative_info=info) @@ -1889,6 +2471,7 @@ def test_adaptive_spacing(eps_real): eps_inf_structure={}, bounds_intersect=((-1, -1, -1), (1, 1, 1)), simulation_bounds=((-2, -2, -2), (2, 2, 2)), + updated_epsilon=None, ) with AssertLogLevel("WARNING", contains_str="Based on the material, the adaptive spacing"): @@ -1919,6 +2502,7 @@ def test_cylinder_discretization(eps_real): eps_inf_structure={}, bounds_intersect=((-1, -1, -1), (1, 1, 1)), simulation_bounds=((-2, -2, -2), (2, 2, 2)), + updated_epsilon=None, ) with AssertLogLevel( @@ -2000,6 +2584,7 @@ def J(eps): ), bounds_intersect=((-1, -1, -1), (1, 1, 1)), simulation_bounds=((-2, -2, -2), (2, 2, 2)), + updated_epsilon=None, ) grads_computed = pr._compute_derivatives(derivative_info=info) diff --git a/tests/test_components/autograd/test_autograd_custom_dispersive_vjps.py b/tests/test_components/autograd/test_autograd_custom_dispersive_vjps.py index 3f23a7e98f..6c07125338 100644 --- a/tests/test_components/autograd/test_autograd_custom_dispersive_vjps.py +++ b/tests/test_components/autograd/test_autograd_custom_dispersive_vjps.py @@ -48,6 +48,7 @@ def _deriv_info(freq): "eps_inf_structure": eps_inf, "bounds_intersect": ((-1, -1, -1), (1, 1, 1)), "simulation_bounds": ((-2, -2, -2), (2, 2, 2)), + "updated_epsilon": None, } diff --git a/tidy3d/components/autograd/__init__.py b/tidy3d/components/autograd/__init__.py index a2e9eea893..2e751c49fa 100644 --- a/tidy3d/components/autograd/__init__.py +++ b/tidy3d/components/autograd/__init__.py @@ -5,6 +5,7 @@ from .types import ( AutogradFieldMap, AutogradTraced, + NumericalStructureInfo, TracedCoordinate, TracedFloat, TracedSize, @@ -16,6 +17,7 @@ __all__ = [ "AutogradFieldMap", "AutogradTraced", + "NumericalStructureInfo", "TidyArrayBox", "TracedCoordinate", "TracedFloat", diff --git a/tidy3d/components/autograd/derivative_utils.py b/tidy3d/components/autograd/derivative_utils.py index 7c36444687..e558c70f03 100644 --- a/tidy3d/components/autograd/derivative_utils.py +++ b/tidy3d/components/autograd/derivative_utils.py @@ -115,6 +115,9 @@ class DerivativeInfo: frequencies: ArrayLike """Frequencies at which the adjoint gradient should be computed.""" + updated_epsilon: Callable + """Function to return the permittivity upon geometry replacement.""" + H_der_map: Optional[FieldData] = None """Magnetic field gradient map. Dataset where the field components ("Hx", "Hy", "Hz") store the multiplication diff --git a/tidy3d/components/autograd/types.py b/tidy3d/components/autograd/types.py index bb41935695..d3ac113c61 100644 --- a/tidy3d/components/autograd/types.py +++ b/tidy3d/components/autograd/types.py @@ -5,6 +5,7 @@ import copy import typing +from dataclasses import dataclass import pydantic.v1 as pd from autograd.builtins import dict as dict_ag @@ -40,13 +41,27 @@ # The data type that we pass in and out of the web.run() @autograd.primitive AutogradTraced = typing.Union[Box, ArrayLike] PathType = tuple[typing.Union[int, str], ...] +CustomVJPPathType = tuple[typing.Union[int, str, typing.Callable], ...] AutogradFieldMap = dict_ag[PathType, AutogradTraced] InterpolationType = typing.Literal["nearest", "linear"] + +@dataclass(frozen=True) +class NumericalStructureInfo: + """Metadata describing a user-supplied numerical structure insertion.""" + + index: int + parameters: typing.Any + function: typing.Callable[..., typing.Any] + structure: typing.Any + vjp: typing.Callable[..., typing.Any] + + __all__ = [ "AutogradFieldMap", "AutogradTraced", + "NumericalStructureInfo", "TracedCoordinate", "TracedFloat", "TracedSize", diff --git a/tidy3d/components/autograd/utils.py b/tidy3d/components/autograd/utils.py index a87e18f98b..3ba24af0d1 100644 --- a/tidy3d/components/autograd/utils.py +++ b/tidy3d/components/autograd/utils.py @@ -5,11 +5,14 @@ from typing import Any import autograd.numpy as anp +import numpy as np +from autograd.extend import Box from autograd.tracer import getval __all__ = [ "asarray1d", "contains", + "contains_tracer", "get_static", "is_tidy_box", "pack_complex_vec", @@ -44,6 +47,20 @@ def contains(target: Any, seq: Iterable[Any]) -> bool: return False +def contains_tracer(value) -> bool: + if isinstance(value, Box): + return True + if isinstance(value, np.ndarray): + return any(contains_tracer(v) for v in value.flat) + if isinstance(value, dict): + return any(contains_tracer(v) for v in value.values()) + if isinstance(value, (list, tuple)): + return any(contains_tracer(v) for v in value) + if isinstance(value, Iterable) and not isinstance(value, (str, bytes)): + return any(contains_tracer(v) for v in value) + return False + + def pack_complex_vec(z): """Ravel [Re(z); Im(z)] into one real vector (autograd-safe).""" return anp.concatenate([anp.ravel(anp.real(z)), anp.ravel(anp.imag(z))]) diff --git a/tidy3d/components/geometry/primitives.py b/tidy3d/components/geometry/primitives.py index 667ef5cb1a..313c91a57a 100644 --- a/tidy3d/components/geometry/primitives.py +++ b/tidy3d/components/geometry/primitives.py @@ -42,6 +42,13 @@ class Sphere(base.Centered, base.Circular): >>> b = Sphere(center=(1,2,3), radius=2) """ + radius: TracedSize1D = pydantic.Field( + ..., + title="Radius", + description="Radius of geometry at the ``reference_plane``.", + units=MICROMETER, + ) + def inside( self, x: np.ndarray[float], y: np.ndarray[float], z: np.ndarray[float] ) -> np.ndarray[bool]: diff --git a/tidy3d/components/simulation.py b/tidy3d/components/simulation.py index 1074899775..5407873fc8 100644 --- a/tidy3d/components/simulation.py +++ b/tidy3d/components/simulation.py @@ -4818,9 +4818,24 @@ def _make_adjoint_monitors(self, sim_fields_keys: list) -> tuple[list, list]: """Get lists of field and permittivity monitors for this simulation.""" index_to_keys = defaultdict(list) + numerical_indices = set() - for _, index, *fields in sim_fields_keys: - index_to_keys[index].append(fields) + for namespace, index, *fields in sim_fields_keys: + if namespace not in {"structures", "numerical"}: + log.warning( + "Encountered unknown namespace '%s' while creating adjoint monitors; ignoring.", + namespace, + ) + continue + + if namespace == "structures": + index_to_keys[index].append(fields) + elif namespace == "numerical": + numerical_indices.add(index) + + for index in numerical_indices: + if not index_to_keys[index]: + index_to_keys[index].append([]) freqs = self._freqs_adjoint diff --git a/tidy3d/components/structure.py b/tidy3d/components/structure.py index 07f19a3e06..6925c4cfbf 100644 --- a/tidy3d/components/structure.py +++ b/tidy3d/components/structure.py @@ -3,6 +3,7 @@ from __future__ import annotations import pathlib +import typing from collections import defaultdict from functools import cmp_to_key from os import PathLike @@ -346,8 +347,13 @@ def _make_adjoint_monitors( return mnt_fld, mnt_eps - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute adjoint gradients given the forward and adjoint fields""" + def _compute_derivatives( + self, + derivative_info: DerivativeInfo, + vjp_fns: typing.Optional[dict[tuple[str, str], typing.Callable[..., typing.Any]]] = None, + ) -> AutogradFieldMap: + """Compute adjoint gradients given the forward and adjoint fields provided in derivative_info.""" + """vjp_fns provide alternate derivative computation paths for the geometry or medium derivatives.""" # generate a mapping from the 'medium', or 'geometry' tag to the list of fields for VJP structure_fields_map = defaultdict(list) @@ -366,11 +372,35 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField # loop through sub fields, compute VJPs, and store in the derivative map {path -> vjp_value} derivative_map = {} + # the first level of integration would be to for med_or_geo, field_paths in structure_fields_map.items(): # grab derivative values {field_name -> vjp_value} med_or_geo_field = self.medium if med_or_geo == "medium" else self.geometry - info = derivative_info.updated_copy(paths=field_paths, deep=False) - derivative_values_map = med_or_geo_field._compute_derivatives(derivative_info=info) + + collect_paths_by_keys = {} + for path in field_paths: + if path[0] in collect_paths_by_keys: + collect_paths_by_keys[path[0]].append(path) + else: + collect_paths_by_keys[path[0]] = [path] + + derivative_values_map = {} + for path_key, paths in collect_paths_by_keys.items(): + info = derivative_info.updated_copy(paths=paths, deep=False) + + full_path = (med_or_geo, path_key) + if (vjp_fns is not None) and (full_path in vjp_fns): + full_paths = ((med_or_geo, *path) for path in paths) + info = derivative_info.updated_copy(paths=full_paths, deep=False) + + vjp = vjp_fns[full_path](med_or_geo_field, info) + vjp_strip_med_or_geo = {key[1:]: val for key, val in vjp.items()} + + derivative_values_map.update(vjp_strip_med_or_geo) + else: + derivative_values_map.update( + med_or_geo_field._compute_derivatives(derivative_info=info) + ) # construct map of {field path -> derivative value} for field_path, derivative_value in derivative_values_map.items(): diff --git a/tidy3d/plugins/smatrix/run.py b/tidy3d/plugins/smatrix/run.py index 97f9393338..aea07a0fd6 100644 --- a/tidy3d/plugins/smatrix/run.py +++ b/tidy3d/plugins/smatrix/run.py @@ -1,11 +1,12 @@ from __future__ import annotations import json +import typing from os import PathLike -from typing import Any from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.data.index import SimulationDataMap +from tidy3d.exceptions import AdjointError from tidy3d.log import log from tidy3d.plugins.smatrix.component_modelers.modal import ModalComponentModeler from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler @@ -14,6 +15,10 @@ from tidy3d.plugins.smatrix.data.terminal import TerminalComponentModelerData from tidy3d.plugins.smatrix.data.types import ComponentModelerDataType from tidy3d.web import Batch, BatchData +from tidy3d.web.api.autograd.types import ( + NumericalStructureConfig, + UserVJPConfig, +) DEFAULT_DATA_DIR = "." @@ -127,7 +132,7 @@ def compose_modeler_data_from_batch_data( def create_batch( modeler: ComponentModelerType, - **kwargs: Any, + **kwargs: typing.Any, ) -> Batch: """Create a simulation Batch from a component modeler. @@ -154,7 +159,11 @@ def create_batch( def _run_local( modeler: ComponentModelerType, path_dir: str = DEFAULT_DATA_DIR, - **kwargs: Any, + numerical_structures: typing.Optional[ + typing.Union[NumericalStructureConfig, tuple[NumericalStructureConfig]] + ] = None, + user_vjp: typing.Optional[typing.Union[UserVJPConfig, tuple[UserVJPConfig]]] = None, + **kwargs: typing.Any, ) -> ComponentModelerDataType: """Execute the full simulation workflow for a given component modeler. @@ -169,6 +178,12 @@ def _run_local( The component modeler defining the simulations to be run. path_dir : str, optional The directory where the batch file will be saved. Defaults to ".". + numerical_structures : typing.Union[NumericalStructureConfig, tuple[NumericalStructureConfig]] = None + Specification of additional structures to add to the base simulation that can be traced via + autograd. This can be a single structure or multiple structures specified in a tuple. + user_vjp : typing.Union[UserVJPConfig, tuple[UserVJPConfig]] = None + Specification of alternate gradient function for certain structures in the simulation. + This can be a single vjp configuration or multiple specified in a tuple. **kwargs Extra keyword arguments propagated to the Batch creation. @@ -183,7 +198,18 @@ def _run_local( from tidy3d.web.api.autograd import autograd as web_ag sims = modeler.sim_dict - if any(web_ag.is_valid_for_autograd(sim) for sim in sims.values()): + + if isinstance(numerical_structures, NumericalStructureConfig): + numerical_structures = (numerical_structures,) + + traced_numerical_structures = numerical_structures and web_ag.has_traced_numerical_structures( + numerical_structures + ) + should_use_autograd = traced_numerical_structures or any( + web_ag.is_valid_for_autograd(sim) for sim in sims.values() + ) + + if should_use_autograd: if len(modeler.element_mappings) > 0: log.warning( "Element mappings are used to populate S-matrix values, but autograd gradients " @@ -199,10 +225,49 @@ def _run_local( kwargs.setdefault("simulation_type", "tidy3d_autograd_async") kwargs.setdefault("path_dir", path_dir) - sim_data_map = _run_async(simulations=sims, **kwargs) + local_gradient = kwargs.get("local_gradient", True) + + if not local_gradient: + if user_vjp is not None: + raise AdjointError("user_vjp specified for a remote gradient not supported.") + + if traced_numerical_structures: + raise AdjointError( + "ComponentModeler autograd with traced numerical structures requires local_gradient=True." + ) + + if numerical_structures: + web_ag.validate_numerical_structure_parameters(numerical_structures) + numerical_structures = dict.fromkeys(sims, numerical_structures) + + if isinstance(user_vjp, UserVJPConfig): + user_vjp = (user_vjp,) + + if user_vjp: + user_vjp = dict.fromkeys(sims, user_vjp) + + if numerical_structures is not None: + for key in numerical_structures: + numerical_structures[key] = web_ag.populate_numerical_structures( + simulation=sims[key], numerical_structures=numerical_structures[key] + ) + + sim_data_map = _run_async( + simulations=sims, + numerical_structures=numerical_structures, + user_vjp=user_vjp, + **kwargs, + ) return compose_modeler_data_from_batch_data(modeler=modeler, batch_data=sim_data_map) + if numerical_structures is not None: + modeler = modeler.updated_copy( + simulation=web_ag.insert_numerical_structures_static( + simulation=modeler.simulation, numerical_structures=numerical_structures + ) + ) + # Filter kwargs to only include valid Batch parameters batch_kwargs = { k: v diff --git a/tidy3d/web/__init__.py b/tidy3d/web/__init__.py index 0cdc8942e5..608f679fd6 100644 --- a/tidy3d/web/__init__.py +++ b/tidy3d/web/__init__.py @@ -11,8 +11,6 @@ # set logger to tidy3d.log before it's invoked in other imports core_config.set_config(log, get_logging_console(), __version__) -# from .api.asynchronous import run_async # NOTE: we use autograd one now (see below) -# autograd compatible wrappers for run and run_async from .api.autograd.autograd import run_async from .api.container import Batch, BatchData, Job from .api.run import run diff --git a/tidy3d/web/api/autograd/__init__.py b/tidy3d/web/api/autograd/__init__.py index e69de29bb2..2ae40f05b6 100644 --- a/tidy3d/web/api/autograd/__init__.py +++ b/tidy3d/web/api/autograd/__init__.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from .autograd import ( + has_traced_numerical_structures, + insert_numerical_structures_static, + populate_numerical_structures, + validate_numerical_structure_parameters, +) + +__all__ = [ + "has_traced_numerical_structures", + "insert_numerical_structures_static", + "populate_numerical_structures", + "validate_numerical_structure_parameters", +] diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index c2e4eb965c..f7e269b9b5 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -2,15 +2,18 @@ from __future__ import annotations import typing +from dataclasses import replace from os import PathLike from pathlib import Path from typing import Any +import numpy as np from autograd.builtins import dict as dict_ag from autograd.extend import defvjp, primitive import tidy3d as td -from tidy3d.components.autograd import AutogradFieldMap +from tidy3d.components.autograd import AutogradFieldMap, get_static +from tidy3d.components.autograd.utils import contains_tracer from tidy3d.components.base import TRACED_FIELD_KEYS_ATTR from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType from tidy3d.config import config @@ -50,6 +53,11 @@ from .io_utils import ( upload_sim_fields_keys as _upload_sim_fields_keys_impl, ) +from .types import ( + NumericalStructureConfig, + SetupRunResult, + UserVJPConfig, +) def _resolve_local_gradient(value: typing.Optional[bool]) -> bool: @@ -59,6 +67,72 @@ def _resolve_local_gradient(value: typing.Optional[bool]) -> bool: return bool(config.adjoint.local_gradient) +def insert_numerical_structures_static( + simulation: td.Simulation, + numerical_structures: typing.Sequence[NumericalStructureConfig], +) -> td.Simulation: + """Return a Simulation with numerical structures inserted, without autograd metadata.""" + + structures = list(simulation.structures) + + for numerical_cfg in numerical_structures: + structure = numerical_cfg.create(get_static(numerical_cfg.parameters)) + structures.insert(numerical_cfg.structure_index, structure) + + return simulation.updated_copy(structures=structures) + + +def _normalize_simulations_input( + simulations: typing.Union[dict[str, td.Simulation], tuple[td.Simulation], list[td.Simulation]], +) -> tuple[dict[str, td.Simulation], dict[str, int]]: + """Normalize simulations to a dict and map each task name to its positional index.""" + + if isinstance(simulations, dict): + return simulations, {name: idx for idx, name in enumerate(simulations)} + + normalized: dict[str, td.Simulation] = {} + name_mapping: dict[str, int] = {} + + for idx, sim in enumerate(simulations): + task_name = Tidy3dStub(simulation=sim).get_default_task_name() + f"_{idx + 1}" + normalized[task_name] = sim + name_mapping[task_name] = idx + + return normalized, name_mapping + + +def has_traced_numerical_structures( + numerical_structures: typing.Union[ + tuple[NumericalStructureConfig], + list[NumericalStructureConfig], + dict[str, NumericalStructureConfig], + ], +) -> bool: + iterable_structures = ( + numerical_structures.values() + if isinstance(numerical_structures, dict) + else numerical_structures + ) + for cfg in iterable_structures: + if contains_tracer(cfg.parameters): + return True + + return False + + +def validate_numerical_structure_parameters( + numerical_structures: tuple[NumericalStructureConfig], +) -> None: + """Validate user-supplied numerical structure configuration.""" + + for numerical_config in numerical_structures: + array_params = np.array(numerical_config.parameters) + if array_params.ndim != 1: + raise AdjointError( + f"Parameters for numerical structure index {numerical_config.structure_index} must be 1D array-like." + ) + + def is_valid_for_autograd(simulation: td.Simulation) -> bool: """Check whether a supplied Simulation can use the autograd path.""" if not isinstance(simulation, td.Simulation): @@ -100,7 +174,7 @@ def is_valid_for_autograd_async(simulations: dict[str, td.Simulation]) -> bool: return True -def run( +def run_custom( simulation: WorkflowType, task_name: typing.Optional[str] = None, folder_name: str = "default", @@ -119,6 +193,10 @@ def run( pay_type: typing.Union[PayType, str] = PayType.AUTO, priority: typing.Optional[int] = None, lazy: typing.Optional[bool] = None, + numerical_structures: typing.Optional[ + typing.Union[NumericalStructureConfig, tuple[NumericalStructureConfig]] + ] = None, + user_vjp: typing.Optional[typing.Union[UserVJPConfig, tuple[UserVJPConfig]]] = None, ) -> WorkflowDataType: """ Submits a :class:`.Simulation` to server, starts running, monitors progress, downloads, @@ -166,6 +244,13 @@ def run( lazy: Optional[bool] = None Whether to return lazy data proxies. Defaults to ``False`` for single runs when unspecified, matching :func:`tidy3d.web.run`. + numerical_structures : typing.Optional[typing.Union[NumericalStructureConfig, tuple[NumericalStructureConfig]]] = None + Specification of additional structures to add to the simulation (or base simulation for ComponentModeler workflows) + that can be traced via autograd. This can be a single structure or multiple structures specified in a tuple. + user_vjp : typing.Optional[typing.Union[UserVJPConfig, tuple[UserVJPConfig]]] = None + Specification of alternate gradient function for certain structures in the simulation. + This can be a single vjp configuration or multiple specified in a tuple. + Returns ------- Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`, :class:`.ModalComponentModelerData`, :class:`.TerminalComponentModelerData`] @@ -224,13 +309,28 @@ def run( stub = Tidy3dStub(simulation=simulation) task_name = stub.get_default_task_name() + if numerical_structures is not None: + if isinstance(numerical_structures, NumericalStructureConfig): + numerical_structures = (numerical_structures,) + + if user_vjp is not None: + if isinstance(user_vjp, UserVJPConfig): + user_vjp = (user_vjp,) + + if numerical_structures is not None: + validate_numerical_structure_parameters(numerical_structures=numerical_structures) + + traced_numerical_structures = has_traced_numerical_structures(numerical_structures or []) + # component modeler path: route autograd-valid modelers to local run from tidy3d.plugins.smatrix.component_modelers.types import ComponentModelerType path = Path(path) if isinstance(simulation, typing.get_args(ComponentModelerType)): - if any(is_valid_for_autograd(s) for s in simulation.sim_dict.values()): + if traced_numerical_structures or ( + any(is_valid_for_autograd(s) for s in simulation.sim_dict.values()) + ): from tidy3d.plugins.smatrix import run as smatrix_run path_dir = path.parent @@ -245,9 +345,28 @@ def run( priority=priority, local_gradient=local_gradient, max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, + numerical_structures=numerical_structures, + user_vjp=user_vjp, + ) + + should_use_autograd = False + if isinstance(simulation, td.Simulation): + should_use_autograd = is_valid_for_autograd(simulation) or traced_numerical_structures + + if numerical_structures is not None: + numerical_structures = populate_numerical_structures( + simulation=simulation, numerical_structures=numerical_structures + ) + + if should_use_autograd: + if (user_vjp is not None) and (not local_gradient): + raise AdjointError("user_vjp specified for a remote gradient not supported.") + + if traced_numerical_structures and (not local_gradient): + raise AdjointError( + "numerical_structures specified for a remote gradient not supported." ) - if isinstance(simulation, td.Simulation) and is_valid_for_autograd(simulation): return _run( simulation=simulation, task_name=task_name, @@ -263,12 +382,64 @@ def run( parent_tasks=parent_tasks, local_gradient=local_gradient, max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, + numerical_structures=numerical_structures, + user_vjp=user_vjp, pay_type=pay_type, priority=priority, lazy=lazy, ) + simulation_static = simulation + if isinstance(simulation, td.Simulation) and (numerical_structures is not None): + # if there are numerical_structures without traced parameters, we still want + # to insert them into the simulation + simulation_static = insert_numerical_structures_static( + simulation=simulation, + numerical_structures=numerical_structures, + ) + return run_webapi( + simulation=simulation_static, + task_name=task_name, + folder_name=folder_name, + path=path, + callback_url=callback_url, + verbose=verbose, + progress_callback_upload=progress_callback_upload, + progress_callback_download=progress_callback_download, + solver_version=solver_version, + worker_group=worker_group, + simulation_type=simulation_type, + parent_tasks=parent_tasks, + reduce_simulation=reduce_simulation, + pay_type=pay_type, + priority=priority, + lazy=lazy, + ) + + +def run( + simulation: WorkflowType, + task_name: typing.Optional[str] = None, + folder_name: str = "default", + path: PathLike = "simulation_data.hdf5", + callback_url: typing.Optional[str] = None, + verbose: bool = True, + progress_callback_upload: typing.Optional[typing.Callable[[float], None]] = None, + progress_callback_download: typing.Optional[typing.Callable[[float], None]] = None, + solver_version: typing.Optional[str] = None, + worker_group: typing.Optional[str] = None, + simulation_type: str = "tidy3d", + parent_tasks: typing.Optional[list[str]] = None, + local_gradient: typing.Optional[bool] = None, + max_num_adjoint_per_fwd: typing.Optional[int] = None, + reduce_simulation: typing.Literal["auto", True, False] = "auto", + pay_type: typing.Union[PayType, str] = PayType.AUTO, + priority: typing.Optional[int] = None, + lazy: typing.Optional[bool] = None, +) -> WorkflowDataType: + """Wrapper for run_custom for usage without numerical_structures or user_vjp for public facing API.""" + return run_custom( simulation=simulation, task_name=task_name, folder_name=folder_name, @@ -281,14 +452,18 @@ def run( worker_group=worker_group, simulation_type=simulation_type, parent_tasks=parent_tasks, + local_gradient=local_gradient, + max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, reduce_simulation=reduce_simulation, pay_type=pay_type, priority=priority, lazy=lazy, + numerical_structures=None, + user_vjp=None, ) -def run_async( +def run_async_custom( simulations: typing.Union[dict[str, td.Simulation], tuple[td.Simulation], list[td.Simulation]], folder_name: str = "default", path_dir: PathLike = DEFAULT_DATA_DIR, @@ -304,6 +479,24 @@ def run_async( pay_type: typing.Union[PayType, str] = PayType.AUTO, priority: typing.Optional[int] = None, lazy: typing.Optional[bool] = None, + numerical_structures: typing.Optional[ + typing.Union[ + NumericalStructureConfig, + dict[str, NumericalStructureConfig], + typing.Sequence[NumericalStructureConfig], + dict[str, typing.Sequence[NumericalStructureConfig]], + typing.Sequence[typing.Sequence[NumericalStructureConfig]], + ] + ] = None, + user_vjp: typing.Optional[ + typing.Union[ + UserVJPConfig, + dict[str, UserVJPConfig], + typing.Sequence[UserVJPConfig], + dict[str, typing.Sequence[UserVJPConfig]], + typing.Sequence[typing.Sequence[UserVJPConfig]], + ] + ] = None, ) -> BatchData: """Submits a set of Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] objects to server, starts running, monitors progress, downloads, and loads results as a :class:`.BatchData` object. @@ -344,6 +537,30 @@ def run_async( lazy: Optional[bool] = None Whether to return lazy data proxies. Defaults to ``True`` for batch runs when unspecified, matching :func:`tidy3d.web.run`. + numerical_structures: typing.Optional[typing.Union[ + NumericalStructureConfig, + dict[str, NumericalStructureConfig], + typing.Sequence[NumericalStructureConfig], + dict[str, typing.Sequence[NumericalStructureConfig]], + typing.Sequence[typing.Sequence[NumericalStructureConfig]], + ]] = None + Specification of additional structures to add to the simulations that can be traced via autograd. Different + numerical_structures can be added for different simulations or the same set can be broadcasted to all simulations. + Specifying a single config will broadcast to all simluations. Specifying a dict or a sequence with single configs + as values will set one config for each simluation. Most generally, multiple structures can be specified for each + simulation by specifying a dict with sequence values or a sequence of sequences. + user_vjp: typing.Optional[typing.Union[ + UserVJPConfig, + dict[str, UserVJPConfig], + typing.Sequence[UserVJPConfig], + dict[str, typing.Sequence[UserVJPConfig]], + typing.Sequence[typing.Sequence[UserVJPConfig]], + ]] = None + Specification of alternate gradient function for certain structures in the simulation. Different + user_vjp's can be added for different simulations or the same set can be broadcasted to all simulations. + Specifying a single config will broadcast to all simluations. Specifying a dict or a sequence with single configs + as values will set one config for each simluation. Most generally, multiple user_vjp's can be specified for each + simulation by specifying a dict with sequence values or a sequence of sequences. Returns ------ @@ -371,18 +588,111 @@ def run_async( lazy = True if lazy is None else bool(lazy) + def validate_and_expand( + fn_arg: typing.Union[NumericalStructureConfig, UserVJPConfig], + fn_arg_name: str, + base_type: type[typing.Union[NumericalStructureConfig, UserVJPConfig]], + orig_sim_arg: typing.Union[ + dict[str, td.Simulation], tuple[td.Simulation], list[td.Simulation] + ], + sim_dict: dict[str, tuple[td.Simulation]], + ) -> dict[str, typing.Sequence[typing.Union[NumericalStructureConfig, UserVJPConfig]]]: + """Check and validate the provided numerical_structures or user_vjp type and expand as""" + """necessary to match the provided simulation specification.""" + if fn_arg is None: + return fn_arg + + if isinstance(fn_arg, base_type): + expanded = dict.fromkeys(sim_dict.keys(), fn_arg) + return expanded + + expanded = {} + if not isinstance(fn_arg, type(orig_sim_arg)): + raise AdjointError( + f"{fn_arg_name} type ({type(fn_arg)}) should match simulations type ({type(simulations)})" + ) + + if isinstance(orig_sim_arg, dict): + check_keys = fn_arg.keys() == sim_dict.keys() + + if not check_keys: + raise AdjointError(f"{fn_arg_name} keys do not match simulations keys") + + for key, val in fn_arg.items(): + if isinstance(val, base_type): + expanded[key] = (val,) + else: + expanded[key] = val + + elif isinstance(orig_sim_arg, (list, tuple)): + if not (len(fn_arg) == len(orig_sim_arg)): + raise AdjointError( + f"{fn_arg_name} is not the same length as simulations ({len(expanded)} vs. {len(simulations)})" + ) + + for idx, key in enumerate(sim_dict.keys()): + val = fn_arg[idx] + if isinstance(val, (list, tuple)): + expanded[key] = val + else: + expanded[key] = (val,) + + return expanded + if isinstance(simulations, (tuple, list)): sim_dict = {} for i, sim in enumerate(simulations, 1): task_name = Tidy3dStub(simulation=sim).get_default_task_name() + f"_{i}" sim_dict[task_name] = sim - simulations = sim_dict + else: + sim_dict = simulations + + numerical_structures = validate_and_expand( + numerical_structures, + "numerical_structures", + NumericalStructureConfig, + simulations, + sim_dict, + ) + if numerical_structures is not None: + for _, numerical_structures_configs in numerical_structures.items(): + validate_numerical_structure_parameters( + numerical_structures=numerical_structures_configs + ) + + user_vjp = validate_and_expand(user_vjp, "user_vjp", UserVJPConfig, simulations, sim_dict) + + simulations = sim_dict path_dir = Path(path_dir) - if is_valid_for_autograd_async(simulations): + simulations_norm, name_mapping = _normalize_simulations_input(simulations) + + should_use_autograd_async = is_valid_for_autograd_async(simulations_norm) + traced_numerical_structures = (numerical_structures is not None) and any( + has_traced_numerical_structures(numerical_structure) + for _, numerical_structure in numerical_structures.items() + ) + should_use_autograd_async = ( + is_valid_for_autograd_async(simulations_norm) or traced_numerical_structures + ) + + if numerical_structures is not None: + for key in numerical_structures: + numerical_structures[key] = populate_numerical_structures( + simulation=simulations_norm[key], numerical_structures=numerical_structures[key] + ) + + if should_use_autograd_async: + if (user_vjp is not None) and (not local_gradient): + raise AdjointError("user_vjp specified for a remote gradient not supported.") + if traced_numerical_structures and (not local_gradient): + raise AdjointError( + "numerical_structures specified for a remote gradient not supported." + ) + return _run_async( - simulations=simulations, + simulations=simulations_norm, folder_name=folder_name, path_dir=path_dir, callback_url=callback_url, @@ -393,12 +703,65 @@ def run_async( parent_tasks=parent_tasks, local_gradient=local_gradient, max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, + numerical_structures=numerical_structures, + user_vjp=user_vjp, pay_type=pay_type, priority=priority, lazy=lazy, ) + # insert numerical_structures even if not traced + if numerical_structures is not None: + simulations_static = { + name: ( + insert_numerical_structures_static( + simulation=simulations_norm[name], + numerical_structures=numerical_structures[name], + ) + if numerical_structures[name] + else simulations_norm[name] + ) + for name in simulations_norm + } + else: + simulations_static = simulations_norm + return run_async_webapi( + simulations=simulations_static, + folder_name=folder_name, + path_dir=path_dir, + callback_url=callback_url, + num_workers=num_workers, + verbose=verbose, + simulation_type=simulation_type, + solver_version=solver_version, + parent_tasks=parent_tasks, + reduce_simulation=reduce_simulation, + pay_type=pay_type, + priority=priority, + lazy=lazy, + ) + + +def run_async( + simulations: typing.Union[dict[str, td.Simulation], tuple[td.Simulation], list[td.Simulation]], + folder_name: str = "default", + path_dir: PathLike = DEFAULT_DATA_DIR, + callback_url: typing.Optional[str] = None, + num_workers: typing.Optional[int] = None, + verbose: bool = True, + simulation_type: str = "tidy3d", + solver_version: typing.Optional[str] = None, + parent_tasks: typing.Optional[dict[str, list[str]]] = None, + local_gradient: typing.Optional[bool] = None, + max_num_adjoint_per_fwd: typing.Optional[int] = None, + reduce_simulation: typing.Literal["auto", True, False] = "auto", + pay_type: typing.Union[PayType, str] = PayType.AUTO, + priority: typing.Optional[int] = None, + lazy: typing.Optional[bool] = None, +) -> BatchData: + """Wrapper for run_async_custom for usage without numerical_structures or user_vjp for public facing API.""" + return run_async_custom( simulations=simulations, folder_name=folder_name, path_dir=path_dir, @@ -408,10 +771,14 @@ def run_async( simulation_type=simulation_type, solver_version=solver_version, parent_tasks=parent_tasks, + local_gradient=local_gradient, + max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, reduce_simulation=reduce_simulation, pay_type=pay_type, priority=priority, lazy=lazy, + numerical_structures=None, + user_vjp=None, ) @@ -423,11 +790,18 @@ def _run( task_name: str, local_gradient: bool = False, max_num_adjoint_per_fwd: typing.Optional[int] = None, + numerical_structures: typing.Optional[tuple[NumericalStructureConfig]] = None, + user_vjp: typing.Optional[tuple[UserVJPConfig]] = None, **run_kwargs: Any, ) -> td.SimulationData: """User-facing ``web.run`` function, compatible with ``autograd`` differentiation.""" - traced_fields_sim = setup_run(simulation=simulation) + setup_result = setup_run( + simulation=simulation, + numerical_structures=numerical_structures, + ) + traced_fields_sim = setup_result.sim_fields + simulation = setup_result.simulation # if we register this as not needing adjoint at all (no tracers), call regular run function if not traced_fields_sim: @@ -456,6 +830,8 @@ def _run( aux_data=aux_data, local_gradient=local_gradient, max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, + numerical_structures=numerical_structures, + user_vjp=user_vjp, **run_kwargs, ) @@ -466,41 +842,62 @@ def _run_async( simulations: dict[str, td.Simulation], local_gradient: bool = False, max_num_adjoint_per_fwd: typing.Optional[int] = None, + numerical_structures: typing.Optional[ + dict[str, typing.Sequence[NumericalStructureConfig]] + ] = None, + user_vjp: typing.Optional[dict[str, typing.Sequence[UserVJPConfig]]] = None, **run_async_kwargs: Any, ) -> dict[str, td.SimulationData]: """User-facing ``web.run_async`` function, compatible with ``autograd`` differentiation.""" - task_names = simulations.keys() traced_fields_sim_dict: dict[str, AutogradFieldMap] = {} sims_original: dict[str, td.Simulation] = {} + sims_prepared: dict[str, td.Simulation] = {} + + if max_num_adjoint_per_fwd is None: + max_num_adjoint_per_fwd = config.adjoint.max_adjoint_per_fwd + + numerical_structures = numerical_structures or {} + aux_data_dict = {task_name: {} for task_name in task_names} + for task_name in task_names: sim = simulations[task_name] - traced_fields = setup_run(simulation=sim) + setup_result = setup_run( + simulation=sim, + numerical_structures=numerical_structures.get(task_name), + ) + sim_prepared = setup_result.simulation + traced_fields = setup_result.sim_fields + + sims_prepared[task_name] = sim_prepared + traced_fields_sim_dict[task_name] = traced_fields - payload = sim._serialized_traced_field_keys(traced_fields) - sim_static = sim.to_static() + payload = sim_prepared._serialized_traced_field_keys(traced_fields) + sim_static = sim_prepared.to_static() if payload: sim_static.attrs[TRACED_FIELD_KEYS_ATTR] = payload + sims_original[task_name] = sim_static - traced_fields_sim_dict = dict_ag(traced_fields_sim_dict) # TODO: shortcut primitive running for any items with no tracers? + traced_fields_sim_dict = dict_ag(traced_fields_sim_dict) + sims_original = {name: sims_original[name] for name in traced_fields_sim_dict.keys()} - aux_data_dict = {task_name: {} for task_name in task_names} traced_fields_data_dict = _run_async_primitive( traced_fields_sim_dict, # if you pass as a kwarg it will not trace :/ sims_original=sims_original, aux_data_dict=aux_data_dict, local_gradient=local_gradient, max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, + numerical_structures=setup_result.numerical_structures, + user_vjp=user_vjp, **run_async_kwargs, ) - # TODO: package this as a Batch? it might be not possible as autograd tracers lose their - # powers when we save them to file. + # TODO: package this as a Batch? it might be not possible as autograd tracers lose their powers when we save them to file. sim_data_dict = {} - for task_name in task_names: + for task_name in traced_fields_sim_dict.keys(): traced_fields_data = traced_fields_data_dict[task_name] aux_data = aux_data_dict[task_name] sim_data = postprocess_run(traced_fields_data=traced_fields_data, aux_data=aux_data) @@ -509,20 +906,81 @@ def _run_async( return sim_data_dict -def setup_run(simulation: td.Simulation) -> AutogradFieldMap: - """Process a user-supplied ``Simulation`` into inputs to ``_run_primitive``.""" +def populate_numerical_structures( + simulation: td.Simulation, + numerical_structures: tuple[NumericalStructureConfig], +) -> typing.Optional[tuple[NumericalStructureConfig]]: + populated_numerical_structures = [] + + last_structure_index = len(simulation.structures) + + for numerical_structure in numerical_structures: + structure_index = numerical_structure.structure_index + + if structure_index == -1: + populated_numerical_structures.append( + replace(numerical_structure, structure_index=last_structure_index) + ) + else: + populated_numerical_structures.append(numerical_structure) - # get a mapping of all the traced fields in the provided simulation - return simulation._strip_traced_fields( + last_structure_index += 1 + + return tuple(populated_numerical_structures) + + +def setup_run( + simulation: td.Simulation, + numerical_structures: typing.Optional[tuple[NumericalStructureConfig]] = None, +) -> SetupRunResult: + """Prepare simulation and traced fields, including numerical structure insertions.""" + + sim_prepared = simulation + + if numerical_structures: + structures = list(simulation.structures) + for config in numerical_structures: + structure = config.create(get_static(config.parameters)) + structures.insert(config.structure_index, structure) + + sim_prepared = simulation.updated_copy(structures=structures) + + sim_fields_map = sim_prepared._strip_traced_fields( include_untraced_data_arrays=False, starting_path=("structures",) ) + if numerical_structures: + numerical_structures_indices = [ + numerical_structure.structure_index for numerical_structure in numerical_structures + ] + + # collect sim fields for structures that go through regular derivative path + sim_fields_dict = { + key: value + for key, value in sim_fields_map.items() + if not (key[0] == "structures" and key[1] in numerical_structures_indices) + } + + # collect sim fields for structures that go through numerical derivative path + for config in numerical_structures: + for idx, param in enumerate(config.parameters): + sim_fields_dict[("numerical", config.structure_index, idx)] = param + + sim_fields_map = dict_ag(sim_fields_dict) + + return SetupRunResult( + sim_fields=sim_fields_map, + simulation=sim_prepared, + numerical_structures=numerical_structures, + ) + def postprocess_run(traced_fields_data: AutogradFieldMap, aux_data: dict) -> td.SimulationData: """Process the return from ``_run_primitive`` into ``SimulationData`` for user.""" # grab the user's 'SimulationData' and return with the autograd-tracers inserted sim_data_original = aux_data[AUX_KEY_SIM_DATA_ORIGINAL] + return sim_data_original._insert_traced_fields(traced_fields_data) @@ -537,6 +995,7 @@ def _run_primitive( aux_data: dict, local_gradient: bool, max_num_adjoint_per_fwd: int, + user_vjp: typing.Optional[typing.Union[UserVJPConfig, tuple[UserVJPConfig]]] = None, **run_kwargs: Any, ) -> AutogradFieldMap: """Autograd-traced 'run()' function: runs simulation, strips tracer data, caches fwd data.""" @@ -605,6 +1064,10 @@ def _run_async_primitive( aux_data_dict: dict[dict[str, typing.Any]], local_gradient: bool, max_num_adjoint_per_fwd: int, + numerical_structures: typing.Optional[ + dict[str, typing.Sequence[NumericalStructureConfig]], + ] = None, + user_vjp: typing.Optional[dict[str, typing.Sequence[UserVJPConfig]],] = None, **run_async_kwargs: Any, ) -> dict[str, AutogradFieldMap]: task_names = sim_fields_dict.keys() @@ -653,8 +1116,9 @@ def _run_async_primitive( field_map_fwd_dict = {} for task_name, task_id_fwd in task_ids_fwd_dict.items(): sim_data_orig = sim_data_orig_dict[task_name] - aux_data_dict[task_name][AUX_KEY_FWD_TASK_ID] = task_id_fwd - aux_data_dict[task_name][AUX_KEY_SIM_DATA_ORIGINAL] = sim_data_orig + aux_data = aux_data_dict[task_name] + aux_data[AUX_KEY_FWD_TASK_ID] = task_id_fwd + aux_data[AUX_KEY_SIM_DATA_ORIGINAL] = sim_data_orig field_map = sim_data_orig._strip_traced_fields( include_untraced_data_arrays=True, starting_path=("data",) ) @@ -710,6 +1174,8 @@ def _run_bwd( aux_data: dict, local_gradient: bool, max_num_adjoint_per_fwd: int, + numerical_structures: tuple[NumericalStructureConfig], + user_vjp: tuple[UserVJPConfig], **run_kwargs: Any, ) -> typing.Callable[[AutogradFieldMap], AutogradFieldMap]: """VJP-maker for ``_run_primitive()``. Constructs and runs adjoint simulations, computes grad.""" @@ -761,7 +1227,6 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap: td.log.info(f"Running {len(sims_adj)} adjoint simulations") vjp_traced_fields = {} - if local_gradient: # Run all adjoint sims in batch td.log.info("Starting local batch adjoint simulations") @@ -779,11 +1244,14 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap: vjp_fields_dict = {} for task_name_adj, sim_data_adj in batch_data_adj.items(): td.log.info(f"Processing VJP contribution from {task_name_adj}") + vjp_fields_dict[task_name_adj] = postprocess_adj( sim_data_adj=sim_data_adj, sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_keys=sim_fields_keys, + numerical_structures=numerical_structures, + user_vjp=user_vjp, ) else: td.log.info("Starting server-side batch of adjoint simulations ...") @@ -835,6 +1303,10 @@ def _run_async_bwd( aux_data_dict: dict[str, dict[str, typing.Any]], local_gradient: bool, max_num_adjoint_per_fwd: int, + numerical_structures: typing.Optional[ + dict[str, typing.Sequence[NumericalStructureConfig]], + ] = None, + user_vjp: typing.Optional[dict[str, typing.Sequence[UserVJPConfig]],] = None, **run_async_kwargs: Any, ) -> typing.Callable[[dict[str, AutogradFieldMap]], dict[str, AutogradFieldMap]]: """VJP-maker for ``_run_primitive()``. Constructs and runs adjoint simulation, computes grad.""" @@ -844,6 +1316,8 @@ def _run_async_bwd( task_names = data_fields_original_dict.keys() + user_vjp = user_vjp or {} + # get the fwd epsilon and field data from the cached aux_data sim_data_orig_dict = {} sim_data_fwd_dict = {} @@ -856,8 +1330,6 @@ def _run_async_bwd( if local_gradient: sim_data_fwd_dict[task_name] = aux_data[AUX_KEY_SIM_DATA_FWD] - td.log.info("constructing custom vjp function for backwards pass.") - def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, AutogradFieldMap]: """dJ/d{sim.traced_fields()} as a function of Function of dJ/d{data.traced_fields()}""" @@ -920,11 +1392,17 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd sim_fields_keys = sim_fields_keys_dict[task_name] # Compute VJP contribution + task_user_vjp = user_vjp.get(task_name) + if isinstance(task_user_vjp, UserVJPConfig): + task_user_vjp = (task_user_vjp,) + vjp_results[adj_task_name] = postprocess_adj( sim_data_adj=sim_data_adj, sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_keys=sim_fields_keys, + numerical_structures=numerical_structures, + user_vjp=task_user_vjp, ) else: # Set up parent tasks mapping for all adjoint simulations @@ -990,6 +1468,8 @@ def postprocess_adj( sim_data_orig: td.SimulationData, sim_data_fwd: td.SimulationData, sim_fields_keys: list[tuple], + user_vjp: tuple[UserVJPConfig], + numerical_structures: tuple[NumericalStructureConfig], ) -> AutogradFieldMap: """Postprocess adjoint results into VJPs (delegated).""" return _postprocess_adj_impl( @@ -997,6 +1477,8 @@ def postprocess_adj( sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_keys=sim_fields_keys, + user_vjp=user_vjp, + numerical_structures=numerical_structures, ) diff --git a/tidy3d/web/api/autograd/backward.py b/tidy3d/web/api/autograd/backward.py index 0c596f61dd..31cb5b108f 100644 --- a/tidy3d/web/api/autograd/backward.py +++ b/tidy3d/web/api/autograd/backward.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing from collections import defaultdict import numpy as np @@ -9,13 +10,22 @@ from tidy3d import Medium from tidy3d.components.autograd import AutogradFieldMap, get_static from tidy3d.components.autograd.derivative_utils import DerivativeInfo -from tidy3d.components.data.data_array import DataArray +from tidy3d.components.data.data_array import DataArray, FreqDataArray, ScalarFieldDataArray +from tidy3d.components.geometry.base import Box +from tidy3d.components.geometry.utils import GeometryType from tidy3d.config import config from tidy3d.exceptions import AdjointError from tidy3d.packaging import disable_local_subpixel +from .types import ( + NumericalStructureConfig, + UserVJPConfig, +) from .utils import E_to_D, get_derivative_maps +if typing.TYPE_CHECKING: + pass + def setup_adj( data_fields_vjp: AutogradFieldMap, @@ -105,18 +115,76 @@ def postprocess_adj( sim_data_orig: td.SimulationData, sim_data_fwd: td.SimulationData, sim_fields_keys: list[tuple], + numerical_structures: typing.Optional[tuple[NumericalStructureConfig]] = None, + user_vjp: typing.Optional[tuple[UserVJPConfig]] = None, ) -> AutogradFieldMap: """Postprocess some data from the adjoint simulation into the VJP for the original sim flds.""" - # map of index into 'structures' to the list of paths we need vjps for + def get_all_paths(match_structure_index: int) -> tuple[tuple[str, str, int]]: + """Get all the paths that may appear in autograd for this structure index. This allows a""" + """user_vjp to be called for all autograd paths for the structure.""" + all_paths = tuple( + tuple(structure_path) + for namespace, structure_index, *structure_path in sim_fields_keys + if structure_index == match_structure_index + ) + + return all_paths + + user_vjp_lookup: dict[int, dict[tuple[str, str], typing.Callable[..., typing.Any]]] = {} + if user_vjp: + for vjp_config in user_vjp: + structure_index = vjp_config.structure_index + vjp_fn = vjp_config.compute_derivatives + path = vjp_config.path_key + + if path is None: + for match_path in get_all_paths(structure_index): + user_vjp_lookup.setdefault(structure_index, {})[match_path[0:2]] = vjp_fn + else: + user_vjp_lookup.setdefault(structure_index, {})[path] = vjp_fn + + # map of index into 'structures' and 'numerical' to the paths we need VJPs for sim_vjp_map = defaultdict(list) - for _, structure_index, *structure_path in sim_fields_keys: + numerical_vjp_map = defaultdict(set) + numerical_structure_indices = [] + for namespace, structure_index, *structure_path in sim_fields_keys: structure_path = tuple(structure_path) - sim_vjp_map[structure_index].append(structure_path) + if namespace == "structures": + sim_vjp_map[structure_index].append(structure_path) + elif namespace == "numerical": + numerical_vjp_map[structure_index].add(structure_path) + numerical_structure_indices.append(structure_index) + + def lookup_numerical_structure(structure_index: int) -> NumericalStructureConfig: + for numerical_structure in numerical_structures: + if numerical_structure.structure_index == structure_index: + return numerical_structure # store the derivative values given the forward and adjoint data sim_fields_vjp = {} - for structure_index, structure_paths in sim_vjp_map.items(): + all_structure_indices = sorted(set(sim_vjp_map.keys()) | set(numerical_vjp_map.keys())) + + for structure_index in all_structure_indices: + structure_paths = tuple(sim_vjp_map.get(structure_index, ())) + + use_numerical_vjp = structure_index in numerical_structure_indices + + numerical_paths_raw = numerical_vjp_map.get(structure_index, set()) + numerical_paths_ordered: tuple[tuple, ...] = () + numerical_value_map: dict[tuple, typing.Any] = {} + numerical_vjp_fn = None + numerical_params_static: tuple[typing.Any, ...] = () + + if use_numerical_vjp: + numerical_structure = lookup_numerical_structure(structure_index) + + numerical_vjp_fn = numerical_structure.compute_derivatives + numerical_params_static = tuple( + get_static(param) for param in numerical_structure.parameters + ) + numerical_paths_ordered = tuple(sorted(numerical_paths_raw)) + # grab the forward and adjoint data fld_fwd = sim_data_fwd._get_adjoint_data(structure_index, data_type="fld") eps_fwd = sim_data_fwd._get_adjoint_data(structure_index, data_type="eps") @@ -215,6 +283,35 @@ def postprocess_adj( rmax_intersect = tuple([min(a, b) for a, b in zip(rmax_sim, rmax_struct)]) bounds_intersect = (rmin_intersect, rmax_intersect) + def updated_epsilon_full( + replacement_geometry: GeometryType, + adjoint_frequencies: typing.Optional[FreqDataArray] = adjoint_frequencies, + structure_index: typing.Optional[int] = structure_index, + eps_box: typing.Optional[Box] = eps_fwd.monitor.geometry, + ) -> ScalarFieldDataArray: + # Return the simulation permittivity for eps_box after replacing the geometry + # for this structure with a new geometry. This is helpful for carrying out finite + # difference permittivity computations + sim_orig = sim_data_orig.simulation + sim_orig_grid_spec = td.components.grid.grid_spec.GridSpec.from_grid(sim_orig.grid) + + update_sim = sim_orig.updated_copy( + structures=[ + sim_orig.structures[idx].updated_copy(geometry=replacement_geometry) + if idx == structure_index + else sim_orig.structures[idx] + for idx in range(len(sim_orig.structures)) + ], + grid_spec=sim_orig_grid_spec, + ) + + eps_by_f = [ + update_sim.epsilon(box=eps_box, coord_key="centers", freq=f) + for f in adjoint_frequencies + ] + + return xr.concat(eps_by_f, dim="f").assign_coords(f=adjoint_frequencies) + # get chunk size - if None, process all frequencies as one chunk freq_chunk_size = config.adjoint.solver_freq_chunk_size n_freqs = len(adjoint_frequencies) @@ -277,48 +374,105 @@ def postprocess_adj( else None ) - # create derivative info with sliced data - derivative_info = DerivativeInfo( - paths=structure_paths, - E_der_map=E_der_map_chunk, - D_der_map=D_der_map_chunk, - H_der_map=H_der_map_chunk, - E_fwd=E_fwd_chunk, - E_adj=E_adj_chunk, - D_fwd=D_fwd_chunk, - D_adj=D_adj_chunk, - H_fwd=H_fwd_chunk, - H_adj=H_adj_chunk, - eps_data=eps_data_chunk, - eps_in=eps_in_chunk, - eps_out=eps_out_chunk, - eps_background=eps_background_chunk, - frequencies=select_adjoint_freqs, # only chunk frequencies - eps_no_structure=eps_no_structure_chunk, - eps_inf_structure=eps_inf_structure_chunk, - bounds=struct_bounds, - bounds_intersect=bounds_intersect, - simulation_bounds=sim_data_orig.simulation.bounds, - is_medium_pec=structure.medium.is_pec, - ) - - # compute derivatives for chunk - vjp_chunk = structure._compute_derivatives(derivative_info) + def updated_epsilon( + replacement_geometry: GeometryType, + select_adjoint_freqs: typing.Optional[FreqDataArray] = select_adjoint_freqs, + updated_epsilon_full: typing.Optional[typing.Callable] = updated_epsilon_full, + ) -> ScalarFieldDataArray: + # Get permittivity function for a subset of frequencies + return updated_epsilon_full(replacement_geometry).sel(f=select_adjoint_freqs) + + common_kwargs = { + "E_der_map": E_der_map_chunk, + "D_der_map": D_der_map_chunk, + "H_der_map": H_der_map_chunk, + "E_fwd": E_fwd_chunk, + "E_adj": E_adj_chunk, + "D_fwd": D_fwd_chunk, + "D_adj": D_adj_chunk, + "H_fwd": H_fwd_chunk, + "H_adj": H_adj_chunk, + "eps_data": eps_data_chunk, + "eps_in": eps_in_chunk, + "eps_out": eps_out_chunk, + "eps_background": eps_background_chunk, + "frequencies": select_adjoint_freqs, + "eps_no_structure": eps_no_structure_chunk, + "eps_inf_structure": eps_inf_structure_chunk, + "updated_epsilon": updated_epsilon, + "bounds": struct_bounds, + "bounds_intersect": bounds_intersect, + "simulation_bounds": sim_data_orig.simulation.bounds, + "is_medium_pec": structure.medium.is_pec, + } + + if structure_paths: + derivative_info_struct = DerivativeInfo( + paths=structure_paths, + **common_kwargs, + ) - # accumulate results - for path, value in vjp_chunk.items(): - if path in vjp_value_map: - val = vjp_value_map[path] - if isinstance(val, (list, tuple)) and isinstance(value, (list, tuple)): - vjp_value_map[path] = type(val)(x + y for x, y in zip(val, value)) + vjp_fns = user_vjp_lookup.get(structure_index) + vjp_chunk = structure._compute_derivatives(derivative_info_struct, vjp_fns=vjp_fns) + + for path, value in vjp_chunk.items(): + if path in vjp_value_map: + existing = vjp_value_map[path] + if isinstance(existing, (list, tuple)) and isinstance(value, (list, tuple)): + vjp_value_map[path] = type(existing)( + x + y for x, y in zip(existing, value) + ) + else: + vjp_value_map[path] = existing + value else: - vjp_value_map[path] += value + vjp_value_map[path] = value + + if use_numerical_vjp: + derivative_info_num = DerivativeInfo( + paths=numerical_paths_ordered, + **common_kwargs, + ) + + gradients = numerical_vjp_fn( + parameters=numerical_params_static, derivative_info=derivative_info_num + ) + + if isinstance(gradients, dict): + gradient_items = ( + (path, gradients.get(path)) for path in numerical_paths_ordered + ) else: - vjp_value_map[path] = value + gradients_seq = tuple(gradients) + if len(gradients_seq) != len(numerical_paths_ordered): + raise AdjointError( + f"User VJP for numerical structure index {structure_index} returned {len(gradients_seq)} gradients, " + f"expected {len(numerical_paths_ordered)}." + ) + gradient_items = zip(numerical_paths_ordered, gradients_seq) + + for path, grad_value in gradient_items: + if grad_value is None: + continue + if path in numerical_value_map: + existing = numerical_value_map[path] + if isinstance(existing, (list, tuple)) and isinstance( + grad_value, (list, tuple) + ): + numerical_value_map[path] = type(existing)( + x + y for x, y in zip(existing, grad_value) + ) + else: + numerical_value_map[path] = existing + grad_value + else: + numerical_value_map[path] = grad_value # store vjps in output map for structure_path, vjp_value in vjp_value_map.items(): sim_path = ("structures", structure_index, *list(structure_path)) sim_fields_vjp[sim_path] = vjp_value + for numerical_path, gradient_value in numerical_value_map.items(): + sim_path = ("numerical", structure_index, *list(numerical_path)) + sim_fields_vjp[sim_path] = gradient_value + return sim_fields_vjp diff --git a/tidy3d/web/api/autograd/engine.py b/tidy3d/web/api/autograd/engine.py index c9f36e0a42..8383d0bb57 100644 --- a/tidy3d/web/api/autograd/engine.py +++ b/tidy3d/web/api/autograd/engine.py @@ -2,8 +2,11 @@ from pathlib import Path from typing import Any +import typing +from os.path import basename, dirname, join import tidy3d as td +from tidy3d.components.autograd.types import NumericalStructureInfo from tidy3d.web.api.container import DEFAULT_DATA_PATH, Batch, Job from .io_utils import get_vjp_traced_fields, upload_sim_fields_keys @@ -75,6 +78,7 @@ def _run_async_tidy3d( def _run_async_tidy3d_bwd( simulations: dict[str, td.Simulation], + numerical_structures_info: typing.Optional[dict[str, dict[int, NumericalStructureInfo]]] = None, **run_kwargs: Any, ) -> dict[str, dict]: """Run a batch of adjoint simulations using regular web.run().""" diff --git a/tidy3d/web/api/autograd/types.py b/tidy3d/web/api/autograd/types.py new file mode 100644 index 0000000000..107ce7a869 --- /dev/null +++ b/tidy3d/web/api/autograd/types.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import typing +from dataclasses import dataclass + +import tidy3d as td +from tidy3d.components.autograd import AutogradFieldMap + + +@dataclass +class NumericalStructureConfig: + create: typing.Callable + """Function that creates the structure given an untraced version of the parameters""" + + compute_derivatives: typing.Callable + """Function that computes the vjp for the structure given the same arguments + that the internal _compute_derivatives function gets.""" + + parameters: typing.Any + """Parameters used for creating the structure.""" + + # we could consider making this Optional and if it is not specified, we could + # just append it to the structures list in the simulation + structure_index: typing.Optional[int] = -1 + """Index for structure in the simulation. If not specified, assume the structure is appended into the structure list.""" + + +@dataclass +class UserVJPConfig: + structure_index: int + """Index for structure to replace vjp.""" + + compute_derivatives: typing.Callable + """Function that computes the vjp for the structure given the same arguments + that the internal _compute_derivatives function gets.""" + + path_key: typing.Optional[str] = None + """Path key this is relevant for. If not specified, assume the supplied function applies for all keys.""" + + +class SetupRunResult(typing.NamedTuple): + sim_fields: AutogradFieldMap + simulation: td.Simulation + numerical_structures: tuple[NumericalStructureConfig] + + +__all__ = [ + "SetupRunResult", +]