diff --git a/devito/ir/equations/algorithms.py b/devito/ir/equations/algorithms.py index ce844887aa..66b44e34f9 100644 --- a/devito/ir/equations/algorithms.py +++ b/devito/ir/equations/algorithms.py @@ -6,14 +6,15 @@ from devito.tools import (Ordering, as_tuple, flatten, filter_sorted, filter_ordered, frozendict) from devito.types import (Dimension, Eq, IgnoreDimSort, SubDimension, - ConditionalDimension) + ConditionalDimension, MultiStage) from devito.types.array import Array from devito.types.basic import AbstractFunction from devito.types.dimension import MultiSubDimension, Thickness from devito.data.allocators import DataReference from devito.logger import warning -__all__ = ['dimension_sort', 'lower_exprs', 'concretize_subdims'] + +__all__ = ['dimension_sort', 'lower_multistage', 'lower_exprs', 'concretize_subdims'] def dimension_sort(expr): @@ -95,6 +96,39 @@ def handle_indexed(indexed): return ordering +def lower_multistage(expressions, **kwargs): + """ + Separating the multi-stage time-integrator scheme in stages: + * If the object is MultiStage, it creates the stages of the method. + """ + return _lower_multistage(expressions, **kwargs) + + +@singledispatch +def _lower_multistage(expr, **kwargs): + """ + Default handler for expressions that are not MultiStage. + Simply return them in a list. + """ + return [expr] + + +@_lower_multistage.register(MultiStage) +def _(expr, **kwargs): + """ + Specialized handler for MultiStage expressions. + """ + return expr._evaluate(**kwargs) + + +@_lower_multistage.register(Iterable) +def _(exprs, **kwargs): + """ + Handle iterables of expressions. + """ + return sum([_lower_multistage(expr, **kwargs) for expr in exprs], []) + + def lower_exprs(expressions, subs=None, **kwargs): """ Lowering an expression consists of the following passes: diff --git a/devito/operations/solve.py b/devito/operations/solve.py index 0203dbe26d..57b65d0e91 100644 --- a/devito/operations/solve.py +++ b/devito/operations/solve.py @@ -7,6 +7,8 @@ from devito.finite_differences.derivative import Derivative from devito.tools import as_tuple +from devito.types.multistage import resolve_method + __all__ = ['solve', 'linsolve'] @@ -15,7 +17,7 @@ class SolveError(Exception): pass -def solve(eq, target, **kwargs): +def solve(eq, target, method = None, eq_num = 0, **kwargs): """ Algebraically rearrange an Eq w.r.t. a given symbol. @@ -56,9 +58,12 @@ def solve(eq, target, **kwargs): # We need to rebuild the vector/tensor as sympy.solve outputs a tuple of solutions if len(sols) > 1: - return target.new_from_mat(sols) + sols_temp = target.new_from_mat(sols) else: - return sols[0] + sols_temp = sols[0] + + method = kwargs.get("method", None) + return sols_temp if method is None else resolve_method(method)(target, sols_temp) def linsolve(expr, target, **kwargs): diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 0d473fe6a2..7f5b769a77 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -17,7 +17,7 @@ InvalidOperator) from devito.logger import (debug, info, perf, warning, is_log_enabled_for, switch_log_level) -from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims +from devito.ir.equations import LoweredEq, lower_multistage, lower_exprs, concretize_subdims from devito.ir.clusters import ClusterGroup, clusterize from devito.ir.iet import (Callable, CInterface, EntryFunction, DeviceFunction, FindSymbols, MetaCall, derive_parameters, iet_build) @@ -40,7 +40,6 @@ disk_layer) from devito.types.dimension import Thickness - __all__ = ['Operator'] @@ -337,6 +336,8 @@ def _lower_exprs(cls, expressions, **kwargs): * Apply substitution rules; * Shift indices for domain alignment. """ + expressions = lower_multistage(expressions, **kwargs) + expand = kwargs['options'].get('expand', True) # Specialization is performed on unevaluated expressions diff --git a/devito/types/__init__.py b/devito/types/__init__.py index 6ec8bdfd16..d669538384 100644 --- a/devito/types/__init__.py +++ b/devito/types/__init__.py @@ -22,3 +22,6 @@ from .relational import * # noqa from .sparse import * # noqa from .tensor import * # noqa + +from .multistage import * # noqa +from .multistage_new import * # noqa diff --git a/devito/types/multistage.py b/devito/types/multistage.py new file mode 100644 index 0000000000..e480ca904e --- /dev/null +++ b/devito/types/multistage.py @@ -0,0 +1,541 @@ +from devito.types.equation import Eq +from devito.types.dense import TimeFunction +from devito.symbolics import uxreplace +import numpy as np +from devito.types.array import Array +from types import MappingProxyType + +method_registry = {} + + +def register_method(cls=None, *, aliases=None): + """ + Register a time integration method class. + + Parameters + ---------- + cls : class, optional + The method class to register. + aliases : list of str, optional + Additional aliases for the method. + """ + def decorator(cls): + # Register the class name + method_registry[cls.__name__] = cls + + # Register any aliases + if aliases: + for alias in aliases: + method_registry[alias] = cls + + return cls + + if cls is None: + # Called as @register_method(aliases=['alias1']) + return decorator + else: + # Called as @register_method + return decorator(cls) + + +def resolve_method(method): + """ + Resolve a time integration method by name. + + Parameters + ---------- + method : str + Name or alias of the time integration method. + + Returns + ------- + class + The method class. + + Raises + ------ + ValueError + If the method is not found in the registry. + """ + try: + return method_registry[method] + except KeyError: + available = sorted(method_registry.keys()) + raise ValueError( + f"The time integrator '{method}' is not implemented. " + f"Available methods: {available}" + ) + + +def multistage_method(lhs, rhs, method, degree=None, source=None, optimized_feature=None): + method_cls = resolve_method(method) + return method_cls(lhs, rhs, degree=degree, source=source, optimized_feature=optimized_feature) + + +class MultiStage(Eq): + """ + Abstract base class for multi-stage time integration methods + (e.g., Runge-Kutta schemes) in Devito. + + This class represents a symbolic equation of the form `target = rhs` + and provides a mechanism to associate it with a time integration + scheme. The specific integration behavior must be implemented by + subclasses via the `_evaluate` method. + + Parameters + ---------- + lhs : expr-like + The left-hand side of the equation, typically a time-updated Function + (e.g., `u.forward`). + rhs : expr-like, optional + The right-hand side of the equation to integrate. Defaults to 0. + subdomain : SubDomain, optional + A subdomain over which the equation applies. + coefficients : dict, optional + Optional dictionary of symbolic coefficients for the integration. + implicit_dims : tuple, optional + Additional dimensions that should be treated implicitly in the equation. + **kwargs : dict + Additional keyword arguments, such as time integration method selection. + + Notes + ----- + Subclasses must override the `_evaluate()` method to return a sequence + of update expressions for each stage in the integration process. + """ + + def __new__(cls, lhs, rhs, degree=None, source=None, optimized_feature=None, **kwargs): + # Normalize to lists first lhs and rhs + if not isinstance(lhs, (list, tuple)): + lhs = [lhs] + if not isinstance(rhs, (list, tuple)): + rhs = [rhs] + + # Convert to tuples for immutability + lhs_tuple = tuple(i.function for i in lhs) + rhs_tuple = tuple(rhs) + + obj = super().__new__(cls, lhs_tuple[0], rhs_tuple[0], **kwargs) + + # Store all equations as immutable tuples + obj._eq = tuple(Eq(lhs, rhs) for lhs, rhs in zip(lhs_tuple, rhs_tuple)) + obj._lhs = lhs_tuple + obj._rhs = rhs_tuple + obj._deg = degree + # Convert source to tuple of tuples for immutability + obj._src = tuple(tuple(item) + for item in source) if source is not None else None + obj._t = lhs_tuple[0].grid.time_dim + obj._dt = obj._t.spacing + obj._optimized_feature = optimized_feature + + return obj + + @property + def eq(self): + """Full tuple of equations""" + return self._eq + + @property + def lhs(self): + """Tuple of left-hand sides""" + return self._lhs + + @property + def rhs(self): + """Tuple of right-hand sides""" + return self._rhs + + @property + def deg(self): + """Degree parameter (e.g., number of stages)""" + return self._deg + + @property + def src(self): + """Source parameter as tuple of tuples (immutable)""" + return self._src + + @property + def t(self): + """Time (t) parameter""" + return self._t + + @property + def dt(self): + """Time step (dt) parameter""" + return self._dt + + @property + def n_eq(self): + """Number of equations""" + return len(self.lhs) + + def _evaluate(self, **kwargs): + raise NotImplementedError( + f"_evaluate() must be implemented in the subclass {self.__class__.__name__}") + + +class TableauRungeKutta(MultiStage): + """ + Base class for explicit Runge-Kutta (RK) time integration methods defined + via a Butcher tableau. + + This class handles the general structure of RK schemes by using + the Butcher coefficients (`a`, `b`, `c`) to expand a single equation into + a series of intermediate stages followed by a final update. Subclasses + must define `a`, `b`, and `c` as class attributes. + + Parameters + ---------- + a : tuple of tuple of float + The coefficient matrix representing stage dependencies. + b : tuple of float + The weights for the final combination step. + c : tuple of float + The time shifts for each intermediate stage (often the row sums of `a`). + + Attributes + ---------- + a : tuple[tuple[float, ...], ...] + Butcher tableau `a` coefficients (stage coupling). + b : tuple[float, ...] + Butcher tableau `b` coefficients (weights for combining stages). + c : tuple[float, ...] + Butcher tableau `c` coefficients (stage time positions). + s : int + Number of stages in the RK method, inferred from `b`. + """ + + CoeffsBC = tuple[float | np.number, ...] + CoeffsA = tuple[CoeffsBC, ...] + + def __init__(self, lhs, rhs, a: CoeffsA = None, b: CoeffsBC = None, + c: CoeffsBC = None, **kwargs) -> None: + self.a = a if a is not None else getattr(self, 'a', None) + self.b = b if b is not None else getattr(self, 'b', None) + self.c = c if c is not None else getattr(self, 'c', None) + + if self.a is None or self.b is None or self.c is None: + raise ValueError("TableauRungeKutta requires coefficients 'a', 'b', and 'c'.") + + @property + def s(self): + return len(self.b) + + def _evaluate(self, **kwargs): + """ + Generate the stage-wise equations for a Runge-Kutta time integration method. + + This method takes a single equation of the form `Eq(u.forward, rhs)` and + expands it into a sequence of intermediate stage evaluations and a final + update equation according to the Runge-Kutta coefficients `a`, `b`, and `c`. + + Returns + ------- + list of Devito Eq objects + A list of SymPy Eq objects representing: + - `s` stage equations of the form `k_i = rhs evaluated at intermediate state` + - 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)` + """ + + sregistry = kwargs.get('sregistry') + # Create temporary Arrays to hold each stage + + k = [] + for j in range(self.n_eq): + k_j = [] + for _ in range(self.s): + k_name = sregistry.make_name(prefix="k") + k_j.append(TimeFunction(name=k_name, grid=self.lhs[j].grid, + space_order=self.lhs[j].space_order, time_order=0, dtype=self.lhs[j].dtype)) + k.append(k_j) + + stage_eqs = [] + + # Build each stage + for i in range(self.s): + u_temp = [self.lhs[l] + self.dt * sum(aij * kj for aij, kj in zip( + self.a[i][:i], k[l][:i])) for l in range(self.n_eq)] + t_shift = self.t + self.c[i] + + # Evaluate RHS at intermediate value + stage_rhs = [uxreplace(self.rhs[l], {**{self.lhs[m]: u_temp[m] for m in range( + self.n_eq)}, self.t: t_shift}) for l in range(self.n_eq)] + stage_eqs.extend([Eq(k[l][i], stage_rhs[l]) + for l in range(self.n_eq)]) + + # Final update: u = u + dt * sum(b_i * k_i) + u_next = [self.lhs[l] + self.dt * + sum(bi * ki for bi, ki in zip(self.b, k[l])) for l in range(self.n_eq)] + stage_eqs.extend([Eq(self.lhs[l].forward, u_next[l]) + for l in range(self.n_eq)]) + + return stage_eqs + + +@register_method(aliases=['RK44']) +class RungeKutta44(TableauRungeKutta): + """ + Classic 4th-order Runge-Kutta (RK4) time integration method. + + This class implements the classic explicit Runge-Kutta method of order 4 (RK44). + + Attributes + ---------- + a : tuple[tuple[float, ...], ...] + Coefficients of the `a` matrix for intermediate stage coupling. + b : tuple[float, ...] + Weights for final combination. + c : tuple[float, ...] + Time positions of intermediate stages. + """ + a = ((0, 0, 0, 0), + (1/2, 0, 0, 0), + (0, 1/2, 0, 0), + (0, 0, 1, 0)) + b = (1/6, 1/3, 1/3, 1/6) + c = (0, 1/2, 1/2, 1) + +@register_method(aliases=['RK32']) +class RungeKutta32(TableauRungeKutta): + """ + 3 stages 2nd-order Runge-Kutta (RK32) time integration method. + + This class implements the 3-stages explicit Runge-Kutta method of order 2 (RK32). + + Attributes + ---------- + a : list[list[float]] + Coefficients of the `a` matrix for intermediate stage coupling. + b : list[float] + Weights for final combination. + c : list[float] + Time positions of intermediate stages. + """ + a = ((0, 0, 0), + (1/2, 0, 0), + (0, 1/2, 0)) + b = (0, 0, 1) + c = (0, 1/2, 1/2) + +@register_method(aliases=['RK97']) +class RungeKutta97(TableauRungeKutta): + """ + 9 stages 7th-order Runge-Kutta (RK97) time integration method. + + This class implements the 9-stages explicit Runge-Kutta method of order 7 (RK97). + + Attributes + ---------- + a : list[list[float]] + Coefficients of the `a` matrix for intermediate stage coupling. + b : list[float] + Weights for final combination. + c : list[float] + Time positions of intermediate stages. + """ + a = ((0, 0, 0, 0, 0, 0, 0, 0, 0), + (4/63, 0, 0, 0, 0, 0, 0, 0, 0), + (1/42, 1/14, 0, 0, 0, 0, 0, 0, 0), + (1/28, 0, 3/28, 0, 0, 0, 0, 0, 0), + (12551/19652, 0, -48363/19652, 10976/4913, 0, 0, 0, 0, 0), + (-36616931/27869184, 0, 2370277/442368, -255519173 / + 63700992, 226798819/445906944, 0, 0, 0, 0), + (-10401401/7164612, 0, 47383/8748, -4914455 / + 1318761, -1498465/7302393, 2785280/3739203, 0, 0, 0), + (181002080831/17500000000, 0, -14827049601/400000000, 23296401527134463/857600000000000, + 2937811552328081/949760000000000, -243874470411/69355468750, 2857867601589/3200000000000), + (-228380759/19257212, 0, 4828803/113948, -331062132205/10932626912, -12727101935/3720174304, + 22627205314560/4940625496417, -268403949/461033608, 3600000000000/19176750553961)) + b = (95/2366, 0, 0, 3822231133/16579123200, 555164087/2298419200, 1279328256/9538891505, + 5963949/25894400, 50000000000/599799373173, 28487/712800) + c = (0, 4/63, 2/21, 1/7, 7/17, 13/24, 7/9, 91/100, 1) + +@register_method(aliases=['HORK_EXP']) +class HighOrderRungeKuttaExponential(MultiStage): + # In construction + """ + n stages Runge-Kutta (HORK) time integration method. + + This class implements the arbitrary high-order explicit Runge-Kutta method. + + Attributes + ---------- + a : list[list[float]] + Coefficients of the `a` matrix for intermediate stage coupling. + b : list[float] + Weights for final combination. + c : list[float] + Time positions of intermediate stages. + """ + + def source_derivatives(self, src_index, **kwargs): + + # Compute the base wavelet function + f_deriv = [[src[1] for src in self.src]] + + # Compute derivatives up to order p + for _ in range(self.deg - 1): + f_deriv.append([deriv.diff(self.t) for deriv in f_deriv[-1]]) + + f_deriv.reverse() + return f_deriv + + def ssprk_alpha(self, mu=1): + """ + Computes the coefficients for the Strong Stability Preserving Runge-Kutta (SSPRK) method. + + Parameters: + mu : float + Theoretically, it should be the inverse of the CFL condition (typically mu=1 for best performance). + In practice, mu=1 works better. + degree : int + Degree of the polynomial used in the time-stepping scheme. + + Returns: + numpy.ndarray + Array of SSPRK coefficients. + """ + + alpha = [0] * self.deg + alpha[0] = 1.0 # Initial coefficient + + # recurrence relation to compute the HORK coefficients following the formula in Gottlieb and Gottlieb (2002) + for i in range(1, self.deg): + alpha[i] = 1 / (mu * (i + 1)) * alpha[i - 1] + alpha[1:i] = [1 / (mu * j) * alpha[j - 1] for j in range(1, i)] + alpha[0] = 1 - sum(alpha[1:i + 1]) + + return alpha + + def source_inclusion(self, current_state, stage_values, e_p, **integration_params): + """ + Include source terms in the time integration step. + + This method applies source term contributions to the right-hand side + of the differential equations during time integration, accounting for + time derivatives of the source function and expansion coefficients. + + Parameters + ---------- + current_state : list + Current state variables (u). + stage_values : list + Current stage values (k). + e_p : list + Expansion coefficients for stability control. + **integration_params : dict + Integration parameters containing 't', 'dt', 'mu', 'src_index', + 'src_deriv', 'n_eq'. + + Returns + ------- + tuple + (modified_rhs, updated_e_p) - Updated right-hand side + equations and modified expansion coefficients. + """ + # Extract integration parameters + mu = integration_params['mu'] + src_index = integration_params['src_index'] + src_deriv = integration_params['src_deriv'] + n_eq = integration_params['n_eq'] + + # Build base right-hand side by substituting current stage values + src_lhs = [uxreplace(self.rhs[i], {current_state[m]: stage_values[m] for m in range(n_eq)}) + for i in range(n_eq)] + + # Apply source term contributions if sources exist + if self.src is not None: + p = len(src_deriv) + + # Add source contributions for each derivative order + for i in range(p): + if e_p[i] != 0: + for j, idx in enumerate(src_index): + # Add weighted source derivative contribution + source_contribution = (self.src[j][0] * src_deriv[i][j].subs({self.t: self.t * self.dt}) * e_p[i]) + src_lhs[idx] += source_contribution + + # Update expansion coefficients for next stage + e_p = [e_p[i] + mu*self.dt*e_p[i + 1] for i in range(p - 1)] + [e_p[-1]] + + return src_lhs, e_p + + def _evaluate(self, **kwargs): + """ + Generate the stage-wise equations for a Runge-Kutta time integration method. + + This method takes a single equation of the form `Eq(u.forward, rhs)` and + expands it into a sequence of intermediate stage evaluations and a final + update equation according to the Runge-Kutta coefficients `a`, `b`, and `c`. + + Returns + ------- + list of Eq + A list of SymPy Eq objects representing: + - `s` stage equations of the form `k_i = rhs evaluated at intermediate state` + - 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)` + """ + + sregistry = kwargs.get('sregistry') + # Create a temporary Array for each variable to save the time stages + # k = [Array(name=f'{sregistry.make_name(prefix='k')}', dimensions=u[i].grid.dimensions, grid=u[i].grid, dtype=u[i].dtype) for i in range(n_eq)] + k = [TimeFunction(name=f'{sregistry.make_name(prefix="k")}', grid=self.lhs[i].grid, + space_order=self.lhs[i].space_order, time_order=0, dtype=self.lhs[i].dtype) for i in range(self.n_eq)] + k_old = [TimeFunction(name=f'{sregistry.make_name(prefix="k_old")}', grid=self.lhs[i].grid, + space_order=self.lhs[i].space_order, time_order=0, dtype=self.lhs[i].dtype) for i in range(self.n_eq)] + # Compute SSPRK coefficients + mu = 1 + alpha = self.ssprk_alpha(mu=mu) + + # Initialize symbolic differentiation for source terms + field_map = {val: i for i, val in enumerate(self.lhs)} + if self.src is not None: + src_index = [field_map[src[2]] for src in self.src] + src_deriv = self.source_derivatives(src_index, **kwargs) + else: + src_index = None + src_deriv = None + + # Expansion coefficients for stability control + e_p = [0] * self.deg + eta = 1 + e_p[-1] = 1 / eta + + stage_eqs = [Eq(ki, ui) for ki, ui in zip(k, self.lhs)] + stage_eqs.extend([Eq(lhs_i.forward, lhs_i*alpha[0]) for lhs_i in self.lhs]) + + # Prepare integration parameters for source inclusion + integration_params = {'mu': mu, 'src_index': src_index, + 'src_deriv': src_deriv, 'n_eq': self.n_eq} + + # Build each stage + for i in range(1, self.deg - 1): + # saving stage variables for consistent spatial operator application + stage_eqs.extend([Eq(k_old_j, k_j) for k_old_j, k_j in zip(k_old, k)]) + + # include source terms approximation in the current stage evaluation + src_lhs, e_p = self.source_inclusion(self.lhs, k_old, e_p, **integration_params) + + # update stage equations with source contributions + stage_eqs.extend([Eq(k_j, k_old_j+mu*self.dt*src_lhs_j) for k_j, k_old_j, src_lhs_j in zip(k, k_old, src_lhs)]) + + # include the last stage to the final approximation with the corresponding alpha coefficient + stage_eqs.extend([Eq(lhs_j.forward, lhs_j.forward+k_j*alpha[i]) for lhs_j, k_j in zip(self.lhs, k)]) + + # Final Runge-Kutta updates + stage_eqs.extend([Eq(k_old_j, k_j) for k_old_j, k_j in zip(k_old, k)]) + src_lhs, e_p = self.source_inclusion(self.lhs, k_old, e_p, **integration_params) + stage_eqs.extend([Eq(k_j, k_old_j+mu*self.dt*src_lhs_j) for k_j, k_old_j, src_lhs_j in zip(k, k_old, src_lhs)]) + + stage_eqs.extend([Eq(k_old_j, k_j) for k_old_j, k_j in zip(k_old, k)]) + src_lhs, _ = self.source_inclusion(self.lhs, k_old, e_p, **integration_params) + stage_eqs.extend([Eq(k_j, k_old_j+mu*self.dt*src_lhs_j) for k_j, k_old_j, src_lhs_j in zip(k, k_old, src_lhs)]) + + # Compute final approximation + stage_eqs.extend([Eq(lhs_j.forward, lhs_j.forward+k_j*alpha[self.deg-1]) for lhs_j, k_j in zip(self.lhs, k)]) + + return stage_eqs + +method_registry = MappingProxyType(method_registry) \ No newline at end of file diff --git a/tests/test_multistage.py b/tests/test_multistage.py new file mode 100644 index 0000000000..b47c83d012 --- /dev/null +++ b/tests/test_multistage.py @@ -0,0 +1,503 @@ +import pytest +import numpy as np +import sympy as sym +import tempfile +import pickle +import os + +from devito import (Grid, Function, TimeFunction, + Derivative, Operator, solve, Eq, configuration) +from devito.types.multistage import multistage_method, MultiStage +from devito.ir.support import SymbolRegistry +from devito.ir.equations import lower_multistage + +configuration['log-level'] = 'DEBUG' + + +def grid_parameters(extent=(10, 10), shape=(3, 3)): + grid = Grid(origin=(0, 0), extent=extent, shape=shape, dtype=np.float64) + x, y = grid.dimensions + dt = grid.stepping_dim.spacing + t = grid.time_dim + dx = extent[0] / (shape[0] - 1) + return grid, x, y, dt, t, dx + + +def time_parameters(tn, dx, scale=1, t0=0): + t0, tn = 0.0, tn + dt0 = scale / dx**2 + nt = int((tn - t0) / dt0) + dt0 = tn / nt + return tn, dt0, nt + + +class Test_API: + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_pickles(self, time_int): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters(extent=(1, 1), shape=(3, 3)) + + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u', 'v'] + u = [TimeFunction(name=name, grid=grid, space_order=2, + time_order=1, dtype=np.float64) for name in fun_labels] + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[1, 1] = 1 + src_temporal = (1 - 2 * (t * dt - 1)**2) + + # PDE system (2D acoustic) + system_eqs_rhs = [u[1] + src_spatial * src_temporal, + Derivative(u[0], (x, 2), fd_order=2) + + Derivative(u[0], (y, 2), fd_order=2) + + src_spatial * src_temporal] + + # Class of the time integration scheme + method = multistage_method(u, system_eqs_rhs, time_int) + + with tempfile.NamedTemporaryFile(delete=False) as tmpfile: + pickle.dump(method, tmpfile) + filename = tmpfile.name + + with open(filename, 'rb') as file: + method_saved = pickle.load(file) + os.remove(filename) + + assert str(method) == str( + method_saved), "Mismatch in PDE after pickling" + + op_orig = Operator(method) + op_saved = Operator(method_saved) + + assert str(op_orig) == str(op_saved) + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_solve(self, time_int): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters(extent=(1, 1), shape=(3, 3)) + + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u', 'v'] + u = [TimeFunction(name=name, grid=grid, space_order=2, + time_order=1, dtype=np.float64) for name in fun_labels] + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[1, 1] = 1 + src_temporal = (1 - 2 * (t * dt - 1)**2) + + # PDE system (2D acoustic) + system_eqs_rhs = [u[1] + src_spatial * src_temporal, + Derivative(u[0], (x, 2), fd_order=2) + + Derivative(u[0], (y, 2), fd_order=2) + + src_spatial * src_temporal] + + # Time integration scheme + pdes = [solve(system_eqs_rhs[i] - u[i], u[i], method=time_int) + for i in range(2)] + + assert all(isinstance(i, MultiStage) + for i in pdes), "Not all elements are instances of MultiStage" + + +class Test_lowering: + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_object(self, time_int): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters(extent=(1, 1), shape=(3, 3)) + + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u', 'v'] + u = [TimeFunction(name=name, grid=grid, space_order=2, + time_order=1, dtype=np.float64) for name in fun_labels] + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[1, 1] = 1 + src_temporal = (1 - 2 * (t * dt - 1)**2) + + # PDE system (2D acoustic) + system_eqs_rhs = [u[1] + src_spatial * src_temporal, + Derivative(u[0], (x, 2), fd_order=2) + + Derivative(u[0], (y, 2), fd_order=2) + + src_spatial * src_temporal] + + # Class of the time integration scheme + pdes = multistage_method(u, system_eqs_rhs, time_int) + + assert isinstance( + pdes, MultiStage), "Not all elements are instances of MultiStage" + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_lower_multistage(self, time_int): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters(extent=(1, 1), shape=(3, 3)) + + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u', 'v'] + u = [TimeFunction(name=name, grid=grid, space_order=2, + time_order=1, dtype=np.float64) for name in fun_labels] + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[1, 1] = 1 + src_temporal = (1 - 2 * (t * dt - 1)**2) + + # PDE system (2D acoustic) + system_eqs_rhs = [u[1] + src_spatial * src_temporal, + Derivative(u[0], (x, 2), fd_order=2) + + Derivative(u[0], (y, 2), fd_order=2) + + src_spatial * src_temporal] + + # Class of the time integration scheme + pdes = multistage_method(u, system_eqs_rhs, time_int) + + # Test the lowering process + sregistry = SymbolRegistry() + + # Lower the multistage method - this should not raise an exception + lowered_eqs = lower_multistage(pdes, sregistry=sregistry) + + # Validate the lowered equations + assert lowered_eqs is not None, "Lowering returned None" + assert len(lowered_eqs) > 0, "Lowering returned empty list" + + +class Test_RK: + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_single_equation_integration(self, time_int): + """ + Test single equation time integration with MultiStage methods. + + This test verifies that time integration works correctly for the simplest case: + a single PDE with a single unknown function. This represents the most basic + MultiStage usage scenario (e.g., heat equation, simple wave equation). + """ + + # Grid setup + grid, x, y, dt, t, dx = grid_parameters( + extent=(1, 1), shape=(200, 200)) + + # Define single unknown function + u_multi_stage = TimeFunction(name='u_multi_stage', grid=grid, space_order=2, + time_order=1, dtype=np.float64) + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[1, 1] = 1 + src_temporal = (1 - 2 * (t * dt - 1)**2) + + # Single PDE: du/dt = ∇²u + source (diffusion/wave equation) + eq_rhs = (Derivative(u_multi_stage, (x, 2), fd_order=2) + + Derivative(u_multi_stage, (y, 2), fd_order=2) + + src_spatial * src_temporal) + + # Store initial data for comparison + initial_data = u_multi_stage.data.copy() + + # Time integration scheme - single equation MultiStage object + pde = multistage_method(u_multi_stage, eq_rhs, time_int) + + # Run the operator + op = Operator([pde], subs=grid.spacing_map) # Operator expects a list + op(dt=0.01, time=1) + + # Verify that computation actually occurred (data changed) + assert not np.array_equal( + u_multi_stage.data, initial_data), "Data should have changed" + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_decoupled_equations(self, time_int): + """ + Test decoupled time integration where each equation gets its own MultiStage object. + + This test verifies that time integration works when creating separate MultiStage + objects for each equation, as opposed to coupled integration where all equations + are handled by a single MultiStage object. + """ + # Grid setup + grid, x, y, dt, t, dx = grid_parameters( + extent=(1, 1), shape=(200, 200)) + + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u_multi_stage', 'v_multi_stage'] + u_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2, time_order=1, dtype=np.float64) + for name in fun_labels] + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[1, 1] = 1 + src_temporal = (1 - 2 * (t * dt - 1)**2) + + # PDE system - each equation independent for decoupled integration + system_eqs_rhs = [u_multi_stage[1] + src_spatial * src_temporal, + Derivative(u_multi_stage[0], (x, 2), fd_order=2) + + Derivative(u_multi_stage[0], (y, 2), fd_order=2) + + src_spatial * src_temporal] + + # Store initial data for comparison + initial_data = [u.data.copy() for u in u_multi_stage] + + # Time integration scheme - create separate MultiStage objects (decoupled) + pdes = [multistage_method(u_multi_stage[i], system_eqs_rhs[i], time_int) + for i in range(len(fun_labels))] + + # Run the operator + op = Operator(pdes, subs=grid.spacing_map) + op(dt=0.01, time=1) + + # Verify that computation actually occurred (data changed) + for i, u in enumerate(u_multi_stage): + assert not np.array_equal( + u.data, initial_data[i]), f"Data should have changed for variable {i}" + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_coupled_op_computing(self, time_int): + """ + Test coupled time integration where all equations are handled by a single MultiStage object. + + This test verifies that time integration works correctly when multiple coupled equations + are integrated together within a single MultiStage object, allowing for proper coupling + between the equations during the time stepping process. + """ + # Grid setup + grid, x, y, dt, t, dx = grid_parameters( + extent=(1, 1), shape=(200, 200)) + + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u_multi_stage', 'v_multi_stage'] + u_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2, time_order=1, + dtype=np.float64) for name in fun_labels] + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[1, 1] = 1 + src_temporal = (1 - 2 * (t * dt - 1)**2) + + # PDE system - coupled acoustic wave equations + system_eqs_rhs = [u_multi_stage[1], # velocity equation: du/dt = v + Derivative(u_multi_stage[0], (x, 2), fd_order=2) + + Derivative(u_multi_stage[0], (y, 2), fd_order=2) + + src_spatial * src_temporal] # displacement equation: dv/dt = ∇²u + source + + # Store initial data for comparison + initial_data = [u.data.copy() for u in u_multi_stage] + + # Time integration scheme - single coupled MultiStage object + pdes = multistage_method(u_multi_stage, system_eqs_rhs, time_int) + + # Run the operator + op = Operator(pdes, subs=grid.spacing_map) + op(dt=0.01, time=1) + + # Verify that computation actually occurred (data changed) + for i, u in enumerate(u_multi_stage): + assert not np.array_equal( + u.data, initial_data[i]), f"Data should have changed for variable {i}" + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_low_order_convergence_ODE(self, time_int): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters(extent=(10, 10), shape=(3, 3)) + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[:] = 1 + src_temporal = 2 * t * dt + + # Time axis + tn, dt0, nt = time_parameters(3.0, dx, scale=1e-2) + + # Time integrator solution + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u', 'v'] + u_multi_stage = [TimeFunction(name=name + '_multi_stage', grid=grid, space_order=2, time_order=1, + dtype=np.float64) for name in fun_labels] + + # PDE (2D acoustic) + eq_rhs = [ + (-1.5 * u_multi_stage[0] + 0.5 * u_multi_stage[1]) * src_spatial * src_temporal, + (-1.5 * u_multi_stage[1] + 0.5 * u_multi_stage[0]) * src_spatial * src_temporal] + u_multi_stage[0].data[0, :] = 1 + + # Time integration scheme + pdes = multistage_method(u_multi_stage, eq_rhs, time_int) + op = Operator(pdes, subs=grid.spacing_map) + op(dt=dt0, time=nt) + + # exact solution + d = np.array([-1, -2]) + a = np.array([[1, 1], [1, -1]]) + exact_sol = np.dot( + np.dot(a, np.diag(np.exp(d * tn**2))), np.linalg.inv(a)) + assert np.max(np.abs(exact_sol[0, 0] - u_multi_stage[0].data[0, :]) + ) < 10 ** -5, "the method is not converging to the solution" + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_low_order_convergence(self, time_int): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters( + extent=(1000, 1000), shape=(201, 201)) + + # Medium velocity model + vel = Function(name=f"vel_{time_int}", + grid=grid, space_order=2, dtype=np.float64) + vel.data[:] = 1.0 + vel.data[150:, :] = 1.3 + + # Source definition + src_spatial = Function( + name=f"src_spat_{time_int}", grid=grid, space_order=2, dtype=np.float64) + src_spatial.data[100, 100] = 1 / dx**2 + f0 = 0.01 + src_temporal = (1 - 2 * (np.pi * f0 * (t * dt - 1 / f0))**2) * \ + sym.exp(-(np.pi * f0 * (t * dt - 1 / f0))**2) + + # Time axis + tn, dt0, nt = time_parameters(500.0, dx, scale=1e-1 * np.max(vel.data)) + + # Time integrator solution + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u', 'v'] + u_multi_stage = [ + TimeFunction(name=f"{name}_multi_stage_{time_int}", grid=grid, space_order=2, time_order=1, + dtype=np.float64) for name in fun_labels] + + # PDE (2D acoustic) + eq_rhs = [u_multi_stage[1], (Derivative(u_multi_stage[0], (x, 2), fd_order=2) + + Derivative(u_multi_stage[0], (y, 2), fd_order=2) + + src_spatial * src_temporal) * vel**2] + + # Time integration scheme + pdes = multistage_method(u_multi_stage, eq_rhs, time_int) + op = Operator(pdes, subs=grid.spacing_map) + op(dt=dt0, time=nt) + + # Devito's default solution + u = [TimeFunction(name=f"{name}_{time_int}", grid=grid, space_order=2, + time_order=1, dtype=np.float64) for name in fun_labels] + + # PDE (2D acoustic) + eq_rhs = [u[1], (Derivative(u[0], (x, 2), fd_order=2) + Derivative(u[0], (y, 2), fd_order=2) + + src_spatial * src_temporal) * vel**2] + + # Time integration scheme + pdes = [Eq(u[i].forward, solve(Eq(u[i].dt - eq_rhs[i]), u[i].forward)) + for i in range(len(fun_labels))] + op = Operator(pdes, subs=grid.spacing_map) + op(dt=dt0, time=nt) + + assert (np.linalg.norm(u[0].data[0, :] - u_multi_stage[0].data[0, :]) / np.linalg.norm( + u[0].data[0, :])) < 10**-1, "the method is not converging to the solution" + + +class Test_HORK: + + def test_coupled_op_computing_exp(self, time_int='HORK_EXP'): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters( + extent=(1, 1), shape=(201, 201)) + + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u_multi_stage', 'v_multi_stage'] + u_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2, time_order=1, + dtype=np.float64) for name in fun_labels] + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=np.float64) + src_spatial.data[100, 100] = 1 + src_temporal = sym.exp(- 100 * (t - 0.01) ** 2) + + # PDE system + system_eqs_rhs = [u_multi_stage[1], + Derivative(u_multi_stage[0], (x, 2), fd_order=2) + + Derivative(u_multi_stage[0], (y, 2), fd_order=2)] + + # Store initial data for comparison + initial_data = [u.data.copy() for u in u_multi_stage] + + src = [[src_spatial, src_temporal, u_multi_stage[0]], + [src_spatial, src_temporal * 10, u_multi_stage[0]], + [src_spatial, src_temporal, u_multi_stage[1]]] + + # Time integration scheme + pdes = multistage_method( + u_multi_stage, system_eqs_rhs, time_int, degree=4, source=src) + op = Operator(pdes, subs=grid.spacing_map) + op(dt=0.001, time=2000) + + # Verify that computation actually occurred (data changed) + for i, u in enumerate(u_multi_stage): + assert not np.array_equal( + u.data, initial_data[i]), f"Data should have changed for variable {i}" + + + @pytest.mark.parametrize('degree', list(range(3, 11))) + def test_HORK_EXP_convergence(self, degree): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters( + extent=(1000, 1000), shape=(201, 201)) + + # Medium velocity model + vel = Function(name="vel", grid=grid, space_order=2, dtype=np.float64) + vel.data[:] = 1.0 + vel.data[150:, :] = 1.3 + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[100, 100] = 1 / dx**2 + f0 = 0.01 + src_temporal = (1 - 2 * (np.pi * f0 * (t - 1 / f0))**2) * sym.exp(-(np.pi * f0 * (t - 1 / f0))**2) + + # Time axis + tn, dt0, nt = time_parameters(500.0, dx, scale=np.max(vel.data)) + + # Time integrator solution + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u_sol', 'v_sol'] + u_multi_stage = [TimeFunction(name=name + '_multi_stage', grid=grid, space_order=2, time_order=1, + dtype=np.float64) for name in fun_labels] + + # PDE (2D acoustic) + eq_rhs = [u_multi_stage[1], (Derivative(u_multi_stage[0],(x,2), fd_order=2) + Derivative( + u_multi_stage[0], (y,2), fd_order=2)) * vel**2] + + src = [[src_spatial * vel**2, src_temporal, u_multi_stage[1]]] + + # Time integration scheme + pdes = multistage_method( + u_multi_stage, eq_rhs, 'HORK_EXP', source=src, degree=degree) + op = Operator(pdes, subs=grid.spacing_map) + op(dt=dt0, time=nt) + + # Devito's default solution + u = [TimeFunction(name=name, grid=grid, space_order=2, + time_order=1, dtype=np.float64) for name in fun_labels] + + # PDE (2D acoustic) + src_temporal = (1 - 2 * (np.pi * f0 * (t * dt - 1 / f0))**2) * sym.exp(-(np.pi * f0 * (t * dt - 1 / f0))**2) + eq_rhs = [u[1], (Derivative(u[0], (x, 2), fd_order=2) + + Derivative(u[0], (y, 2), fd_order=2) + + src_spatial * src_temporal) * vel**2] + + # Time integration scheme + pdes = [Eq(u[i].forward, solve(Eq(u[i].dt - eq_rhs[i]), u[i].forward)) + for i in range(len(fun_labels))] + op = Operator(pdes, subs=grid.spacing_map) + op(dt=dt0, time=nt) + + assert (np.linalg.norm(u[0].data[0, :] - u_multi_stage[0].data[0, :]) / np.linalg.norm( + u[0].data[0, :])) < 10**-1, "the method is not converging to the solution" \ No newline at end of file