From b7cb532634afab25041703e68382d9e284c376ef Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:28:37 -0700 Subject: [PATCH 1/7] reference --- diffrax/__init__.py | 2 ++ diffrax/_brownian/path.py | 33 ++++++++++++++++++++++++--------- diffrax/_custom_types.py | 30 ++++++++++++++++++++++++------ 3 files changed, 50 insertions(+), 15 deletions(-) diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 67b4ca50..75fa14e4 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -17,9 +17,11 @@ AbstractBrownianIncrement as AbstractBrownianIncrement, AbstractSpaceTimeLevyArea as AbstractSpaceTimeLevyArea, AbstractSpaceTimeTimeLevyArea as AbstractSpaceTimeTimeLevyArea, + AbstractWeakSpaceSpaceLevyArea as AbstractWeakSpaceSpaceLevyArea, BrownianIncrement as BrownianIncrement, SpaceTimeLevyArea as SpaceTimeLevyArea, SpaceTimeTimeLevyArea as SpaceTimeTimeLevyArea, + WeakSpaceSpaceLevyArea as WeakSpaceSpaceLevyArea, ) from ._event import ( # Deliberately not provided with `X as X` as these are now deprecated, so we'd like diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index 0333caa5..be46b4be 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -18,6 +18,7 @@ RealScalarLike, SpaceTimeLevyArea, SpaceTimeTimeLevyArea, + WeakSpaceSpaceLevyArea, ) from .._misc import ( force_bitcast_convert_type, @@ -27,6 +28,11 @@ from .base import AbstractBrownianPath +_Levy_Areas = Union[ + BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea, WeakSpaceSpaceLevyArea +] + + class UnsafeBrownianPath(AbstractBrownianPath): """Brownian simulation that is only suitable for certain cases. @@ -62,18 +68,14 @@ class UnsafeBrownianPath(AbstractBrownianPath): """ shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) - levy_area: type[ - Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea] - ] = eqx.field(static=True) + levy_area: type[_Levy_Areas] = eqx.field(static=True) key: PRNGKeyArray def __init__( self, shape: Union[tuple[int, ...], PyTree[jax.ShapeDtypeStruct]], key: PRNGKeyArray, - levy_area: type[ - Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea] - ] = BrownianIncrement, + levy_area: type[_Levy_Areas] = BrownianIncrement, ): self.shape = ( jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype()) @@ -141,9 +143,7 @@ def _evaluate_leaf( t1: RealScalarLike, key, shape: jax.ShapeDtypeStruct, - levy_area: type[ - Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea] - ], + levy_area: type[_Levy_Areas], use_levy: bool, ): w_std = jnp.sqrt(t1 - t0).astype(shape.dtype) @@ -157,6 +157,21 @@ def _evaluate_leaf( kk_std = w_std / math.sqrt(720) kk = jr.normal(key_kk, shape.shape, shape.dtype) * kk_std levy_val = SpaceTimeTimeLevyArea(dt=dt, W=w, H=hh, K=kk) + + elif levy_area is WeakSpaceSpaceLevyArea: + # TODO: add doc/reference + key_w, key_hh, key_b = jr.split(key, 3) + w = jr.normal(key_w, shape.shape, shape.dtype) * w_std + hh_std = w_std / math.sqrt(12) + hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std + levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh) + b_std = dt / jnp.sqrt(12) + # TODO: fix for more general shapes + assert len(shape.shape) != 1, "Must be 1D array Wiener process" + b = jr.normal(key_b, shape.shape + shape.shape, shape.dtype) * b_std + b = jnp.tril(b) - jnp.tril(b).T + a = jnp.outer(hh, w) - jnp.outer(w, hh) + b + levy_val = WeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a) elif levy_area is SpaceTimeLevyArea: key_w, key_hh = jr.split(key, 2) diff --git a/diffrax/_custom_types.py b/diffrax/_custom_types.py index 70ec5a1a..e11f25c8 100644 --- a/diffrax/_custom_types.py +++ b/diffrax/_custom_types.py @@ -55,7 +55,7 @@ sentinel: Any = eqxi.doc_repr(object(), "sentinel") -class AbstractBrownianIncrement(eqx.Module): +class AbstractBrownianIncrement(eqx.Module, strict=True): """ Abstract base class for all Brownian increments. """ @@ -64,7 +64,7 @@ class AbstractBrownianIncrement(eqx.Module): W: eqx.AbstractVar[BM] -class AbstractSpaceTimeLevyArea(AbstractBrownianIncrement): +class AbstractSpaceTimeLevyArea(AbstractBrownianIncrement, strict=True): """ Abstract base class for all Space Time Levy Areas. """ @@ -72,7 +72,25 @@ class AbstractSpaceTimeLevyArea(AbstractBrownianIncrement): H: eqx.AbstractVar[BM] -class AbstractSpaceTimeTimeLevyArea(AbstractSpaceTimeLevyArea): +class AbstractWeakSpaceSpaceLevyArea(AbstractBrownianIncrement, strict=True): + """ + Abstract base class for all weak Space Space Levy Areas. + """ + + H: eqx.AbstractVar[BM] + A: eqx.AbstractVar[BM] + + +class WeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea, strict=True): + """ + Abstract base class for all weak Space Space Levy Areas. + """ + + H: BM + A: BM + + +class AbstractSpaceTimeTimeLevyArea(AbstractSpaceTimeLevyArea, strict=True): """ Abstract base class for all Space Time Time Levy Areas. """ @@ -80,7 +98,7 @@ class AbstractSpaceTimeTimeLevyArea(AbstractSpaceTimeLevyArea): K: eqx.AbstractVar[BM] -class BrownianIncrement(AbstractBrownianIncrement): +class BrownianIncrement(AbstractBrownianIncrement, strict=True): """ Pytree containing the `dt` time increment and `W` the Brownian motion. """ @@ -89,7 +107,7 @@ class BrownianIncrement(AbstractBrownianIncrement): W: BM -class SpaceTimeLevyArea(AbstractSpaceTimeLevyArea): +class SpaceTimeLevyArea(AbstractSpaceTimeLevyArea, strict=True): """ Pytree containing the `dt` time increment, `W` the Brownian motion, and `H` the Space Time Levy Area. @@ -100,7 +118,7 @@ class SpaceTimeLevyArea(AbstractSpaceTimeLevyArea): H: BM -class SpaceTimeTimeLevyArea(AbstractSpaceTimeTimeLevyArea): +class SpaceTimeTimeLevyArea(AbstractSpaceTimeTimeLevyArea, strict=True): """ Pytree containing the `dt` time increment, `W` the Brownian motion, `H` the Space Time Levy Area, and `K` the Space Time Time Levy Area. From 7cd6d5edc59fd26ba1418905372f6fa366db06d0 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:28:42 -0700 Subject: [PATCH 2/7] weak --- diffrax/_brownian/path.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index be46b4be..2ec5afa6 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -159,7 +159,7 @@ def _evaluate_leaf( levy_val = SpaceTimeTimeLevyArea(dt=dt, W=w, H=hh, K=kk) elif levy_area is WeakSpaceSpaceLevyArea: - # TODO: add doc/reference + # See (7.4.1) of Foster's thesis key_w, key_hh, key_b = jr.split(key, 3) w = jr.normal(key_w, shape.shape, shape.dtype) * w_std hh_std = w_std / math.sqrt(12) From a54fdd1bc4be250e84a906c7cd92d5aa76a73b9e Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:29:01 -0700 Subject: [PATCH 3/7] format --- diffrax/_brownian/path.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index 2ec5afa6..e756aa94 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -157,7 +157,7 @@ def _evaluate_leaf( kk_std = w_std / math.sqrt(720) kk = jr.normal(key_kk, shape.shape, shape.dtype) * kk_std levy_val = SpaceTimeTimeLevyArea(dt=dt, W=w, H=hh, K=kk) - + elif levy_area is WeakSpaceSpaceLevyArea: # See (7.4.1) of Foster's thesis key_w, key_hh, key_b = jr.split(key, 3) From 770cb05b4af579b6a3bbaa195bf35c00364a3c5e Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:47:51 -0700 Subject: [PATCH 4/7] format --- diffrax/_brownian/path.py | 5 +---- diffrax/_custom_types.py | 20 +++++++++++--------- test/test_brownian.py | 7 ++++++- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index e756aa94..76a2526b 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -159,18 +159,15 @@ def _evaluate_leaf( levy_val = SpaceTimeTimeLevyArea(dt=dt, W=w, H=hh, K=kk) elif levy_area is WeakSpaceSpaceLevyArea: - # See (7.4.1) of Foster's thesis key_w, key_hh, key_b = jr.split(key, 3) w = jr.normal(key_w, shape.shape, shape.dtype) * w_std hh_std = w_std / math.sqrt(12) hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh) b_std = dt / jnp.sqrt(12) - # TODO: fix for more general shapes - assert len(shape.shape) != 1, "Must be 1D array Wiener process" b = jr.normal(key_b, shape.shape + shape.shape, shape.dtype) * b_std b = jnp.tril(b) - jnp.tril(b).T - a = jnp.outer(hh, w) - jnp.outer(w, hh) + b + a = jnp.tensordot(hh, w, axes=0) - jnp.tensordot(w, hh, axes=0) + b levy_val = WeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a) elif levy_area is SpaceTimeLevyArea: diff --git a/diffrax/_custom_types.py b/diffrax/_custom_types.py index e11f25c8..775d3745 100644 --- a/diffrax/_custom_types.py +++ b/diffrax/_custom_types.py @@ -55,7 +55,7 @@ sentinel: Any = eqxi.doc_repr(object(), "sentinel") -class AbstractBrownianIncrement(eqx.Module, strict=True): +class AbstractBrownianIncrement(eqx.Module): """ Abstract base class for all Brownian increments. """ @@ -64,7 +64,7 @@ class AbstractBrownianIncrement(eqx.Module, strict=True): W: eqx.AbstractVar[BM] -class AbstractSpaceTimeLevyArea(AbstractBrownianIncrement, strict=True): +class AbstractSpaceTimeLevyArea(AbstractBrownianIncrement): """ Abstract base class for all Space Time Levy Areas. """ @@ -72,7 +72,7 @@ class AbstractSpaceTimeLevyArea(AbstractBrownianIncrement, strict=True): H: eqx.AbstractVar[BM] -class AbstractWeakSpaceSpaceLevyArea(AbstractBrownianIncrement, strict=True): +class AbstractWeakSpaceSpaceLevyArea(AbstractBrownianIncrement): """ Abstract base class for all weak Space Space Levy Areas. """ @@ -81,16 +81,18 @@ class AbstractWeakSpaceSpaceLevyArea(AbstractBrownianIncrement, strict=True): A: eqx.AbstractVar[BM] -class WeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea, strict=True): +class WeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea): """ - Abstract base class for all weak Space Space Levy Areas. + Davie's approximation to weak Space Space Levy Areas. + See (7.4.1) of Foster's thesis. """ + dt: PyTree[FloatScalarLike, "BM"] H: BM A: BM -class AbstractSpaceTimeTimeLevyArea(AbstractSpaceTimeLevyArea, strict=True): +class AbstractSpaceTimeTimeLevyArea(AbstractSpaceTimeLevyArea): """ Abstract base class for all Space Time Time Levy Areas. """ @@ -98,7 +100,7 @@ class AbstractSpaceTimeTimeLevyArea(AbstractSpaceTimeLevyArea, strict=True): K: eqx.AbstractVar[BM] -class BrownianIncrement(AbstractBrownianIncrement, strict=True): +class BrownianIncrement(AbstractBrownianIncrement): """ Pytree containing the `dt` time increment and `W` the Brownian motion. """ @@ -107,7 +109,7 @@ class BrownianIncrement(AbstractBrownianIncrement, strict=True): W: BM -class SpaceTimeLevyArea(AbstractSpaceTimeLevyArea, strict=True): +class SpaceTimeLevyArea(AbstractSpaceTimeLevyArea): """ Pytree containing the `dt` time increment, `W` the Brownian motion, and `H` the Space Time Levy Area. @@ -118,7 +120,7 @@ class SpaceTimeLevyArea(AbstractSpaceTimeLevyArea, strict=True): H: BM -class SpaceTimeTimeLevyArea(AbstractSpaceTimeTimeLevyArea, strict=True): +class SpaceTimeTimeLevyArea(AbstractSpaceTimeTimeLevyArea): """ Pytree containing the `dt` time increment, `W` the Brownian motion, `H` the Space Time Levy Area, and `K` the Space Time Time Levy Area. diff --git a/test/test_brownian.py b/test/test_brownian.py index 3a265019..c680a9e4 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -36,9 +36,14 @@ def _make_struct(shape, dtype): @pytest.mark.parametrize( "ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree] ) -@pytest.mark.parametrize("levy_area", _levy_areas) +@pytest.mark.parametrize("levy_area", _levy_areas + (diffrax.WeakSpaceSpaceLevyArea,)) @pytest.mark.parametrize("use_levy", (False, True)) def test_shape_and_dtype(ctr, levy_area, use_levy, getkey): + if ( + levy_area is diffrax.WeakSpaceSpaceLevyArea + and ctr is diffrax.VirtualBrownianTree + ): + return t0 = 0.0 t1 = 2.0 From 96972115ce6539b5be038d2137de5ebed7c3da76 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:20:29 -0700 Subject: [PATCH 5/7] a --- diffrax/_brownian/path.py | 7 ++++++- diffrax/_custom_types.py | 4 +++- test/test_brownian.py | 7 +------ 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index 76a2526b..043c064f 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -166,7 +166,12 @@ def _evaluate_leaf( levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh) b_std = dt / jnp.sqrt(12) b = jr.normal(key_b, shape.shape + shape.shape, shape.dtype) * b_std - b = jnp.tril(b) - jnp.tril(b).T + if b.ndim == 0 or b.size == 1: + b = jnp.zeros(shape=shape.shape + shape.shape, dtype=shape.dtype) + else: + # TODO: generalize to tensors? + assert b.ndim == 2 + b = jnp.tril(b) - jnp.tril(b).T a = jnp.tensordot(hh, w, axes=0) - jnp.tensordot(w, hh, axes=0) + b levy_val = WeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a) diff --git a/diffrax/_custom_types.py b/diffrax/_custom_types.py index 775d3745..692f538d 100644 --- a/diffrax/_custom_types.py +++ b/diffrax/_custom_types.py @@ -48,6 +48,7 @@ Args = PyTree[Any] BM = PyTree[Shaped[ArrayLike, "?*bm"], "BM"] +Area = PyTree[Shaped[ArrayLike, "?*area"], "Area"] DenseInfo = dict[str, PyTree[Array]] DenseInfos = dict[str, PyTree[Shaped[Array, "times-1 ..."]]] @@ -88,8 +89,9 @@ class WeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea): """ dt: PyTree[FloatScalarLike, "BM"] + W: BM H: BM - A: BM + A: Area class AbstractSpaceTimeTimeLevyArea(AbstractSpaceTimeLevyArea): diff --git a/test/test_brownian.py b/test/test_brownian.py index c680a9e4..3a265019 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -36,14 +36,9 @@ def _make_struct(shape, dtype): @pytest.mark.parametrize( "ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree] ) -@pytest.mark.parametrize("levy_area", _levy_areas + (diffrax.WeakSpaceSpaceLevyArea,)) +@pytest.mark.parametrize("levy_area", _levy_areas) @pytest.mark.parametrize("use_levy", (False, True)) def test_shape_and_dtype(ctr, levy_area, use_levy, getkey): - if ( - levy_area is diffrax.WeakSpaceSpaceLevyArea - and ctr is diffrax.VirtualBrownianTree - ): - return t0 = 0.0 t1 = 2.0 From 68b750b03ad12bb49291e98018044436e150ff30 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Tue, 20 Aug 2024 23:25:56 -0700 Subject: [PATCH 6/7] a --- diffrax/_brownian/path.py | 1 - 1 file changed, 1 deletion(-) diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index 043c064f..e14e70b0 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -163,7 +163,6 @@ def _evaluate_leaf( w = jr.normal(key_w, shape.shape, shape.dtype) * w_std hh_std = w_std / math.sqrt(12) hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std - levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh) b_std = dt / jnp.sqrt(12) b = jr.normal(key_b, shape.shape + shape.shape, shape.dtype) * b_std if b.ndim == 0 or b.size == 1: From 8f4b4cc3fe1d5728385af30ca95890da4f094cda Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Fri, 23 Aug 2024 16:23:23 -0700 Subject: [PATCH 7/7] add foster --- diffrax/__init__.py | 3 +- diffrax/_brownian/path.py | 64 +++++++++++++++++++++++++++++++-------- diffrax/_custom_types.py | 14 ++++++++- diffrax/_integrate.py | 2 +- test/test_brownian.py | 16 +++++++++- 5 files changed, 83 insertions(+), 16 deletions(-) diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 75fa14e4..5f055a2c 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -19,9 +19,10 @@ AbstractSpaceTimeTimeLevyArea as AbstractSpaceTimeTimeLevyArea, AbstractWeakSpaceSpaceLevyArea as AbstractWeakSpaceSpaceLevyArea, BrownianIncrement as BrownianIncrement, + DavieFosterWeakSpaceSpaceLevyArea as DavieFosterWeakSpaceSpaceLevyArea, + DavieWeakSpaceSpaceLevyArea as DavieWeakSpaceSpaceLevyArea, SpaceTimeLevyArea as SpaceTimeLevyArea, SpaceTimeTimeLevyArea as SpaceTimeTimeLevyArea, - WeakSpaceSpaceLevyArea as WeakSpaceSpaceLevyArea, ) from ._event import ( # Deliberately not provided with `X as X` as these are now deprecated, so we'd like diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index e14e70b0..b2d7f0bc 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -14,11 +14,12 @@ from .._custom_types import ( AbstractBrownianIncrement, BrownianIncrement, + DavieFosterWeakSpaceSpaceLevyArea, + DavieWeakSpaceSpaceLevyArea, levy_tree_transpose, RealScalarLike, SpaceTimeLevyArea, SpaceTimeTimeLevyArea, - WeakSpaceSpaceLevyArea, ) from .._misc import ( force_bitcast_convert_type, @@ -29,7 +30,11 @@ _Levy_Areas = Union[ - BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea, WeakSpaceSpaceLevyArea + BrownianIncrement, + SpaceTimeLevyArea, + SpaceTimeTimeLevyArea, + DavieWeakSpaceSpaceLevyArea, + DavieFosterWeakSpaceSpaceLevyArea, ] @@ -158,21 +163,56 @@ def _evaluate_leaf( kk = jr.normal(key_kk, shape.shape, shape.dtype) * kk_std levy_val = SpaceTimeTimeLevyArea(dt=dt, W=w, H=hh, K=kk) - elif levy_area is WeakSpaceSpaceLevyArea: + elif levy_area is DavieWeakSpaceSpaceLevyArea: + key_w, key_hh, key_b = jr.split(key, 3) + w = jr.normal(key_w, shape.shape, shape.dtype) * w_std + hh_std = w_std / math.sqrt(12) + hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std + if w.ndim == 0 or w.ndim == 1: + a = jnp.zeros_like(w, dtype=shape.dtype) + levy_val = DavieWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a) + else: + b_std = (dt / jnp.sqrt(12)).astype(shape.dtype) + b = ( + jr.normal(key_b, shape.shape + shape.shape[-1:], shape.dtype) + * b_std + ) + b = b - b.transpose(*range(b.ndim - 2), -1, -2) + a = jnp.expand_dims(hh, -1) * jnp.expand_dims(w, -2) - jnp.expand_dims( + w, -1 + ) * jnp.expand_dims(hh, -2) + a += b + levy_val = DavieWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a) + + elif levy_area is DavieFosterWeakSpaceSpaceLevyArea: key_w, key_hh, key_b = jr.split(key, 3) w = jr.normal(key_w, shape.shape, shape.dtype) * w_std hh_std = w_std / math.sqrt(12) hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std - b_std = dt / jnp.sqrt(12) - b = jr.normal(key_b, shape.shape + shape.shape, shape.dtype) * b_std - if b.ndim == 0 or b.size == 1: - b = jnp.zeros(shape=shape.shape + shape.shape, dtype=shape.dtype) + if w.ndim == 0 or w.ndim == 1: + a = jnp.zeros_like(w, dtype=shape.dtype) + levy_val = DavieFosterWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a) else: - # TODO: generalize to tensors? - assert b.ndim == 2 - b = jnp.tril(b) - jnp.tril(b).T - a = jnp.tensordot(hh, w, axes=0) - jnp.tensordot(w, hh, axes=0) + b - levy_val = WeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a) + tenth_dt = (0.1 * dt).astype(shape.dtype) + hh_squared = hh**2 + b_std = jnp.sqrt( + tenth_dt + * ( + tenth_dt + + jnp.expand_dims(hh_squared, -1) + + jnp.expand_dims(hh_squared, -2) + ) + ).astype(shape.dtype) + b = ( + jr.normal(key_b, shape.shape + shape.shape[-1:], shape.dtype) + * b_std + ) + b = b - b.transpose(*range(b.ndim - 2), -1, -2) + a = jnp.expand_dims(hh, -1) * jnp.expand_dims(w, -2) - jnp.expand_dims( + w, -1 + ) * jnp.expand_dims(hh, -2) + a += b + levy_val = DavieFosterWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a) elif levy_area is SpaceTimeLevyArea: key_w, key_hh = jr.split(key, 2) diff --git a/diffrax/_custom_types.py b/diffrax/_custom_types.py index 692f538d..8bfe16f1 100644 --- a/diffrax/_custom_types.py +++ b/diffrax/_custom_types.py @@ -82,7 +82,7 @@ class AbstractWeakSpaceSpaceLevyArea(AbstractBrownianIncrement): A: eqx.AbstractVar[BM] -class WeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea): +class DavieWeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea): """ Davie's approximation to weak Space Space Levy Areas. See (7.4.1) of Foster's thesis. @@ -94,6 +94,18 @@ class WeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea): A: Area +class DavieFosterWeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea): + """ + Davie's approximation to weak Space Space Levy Areas. + See (7.4.2) of Foster's thesis. + """ + + dt: PyTree[FloatScalarLike, "BM"] + W: BM + H: BM + A: Area + + class AbstractSpaceTimeTimeLevyArea(AbstractSpaceTimeLevyArea): """ Abstract base class for all Space Time Time Levy Areas. diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 938eee37..a278a7a0 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -1054,7 +1054,7 @@ def _promote(yi): if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): # Specific check to not work even if using HalfSolver(Euler()) if isinstance(solver, Euler): - raise ValueError( + warnings.warn( "An SDE should not be solved with adaptive step sizes with Euler's " "method, as it may not converge to the correct solution." ) diff --git a/test/test_brownian.py b/test/test_brownian.py index 3a265019..978cb49d 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -1,3 +1,7 @@ +import jax + + +jax.config.update("jax_enable_x64", True) import contextlib import math from typing import Literal @@ -36,12 +40,22 @@ def _make_struct(shape, dtype): @pytest.mark.parametrize( "ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree] ) -@pytest.mark.parametrize("levy_area", _levy_areas) +@pytest.mark.parametrize( + "levy_area", + _levy_areas + + (diffrax.DavieWeakSpaceSpaceLevyArea, diffrax.DavieFosterWeakSpaceSpaceLevyArea), +) @pytest.mark.parametrize("use_levy", (False, True)) def test_shape_and_dtype(ctr, levy_area, use_levy, getkey): t0 = 0.0 t1 = 2.0 + if ( + issubclass(levy_area, diffrax.AbstractWeakSpaceSpaceLevyArea) + and ctr is diffrax.VirtualBrownianTree + ): + return + shapes_dtypes1 = ( ((), None), ((0,), None),