diff --git a/pyproject.toml b/pyproject.toml index 6d0b8c2..49a2a0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ name = "sequence_layers" description = "Sequence Layers neural network layer library from Google." readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.13" license = {file = "LICENSE"} authors = [ {name = "RJ Skerry-Ryan", email="rjryan@google.com"}, @@ -24,7 +24,7 @@ dependencies = [ "jaxtyping", "numpy", "orbax-export", - "recurrentgemma[jax]", + "recurrentgemma[jax]>=1.0.1", "typeguard==2.13.3", ] @@ -50,13 +50,15 @@ mlx = [ "mlx", ] dev = [ - "absl-py", + "absl-py>=2.4.0", "chex", "orbax", + "isort", + "pyink", + "pylint>=2.6.0", + "pyrefly>=0.58.0", "pytest", "pytest-xdist", - "pylint>=2.6.0", - "pyink", "tensorflow", # JAX tests use TensorFlow. ] @@ -67,6 +69,24 @@ unstable = true pyink-indentation = 2 pyink-use-majority-quotes = true +[tool.isort] +profile = "google" +line_length = 80 + +[tool.pylint.master] +extension-pkg-whitelist = ["mlx", "mlx.core"] + +[tool.pylint.format] +indent-string = " " + +[tool.pylint.basic] +no-docstring-rgx = "^(_)?test_|^.*Test$|^__.*__$" + +[tool.pylint.messages_control] +disable = ["too-many-lines", "too-many-ancestors", "too-few-public-methods", "duplicate-code"] + + + [build-system] # Build system specify which backend is used to build/install the project (flit, # poetry, setuptools,...). All backends are supported by `pip install` @@ -81,4 +101,7 @@ exclude = [ # Do not release test files on PyPI "**/*_test.py", "testdata/**", -] \ No newline at end of file +] + +[tool.pyrefly] +errors = { missing-override-decorator = "error" } \ No newline at end of file diff --git a/sequence_layers/abstract/types.py b/sequence_layers/abstract/types.py deleted file mode 100644 index 4eaa3e9..0000000 --- a/sequence_layers/abstract/types.py +++ /dev/null @@ -1,297 +0,0 @@ -"""Abstract base classes and types for SequenceLayers.""" - -import abc -import enum -import fractions -from typing import Any, Callable, Generic, Iterable, Literal, TypeVar - -import numpy as np - -# Type aliases for generic usage -T = TypeVar('T') -ValuesT = TypeVar('ValuesT') -MaskT = TypeVar('MaskT') -SequenceSelf = TypeVar('SequenceSelf', bound='Sequence') -Shape = tuple[int, ...] -ShapeLike = list[int] | tuple[int, ...] -DType = Any # Can be numpy, jax, or mlx dtype -ChannelSpec = Any # Typically ShapeDType or compatible - -class PaddingMode(enum.Enum): - """Supported padding modes.""" - - # In VALID padding mode, no padding is applied. - # - # Key properties: - # * The physical length of an input array to a VALID padded function shrinks, - # dropping any timesteps whose inputs are computed from implicit edge - # padding. - # * An output timestep is valid when all of its input timesteps are also - # valid. - VALID = 'valid' - - # In SAME padding mode, the input sequence is padded such that the output - # length is equal to the input length before applying striding. - # - # Key properties: - # * The input length is equal to the output length, before applying striding. - # * Padding of `effective_kernel_size - 1` is applied. Half is applied to the - # front and half to the back. If `effective_kernel_size` is even, the extra - # padding is added to the end. - # * An output timestep is valid when its corresponding input timestep is - # valid. - SAME = 'same' - - # In CAUSAL_VALID padding mode, the input sequence is padded such that the - # output length is equal to the input length before applying striding. Padding - # is applied such that the output timestep `to` can only depend on input - # timesteps at or before `ti` where `ti * output_ratio = to`. - # - # Key properties: - # * As in SAME padding, the input length is equal to the output length, before - # applying striding. - # * Padding of `effective_kernel_size - 1` is applied to the front of the - # sequence. - # * As in VALID padding, an output timestep is valid iff all of its input - # timesteps are also valid. - CAUSAL_VALID = 'causal_valid' - - # In REVERSE_CAUSAL_VALID padding mode, the input sequence is padded such that - # the output length is equal to the input length before applying striding. - # Padding is applied such that the output timestep `to` can only depend on - # input timesteps at or after `ti` where `ti * output_ratio = to`. - # - # Key properties: - # * As in SAME padding, the input length is equal to the output length, before - # applying striding. - # * Padding of `effective_kernel_size - 1` is applied to the back of the - # sequence. - REVERSE_CAUSAL_VALID = 'reverse_causal_valid' - - # In CAUSAL padding mode, the input sequence is padded such that the output - # length is equal to the input length before applying striding. Padding is - # applied such that the output timestep `to` can only depend on input - # timesteps at or before `ti` where `ti * output_ratio = to`. - # - # Key properties: - # * As in SAME padding, the input length is equal to the output length, before - # applying striding. - # * Padding of `effective_kernel_size - 1` is applied to the front of the - # sequence. - # * As in SAME padding, an output timestep is valid when its corresponding - # input timestep is valid. - CAUSAL = 'causal' - - # In REVERSE_CAUSAL padding mode, the input sequence is padded such that the - # output length is equal to the input length before applying striding. Padding - # is applied such that the output timestep `to` can only depend on input - # timesteps at or after `ti` where `ti * output_ratio = to`. - # - # Key properties: - # * As in SAME padding, the input length is equal to the output length, before - # applying striding. - # * Padding of `effective_kernel_size - 1` is applied to the back of the - # sequence. - # * As in SAME padding, an output timestep is valid when its corresponding - # input timestep is valid. - REVERSE_CAUSAL = 'reverse_causal' - - # In SEMICAUSAL padding mode, the input sequence is padded such that the - # output length is equal to the input length before applying striding. Padding - # is applied such that the output timestep `to` can only depend on input - # timesteps at or before `ti` where `ti * output_ratio = to`. - # - # Key properties: - # * As in SAME padding, the input length is equal to the output length, before - # applying striding. - # * Padding of `effective_kernel_size - stride` is applied to the front of the - # sequence, and padding of `stride - 1` timesteps is applied to the back of - # the sequence for a total of `effective_kernel_size - 1` timesteps of - # padding. If `effective_kernel_size` < `stride`, then padding of - # `effective_kernel_size - 1` is applied to the back of the sequence. - # * As in SAME padding, an output timestep is valid when its corresponding - # input timestep is valid. - SEMICAUSAL = 'semicausal' - - # In SEMICAUSAL_FULL padding mode, the input sequence is padded such that the - # output of the corresponding overlap-add or transpose convolution is of the - # same size as the input sequence and perfect reconstruction can be achieved. - # The reconstructed signal is of the same length or of length rounded up to - # cover the full input sequence. - SEMICAUSAL_FULL = 'semicausal_full' - -PaddingModeString = Literal[ - 'valid', - 'same', - 'causal_valid', - 'reverse_causal_valid', - 'causal', - 'reverse_causal', - 'semicausal', - 'semicausal_full', -] - -class Sequence(Generic[ValuesT, MaskT], metaclass=abc.ABCMeta): - """Abstract base class for Sequence.""" - - values: ValuesT - mask: MaskT - - @property - @abc.abstractmethod - def shape(self) -> Shape: - pass - - @property - @abc.abstractmethod - def ndim(self) -> int: - pass - - @property - @abc.abstractmethod - def channel_shape(self) -> Shape: - pass - - @property - @abc.abstractmethod - def dtype(self) -> DType: - pass - - @classmethod - @abc.abstractmethod - def from_values(cls, values: ValuesT) -> 'Sequence': - pass - - @classmethod - @abc.abstractmethod - def concatenate_sequences(cls, sequences: Iterable['Sequence']) -> 'Sequence': - pass - - @abc.abstractmethod - def expanded_mask(self) -> Any: - pass - - @abc.abstractmethod - def apply_values( - self, - values_fn: Callable[..., ValuesT], - *args, - **kwargs, - ) -> 'Sequence': - pass - - @abc.abstractmethod - def apply_values_masked( - self: SequenceSelf, - values_fn: Callable[..., ValuesT], - *args, - **kwargs, - ) -> SequenceSelf: - pass - - @abc.abstractmethod - def apply( - self, - apply_fn: Callable[..., tuple[ValuesT, MaskT]], - *args, - **kwargs, - ) -> 'Sequence': - pass - - @abc.abstractmethod - def apply_masked( - self: SequenceSelf, - apply_fn: Callable[..., tuple[ValuesT, MaskT]], - *args, - **kwargs, - ) -> SequenceSelf: - pass - - @abc.abstractmethod - def astype(self: SequenceSelf, dtype: DType | None) -> SequenceSelf: - pass - - @abc.abstractmethod - def lengths(self) -> Any: - pass - - @abc.abstractmethod - def __getitem__(self: SequenceSelf, the_slice: Any) -> SequenceSelf: - pass - - @abc.abstractmethod - def pad_time( - self: SequenceSelf, - pad_left: int, - pad_right: int, - valid: bool, - pad_value: Any | None = None, - ) -> SequenceSelf: - pass - - @abc.abstractmethod - def concatenate(self, other: 'Sequence') -> 'Sequence': - pass - - @abc.abstractmethod - def mask_invalid(self, mask_value: Any | None = None) -> 'Sequence': - pass - - @abc.abstractmethod - def unmask(self) -> 'Sequence': - pass - - -class SequenceLayerConfig(metaclass=abc.ABCMeta): - """Configuration for a SequenceLayer.""" - - @abc.abstractmethod - def make(self) -> Any: - """Creates the sequence layer.""" - - @abc.abstractmethod - def copy(self, **kwargs) -> 'SequenceLayerConfig': - """Returns a copy of the config with updated fields.""" - - -class Steppable(metaclass=abc.ABCMeta): - """A sequence processing layer that can be executed layerwise or stepwise.""" - - @property - @abc.abstractmethod - def block_size(self) -> int: - pass - - @property - @abc.abstractmethod - def output_ratio(self) -> fractions.Fraction: - pass - - @property - @abc.abstractmethod - def supports_step(self) -> bool: - pass - - @property - @abc.abstractmethod - def input_latency(self) -> int: - pass - - @property - @abc.abstractmethod - def output_latency(self) -> int: - pass - - @abc.abstractmethod - def get_accumulated_input_latency(self, input_latency: int) -> int: - pass - - @abc.abstractmethod - def get_accumulated_output_latency(self, output_latency: int) -> int: - pass - - @property - @abc.abstractmethod - def receptive_field(self) -> Any: - pass - diff --git a/sequence_layers/abstract/types_test_base.py b/sequence_layers/abstract/types_test_base.py deleted file mode 100644 index f9cd54b..0000000 --- a/sequence_layers/abstract/types_test_base.py +++ /dev/null @@ -1,338 +0,0 @@ -"""Abstract tests for Sequence types.""" - -import abc -from typing import Any, Callable, Sequence as TypingSequence -import dataclasses -from sequence_layers.abstract import types - -import fractions -from absl.testing import parameterized -import numpy as np - -class SequenceLayerTest(parameterized.TestCase): - """Base abstract test class providing common sequence testing assertions.""" - - @abc.abstractmethod - def assertSequencesClose(self, x: Any, y: Any, **kwargs): - pass - - @abc.abstractmethod - def assertSequencesNotClose(self, x: Any, y: Any, **kwargs): - pass - - @abc.abstractmethod - def assertSequencesEqual(self, x: Any, y: Any): - pass - - @abc.abstractmethod - def assertSequencesNotEqual(self, x: Any, y: Any): - pass - - @abc.abstractmethod - def assertAllEqual(self, x: Any, y: Any): - pass - - @abc.abstractmethod - def assertAllClose(self, x: Any, y: Any, **kwargs): - pass - - @abc.abstractmethod - def assertNotAllEqual(self, x: Any, y: Any): - pass - - @abc.abstractmethod - def assertNotAllClose(self, x: Any, y: Any, **kwargs): - pass - - -class SequenceTest(SequenceLayerTest): - """Abstract tests for the Sequence class.""" - - @abc.abstractmethod - def get_backend(self) -> Any: - """Returns the backend module (jax.numpy or mlx.core).""" - - @property - @abc.abstractmethod - def Sequence(self) -> type[types.Sequence]: - """Returns the Sequence class for the backend.""" - - @property - @abc.abstractmethod - def MaskedSequence(self) -> Any: - """Returns the MaskedSequence class for the backend.""" - - @property - def check_trees_all_equal(self) -> Callable[[Any, Any], None]: - """Returns a function to check tree equality.""" - return self.assertAllEqual - - def test_mask_invalid_idempotent(self): - xp = self.get_backend() - values = xp.array([ - [1.0, 2.0, 3.0, 4.0], - [10.0, 20.0, 30.0, 40.0], - ]) - # Different backends might handle boolean creation differently, but standard numpy-like syntax usually works - mask = xp.array([[True, True, False, False], [False, False, False, True]]) - - x = self.Sequence(values, mask) - masked = x.mask_invalid() - self.assertIsNot(masked, x) - # We can't easily check isinstance here without importing the concrete classes, - # but we can check behavior or use a property if we added one. - # For now, we trust the concrete tests to check types if needed, - # or we could add abstract methods to check types. - - masked_again = masked.mask_invalid() - self.assertIs(masked_again, masked) - - masked2 = x.mask_invalid() - self.assertIsNot(masked2, masked) - - @parameterized.named_parameters( - ('mask_value=None', 0.0, None), - ('mask_value=0.0', 0.0, 0.0), - ('mask_value=-1.0', -1.0, -1.0), - ) - def test_mask_invalid(self, mask_value, expected_mask_value): - xp = self.get_backend() - values = xp.array([ - [1.0, 2.0, 3.0, 4.0], - [10.0, 20.0, 30.0, 40.0], - ]) - mask = xp.array([[True, True, False, False], [False, False, False, True]]) - - # Pass mask_value only if it is not None (to test default None behavior vs explicit value) - if expected_mask_value is None: - output = self.Sequence(values, mask).mask_invalid() - fill_value = 0.0 - else: - output = self.Sequence(values, mask).mask_invalid(mask_value) - fill_value = mask_value - - expected_values = xp.array([ - [1.0, 2.0, fill_value, fill_value], - [fill_value, fill_value, fill_value, 40.0], - ]) - self.check_trees_all_equal(output.values, expected_values) - self.check_trees_all_equal(output.mask, mask) - - def test_pad_time(self): - xp = self.get_backend() - values = xp.array([ - [1.0, 2.0, 3.0, 4.0], - [10.0, 20.0, 30.0, 40.0], - ]) - mask = xp.array([[True, True, False, False], [False, False, False, True]]) - - x = self.Sequence(values, mask).mask_invalid() - - y = x.pad_time(0, 0, valid=False) - self.check_trees_all_equal(y.values, x.values) - self.check_trees_all_equal(y.mask, x.mask) - - y = x.pad_time(1, 0, valid=False) - - x_left1 = self.Sequence( - xp.array([ - [0.0, 1.0, 2.0, 3.0, 4.0], - [0.0, 10.0, 20.0, 30.0, 40.0], - ]), - xp.array([ - [False, True, True, False, False], - [False, False, False, False, True], - ]), - ).mask_invalid() - self.check_trees_all_equal(y.values, x_left1.values) - self.check_trees_all_equal(y.mask, x_left1.mask) - - def _create_test_sequence(self, shape): - xp = self.get_backend() - size = 1 - for d in shape: size *= d - values_np = np.arange(size, dtype=np.float32).reshape(shape) - mask_np = np.ones(shape[:2], dtype=bool) - if shape[0] > 0 and shape[1] > 1: - mask_np[0, 1] = False - - values = xp.array(values_np) - mask = xp.array(mask_np) - return self.Sequence(values, mask) - - def test_slice(self): - x = self._create_test_sequence((3, 5, 9)) - - self.assertSequencesEqual( - x[:, 1:], self.Sequence(x.values[:, 1:], x.mask[:, 1:]) - ) - self.assertSequencesEqual( - x[:, ::2], self.Sequence(x.values[:, ::2], x.mask[:, ::2]) - ) - self.assertSequencesEqual( - x[::2, ::3], self.Sequence(x.values[::2, ::3], x.mask[::2, ::3]) - ) - - def test_slice_can_slice_channel_dimensions(self): - x = self._create_test_sequence((3, 5, 9, 4)) - - self.assertSequencesEqual( - x[:, 1:, :], self.Sequence(x.values[:, 1:], x.mask[:, 1:]) - ) - self.assertSequencesEqual( - x[:, ::2, :3], - self.Sequence(x.values[:, ::2, :3], x.mask[:, ::2]), - ) - - def test_apply_values(self): - xp = self.get_backend() - values = xp.array([ - [-1.0, 2.0, 3.0, 4.0], - [10.0, -20.0, 30.0, 40.0], - ]) - mask = xp.array([[True, True, False, False], [False, True, False, True]]) - - x = self.Sequence(values, mask) - masked = x.mask_invalid() - - # Simple abs function - fn = abs - - y = x.apply_values(fn) - self.check_trees_all_equal(y.values, fn(x.values)) - self.check_trees_all_equal(y.mask, x.mask) - - y = masked.apply_values(fn) - self.check_trees_all_equal(y.values, fn(masked.values)) - self.check_trees_all_equal(y.mask, x.mask) - - y = masked.apply_values_masked(fn) - self.check_trees_all_equal(y.values, fn(masked.values)) - self.check_trees_all_equal(y.mask, x.mask) - - def test_apply_values_args(self): - xp = self.get_backend() - values = xp.array([ - [-1.0, 2.0, 3.0, 4.0], - [10.0, -20.0, 30.0, 40.0], - ]) - mask = xp.array([[True, True, False, False], [False, True, False, True]]) - x = self.Sequence(values, mask) - - target_shape = (2, 4, 1) - y = x.apply_values(lambda v, s: v.reshape(s), target_shape) - self.check_trees_all_equal(y.values.shape, target_shape) - self.check_trees_all_equal(y.mask.shape, (2, 4)) - - def test_from_values(self): - xp = self.get_backend() - values_np = np.array([ - [1.0, 2.0], - [3.0, 4.0] - ], dtype=np.float32) - values = xp.array(values_np) - # Get the class from an instance - seq = self.Sequence(values, xp.array(np.ones(values.shape[:2], dtype=bool))) - SeqClass = type(seq) - - x = SeqClass.from_values(values) - self.check_trees_all_equal(x.values, values) - self.check_trees_all_equal(x.mask, xp.array(np.ones(values.shape[:2], dtype=bool))) - - def test_astype(self): - xp = self.get_backend() - values_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) - mask_np = np.array([[True, False], [False, True]], dtype=bool) - - values = xp.array(values_np) - mask = xp.array(mask_np) - - x = self.Sequence(values, mask) - - # We need a dtype that matches the backend - if xp.__name__ == 'jax.numpy': - dtype = xp.int32 - elif xp.__name__ == 'mlx.core': - dtype = xp.int32 - else: - dtype = np.int32 - - y = x.astype(dtype) - - # Check values match casted version - self.check_trees_all_equal(y.mask, mask) - # y.values might be mlx array, values.astype(dtype) might be numpy if values was numpy? - # values is backend array. values.astype(dtype) should work if dtype is backend dtype. - self.check_trees_all_equal(y.values, values.astype(dtype)) - - - -class SteppableTest(parameterized.TestCase): - """Abstract tests for Steppable layers.""" - - @abc.abstractmethod - def create_steppable(self) -> Any: - """Creates a basic Steppable instance that should have default properties.""" - - def test_steppable_defaults(self): - layer = self.create_steppable() - self.assertEqual(layer.block_size, 1) - self.assertEqual(layer.output_ratio, fractions.Fraction(1)) - self.assertTrue(layer.supports_step) - self.assertEqual(layer.input_latency, 0) - self.assertEqual(layer.output_latency, 0) - self.assertEqual(layer.get_accumulated_input_latency(0), 0) - self.assertEqual(layer.get_accumulated_output_latency(0), 0) - - -class SequenceLayerConfigTest(SequenceLayerTest): - - @abc.abstractmethod - def get_config_base_cls(self) -> type[types.SequenceLayerConfig]: - """Returns the backend-specific SequenceLayerConfig class.""" - - def test_copy(self): - ConfigBase = self.get_config_base_cls() - - @dataclasses.dataclass(frozen=True) - class Config(ConfigBase): - a: int = 1234 - b: str = 'default string' - - def make(self) -> Any: - return 'dummy_layer' - - config = Config() - new_config = config.copy(b='new string') - self.assertEqual(new_config.a, config.a) - self.assertEqual(new_config.b, 'new string') - - def test_copy_raises_on_non_dataclass(self): - ConfigBase = self.get_config_base_cls() - - class NonDataclassConfig(ConfigBase): - - def make(self) -> Any: - return 'dummy_layer' - - config = NonDataclassConfig() - with self.assertRaises(TypeError): - new_config = config.copy() - del new_config - - def test_copy_disallows_new_fields(self): - ConfigBase = self.get_config_base_cls() - - @dataclasses.dataclass(frozen=True) - class Config(ConfigBase): - - def make(self) -> Any: - return 'dummy_layer' - - config = Config() - # dataclasses.replace raises TypeError for unknown arguments - # JAX implementation wraps it in AttributeError - with self.assertRaises((TypeError, AttributeError)): - new_config = config.copy(field_does_not_exist=1234) - del new_config - diff --git a/sequence_layers/jax/__init__.py b/sequence_layers/jax/__init__.py index 85bb162..dcbb3a0 100644 --- a/sequence_layers/jax/__init__.py +++ b/sequence_layers/jax/__init__.py @@ -27,3 +27,7 @@ from sequence_layers.jax.simple import * from sequence_layers.jax.time_varying import * from sequence_layers.jax.types import * + +# (re-export the names for typechecking) +from . import backend as backend +from . import types as types diff --git a/sequence_layers/jax/backend.py b/sequence_layers/jax/backend.py new file mode 100644 index 0000000..f29d4e1 --- /dev/null +++ b/sequence_layers/jax/backend.py @@ -0,0 +1,26 @@ +"""Backend-specific helpers (JAX)""" + +from typing import override + +import jax.numpy as jnp + +from sequence_layers.specs import backend as spec +from sequence_layers.specs import types as types_spec + + +class BackendWrapper(spec.xp): + """Thin wrapper around JAX to match NumPy interface for tests.""" + + bool_ = jnp.bool_ + int32 = jnp.int32 + + @override + def array(self, a, dtype=None) -> types_spec.Array: + return jnp.array(a, dtype=dtype) + + @override + def zeros(self, shape, dtype=None) -> types_spec.Array: + return jnp.zeros(shape, dtype=dtype) + + +xp: spec.xp = BackendWrapper() diff --git a/sequence_layers/jax/test_utils.py b/sequence_layers/jax/test_utils.py index ce6741e..13add8f 100644 --- a/sequence_layers/jax/test_utils.py +++ b/sequence_layers/jax/test_utils.py @@ -18,7 +18,9 @@ import itertools import logging import random -from typing import Any, Callable, Iterable, Mapping, Sequence as TypingSequence, TypeVar +from typing import Any, Callable, Iterable, Mapping +from typing import Sequence as TypingSequence +from typing import TypeVar from absl.testing import absltest from absl.testing import parameterized @@ -27,10 +29,12 @@ import jax import jax.numpy as jnp import numpy as np + from sequence_layers.jax import types from sequence_layers.jax import typing as jt from sequence_layers.jax import utils - +import sequence_layers.jax as sl +from sequence_layers.specs import test_utils as spec _SequenceLayerT = TypeVar('_SequenceLayerT', bound=types.SequenceLayer) _T = TypeVar('_T') @@ -777,9 +781,11 @@ def _mask_and_pad_to_max_length( return a, b -class SequenceLayerTest(parameterized.TestCase): +class SequenceLayerTest(spec.SequenceLayerTest): """Base class for SequenceLayer tests.""" + sl = sl + def setUp(self): super().setUp() # To avoid flakes, fix random seeds. diff --git a/sequence_layers/jax/types.py b/sequence_layers/jax/types.py index 57e9d3f..1c82b6c 100644 --- a/sequence_layers/jax/types.py +++ b/sequence_layers/jax/types.py @@ -20,7 +20,21 @@ import functools import math import typing -from typing import Any, Callable, Generic, Iterable, Literal, MutableMapping, ParamSpec, Protocol, Self, Sequence as TypingSequence, TypeVar, override +from typing import ( + Any, + Callable, + Concatenate, + Generic, + Iterable, + Literal, + MutableMapping, + override, + ParamSpec, + Protocol, + Self, +) +from typing import Sequence as TypingSequence +from typing import TypeVar from absl import logging from flax import linen as nn @@ -29,12 +43,11 @@ from jax import numpy as jnp import jaxtyping import numpy as np +import typeguard -from sequence_layers.abstract import types from sequence_layers.jax import sharding as sharding_lib from sequence_layers.jax import typing as jt -import typeguard - +from sequence_layers.specs import types as spec __all__ = ( # go/keep-sorted start @@ -87,17 +100,17 @@ # Sequence type aliases: MASK_DTYPE = np.bool_ -# A rank 2+ tensor of any type. +# A rank 2+ array of any type. ValuesT = TypeVar('ValuesT', bound=jt.Shaped[jt.ArrayT, 'B T *C']) -# A boolean batched mask tensor. True indicates a given timepoint is valid, and +# A boolean batched mask array. True indicates a given timepoint is valid, and # False indicates it is invalid. MaskT = TypeVar('MaskT', bound=jt.Bool[jt.ArrayT, 'B T']) -# An integer batched lengths tensor. +# An integer batched lengths array. LengthsT = TypeVar('LengthsT', bound=jt.Int[jt.ArrayT, 'B']) -# A rank 2 boolean tensor with unit dimensions inserted to match their +# A rank 2 boolean array with unit dimensions inserted to match their # corresponding values (e.g. for broadcasting). ExpandedMaskT = TypeVar('ExpandedMaskT', bound=jt.Bool[jt.ArrayT, 'B T *C']) @@ -189,8 +202,8 @@ def get_einsum( jax.core.ShapedArray, ) -PaddingMode = types.PaddingMode -PaddingModeString = types.PaddingModeString +PaddingMode = spec.PaddingMode +PaddingModeString = spec.PaddingModeString def validate_padding(padding: str) -> PaddingModeString: @@ -254,7 +267,9 @@ def sequence_mask(lengths: LengthsT, maxlen: int) -> MaskT: ) -class Sequence(types.Sequence[ValuesT, MaskT], struct.PyTreeNode): +class Sequence( + Generic[ValuesT, MaskT], spec.Sequence[ValuesT, MaskT], struct.PyTreeNode +): """A generic sequence container that preserves masking information.""" values: ValuesT @@ -344,6 +359,7 @@ def dtype(self) -> DType: return self.values.dtype @classmethod + @override def from_lengths( cls, values: ValuesT, lengths: LengthsT, is_masked: bool = False ) -> 'Sequence': @@ -405,7 +421,7 @@ def apply_values( return Sequence(values_fn(self.values, *args, **kwargs), self.mask) @override - def apply_values_masked( + def apply_values_masked( # pyrefly: ignore[bad-override] self: SequenceSelf, values_fn: Callable[..., ValuesT], *args: ApplyValuesMaskedParams.args, @@ -426,7 +442,7 @@ def apply( return Sequence(values, mask) @override - def apply_masked( + def apply_masked( # pyrefly: ignore[bad-override] self: SequenceSelf, apply_fn: Callable[..., tuple[ValuesT, MaskT]], *args: ApplyMaskedParams.args, @@ -471,7 +487,7 @@ def __getitem__( ) @override - def pad_time( + def pad_time( # pyrefly: ignore[bad-override] self: SequenceSelf, pad_left: jt.ScalarInt, pad_right: jt.ScalarInt, @@ -537,7 +553,9 @@ def unmask(self) -> 'Sequence': return self -class MaskedSequence(Sequence[ValuesT, MaskT]): +class MaskedSequence( # pyrefly: ignore[inconsistent-inheritance] + Sequence[ValuesT, MaskT], spec.MaskedSequence[ValuesT, MaskT] +): """Sequence whose invalid timesteps are masked to zero.""" @override @@ -555,21 +573,19 @@ def unmask(self) -> Sequence: def mask_invalid( - sequence: Sequence, + self: Sequence, mask_value: complex | None = None, ) -> 'Sequence': """Returns a sequence whose invalid timesteps are replaced with mask_value.""" - expanded_mask = sequence.expanded_mask() + expanded_mask = self.expanded_mask() if mask_value is None: - masked_values = jnp.zeros_like(sequence.values) + masked_values = jnp.zeros_like(self.values) result_type = MaskedSequence else: - masked_values = jnp.full( - sequence.values.shape, mask_value, sequence.values.dtype - ) + masked_values = jnp.full(self.values.shape, mask_value, self.values.dtype) result_type = Sequence - masked_values = jnp.where(expanded_mask, sequence.values, masked_values) - return result_type(masked_values, sequence.mask) + masked_values = jnp.where(expanded_mask, self.values, masked_values) + return result_type(masked_values, self.mask) # Defined outside of Sequence so that mask_invalid can return a MaskedSequence. @@ -598,6 +614,8 @@ def __getitem__(cls, item): class SequenceT(Sequence, metaclass=MetaSequenceT): + """Allows typing to be: SequenceT[Float, "B T C"]""" + pass @@ -611,22 +629,22 @@ def _sequence_checker_fn(value, origin_type, args, memo): values_dtype, mask_dtype = args try: - typeguard.check_type_internal( + typeguard.check_type_internal( # pyrefly: ignore[missing-attribute] value.values, values_dtype, memo=memo, ) - except typeguard.TypeCheckError as exc: + except typeguard.TypeCheckError as exc: # pyrefly: ignore[missing-attribute] exc.append_path_element('values') raise try: - typeguard.check_type_internal( + typeguard.check_type_internal( # pyrefly: ignore[missing-attribute] value.mask, mask_dtype, memo=memo, ) - except typeguard.TypeCheckError as exc: + except typeguard.TypeCheckError as exc: # pyrefly: ignore[missing-attribute] exc.append_path_element('mask') raise @@ -669,7 +687,7 @@ def _add_custom_checker_lookup_fn(lookup_fn): _add_custom_checker_lookup_fn(_sequence_checker_lookup_fn) -class Steppable(types.Steppable): +class Steppable(spec.Steppable[Sequence, Sequence, ChannelSpec]): """A sequence processing layer that can be executed layerwise or stepwise. # Step-wise execution: @@ -735,7 +753,7 @@ class Steppable(types.Steppable): ``` """ - path: str # Provided by nn.Module. + path: tuple[str, ...] # Provided by nn.Module. @property @override @@ -880,6 +898,7 @@ def receptive_field_per_step(self) -> dict[int, ReceptiveField]: ) @abc.abstractmethod + @override def layer( self, x: Sequence, *, training: bool, constants: Constants | None = None ) -> Sequence: @@ -899,6 +918,7 @@ def layer( truncated to only represent valid frames. """ + @override def layer_with_emits( self, x: Sequence, @@ -906,11 +926,11 @@ def layer_with_emits( training: bool, constants: Constants | None = None, ) -> tuple[Sequence, Emits]: - """Process this layer layer-wise, producing emitted tensors. + """Process this layer layer-wise, producing emitted arrays. This is like `layer`, except it has an additional return value which is the - "emitted" tensors for the layer. The emitted tensors are a structure of - tensors whose whose values are `ArrayLike`s or `Sequence`s. + "emitted" arrays for the layer. The emitted arrays are a structure of + arrays whose whose values are `ArrayLike`s or `Sequence`s. Args: x: Input sequence with values shaped [b, t_i, ...]. @@ -924,11 +944,12 @@ def layer_with_emits( y: The outputs corresponding to this layer with values shaped [b, t_o, ...] where `t_o == t_i * output_ratio`. t_o may have been truncated to only represent valid frames. - emits: A nest of emitted tensors or Sequences. + emits: A nest of emitted arrays or Sequences. """ outputs = self.layer(x, training=training, constants=constants) return outputs, () + @override def __call__( self, x: Sequence, training: bool, constants: Constants | None = None ) -> Sequence: @@ -936,6 +957,7 @@ def __call__( return self.layer(x, training=training, constants=constants) @abc.abstractmethod + @override def step( self, x: Sequence, @@ -944,12 +966,12 @@ def step( training: bool, constants: Constants | None = None, ) -> tuple[Sequence, State]: - """Process this layer step-wise, producing emitted tensors. + """Process this layer step-wise, producing emitted arrays. Args: x: Input sequence with values shaped [b, t_i, ...], where t_i is a multiple of block_size. - state: A structure of state tensors matching get_initial_state. The + state: A structure of state arrays matching get_initial_state. The previous state for this layer. training: Python bool. Whether we are in training mode. constants: A dictionary of constant name to ArrayLike or sl.Sequence. @@ -960,10 +982,11 @@ def step( Returns: y: The outputs corresponding to this step with values shaped [b, t_o, ...] where `t_o == t_i * output_ratio`. - state: A structure of state tensors matching get_initial_state. The + state: A structure of state arrays matching get_initial_state. The new state for this layer. """ + @override def step_with_emits( self, x: Sequence, @@ -972,16 +995,16 @@ def step_with_emits( training: bool, constants: Constants | None = None, ) -> tuple[Sequence, State, Emits]: - """Process this layer step-wise, producing emitted tensors. + """Process this layer step-wise, producing emitted arrays. This is like `step`, except it has an additional return value which is the - "emitted" tensors for the step. The emitted tensors are a structure of - tensors whose values are `ArrayLike`s or `Sequence`s. + "emitted" arrays for the step. The emitted arrays are a structure of + arrays whose values are `ArrayLike`s or `Sequence`s. Args: x: Input sequence with values shaped [b, t_i, ...], where t_i is a multiple of block_size. - state: A structure of state tensors matching get_initial_state. The + state: A structure of state arrays matching get_initial_state. The previous state for this layer. training: Python bool. Whether we are in training mode. constants: A dictionary of constant name to ArrayLike or sl.Sequence. @@ -992,14 +1015,15 @@ def step_with_emits( Returns: y: The outputs corresponding to this step with values shaped [b, t_o, ...] where `t_o == t_i * output_ratio`. - state: A structure of state tensors matching get_initial_state. The + state: A structure of state arrays matching get_initial_state. The new state for this layer. - emits: A nest of emitted tensors or Sequences. + emits: A nest of emitted arrays or Sequences. """ outputs, state = self.step(x, state, training=training, constants=constants) return outputs, state, () @abc.abstractmethod + @override def get_initial_state( self, batch_size: int, @@ -1021,14 +1045,15 @@ def get_initial_state( attention layer this may contain the source sequence to attend to. Returns: - An integer, TensorShape or structure of integer/TensorShapes. + An integer, shape or structure of integer/shapes. """ @abc.abstractmethod + @override def get_output_shape( self, input_shape: ShapeLike, *, constants: Constants | None = None ) -> Shape: - """Returns the output shape this layer produces for an input shape. + """Returns the output channel shape this layer produces for an input channel shape. Args: input_shape: A shape representing the channels dimension of the input @@ -1089,12 +1114,14 @@ def get_output_spec_for_sequence( return self.get_output_spec(x.channel_spec, constants=constants) @abc.abstractmethod + @override def get_output_dtype( self, input_dtype: DType, *, constants: Constants | None = None ) -> DType: """Returns the layer's output dtype for the specified input dtype.""" @nn.nowrap + @override def get_output_spec( self, input_spec: ChannelSpec, @@ -1261,14 +1288,19 @@ def check_step_with_emits_fn( return check_step_with_emits_fn -class SequenceLayer(nn.Module, Steppable): +class SequenceLayer( + nn.Module, Steppable, spec.SequenceLayer[Sequence, Sequence, ChannelSpec] +): """Base Module for Sequence Layers.""" -class PreservesType: +class PreservesType( + SequenceLayer, spec.PreservesType[Sequence, Sequence, ChannelSpec] +): """A mix-in for layers that do not change the input dtype.""" @nn.nowrap + @override def get_output_dtype( self, input_dtype: DType, *, constants: Constants | None = None ) -> DType: @@ -1276,10 +1308,13 @@ def get_output_dtype( return input_dtype -class PreservesShape: +class PreservesShape( + SequenceLayer, spec.PreservesShape[Sequence, Sequence, ChannelSpec] +): """A mix-in for layers that do not change the input shape.""" @nn.nowrap + @override def get_output_shape( self, input_shape: ShapeLike, *, constants: Constants | None = None ) -> Shape: @@ -1287,8 +1322,8 @@ def get_output_shape( return tuple(input_shape) -class Emitting(SequenceLayer, metaclass=abc.ABCMeta): - """A SequenceLayer that emits auxiliary tensors. +class Emitting(SequenceLayer, spec.Emitting[Sequence, Sequence, ChannelSpec]): + """A SequenceLayer that emits auxiliary arrays. This is a convenience subclass that implements step and layer in terms of step_with_emits and layer_with_emits, so that implementors need only implement @@ -1297,6 +1332,7 @@ class Emitting(SequenceLayer, metaclass=abc.ABCMeta): do not produce emits. """ + @override def step( self, x: Sequence, @@ -1311,6 +1347,7 @@ def step( return output, state @abc.abstractmethod + @override def step_with_emits( self, x: Sequence, @@ -1321,6 +1358,7 @@ def step_with_emits( ) -> tuple[Sequence, State, Emits]: pass + @override def layer( self, x: Sequence, @@ -1334,6 +1372,7 @@ def layer( return outputs @abc.abstractmethod + @override def layer_with_emits( self, x: Sequence, @@ -1344,7 +1383,7 @@ def layer_with_emits( pass -class Stateless(SequenceLayer): +class Stateless(SequenceLayer, spec.Stateless[Sequence, Sequence, ChannelSpec]): """A SequenceLayer with no state over time required for step-wise processing. Sub-classes must only implement: @@ -1354,9 +1393,11 @@ class Stateless(SequenceLayer): """ @property + @override def receptive_field_per_step(self) -> dict[int, ReceptiveField]: return {0: (0, 0)} + @override def get_initial_state( self, batch_size: int, @@ -1367,9 +1408,11 @@ def get_initial_state( ) -> State: del batch_size del input_spec + del training del constants return () + @override def step( self, x: Sequence, @@ -1380,9 +1423,36 @@ def step( ) -> tuple[Sequence, State]: return self.layer(x, training=training, constants=constants), state + @abc.abstractmethod + @override + def get_output_shape( + self, input_shape: ShapeLike, *, constants: Constants | None = None + ) -> Shape: + pass -class StatelessEmitting(Emitting): - """A SequenceLayer with no state over time that emits auxiliary tensors. + @abc.abstractmethod + @override + def get_output_dtype( + self, input_dtype: DType, *, constants: Constants | None = None + ) -> DType: + pass + + @abc.abstractmethod + @override + def layer( + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, + ) -> 'Sequence': + pass + + +class StatelessEmitting( + Emitting, spec.StatelessEmitting[Sequence, Sequence, ChannelSpec] +): + """A SequenceLayer with no state over time that emits auxiliary arrays. Sub-classes must only implement: - layer_with_emits @@ -1391,9 +1461,11 @@ class StatelessEmitting(Emitting): """ @property + @override def receptive_field_per_step(self) -> dict[int, ReceptiveField]: return {0: (0, 0)} + @override def step_with_emits( self, x: Sequence, @@ -1407,6 +1479,7 @@ def step_with_emits( ) return outputs, state, emits + @override def get_initial_state( self, batch_size: int, @@ -1415,21 +1488,59 @@ def get_initial_state( training: bool, constants: Constants | None = None, ) -> State: + del batch_size + del input_spec + del training + del constants return () + @abc.abstractmethod + @override + def get_output_shape( + self, input_shape: ShapeLike, *, constants: Constants | None = None + ) -> Shape: + pass -class StatelessPointwise(PreservesShape, Stateless): + @abc.abstractmethod + @override + def get_output_dtype( + self, input_dtype: DType, *, constants: Constants | None = None + ) -> DType: + pass + + @abc.abstractmethod + @override + def layer_with_emits( + self, + x: Sequence[ValuesT, MaskT], + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence[ValuesT, MaskT], Emits]: + pass + + +class StatelessPointwise( + PreservesShape, + Stateless, + spec.StatelessPointwise[Sequence, Sequence, ChannelSpec], +): """A SequenceLayer that has no state and operates pointwise on its input.""" -class StatelessPointwiseFunctor(StatelessPointwise, metaclass=abc.ABCMeta): +class StatelessPointwiseFunctor( + StatelessPointwise, + spec.StatelessPointwiseFunctor[Sequence, Sequence, ChannelSpec], +): """A stateless SequenceLayer for simple pointwise processing fns.""" @abc.abstractmethod + @override def fn(self, values: ValuesT, mask: MaskT) -> tuple[ValuesT, MaskT]: """Transforms each scalar in values independently.""" @property + @override def mask_required(self): """Returns true if fn can change the sequence's masked state. @@ -1438,7 +1549,8 @@ def mask_required(self): return True @check_layer - def layer( + @override + def layer( # pyrefly: ignore[missing-override-decorator] self, x: Sequence, *, @@ -1455,7 +1567,7 @@ def layer( return y -class SequenceLayerConfig(types.SequenceLayerConfig): +class SequenceLayerConfig(spec.SequenceLayerConfig): """Base class for SequenceLayer configuration objects. Requires a no-argument make() method which returns a SequenceLayer. @@ -1467,9 +1579,11 @@ class SequenceLayerConfig(types.SequenceLayerConfig): """ @abc.abstractmethod + @override def make(self) -> SequenceLayer: """Builds a SequenceLayer from this config.""" + @override def copy(self, **kwargs) -> Self: """Create a copy of this config. @@ -1501,6 +1615,7 @@ def copy(self, **kwargs) -> Self: ) try: + assert not isinstance(self, type), 'replace must be called on an instance' return typing.cast(Self, dataclasses.replace(self, **kwargs)) except TypeError as type_error: raise AttributeError( diff --git a/sequence_layers/jax/types_test.py b/sequence_layers/jax/types_test.py index c38420f..7b7b5cf 100644 --- a/sequence_layers/jax/types_test.py +++ b/sequence_layers/jax/types_test.py @@ -15,9 +15,7 @@ import dataclasses import typing -from typing import Any -from absl.testing import parameterized import chex import flax.linen as nn import jax @@ -25,37 +23,33 @@ import jaxtyping import numpy as np -from sequence_layers.abstract import types_test_base from sequence_layers.jax import simple from sequence_layers.jax import test_utils from sequence_layers.jax import types from sequence_layers.jax import typing as jt +from sequence_layers.specs import types_behaviors as spec -class Foo(nn.Module): +class ModuleInterfaceTest( + test_utils.SequenceLayerTest, spec.ModuleInterfaceTest +): + pass - @nn.compact - def __call__(self, x: types.Sequence) -> types.Sequence: - return x - -class SequenceTest(types_test_base.SequenceTest, test_utils.SequenceLayerTest): +class SequenceTest(test_utils.SequenceLayerTest, spec.SequenceTest): """Tests for the Sequence class.""" - def get_backend(self): - return jnp - - @property - def Sequence(self): - return types.Sequence - - @property - def MaskedSequence(self): - return types.MaskedSequence - def test_type_checks(self): """Test type checks in Sequence.__post_init__.""" + class Foo(nn.Module): + + @nn.compact + def __call__( + self, x: types.Sequence[types.ValuesT, types.MaskT] + ) -> types.Sequence: + return x + # Allowed: Both array-like. types.Sequence(jnp.zeros((2, 3, 5)), jnp.zeros((2, 3), dtype=jnp.bool_)) types.Sequence(np.zeros((2, 3, 5)), np.zeros((2, 3), dtype=jnp.bool_)) @@ -109,26 +103,6 @@ def test_type_checks(self): with self.assertRaises(jaxtyping.TypeCheckError): types.Sequence(np.zeros((2, 3, 5)), np.zeros((1, 3), dtype=jnp.bool_)) - def test_mask_invalid_idempotent(self): - values = jnp.array([ - [1.0, 2.0, 3.0, 4.0], - [10.0, 20.0, 30.0, 40.0], - ]) - mask = jnp.array([[True, True, False, False], [False, False, False, True]]) - - x = types.Sequence(values, mask) - masked = x.mask_invalid() - self.assertIsNot(masked, x) - self.assertIsInstance(masked, types.MaskedSequence) - - masked_again = masked.mask_invalid() - self.assertIs(masked_again, masked) - self.assertIsInstance(masked_again, types.MaskedSequence) - - masked2 = x.mask_invalid() - self.assertIsNot(masked2, masked) - self.assertIsInstance(masked2, types.MaskedSequence) - def test_type_annotation(self): if not jt.runtime_type_checking_enabled: self.skipTest('Type checking is disabled.') @@ -137,7 +111,7 @@ def test_type_annotation(self): def f( x: types.SequenceT[jt.Float, 'B T C'], ) -> types.SequenceT[jt.Float, 'B T C 1']: - return types.Sequence(x.values[..., jnp.newaxis], x.mask) + return types.Sequence(x.values[..., jnp.newaxis], x.mask) # type: ignore[return-value] values = jnp.zeros((2, 3, 5)) mask = jnp.zeros((2, 3), dtype=jnp.bool_) @@ -171,7 +145,7 @@ def test_type_annotation_masked_sequence(self): def f( x: types.SequenceT[jt.Float, 'B T C'], ) -> types.SequenceT[jt.Float, 'B T C 1']: - return types.MaskedSequence(x.values[..., jnp.newaxis], x.mask) + return types.MaskedSequence(x.values[..., jnp.newaxis], x.mask) # type: ignore[return-value] values = jnp.zeros((2, 3, 5)) mask = jnp.zeros((2, 3), dtype=jnp.bool_) @@ -222,28 +196,11 @@ def fn(x: types.Sequence) -> types.Sequence: y = fn(x) self.assertSequencesEqual(y, x) - def test_from_lengths(self): - x_expected = test_utils.random_sequence(5, 17, 2) - x = types.Sequence.from_lengths(x_expected.values, x_expected.lengths()) - self.assertSequencesEqual(x, x_expected) - - # Out of range lengths are clipped to 0 or max. - x = types.Sequence.from_lengths(x_expected.values, [-1, 0, 5, 17, 18]) - self.assertAllEqual(x.lengths(), jnp.asarray([0, 0, 5, 17, 17])) - self.assertNotIsInstance(x, types.MaskedSequence) - - # Return type is MaskedSequence if is_masked=True. - x = types.Sequence.from_lengths( - x_expected.values, [-1, 0, 5, 17, 18], is_masked=True - ) - self.assertAllEqual(x.lengths(), jnp.asarray([0, 0, 5, 17, 17])) - self.assertIsInstance(x, types.MaskedSequence) - - -class SequenceLayerConfigTest(types_test_base.SequenceLayerConfigTest): - def get_config_base_cls(self): - return types.SequenceLayerConfig +class SequenceLayerConfigTest( + test_utils.SequenceLayerTest, spec.SequenceLayerConfigTest +): + pass def test_copy_raises_on_mutable_attribute(self): @@ -284,28 +241,36 @@ def make(self) -> simple.Identity: del new_config -class SteppableTest(types_test_base.SteppableTest): +class SteppableTest(test_utils.SequenceLayerTest, spec.SteppableTest): + pass - def create_steppable(self): - class DefaultSteppable(types.Steppable): +class PreservesTypeTest(test_utils.SequenceLayerTest, spec.PreservesTypeTest): + pass + + +class PreservesShapeTest(test_utils.SequenceLayerTest, spec.PreservesShapeTest): + pass + + +class StatelessTest(test_utils.SequenceLayerTest, spec.StatelessTest): + pass - def layer(self, x, *, constants=None): - return x - def step(self, x, state, *, constants=None): - return x, state +class EmittingTest(test_utils.SequenceLayerTest, spec.EmittingTest): + pass - def get_initial_state(self, batch_size, input_spec, *, constants=None): - return 0 - def get_output_shape(self, input_shape, *, constants=None): - return input_shape +class StatelessEmittingTest( + test_utils.SequenceLayerTest, spec.StatelessEmittingTest +): + pass - def get_output_dtype(self, input_dtype, *, constants=None): - return input_dtype - return DefaultSteppable() +class StatelessPointwiseFunctorTest( + test_utils.SequenceLayerTest, spec.StatelessPointwiseFunctorTest +): + pass if __name__ == '__main__': diff --git a/sequence_layers/jax/typing.py b/sequence_layers/jax/typing.py index 9b75781..a83a2ba 100644 --- a/sequence_layers/jax/typing.py +++ b/sequence_layers/jax/typing.py @@ -13,27 +13,37 @@ # limitations under the License. """Wrappers for making jaxtyping easier to use and understand.""" +from typing import Callable, TYPE_CHECKING, TypeVar, Union + import jax import jax.numpy as jnp -from jaxtyping import AbstractDtype, Bool, config as jaxtyping_config, Float, Int, PyTree, Shaped, jaxtyped, TypeCheckError +from jaxtyping import AbstractDtype +from jaxtyping import Bool +from jaxtyping import config as jaxtyping_config +from jaxtyping import Float +from jaxtyping import Int +from jaxtyping import jaxtyped +from jaxtyping import PyTree +from jaxtyping import Shaped +from jaxtyping import TypeCheckError import numpy as np import typeguard -from typing import Callable, TypeVar, Union - - -class _MetaArrayT(type): - types = () - def __instancecheck__(cls, obj): - return isinstance(obj, cls.types) +if TYPE_CHECKING: + ArrayT = jax.Array | np.ndarray +else: + class _MetaArrayT(type): + types = () -class JaxArrayT(metaclass=_MetaArrayT): - types = (jax.Array, jax.ShapeDtypeStruct) + def __instancecheck__(cls, obj): + return isinstance(obj, cls.types) + class JaxArrayT(metaclass=_MetaArrayT): + types = (jax.Array, jax.ShapeDtypeStruct) -class ArrayT(metaclass=_MetaArrayT): - types = (JaxArrayT, np.ndarray) + class ArrayT(metaclass=_MetaArrayT): + types = (JaxArrayT, np.ndarray) Scalar = Shaped[ArrayT, ''] | Shaped[np.generic, ''] | Shaped[jnp.generic, ''] diff --git a/sequence_layers/jax/utils.py b/sequence_layers/jax/utils.py index 64029fb..c4e0b40 100644 --- a/sequence_layers/jax/utils.py +++ b/sequence_layers/jax/utils.py @@ -21,13 +21,16 @@ import pprint import re import typing -from typing import Any, Callable, Protocol, Self, Sequence as TypingSequence, TypeVar +from typing import Any, Callable, Protocol, Self +from typing import Sequence as TypingSequence +from typing import TypeVar import flax.core.scope import flax.linen as nn import jax import jax.numpy as jnp import numpy as np + from sequence_layers.jax import meta from sequence_layers.jax import types from sequence_layers.jax import typing as jt @@ -2228,6 +2231,7 @@ def layer_with_emits_spec( values_spec, types.ShapeDType(values_spec.shape[:2], dtype=types.MASK_DTYPE), ) + def layer_fn( layer: types.SequenceLayer, x: types.Sequence, diff --git a/sequence_layers/mlx/__init__.py b/sequence_layers/mlx/__init__.py index 8cb7355..95f38af 100644 --- a/sequence_layers/mlx/__init__.py +++ b/sequence_layers/mlx/__init__.py @@ -13,4 +13,7 @@ # limitations under the License. """Sequence layers in MLX.""" -from sequence_layers.mlx.types import * \ No newline at end of file +from sequence_layers.mlx.types import * + +from . import backend +from . import types diff --git a/sequence_layers/mlx/backend.py b/sequence_layers/mlx/backend.py new file mode 100644 index 0000000..f73847a --- /dev/null +++ b/sequence_layers/mlx/backend.py @@ -0,0 +1,26 @@ +"""Backend-specific helpers (MLX)""" + +from typing import override + +import mlx.core as mx + +from sequence_layers.specs import backend as spec +from sequence_layers.specs import types as types_spec + + +class BackendWrapper(spec.xp): + """Thin wrapper around MLX to match NumPy interface for tests.""" + + bool_ = mx.bool_ + int32 = mx.int32 + + @override + def array(self, a, dtype=None) -> types_spec.Array: + return mx.array(a, dtype=dtype) + + @override + def zeros(self, shape, dtype=None) -> types_spec.Array: + return mx.zeros(shape, dtype=dtype) + + +xp: spec.xp = BackendWrapper() diff --git a/sequence_layers/mlx/test_utils.py b/sequence_layers/mlx/test_utils.py new file mode 100644 index 0000000..9953b74 --- /dev/null +++ b/sequence_layers/mlx/test_utils.py @@ -0,0 +1,47 @@ +"""Test utilities for MLX sequence layers.""" + +from typing import override + +import mlx.core as mx +import numpy as np + +from sequence_layers.mlx import types +import sequence_layers.mlx as sl +from sequence_layers.specs import test_utils as spec + + +def _mask_and_pad_to_max_length( + a: types.Sequence, b: types.Sequence +) -> tuple[types.Sequence, types.Sequence]: + """Masks and pads two sequences to the same max length.""" + # Only compare values in non-masked regions. + a = a.mask_invalid() + b = b.mask_invalid() + a_time = a.values.shape[1] + b_time = b.values.shape[1] + max_time = max(a_time, b_time) + a = a.pad_time(0, max_time - a_time, valid=False) + b = b.pad_time(0, max_time - b_time, valid=False) + return a, b + + +class SequenceLayerTest(spec.SequenceLayerTest): + """Base class for MLX SequenceLayer tests.""" + + sl = sl # pyrefly: ignore[bad-assignment] # module-as-protocol + + @override + def assertAllEqual(self, x, y): + """Asserts that two arrays are equal.""" + x_np = np.array(x) if isinstance(x, mx.array) else x + y_np = np.array(y) if isinstance(y, mx.array) else y + np.testing.assert_array_equal(x_np, y_np) + + @override + def assertSequencesEqual( # pyrefly: ignore[bad-override] + self, x: types.Sequence, y: types.Sequence + ): + """After padding, checks sequence values are equal and masks are equal.""" + x, y = _mask_and_pad_to_max_length(x, y) + self.assertAllEqual(x.values, y.values) + self.assertAllEqual(x.mask, y.mask) diff --git a/sequence_layers/mlx/types.py b/sequence_layers/mlx/types.py index 5647618..80121d9 100644 --- a/sequence_layers/mlx/types.py +++ b/sequence_layers/mlx/types.py @@ -2,16 +2,26 @@ import abc import dataclasses -import enum import fractions import functools -from typing import Callable, Generic, Iterable, TypeVar, override +import math +import types +from typing import ( + Any, + Callable, + cast, + Iterable, + MutableMapping, + override, + Self, + TypeVar, +) +import jaxtyping as jt +from mlx import nn import mlx.core as mx -import mlx.nn as nn -import numpy as np -from sequence_layers.abstract import types +from sequence_layers.specs import types as spec # Type aliases. MASK_DTYPE = mx.bool_ @@ -20,18 +30,22 @@ MaskT = TypeVar('MaskT', bound=mx.array) LengthsT = TypeVar('LengthsT', bound=mx.array) ExpandedMaskT = TypeVar('ExpandedMaskT', bound=mx.array) +NewValuesT = TypeVar('NewValuesT', bound=mx.array) +NewMaskT = TypeVar('NewMaskT', bound=mx.array) SequenceSelf = TypeVar('SequenceSelf', bound='Sequence') Shape = tuple[int, ...] ShapeLike = list[int] | tuple[int, ...] -DType = np.dtype +DType = mx.Dtype State = object # Any pytree. -Constants = dict[str, object] -Emits = object +Constants = MutableMapping[str, jt.PyTree[mx.array]] +Emits = jt.PyTree[mx.array] # Receptive field. ReceptiveField = tuple[float | int, float | int] | None +InputT = TypeVar('InputT', bound='Sequence') +OutputT = TypeVar('OutputT', bound='Sequence') __all__ = ( # go/keep-sorted start @@ -67,6 +81,7 @@ # go/keep-sorted end ) + class ShapeDType: """Lightweight replacement for jax.ShapeDtypeStruct.""" @@ -74,9 +89,11 @@ def __init__(self, shape: Shape, dtype: DType): self.shape = shape self.dtype = dtype + @override def __repr__(self) -> str: return f'ShapeDType(shape={self.shape}, dtype={self.dtype})' + @override def __eq__(self, other: object) -> bool: if not isinstance(other, ShapeDType): return NotImplemented @@ -88,14 +105,17 @@ def __hash__(self) -> int: ChannelSpec = ShapeDType -PaddingMode = types.PaddingMode +PaddingMode = spec.PaddingMode -def sequence_mask(lengths: LengthsT, maxlen: int) -> MaskT: - return mx.arange(maxlen)[None, :] < mx.array(lengths)[:, None] +def sequence_mask(lengths: LengthsT, maxlen: int) -> mx.array: + """Generates a boolean mask for sequences based on lengths.""" + return mx.arange(maxlen)[None, :] < mx.array(lengths)[:, None] # pylint: disable=unsubscriptable-object -class Sequence(types.Sequence[ValuesT, MaskT]): +class Sequence[ValuesT: mx.array, MaskT: mx.array]( + spec.Sequence[ValuesT, MaskT] +): """A generic sequence container that preserves masking information.""" values: ValuesT @@ -134,6 +154,23 @@ def dtype(self) -> DType: """Returns the dtype of the sequence values.""" return self.values.dtype + @classmethod + @override + def from_lengths( + cls, + values: ValuesT, + lengths: LengthsT, + is_masked: bool = False, + ) -> 'Sequence': + """Constructs a sequence from values and per-batch element lengths.""" + values_arr = mx.array(values) + mask = sequence_mask(lengths, maxlen=values_arr.shape[1]) + return ( + MaskedSequence(values_arr, mask) + if is_masked + else Sequence(values_arr, mask) + ) + @classmethod @override def from_values(cls, values: ValuesT) -> 'MaskedSequence': @@ -161,7 +198,7 @@ def concatenate_sequences(cls, sequences: Iterable['Sequence']) -> 'Sequence': ) @override - def expanded_mask(self) -> ExpandedMaskT: + def expanded_mask(self) -> mx.array: """Returns the Sequence mask expanded to match values rank.""" return self.mask.reshape(self.mask.shape + (1,) * (self.values.ndim - 2)) @@ -177,13 +214,16 @@ def apply_values( @override def apply_values_masked( - self: SequenceSelf, - values_fn: Callable[..., ValuesT], + self, + values_fn: Callable[..., NewValuesT], *args, **kwargs, - ) -> SequenceSelf: + ) -> 'Sequence[NewValuesT, MaskT]': """Transforms values, preserving masked state.""" - return type(self)(values_fn(self.values, *args, **kwargs), self.mask) + return cast( + 'Sequence[NewValuesT, MaskT]', + type(self)(values_fn(self.values, *args, **kwargs), self.mask), + ) @override def apply( @@ -198,14 +238,14 @@ def apply( @override def apply_masked( - self: SequenceSelf, - apply_fn: Callable[..., tuple[ValuesT, MaskT]], + self, + apply_fn: Callable[..., tuple[NewValuesT, NewMaskT]], *args, **kwargs, - ) -> SequenceSelf: + ) -> 'Sequence[NewValuesT, NewMaskT]': """Transforms values/mask, preserving masked state.""" values, mask = apply_fn(self.values, self.mask, *args, **kwargs) - return type(self)(values, mask) + return cast('Sequence[NewValuesT, NewMaskT]', type(self)(values, mask)) @override def astype(self: SequenceSelf, dtype: DType | None) -> SequenceSelf: @@ -222,14 +262,14 @@ def lengths(self) -> mx.array: @override def __getitem__( self: SequenceSelf, - the_slice, + the_slice: slice | tuple[int | slice | None | types.EllipsisType, ...], ) -> SequenceSelf: """Slices the Sequence values and mask.""" if isinstance(the_slice, slice): the_slice = (the_slice,) return type(self)( - self.values.__getitem__(the_slice), - self.mask.__getitem__(the_slice[:2]), + self.values[the_slice], + self.mask[the_slice[:2]], ) @override @@ -262,8 +302,9 @@ def concatenate(self, other: 'Sequence') -> 'Sequence': """Concatenates with other on the time dimension.""" values = mx.concatenate([self.values, other.values], axis=1) mask = mx.concatenate([self.mask, other.mask], axis=1) - return_type = type(self) if type(self) is type(other) else Sequence - return return_type(values, mask) + if type(self) is type(other): + return type(self)(values, mask) + return Sequence(values, mask) @override def mask_invalid(self, mask_value: complex | None = None) -> 'Sequence': @@ -276,11 +317,37 @@ def unmask(self) -> 'Sequence': return self -class MaskedSequence(Sequence[ValuesT, MaskT]): +class MaskedSequence[ValuesT: mx.array, MaskT: mx.array]( + Sequence[ValuesT, MaskT], spec.MaskedSequence[ValuesT, MaskT] +): """Sequence whose invalid timesteps are masked to zero.""" @override - def mask_invalid(self, mask_value: complex | None = None) -> 'Sequence': + def apply_values_masked( + self, + values_fn: Callable[..., NewValuesT], + *args, + **kwargs, + ) -> 'MaskedSequence[NewValuesT, MaskT]': + return cast( + 'MaskedSequence[NewValuesT, MaskT]', + type(self)(values_fn(self.values, *args, **kwargs), self.mask), + ) + + @override + def apply_masked( + self, + apply_fn: Callable[..., tuple[NewValuesT, NewMaskT]], + *args, + **kwargs, + ) -> 'MaskedSequence[NewValuesT, NewMaskT]': + values, mask = apply_fn(self.values, self.mask, *args, **kwargs) + return cast( + 'MaskedSequence[NewValuesT, NewMaskT]', type(self)(values, mask) + ) + + @override + def mask_invalid(self, mask_value: complex | None = None) -> Sequence: if mask_value is None: return self return mask_invalid(self, mask_value) @@ -298,10 +365,10 @@ def mask_invalid( expanded_mask = sequence.expanded_mask() if mask_value is None: masked_values = mx.zeros_like(sequence.values) - result_type = MaskedSequence + result_type: type[Sequence] = MaskedSequence else: masked_values = mx.full( - sequence.values.shape, mask_value, sequence.values.dtype + sequence.values.shape, mask_value, sequence.values.dtype # type: ignore[arg-type] ) result_type = Sequence masked_values = mx.where(expanded_mask, sequence.values, masked_values) @@ -309,7 +376,7 @@ def mask_invalid( # Defined outside of Sequence so mask_invalid can return MaskedSequence. -Sequence.mask_invalid = mask_invalid +Sequence.mask_invalid = mask_invalid # type: ignore[assignment] # --------------------------------------------------------------------------- # Check decorators @@ -317,6 +384,7 @@ def mask_invalid( def _check_output_spec(layer, x, y, constants): + """Checks that the output spec of a layer matches the expected spec.""" expected = layer.get_output_spec(x.channel_spec, constants=constants) if y.channel_shape != expected.shape: raise ValueError( @@ -328,6 +396,7 @@ def _check_output_spec(layer, x, y, constants): def _check_output_ratio(layer, x, y): + """Checks that the output length of a layer matches the expected length.""" expected_length = x.shape[1] * layer.output_ratio if y.shape[1] != expected_length: raise ValueError( @@ -341,8 +410,8 @@ def check_layer(layer_fn): """Validates layer inputs and outputs.""" @functools.wraps(layer_fn) - def wrapper(self, x, *, constants=None): - y = layer_fn(self, x, constants=constants) + def wrapper(self, x, *, training: bool, constants=None): + y = layer_fn(self, x, training=training, constants=constants) _check_output_spec(self, x, y, constants) return y @@ -353,7 +422,7 @@ def check_step(step_fn): """Validates step inputs and outputs.""" @functools.wraps(step_fn) - def wrapper(self, x, state, *, constants=None): + def wrapper(self, x, state, *, training: bool, constants=None): if not self.supports_step: raise ValueError(f'{self.__class__.__name__} does not support step().') block_size = self.block_size @@ -362,7 +431,7 @@ def wrapper(self, x, state, *, constants=None): f'{self.__class__.__name__} received input with' f' {x.shape=} not a multiple of {block_size=}.' ) - y, state = step_fn(self, x, state, constants=constants) + y, state = step_fn(self, x, state, training=training, constants=constants) _check_output_spec(self, x, y, constants) _check_output_ratio(self, x, y) return y, state @@ -375,8 +444,71 @@ def wrapper(self, x, state, *, constants=None): # --------------------------------------------------------------------------- -class Steppable(types.Steppable): - """A sequence processing layer that supports layer and step modes.""" +class Steppable(spec.Steppable[Sequence, Sequence, ChannelSpec]): + """A sequence processing layer that can be executed layerwise or stepwise. + + # Step-wise execution: + + A SequenceLayer supports step-wise execution if its `supports_step` property + is true. Most built-in SequenceLayers support step-wise processing by default, + but may support processing features that are not causal and therefore cannot + be executed step-by-step (e.g. non-causal convolutions, bidirectional RNNs, + etc.). + + When executing step-wise, use the `step` or `step_with_emits` method to + process a block of inputs (a `Sequence` shaped `[b, block_size * n, ...]`) and + a `state` input whose structure matches `get_initial_state`. + + This produces: + - An output `Sequence` shaped `[b, block_size * n * output_ratio, ...]` + whose `...` shape matches `get_output_shape`. + - A `state` output whose structure matches `get_initial_state`. + - (Optionally) an `emits` output. + + The output `Sequence` is the primary output of the step, while the `emits` + represent "auxiliary" outputs that are produced by the layer (for example, + debug output). + + # Layer-wise execution: + + When executing layer-wise, use the `layer` or `layer_with_emits` method to + process inputs (a `Sequence` shaped `[b, t, ...]`). + + This produces: + - An output `Sequence` shaped `[b, t * output_ratio, ...]` + whose `...` shape matches `get_output_shape`. + - (Optionally) an `emits` output. + + The output `Sequence` is the primary output of the layer, while the `emits` + represent "auxiliary" outputs that are produced by the layer (for example, + debug output). + + # Latency + + SequenceLayers have an input and output "latency" to describe their latency + characteristics. Latency is the number of input or output timesteps from + step-wise excecution that are input or output before the step-wise output of + the layer matches the layer-wise output of the layer. + + An invariant that all layers must maintain is that for the layer-wise output + and step-wise output: + + ``` + y_layer = l.layer(x, training=training) + + # Pad x with input_latency timesteps to process the entire sequence: + x = x.pad_time(0, l.input_latency, valid=False) + + y_step, _, _ = utils.step_by_step_dynamic(l, x, training=training) + ``` + + The step-wise output is equivalent to the layer-wise output after dropping the + initial latency timesteps of the step-wise output: + + ``` + y_layer == y_step[:, l.output_latency:] + ``` + """ @property @override @@ -405,14 +537,17 @@ def output_latency(self) -> int: @override def get_accumulated_input_latency(self, input_latency: int) -> int: - import math return math.ceil(input_latency / self.output_ratio) + self.input_latency @override def get_accumulated_output_latency(self, output_latency: int) -> int: output_ratio = self.output_ratio if required_delay := -output_latency % (1 / output_ratio): - path = '/'.join(self.path) if hasattr(self, 'path') else self.__class__.__name__ + path = ( + '/'.join(self.path) + if hasattr(self, 'path') + else self.__class__.__name__ + ) raise ValueError( f'Input to {self.__class__.__name__}(path={path!r}) has a step-wise' f' incoming {output_latency=} which is not divisible' @@ -425,20 +560,56 @@ def get_accumulated_output_latency(self, output_latency: int) -> int: @property @override def receptive_field(self) -> ReceptiveField: - raise NotImplementedError('receptive_field is not implemented by MLX Steppable.') - + raise NotImplementedError( + 'receptive_field is not implemented by MLX Steppable.' + ) @abc.abstractmethod @override def layer( - self, x: Sequence, *, constants: Constants | None = None + self, x: Sequence, *, training: bool, constants: Constants | None = None ) -> Sequence: - """Process this layer layer-wise.""" + """Process this layer layer-wise. + + Args: + x: Input sequence with values shaped [b, t_i, ...]. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the source sequence to attend to. + + Returns: + y: The outputs corresponding to this layer with values shaped + [b, t_o, ...] where `t_o == t_i * output_ratio`. t_o may have been + truncated to only represent valid frames. + """ + @override def layer_with_emits( - self, x: Sequence, *, constants: Constants | None = None + self, x: Sequence, *, training: bool, constants: Constants | None = None ) -> tuple[Sequence, Emits]: - return self.layer(x, constants=constants), () + """Process this layer layer-wise, producing emitted arrays. + + This is like `layer`, except it has an additional return value which is the + "emitted" arrays for the layer. The emitted arrays are a structure of + arrays whose values are arrays or `Sequence`s. + + Args: + x: Input sequence with values shaped [b, t_i, ...]. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the key/value sequence to attend to. + + Returns: + y: The outputs corresponding to this layer with values shaped + [b, t_o, ...] where `t_o == t_i * output_ratio`. t_o may have been + truncated to only represent valid frames. + emits: A nest of emitted arrays or Sequences. + """ + return self.layer(x, training=training, constants=constants), () @abc.abstractmethod @override @@ -447,18 +618,63 @@ def step( x: Sequence, state: State, *, + training: bool, constants: Constants | None = None, ) -> tuple[Sequence, State]: - """Process this layer step-wise.""" + """Process this layer step-wise. + + Args: + x: Input sequence with values shaped [b, t_i, ...], where t_i is a + multiple of block_size. + state: A structure of state arrays matching get_initial_state. The + previous state for this layer. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the key/value sequence to attend to. + + Returns: + y: The outputs corresponding to this step with values shaped [b, t_o, ...] + where `t_o == t_i * output_ratio`. + state: A structure of state arrays matching get_initial_state. The + new state for this layer. + """ + @override def step_with_emits( self, x: Sequence, state: State, *, + training: bool, constants: Constants | None = None, ) -> tuple[Sequence, State, Emits]: - y, state = self.step(x, state, constants=constants) + """Process this layer step-wise, producing emitted arrays. + + This is like `step`, except it has an additional return value which is the + "emitted" arrays for the step. The emitted arrays are a structure of + arrays whose values are arrays or `Sequence`s. + + Args: + x: Input sequence with values shaped [b, t_i, ...], where t_i is a + multiple of block_size. + state: A structure of state arrays matching get_initial_state. The + previous state for this layer. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the key/value sequence to attend to. + + Returns: + y: The outputs corresponding to this step with values shaped [b, t_o, ...] + where `t_o == t_i * output_ratio`. + state: A structure of state arrays matching get_initial_state. The + new state for this layer. + emits: A nest of emitted arrays or Sequences. + """ + y, state = self.step(x, state, training=training, constants=constants) return y, state, () @abc.abstractmethod @@ -468,9 +684,23 @@ def get_initial_state( batch_size: int, input_spec: ChannelSpec, *, + training: bool, constants: Constants | None = None, ) -> State: - """Returns the initial state for step-wise processing.""" + """Returns the initial state for this SequenceLayer. + + Args: + batch_size: The batch size to create state for. + input_spec: An input ChannelSpec representing the channel shape and dtype + of the input that will be stepped. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the source sequence to attend to. + + Returns: + An integer, shape, or structure of integer/shapes. + """ @abc.abstractmethod @override @@ -480,7 +710,20 @@ def get_output_shape( *, constants: Constants | None = None, ) -> Shape: - """Returns the output channel shape for an input channel shape.""" + """Returns the output channel shape this layer produces for an input channel shape. + + Args: + input_shape: A shape representing the channels dimension of the input + sequence (i.e. not including the batch or time dimension). + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the source sequence to attend to. + + Returns: + A shape representing the output channels dimensions (i.e. not including + the batch or time dimension). + """ @abc.abstractmethod @override @@ -490,14 +733,37 @@ def get_output_dtype( *, constants: Constants | None = None, ) -> DType: - """Returns the output dtype for an input dtype.""" + """Returns the layer's output dtype for the specified input dtype. + + Args: + input_dtype: The dtype of the input features. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. + Returns: + The dtype of the output features. + """ + + @override def get_output_spec( self, input_spec: ChannelSpec, *, constants: Constants | None = None, ) -> ChannelSpec: + """Returns the output spec this layer produces for the provided input spec. + + Args: + input_spec: A ChannelSpec which represents the channels shape and dtype of + the input sequence (i.e. not including the batch or time dimension). + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. + + Returns: + The ChannelSpec of the output features. + """ shape = self.get_output_shape(input_spec.shape, constants=constants) dtype = self.get_output_dtype(input_spec.dtype, constants=constants) return ChannelSpec(shape, dtype) @@ -508,20 +774,27 @@ def get_output_spec( # --------------------------------------------------------------------------- -class SequenceLayer(nn.Module, Steppable): - """Base MLX Module for Sequence Layers.""" +class SequenceLayer( + nn.Module, + Steppable, + spec.SequenceLayer[Sequence, Sequence, ChannelSpec], + metaclass=abc.ABCMeta, +): + """Base Module for Sequence Layers.""" -class SequenceLayerConfig(types.SequenceLayerConfig): + +class SequenceLayerConfig(spec.SequenceLayerConfig): """Base class for SequenceLayer configuration objects.""" @abc.abstractmethod + @override def make(self) -> SequenceLayer: """Builds a SequenceLayer from this config.""" - def copy(self, **kwargs) -> 'SequenceLayerConfig': + @override + def copy(self, **kwargs) -> Self: """Returns a copy of the config with updated fields.""" - return dataclasses.replace(self, **kwargs) - + return cast(Self, dataclasses.replace(cast(Any, self), **kwargs)) # --------------------------------------------------------------------------- @@ -529,19 +802,32 @@ def copy(self, **kwargs) -> 'SequenceLayerConfig': # --------------------------------------------------------------------------- -class PreservesType: - """Mix-in: layer does not change the input dtype.""" +class PreservesType( + SequenceLayer, + spec.PreservesType[Sequence, Sequence, ChannelSpec], + metaclass=abc.ABCMeta, +): + """A mix-in for layers that do not change the input dtype.""" + @override def get_output_dtype( - self, input_dtype: DType, *, constants: Constants | None = None + self, + input_dtype: DType, + *, + constants: Constants | None = None, ) -> DType: del constants return input_dtype -class PreservesShape: - """Mix-in: layer does not change the input channel shape.""" +class PreservesShape( + SequenceLayer, + spec.PreservesShape[Sequence, Sequence, ChannelSpec], + metaclass=abc.ABCMeta, +): + """A mix-in for layers that do not change the input shape.""" + @override def get_output_shape( self, input_shape: ShapeLike, @@ -557,47 +843,112 @@ def get_output_shape( # --------------------------------------------------------------------------- -class Stateless(SequenceLayer): - """A SequenceLayer with no step state.""" +class Stateless(SequenceLayer, spec.Stateless[Sequence, Sequence, ChannelSpec]): + """A SequenceLayer with no state over time required for step-wise processing. + Sub-classes must also implement: + - layer + - get_output_shape + - get_output_dtype + """ + + @override def get_initial_state( self, batch_size: int, input_spec: ChannelSpec, *, + training: bool, constants: Constants | None = None, ) -> State: + del batch_size + del input_spec + del training + del constants return () + @abc.abstractmethod + @override + def get_output_shape( + self, + input_shape: ShapeLike, + *, + constants: Constants | None = None, + ) -> Shape: + ... + + @abc.abstractmethod + @override + def get_output_dtype( + self, + input_dtype: DType, + *, + constants: Constants | None = None, + ) -> DType: + ... + + @abc.abstractmethod + @override + def layer( + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, + ) -> Sequence: + ... + + @override def step( self, x: Sequence, state: State, *, + training: bool, constants: Constants | None = None, ) -> tuple[Sequence, State]: - return self.layer(x, constants=constants), state + return self.layer(x, training=training, constants=constants), state -class StatelessPointwise(PreservesShape, Stateless): - """Stateless layer that operates pointwise (preserves shape).""" +class StatelessPointwise( + PreservesShape, + Stateless, + spec.StatelessPointwise[Sequence, Sequence, ChannelSpec], + metaclass=abc.ABCMeta, +): + """A SequenceLayer that has no state and operates pointwise on its input.""" -class StatelessPointwiseFunctor(StatelessPointwise, metaclass=abc.ABCMeta): - """Stateless pointwise layer defined by a fn(values, mask).""" +class StatelessPointwiseFunctor( + StatelessPointwise, + spec.StatelessPointwiseFunctor[Sequence, Sequence, ChannelSpec], +): + """A stateless SequenceLayer for simple pointwise processing fns.""" @abc.abstractmethod + @override def fn(self, values: ValuesT, mask: MaskT) -> tuple[ValuesT, MaskT]: """Transforms each scalar in values independently.""" @property + @override def mask_required(self): + """Returns true if fn can change the sequence's masked state. + + If fn(0) -> 0, then mask_required() is False. + """ return True @check_layer - def layer( - self, x: Sequence, *, constants: Constants | None = None + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, ) -> Sequence: + del training if self.mask_required: y = x.apply(self.fn) else: @@ -605,7 +956,7 @@ def layer( # Ensure MaskedSequence -> Sequence conversion for apply. if isinstance(y, MaskedSequence) and self.mask_required: y = Sequence(y.values, y.mask) - return y + return cast(Sequence, y) # --------------------------------------------------------------------------- @@ -613,60 +964,170 @@ def layer( # --------------------------------------------------------------------------- -class Emitting(SequenceLayer, metaclass=abc.ABCMeta): - """A SequenceLayer that emits auxiliary tensors.""" +class Emitting( + SequenceLayer, + spec.Emitting[Sequence, Sequence, ChannelSpec], +): + """A SequenceLayer that emits auxiliary arrays. - def step( + This is a convenience subclass that implements step and layer in terms of + step_with_emits and layer_with_emits, so that implementors need only implement + two of the four methods. For emits that are substantially expensive to compute + subclasses can choose to implement all four and save computation in those that + do not produce emits. + """ + + @abc.abstractmethod + @override + def get_initial_state( self, - x: Sequence, - state: State, + batch_size: int, + input_spec: ChannelSpec, *, + training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State]: - y, state, _ = self.step_with_emits(x, state, constants=constants) - return y, state + ) -> State: + ... + + @abc.abstractmethod + @override + def get_output_shape( + self, + input_shape: ShapeLike, + *, + constants: Constants | None = None, + ) -> Shape: + ... + + @abc.abstractmethod + @override + def get_output_dtype( + self, + input_dtype: DType, + *, + constants: Constants | None = None, + ) -> DType: + ... @abc.abstractmethod + @override def step_with_emits( self, x: Sequence, state: State, *, + training: bool, constants: Constants | None = None, ) -> tuple[Sequence, State, Emits]: - pass + ... + @abc.abstractmethod + @override + def layer_with_emits( + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, Emits]: + ... + + @override + def step( + self, + x: Sequence, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, State]: + output, state, _ = self.step_with_emits( + x, state, training=training, constants=constants + ) + return output, state + + @override def layer( - self, x: Sequence, *, constants: Constants | None = None + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, ) -> Sequence: - y, _ = self.layer_with_emits(x, constants=constants) - return y + outputs, _ = self.layer_with_emits( + x, training=training, constants=constants + ) + return outputs + + +class StatelessEmitting( + Emitting, + spec.StatelessEmitting[Sequence, Sequence, ChannelSpec], +): + """A SequenceLayer with no state over time that emits auxiliary arrays. + + Sub-classes must implement: + - layer_with_emits + - get_output_shape + - get_output_dtype + """ @abc.abstractmethod - def layer_with_emits( - self, x: Sequence, *, constants: Constants | None = None - ) -> tuple[Sequence, Emits]: + @override + def get_output_shape( + self, + input_shape: ShapeLike, + *, + constants: Constants | None = None, + ) -> Shape: pass + @abc.abstractmethod + @override + def get_output_dtype( + self, + input_dtype: DType, + *, + constants: Constants | None = None, + ) -> DType: + ... -class StatelessEmitting(Emitting): - """Stateless layer that emits auxiliary tensors.""" - - def step_with_emits( + @abc.abstractmethod + @override + def layer_with_emits( self, x: Sequence, - state: State, *, + training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State, Emits]: - y, emits = self.layer_with_emits(x, constants=constants) - return y, state, emits + ) -> tuple[Sequence, Emits]: + ... + @override def get_initial_state( self, batch_size: int, input_spec: ChannelSpec, *, + training: bool, constants: Constants | None = None, ) -> State: - return () \ No newline at end of file + del batch_size + del input_spec + del training + del constants + return () + + @override + def step_with_emits( + self, + x: Sequence, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, State, Emits]: + outputs, emits = self.layer_with_emits( + x, training=training, constants=constants + ) + return outputs, state, emits diff --git a/sequence_layers/mlx/types_test.py b/sequence_layers/mlx/types_test.py index 6e12492..da6163c 100644 --- a/sequence_layers/mlx/types_test.py +++ b/sequence_layers/mlx/types_test.py @@ -1,62 +1,57 @@ -import mlx.core as mx -import numpy as np -from sequence_layers.abstract import types_test_base -from sequence_layers.mlx import types -from absl.testing import parameterized +"""Tests for MLX sequence types.""" + from absl.testing import absltest +from sequence_layers.mlx import test_utils +from sequence_layers.specs import types_behaviors as spec + -class SequenceTest(types_test_base.SequenceTest): +class ModuleInterfaceTest( + test_utils.SequenceLayerTest, spec.ModuleInterfaceTest +): + pass - def get_backend(self): - return mx - @property - def Sequence(self): - return types.Sequence +class SequenceTest(test_utils.SequenceLayerTest, spec.SequenceTest): + pass - @property - def MaskedSequence(self): - return types.MaskedSequence - def assertAllEqual(self, a, b): - a = np.array(a) if isinstance(a, mx.array) else a - b = np.array(b) if isinstance(b, mx.array) else b - np.testing.assert_array_equal(a, b) +class SequenceLayerConfigTest( + test_utils.SequenceLayerTest, spec.SequenceLayerConfigTest +): + pass - def assertSequencesEqual(self, a, b): - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.mask, b.mask) +class SteppableTest(test_utils.SequenceLayerTest, spec.SteppableTest): + pass -class SteppableTest(types_test_base.SteppableTest): - def create_steppable(self): +class PreservesTypeTest(test_utils.SequenceLayerTest, spec.PreservesTypeTest): + pass - class DefaultSteppable(types.Steppable): - def layer(self, x, *, constants=None): - return x +class PreservesShapeTest(test_utils.SequenceLayerTest, spec.PreservesShapeTest): + pass - def step(self, x, state, *, constants=None): - return x, state - def get_initial_state(self, batch_size, input_spec, *, constants=None): - return () +class StatelessTest(test_utils.SequenceLayerTest, spec.StatelessTest): + pass - def get_output_shape(self, input_shape, *, constants=None): - return input_shape - def get_output_dtype(self, input_dtype, *, constants=None): - return input_dtype +class EmittingTest(test_utils.SequenceLayerTest, spec.EmittingTest): + pass - return DefaultSteppable() +class StatelessEmittingTest( + test_utils.SequenceLayerTest, spec.StatelessEmittingTest +): + pass -class SequenceLayerConfigTest(types_test_base.SequenceLayerConfigTest): - def get_config_base_cls(self): - return types.SequenceLayerConfig +class StatelessPointwiseFunctorTest( + test_utils.SequenceLayerTest, spec.StatelessPointwiseFunctorTest +): + pass if __name__ == '__main__': diff --git a/sequence_layers/specs/__init__.py b/sequence_layers/specs/__init__.py new file mode 100644 index 0000000..0060178 --- /dev/null +++ b/sequence_layers/specs/__init__.py @@ -0,0 +1,38 @@ +# https://typing.python.org/en/latest/spec/protocol.html#modules-as-implementations-of-protocols + +from typing import Protocol, runtime_checkable + +from . import backend as _backend +from . import types as _types + + +@runtime_checkable +class ModuleSpec(Protocol): + """Protocol for a backend-specific SequenceLayers module (sequence_layers. as sl).""" + + @property + def backend(self) -> _backend.ModuleSpec: + ... + + @property + def types(self) -> _types.ModuleSpec: + ... + + # Identifiers that backend-specific implementations should expose at top level. + # Demonstrating read-only allows for covariance (subclasses of types_module.Sequence to satisfy the protocol). + + @property + def Sequence(self) -> type[_types.Sequence]: + ... + + @property + def MaskedSequence(self) -> type[_types.MaskedSequence]: + ... + + @property + def SequenceLayer(self) -> type[_types.SequenceLayer]: + ... + + @property + def SequenceLayerConfig(self) -> type[_types.SequenceLayerConfig]: + ... diff --git a/sequence_layers/specs/backend.py b/sequence_layers/specs/backend.py new file mode 100644 index 0000000..fb64595 --- /dev/null +++ b/sequence_layers/specs/backend.py @@ -0,0 +1,41 @@ +"""Specification for backend-specific helpers.""" + +from typing import Any, Protocol, runtime_checkable + +from sequence_layers.specs import types as types_spec + +Array = types_spec.Array + + +# pylint: disable=invalid-name +class xp(Protocol): + """NumPy-compatible interface to enable generic behavior tests. + + https://numpy.org/doc/stable/reference/routines.html#routines + https://docs.jax.dev/en/latest/jax.numpy.html + """ + + bool_: Any + int32: Any + + def array(self, a: Any, dtype: Any = None) -> Array: + """Creates an array.""" + + def zeros(self, shape: tuple[int, ...], dtype: Any = None) -> Array: + """Creates an array of zeros.""" + + +@runtime_checkable +class ModuleSpec(Protocol): + """Specification for sequence_layers..backend""" + + @property + def xp(self) -> xp: + """Returns the NumPy-compatible interface.""" + + +__all__ = [ + name + for name, attr in ModuleSpec.__dict__.items() + if isinstance(attr, property) +] diff --git a/sequence_layers/specs/test_utils.py b/sequence_layers/specs/test_utils.py new file mode 100644 index 0000000..26b1d8d --- /dev/null +++ b/sequence_layers/specs/test_utils.py @@ -0,0 +1,39 @@ +"""Test utilities for sequence layers.""" + +import abc +from typing import Any + +from absl.testing import parameterized + +from sequence_layers import specs +from sequence_layers.specs import backend as backend_spec +from sequence_layers.specs import types as types_spec + + +class _AbcParameterizedTestCaseMeta(abc.ABCMeta, type(parameterized.TestCase)): + """Metaclass for abstract parameterized test cases.""" + + +class SequenceLayerTest[SequenceT: types_spec.Sequence = types_spec.Sequence]( + parameterized.TestCase, metaclass=_AbcParameterizedTestCaseMeta +): + """Base test class providing common sequence testing assertions. + + Binds a backend implementation to tests. + """ + + # sequence_layers. module + sl: specs.ModuleSpec + + @property + def xp(self) -> backend_spec.xp: + """Returns the backend module.""" + return self.sl.backend.xp + + @abc.abstractmethod + def assertSequencesEqual(self, x: SequenceT, y: SequenceT) -> None: # pylint: disable=invalid-name + """After padding, checks sequence values are equal and masks are equal.""" + + @abc.abstractmethod + def assertAllEqual(self, x: Any, y: Any) -> None: # pylint: disable=invalid-name + """Asserts that two arrays are equal.""" diff --git a/sequence_layers/specs/types.py b/sequence_layers/specs/types.py new file mode 100644 index 0000000..5d5e8be --- /dev/null +++ b/sequence_layers/specs/types.py @@ -0,0 +1,953 @@ +"""Signatures for the types module. + +See the corresponding _behaviors module for behaviors. + +If you are adding a new class or method to be implemented per backend, make +sure to add it to the ModuleSpec protocol. +""" + +import abc +import enum +import fractions +from types import EllipsisType +from typing import (Any, Callable, Concatenate, Iterable, Literal, + MutableMapping, override, Protocol, runtime_checkable, Self, + TypeVar) + +import jaxtyping as jt +import numpy.typing as npt + +ArrayLike = npt.ArrayLike + +Array = jt.Shaped[Any, '...'] + +# Type aliases for generic usage +T = TypeVar('T') +Shape = tuple[int, ...] +ShapeLike = list[int] | tuple[int, ...] +DType = Any # Can be numpy, jax, or mlx dtype + + +class ChannelSpec(Protocol): + """Protocol for channel specifications.""" + + @property + def shape(self) -> Shape: + """The shape of the channel.""" + + @property + def dtype(self) -> Any: + """The dtype of the channel.""" + + +State = Any +Constants = MutableMapping[str, jt.PyTree[Array]] +Emits = jt.PyTree[Array] + + +ValuesT = TypeVar('ValuesT', bound=Array) +MaskT = TypeVar('MaskT', bound=Array) +ChannelSpecT = TypeVar('ChannelSpecT', bound=ChannelSpec) + +LengthsT = TypeVar('LengthsT', bound=Array) + +InputT = TypeVar('InputT', bound='Sequence') +OutputT = TypeVar('OutputT', bound='Sequence') + + +class PaddingMode(enum.Enum): + """Supported padding modes.""" + + # In VALID padding mode, no padding is applied. + # + # Key properties: + # * The physical length of an input array to a VALID padded function shrinks, + # dropping any timesteps whose inputs are computed from implicit edge + # padding. + # * An output timestep is valid when all of its input timesteps are also + # valid. + VALID = 'valid' + + # In SAME padding mode, the input sequence is padded such that the output + # length is equal to the input length before applying striding. + # + # Key properties: + # * The input length is equal to the output length, before applying striding. + # * Padding of `effective_kernel_size - 1` is applied. Half is applied to the + # front and half to the back. If `effective_kernel_size` is even, the extra + # padding is added to the end. + # * An output timestep is valid when its corresponding input timestep is + # valid. + SAME = 'same' + + # In CAUSAL_VALID padding mode, the input sequence is padded such that the + # output length is equal to the input length before applying striding. Padding + # is applied such that the output timestep `to` can only depend on input + # timesteps at or before `ti` where `ti * output_ratio = to`. + # + # Key properties: + # * As in SAME padding, the input length is equal to the output length, before + # applying striding. + # * Padding of `effective_kernel_size - 1` is applied to the front of the + # sequence. + # * As in VALID padding, an output timestep is valid iff all of its input + # timesteps are also valid. + CAUSAL_VALID = 'causal_valid' + + # In REVERSE_CAUSAL_VALID padding mode, the input sequence is padded such that + # the output length is equal to the input length before applying striding. + # Padding is applied such that the output timestep `to` can only depend on + # input timesteps at or after `ti` where `ti * output_ratio = to`. + # + # Key properties: + # * As in SAME padding, the input length is equal to the output length, before + # applying striding. + # * Padding of `effective_kernel_size - 1` is applied to the back of the + # sequence. + REVERSE_CAUSAL_VALID = 'reverse_causal_valid' + + # In CAUSAL padding mode, the input sequence is padded such that the output + # length is equal to the input length before applying striding. Padding is + # applied such that the output timestep `to` can only depend on input + # timesteps at or before `ti` where `ti * output_ratio = to`. + # + # Key properties: + # * As in SAME padding, the input length is equal to the output length, before + # applying striding. + # * Padding of `effective_kernel_size - 1` is applied to the front of the + # sequence. + # * As in SAME padding, an output timestep is valid when its corresponding + # input timestep is valid. + CAUSAL = 'causal' + + # In REVERSE_CAUSAL padding mode, the input sequence is padded such that the + # output length is equal to the input length before applying striding. Padding + # is applied such that the output timestep `to` can only depend on input + # timesteps at or after `ti` where `ti * output_ratio = to`. + # + # Key properties: + # * As in SAME padding, the input length is equal to the output length, before + # applying striding. + # * Padding of `effective_kernel_size - 1` is applied to the back of the + # sequence. + # * As in SAME padding, an output timestep is valid when its corresponding + # input timestep is valid. + REVERSE_CAUSAL = 'reverse_causal' + + # In SEMICAUSAL padding mode, the input sequence is padded such that the + # output length is equal to the input length before applying striding. Padding + # is applied such that the output timestep `to` can only depend on input + # timesteps at or before `ti` where `ti * output_ratio = to`. + # + # Key properties: + # * As in SAME padding, the input length is equal to the output length, before + # applying striding. + # * Padding of `effective_kernel_size - stride` is applied to the front of the + # sequence, and padding of `stride - 1` timesteps is applied to the back of + # the sequence for a total of `effective_kernel_size - 1` timesteps of + # padding. If `effective_kernel_size` < `stride`, then padding of + # `effective_kernel_size - 1` is applied to the back of the sequence. + # * As in SAME padding, an output timestep is valid when its corresponding + # input timestep is valid. + SEMICAUSAL = 'semicausal' + + # In SEMICAUSAL_FULL padding mode, the input sequence is padded such that the + # output of the corresponding overlap-add or transpose convolution is of the + # same size as the input sequence and perfect reconstruction can be achieved. + # The reconstructed signal is of the same length or of length rounded up to + # cover the full input sequence. + SEMICAUSAL_FULL = 'semicausal_full' + + +PaddingModeString = Literal[ + 'valid', + 'same', + 'causal_valid', + 'reverse_causal_valid', + 'causal', + 'reverse_causal', + 'semicausal', + 'semicausal_full', +] + + +class Sequence[ValuesT = Array, MaskT = Array](metaclass=abc.ABCMeta): + """Abstract base class for Sequence.""" + + values: ValuesT + mask: MaskT + + def __init__(self, values: ValuesT, mask: MaskT): + raise NotImplementedError('Subclasses must implement __init__') + + @property + @abc.abstractmethod + def shape(self) -> Shape: + """The shape of the sequence as (batch, time, ...channels).""" + + @property + @abc.abstractmethod + def ndim(self) -> int: + """The number of dimensions of the sequence values.""" + + @property + @abc.abstractmethod + def channel_shape(self) -> Shape: + """The shape of the channels in the sequence.""" + + @property + @abc.abstractmethod + def dtype(self) -> DType: + """The dtype of the sequence values.""" + + @classmethod + @abc.abstractmethod + def from_values(cls, values: ValuesT) -> Self: + """Creates a Sequence from values with a default mask.""" + + @classmethod + @abc.abstractmethod + def from_lengths( + cls, + values: ValuesT, + lengths: LengthsT, + is_masked: bool = False, + ) -> Self: + """Creates a Sequence from values and lengths.""" + + @classmethod + @abc.abstractmethod + def concatenate_sequences(cls, sequences: Iterable[Self]) -> Self: + """Concatenates multiple sequences into one.""" + + @abc.abstractmethod + def expanded_mask(self) -> Any: + """Returns the mask expanded to the shape of the values.""" + + @abc.abstractmethod + def apply_values[NewValuesT: Array, **P]( + self, + values_fn: Callable[Concatenate[ValuesT, P], NewValuesT], + *args: P.args, + **kwargs: P.kwargs, + ) -> 'Sequence[NewValuesT, MaskT]': + """Applies a function to the sequence values.""" + + @abc.abstractmethod + def apply_values_masked[NewValuesT: Array, **P]( + self, + values_fn: Callable[Concatenate[ValuesT, P], NewValuesT], + *args: P.args, + **kwargs: P.kwargs, + ) -> 'Sequence[NewValuesT, MaskT]': + """Applies a function to the sequence values, respecting the mask.""" + + @abc.abstractmethod + def apply[NewValuesT: Array, NewMaskT: Array, **P]( + self, + apply_fn: Callable[Concatenate[ValuesT, P], tuple[NewValuesT, NewMaskT]], + *args: P.args, + **kwargs: P.kwargs, + ) -> 'Sequence[NewValuesT, NewMaskT]': + """Applies a function to both values and mask.""" + + @abc.abstractmethod + def apply_masked[NewValuesT: Array, NewMaskT: Array, **P]( + self, + apply_fn: Callable[Concatenate[ValuesT, P], tuple[NewValuesT, NewMaskT]], + *args: P.args, + **kwargs: P.kwargs, + ) -> 'Sequence[NewValuesT, NewMaskT]': + """Applies a function to values and mask, respecting the mask.""" + + @abc.abstractmethod + def astype(self, dtype: DType | None) -> Self: + """Returns a copy of the sequence with a new dtype.""" + + @abc.abstractmethod + def lengths(self) -> Any: + """Returns the lengths of the sequences in the batch.""" + + @abc.abstractmethod + def __getitem__( + self, + the_slice: slice | tuple[int | slice | None | EllipsisType, ...], + ) -> Self: + ... + + @abc.abstractmethod + def pad_time( + self, + pad_left: int, + pad_right: int, + valid: bool, + pad_value: Any | None = None, + ) -> Self: + """Pads the sequence along the time dimension.""" + + @abc.abstractmethod + def concatenate(self, other: Self) -> Self: + """Concatenates another sequence to this one.""" + + @abc.abstractmethod + def mask_invalid( + self, mask_value: Any | None = None + ) -> 'Sequence[ValuesT, MaskT]': + """Returns a MaskedSequence with invalid timesteps zeroed.""" + + @abc.abstractmethod + def unmask(self) -> 'Sequence[ValuesT, MaskT]': + """Returns a Sequence with no masking applied.""" + + +class MaskedSequence(Sequence[ValuesT, MaskT]): + """A sequence whose invalid timesteps are masked to zero.""" + + @abc.abstractmethod + @override + def apply_values_masked[NewValuesT: Array, **P]( + self, + values_fn: Callable[Concatenate[ValuesT, P], NewValuesT], + *args: P.args, + **kwargs: P.kwargs, + ) -> 'MaskedSequence[NewValuesT, MaskT]': + ... + + @abc.abstractmethod + @override + def apply_masked[NewValuesT: Array, NewMaskT: Array, **P]( + self, + apply_fn: Callable[Concatenate[ValuesT, P], tuple[NewValuesT, NewMaskT]], + *args: P.args, + **kwargs: P.kwargs, + ) -> 'MaskedSequence[NewValuesT, NewMaskT]': + ... + + +class SequenceLayerConfig(metaclass=abc.ABCMeta): + """Configuration for a SequenceLayer.""" + + @abc.abstractmethod + def make(self) -> Any: + """Creates the sequence layer.""" + + @abc.abstractmethod + def copy(self, **kwargs: Any) -> Self: + """Returns a copy of the config with updated fields.""" + + +class Steppable[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](metaclass=abc.ABCMeta): + """A sequence processing layer that can be executed layerwise or stepwise. + + The backend must implement: + - layer_with_emits + - step_with_emits + """ + + @property + @abc.abstractmethod + def block_size(self) -> int: + """The block size this layer processes at once.""" + + @property + @abc.abstractmethod + def output_ratio(self) -> fractions.Fraction: + """The ratio of output timesteps to input timesteps.""" + + @property + @abc.abstractmethod + def supports_step(self) -> bool: + """Returns True if the layer supports stepwise processing.""" + + @property + @abc.abstractmethod + def input_latency(self) -> int: + """The number of future timesteps required for the current output.""" + + @property + @abc.abstractmethod + def output_latency(self) -> int: + """The number of timesteps the output is delayed.""" + + @abc.abstractmethod + def get_accumulated_input_latency(self, input_latency: int) -> int: + """Calculates the total input latency including previous layers.""" + + @abc.abstractmethod + def get_accumulated_output_latency(self, output_latency: int) -> int: + """Calculates the total output latency including previous layers.""" + + @abc.abstractmethod + def layer( + self, x: InputT, *, training: bool, constants: Constants | None = None + ) -> OutputT: + """Process this layer layer-wise. + + Args: + x: Input sequence with values shaped [b, t_i, ...]. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the source sequence to attend to. + + Returns: + y: The outputs corresponding to this layer with values shaped + [b, t_o, ...] where `t_o == t_i * output_ratio`. t_o may have been + truncated to only represent valid frames. + """ + + @abc.abstractmethod + def layer_with_emits( + self, + x: InputT, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[OutputT, Emits]: + """Process this layer layer-wise, producing emitted arrays. + + This is like `layer`, except it has an additional return value which is the + "emitted" arrays for the layer. The emitted arrays are a structure of + arrays whose values are arrays or `Sequence`s. + + Args: + x: Input sequence with values shaped [b, t_i, ...]. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the key/value sequence to attend to. + + Returns: + y: The outputs corresponding to this layer with values shaped + [b, t_o, ...] where `t_o == t_i * output_ratio`. t_o may have been + truncated to only represent valid frames. + emits: A nest of emitted arrays or Sequences. + """ + + @abc.abstractmethod + def step( + self, + x: InputT, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[OutputT, State]: + """Process this layer step-wise. + + Args: + x: Input sequence with values shaped [b, t_i, ...], where t_i is a + multiple of block_size. + state: A structure of state arrays matching get_initial_state. The + previous state for this layer. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the key/value sequence to attend to. + + Returns: + y: The outputs corresponding to this step with values shaped [b, t_o, ...] + where `t_o == t_i * output_ratio`. + state: A structure of state arrays matching get_initial_state. The + new state for this layer. + """ + + @abc.abstractmethod + def step_with_emits( + self, + x: InputT, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[OutputT, State, Emits]: + """Process this layer step-wise, producing emitted arrays. + + This is like `step`, except it has an additional return value which is the + "emitted" arrays for the step. The emitted arrays are a structure of + arrays whose values are arrays or `Sequence`s. + + Args: + x: Input sequence with values shaped [b, t_i, ...], where t_i is a + multiple of block_size. + state: A structure of state arrays matching get_initial_state. The + previous state for this layer. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the key/value sequence to attend to. + + Returns: + y: The outputs corresponding to this step with values shaped [b, t_o, ...] + where `t_o == t_i * output_ratio`. + state: A structure of state arrays matching get_initial_state. The + new state for this layer. + emits: A nest of emitted arrays or Sequences. + """ + + @abc.abstractmethod + def get_initial_state( + self, + batch_size: int, + input_spec: ChannelSpecT, + *, + training: bool, + constants: Constants | None = None, + ) -> State: + """Returns the initial state for this SequenceLayer. + + Args: + batch_size: The batch size to create state for. + input_spec: An input ChannelSpec representing the channel shape and dtype + of the input that will be stepped. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the source sequence to attend to. + + Returns: + An integer, shape, or structure of integer/shapes. + """ + + @abc.abstractmethod + def get_output_shape( + self, + input_shape: ShapeLike, + *, + constants: Constants | None = None, + ) -> Shape: + """Returns the output channel shape this layer produces for an input channel shape. + + Args: + input_shape: A shape representing the channels dimension of the input + sequence (i.e. not including the batch or time dimension). + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the source sequence to attend to. + + Returns: + A shape representing the output channels dimensions (i.e. not including + the batch or time dimension). + """ + + @abc.abstractmethod + def get_output_dtype( + self, + input_dtype: DType, + *, + constants: Constants | None = None, + ) -> DType: + """Returns the layer's output dtype for the specified input dtype. + + Args: + input_dtype: The dtype of the input features. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. + + Returns: + The dtype of the output features. + """ + + @abc.abstractmethod + def get_output_spec( + self, + input_spec: ChannelSpecT, + *, + constants: Constants | None = None, + ) -> ChannelSpec: + """Returns the output spec this layer produces for the provided input spec. + + Args: + input_spec: A ChannelSpec which represents the channels shape and dtype of + the input sequence (i.e. not including the batch or time dimension). + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. + + Returns: + The ChannelSpec of the output features. + """ + + @property + @abc.abstractmethod + def receptive_field(self) -> Any: + """The receptive field of the layer.""" + + +class SequenceLayer[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](Steppable[InputT, OutputT, ChannelSpecT], metaclass=abc.ABCMeta): + """Base class for Sequence Layers.""" + + +# --------------------------------------------------------------------------- +# Mixins +# --------------------------------------------------------------------------- + + +class PreservesType[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](SequenceLayer[InputT, OutputT, ChannelSpecT]): + """A mix-in for layers that do not change the input dtype.""" + + @abc.abstractmethod + @override + def get_output_dtype( + self, + input_dtype: DType, + *, + constants: Constants | None = None, + ) -> DType: + ... + + +class PreservesShape[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](SequenceLayer[InputT, OutputT, ChannelSpecT]): + """A mix-in for layers that do not change the input channel shape.""" + + @abc.abstractmethod + @override + def get_output_shape( + self, + input_shape: ShapeLike, + *, + constants: Constants | None = None, + ) -> Shape: + ... + + +# --------------------------------------------------------------------------- +# Stateless variants +# --------------------------------------------------------------------------- + + +class Stateless[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](SequenceLayer[InputT, OutputT, ChannelSpecT]): + """A layer with no state over time required for step-wise processing. + + The backend must implement: + - get_initial_state + - step + Further sub-classes must only implement: + - layer + - get_output_shape + - get_output_dtype + """ + + @abc.abstractmethod + @override + def get_output_shape( + self, input_shape: ShapeLike, *, constants: Constants | None = None + ) -> Shape: + ... + + @abc.abstractmethod + @override + def get_output_dtype( + self, input_dtype: DType, *, constants: Constants | None = None + ) -> DType: + ... + + @abc.abstractmethod + @override + def layer( + self, + x: InputT, + *, + training: bool, + constants: Constants | None = None, + ) -> OutputT: + ... + + @abc.abstractmethod + @override + def get_initial_state( + self, + batch_size: int, + input_spec: ChannelSpecT, + *, + training: bool, + constants: Constants | None = None, + ) -> State: + ... + + @abc.abstractmethod + @override + def step( + self, + x: InputT, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[OutputT, State]: + ... + + +class StatelessPointwise[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +]( + PreservesShape[InputT, OutputT, ChannelSpecT], + Stateless[InputT, OutputT, ChannelSpecT], + metaclass=abc.ABCMeta, +): + """Stateless layer that operates pointwise (preserves shape).""" + + +class StatelessPointwiseFunctor[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](StatelessPointwise[InputT, OutputT, ChannelSpecT]): + """Stateless pointwise layer defined by a fn(values, mask). + + The backend must implement: + - layer + Further sub-classes must only implement: + - fn + - mask_required + """ + + @abc.abstractmethod + def fn(self, values: Any, mask: Any) -> tuple[Any, Any]: + """Transforms each scalar in values independently.""" + + @property + @abc.abstractmethod + def mask_required(self) -> bool: + """Returns true if fn can change the sequence's masked state. + + If fn(0) -> 0, then mask_required() is False. + """ + + @abc.abstractmethod + @override + def layer( + self, + x: InputT, + *, + training: bool, + constants: Constants | None = None, + ) -> OutputT: + ... + + +# --------------------------------------------------------------------------- +# Emitting variants +# --------------------------------------------------------------------------- + + +class Emitting[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](SequenceLayer[InputT, OutputT, ChannelSpecT]): + """A Steppable layer that emits auxiliary arrays. + + This is a convenience subclass that implements step and layer in terms of + step_with_emits and layer_with_emits. + + The backend must implement: + - step + - layer + Further sub-classes must only implement: + - step_with_emits + - layer_with_emits + """ + + @abc.abstractmethod + @override + def step( + self, + x: InputT, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[OutputT, State]: + ... + + @abc.abstractmethod + @override + def layer( + self, + x: InputT, + *, + training: bool, + constants: Constants | None = None, + ) -> OutputT: + ... + + @abc.abstractmethod + @override + def step_with_emits( + self, + x: InputT, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[OutputT, State, Emits]: + ... + + @abc.abstractmethod + @override + def layer_with_emits( + self, + x: InputT, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[OutputT, Emits]: + ... + + +class StatelessEmitting[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](Emitting[InputT, OutputT, ChannelSpecT]): + """A Steppable layer with no state over time that emits auxiliary arrays. + + The backend must implement: + - get_initial_state + - step_with_emits + Further sub-classes must only implement: + - layer_with_emits + - get_output_shape + - get_output_dtype + """ + + @abc.abstractmethod + @override + def get_initial_state( + self, + batch_size: int, + input_spec: ChannelSpecT, + *, + training: bool, + constants: Constants | None = None, + ) -> State: + ... + + @abc.abstractmethod + @override + def step_with_emits( + self, + x: InputT, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[OutputT, State, Emits]: + ... + + @abc.abstractmethod + @override + def get_output_shape( + self, input_shape: ShapeLike, *, constants: Constants | None = None + ) -> Shape: + ... + + @abc.abstractmethod + @override + def get_output_dtype( + self, input_dtype: DType, *, constants: Constants | None = None + ) -> DType: + ... + + @abc.abstractmethod + @override + def layer_with_emits( + self, + x: InputT, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[OutputT, Emits]: + ... + + +@runtime_checkable +class ModuleSpec(Protocol): + """Specification for sequence_layers..types""" + + # pylint: disable=invalid-name + + @property + def Sequence(self) -> type[Sequence]: + """The Sequence class for this backend.""" + + @property + def MaskedSequence(self) -> type[MaskedSequence]: + """The MaskedSequence class for this backend.""" + + @property + def SequenceLayer(self) -> type[SequenceLayer]: + """The SequenceLayer class for this backend.""" + + @property + def SequenceLayerConfig(self) -> type[SequenceLayerConfig]: + """The SequenceLayerConfig class for this backend.""" + + @property + def Steppable(self) -> type[Steppable]: + """The Steppable class for this backend.""" + + @property + def PreservesShape(self) -> type[PreservesShape]: + """The PreservesShape class for this backend.""" + + @property + def Stateless(self) -> type[Stateless]: + """The Stateless class for this backend.""" + + @property + def StatelessPointwise(self) -> type[StatelessPointwise]: + """The StatelessPointwise class for this backend.""" + + @property + def StatelessPointwiseFunctor(self) -> type[StatelessPointwiseFunctor]: + """The StatelessPointwiseFunctor class for this backend.""" + + @property + def PreservesType(self) -> type[PreservesType]: + """The PreservesType class for this backend.""" + + @property + def Emitting(self) -> type[Emitting]: + """The Emitting class for this backend.""" + + @property + def StatelessEmitting(self) -> type[StatelessEmitting]: + """The StatelessEmitting class for this backend.""" + + +__all__ = [ + name + for name, attr in ModuleSpec.__dict__.items() + if isinstance(attr, property) +] diff --git a/sequence_layers/specs/types_behaviors.py b/sequence_layers/specs/types_behaviors.py new file mode 100644 index 0000000..a1bb773 --- /dev/null +++ b/sequence_layers/specs/types_behaviors.py @@ -0,0 +1,772 @@ +# pylint: disable=abstract-method +"""Generic tests for Sequence types.""" + +import dataclasses +import fractions +from typing import Any, NamedTuple, override +import unittest.mock + +from absl.testing import parameterized +import numpy as np + +from sequence_layers.specs import types as types_spec +from sequence_layers.specs.test_utils import SequenceLayerTest + + +class DummyChannelSpec(NamedTuple): + """Dummy channel spec for testing.""" + + shape: types_spec.Shape + dtype: types_spec.DType + + +class DefaultTestLayer(types_spec.SequenceLayer): + """A default test layer for testing.""" + + @property + @override + def block_size(self) -> int: + return 1 + + @property + @override + def output_ratio(self) -> fractions.Fraction: + return fractions.Fraction(1) + + @property + @override + def supports_step(self) -> bool: + return True + + @property + @override + def input_latency(self) -> int: + return 0 + + @property + @override + def output_latency(self) -> int: + return 0 + + @property + @override + def receptive_field(self) -> Any: + return 1 + + @override + def get_accumulated_input_latency(self, input_latency: int) -> int: + return input_latency + + @override + def get_accumulated_output_latency(self, output_latency: int) -> int: + return output_latency + + @override + def layer( + self, + x: types_spec.Sequence, + *, + training: bool, + constants: types_spec.Constants | None = None, + ) -> types_spec.Sequence: + return x + + @override + def layer_with_emits( + self, + x: types_spec.Sequence, + *, + training: bool, + constants: types_spec.Constants | None = None, + ) -> tuple[types_spec.Sequence, types_spec.Emits]: + return self.layer(x, training=training, constants=constants), ( + 'test_emits', + ) + + @override + def step( + self, + x: types_spec.Sequence, + state: types_spec.State, + *, + training: bool, + constants: types_spec.Constants | None = None, + ) -> tuple[types_spec.Sequence, types_spec.State]: + return x, ('new_test_state',) + + @override + def step_with_emits( + self, + x: types_spec.Sequence, + state: types_spec.State, + *, + training: bool, + constants: types_spec.Constants | None = None, + ) -> tuple[types_spec.Sequence, types_spec.State, types_spec.Emits]: + return *self.step(x, state, training=training, constants=constants), ( + 'test_emits', + ) + + @override + def get_initial_state( + self, + batch_size: int, + input_spec: types_spec.ChannelSpec, + *, + training: bool, + constants: types_spec.Constants | None = None, + ) -> types_spec.State: + return ('test_state',) + + @override + def get_output_shape( + self, + input_shape: types_spec.ShapeLike, + *, + constants: types_spec.Constants | None = None, + ) -> types_spec.Shape: + return tuple(input_shape) + (1,) + + @override + def get_output_dtype( + self, + input_dtype: types_spec.DType, + *, + constants: types_spec.Constants | None = None, + ) -> types_spec.DType: + return np.float64 + + @override + def get_output_spec( + self, + input_spec: Any, + *, + constants: types_spec.Constants | None = None, + ) -> Any: + shape = self.get_output_shape(input_spec.shape, constants=constants) + dtype = self.get_output_dtype(input_spec.dtype, constants=constants) + return DummyChannelSpec(shape, dtype) + + +class ModuleInterfaceTest(SequenceLayerTest): + + def test_backend_specific_module_has_interface(self) -> None: + self.assertIsInstance(self.sl.types, types_spec.ModuleSpec) + + +class SequenceTest(SequenceLayerTest): + """Generic tests for the Sequence class.""" + + @parameterized.named_parameters( + ('mask_value=None', 0.0, None), + ('mask_value=0.0', 0.0, 0.0), + ('mask_value=-1.0', -1.0, -1.0), + ) + def test_mask_invalid( + self, mask_value: float, expected_mask_value: float | None + ) -> None: + values = self.xp.array([ + [1.0, 2.0, 3.0, 4.0], + [10.0, 20.0, 30.0, 40.0], + ]) + mask = self.xp.array( + [[True, True, False, False], [False, False, False, True]] + ) + + # Pass mask_value only if it is not None (to test default None behavior vs explicit value) + if expected_mask_value is None: + output = self.sl.Sequence(values, mask).mask_invalid() + fill_value = 0.0 + else: + output = self.sl.Sequence(values, mask).mask_invalid(mask_value) + fill_value = mask_value + + expected_values = self.xp.array([ + [1.0, 2.0, fill_value, fill_value], + [fill_value, fill_value, fill_value, 40.0], + ]) + self.assertAllEqual(output.values, expected_values) + self.assertAllEqual(output.mask, mask) + + def test_pad_time(self) -> None: + values = self.xp.array([ + [1.0, 2.0, 3.0, 4.0], + [10.0, 20.0, 30.0, 40.0], + ]) + mask = self.xp.array( + [[True, True, False, False], [False, False, False, True]] + ) + + x = self.sl.Sequence(values, mask).mask_invalid() + + y = x.pad_time(0, 0, valid=False) + self.assertAllEqual(y.values, x.values) + self.assertAllEqual(y.mask, x.mask) + + y = x.pad_time(1, 0, valid=False) + + x_left1 = self.sl.Sequence( + self.xp.array([ + [0.0, 1.0, 2.0, 3.0, 4.0], + [0.0, 10.0, 20.0, 30.0, 40.0], + ]), + self.xp.array([ + [False, True, True, False, False], + [False, False, False, False, True], + ]), + ).mask_invalid() + self.assertAllEqual(y.values, x_left1.values) + self.assertAllEqual(y.mask, x_left1.mask) + + def _create_test_sequence( + self, shape: types_spec.Shape + ) -> types_spec.Sequence[types_spec.Array, types_spec.Array]: + """Creates a test sequence with specific shape.""" + size = 1 + for d in shape: + size *= d + values_np = np.arange(size, dtype=np.float32).reshape(shape) + mask_np = np.ones(shape[:2], dtype=bool) + if shape[0] > 0 and shape[1] > 1: + mask_np[0, 1] = False + + values = self.xp.array(values_np) + mask = self.xp.array(mask_np) + return self.sl.Sequence(values, mask) + + def test_slice(self) -> None: + x = self._create_test_sequence((3, 5, 9)) + + self.assertSequencesEqual( + x[:, 1:], self.sl.Sequence(x.values[:, 1:], x.mask[:, 1:]) + ) + self.assertSequencesEqual( + x[:, ::2], self.sl.Sequence(x.values[:, ::2], x.mask[:, ::2]) + ) + self.assertSequencesEqual( + x[::2, ::3], self.sl.Sequence(x.values[::2, ::3], x.mask[::2, ::3]) + ) + + def test_slice_can_slice_channel_dimensions(self) -> None: + x = self._create_test_sequence((3, 5, 9, 4)) + + self.assertSequencesEqual( + x[:, 1:, :], self.sl.Sequence(x.values[:, 1:], x.mask[:, 1:]) + ) + self.assertSequencesEqual( + x[:, ::2, :3], + self.sl.Sequence(x.values[:, ::2, :3], x.mask[:, ::2]), + ) + + def test_apply_values(self) -> None: + values = self.xp.array([ + [-1.0, 2.0, 3.0, 4.0], + [10.0, -20.0, 30.0, 40.0], + ]) + mask = self.xp.array( + [[True, True, False, False], [False, True, False, True]] + ) + + x = self.sl.Sequence(values, mask) + masked = x.mask_invalid() + + # Simple abs function + fn = abs + + y = x.apply_values(fn) + self.assertAllEqual(y.values, fn(x.values)) + self.assertAllEqual(y.mask, x.mask) + + y = masked.apply_values(fn) + self.assertAllEqual(y.values, fn(masked.values)) + self.assertAllEqual(y.mask, x.mask) + + y = masked.apply_values_masked(fn) + self.assertAllEqual(y.values, fn(masked.values)) + self.assertAllEqual(y.mask, x.mask) + + def test_apply_values_args(self) -> None: + values = self.xp.array([ + [-1.0, 2.0, 3.0, 4.0], + [10.0, -20.0, 30.0, 40.0], + ]) + mask = self.xp.array( + [[True, True, False, False], [False, True, False, True]] + ) + x = self.sl.Sequence(values, mask) + + target_shape = (2, 4, 1) + y = x.apply_values(lambda v, s: v.reshape(s), target_shape) + self.assertAllEqual(y.values.shape, target_shape) + self.assertAllEqual(y.mask.shape, (2, 4)) + + def test_from_values(self) -> None: + values_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + values = self.xp.array(values_np) + x = self.sl.Sequence.from_values(values) + self.assertAllEqual(x.values, values) + self.assertAllEqual( + x.mask, self.xp.array(np.ones(values.shape[:2], dtype=bool)) + ) + + def test_astype(self) -> None: + values_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + mask_np = np.array([[True, False], [False, True]], dtype=bool) + + values = self.xp.array(values_np) + mask = self.xp.array(mask_np) + + x = self.sl.Sequence(values, mask) + + y = x.astype(self.xp.int32) + + # Check values match casted version + self.assertAllEqual(y.mask, mask) + # y.values might be mlx array, values.astype(dtype) might be numpy if values was numpy? + # values is backend array. values.astype(dtype) should work if dtype is backend dtype. + self.assertAllEqual(y.values, values.astype(self.xp.int32)) + + def test_mask_invalid_idempotent(self) -> None: + values = self.xp.array([ + [1.0, 2.0, 3.0, 4.0], + [10.0, 20.0, 30.0, 40.0], + ]) + mask = self.xp.array( + [[True, True, False, False], [False, False, False, True]] + ) + + x = self.sl.Sequence(values, mask) + masked = x.mask_invalid() + self.assertIsNot(masked, x) + self.assertIsInstance(masked, self.sl.MaskedSequence) + + masked_again = masked.mask_invalid() + self.assertIs(masked_again, masked) + self.assertIsInstance(masked_again, self.sl.MaskedSequence) + + masked2 = x.mask_invalid() + self.assertIsNot(masked2, masked) + self.assertIsInstance(masked2, self.sl.MaskedSequence) + + def test_from_lengths(self) -> None: + """Tests creating a sequence from lengths.""" + values = self.xp.array( + np.arange(5 * 17 * 2).reshape((5, 17, 2)).astype(np.float32) + ) + lengths_np = np.array([0, 5, 10, 17, 12], dtype=np.int32) + mask_np = np.arange(17)[None, :] < lengths_np[:, None] + mask = self.xp.array(mask_np) + + x_expected = self.sl.Sequence(values, mask) + x = self.sl.Sequence.from_lengths(x_expected.values, lengths_np) + self.assertAllEqual(x.values, x_expected.values) + self.assertAllEqual(x.mask, x_expected.mask) + + # Out of range lengths are clipped to 0 or max. + x = self.sl.Sequence.from_lengths( + x_expected.values, self.xp.array([-1, 5, 10, 17, 18]) + ) + self.assertAllEqual(x.lengths(), self.xp.array([0, 5, 10, 17, 17])) + self.assertNotIsInstance(x, self.sl.MaskedSequence) + + # Return type is MaskedSequence if is_masked=True. + x = self.sl.Sequence.from_lengths( + x_expected.values, [-1, 5, 10, 17, 18], is_masked=True + ) + self.assertAllEqual(x.lengths(), self.xp.array([0, 5, 10, 17, 17])) + self.assertIsInstance(x, self.sl.MaskedSequence) + + +class SteppableTest(SequenceLayerTest): + + def create_steppable(self) -> types_spec.Steppable: + """Creates a basic Steppable instance.""" + backend_sl = self.sl + + class DefaultSteppable(DefaultTestLayer, backend_sl.types.Steppable): + """Mock layer for testing.""" + + @override + def layer_with_emits(self, *args, **kwargs): + return backend_sl.types.Steppable.layer_with_emits( + self, *args, **kwargs + ) + + @override + def step_with_emits(self, *args, **kwargs): + return backend_sl.types.Steppable.step_with_emits(self, *args, **kwargs) + + return DefaultSteppable() + + def test_steppable_defaults(self) -> None: + layer = self.create_steppable() + self.assertEqual(layer.block_size, 1) + self.assertEqual(layer.output_ratio, fractions.Fraction(1)) + self.assertTrue(layer.supports_step) + self.assertEqual(layer.input_latency, 0) + self.assertEqual(layer.output_latency, 0) + self.assertEqual(layer.get_accumulated_input_latency(0), 0) + self.assertEqual(layer.get_accumulated_output_latency(0), 0) + + def test_get_output_spec(self) -> None: + layer = self.create_steppable() + input_spec = DummyChannelSpec(shape=(2, 3), dtype=np.float32) + output_spec = layer.get_output_spec(input_spec) + self.assertEqual(output_spec.shape, (2, 3, 1)) + self.assertEqual(output_spec.dtype, np.float64) + + def create_sequence(self) -> types_spec.Sequence: + """Creates a test sequence.""" + return self.sl.Sequence( + self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) + ) + + def test_steppable_with_emits_defaults_to_tuple_with_empty_emits( + self, + ) -> None: + layer = self.create_steppable() + seq = self.create_sequence() + state_in = {'a': 'b'} + state_out = {1: 2} + + with unittest.mock.patch.object( + layer, 'layer', return_value=seq + ) as mock_layer: + out, emits = layer.layer_with_emits(seq, training=False, constants=None) + self.assertEqual(out, seq) + self.assertEqual(emits, ()) + mock_layer.assert_called_with(seq, training=False, constants=None) + + with unittest.mock.patch.object( + layer, 'step', return_value=(seq, state_out) + ) as mock_step: + out, state, emits = layer.step_with_emits( + seq, state_in, training=True, constants=None + ) + self.assertEqual(out, seq) + self.assertEqual(state, state_out) + self.assertEqual(emits, ()) + mock_step.assert_called_with(seq, state_in, training=True, constants=None) + + +class SequenceLayerConfigTest(SequenceLayerTest): + + def test_copy(self) -> None: + backend_sl = self.sl + + @dataclasses.dataclass(frozen=True) + class Config(backend_sl.SequenceLayerConfig): + """Mock config.""" + + a: int = 1234 + b: str = 'default string' + + @override + def make(self) -> Any: + """Makes a dummy layer.""" + return 'dummy_layer' + + config = Config() # type: ignore + new_config = config.copy(b='new string') + self.assertEqual(new_config.a, config.a) + self.assertEqual(new_config.b, 'new string') + + def test_copy_raises_on_non_dataclass(self) -> None: + backend_sl = self.sl + + class NonDataclassConfig(backend_sl.SequenceLayerConfig): # pylint: disable=too-few-public-methods + """Non-dataclass mock config.""" + + @override + def make(self) -> Any: + """Makes a dummy layer.""" + return 'dummy_layer' + + config = NonDataclassConfig() # type: ignore + with self.assertRaises(TypeError): + new_config = config.copy() + del new_config + + def test_copy_disallows_new_fields(self) -> None: + backend_sl = self.sl + + @dataclasses.dataclass(frozen=True) + class Config(backend_sl.SequenceLayerConfig): + """Mock config.""" + + @override + def make(self) -> Any: + """Makes a dummy layer.""" + return 'dummy_layer' + + config = Config() # type: ignore + # dataclasses.replace raises TypeError for unknown arguments + # JAX implementation wraps it in AttributeError + with self.assertRaises((TypeError, AttributeError)): + new_config = config.copy(field_does_not_exist=1234) + del new_config + + +class PreservesTypeTest(SequenceLayerTest): + + def create_layer(self) -> types_spec.PreservesType: + """Creates a preserves type layer.""" + backend_sl = self.sl + + class DummyLayer(DefaultTestLayer, backend_sl.types.PreservesType): + """Mock layer for testing.""" + + @override + def get_output_dtype(self, *args, **kwargs): + return backend_sl.types.PreservesType.get_output_dtype( + self, *args, **kwargs + ) + + return DummyLayer() + + def test_preserves_dtype(self) -> None: + layer = self.create_layer() + self.assertEqual(layer.get_output_dtype('fake_dtype123'), 'fake_dtype123') + + +class PreservesShapeTest(SequenceLayerTest): + + def create_layer(self) -> types_spec.PreservesShape: + """Creates a preserves shape layer.""" + backend_sl = self.sl + + class DummyLayer(DefaultTestLayer, backend_sl.types.PreservesShape): + """Mock layer for testing.""" + + @override + def get_output_shape(self, *args, **kwargs): + return backend_sl.types.PreservesShape.get_output_shape( + self, *args, **kwargs + ) + + return DummyLayer() + + def test_preserves_shape(self) -> None: + layer = self.create_layer() + self.assertEqual(layer.get_output_shape((1, 2, 3, 5)), (1, 2, 3, 5)) + + +class StatelessTest(SequenceLayerTest): + + def create_sequence(self) -> types_spec.Sequence: + """Creates a default test sequence.""" + return self.sl.Sequence( + self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) + ) + + def create_layer(self) -> types_spec.Stateless: + """Creates a stateless layer.""" + backend_sl = self.sl + + class DummyLayer(DefaultTestLayer, backend_sl.types.Stateless): + """Mock layer for testing.""" + + @override + def get_initial_state(self, *args, **kwargs): + return backend_sl.types.Stateless.get_initial_state( + self, *args, **kwargs + ) + + @override + def step(self, *args, **kwargs): + return backend_sl.types.Stateless.step(self, *args, **kwargs) + + return DummyLayer() + + def test_stateless_behaviors(self) -> None: + layer = self.create_layer() + + # Initial state must be empty + self.assertEqual( + layer.get_initial_state( + 32, + DummyChannelSpec(shape=(2, 3), dtype=np.float32), + training=False, + ), + (), + ) + + # step unconditionally delegates to layer and returns identical empty state + x = self.create_sequence() + with unittest.mock.patch.object( + layer, 'layer', return_value='layer_out' + ) as mock_layer: + out, state = layer.step( + x, 'mock_state', training=True, constants={'c': 1} + ) + self.assertEqual(out, 'layer_out') + self.assertEqual(state, 'mock_state') + mock_layer.assert_called_once_with(x, training=True, constants={'c': 1}) + + +class EmittingTest(SequenceLayerTest): + + def create_sequence(self) -> types_spec.Sequence: + """Creates a default test sequence.""" + return self.sl.Sequence( + self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) + ) + + def create_layer(self) -> types_spec.Emitting: + """Creates an emitting layer.""" + backend_sl = self.sl + + class DummyLayer(DefaultTestLayer, backend_sl.types.Emitting): + """Mock layer for testing.""" + + @override + def layer(self, *args, **kwargs): + return backend_sl.types.Emitting.layer(self, *args, **kwargs) + + @override + def step(self, *args, **kwargs): + return backend_sl.types.Emitting.step(self, *args, **kwargs) + + return DummyLayer() + + def test_emitting_drops_emits_on_standard_calls(self) -> None: + layer = self.create_layer() + x = self.create_sequence() + + with unittest.mock.patch.object( + layer, 'layer_with_emits', return_value=('out', 'emits') + ) as m_layer: + self.assertEqual(layer.layer(x, training=False), 'out') + m_layer.assert_called_once_with(x, training=False, constants=None) + + with unittest.mock.patch.object( + layer, 'step_with_emits', return_value=('out', 'state', 'emits') + ) as m_step: + out, state = layer.step(x, 'state', training=True, constants={'c': 1}) + self.assertEqual(out, 'out') + self.assertEqual(state, 'state') + m_step.assert_called_once_with( + x, 'state', training=True, constants={'c': 1} + ) + + +class StatelessEmittingTest(SequenceLayerTest): + + def create_sequence(self) -> types_spec.Sequence: + """Creates a default test sequence.""" + return self.sl.Sequence( + self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) + ) + + def create_layer(self) -> types_spec.SequenceLayer: + """Creates a stateless emitting layer.""" + backend_sl = self.sl + + class DummyLayer(DefaultTestLayer, backend_sl.types.StatelessEmitting): + """Mock layer for testing.""" + + @override + def get_initial_state(self, *args, **kwargs): + return backend_sl.types.StatelessEmitting.get_initial_state( + self, *args, **kwargs + ) + + @override + def step_with_emits(self, *args, **kwargs): + return backend_sl.types.StatelessEmitting.step_with_emits( + self, *args, **kwargs + ) + + return DummyLayer() + + def test_stateless_emitting_behaviors(self) -> None: + layer = self.create_layer() + + self.assertEqual( + layer.get_initial_state( + 32, + DummyChannelSpec(shape=(2, 3), dtype=np.float32), + training=False, + ), + (), + ) + + x = self.create_sequence() + with unittest.mock.patch.object( + layer, 'layer_with_emits', return_value=('out', 'emits') + ) as m_layer: + out, state, emits = layer.step_with_emits(x, 'state', training=False) + self.assertEqual(out, 'out') + self.assertEqual(state, 'state') + self.assertEqual(emits, 'emits') + m_layer.assert_called_once_with(x, training=False, constants=None) + + +class StatelessPointwiseFunctorTest(SequenceLayerTest): + + def create_layer( + self, is_mask_required: bool + ) -> types_spec.SequenceLayer[Any]: + """Creates a stateless pointwise functor layer.""" + + backend_sl = self.sl + + class DummyLayer( + DefaultTestLayer, backend_sl.types.StatelessPointwiseFunctor + ): + """Mock layer for testing.""" + + @override + def layer(self, *args, **kwargs): + return backend_sl.types.StatelessPointwiseFunctor.layer( + self, *args, **kwargs + ) + + @override + def get_output_shape(self, *args, **kwargs): + return backend_sl.types.StatelessPointwiseFunctor.get_output_shape( + self, *args, **kwargs + ) + + @property + @override + def mask_required(self) -> bool: + """Whether mask is required.""" + return is_mask_required + + @override + def fn(self, values: Any, mask: Any) -> tuple[Any, Any]: + """Pointwise function.""" + return values, mask + + return DummyLayer() + + def create_sequence( + self, + ) -> types_spec.Sequence[types_spec.Array, types_spec.Array]: + """Creates a test sequence.""" + return self.sl.Sequence( + self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) + ) + + def test_layer_applies_fn_based_on_mask_required(self) -> None: + for mask_required in [True, False]: + with self.subTest(mask_required=mask_required): + layer = self.create_layer(mask_required) + x = self.create_sequence() + # Mock the apply methods on the Sequence class itself so we return a valid Sequence + # that satisfies any @check_layer decorators. + with unittest.mock.patch.object( + type(x), 'apply', return_value=x + ) as mock_apply: + with unittest.mock.patch.object( + type(x), 'apply_masked', return_value=x + ) as mock_apply_masked: + layer.layer(x, training=False) + + if mask_required: + mock_apply.assert_called_once() + mock_apply_masked.assert_not_called() + else: + mock_apply_masked.assert_called_once() + mock_apply.assert_not_called()