From a7ceb3913d994cacd9cc3a377dceadeb85ef9e25 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Fri, 27 Mar 2026 02:45:45 -0700 Subject: [PATCH 1/9] refactor(types.py): add Steppable methods to abstract, training args to MLX, sync comments - Require training parameter across all backend calls - Add explicit get_output_dtype abstract method - Refine Steppable generic type aliases - Add explicit unittesting for default mixin alias behaviors in abstract/types_test_base.py --- sequence_layers/abstract/types.py | 454 +++++++++++++++++++- sequence_layers/abstract/types_test_base.py | 176 ++++++++ sequence_layers/jax/types.py | 121 ++++-- sequence_layers/jax/types_test.py | 103 +++-- sequence_layers/mlx/types.py | 449 ++++++++++++++++--- sequence_layers/mlx/types_test.py | 67 ++- 6 files changed, 1246 insertions(+), 124 deletions(-) diff --git a/sequence_layers/abstract/types.py b/sequence_layers/abstract/types.py index 4eaa3e9..de14366 100644 --- a/sequence_layers/abstract/types.py +++ b/sequence_layers/abstract/types.py @@ -16,6 +16,9 @@ ShapeLike = list[int] | tuple[int, ...] DType = Any # Can be numpy, jax, or mlx dtype ChannelSpec = Any # Typically ShapeDType or compatible +State = Any +Constants = Any +Emits = Any class PaddingMode(enum.Enum): """Supported padding modes.""" @@ -162,6 +165,16 @@ def dtype(self) -> DType: def from_values(cls, values: ValuesT) -> 'Sequence': pass + @classmethod + @abc.abstractmethod + def from_lengths( + cls, + values: ValuesT, + lengths: Any, + is_masked: bool = False, + ) -> 'Sequence': + pass + @classmethod @abc.abstractmethod def concatenate_sequences(cls, sequences: Iterable['Sequence']) -> 'Sequence': @@ -216,7 +229,10 @@ def lengths(self) -> Any: pass @abc.abstractmethod - def __getitem__(self: SequenceSelf, the_slice: Any) -> SequenceSelf: + def __getitem__( + self: SequenceSelf, + the_slice: slice | tuple[int | slice | None | type(Ellipsis), ...], + ) -> SequenceSelf: pass @abc.abstractmethod @@ -242,6 +258,11 @@ def unmask(self) -> 'Sequence': pass +class MaskedSequence(Sequence[ValuesT, MaskT], metaclass=abc.ABCMeta): + """A sequence whose invalid timesteps are masked to zero.""" + pass + + class SequenceLayerConfig(metaclass=abc.ABCMeta): """Configuration for a SequenceLayer.""" @@ -290,8 +311,439 @@ def get_accumulated_input_latency(self, input_latency: int) -> int: def get_accumulated_output_latency(self, output_latency: int) -> int: pass + @abc.abstractmethod + def layer( + self, x: Sequence, *, training: bool, constants: Constants | None = None + ) -> Sequence: + """Process this layer layer-wise. + + Args: + x: Input sequence with values shaped [b, t_i, ...]. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the source sequence to attend to. + + Returns: + y: The outputs corresponding to this layer with values shaped + [b, t_o, ...] where `t_o == t_i * output_ratio`. t_o may have been + truncated to only represent valid frames. + """ + + @abc.abstractmethod + def layer_with_emits( + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, Emits]: + """Process this layer layer-wise, producing emitted arrays. + + This is like `layer`, except it has an additional return value which is the + "emitted" arrays for the layer. The emitted arrays are a structure of + arrays whose values are arrays or `Sequence`s. + + Args: + x: Input sequence with values shaped [b, t_i, ...]. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the key/value sequence to attend to. + + Returns: + y: The outputs corresponding to this layer with values shaped + [b, t_o, ...] where `t_o == t_i * output_ratio`. t_o may have been + truncated to only represent valid frames. + emits: A nest of emitted arrays or Sequences. + """ + + @abc.abstractmethod + def step( + self, + x: Sequence, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, State]: + """Process this layer step-wise. + + Args: + x: Input sequence with values shaped [b, t_i, ...], where t_i is a + multiple of block_size. + state: A structure of state arrays matching get_initial_state. The + previous state for this layer. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the key/value sequence to attend to. + + Returns: + y: The outputs corresponding to this step with values shaped [b, t_o, ...] + where `t_o == t_i * output_ratio`. + state: A structure of state arrays matching get_initial_state. The + new state for this layer. + """ + + @abc.abstractmethod + def step_with_emits( + self, + x: Sequence, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, State, Emits]: + """Process this layer step-wise, producing emitted arrays. + + This is like `step`, except it has an additional return value which is the + "emitted" arrays for the step. The emitted arrays are a structure of + arrays whose values are arrays or `Sequence`s. + + Args: + x: Input sequence with values shaped [b, t_i, ...], where t_i is a + multiple of block_size. + state: A structure of state arrays matching get_initial_state. The + previous state for this layer. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the key/value sequence to attend to. + + Returns: + y: The outputs corresponding to this step with values shaped [b, t_o, ...] + where `t_o == t_i * output_ratio`. + state: A structure of state arrays matching get_initial_state. The + new state for this layer. + emits: A nest of emitted arrays or Sequences. + """ + + @abc.abstractmethod + def get_initial_state( + self, + batch_size: int, + input_spec: ChannelSpec, + *, + training: bool, + constants: Constants | None = None, + ) -> State: + """Returns the initial state for this SequenceLayer. + + Args: + batch_size: The batch size to create state for. + input_spec: An input ChannelSpec representing the channel shape and dtype + of the input that will be stepped. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the source sequence to attend to. + + Returns: + An integer, shape, or structure of integer/shapes. + """ + + @abc.abstractmethod + def get_output_shape( + self, + input_shape: ShapeLike, + *, + constants: Constants | None = None, + ) -> Shape: + """Returns the output channel shape this layer produces for an input channel shape. + + Args: + input_shape: A shape representing the channels dimension of the input + sequence (i.e. not including the batch or time dimension). + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the source sequence to attend to. + + Returns: + A shape representing the output channels dimensions (i.e. not including + the batch or time dimension). + """ + + @abc.abstractmethod + def get_output_dtype( + self, + input_dtype: DType, + *, + constants: Constants | None = None, + ) -> DType: + """Returns the layer's output dtype for the specified input dtype. + + Args: + input_dtype: The dtype of the input features. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. + + Returns: + The dtype of the output features. + """ + @property @abc.abstractmethod def receptive_field(self) -> Any: pass + +class SequenceLayer(Steppable): + """Base class for Sequence Layers.""" + pass + + +# --------------------------------------------------------------------------- +# Mixins +# --------------------------------------------------------------------------- + + +class PreservesType: + """A mix-in for layers that do not change the input dtype.""" + + @abc.abstractmethod + def get_output_dtype( + self, + input_dtype: DType, + *, + constants: Constants | None = None, + ) -> DType: + pass + + +class PreservesShape: + """A mix-in for layers that do not change the input channel shape.""" + + @abc.abstractmethod + def get_output_shape( + self, + input_shape: ShapeLike, + *, + constants: Constants | None = None, + ) -> Shape: + pass + + +# --------------------------------------------------------------------------- +# Stateless variants +# --------------------------------------------------------------------------- + + +class Stateless(SequenceLayer): + """A layer with no state over time required for step-wise processing. + + The backend must implement: + - get_initial_state + - step + Further sub-classes must only implement: + - layer + - get_output_shape + - get_output_dtype + """ + + @abc.abstractmethod + def get_output_shape( + self, input_shape: ShapeLike, *, constants: Constants | None = None + ) -> Shape: + pass + + @abc.abstractmethod + def get_output_dtype( + self, input_dtype: DType, *, constants: Constants | None = None + ) -> DType: + pass + + @abc.abstractmethod + def layer( + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, + ) -> Sequence: + pass + + @abc.abstractmethod + def get_initial_state( + self, + batch_size: int, + input_spec: ChannelSpec, + *, + training: bool, + constants: Constants | None = None, + ) -> State: + pass + + @abc.abstractmethod + def step( + self, + x: Sequence, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, State]: + pass + + +class StatelessPointwise(PreservesShape, Stateless): + """Stateless layer that operates pointwise (preserves shape).""" + + +class StatelessPointwiseFunctor(StatelessPointwise, metaclass=abc.ABCMeta): + """Stateless pointwise layer defined by a fn(values, mask). + + The backend must implement: + - layer + Further sub-classes must only implement: + - fn + - mask_required + """ + + @abc.abstractmethod + def fn(self, values: Any, mask: Any) -> tuple[Any, Any]: + """Transforms each scalar in values independently.""" + + @property + @abc.abstractmethod + def mask_required(self) -> bool: + """Returns true if fn can change the sequence's masked state. + + If fn(0) -> 0, then mask_required() is False. + """ + + @abc.abstractmethod + def layer( + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, + ) -> Sequence: + pass + + +# --------------------------------------------------------------------------- +# Emitting variants +# --------------------------------------------------------------------------- + + +class Emitting(SequenceLayer, metaclass=abc.ABCMeta): + """A Steppable layer that emits auxiliary arrays. + + This is a convenience subclass that implements step and layer in terms of + step_with_emits and layer_with_emits. + + The backend must implement: + - step + - layer + Further sub-classes must only implement: + - step_with_emits + - layer_with_emits + """ + + @abc.abstractmethod + def step( + self, + x: Sequence, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, State]: + pass + + @abc.abstractmethod + def layer( + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, + ) -> Sequence: + pass + + @abc.abstractmethod + def step_with_emits( + self, + x: Sequence, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, State, Emits]: + pass + + @abc.abstractmethod + def layer_with_emits( + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, Emits]: + pass + + +class StatelessEmitting(Emitting): + """A Steppable layer with no state over time that emits auxiliary arrays. + + The backend must implement: + - get_initial_state + - step_with_emits + Further sub-classes must only implement: + - layer_with_emits + - get_output_shape + - get_output_dtype + """ + + @abc.abstractmethod + def get_initial_state( + self, + batch_size: int, + input_spec: ChannelSpec, + *, + training: bool, + constants: Constants | None = None, + ) -> State: + pass + + @abc.abstractmethod + def step_with_emits( + self, + x: Sequence, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, State, Emits]: + pass + + @abc.abstractmethod + def get_output_shape( + self, input_shape: ShapeLike, *, constants: Constants | None = None + ) -> Shape: + pass + + @abc.abstractmethod + def get_output_dtype( + self, input_dtype: DType, *, constants: Constants | None = None + ) -> DType: + pass + + @abc.abstractmethod + def layer_with_emits( + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, Emits]: + pass diff --git a/sequence_layers/abstract/types_test_base.py b/sequence_layers/abstract/types_test_base.py index f9cd54b..3bb6a43 100644 --- a/sequence_layers/abstract/types_test_base.py +++ b/sequence_layers/abstract/types_test_base.py @@ -8,6 +8,7 @@ import fractions from absl.testing import parameterized import numpy as np +import unittest.mock class SequenceLayerTest(parameterized.TestCase): """Base abstract test class providing common sequence testing assertions.""" @@ -265,6 +266,51 @@ def test_astype(self): # values is backend array. values.astype(dtype) should work if dtype is backend dtype. self.check_trees_all_equal(y.values, values.astype(dtype)) + def test_mask_invalid_idempotent(self): + xp = self.get_backend() + values = xp.array([ + [1.0, 2.0, 3.0, 4.0], + [10.0, 20.0, 30.0, 40.0], + ]) + mask = xp.array([[True, True, False, False], [False, False, False, True]]) + + x = self.Sequence(values, mask) + masked = x.mask_invalid() + self.assertIsNot(masked, x) + self.assertIsInstance(masked, self.MaskedSequence) + + masked_again = masked.mask_invalid() + self.assertIs(masked_again, masked) + self.assertIsInstance(masked_again, self.MaskedSequence) + + masked2 = x.mask_invalid() + self.assertIsNot(masked2, masked) + self.assertIsInstance(masked2, self.MaskedSequence) + + def test_from_lengths(self): + xp = self.get_backend() + values = xp.array(np.arange(5 * 17 * 2).reshape((5, 17, 2)).astype(np.float32)) + lengths_np = np.array([0, 5, 10, 17, 12], dtype=np.int32) + mask_np = np.arange(17)[None, :] < lengths_np[:, None] + mask = xp.array(mask_np) + + x_expected = self.Sequence(values, mask) + x = self.Sequence.from_lengths(x_expected.values, lengths_np) + self.check_trees_all_equal(x.values, x_expected.values) + self.check_trees_all_equal(x.mask, x_expected.mask) + + # Out of range lengths are clipped to 0 or max. + x = self.Sequence.from_lengths(x_expected.values, [-1, 5, 10, 17, 18]) + self.check_trees_all_equal(x.lengths(), xp.array([0, 5, 10, 17, 17])) + self.assertNotIsInstance(x, self.MaskedSequence) + + # Return type is MaskedSequence if is_masked=True. + x = self.Sequence.from_lengths( + x_expected.values, [-1, 5, 10, 17, 18], is_masked=True + ) + self.check_trees_all_equal(x.lengths(), xp.array([0, 5, 10, 17, 17])) + self.assertIsInstance(x, self.MaskedSequence) + class SteppableTest(parameterized.TestCase): @@ -284,6 +330,22 @@ def test_steppable_defaults(self): self.assertEqual(layer.get_accumulated_input_latency(0), 0) self.assertEqual(layer.get_accumulated_output_latency(0), 0) + def test_steppable_with_emits_defaults_to_tuple_with_empty_emits(self): + layer = self.create_steppable() + + with unittest.mock.patch.object(layer, 'layer', return_value='mock_layer_out') as mock_layer: + out, emits = layer.layer_with_emits('mock_x', training=False, constants=None) + self.assertEqual(out, 'mock_layer_out') + self.assertEqual(emits, ()) + mock_layer.assert_called_with('mock_x', training=False, constants=None) + + with unittest.mock.patch.object(layer, 'step', return_value=('step_out', 'state_out')) as mock_step: + out, state, emits = layer.step_with_emits('mock_x', 'state_in', training=True, constants=None) + self.assertEqual(out, 'step_out') + self.assertEqual(state, 'state_out') + self.assertEqual(emits, ()) + mock_step.assert_called_with('mock_x', 'state_in', training=True, constants=None) + class SequenceLayerConfigTest(SequenceLayerTest): @@ -336,3 +398,117 @@ def make(self) -> Any: new_config = config.copy(field_does_not_exist=1234) del new_config + +class PreservesTypeTest(parameterized.TestCase): + + @abc.abstractmethod + def create_layer(self) -> types.PreservesType: + pass + + def test_preserves_dtype(self): + layer = self.create_layer() + self.assertEqual(layer.get_output_dtype('fake_dtype123'), 'fake_dtype123') + + +class PreservesShapeTest(parameterized.TestCase): + + @abc.abstractmethod + def create_layer(self) -> types.PreservesShape: + pass + + def test_preserves_shape(self): + layer = self.create_layer() + self.assertEqual(layer.get_output_shape((1, 2, 3, 5)), (1, 2, 3, 5)) + + +class StatelessTest(parameterized.TestCase): + + @abc.abstractmethod + def create_layer(self) -> types.Stateless: + pass + + def test_stateless_behaviors(self): + layer = self.create_layer() + + # Initial state must be empty + self.assertEqual( + layer.get_initial_state(32, 'fake_spec', training=False), () + ) + + # step unconditionally delegates to layer and returns identical empty state + with unittest.mock.patch.object(layer, 'layer', return_value='layer_out') as mock_layer: + out, state = layer.step('mock_x', 'mock_state', training=True, constants={'c': 1}) + self.assertEqual(out, 'layer_out') + self.assertEqual(state, 'mock_state') + mock_layer.assert_called_once_with('mock_x', training=True, constants={'c': 1}) + + +class EmittingTest(parameterized.TestCase): + + @abc.abstractmethod + def create_layer(self) -> types.Emitting: + pass + + def test_emitting_drops_emits_on_standard_calls(self): + layer = self.create_layer() + + with unittest.mock.patch.object(layer, 'layer_with_emits', return_value=('out', 'emits')) as m_layer: + self.assertEqual(layer.layer('mock_x', training=False), 'out') + m_layer.assert_called_once_with('mock_x', training=False, constants=None) + + with unittest.mock.patch.object(layer, 'step_with_emits', return_value=('out', 'state', 'emits')) as m_step: + out, state = layer.step('mock_x', 'state', training=True, constants={'c': 1}) + self.assertEqual(out, 'out') + self.assertEqual(state, 'state') + m_step.assert_called_once_with('mock_x', 'state', training=True, constants={'c': 1}) + + +class StatelessEmittingTest(parameterized.TestCase): + + @abc.abstractmethod + def create_layer(self) -> types.StatelessEmitting: + pass + + def test_stateless_emitting_behaviors(self): + layer = self.create_layer() + + self.assertEqual( + layer.get_initial_state(32, 'fake_spec', training=False), () + ) + + with unittest.mock.patch.object(layer, 'layer_with_emits', return_value=('out', 'emits')) as m_layer: + out, state, emits = layer.step_with_emits('mock_x', 'state', training=False) + self.assertEqual(out, 'out') + self.assertEqual(state, 'state') + self.assertEqual(emits, 'emits') + m_layer.assert_called_once_with('mock_x', training=False, constants=None) + + +class StatelessPointwiseFunctorTest(parameterized.TestCase): + + @abc.abstractmethod + def create_layer(self, mask_required: bool) -> types.StatelessPointwiseFunctor: + pass + + @abc.abstractmethod + def create_sequence(self) -> types.Sequence: + pass + + def test_layer_applies_fn_based_on_mask_required(self): + for mask_required in [True, False]: + with self.subTest(mask_required=mask_required): + layer = self.create_layer(mask_required) + x = self.create_sequence() + # Mock the apply methods on the Sequence class itself so we return a valid Sequence + # that satisfies any @check_layer decorators. + with unittest.mock.patch.object(type(x), 'apply', return_value=x) as mock_apply: + with unittest.mock.patch.object(type(x), 'apply_masked', return_value=x) as mock_apply_masked: + layer.layer(x, training=False) + + if mask_required: + mock_apply.assert_called_once() + mock_apply_masked.assert_not_called() + else: + mock_apply_masked.assert_called_once() + mock_apply.assert_not_called() + diff --git a/sequence_layers/jax/types.py b/sequence_layers/jax/types.py index 57e9d3f..5886433 100644 --- a/sequence_layers/jax/types.py +++ b/sequence_layers/jax/types.py @@ -87,17 +87,17 @@ # Sequence type aliases: MASK_DTYPE = np.bool_ -# A rank 2+ tensor of any type. +# A rank 2+ array of any type. ValuesT = TypeVar('ValuesT', bound=jt.Shaped[jt.ArrayT, 'B T *C']) -# A boolean batched mask tensor. True indicates a given timepoint is valid, and +# A boolean batched mask array. True indicates a given timepoint is valid, and # False indicates it is invalid. MaskT = TypeVar('MaskT', bound=jt.Bool[jt.ArrayT, 'B T']) -# An integer batched lengths tensor. +# An integer batched lengths array. LengthsT = TypeVar('LengthsT', bound=jt.Int[jt.ArrayT, 'B']) -# A rank 2 boolean tensor with unit dimensions inserted to match their +# A rank 2 boolean array with unit dimensions inserted to match their # corresponding values (e.g. for broadcasting). ExpandedMaskT = TypeVar('ExpandedMaskT', bound=jt.Bool[jt.ArrayT, 'B T *C']) @@ -537,7 +537,7 @@ def unmask(self) -> 'Sequence': return self -class MaskedSequence(Sequence[ValuesT, MaskT]): +class MaskedSequence(Sequence[ValuesT, MaskT], types.MaskedSequence[ValuesT, MaskT]): """Sequence whose invalid timesteps are masked to zero.""" @override @@ -899,6 +899,7 @@ def layer( truncated to only represent valid frames. """ + @override def layer_with_emits( self, x: Sequence, @@ -906,11 +907,11 @@ def layer_with_emits( training: bool, constants: Constants | None = None, ) -> tuple[Sequence, Emits]: - """Process this layer layer-wise, producing emitted tensors. + """Process this layer layer-wise, producing emitted arrays. This is like `layer`, except it has an additional return value which is the - "emitted" tensors for the layer. The emitted tensors are a structure of - tensors whose whose values are `ArrayLike`s or `Sequence`s. + "emitted" arrays for the layer. The emitted arrays are a structure of + arrays whose whose values are `ArrayLike`s or `Sequence`s. Args: x: Input sequence with values shaped [b, t_i, ...]. @@ -924,11 +925,12 @@ def layer_with_emits( y: The outputs corresponding to this layer with values shaped [b, t_o, ...] where `t_o == t_i * output_ratio`. t_o may have been truncated to only represent valid frames. - emits: A nest of emitted tensors or Sequences. + emits: A nest of emitted arrays or Sequences. """ outputs = self.layer(x, training=training, constants=constants) return outputs, () + @override def __call__( self, x: Sequence, training: bool, constants: Constants | None = None ) -> Sequence: @@ -944,12 +946,12 @@ def step( training: bool, constants: Constants | None = None, ) -> tuple[Sequence, State]: - """Process this layer step-wise, producing emitted tensors. + """Process this layer step-wise, producing emitted arrays. Args: x: Input sequence with values shaped [b, t_i, ...], where t_i is a multiple of block_size. - state: A structure of state tensors matching get_initial_state. The + state: A structure of state arrays matching get_initial_state. The previous state for this layer. training: Python bool. Whether we are in training mode. constants: A dictionary of constant name to ArrayLike or sl.Sequence. @@ -960,10 +962,11 @@ def step( Returns: y: The outputs corresponding to this step with values shaped [b, t_o, ...] where `t_o == t_i * output_ratio`. - state: A structure of state tensors matching get_initial_state. The + state: A structure of state arrays matching get_initial_state. The new state for this layer. """ + @override def step_with_emits( self, x: Sequence, @@ -972,16 +975,16 @@ def step_with_emits( training: bool, constants: Constants | None = None, ) -> tuple[Sequence, State, Emits]: - """Process this layer step-wise, producing emitted tensors. + """Process this layer step-wise, producing emitted arrays. This is like `step`, except it has an additional return value which is the - "emitted" tensors for the step. The emitted tensors are a structure of - tensors whose values are `ArrayLike`s or `Sequence`s. + "emitted" arrays for the step. The emitted arrays are a structure of + arrays whose values are `ArrayLike`s or `Sequence`s. Args: x: Input sequence with values shaped [b, t_i, ...], where t_i is a multiple of block_size. - state: A structure of state tensors matching get_initial_state. The + state: A structure of state arrays matching get_initial_state. The previous state for this layer. training: Python bool. Whether we are in training mode. constants: A dictionary of constant name to ArrayLike or sl.Sequence. @@ -992,9 +995,9 @@ def step_with_emits( Returns: y: The outputs corresponding to this step with values shaped [b, t_o, ...] where `t_o == t_i * output_ratio`. - state: A structure of state tensors matching get_initial_state. The + state: A structure of state arrays matching get_initial_state. The new state for this layer. - emits: A nest of emitted tensors or Sequences. + emits: A nest of emitted arrays or Sequences. """ outputs, state = self.step(x, state, training=training, constants=constants) return outputs, state, () @@ -1021,14 +1024,15 @@ def get_initial_state( attention layer this may contain the source sequence to attend to. Returns: - An integer, TensorShape or structure of integer/TensorShapes. + An integer, shape or structure of integer/shapes. """ @abc.abstractmethod + @override def get_output_shape( self, input_shape: ShapeLike, *, constants: Constants | None = None ) -> Shape: - """Returns the output shape this layer produces for an input shape. + """Returns the output channel shape this layer produces for an input channel shape. Args: input_shape: A shape representing the channels dimension of the input @@ -1089,11 +1093,13 @@ def get_output_spec_for_sequence( return self.get_output_spec(x.channel_spec, constants=constants) @abc.abstractmethod + @override def get_output_dtype( self, input_dtype: DType, *, constants: Constants | None = None ) -> DType: """Returns the layer's output dtype for the specified input dtype.""" + @nn.nowrap def get_output_spec( self, @@ -1261,14 +1267,15 @@ def check_step_with_emits_fn( return check_step_with_emits_fn -class SequenceLayer(nn.Module, Steppable): +class SequenceLayer(nn.Module, Steppable, types.SequenceLayer): """Base Module for Sequence Layers.""" -class PreservesType: +class PreservesType(types.PreservesType): """A mix-in for layers that do not change the input dtype.""" @nn.nowrap + @override def get_output_dtype( self, input_dtype: DType, *, constants: Constants | None = None ) -> DType: @@ -1276,10 +1283,11 @@ def get_output_dtype( return input_dtype -class PreservesShape: +class PreservesShape(types.PreservesShape): """A mix-in for layers that do not change the input shape.""" @nn.nowrap + @override def get_output_shape( self, input_shape: ShapeLike, *, constants: Constants | None = None ) -> Shape: @@ -1287,8 +1295,8 @@ def get_output_shape( return tuple(input_shape) -class Emitting(SequenceLayer, metaclass=abc.ABCMeta): - """A SequenceLayer that emits auxiliary tensors. +class Emitting(SequenceLayer, types.Emitting): + """A SequenceLayer that emits auxiliary arrays. This is a convenience subclass that implements step and layer in terms of step_with_emits and layer_with_emits, so that implementors need only implement @@ -1297,6 +1305,7 @@ class Emitting(SequenceLayer, metaclass=abc.ABCMeta): do not produce emits. """ + @override def step( self, x: Sequence, @@ -1311,6 +1320,7 @@ def step( return output, state @abc.abstractmethod + @override def step_with_emits( self, x: Sequence, @@ -1321,6 +1331,7 @@ def step_with_emits( ) -> tuple[Sequence, State, Emits]: pass + @override def layer( self, x: Sequence, @@ -1334,6 +1345,7 @@ def layer( return outputs @abc.abstractmethod + @override def layer_with_emits( self, x: Sequence, @@ -1344,7 +1356,7 @@ def layer_with_emits( pass -class Stateless(SequenceLayer): +class Stateless(SequenceLayer, types.Stateless): """A SequenceLayer with no state over time required for step-wise processing. Sub-classes must only implement: @@ -1367,6 +1379,7 @@ def get_initial_state( ) -> State: del batch_size del input_spec + del training del constants return () @@ -1380,9 +1393,31 @@ def step( ) -> tuple[Sequence, State]: return self.layer(x, training=training, constants=constants), state + @abc.abstractmethod + def get_output_shape( + self, input_shape: ShapeLike, *, constants: Constants | None = None + ) -> Shape: + pass -class StatelessEmitting(Emitting): - """A SequenceLayer with no state over time that emits auxiliary tensors. + @abc.abstractmethod + def get_output_dtype( + self, input_dtype: DType, *, constants: Constants | None = None + ) -> DType: + pass + + @abc.abstractmethod + def layer( + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, + ) -> 'Sequence': + pass + + +class StatelessEmitting(Emitting, types.StatelessEmitting): + """A SequenceLayer with no state over time that emits auxiliary arrays. Sub-classes must only implement: - layer_with_emits @@ -1415,14 +1450,40 @@ def get_initial_state( training: bool, constants: Constants | None = None, ) -> State: + del batch_size + del input_spec + del training + del constants return () + @abc.abstractmethod + def get_output_shape( + self, input_shape: ShapeLike, *, constants: Constants | None = None + ) -> Shape: + pass + + @abc.abstractmethod + def get_output_dtype( + self, input_dtype: DType, *, constants: Constants | None = None + ) -> DType: + pass + + @abc.abstractmethod + def layer_with_emits( + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, Emits]: + pass + -class StatelessPointwise(PreservesShape, Stateless): +class StatelessPointwise(PreservesShape, Stateless, types.StatelessPointwise): """A SequenceLayer that has no state and operates pointwise on its input.""" -class StatelessPointwiseFunctor(StatelessPointwise, metaclass=abc.ABCMeta): +class StatelessPointwiseFunctor(StatelessPointwise, types.StatelessPointwiseFunctor): """A stateless SequenceLayer for simple pointwise processing fns.""" @abc.abstractmethod diff --git a/sequence_layers/jax/types_test.py b/sequence_layers/jax/types_test.py index c38420f..e297f9f 100644 --- a/sequence_layers/jax/types_test.py +++ b/sequence_layers/jax/types_test.py @@ -39,7 +39,7 @@ def __call__(self, x: types.Sequence) -> types.Sequence: return x -class SequenceTest(types_test_base.SequenceTest, test_utils.SequenceLayerTest): +class SequenceTest(test_utils.SequenceLayerTest, types_test_base.SequenceTest): """Tests for the Sequence class.""" def get_backend(self): @@ -109,25 +109,7 @@ def test_type_checks(self): with self.assertRaises(jaxtyping.TypeCheckError): types.Sequence(np.zeros((2, 3, 5)), np.zeros((1, 3), dtype=jnp.bool_)) - def test_mask_invalid_idempotent(self): - values = jnp.array([ - [1.0, 2.0, 3.0, 4.0], - [10.0, 20.0, 30.0, 40.0], - ]) - mask = jnp.array([[True, True, False, False], [False, False, False, True]]) - x = types.Sequence(values, mask) - masked = x.mask_invalid() - self.assertIsNot(masked, x) - self.assertIsInstance(masked, types.MaskedSequence) - - masked_again = masked.mask_invalid() - self.assertIs(masked_again, masked) - self.assertIsInstance(masked_again, types.MaskedSequence) - - masked2 = x.mask_invalid() - self.assertIsNot(masked2, masked) - self.assertIsInstance(masked2, types.MaskedSequence) def test_type_annotation(self): if not jt.runtime_type_checking_enabled: @@ -222,22 +204,7 @@ def fn(x: types.Sequence) -> types.Sequence: y = fn(x) self.assertSequencesEqual(y, x) - def test_from_lengths(self): - x_expected = test_utils.random_sequence(5, 17, 2) - x = types.Sequence.from_lengths(x_expected.values, x_expected.lengths()) - self.assertSequencesEqual(x, x_expected) - # Out of range lengths are clipped to 0 or max. - x = types.Sequence.from_lengths(x_expected.values, [-1, 0, 5, 17, 18]) - self.assertAllEqual(x.lengths(), jnp.asarray([0, 0, 5, 17, 17])) - self.assertNotIsInstance(x, types.MaskedSequence) - - # Return type is MaskedSequence if is_masked=True. - x = types.Sequence.from_lengths( - x_expected.values, [-1, 0, 5, 17, 18], is_masked=True - ) - self.assertAllEqual(x.lengths(), jnp.asarray([0, 0, 5, 17, 17])) - self.assertIsInstance(x, types.MaskedSequence) class SequenceLayerConfigTest(types_test_base.SequenceLayerConfigTest): @@ -290,10 +257,10 @@ def create_steppable(self): class DefaultSteppable(types.Steppable): - def layer(self, x, *, constants=None): + def layer(self, x, *, training: bool, constants=None): return x - def step(self, x, state, *, constants=None): + def step(self, x, state, *, training: bool, constants=None): return x, state def get_initial_state(self, batch_size, input_spec, *, constants=None): @@ -308,5 +275,69 @@ def get_output_dtype(self, input_dtype, *, constants=None): return DefaultSteppable() +class PreservesTypeTest(types_test_base.PreservesTypeTest): + def create_layer(self): + class DummyLayer(types.PreservesType, types.SequenceLayer): + def layer(self, x, *, training: bool, constants=None): return x + def step(self, x, state, *, training: bool, constants=None): return x, state + def get_initial_state(self, batch_size, input_spec, *, training: bool, constants=None): return () + def get_output_shape(self, input_shape, *, constants=None): return input_shape + return DummyLayer() + + +class PreservesShapeTest(types_test_base.PreservesShapeTest): + def create_layer(self): + class DummyLayer(types.PreservesShape, types.SequenceLayer): + def layer(self, x, *, training: bool, constants=None): return x + def step(self, x, state, *, training: bool, constants=None): return x, state + def get_initial_state(self, batch_size, input_spec, *, training: bool, constants=None): return () + def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype + return DummyLayer() + + +class StatelessTest(types_test_base.StatelessTest): + def create_layer(self): + class DummyLayer(types.Stateless, types.SequenceLayer): + def layer(self, x, *, training: bool, constants=None): return x + def get_output_shape(self, input_shape, *, constants=None): return input_shape + def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype + return DummyLayer() + + +class EmittingTest(types_test_base.EmittingTest): + def create_layer(self): + class DummyLayer(types.Emitting, types.SequenceLayer): + def get_initial_state(self, batch_size, input_spec, *, training: bool, constants=None): return () + def layer_with_emits(self, x, *, training: bool, constants=None): return x, () + def step_with_emits(self, x, state, *, training: bool, constants=None): return x, state, () + def get_output_shape(self, input_shape, *, constants=None): return input_shape + def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype + @property + def receptive_field_per_step(self): return {0: (0, 0)} + return DummyLayer() + + +class StatelessEmittingTest(types_test_base.StatelessEmittingTest): + def create_layer(self): + class DummyLayer(types.StatelessEmitting, types.SequenceLayer): + def layer_with_emits(self, x, *, training: bool, constants=None): return x, () + def get_output_shape(self, input_shape, *, constants=None): return input_shape + def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype + return DummyLayer() + + +class StatelessPointwiseFunctorTest(types_test_base.StatelessPointwiseFunctorTest): + def create_layer(self, is_mask_required: bool): + class DummyLayer(types.StatelessPointwiseFunctor, types.SequenceLayer): + @property + def mask_required(self): return is_mask_required + def fn(self, values, mask): return values, mask + def get_output_shape(self, input_shape, *, constants=None): return input_shape + def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype + return DummyLayer() + + def create_sequence(self): + return types.Sequence(jnp.zeros((2, 3, 5)), jnp.zeros((2, 3), dtype=jnp.bool_)) + if __name__ == '__main__': test_utils.main() diff --git a/sequence_layers/mlx/types.py b/sequence_layers/mlx/types.py index 5647618..c897322 100644 --- a/sequence_layers/mlx/types.py +++ b/sequence_layers/mlx/types.py @@ -134,6 +134,19 @@ def dtype(self) -> DType: """Returns the dtype of the sequence values.""" return self.values.dtype + @classmethod + @override + def from_lengths( + cls, + values: ValuesT, + lengths: LengthsT, + is_masked: bool = False, + ) -> 'Sequence': + """Constructs a sequence from values and per-batch element lengths.""" + values = mx.array(values) + mask = sequence_mask(lengths, maxlen=values.shape[1]) + return MaskedSequence(values, mask) if is_masked else Sequence(values, mask) + @classmethod @override def from_values(cls, values: ValuesT) -> 'MaskedSequence': @@ -222,7 +235,7 @@ def lengths(self) -> mx.array: @override def __getitem__( self: SequenceSelf, - the_slice, + the_slice: slice | tuple[int | slice | None | type(Ellipsis), ...], ) -> SequenceSelf: """Slices the Sequence values and mask.""" if isinstance(the_slice, slice): @@ -276,11 +289,11 @@ def unmask(self) -> 'Sequence': return self -class MaskedSequence(Sequence[ValuesT, MaskT]): +class MaskedSequence(Sequence[ValuesT, MaskT], types.MaskedSequence[ValuesT, MaskT]): """Sequence whose invalid timesteps are masked to zero.""" @override - def mask_invalid(self, mask_value: complex | None = None) -> 'Sequence': + def mask_invalid(self, mask_value: complex | None = None) -> Sequence: if mask_value is None: return self return mask_invalid(self, mask_value) @@ -341,8 +354,8 @@ def check_layer(layer_fn): """Validates layer inputs and outputs.""" @functools.wraps(layer_fn) - def wrapper(self, x, *, constants=None): - y = layer_fn(self, x, constants=constants) + def wrapper(self, x, *, training: bool, constants=None): + y = layer_fn(self, x, training=training, constants=constants) _check_output_spec(self, x, y, constants) return y @@ -353,7 +366,7 @@ def check_step(step_fn): """Validates step inputs and outputs.""" @functools.wraps(step_fn) - def wrapper(self, x, state, *, constants=None): + def wrapper(self, x, state, *, training: bool, constants=None): if not self.supports_step: raise ValueError(f'{self.__class__.__name__} does not support step().') block_size = self.block_size @@ -362,7 +375,7 @@ def wrapper(self, x, state, *, constants=None): f'{self.__class__.__name__} received input with' f' {x.shape=} not a multiple of {block_size=}.' ) - y, state = step_fn(self, x, state, constants=constants) + y, state = step_fn(self, x, state, training=training, constants=constants) _check_output_spec(self, x, y, constants) _check_output_ratio(self, x, y) return y, state @@ -376,7 +389,70 @@ def wrapper(self, x, state, *, constants=None): class Steppable(types.Steppable): - """A sequence processing layer that supports layer and step modes.""" + """A sequence processing layer that can be executed layerwise or stepwise. + + # Step-wise execution: + + A SequenceLayer supports step-wise execution if its `supports_step` property + is true. Most built-in SequenceLayers support step-wise processing by default, + but may support processing features that are not causal and therefore cannot + be executed step-by-step (e.g. non-causal convolutions, bidirectional RNNs, + etc.). + + When executing step-wise, use the `step` or `step_with_emits` method to + process a block of inputs (a `Sequence` shaped `[b, block_size * n, ...]`) and + a `state` input whose structure matches `get_initial_state`. + + This produces: + - An output `Sequence` shaped `[b, block_size * n * output_ratio, ...]` + whose `...` shape matches `get_output_shape`. + - A `state` output whose structure matches `get_initial_state`. + - (Optionally) an `emits` output. + + The output `Sequence` is the primary output of the step, while the `emits` + represent "auxiliary" outputs that are produced by the layer (for example, + debug output). + + # Layer-wise execution: + + When executing layer-wise, use the `layer` or `layer_with_emits` method to + process inputs (a `Sequence` shaped `[b, t, ...]`). + + This produces: + - An output `Sequence` shaped `[b, t * output_ratio, ...]` + whose `...` shape matches `get_output_shape`. + - (Optionally) an `emits` output. + + The output `Sequence` is the primary output of the layer, while the `emits` + represent "auxiliary" outputs that are produced by the layer (for example, + debug output). + + # Latency + + SequenceLayers have an input and output "latency" to describe their latency + characteristics. Latency is the number of input or output timesteps from + step-wise excecution that are input or output before the step-wise output of + the layer matches the layer-wise output of the layer. + + An invariant that all layers must maintain is that for the layer-wise output + and step-wise output: + + ``` + y_layer = l.layer(x, training=training) + + # Pad x with input_latency timesteps to process the entire sequence: + x = x.pad_time(0, l.input_latency, valid=False) + + y_step, _, _ = utils.step_by_step_dynamic(l, x, training=training) + ``` + + The step-wise output is equivalent to the layer-wise output after dropping the + initial latency timesteps of the step-wise output: + + ``` + y_layer == y_step[:, l.output_latency:] + ``` + """ @property @override @@ -431,14 +507,49 @@ def receptive_field(self) -> ReceptiveField: @abc.abstractmethod @override def layer( - self, x: Sequence, *, constants: Constants | None = None + self, x: Sequence, *, training: bool, constants: Constants | None = None ) -> Sequence: - """Process this layer layer-wise.""" + """Process this layer layer-wise. + + Args: + x: Input sequence with values shaped [b, t_i, ...]. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the source sequence to attend to. + + Returns: + y: The outputs corresponding to this layer with values shaped + [b, t_o, ...] where `t_o == t_i * output_ratio`. t_o may have been + truncated to only represent valid frames. + """ + @override def layer_with_emits( - self, x: Sequence, *, constants: Constants | None = None + self, x: Sequence, *, training: bool, constants: Constants | None = None ) -> tuple[Sequence, Emits]: - return self.layer(x, constants=constants), () + """Process this layer layer-wise, producing emitted arrays. + + This is like `layer`, except it has an additional return value which is the + "emitted" arrays for the layer. The emitted arrays are a structure of + arrays whose values are arrays or `Sequence`s. + + Args: + x: Input sequence with values shaped [b, t_i, ...]. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the key/value sequence to attend to. + + Returns: + y: The outputs corresponding to this layer with values shaped + [b, t_o, ...] where `t_o == t_i * output_ratio`. t_o may have been + truncated to only represent valid frames. + emits: A nest of emitted arrays or Sequences. + """ + return self.layer(x, training=training, constants=constants), () @abc.abstractmethod @override @@ -447,18 +558,63 @@ def step( x: Sequence, state: State, *, + training: bool, constants: Constants | None = None, ) -> tuple[Sequence, State]: - """Process this layer step-wise.""" + """Process this layer step-wise. + + Args: + x: Input sequence with values shaped [b, t_i, ...], where t_i is a + multiple of block_size. + state: A structure of state arrays matching get_initial_state. The + previous state for this layer. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the key/value sequence to attend to. + + Returns: + y: The outputs corresponding to this step with values shaped [b, t_o, ...] + where `t_o == t_i * output_ratio`. + state: A structure of state arrays matching get_initial_state. The + new state for this layer. + """ + @override def step_with_emits( self, x: Sequence, state: State, *, + training: bool, constants: Constants | None = None, ) -> tuple[Sequence, State, Emits]: - y, state = self.step(x, state, constants=constants) + """Process this layer step-wise, producing emitted arrays. + + This is like `step`, except it has an additional return value which is the + "emitted" arrays for the step. The emitted arrays are a structure of + arrays whose values are arrays or `Sequence`s. + + Args: + x: Input sequence with values shaped [b, t_i, ...], where t_i is a + multiple of block_size. + state: A structure of state arrays matching get_initial_state. The + previous state for this layer. + training: Python bool. Whether we are in training mode. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the key/value sequence to attend to. + + Returns: + y: The outputs corresponding to this step with values shaped [b, t_o, ...] + where `t_o == t_i * output_ratio`. + state: A structure of state arrays matching get_initial_state. The + new state for this layer. + emits: A nest of emitted arrays or Sequences. + """ + y, state = self.step(x, state, training=training, constants=constants) return y, state, () @abc.abstractmethod @@ -468,9 +624,23 @@ def get_initial_state( batch_size: int, input_spec: ChannelSpec, *, + training: bool, constants: Constants | None = None, ) -> State: - """Returns the initial state for step-wise processing.""" + """Returns the initial state for this SequenceLayer. + + Args: + batch_size: The batch size to create state for. + input_spec: An input ChannelSpec representing the channel shape and dtype + of the input that will be stepped. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the source sequence to attend to. + + Returns: + An integer, shape, or structure of integer/shapes. + """ @abc.abstractmethod @override @@ -480,7 +650,20 @@ def get_output_shape( *, constants: Constants | None = None, ) -> Shape: - """Returns the output channel shape for an input channel shape.""" + """Returns the output channel shape this layer produces for an input channel shape. + + Args: + input_shape: A shape representing the channels dimension of the input + sequence (i.e. not including the batch or time dimension). + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. For example, for an + attention layer this may contain the source sequence to attend to. + + Returns: + A shape representing the output channels dimensions (i.e. not including + the batch or time dimension). + """ @abc.abstractmethod @override @@ -490,7 +673,17 @@ def get_output_dtype( *, constants: Constants | None = None, ) -> DType: - """Returns the output dtype for an input dtype.""" + """Returns the layer's output dtype for the specified input dtype. + + Args: + input_dtype: The dtype of the input features. + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. + + Returns: + The dtype of the output features. + """ def get_output_spec( self, @@ -508,8 +701,8 @@ def get_output_spec( # --------------------------------------------------------------------------- -class SequenceLayer(nn.Module, Steppable): - """Base MLX Module for Sequence Layers.""" +class SequenceLayer(nn.Module, Steppable, types.SequenceLayer): + """Base Module for Sequence Layers.""" class SequenceLayerConfig(types.SequenceLayerConfig): """Base class for SequenceLayer configuration objects.""" @@ -529,19 +722,24 @@ def copy(self, **kwargs) -> 'SequenceLayerConfig': # --------------------------------------------------------------------------- -class PreservesType: - """Mix-in: layer does not change the input dtype.""" +class PreservesType(types.PreservesType): + """A mix-in for layers that do not change the input dtype.""" + @override def get_output_dtype( - self, input_dtype: DType, *, constants: Constants | None = None + self, + input_dtype: DType, + *, + constants: Constants | None = None, ) -> DType: del constants return input_dtype -class PreservesShape: - """Mix-in: layer does not change the input channel shape.""" +class PreservesShape(types.PreservesShape): + """A mix-in for layers that do not change the input shape.""" + @override def get_output_shape( self, input_shape: ShapeLike, @@ -557,34 +755,74 @@ def get_output_shape( # --------------------------------------------------------------------------- -class Stateless(SequenceLayer): - """A SequenceLayer with no step state.""" +class Stateless(SequenceLayer, types.Stateless): + """A SequenceLayer with no state over time required for step-wise processing. + + Sub-classes must only implement: + - layer + - get_output_shape + - get_output_dtype + """ def get_initial_state( self, batch_size: int, input_spec: ChannelSpec, *, + training: bool, constants: Constants | None = None, ) -> State: + del batch_size + del input_spec + del training + del constants return () + @abc.abstractmethod + def get_output_shape( + self, + input_shape: ShapeLike, + *, + constants: Constants | None = None, + ) -> Shape: + pass + + @abc.abstractmethod + def get_output_dtype( + self, + input_dtype: DType, + *, + constants: Constants | None = None, + ) -> DType: + pass + + @abc.abstractmethod + def layer( + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, + ) -> Sequence: + pass + def step( self, x: Sequence, state: State, *, + training: bool, constants: Constants | None = None, ) -> tuple[Sequence, State]: - return self.layer(x, constants=constants), state + return self.layer(x, training=training, constants=constants), state -class StatelessPointwise(PreservesShape, Stateless): - """Stateless layer that operates pointwise (preserves shape).""" +class StatelessPointwise(PreservesShape, Stateless, types.StatelessPointwise): + """A SequenceLayer that has no state and operates pointwise on its input.""" -class StatelessPointwiseFunctor(StatelessPointwise, metaclass=abc.ABCMeta): - """Stateless pointwise layer defined by a fn(values, mask).""" +class StatelessPointwiseFunctor(StatelessPointwise, types.StatelessPointwiseFunctor): + """A stateless SequenceLayer for simple pointwise processing fns.""" @abc.abstractmethod def fn(self, values: ValuesT, mask: MaskT) -> tuple[ValuesT, MaskT]: @@ -592,12 +830,21 @@ def fn(self, values: ValuesT, mask: MaskT) -> tuple[ValuesT, MaskT]: @property def mask_required(self): + """Returns true if fn can change the sequence's masked state. + + If fn(0) -> 0, then mask_required() is False. + """ return True @check_layer def layer( - self, x: Sequence, *, constants: Constants | None = None + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, ) -> Sequence: + del training if self.mask_required: y = x.apply(self.fn) else: @@ -613,18 +860,44 @@ def layer( # --------------------------------------------------------------------------- -class Emitting(SequenceLayer, metaclass=abc.ABCMeta): - """A SequenceLayer that emits auxiliary tensors.""" +class Emitting(SequenceLayer, types.Emitting): + """A SequenceLayer that emits auxiliary arrays. - def step( + This is a convenience subclass that implements step and layer in terms of + step_with_emits and layer_with_emits, so that implementors need only implement + two of the four methods. For emits that are substantially expensive to compute + subclasses can choose to implement all four and save computation in those that + do not produce emits. + """ + + @abc.abstractmethod + def get_initial_state( self, - x: Sequence, - state: State, + batch_size: int, + input_spec: ChannelSpec, *, + training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State]: - y, state, _ = self.step_with_emits(x, state, constants=constants) - return y, state + ) -> State: + pass + + @abc.abstractmethod + def get_output_shape( + self, + input_shape: ShapeLike, + *, + constants: Constants | None = None, + ) -> Shape: + pass + + @abc.abstractmethod + def get_output_dtype( + self, + input_dtype: DType, + *, + constants: Constants | None = None, + ) -> DType: + pass @abc.abstractmethod def step_with_emits( @@ -632,41 +905,107 @@ def step_with_emits( x: Sequence, state: State, *, + training: bool, constants: Constants | None = None, ) -> tuple[Sequence, State, Emits]: pass - def layer( - self, x: Sequence, *, constants: Constants | None = None - ) -> Sequence: - y, _ = self.layer_with_emits(x, constants=constants) - return y - @abc.abstractmethod def layer_with_emits( - self, x: Sequence, *, constants: Constants | None = None + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, ) -> tuple[Sequence, Emits]: pass + def step( + self, + x: Sequence, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, State]: + output, state, _ = self.step_with_emits( + x, state, training=training, constants=constants + ) + return output, state + + def layer( + self, + x: Sequence, + *, + training: bool, + constants: Constants | None = None, + ) -> Sequence: + outputs, _ = self.layer_with_emits( + x, training=training, constants=constants + ) + return outputs + -class StatelessEmitting(Emitting): - """Stateless layer that emits auxiliary tensors.""" +class StatelessEmitting(Emitting, types.StatelessEmitting): + """A SequenceLayer with no state over time that emits auxiliary arrays. - def step_with_emits( + Sub-classes must only implement: + - layer_with_emits + - get_output_shape + - get_output_dtype + """ + + @abc.abstractmethod + def get_output_shape( + self, + input_shape: ShapeLike, + *, + constants: Constants | None = None, + ) -> Shape: + pass + + @abc.abstractmethod + def get_output_dtype( + self, + input_dtype: DType, + *, + constants: Constants | None = None, + ) -> DType: + pass + + @abc.abstractmethod + def layer_with_emits( self, x: Sequence, - state: State, *, + training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State, Emits]: - y, emits = self.layer_with_emits(x, constants=constants) - return y, state, emits + ) -> tuple[Sequence, Emits]: + pass def get_initial_state( self, batch_size: int, input_spec: ChannelSpec, *, + training: bool, constants: Constants | None = None, ) -> State: - return () \ No newline at end of file + del batch_size + del input_spec + del training + del constants + return () + + def step_with_emits( + self, + x: Sequence, + state: State, + *, + training: bool, + constants: Constants | None = None, + ) -> tuple[Sequence, State, Emits]: + outputs, emits = self.layer_with_emits( + x, training=training, constants=constants + ) + return outputs, state, emits \ No newline at end of file diff --git a/sequence_layers/mlx/types_test.py b/sequence_layers/mlx/types_test.py index 6e12492..e65c28e 100644 --- a/sequence_layers/mlx/types_test.py +++ b/sequence_layers/mlx/types_test.py @@ -35,10 +35,10 @@ def create_steppable(self): class DefaultSteppable(types.Steppable): - def layer(self, x, *, constants=None): + def layer(self, x, *, training: bool, constants=None): return x - def step(self, x, state, *, constants=None): + def step(self, x, state, *, training: bool, constants=None): return x, state def get_initial_state(self, batch_size, input_spec, *, constants=None): @@ -59,5 +59,68 @@ def get_config_base_cls(self): return types.SequenceLayerConfig +class PreservesTypeTest(types_test_base.PreservesTypeTest): + def create_layer(self): + class DummyLayer(types.PreservesType, types.SequenceLayer): + def layer(self, x, *, training: bool, constants=None): return x + def step(self, x, state, *, training: bool, constants=None): return x, state + def get_initial_state(self, batch_size, input_spec, *, training: bool, constants=None): return () + def get_output_shape(self, input_shape, *, constants=None): return input_shape + return DummyLayer() + + +class PreservesShapeTest(types_test_base.PreservesShapeTest): + def create_layer(self): + class DummyLayer(types.PreservesShape, types.SequenceLayer): + def layer(self, x, *, training: bool, constants=None): return x + def step(self, x, state, *, training: bool, constants=None): return x, state + def get_initial_state(self, batch_size, input_spec, *, training: bool, constants=None): return () + def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype + return DummyLayer() + + +class StatelessTest(types_test_base.StatelessTest): + def create_layer(self): + class DummyLayer(types.Stateless, types.SequenceLayer): + def layer(self, x, *, training: bool, constants=None): return x + def get_output_shape(self, input_shape, *, constants=None): return input_shape + def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype + return DummyLayer() + + +class EmittingTest(types_test_base.EmittingTest): + def create_layer(self): + class DummyLayer(types.Emitting, types.SequenceLayer): + def get_initial_state(self, batch_size, input_spec, *, training: bool, constants=None): return () + def layer_with_emits(self, x, *, training: bool, constants=None): return x, () + def step_with_emits(self, x, state, *, training: bool, constants=None): return x, state, () + def get_output_shape(self, input_shape, *, constants=None): return input_shape + def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype + return DummyLayer() + + +class StatelessEmittingTest(types_test_base.StatelessEmittingTest): + def create_layer(self): + class DummyLayer(types.StatelessEmitting, types.SequenceLayer): + def layer_with_emits(self, x, *, training: bool, constants=None): return x, () + def get_output_shape(self, input_shape, *, constants=None): return input_shape + def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype + return DummyLayer() + + +class StatelessPointwiseFunctorTest(types_test_base.StatelessPointwiseFunctorTest): + def create_layer(self, is_mask_required: bool): + class DummyLayer(types.StatelessPointwiseFunctor, types.SequenceLayer): + @property + def mask_required(self): return is_mask_required + def fn(self, values, mask): return values, mask + def get_output_shape(self, input_shape, *, constants=None): return input_shape + def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype + return DummyLayer() + + def create_sequence(self): + return types.Sequence(mx.zeros((2, 3, 5)), mx.zeros((2, 3), dtype=mx.bool_)) + + if __name__ == '__main__': absltest.main() From 4e4ffa00fdaf49955fee81e6267b449636693dce Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Fri, 3 Apr 2026 03:53:06 -0700 Subject: [PATCH 2/9] refactor: Update reqs for module protocols, typechecking --- pyproject.toml | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6d0b8c2..aefdb47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ name = "sequence_layers" description = "Sequence Layers neural network layer library from Google." readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.13" license = {file = "LICENSE"} authors = [ {name = "RJ Skerry-Ryan", email="rjryan@google.com"}, @@ -24,7 +24,7 @@ dependencies = [ "jaxtyping", "numpy", "orbax-export", - "recurrentgemma[jax]", + "recurrentgemma[jax]>=1.0.1", "typeguard==2.13.3", ] @@ -50,13 +50,14 @@ mlx = [ "mlx", ] dev = [ - "absl-py", + "absl-py>=2.4.0", "chex", "orbax", + "pyink", + "pylint>=2.6.0", + "pyrefly>=0.58.0", "pytest", "pytest-xdist", - "pylint>=2.6.0", - "pyink", "tensorflow", # JAX tests use TensorFlow. ] @@ -81,4 +82,6 @@ exclude = [ # Do not release test files on PyPI "**/*_test.py", "testdata/**", -] \ No newline at end of file +] + +[tool.pyrefly] \ No newline at end of file From 1868fc063b0d98b4bb42ac474750054609ad5da9 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Thu, 2 Apr 2026 13:03:01 -0700 Subject: [PATCH 3/9] refactor: Backend module protocols, move to specs/ --- sequence_layers/abstract/types_test_base.py | 514 --------------- sequence_layers/jax/__init__.py | 4 + sequence_layers/jax/backend.py | 22 + sequence_layers/jax/test_utils.py | 7 +- sequence_layers/jax/types.py | 46 +- sequence_layers/jax/types_test.py | 151 ++--- sequence_layers/jax/utils.py | 1 + sequence_layers/mlx/__init__.py | 6 +- sequence_layers/mlx/backend.py | 22 + sequence_layers/mlx/types.py | 252 +++++--- sequence_layers/mlx/types_test.py | 121 +--- sequence_layers/specs/__init__.py | 38 ++ sequence_layers/specs/backend.py | 41 ++ sequence_layers/{abstract => specs}/types.py | 396 ++++++++---- sequence_layers/specs/types_behaviors.py | 636 +++++++++++++++++++ 15 files changed, 1300 insertions(+), 957 deletions(-) delete mode 100644 sequence_layers/abstract/types_test_base.py create mode 100644 sequence_layers/jax/backend.py create mode 100644 sequence_layers/mlx/backend.py create mode 100644 sequence_layers/specs/__init__.py create mode 100644 sequence_layers/specs/backend.py rename sequence_layers/{abstract => specs}/types.py (75%) create mode 100644 sequence_layers/specs/types_behaviors.py diff --git a/sequence_layers/abstract/types_test_base.py b/sequence_layers/abstract/types_test_base.py deleted file mode 100644 index 3bb6a43..0000000 --- a/sequence_layers/abstract/types_test_base.py +++ /dev/null @@ -1,514 +0,0 @@ -"""Abstract tests for Sequence types.""" - -import abc -from typing import Any, Callable, Sequence as TypingSequence -import dataclasses -from sequence_layers.abstract import types - -import fractions -from absl.testing import parameterized -import numpy as np -import unittest.mock - -class SequenceLayerTest(parameterized.TestCase): - """Base abstract test class providing common sequence testing assertions.""" - - @abc.abstractmethod - def assertSequencesClose(self, x: Any, y: Any, **kwargs): - pass - - @abc.abstractmethod - def assertSequencesNotClose(self, x: Any, y: Any, **kwargs): - pass - - @abc.abstractmethod - def assertSequencesEqual(self, x: Any, y: Any): - pass - - @abc.abstractmethod - def assertSequencesNotEqual(self, x: Any, y: Any): - pass - - @abc.abstractmethod - def assertAllEqual(self, x: Any, y: Any): - pass - - @abc.abstractmethod - def assertAllClose(self, x: Any, y: Any, **kwargs): - pass - - @abc.abstractmethod - def assertNotAllEqual(self, x: Any, y: Any): - pass - - @abc.abstractmethod - def assertNotAllClose(self, x: Any, y: Any, **kwargs): - pass - - -class SequenceTest(SequenceLayerTest): - """Abstract tests for the Sequence class.""" - - @abc.abstractmethod - def get_backend(self) -> Any: - """Returns the backend module (jax.numpy or mlx.core).""" - - @property - @abc.abstractmethod - def Sequence(self) -> type[types.Sequence]: - """Returns the Sequence class for the backend.""" - - @property - @abc.abstractmethod - def MaskedSequence(self) -> Any: - """Returns the MaskedSequence class for the backend.""" - - @property - def check_trees_all_equal(self) -> Callable[[Any, Any], None]: - """Returns a function to check tree equality.""" - return self.assertAllEqual - - def test_mask_invalid_idempotent(self): - xp = self.get_backend() - values = xp.array([ - [1.0, 2.0, 3.0, 4.0], - [10.0, 20.0, 30.0, 40.0], - ]) - # Different backends might handle boolean creation differently, but standard numpy-like syntax usually works - mask = xp.array([[True, True, False, False], [False, False, False, True]]) - - x = self.Sequence(values, mask) - masked = x.mask_invalid() - self.assertIsNot(masked, x) - # We can't easily check isinstance here without importing the concrete classes, - # but we can check behavior or use a property if we added one. - # For now, we trust the concrete tests to check types if needed, - # or we could add abstract methods to check types. - - masked_again = masked.mask_invalid() - self.assertIs(masked_again, masked) - - masked2 = x.mask_invalid() - self.assertIsNot(masked2, masked) - - @parameterized.named_parameters( - ('mask_value=None', 0.0, None), - ('mask_value=0.0', 0.0, 0.0), - ('mask_value=-1.0', -1.0, -1.0), - ) - def test_mask_invalid(self, mask_value, expected_mask_value): - xp = self.get_backend() - values = xp.array([ - [1.0, 2.0, 3.0, 4.0], - [10.0, 20.0, 30.0, 40.0], - ]) - mask = xp.array([[True, True, False, False], [False, False, False, True]]) - - # Pass mask_value only if it is not None (to test default None behavior vs explicit value) - if expected_mask_value is None: - output = self.Sequence(values, mask).mask_invalid() - fill_value = 0.0 - else: - output = self.Sequence(values, mask).mask_invalid(mask_value) - fill_value = mask_value - - expected_values = xp.array([ - [1.0, 2.0, fill_value, fill_value], - [fill_value, fill_value, fill_value, 40.0], - ]) - self.check_trees_all_equal(output.values, expected_values) - self.check_trees_all_equal(output.mask, mask) - - def test_pad_time(self): - xp = self.get_backend() - values = xp.array([ - [1.0, 2.0, 3.0, 4.0], - [10.0, 20.0, 30.0, 40.0], - ]) - mask = xp.array([[True, True, False, False], [False, False, False, True]]) - - x = self.Sequence(values, mask).mask_invalid() - - y = x.pad_time(0, 0, valid=False) - self.check_trees_all_equal(y.values, x.values) - self.check_trees_all_equal(y.mask, x.mask) - - y = x.pad_time(1, 0, valid=False) - - x_left1 = self.Sequence( - xp.array([ - [0.0, 1.0, 2.0, 3.0, 4.0], - [0.0, 10.0, 20.0, 30.0, 40.0], - ]), - xp.array([ - [False, True, True, False, False], - [False, False, False, False, True], - ]), - ).mask_invalid() - self.check_trees_all_equal(y.values, x_left1.values) - self.check_trees_all_equal(y.mask, x_left1.mask) - - def _create_test_sequence(self, shape): - xp = self.get_backend() - size = 1 - for d in shape: size *= d - values_np = np.arange(size, dtype=np.float32).reshape(shape) - mask_np = np.ones(shape[:2], dtype=bool) - if shape[0] > 0 and shape[1] > 1: - mask_np[0, 1] = False - - values = xp.array(values_np) - mask = xp.array(mask_np) - return self.Sequence(values, mask) - - def test_slice(self): - x = self._create_test_sequence((3, 5, 9)) - - self.assertSequencesEqual( - x[:, 1:], self.Sequence(x.values[:, 1:], x.mask[:, 1:]) - ) - self.assertSequencesEqual( - x[:, ::2], self.Sequence(x.values[:, ::2], x.mask[:, ::2]) - ) - self.assertSequencesEqual( - x[::2, ::3], self.Sequence(x.values[::2, ::3], x.mask[::2, ::3]) - ) - - def test_slice_can_slice_channel_dimensions(self): - x = self._create_test_sequence((3, 5, 9, 4)) - - self.assertSequencesEqual( - x[:, 1:, :], self.Sequence(x.values[:, 1:], x.mask[:, 1:]) - ) - self.assertSequencesEqual( - x[:, ::2, :3], - self.Sequence(x.values[:, ::2, :3], x.mask[:, ::2]), - ) - - def test_apply_values(self): - xp = self.get_backend() - values = xp.array([ - [-1.0, 2.0, 3.0, 4.0], - [10.0, -20.0, 30.0, 40.0], - ]) - mask = xp.array([[True, True, False, False], [False, True, False, True]]) - - x = self.Sequence(values, mask) - masked = x.mask_invalid() - - # Simple abs function - fn = abs - - y = x.apply_values(fn) - self.check_trees_all_equal(y.values, fn(x.values)) - self.check_trees_all_equal(y.mask, x.mask) - - y = masked.apply_values(fn) - self.check_trees_all_equal(y.values, fn(masked.values)) - self.check_trees_all_equal(y.mask, x.mask) - - y = masked.apply_values_masked(fn) - self.check_trees_all_equal(y.values, fn(masked.values)) - self.check_trees_all_equal(y.mask, x.mask) - - def test_apply_values_args(self): - xp = self.get_backend() - values = xp.array([ - [-1.0, 2.0, 3.0, 4.0], - [10.0, -20.0, 30.0, 40.0], - ]) - mask = xp.array([[True, True, False, False], [False, True, False, True]]) - x = self.Sequence(values, mask) - - target_shape = (2, 4, 1) - y = x.apply_values(lambda v, s: v.reshape(s), target_shape) - self.check_trees_all_equal(y.values.shape, target_shape) - self.check_trees_all_equal(y.mask.shape, (2, 4)) - - def test_from_values(self): - xp = self.get_backend() - values_np = np.array([ - [1.0, 2.0], - [3.0, 4.0] - ], dtype=np.float32) - values = xp.array(values_np) - # Get the class from an instance - seq = self.Sequence(values, xp.array(np.ones(values.shape[:2], dtype=bool))) - SeqClass = type(seq) - - x = SeqClass.from_values(values) - self.check_trees_all_equal(x.values, values) - self.check_trees_all_equal(x.mask, xp.array(np.ones(values.shape[:2], dtype=bool))) - - def test_astype(self): - xp = self.get_backend() - values_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) - mask_np = np.array([[True, False], [False, True]], dtype=bool) - - values = xp.array(values_np) - mask = xp.array(mask_np) - - x = self.Sequence(values, mask) - - # We need a dtype that matches the backend - if xp.__name__ == 'jax.numpy': - dtype = xp.int32 - elif xp.__name__ == 'mlx.core': - dtype = xp.int32 - else: - dtype = np.int32 - - y = x.astype(dtype) - - # Check values match casted version - self.check_trees_all_equal(y.mask, mask) - # y.values might be mlx array, values.astype(dtype) might be numpy if values was numpy? - # values is backend array. values.astype(dtype) should work if dtype is backend dtype. - self.check_trees_all_equal(y.values, values.astype(dtype)) - - def test_mask_invalid_idempotent(self): - xp = self.get_backend() - values = xp.array([ - [1.0, 2.0, 3.0, 4.0], - [10.0, 20.0, 30.0, 40.0], - ]) - mask = xp.array([[True, True, False, False], [False, False, False, True]]) - - x = self.Sequence(values, mask) - masked = x.mask_invalid() - self.assertIsNot(masked, x) - self.assertIsInstance(masked, self.MaskedSequence) - - masked_again = masked.mask_invalid() - self.assertIs(masked_again, masked) - self.assertIsInstance(masked_again, self.MaskedSequence) - - masked2 = x.mask_invalid() - self.assertIsNot(masked2, masked) - self.assertIsInstance(masked2, self.MaskedSequence) - - def test_from_lengths(self): - xp = self.get_backend() - values = xp.array(np.arange(5 * 17 * 2).reshape((5, 17, 2)).astype(np.float32)) - lengths_np = np.array([0, 5, 10, 17, 12], dtype=np.int32) - mask_np = np.arange(17)[None, :] < lengths_np[:, None] - mask = xp.array(mask_np) - - x_expected = self.Sequence(values, mask) - x = self.Sequence.from_lengths(x_expected.values, lengths_np) - self.check_trees_all_equal(x.values, x_expected.values) - self.check_trees_all_equal(x.mask, x_expected.mask) - - # Out of range lengths are clipped to 0 or max. - x = self.Sequence.from_lengths(x_expected.values, [-1, 5, 10, 17, 18]) - self.check_trees_all_equal(x.lengths(), xp.array([0, 5, 10, 17, 17])) - self.assertNotIsInstance(x, self.MaskedSequence) - - # Return type is MaskedSequence if is_masked=True. - x = self.Sequence.from_lengths( - x_expected.values, [-1, 5, 10, 17, 18], is_masked=True - ) - self.check_trees_all_equal(x.lengths(), xp.array([0, 5, 10, 17, 17])) - self.assertIsInstance(x, self.MaskedSequence) - - - -class SteppableTest(parameterized.TestCase): - """Abstract tests for Steppable layers.""" - - @abc.abstractmethod - def create_steppable(self) -> Any: - """Creates a basic Steppable instance that should have default properties.""" - - def test_steppable_defaults(self): - layer = self.create_steppable() - self.assertEqual(layer.block_size, 1) - self.assertEqual(layer.output_ratio, fractions.Fraction(1)) - self.assertTrue(layer.supports_step) - self.assertEqual(layer.input_latency, 0) - self.assertEqual(layer.output_latency, 0) - self.assertEqual(layer.get_accumulated_input_latency(0), 0) - self.assertEqual(layer.get_accumulated_output_latency(0), 0) - - def test_steppable_with_emits_defaults_to_tuple_with_empty_emits(self): - layer = self.create_steppable() - - with unittest.mock.patch.object(layer, 'layer', return_value='mock_layer_out') as mock_layer: - out, emits = layer.layer_with_emits('mock_x', training=False, constants=None) - self.assertEqual(out, 'mock_layer_out') - self.assertEqual(emits, ()) - mock_layer.assert_called_with('mock_x', training=False, constants=None) - - with unittest.mock.patch.object(layer, 'step', return_value=('step_out', 'state_out')) as mock_step: - out, state, emits = layer.step_with_emits('mock_x', 'state_in', training=True, constants=None) - self.assertEqual(out, 'step_out') - self.assertEqual(state, 'state_out') - self.assertEqual(emits, ()) - mock_step.assert_called_with('mock_x', 'state_in', training=True, constants=None) - - -class SequenceLayerConfigTest(SequenceLayerTest): - - @abc.abstractmethod - def get_config_base_cls(self) -> type[types.SequenceLayerConfig]: - """Returns the backend-specific SequenceLayerConfig class.""" - - def test_copy(self): - ConfigBase = self.get_config_base_cls() - - @dataclasses.dataclass(frozen=True) - class Config(ConfigBase): - a: int = 1234 - b: str = 'default string' - - def make(self) -> Any: - return 'dummy_layer' - - config = Config() - new_config = config.copy(b='new string') - self.assertEqual(new_config.a, config.a) - self.assertEqual(new_config.b, 'new string') - - def test_copy_raises_on_non_dataclass(self): - ConfigBase = self.get_config_base_cls() - - class NonDataclassConfig(ConfigBase): - - def make(self) -> Any: - return 'dummy_layer' - - config = NonDataclassConfig() - with self.assertRaises(TypeError): - new_config = config.copy() - del new_config - - def test_copy_disallows_new_fields(self): - ConfigBase = self.get_config_base_cls() - - @dataclasses.dataclass(frozen=True) - class Config(ConfigBase): - - def make(self) -> Any: - return 'dummy_layer' - - config = Config() - # dataclasses.replace raises TypeError for unknown arguments - # JAX implementation wraps it in AttributeError - with self.assertRaises((TypeError, AttributeError)): - new_config = config.copy(field_does_not_exist=1234) - del new_config - - -class PreservesTypeTest(parameterized.TestCase): - - @abc.abstractmethod - def create_layer(self) -> types.PreservesType: - pass - - def test_preserves_dtype(self): - layer = self.create_layer() - self.assertEqual(layer.get_output_dtype('fake_dtype123'), 'fake_dtype123') - - -class PreservesShapeTest(parameterized.TestCase): - - @abc.abstractmethod - def create_layer(self) -> types.PreservesShape: - pass - - def test_preserves_shape(self): - layer = self.create_layer() - self.assertEqual(layer.get_output_shape((1, 2, 3, 5)), (1, 2, 3, 5)) - - -class StatelessTest(parameterized.TestCase): - - @abc.abstractmethod - def create_layer(self) -> types.Stateless: - pass - - def test_stateless_behaviors(self): - layer = self.create_layer() - - # Initial state must be empty - self.assertEqual( - layer.get_initial_state(32, 'fake_spec', training=False), () - ) - - # step unconditionally delegates to layer and returns identical empty state - with unittest.mock.patch.object(layer, 'layer', return_value='layer_out') as mock_layer: - out, state = layer.step('mock_x', 'mock_state', training=True, constants={'c': 1}) - self.assertEqual(out, 'layer_out') - self.assertEqual(state, 'mock_state') - mock_layer.assert_called_once_with('mock_x', training=True, constants={'c': 1}) - - -class EmittingTest(parameterized.TestCase): - - @abc.abstractmethod - def create_layer(self) -> types.Emitting: - pass - - def test_emitting_drops_emits_on_standard_calls(self): - layer = self.create_layer() - - with unittest.mock.patch.object(layer, 'layer_with_emits', return_value=('out', 'emits')) as m_layer: - self.assertEqual(layer.layer('mock_x', training=False), 'out') - m_layer.assert_called_once_with('mock_x', training=False, constants=None) - - with unittest.mock.patch.object(layer, 'step_with_emits', return_value=('out', 'state', 'emits')) as m_step: - out, state = layer.step('mock_x', 'state', training=True, constants={'c': 1}) - self.assertEqual(out, 'out') - self.assertEqual(state, 'state') - m_step.assert_called_once_with('mock_x', 'state', training=True, constants={'c': 1}) - - -class StatelessEmittingTest(parameterized.TestCase): - - @abc.abstractmethod - def create_layer(self) -> types.StatelessEmitting: - pass - - def test_stateless_emitting_behaviors(self): - layer = self.create_layer() - - self.assertEqual( - layer.get_initial_state(32, 'fake_spec', training=False), () - ) - - with unittest.mock.patch.object(layer, 'layer_with_emits', return_value=('out', 'emits')) as m_layer: - out, state, emits = layer.step_with_emits('mock_x', 'state', training=False) - self.assertEqual(out, 'out') - self.assertEqual(state, 'state') - self.assertEqual(emits, 'emits') - m_layer.assert_called_once_with('mock_x', training=False, constants=None) - - -class StatelessPointwiseFunctorTest(parameterized.TestCase): - - @abc.abstractmethod - def create_layer(self, mask_required: bool) -> types.StatelessPointwiseFunctor: - pass - - @abc.abstractmethod - def create_sequence(self) -> types.Sequence: - pass - - def test_layer_applies_fn_based_on_mask_required(self): - for mask_required in [True, False]: - with self.subTest(mask_required=mask_required): - layer = self.create_layer(mask_required) - x = self.create_sequence() - # Mock the apply methods on the Sequence class itself so we return a valid Sequence - # that satisfies any @check_layer decorators. - with unittest.mock.patch.object(type(x), 'apply', return_value=x) as mock_apply: - with unittest.mock.patch.object(type(x), 'apply_masked', return_value=x) as mock_apply_masked: - layer.layer(x, training=False) - - if mask_required: - mock_apply.assert_called_once() - mock_apply_masked.assert_not_called() - else: - mock_apply_masked.assert_called_once() - mock_apply.assert_not_called() - diff --git a/sequence_layers/jax/__init__.py b/sequence_layers/jax/__init__.py index 85bb162..b922ee7 100644 --- a/sequence_layers/jax/__init__.py +++ b/sequence_layers/jax/__init__.py @@ -13,6 +13,10 @@ # limitations under the License. """Sequence layers in JAX.""" +# (re-export the names for typechecking) +from . import backend as backend +from . import types as types + # pylint: disable=wildcard-import from sequence_layers.jax.attention import * from sequence_layers.jax.combinators import * diff --git a/sequence_layers/jax/backend.py b/sequence_layers/jax/backend.py new file mode 100644 index 0000000..4efdc75 --- /dev/null +++ b/sequence_layers/jax/backend.py @@ -0,0 +1,22 @@ +"""Backend-specific helpers (JAX)""" + +import jax.numpy as jnp + +from sequence_layers.specs import backend +from sequence_layers.specs import types as types_spec + + +class BackendWrapper: + """Thin wrapper around JAX to match NumPy interface for tests.""" + + bool_ = jnp.bool_ + int32 = jnp.int32 + + def array(self, a, dtype=None) -> types_spec.Array: + return jnp.array(a, dtype=dtype) + + def zeros(self, shape, dtype=None) -> types_spec.Array: + return jnp.zeros(shape, dtype=dtype) + + +xp: backend.xp = BackendWrapper() diff --git a/sequence_layers/jax/test_utils.py b/sequence_layers/jax/test_utils.py index ce6741e..1a75441 100644 --- a/sequence_layers/jax/test_utils.py +++ b/sequence_layers/jax/test_utils.py @@ -27,10 +27,13 @@ import jax import jax.numpy as jnp import numpy as np + +import sequence_layers.jax as sl from sequence_layers.jax import types from sequence_layers.jax import typing as jt from sequence_layers.jax import utils +from sequence_layers.specs import types_behaviors as types_behaviors_spec _SequenceLayerT = TypeVar('_SequenceLayerT', bound=types.SequenceLayer) _T = TypeVar('_T') @@ -777,9 +780,11 @@ def _mask_and_pad_to_max_length( return a, b -class SequenceLayerTest(parameterized.TestCase): +class SequenceLayerTest(types_behaviors_spec.SequenceLayerTest[types.Sequence]): """Base class for SequenceLayer tests.""" + sl = sl + def setUp(self): super().setUp() # To avoid flakes, fix random seeds. diff --git a/sequence_layers/jax/types.py b/sequence_layers/jax/types.py index 5886433..86add8f 100644 --- a/sequence_layers/jax/types.py +++ b/sequence_layers/jax/types.py @@ -30,7 +30,7 @@ import jaxtyping import numpy as np -from sequence_layers.abstract import types +from sequence_layers.specs import types as spec from sequence_layers.jax import sharding as sharding_lib from sequence_layers.jax import typing as jt import typeguard @@ -189,8 +189,8 @@ def get_einsum( jax.core.ShapedArray, ) -PaddingMode = types.PaddingMode -PaddingModeString = types.PaddingModeString +PaddingMode = spec.PaddingMode +PaddingModeString = spec.PaddingModeString def validate_padding(padding: str) -> PaddingModeString: @@ -254,7 +254,9 @@ def sequence_mask(lengths: LengthsT, maxlen: int) -> MaskT: ) -class Sequence(types.Sequence[ValuesT, MaskT], struct.PyTreeNode): +class Sequence[ValuesT, MaskT]( + spec.Sequence[ValuesT, MaskT], struct.PyTreeNode +): """A generic sequence container that preserves masking information.""" values: ValuesT @@ -537,7 +539,9 @@ def unmask(self) -> 'Sequence': return self -class MaskedSequence(Sequence[ValuesT, MaskT], types.MaskedSequence[ValuesT, MaskT]): +class MaskedSequence( + Sequence[ValuesT, MaskT], spec.MaskedSequence[ValuesT, MaskT] +): """Sequence whose invalid timesteps are masked to zero.""" @override @@ -598,6 +602,8 @@ def __getitem__(cls, item): class SequenceT(Sequence, metaclass=MetaSequenceT): + """Allows typing to be: SequenceT[Float, "B T C"]""" + pass @@ -669,7 +675,7 @@ def _add_custom_checker_lookup_fn(lookup_fn): _add_custom_checker_lookup_fn(_sequence_checker_lookup_fn) -class Steppable(types.Steppable): +class Steppable(spec.Steppable): """A sequence processing layer that can be executed layerwise or stepwise. # Step-wise execution: @@ -880,6 +886,7 @@ def receptive_field_per_step(self) -> dict[int, ReceptiveField]: ) @abc.abstractmethod + @override def layer( self, x: Sequence, *, training: bool, constants: Constants | None = None ) -> Sequence: @@ -1099,7 +1106,6 @@ def get_output_dtype( ) -> DType: """Returns the layer's output dtype for the specified input dtype.""" - @nn.nowrap def get_output_spec( self, @@ -1267,11 +1273,11 @@ def check_step_with_emits_fn( return check_step_with_emits_fn -class SequenceLayer(nn.Module, Steppable, types.SequenceLayer): +class SequenceLayer(nn.Module, Steppable, spec.SequenceLayer): """Base Module for Sequence Layers.""" -class PreservesType(types.PreservesType): +class PreservesType(SequenceLayer, spec.PreservesType): """A mix-in for layers that do not change the input dtype.""" @nn.nowrap @@ -1283,7 +1289,7 @@ def get_output_dtype( return input_dtype -class PreservesShape(types.PreservesShape): +class PreservesShape(SequenceLayer, spec.PreservesShape): """A mix-in for layers that do not change the input shape.""" @nn.nowrap @@ -1295,7 +1301,7 @@ def get_output_shape( return tuple(input_shape) -class Emitting(SequenceLayer, types.Emitting): +class Emitting(SequenceLayer, spec.Emitting): """A SequenceLayer that emits auxiliary arrays. This is a convenience subclass that implements step and layer in terms of @@ -1356,7 +1362,7 @@ def layer_with_emits( pass -class Stateless(SequenceLayer, types.Stateless): +class Stateless(SequenceLayer, spec.Stateless): """A SequenceLayer with no state over time required for step-wise processing. Sub-classes must only implement: @@ -1369,6 +1375,7 @@ class Stateless(SequenceLayer, types.Stateless): def receptive_field_per_step(self) -> dict[int, ReceptiveField]: return {0: (0, 0)} + @override def get_initial_state( self, batch_size: int, @@ -1383,6 +1390,7 @@ def get_initial_state( del constants return () + @override def step( self, x: Sequence, @@ -1394,18 +1402,21 @@ def step( return self.layer(x, training=training, constants=constants), state @abc.abstractmethod + @override def get_output_shape( self, input_shape: ShapeLike, *, constants: Constants | None = None ) -> Shape: pass @abc.abstractmethod + @override def get_output_dtype( self, input_dtype: DType, *, constants: Constants | None = None ) -> DType: pass @abc.abstractmethod + @override def layer( self, x: Sequence, @@ -1416,7 +1427,7 @@ def layer( pass -class StatelessEmitting(Emitting, types.StatelessEmitting): +class StatelessEmitting(Emitting, spec.StatelessEmitting): """A SequenceLayer with no state over time that emits auxiliary arrays. Sub-classes must only implement: @@ -1429,6 +1440,7 @@ class StatelessEmitting(Emitting, types.StatelessEmitting): def receptive_field_per_step(self) -> dict[int, ReceptiveField]: return {0: (0, 0)} + @override def step_with_emits( self, x: Sequence, @@ -1479,11 +1491,13 @@ def layer_with_emits( pass -class StatelessPointwise(PreservesShape, Stateless, types.StatelessPointwise): +class StatelessPointwise(PreservesShape, Stateless, spec.StatelessPointwise): """A SequenceLayer that has no state and operates pointwise on its input.""" -class StatelessPointwiseFunctor(StatelessPointwise, types.StatelessPointwiseFunctor): +class StatelessPointwiseFunctor( + StatelessPointwise, spec.StatelessPointwiseFunctor +): """A stateless SequenceLayer for simple pointwise processing fns.""" @abc.abstractmethod @@ -1516,7 +1530,7 @@ def layer( return y -class SequenceLayerConfig(types.SequenceLayerConfig): +class SequenceLayerConfig(spec.SequenceLayerConfig): """Base class for SequenceLayer configuration objects. Requires a no-argument make() method which returns a SequenceLayer. diff --git a/sequence_layers/jax/types_test.py b/sequence_layers/jax/types_test.py index e297f9f..4bf0632 100644 --- a/sequence_layers/jax/types_test.py +++ b/sequence_layers/jax/types_test.py @@ -25,37 +25,34 @@ import jaxtyping import numpy as np -from sequence_layers.abstract import types_test_base +import sequence_layers.jax as sl from sequence_layers.jax import simple from sequence_layers.jax import test_utils from sequence_layers.jax import types from sequence_layers.jax import typing as jt +from sequence_layers.specs import types_behaviors as spec -class Foo(nn.Module): +class ModuleInterfaceTest(spec.ModuleInterfaceTest): + sl = sl - @nn.compact - def __call__(self, x: types.Sequence) -> types.Sequence: - return x - -class SequenceTest(test_utils.SequenceLayerTest, types_test_base.SequenceTest): +class SequenceTest(test_utils.SequenceLayerTest, spec.SequenceTest): """Tests for the Sequence class.""" - def get_backend(self): - return jnp - - @property - def Sequence(self): - return types.Sequence - - @property - def MaskedSequence(self): - return types.MaskedSequence + sl = sl def test_type_checks(self): """Test type checks in Sequence.__post_init__.""" + class Foo(nn.Module): + + @nn.compact + def __call__( + self, x: types.Sequence[types.ValuesT, types.MaskT] + ) -> types.Sequence: + return x + # Allowed: Both array-like. types.Sequence(jnp.zeros((2, 3, 5)), jnp.zeros((2, 3), dtype=jnp.bool_)) types.Sequence(np.zeros((2, 3, 5)), np.zeros((2, 3), dtype=jnp.bool_)) @@ -109,8 +106,6 @@ def test_type_checks(self): with self.assertRaises(jaxtyping.TypeCheckError): types.Sequence(np.zeros((2, 3, 5)), np.zeros((1, 3), dtype=jnp.bool_)) - - def test_type_annotation(self): if not jt.runtime_type_checking_enabled: self.skipTest('Type checking is disabled.') @@ -119,7 +114,7 @@ def test_type_annotation(self): def f( x: types.SequenceT[jt.Float, 'B T C'], ) -> types.SequenceT[jt.Float, 'B T C 1']: - return types.Sequence(x.values[..., jnp.newaxis], x.mask) + return types.Sequence(x.values[..., jnp.newaxis], x.mask) # type: ignore[return-value] values = jnp.zeros((2, 3, 5)) mask = jnp.zeros((2, 3), dtype=jnp.bool_) @@ -153,7 +148,7 @@ def test_type_annotation_masked_sequence(self): def f( x: types.SequenceT[jt.Float, 'B T C'], ) -> types.SequenceT[jt.Float, 'B T C 1']: - return types.MaskedSequence(x.values[..., jnp.newaxis], x.mask) + return types.MaskedSequence(x.values[..., jnp.newaxis], x.mask) # type: ignore[return-value] values = jnp.zeros((2, 3, 5)) mask = jnp.zeros((2, 3), dtype=jnp.bool_) @@ -205,12 +200,8 @@ def fn(x: types.Sequence) -> types.Sequence: self.assertSequencesEqual(y, x) - - -class SequenceLayerConfigTest(types_test_base.SequenceLayerConfigTest): - - def get_config_base_cls(self): - return types.SequenceLayerConfig +class SequenceLayerConfigTest(spec.SequenceLayerConfigTest): + sl = sl def test_copy_raises_on_mutable_attribute(self): @@ -251,93 +242,33 @@ def make(self) -> simple.Identity: del new_config -class SteppableTest(types_test_base.SteppableTest): +class SteppableTest(spec.SteppableTest): + sl = sl - def create_steppable(self): - class DefaultSteppable(types.Steppable): +class PreservesTypeTest(spec.PreservesTypeTest): + sl = sl - def layer(self, x, *, training: bool, constants=None): - return x - def step(self, x, state, *, training: bool, constants=None): - return x, state - - def get_initial_state(self, batch_size, input_spec, *, constants=None): - return 0 - - def get_output_shape(self, input_shape, *, constants=None): - return input_shape - - def get_output_dtype(self, input_dtype, *, constants=None): - return input_dtype - - return DefaultSteppable() - - -class PreservesTypeTest(types_test_base.PreservesTypeTest): - def create_layer(self): - class DummyLayer(types.PreservesType, types.SequenceLayer): - def layer(self, x, *, training: bool, constants=None): return x - def step(self, x, state, *, training: bool, constants=None): return x, state - def get_initial_state(self, batch_size, input_spec, *, training: bool, constants=None): return () - def get_output_shape(self, input_shape, *, constants=None): return input_shape - return DummyLayer() - - -class PreservesShapeTest(types_test_base.PreservesShapeTest): - def create_layer(self): - class DummyLayer(types.PreservesShape, types.SequenceLayer): - def layer(self, x, *, training: bool, constants=None): return x - def step(self, x, state, *, training: bool, constants=None): return x, state - def get_initial_state(self, batch_size, input_spec, *, training: bool, constants=None): return () - def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype - return DummyLayer() - - -class StatelessTest(types_test_base.StatelessTest): - def create_layer(self): - class DummyLayer(types.Stateless, types.SequenceLayer): - def layer(self, x, *, training: bool, constants=None): return x - def get_output_shape(self, input_shape, *, constants=None): return input_shape - def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype - return DummyLayer() - - -class EmittingTest(types_test_base.EmittingTest): - def create_layer(self): - class DummyLayer(types.Emitting, types.SequenceLayer): - def get_initial_state(self, batch_size, input_spec, *, training: bool, constants=None): return () - def layer_with_emits(self, x, *, training: bool, constants=None): return x, () - def step_with_emits(self, x, state, *, training: bool, constants=None): return x, state, () - def get_output_shape(self, input_shape, *, constants=None): return input_shape - def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype - @property - def receptive_field_per_step(self): return {0: (0, 0)} - return DummyLayer() - - -class StatelessEmittingTest(types_test_base.StatelessEmittingTest): - def create_layer(self): - class DummyLayer(types.StatelessEmitting, types.SequenceLayer): - def layer_with_emits(self, x, *, training: bool, constants=None): return x, () - def get_output_shape(self, input_shape, *, constants=None): return input_shape - def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype - return DummyLayer() - - -class StatelessPointwiseFunctorTest(types_test_base.StatelessPointwiseFunctorTest): - def create_layer(self, is_mask_required: bool): - class DummyLayer(types.StatelessPointwiseFunctor, types.SequenceLayer): - @property - def mask_required(self): return is_mask_required - def fn(self, values, mask): return values, mask - def get_output_shape(self, input_shape, *, constants=None): return input_shape - def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype - return DummyLayer() - - def create_sequence(self): - return types.Sequence(jnp.zeros((2, 3, 5)), jnp.zeros((2, 3), dtype=jnp.bool_)) +class PreservesShapeTest(spec.PreservesShapeTest): + sl = sl + + +class StatelessTest(spec.StatelessTest): + sl = sl + + +class EmittingTest(spec.EmittingTest): + sl = sl + + +class StatelessEmittingTest(spec.StatelessEmittingTest): + sl = sl + + +class StatelessPointwiseFunctorTest(spec.StatelessPointwiseFunctorTest): + sl = sl + if __name__ == '__main__': test_utils.main() diff --git a/sequence_layers/jax/utils.py b/sequence_layers/jax/utils.py index 64029fb..23b549d 100644 --- a/sequence_layers/jax/utils.py +++ b/sequence_layers/jax/utils.py @@ -2228,6 +2228,7 @@ def layer_with_emits_spec( values_spec, types.ShapeDType(values_spec.shape[:2], dtype=types.MASK_DTYPE), ) + def layer_fn( layer: types.SequenceLayer, x: types.Sequence, diff --git a/sequence_layers/mlx/__init__.py b/sequence_layers/mlx/__init__.py index 8cb7355..58124eb 100644 --- a/sequence_layers/mlx/__init__.py +++ b/sequence_layers/mlx/__init__.py @@ -13,4 +13,8 @@ # limitations under the License. """Sequence layers in MLX.""" -from sequence_layers.mlx.types import * \ No newline at end of file +# (re-export the names for typechecking) +from . import backend as backend +from . import types as types + +from sequence_layers.mlx.types import * diff --git a/sequence_layers/mlx/backend.py b/sequence_layers/mlx/backend.py new file mode 100644 index 0000000..4d96407 --- /dev/null +++ b/sequence_layers/mlx/backend.py @@ -0,0 +1,22 @@ +"""Backend-specific helpers (MLX)""" + +import mlx.core as mx + +from sequence_layers.specs import backend +from sequence_layers.specs import types as types_spec + + +class BackendWrapper: + """Thin wrapper around MLX to match NumPy interface for tests.""" + + bool_ = mx.bool_ + int32 = mx.int32 + + def array(self, a, dtype=None) -> types_spec.Array: + return mx.array(a, dtype=dtype) + + def zeros(self, shape, dtype=None) -> types_spec.Array: + return mx.zeros(shape, dtype=dtype) + + +xp: backend.xp = BackendWrapper() diff --git a/sequence_layers/mlx/types.py b/sequence_layers/mlx/types.py index c897322..667d18d 100644 --- a/sequence_layers/mlx/types.py +++ b/sequence_layers/mlx/types.py @@ -5,13 +5,15 @@ import enum import fractions import functools -from typing import Callable, Generic, Iterable, TypeVar, override +import types +from typing import Any, Callable, Generic, Iterable, Self, TypeVar, override, cast +import jaxtyping as jt import mlx.core as mx import mlx.nn as nn import numpy as np -from sequence_layers.abstract import types +from sequence_layers.specs import types as spec # Type aliases. MASK_DTYPE = mx.bool_ @@ -20,11 +22,13 @@ MaskT = TypeVar('MaskT', bound=mx.array) LengthsT = TypeVar('LengthsT', bound=mx.array) ExpandedMaskT = TypeVar('ExpandedMaskT', bound=mx.array) +NewValuesT = TypeVar('NewValuesT', bound=mx.array) +NewMaskT = TypeVar('NewMaskT', bound=mx.array) SequenceSelf = TypeVar('SequenceSelf', bound='Sequence') Shape = tuple[int, ...] ShapeLike = list[int] | tuple[int, ...] -DType = np.dtype +DType = mx.Dtype State = object # Any pytree. Constants = dict[str, object] Emits = object @@ -32,6 +36,8 @@ # Receptive field. ReceptiveField = tuple[float | int, float | int] | None +InputT = TypeVar('InputT', bound='Sequence') +OutputT = TypeVar('OutputT', bound='Sequence') __all__ = ( # go/keep-sorted start @@ -67,6 +73,7 @@ # go/keep-sorted end ) + class ShapeDType: """Lightweight replacement for jax.ShapeDtypeStruct.""" @@ -88,14 +95,16 @@ def __hash__(self) -> int: ChannelSpec = ShapeDType -PaddingMode = types.PaddingMode +PaddingMode = spec.PaddingMode -def sequence_mask(lengths: LengthsT, maxlen: int) -> MaskT: +def sequence_mask(lengths: LengthsT, maxlen: int) -> mx.array: return mx.arange(maxlen)[None, :] < mx.array(lengths)[:, None] -class Sequence(types.Sequence[ValuesT, MaskT]): +class Sequence[ValuesT: mx.array, MaskT: mx.array]( + spec.Sequence[ValuesT, MaskT] +): """A generic sequence container that preserves masking information.""" values: ValuesT @@ -143,9 +152,13 @@ def from_lengths( is_masked: bool = False, ) -> 'Sequence': """Constructs a sequence from values and per-batch element lengths.""" - values = mx.array(values) - mask = sequence_mask(lengths, maxlen=values.shape[1]) - return MaskedSequence(values, mask) if is_masked else Sequence(values, mask) + values_arr = mx.array(values) + mask = sequence_mask(lengths, maxlen=values_arr.shape[1]) + return ( + MaskedSequence(values_arr, mask) + if is_masked + else Sequence(values_arr, mask) + ) @classmethod @override @@ -174,8 +187,9 @@ def concatenate_sequences(cls, sequences: Iterable['Sequence']) -> 'Sequence': ) @override - def expanded_mask(self) -> ExpandedMaskT: + def expanded_mask(self) -> mx.array: """Returns the Sequence mask expanded to match values rank.""" + print(self, type(self), dir(self), self.mask, type(self.mask)) return self.mask.reshape(self.mask.shape + (1,) * (self.values.ndim - 2)) @override @@ -190,13 +204,16 @@ def apply_values( @override def apply_values_masked( - self: SequenceSelf, - values_fn: Callable[..., ValuesT], + self, + values_fn: Callable[..., NewValuesT], *args, **kwargs, - ) -> SequenceSelf: + ) -> 'Sequence[NewValuesT, MaskT]': """Transforms values, preserving masked state.""" - return type(self)(values_fn(self.values, *args, **kwargs), self.mask) + return cast( + 'Sequence[NewValuesT, MaskT]', + type(self)(values_fn(self.values, *args, **kwargs), self.mask), + ) @override def apply( @@ -211,14 +228,14 @@ def apply( @override def apply_masked( - self: SequenceSelf, - apply_fn: Callable[..., tuple[ValuesT, MaskT]], + self, + apply_fn: Callable[..., tuple[NewValuesT, NewMaskT]], *args, **kwargs, - ) -> SequenceSelf: + ) -> 'Sequence[NewValuesT, NewMaskT]': """Transforms values/mask, preserving masked state.""" values, mask = apply_fn(self.values, self.mask, *args, **kwargs) - return type(self)(values, mask) + return cast('Sequence[NewValuesT, NewMaskT]', type(self)(values, mask)) @override def astype(self: SequenceSelf, dtype: DType | None) -> SequenceSelf: @@ -235,7 +252,7 @@ def lengths(self) -> mx.array: @override def __getitem__( self: SequenceSelf, - the_slice: slice | tuple[int | slice | None | type(Ellipsis), ...], + the_slice: slice | tuple[int | slice | None | types.EllipsisType, ...], ) -> SequenceSelf: """Slices the Sequence values and mask.""" if isinstance(the_slice, slice): @@ -275,8 +292,10 @@ def concatenate(self, other: 'Sequence') -> 'Sequence': """Concatenates with other on the time dimension.""" values = mx.concatenate([self.values, other.values], axis=1) mask = mx.concatenate([self.mask, other.mask], axis=1) - return_type = type(self) if type(self) is type(other) else Sequence - return return_type(values, mask) + if type(self) is type(other): + return type(self)(values, mask) + else: + return Sequence(values, mask) @override def mask_invalid(self, mask_value: complex | None = None) -> 'Sequence': @@ -289,9 +308,35 @@ def unmask(self) -> 'Sequence': return self -class MaskedSequence(Sequence[ValuesT, MaskT], types.MaskedSequence[ValuesT, MaskT]): +class MaskedSequence[ValuesT: mx.array, MaskT: mx.array]( + Sequence[ValuesT, MaskT], spec.MaskedSequence[ValuesT, MaskT] +): """Sequence whose invalid timesteps are masked to zero.""" + @override + def apply_values_masked( + self, + values_fn: Callable[..., NewValuesT], + *args, + **kwargs, + ) -> 'MaskedSequence[NewValuesT, MaskT]': + return cast( + 'MaskedSequence[NewValuesT, MaskT]', + type(self)(values_fn(self.values, *args, **kwargs), self.mask), + ) + + @override + def apply_masked( + self, + apply_fn: Callable[..., tuple[NewValuesT, NewMaskT]], + *args, + **kwargs, + ) -> 'MaskedSequence[NewValuesT, NewMaskT]': + values, mask = apply_fn(self.values, self.mask, *args, **kwargs) + return cast( + 'MaskedSequence[NewValuesT, NewMaskT]', type(self)(values, mask) + ) + @override def mask_invalid(self, mask_value: complex | None = None) -> Sequence: if mask_value is None: @@ -311,10 +356,10 @@ def mask_invalid( expanded_mask = sequence.expanded_mask() if mask_value is None: masked_values = mx.zeros_like(sequence.values) - result_type = MaskedSequence + result_type: type[Sequence] = MaskedSequence else: masked_values = mx.full( - sequence.values.shape, mask_value, sequence.values.dtype + sequence.values.shape, mask_value, sequence.values.dtype # type: ignore[arg-type] ) result_type = Sequence masked_values = mx.where(expanded_mask, sequence.values, masked_values) @@ -322,7 +367,7 @@ def mask_invalid( # Defined outside of Sequence so mask_invalid can return MaskedSequence. -Sequence.mask_invalid = mask_invalid +Sequence.mask_invalid = mask_invalid # type: ignore[assignment] # --------------------------------------------------------------------------- # Check decorators @@ -388,7 +433,7 @@ def wrapper(self, x, state, *, training: bool, constants=None): # --------------------------------------------------------------------------- -class Steppable(types.Steppable): +class Steppable(spec.Steppable[InputT, OutputT], Generic[InputT, OutputT]): """A sequence processing layer that can be executed layerwise or stepwise. # Step-wise execution: @@ -482,13 +527,18 @@ def output_latency(self) -> int: @override def get_accumulated_input_latency(self, input_latency: int) -> int: import math + return math.ceil(input_latency / self.output_ratio) + self.input_latency @override def get_accumulated_output_latency(self, output_latency: int) -> int: output_ratio = self.output_ratio if required_delay := -output_latency % (1 / output_ratio): - path = '/'.join(self.path) if hasattr(self, 'path') else self.__class__.__name__ + path = ( + '/'.join(self.path) + if hasattr(self, 'path') + else self.__class__.__name__ + ) raise ValueError( f'Input to {self.__class__.__name__}(path={path!r}) has a step-wise' f' incoming {output_latency=} which is not divisible' @@ -501,14 +551,15 @@ def get_accumulated_output_latency(self, output_latency: int) -> int: @property @override def receptive_field(self) -> ReceptiveField: - raise NotImplementedError('receptive_field is not implemented by MLX Steppable.') - + raise NotImplementedError( + 'receptive_field is not implemented by MLX Steppable.' + ) @abc.abstractmethod @override def layer( - self, x: Sequence, *, training: bool, constants: Constants | None = None - ) -> Sequence: + self, x: InputT, *, training: bool, constants: Constants | None = None + ) -> OutputT: """Process this layer layer-wise. Args: @@ -527,8 +578,8 @@ def layer( @override def layer_with_emits( - self, x: Sequence, *, training: bool, constants: Constants | None = None - ) -> tuple[Sequence, Emits]: + self, x: InputT, *, training: bool, constants: Constants | None = None + ) -> tuple[OutputT, Emits]: """Process this layer layer-wise, producing emitted arrays. This is like `layer`, except it has an additional return value which is the @@ -555,12 +606,12 @@ def layer_with_emits( @override def step( self, - x: Sequence, + x: InputT, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State]: + ) -> tuple[OutputT, State]: """Process this layer step-wise. Args: @@ -584,12 +635,12 @@ def step( @override def step_with_emits( self, - x: Sequence, + x: InputT, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State, Emits]: + ) -> tuple[OutputT, State, Emits]: """Process this layer step-wise, producing emitted arrays. This is like `step`, except it has an additional return value which is the @@ -701,20 +752,22 @@ def get_output_spec( # --------------------------------------------------------------------------- -class SequenceLayer(nn.Module, Steppable, types.SequenceLayer): +class SequenceLayer[InputT: Sequence, OutputT: Sequence]( + nn.Module, Steppable[InputT, OutputT], spec.SequenceLayer[InputT, OutputT] +): """Base Module for Sequence Layers.""" -class SequenceLayerConfig(types.SequenceLayerConfig): + +class SequenceLayerConfig(spec.SequenceLayerConfig): """Base class for SequenceLayer configuration objects.""" @abc.abstractmethod def make(self) -> SequenceLayer: """Builds a SequenceLayer from this config.""" - def copy(self, **kwargs) -> 'SequenceLayerConfig': + def copy(self, **kwargs) -> Self: """Returns a copy of the config with updated fields.""" - return dataclasses.replace(self, **kwargs) - + return cast(Self, dataclasses.replace(cast(Any, self), **kwargs)) # --------------------------------------------------------------------------- @@ -722,7 +775,7 @@ def copy(self, **kwargs) -> 'SequenceLayerConfig': # --------------------------------------------------------------------------- -class PreservesType(types.PreservesType): +class PreservesType(SequenceLayer, spec.PreservesType): """A mix-in for layers that do not change the input dtype.""" @override @@ -736,7 +789,9 @@ def get_output_dtype( return input_dtype -class PreservesShape(types.PreservesShape): +class PreservesShape[InputT: Sequence, OutputT: Sequence]( + SequenceLayer[InputT, OutputT], spec.PreservesShape[InputT, OutputT] +): """A mix-in for layers that do not change the input shape.""" @override @@ -755,15 +810,18 @@ def get_output_shape( # --------------------------------------------------------------------------- -class Stateless(SequenceLayer, types.Stateless): +class Stateless[InputT: Sequence, OutputT: Sequence]( + SequenceLayer[InputT, OutputT], spec.Stateless[InputT, OutputT] +): """A SequenceLayer with no state over time required for step-wise processing. - Sub-classes must only implement: + Sub-classes must also implement: - layer - get_output_shape - get_output_dtype """ + @override def get_initial_state( self, batch_size: int, @@ -779,56 +837,69 @@ def get_initial_state( return () @abc.abstractmethod + @override def get_output_shape( self, input_shape: ShapeLike, *, constants: Constants | None = None, ) -> Shape: - pass + ... @abc.abstractmethod + @override def get_output_dtype( self, input_dtype: DType, *, constants: Constants | None = None, ) -> DType: - pass + ... @abc.abstractmethod + @override def layer( self, - x: Sequence, + x: InputT, *, training: bool, constants: Constants | None = None, - ) -> Sequence: - pass + ) -> OutputT: + ... + @override def step( self, - x: Sequence, + x: InputT, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State]: + ) -> tuple[OutputT, State]: return self.layer(x, training=training, constants=constants), state -class StatelessPointwise(PreservesShape, Stateless, types.StatelessPointwise): +class StatelessPointwise[InputT: Sequence, OutputT: Sequence]( + PreservesShape[InputT, OutputT], + Stateless[InputT, OutputT], + spec.StatelessPointwise[InputT, OutputT], +): """A SequenceLayer that has no state and operates pointwise on its input.""" -class StatelessPointwiseFunctor(StatelessPointwise, types.StatelessPointwiseFunctor): +class StatelessPointwiseFunctor[InputT: Sequence, OutputT: Sequence]( + StatelessPointwise[InputT, OutputT], + spec.StatelessPointwiseFunctor[InputT, OutputT], +): """A stateless SequenceLayer for simple pointwise processing fns.""" @abc.abstractmethod + @override def fn(self, values: ValuesT, mask: MaskT) -> tuple[ValuesT, MaskT]: """Transforms each scalar in values independently.""" @property + @override def mask_required(self): """Returns true if fn can change the sequence's masked state. @@ -837,13 +908,14 @@ def mask_required(self): return True @check_layer + @override def layer( self, - x: Sequence, + x: InputT, *, training: bool, constants: Constants | None = None, - ) -> Sequence: + ) -> OutputT: del training if self.mask_required: y = x.apply(self.fn) @@ -852,7 +924,7 @@ def layer( # Ensure MaskedSequence -> Sequence conversion for apply. if isinstance(y, MaskedSequence) and self.mask_required: y = Sequence(y.values, y.mask) - return y + return cast(OutputT, y) # --------------------------------------------------------------------------- @@ -860,7 +932,11 @@ def layer( # --------------------------------------------------------------------------- -class Emitting(SequenceLayer, types.Emitting): +class Emitting( + SequenceLayer[InputT, OutputT], + spec.Emitting[InputT, OutputT], + Generic[InputT, OutputT], +): """A SequenceLayer that emits auxiliary arrays. This is a convenience subclass that implements step and layer in terms of @@ -871,6 +947,7 @@ class Emitting(SequenceLayer, types.Emitting): """ @abc.abstractmethod + @override def get_initial_state( self, batch_size: int, @@ -879,83 +956,92 @@ def get_initial_state( training: bool, constants: Constants | None = None, ) -> State: - pass + ... @abc.abstractmethod + @override def get_output_shape( self, input_shape: ShapeLike, *, constants: Constants | None = None, ) -> Shape: - pass + ... @abc.abstractmethod + @override def get_output_dtype( self, input_dtype: DType, *, constants: Constants | None = None, ) -> DType: - pass + ... @abc.abstractmethod + @override def step_with_emits( self, - x: Sequence, + x: InputT, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State, Emits]: - pass + ) -> tuple[OutputT, State, Emits]: + ... @abc.abstractmethod + @override def layer_with_emits( self, - x: Sequence, + x: InputT, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, Emits]: - pass + ) -> tuple[OutputT, Emits]: + ... + @override def step( self, - x: Sequence, + x: InputT, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State]: + ) -> tuple[OutputT, State]: output, state, _ = self.step_with_emits( x, state, training=training, constants=constants ) return output, state + @override def layer( self, - x: Sequence, + x: InputT, *, training: bool, constants: Constants | None = None, - ) -> Sequence: + ) -> OutputT: outputs, _ = self.layer_with_emits( x, training=training, constants=constants ) return outputs -class StatelessEmitting(Emitting, types.StatelessEmitting): +class StatelessEmitting[InputT: Sequence, OutputT: Sequence]( + Emitting[InputT, OutputT], spec.StatelessEmitting[InputT, OutputT] +): """A SequenceLayer with no state over time that emits auxiliary arrays. - Sub-classes must only implement: + Sub-classes must implement: - layer_with_emits - get_output_shape - get_output_dtype """ @abc.abstractmethod + @override def get_output_shape( self, input_shape: ShapeLike, @@ -965,24 +1051,27 @@ def get_output_shape( pass @abc.abstractmethod + @override def get_output_dtype( self, input_dtype: DType, *, constants: Constants | None = None, ) -> DType: - pass + ... @abc.abstractmethod + @override def layer_with_emits( self, - x: Sequence, + x: InputT, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, Emits]: - pass + ) -> tuple[OutputT, Emits]: + ... + @override def get_initial_state( self, batch_size: int, @@ -997,15 +1086,16 @@ def get_initial_state( del constants return () + @override def step_with_emits( self, - x: Sequence, + x: InputT, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State, Emits]: + ) -> tuple[OutputT, State, Emits]: outputs, emits = self.layer_with_emits( x, training=training, constants=constants ) - return outputs, state, emits \ No newline at end of file + return outputs, state, emits diff --git a/sequence_layers/mlx/types_test.py b/sequence_layers/mlx/types_test.py index e65c28e..8713558 100644 --- a/sequence_layers/mlx/types_test.py +++ b/sequence_layers/mlx/types_test.py @@ -1,125 +1,50 @@ import mlx.core as mx import numpy as np -from sequence_layers.abstract import types_test_base -from sequence_layers.mlx import types + +import sequence_layers.mlx as sl +from sequence_layers.specs import types_behaviors as spec from absl.testing import parameterized from absl.testing import absltest -class SequenceTest(types_test_base.SequenceTest): - - def get_backend(self): - return mx - - @property - def Sequence(self): - return types.Sequence - - @property - def MaskedSequence(self): - return types.MaskedSequence - - def assertAllEqual(self, a, b): - a = np.array(a) if isinstance(a, mx.array) else a - b = np.array(b) if isinstance(b, mx.array) else b - np.testing.assert_array_equal(a, b) - - def assertSequencesEqual(self, a, b): - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.mask, b.mask) - - -class SteppableTest(types_test_base.SteppableTest): - - def create_steppable(self): - - class DefaultSteppable(types.Steppable): - - def layer(self, x, *, training: bool, constants=None): - return x - - def step(self, x, state, *, training: bool, constants=None): - return x, state - - def get_initial_state(self, batch_size, input_spec, *, constants=None): - return () - - def get_output_shape(self, input_shape, *, constants=None): - return input_shape +class ModuleInterfaceTest(spec.ModuleInterfaceTest): + sl = sl - def get_output_dtype(self, input_dtype, *, constants=None): - return input_dtype - return DefaultSteppable() +class SequenceTest(spec.SequenceTest): + sl = sl -class SequenceLayerConfigTest(types_test_base.SequenceLayerConfigTest): +class SequenceLayerConfigTest(spec.SequenceLayerConfigTest): + sl = sl - def get_config_base_cls(self): - return types.SequenceLayerConfig +class SteppableTest(spec.SteppableTest): + sl = sl -class PreservesTypeTest(types_test_base.PreservesTypeTest): - def create_layer(self): - class DummyLayer(types.PreservesType, types.SequenceLayer): - def layer(self, x, *, training: bool, constants=None): return x - def step(self, x, state, *, training: bool, constants=None): return x, state - def get_initial_state(self, batch_size, input_spec, *, training: bool, constants=None): return () - def get_output_shape(self, input_shape, *, constants=None): return input_shape - return DummyLayer() +class PreservesTypeTest(spec.PreservesTypeTest): + sl = sl -class PreservesShapeTest(types_test_base.PreservesShapeTest): - def create_layer(self): - class DummyLayer(types.PreservesShape, types.SequenceLayer): - def layer(self, x, *, training: bool, constants=None): return x - def step(self, x, state, *, training: bool, constants=None): return x, state - def get_initial_state(self, batch_size, input_spec, *, training: bool, constants=None): return () - def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype - return DummyLayer() +class PreservesShapeTest(spec.PreservesShapeTest): + sl = sl -class StatelessTest(types_test_base.StatelessTest): - def create_layer(self): - class DummyLayer(types.Stateless, types.SequenceLayer): - def layer(self, x, *, training: bool, constants=None): return x - def get_output_shape(self, input_shape, *, constants=None): return input_shape - def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype - return DummyLayer() +class StatelessTest(spec.StatelessTest): + sl = sl -class EmittingTest(types_test_base.EmittingTest): - def create_layer(self): - class DummyLayer(types.Emitting, types.SequenceLayer): - def get_initial_state(self, batch_size, input_spec, *, training: bool, constants=None): return () - def layer_with_emits(self, x, *, training: bool, constants=None): return x, () - def step_with_emits(self, x, state, *, training: bool, constants=None): return x, state, () - def get_output_shape(self, input_shape, *, constants=None): return input_shape - def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype - return DummyLayer() +class EmittingTest(spec.EmittingTest): + sl = sl -class StatelessEmittingTest(types_test_base.StatelessEmittingTest): - def create_layer(self): - class DummyLayer(types.StatelessEmitting, types.SequenceLayer): - def layer_with_emits(self, x, *, training: bool, constants=None): return x, () - def get_output_shape(self, input_shape, *, constants=None): return input_shape - def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype - return DummyLayer() +class StatelessEmittingTest(spec.StatelessEmittingTest): + sl = sl -class StatelessPointwiseFunctorTest(types_test_base.StatelessPointwiseFunctorTest): - def create_layer(self, is_mask_required: bool): - class DummyLayer(types.StatelessPointwiseFunctor, types.SequenceLayer): - @property - def mask_required(self): return is_mask_required - def fn(self, values, mask): return values, mask - def get_output_shape(self, input_shape, *, constants=None): return input_shape - def get_output_dtype(self, input_dtype, *, constants=None): return input_dtype - return DummyLayer() - def create_sequence(self): - return types.Sequence(mx.zeros((2, 3, 5)), mx.zeros((2, 3), dtype=mx.bool_)) +class StatelessPointwiseFunctorTest(spec.StatelessPointwiseFunctorTest): + sl = sl if __name__ == '__main__': diff --git a/sequence_layers/specs/__init__.py b/sequence_layers/specs/__init__.py new file mode 100644 index 0000000..0060178 --- /dev/null +++ b/sequence_layers/specs/__init__.py @@ -0,0 +1,38 @@ +# https://typing.python.org/en/latest/spec/protocol.html#modules-as-implementations-of-protocols + +from typing import Protocol, runtime_checkable + +from . import backend as _backend +from . import types as _types + + +@runtime_checkable +class ModuleSpec(Protocol): + """Protocol for a backend-specific SequenceLayers module (sequence_layers. as sl).""" + + @property + def backend(self) -> _backend.ModuleSpec: + ... + + @property + def types(self) -> _types.ModuleSpec: + ... + + # Identifiers that backend-specific implementations should expose at top level. + # Demonstrating read-only allows for covariance (subclasses of types_module.Sequence to satisfy the protocol). + + @property + def Sequence(self) -> type[_types.Sequence]: + ... + + @property + def MaskedSequence(self) -> type[_types.MaskedSequence]: + ... + + @property + def SequenceLayer(self) -> type[_types.SequenceLayer]: + ... + + @property + def SequenceLayerConfig(self) -> type[_types.SequenceLayerConfig]: + ... diff --git a/sequence_layers/specs/backend.py b/sequence_layers/specs/backend.py new file mode 100644 index 0000000..bea6cd9 --- /dev/null +++ b/sequence_layers/specs/backend.py @@ -0,0 +1,41 @@ +"""Specification for backend-specific helpers.""" + +from typing import Any, Protocol, runtime_checkable + +from sequence_layers.specs import types as types_spec + + +Array = types_spec.Array + + +class xp(Protocol): + """NumPy-compatible interface to enable generic behavior tests. + + https://numpy.org/doc/stable/reference/routines.html#routines + https://docs.jax.dev/en/latest/jax.numpy.html + """ + + bool_: Any + int32: Any + + def array(self, a: Any, dtype: Any = None) -> Array: + ... + + def zeros(self, shape: tuple[int, ...], dtype: Any = None) -> Array: + ... + + +@runtime_checkable +class ModuleSpec(Protocol): + """Specification for sequence_layers..backend""" + + @property + def xp(self) -> xp: + ... + + +__all__ = [ + name + for name, attr in ModuleSpec.__dict__.items() + if isinstance(attr, property) +] diff --git a/sequence_layers/abstract/types.py b/sequence_layers/specs/types.py similarity index 75% rename from sequence_layers/abstract/types.py rename to sequence_layers/specs/types.py index de14366..07a66a0 100644 --- a/sequence_layers/abstract/types.py +++ b/sequence_layers/specs/types.py @@ -1,17 +1,29 @@ -"""Abstract base classes and types for SequenceLayers.""" +"""Signatures for the types module. + +See the corresponding _behaviors module for behaviors. + +If you are adding a new class or method to be implemented per backend, make +sure to add it to the ModuleSpec protocol. +""" import abc import enum import fractions -from typing import Any, Callable, Generic, Iterable, Literal, TypeVar +from types import EllipsisType +from typing import Any, Callable, Concatenate, Generic, Iterable, Literal, Protocol, Self, TypeVar, override, runtime_checkable import numpy as np +import numpy.typing as npt +import jaxtyping as jt + + +# NEW +ArrayLike = npt.ArrayLike + +Array = jt.Shaped[Any, '...'] # Type aliases for generic usage T = TypeVar('T') -ValuesT = TypeVar('ValuesT') -MaskT = TypeVar('MaskT') -SequenceSelf = TypeVar('SequenceSelf', bound='Sequence') Shape = tuple[int, ...] ShapeLike = list[int] | tuple[int, ...] DType = Any # Can be numpy, jax, or mlx dtype @@ -20,6 +32,16 @@ Constants = Any Emits = Any +# TODO: Do these defaults do anything? apparently not +ValuesT = TypeVar('ValuesT', bound=Array) +MaskT = TypeVar('MaskT', bound=Array) + +LengthsT = TypeVar('LengthsT', bound=Array) +# SequenceT = TypeVar('SequenceT', bound='Sequence[Array, Array]', default='Sequence[Array, Array]') +InputT = TypeVar('InputT', bound='Sequence') +OutputT = TypeVar('OutputT', bound='Sequence') + + class PaddingMode(enum.Enum): """Supported padding modes.""" @@ -123,6 +145,7 @@ class PaddingMode(enum.Enum): # cover the full input sequence. SEMICAUSAL_FULL = 'semicausal_full' + PaddingModeString = Literal[ 'valid', 'same', @@ -134,133 +157,157 @@ class PaddingMode(enum.Enum): 'semicausal_full', ] -class Sequence(Generic[ValuesT, MaskT], metaclass=abc.ABCMeta): + +class Sequence[ValuesT = Array, MaskT = Array](metaclass=abc.ABCMeta): """Abstract base class for Sequence.""" values: ValuesT mask: MaskT + @abc.abstractmethod + def __init__(self, values: ValuesT, mask: MaskT): + ... + @property @abc.abstractmethod def shape(self) -> Shape: - pass + ... @property @abc.abstractmethod def ndim(self) -> int: - pass + ... @property @abc.abstractmethod def channel_shape(self) -> Shape: - pass - + ... + @property @abc.abstractmethod def dtype(self) -> DType: - pass + ... @classmethod @abc.abstractmethod - def from_values(cls, values: ValuesT) -> 'Sequence': - pass + def from_values(cls, values: ValuesT) -> Self: + ... @classmethod @abc.abstractmethod def from_lengths( cls, values: ValuesT, - lengths: Any, + lengths: LengthsT, is_masked: bool = False, - ) -> 'Sequence': - pass + ) -> Self: + ... @classmethod @abc.abstractmethod - def concatenate_sequences(cls, sequences: Iterable['Sequence']) -> 'Sequence': - pass + def concatenate_sequences(cls, sequences: Iterable[Self]) -> Self: + ... @abc.abstractmethod def expanded_mask(self) -> Any: - pass + ... @abc.abstractmethod - def apply_values( + def apply_values[NewValuesT: Array, **P]( self, - values_fn: Callable[..., ValuesT], - *args, - **kwargs, - ) -> 'Sequence': - pass - - @abc.abstractmethod - def apply_values_masked( - self: SequenceSelf, - values_fn: Callable[..., ValuesT], - *args, - **kwargs, - ) -> SequenceSelf: - pass - - @abc.abstractmethod - def apply( + values_fn: Callable[Concatenate[ValuesT, P], NewValuesT], + *args: P.args, + **kwargs: P.kwargs, + ) -> 'Sequence[NewValuesT, MaskT]': + ... + + @abc.abstractmethod + def apply_values_masked[NewValuesT: Array, **P]( self, - apply_fn: Callable[..., tuple[ValuesT, MaskT]], - *args, - **kwargs, - ) -> 'Sequence': - pass - + values_fn: Callable[Concatenate[ValuesT, P], NewValuesT], + *args: P.args, + **kwargs: P.kwargs, + ) -> 'Sequence[NewValuesT, MaskT]': + ... + @abc.abstractmethod - def apply_masked( - self: SequenceSelf, - apply_fn: Callable[..., tuple[ValuesT, MaskT]], - *args, - **kwargs, - ) -> SequenceSelf: - pass + def apply[NewValuesT: Array, NewMaskT: Array, **P]( + self, + apply_fn: Callable[Concatenate[ValuesT, P], tuple[NewValuesT, NewMaskT]], + *args: P.args, + **kwargs: P.kwargs, + ) -> 'Sequence[NewValuesT, NewMaskT]': + ... + + @abc.abstractmethod + def apply_masked[NewValuesT: Array, NewMaskT: Array, **P]( + self, + apply_fn: Callable[Concatenate[ValuesT, P], tuple[NewValuesT, NewMaskT]], + *args: P.args, + **kwargs: P.kwargs, + ) -> 'Sequence[NewValuesT, NewMaskT]': + ... @abc.abstractmethod - def astype(self: SequenceSelf, dtype: DType | None) -> SequenceSelf: - pass + def astype(self, dtype: DType | None) -> Self: + ... @abc.abstractmethod def lengths(self) -> Any: - pass + ... @abc.abstractmethod def __getitem__( - self: SequenceSelf, - the_slice: slice | tuple[int | slice | None | type(Ellipsis), ...], - ) -> SequenceSelf: - pass + self, + the_slice: slice | tuple[int | slice | None | EllipsisType, ...], + ) -> Self: + ... @abc.abstractmethod def pad_time( - self: SequenceSelf, + self, pad_left: int, pad_right: int, valid: bool, pad_value: Any | None = None, - ) -> SequenceSelf: - pass + ) -> Self: + ... @abc.abstractmethod - def concatenate(self, other: 'Sequence') -> 'Sequence': - pass - + def concatenate(self, other: Self) -> Self: + ... + @abc.abstractmethod - def mask_invalid(self, mask_value: Any | None = None) -> 'Sequence': - pass + def mask_invalid( + self, mask_value: Any | None = None + ) -> 'Sequence[ValuesT, MaskT]': + ... @abc.abstractmethod - def unmask(self) -> 'Sequence': - pass + def unmask(self) -> 'Sequence[ValuesT, MaskT]': + ... -class MaskedSequence(Sequence[ValuesT, MaskT], metaclass=abc.ABCMeta): +class MaskedSequence(Sequence[ValuesT, MaskT]): """A sequence whose invalid timesteps are masked to zero.""" - pass + + @abc.abstractmethod + def apply_values_masked[NewValuesT: Array, **P]( + self, + values_fn: Callable[Concatenate[ValuesT, P], NewValuesT], + *args: P.args, + **kwargs: P.kwargs, + ) -> 'MaskedSequence[NewValuesT, MaskT]': + ... + + @abc.abstractmethod + def apply_masked[NewValuesT: Array, NewMaskT: Array, **P]( + self, + apply_fn: Callable[Concatenate[ValuesT, P], tuple[NewValuesT, NewMaskT]], + *args: P.args, + **kwargs: P.kwargs, + ) -> 'MaskedSequence[NewValuesT, NewMaskT]': + ... class SequenceLayerConfig(metaclass=abc.ABCMeta): @@ -271,50 +318,55 @@ def make(self) -> Any: """Creates the sequence layer.""" @abc.abstractmethod - def copy(self, **kwargs) -> 'SequenceLayerConfig': + def copy(self, **kwargs: Any) -> Self: """Returns a copy of the config with updated fields.""" -class Steppable(metaclass=abc.ABCMeta): - """A sequence processing layer that can be executed layerwise or stepwise.""" +class Steppable[InputT = Sequence, OutputT = Sequence](metaclass=abc.ABCMeta): + """A sequence processing layer that can be executed layerwise or stepwise. + + The backend must implement: + - layer_with_emits + - step_with_emits + """ @property @abc.abstractmethod def block_size(self) -> int: - pass + ... @property @abc.abstractmethod def output_ratio(self) -> fractions.Fraction: - pass + ... @property @abc.abstractmethod def supports_step(self) -> bool: - pass + ... @property @abc.abstractmethod def input_latency(self) -> int: - pass + ... @property @abc.abstractmethod def output_latency(self) -> int: - pass + ... @abc.abstractmethod def get_accumulated_input_latency(self, input_latency: int) -> int: - pass + ... @abc.abstractmethod def get_accumulated_output_latency(self, output_latency: int) -> int: - pass + ... @abc.abstractmethod def layer( - self, x: Sequence, *, training: bool, constants: Constants | None = None - ) -> Sequence: + self, x: InputT, *, training: bool, constants: Constants | None = None + ) -> OutputT: """Process this layer layer-wise. Args: @@ -334,11 +386,11 @@ def layer( @abc.abstractmethod def layer_with_emits( self, - x: Sequence, + x: InputT, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, Emits]: + ) -> tuple[OutputT, Emits]: """Process this layer layer-wise, producing emitted arrays. This is like `layer`, except it has an additional return value which is the @@ -363,12 +415,12 @@ def layer_with_emits( @abc.abstractmethod def step( self, - x: Sequence, + x: InputT, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State]: + ) -> tuple[OutputT, State]: """Process this layer step-wise. Args: @@ -392,12 +444,12 @@ def step( @abc.abstractmethod def step_with_emits( self, - x: Sequence, + x: InputT, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State, Emits]: + ) -> tuple[OutputT, State, Emits]: """Process this layer step-wise, producing emitted arrays. This is like `step`, except it has an additional return value which is the @@ -492,12 +544,15 @@ def get_output_dtype( @property @abc.abstractmethod def receptive_field(self) -> Any: - pass + ... -class SequenceLayer(Steppable): +class SequenceLayer[InputT = Sequence, OutputT = Sequence]( + Steppable[InputT, OutputT] +): """Base class for Sequence Layers.""" - pass + + ... # --------------------------------------------------------------------------- @@ -505,30 +560,34 @@ class SequenceLayer(Steppable): # --------------------------------------------------------------------------- -class PreservesType: +class PreservesType(SequenceLayer): """A mix-in for layers that do not change the input dtype.""" @abc.abstractmethod + @override def get_output_dtype( self, input_dtype: DType, *, constants: Constants | None = None, ) -> DType: - pass + ... -class PreservesShape: +class PreservesShape[InputT = Sequence, OutputT = Sequence]( + SequenceLayer[InputT, OutputT] +): """A mix-in for layers that do not change the input channel shape.""" @abc.abstractmethod + @override def get_output_shape( self, input_shape: ShapeLike, *, constants: Constants | None = None, ) -> Shape: - pass + ... # --------------------------------------------------------------------------- @@ -536,9 +595,11 @@ def get_output_shape( # --------------------------------------------------------------------------- -class Stateless(SequenceLayer): +class Stateless[InputT = Sequence, OutputT = Sequence]( + SequenceLayer[InputT, OutputT] +): """A layer with no state over time required for step-wise processing. - + The backend must implement: - get_initial_state - step @@ -549,28 +610,32 @@ class Stateless(SequenceLayer): """ @abc.abstractmethod + @override def get_output_shape( self, input_shape: ShapeLike, *, constants: Constants | None = None ) -> Shape: - pass + ... @abc.abstractmethod + @override def get_output_dtype( self, input_dtype: DType, *, constants: Constants | None = None ) -> DType: - pass + ... @abc.abstractmethod + @override def layer( self, - x: Sequence, + x: InputT, *, training: bool, constants: Constants | None = None, - ) -> Sequence: - pass + ) -> OutputT: + ... @abc.abstractmethod + @override def get_initial_state( self, batch_size: int, @@ -579,27 +644,32 @@ def get_initial_state( training: bool, constants: Constants | None = None, ) -> State: - pass + ... @abc.abstractmethod + @override def step( self, - x: Sequence, + x: InputT, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State]: - pass + ) -> tuple[OutputT, State]: + ... -class StatelessPointwise(PreservesShape, Stateless): +class StatelessPointwise[InputT = Sequence, OutputT = Sequence]( + PreservesShape[InputT, OutputT], Stateless[InputT, OutputT] +): """Stateless layer that operates pointwise (preserves shape).""" -class StatelessPointwiseFunctor(StatelessPointwise, metaclass=abc.ABCMeta): +class StatelessPointwiseFunctor[InputT = Sequence, OutputT = Sequence]( + StatelessPointwise[InputT, OutputT] +): """Stateless pointwise layer defined by a fn(values, mask). - + The backend must implement: - layer Further sub-classes must only implement: @@ -615,19 +685,20 @@ def fn(self, values: Any, mask: Any) -> tuple[Any, Any]: @abc.abstractmethod def mask_required(self) -> bool: """Returns true if fn can change the sequence's masked state. - + If fn(0) -> 0, then mask_required() is False. """ @abc.abstractmethod + @override def layer( self, - x: Sequence, + x: InputT, *, training: bool, constants: Constants | None = None, - ) -> Sequence: - pass + ) -> OutputT: + ... # --------------------------------------------------------------------------- @@ -635,7 +706,9 @@ def layer( # --------------------------------------------------------------------------- -class Emitting(SequenceLayer, metaclass=abc.ABCMeta): +class Emitting[InputT = Sequence, OutputT = Sequence]( + SequenceLayer[InputT, OutputT] +): """A Steppable layer that emits auxiliary arrays. This is a convenience subclass that implements step and layer in terms of @@ -652,49 +725,51 @@ class Emitting(SequenceLayer, metaclass=abc.ABCMeta): @abc.abstractmethod def step( self, - x: Sequence, + x: InputT, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State]: - pass + ) -> tuple[OutputT, State]: + ... @abc.abstractmethod def layer( self, - x: Sequence, + x: InputT, *, training: bool, constants: Constants | None = None, - ) -> Sequence: - pass + ) -> OutputT: + ... @abc.abstractmethod def step_with_emits( self, - x: Sequence, + x: InputT, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State, Emits]: - pass + ) -> tuple[OutputT, State, Emits]: + ... @abc.abstractmethod def layer_with_emits( self, - x: Sequence, + x: InputT, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, Emits]: - pass + ) -> tuple[OutputT, Emits]: + ... -class StatelessEmitting(Emitting): +class StatelessEmitting[InputT = Sequence, OutputT = Sequence]( + Emitting[InputT, OutputT] +): """A Steppable layer with no state over time that emits auxiliary arrays. - + The backend must implement: - get_initial_state - step_with_emits @@ -705,6 +780,7 @@ class StatelessEmitting(Emitting): """ @abc.abstractmethod + @override def get_initial_state( self, batch_size: int, @@ -713,37 +789,85 @@ def get_initial_state( training: bool, constants: Constants | None = None, ) -> State: - pass + ... @abc.abstractmethod def step_with_emits( self, - x: Sequence, + x: InputT, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, State, Emits]: - pass + ) -> tuple[OutputT, State, Emits]: + ... @abc.abstractmethod def get_output_shape( self, input_shape: ShapeLike, *, constants: Constants | None = None ) -> Shape: - pass + ... @abc.abstractmethod def get_output_dtype( self, input_dtype: DType, *, constants: Constants | None = None ) -> DType: - pass + ... @abc.abstractmethod def layer_with_emits( self, - x: Sequence, + x: InputT, *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, Emits]: - pass + ) -> tuple[OutputT, Emits]: + ... + + +@runtime_checkable +class ModuleSpec(Protocol): + """Specification for sequence_layers..types""" + + @property + def Sequence(self) -> type[Sequence]: + ... + + @property + def MaskedSequence(self) -> type[MaskedSequence]: + ... + + @property + def SequenceLayer(self) -> type[SequenceLayer]: + ... + + @property + def SequenceLayerConfig(self) -> type[SequenceLayerConfig]: + ... + + @property + def Steppable(self) -> type[Steppable]: + ... + + @property + def PreservesShape(self) -> type[PreservesShape]: + ... + + @property + def Stateless(self) -> type[Stateless]: + ... + + @property + def StatelessPointwise(self) -> type[StatelessPointwise]: + ... + + @property + def StatelessPointwiseFunctor(self) -> type[StatelessPointwiseFunctor]: + ... + + +__all__ = [ + name + for name, attr in ModuleSpec.__dict__.items() + if isinstance(attr, property) +] diff --git a/sequence_layers/specs/types_behaviors.py b/sequence_layers/specs/types_behaviors.py new file mode 100644 index 0000000..427c151 --- /dev/null +++ b/sequence_layers/specs/types_behaviors.py @@ -0,0 +1,636 @@ +"""Abstract tests for Sequence types.""" + +import abc +from types import ModuleType +from typing import Any, Callable, Sequence as TypingSequence, TYPE_CHECKING, cast, override +import dataclasses +from sequence_layers import specs +from sequence_layers.specs import backend as backend_spec +from sequence_layers.specs import types as spec + +import fractions +from absl.testing import parameterized +import numpy as np +import unittest.mock + + +class DefaultTestLayer(spec.SequenceLayer): + + @override + def layer( + self, + x: spec.Sequence, + *, + training: bool, + constants: spec.Constants | None = None, + ) -> spec.Sequence: + return x + + @override + def layer_with_emits( + self, + x: spec.Sequence, + *, + training: bool, + constants: spec.Constants | None = None, + ) -> tuple[spec.Sequence, spec.Emits]: + return self.layer(x, training=training, constants=constants), ( + 'test_emits', + ) + + @override + def step( + self, + x: spec.Sequence, + state: spec.State, + *, + training: bool, + constants: spec.Constants | None = None, + ) -> tuple[spec.Sequence, spec.State]: + return x, ('new_test_state',) + + @override + def step_with_emits( + self, + x: spec.Sequence, + state: spec.State, + *, + training: bool, + constants: spec.Constants | None = None, + ) -> tuple[spec.Sequence, spec.State, spec.Emits]: + return *self.step(x, state, training=training, constants=constants), ( + 'test_emits', + ) + + @override + def get_initial_state( + self, + batch_size: int, + input_spec: spec.ChannelSpec, + *, + training: bool, + constants: spec.Constants | None = None, + ) -> spec.State: + return ('test_state',) + + @override + def get_output_shape( + self, + input_shape: spec.ShapeLike, + *, + constants: spec.Constants | None = None, + ) -> spec.Shape: + return tuple(input_shape) + (1,) + + @override + def get_output_dtype( + self, input_dtype: spec.DType, *, constants: spec.Constants | None = None + ) -> spec.DType: + return np.float64 + + +class SequenceLayerTest[SequenceT: spec.Sequence = spec.Sequence]( + parameterized.TestCase +): + """Base test class providing common sequence testing assertions and binds a backend implementation to tests.""" + + # sequence_layers. module + sl: specs.ModuleSpec + + @property + def xp(self) -> backend_spec.xp: + return self.sl.backend.xp + + @abc.abstractmethod + def assertSequencesEqual(self, x: SequenceT, y: SequenceT) -> None: + ... + + @abc.abstractmethod + def assertAllEqual(self, x: Any, y: Any) -> None: + ... + + +class ModuleInterfaceTest(SequenceLayerTest): + + def test_backend_specific_module_has_interface(self) -> None: + self.assertIsInstance(self.sl.types, spec.ModuleSpec) + + +class SequenceTest(SequenceLayerTest): + """Abstract tests for the Sequence class.""" + + @parameterized.named_parameters( + ('mask_value=None', 0.0, None), + ('mask_value=0.0', 0.0, 0.0), + ('mask_value=-1.0', -1.0, -1.0), + ) + def test_mask_invalid( + self, mask_value: float, expected_mask_value: float | None + ) -> None: + values = self.xp.array([ + [1.0, 2.0, 3.0, 4.0], + [10.0, 20.0, 30.0, 40.0], + ]) + mask = self.xp.array( + [[True, True, False, False], [False, False, False, True]] + ) + + # Pass mask_value only if it is not None (to test default None behavior vs explicit value) + if expected_mask_value is None: + output = self.sl.Sequence(values, mask).mask_invalid() + fill_value = 0.0 + else: + output = self.sl.Sequence(values, mask).mask_invalid(mask_value) + fill_value = mask_value + + expected_values = self.xp.array([ + [1.0, 2.0, fill_value, fill_value], + [fill_value, fill_value, fill_value, 40.0], + ]) + self.assertAllEqual(output.values, expected_values) + self.assertAllEqual(output.mask, mask) + + def test_pad_time(self) -> None: + values = self.xp.array([ + [1.0, 2.0, 3.0, 4.0], + [10.0, 20.0, 30.0, 40.0], + ]) + mask = self.xp.array( + [[True, True, False, False], [False, False, False, True]] + ) + + x = self.sl.Sequence(values, mask).mask_invalid() + + y = x.pad_time(0, 0, valid=False) + self.assertAllEqual(y.values, x.values) + self.assertAllEqual(y.mask, x.mask) + + y = x.pad_time(1, 0, valid=False) + + x_left1 = self.sl.Sequence( + self.xp.array([ + [0.0, 1.0, 2.0, 3.0, 4.0], + [0.0, 10.0, 20.0, 30.0, 40.0], + ]), + self.xp.array([ + [False, True, True, False, False], + [False, False, False, False, True], + ]), + ).mask_invalid() + self.assertAllEqual(y.values, x_left1.values) + self.assertAllEqual(y.mask, x_left1.mask) + + def _create_test_sequence( + self, shape: spec.Shape + ) -> spec.Sequence[spec.Array, spec.Array]: + size = 1 + for d in shape: + size *= d + values_np = np.arange(size, dtype=np.float32).reshape(shape) + mask_np = np.ones(shape[:2], dtype=bool) + if shape[0] > 0 and shape[1] > 1: + mask_np[0, 1] = False + + values = self.xp.array(values_np) + mask = self.xp.array(mask_np) + return self.sl.Sequence(values, mask) + + def test_slice(self) -> None: + x = self._create_test_sequence((3, 5, 9)) + + self.assertSequencesEqual( + x[:, 1:], self.sl.Sequence(x.values[:, 1:], x.mask[:, 1:]) + ) + self.assertSequencesEqual( + x[:, ::2], self.sl.Sequence(x.values[:, ::2], x.mask[:, ::2]) + ) + self.assertSequencesEqual( + x[::2, ::3], self.sl.Sequence(x.values[::2, ::3], x.mask[::2, ::3]) + ) + + def test_slice_can_slice_channel_dimensions(self) -> None: + x = self._create_test_sequence((3, 5, 9, 4)) + + self.assertSequencesEqual( + x[:, 1:, :], self.sl.Sequence(x.values[:, 1:], x.mask[:, 1:]) + ) + self.assertSequencesEqual( + x[:, ::2, :3], + self.sl.Sequence(x.values[:, ::2, :3], x.mask[:, ::2]), + ) + + def test_apply_values(self) -> None: + values = self.xp.array([ + [-1.0, 2.0, 3.0, 4.0], + [10.0, -20.0, 30.0, 40.0], + ]) + mask = self.xp.array( + [[True, True, False, False], [False, True, False, True]] + ) + + x = self.sl.Sequence(values, mask) + masked = x.mask_invalid() + + # Simple abs function + fn = abs + + y = x.apply_values(fn) + self.assertAllEqual(y.values, fn(x.values)) + self.assertAllEqual(y.mask, x.mask) + + y = masked.apply_values(fn) + self.assertAllEqual(y.values, fn(masked.values)) + self.assertAllEqual(y.mask, x.mask) + + y = masked.apply_values_masked(fn) + self.assertAllEqual(y.values, fn(masked.values)) + self.assertAllEqual(y.mask, x.mask) + + def test_apply_values_args(self) -> None: + values = self.xp.array([ + [-1.0, 2.0, 3.0, 4.0], + [10.0, -20.0, 30.0, 40.0], + ]) + mask = self.xp.array( + [[True, True, False, False], [False, True, False, True]] + ) + x = self.sl.Sequence(values, mask) + + target_shape = (2, 4, 1) + y = x.apply_values(lambda v, s: v.reshape(s), target_shape) + self.assertAllEqual(y.values.shape, target_shape) + self.assertAllEqual(y.mask.shape, (2, 4)) + + def test_from_values(self) -> None: + values_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + values = self.xp.array(values_np) + # Get the class from an instance + seq = self.sl.Sequence( + values, self.xp.array(np.ones(values.shape[:2], dtype=bool)) + ) + SeqClass = type(seq) + + x = SeqClass.from_values(values) # type: ignore + self.assertAllEqual(x.values, values) + self.assertAllEqual( + x.mask, self.xp.array(np.ones(values.shape[:2], dtype=bool)) + ) + + def test_astype(self) -> None: + values_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + mask_np = np.array([[True, False], [False, True]], dtype=bool) + + values = self.xp.array(values_np) + mask = self.xp.array(mask_np) + + x = self.sl.Sequence(values, mask) + + y = x.astype(self.xp.int32) + + # Check values match casted version + self.assertAllEqual(y.mask, mask) + # y.values might be mlx array, values.astype(dtype) might be numpy if values was numpy? + # values is backend array. values.astype(dtype) should work if dtype is backend dtype. + self.assertAllEqual(y.values, values.astype(self.xp.int32)) + + def test_mask_invalid_idempotent(self) -> None: + values = self.xp.array([ + [1.0, 2.0, 3.0, 4.0], + [10.0, 20.0, 30.0, 40.0], + ]) + mask = self.xp.array( + [[True, True, False, False], [False, False, False, True]] + ) + + x = self.sl.Sequence(values, mask) + masked = x.mask_invalid() + self.assertIsNot(masked, x) + self.assertIsInstance(masked, self.sl.MaskedSequence) + + masked_again = masked.mask_invalid() + self.assertIs(masked_again, masked) + self.assertIsInstance(masked_again, self.sl.MaskedSequence) + + masked2 = x.mask_invalid() + self.assertIsNot(masked2, masked) + self.assertIsInstance(masked2, self.sl.MaskedSequence) + + def test_from_lengths(self) -> None: + values = self.xp.array( + np.arange(5 * 17 * 2).reshape((5, 17, 2)).astype(np.float32) + ) + lengths_np = np.array([0, 5, 10, 17, 12], dtype=np.int32) + mask_np = np.arange(17)[None, :] < lengths_np[:, None] + mask = self.xp.array(mask_np) + + x_expected = self.sl.Sequence(values, mask) + x = self.sl.Sequence.from_lengths(x_expected.values, lengths_np) + self.assertAllEqual(x.values, x_expected.values) + self.assertAllEqual(x.mask, x_expected.mask) + + # Out of range lengths are clipped to 0 or max. + x = self.sl.Sequence.from_lengths( + x_expected.values, self.xp.array([-1, 5, 10, 17, 18]) + ) + self.assertAllEqual(x.lengths(), self.xp.array([0, 5, 10, 17, 17])) + self.assertNotIsInstance(x, self.sl.MaskedSequence) + + # Return type is MaskedSequence if is_masked=True. + x = self.sl.Sequence.from_lengths( + x_expected.values, [-1, 5, 10, 17, 18], is_masked=True + ) + self.assertAllEqual(x.lengths(), self.xp.array([0, 5, 10, 17, 17])) + self.assertIsInstance(x, self.sl.MaskedSequence) + + +class SteppableTest(SequenceLayerTest): + """Abstract tests for Steppable layers.""" + + def create_steppable(self) -> spec.Steppable: + """Creates a basic Steppable instance that should have default properties.""" + + class DefaultSteppable(DefaultTestLayer, self.sl.types.Steppable): # type: ignore[name-defined, misc] + + @override + def layer_with_emits(self, *args, **kwargs): + return super(DefaultTestLayer, self).layer_with_emits(*args, **kwargs) # type: ignore + + @override + def step_with_emits(self, *args, **kwargs): + return super(DefaultTestLayer, self).step_with_emits(*args, **kwargs) # type: ignore + + return DefaultSteppable() # type: ignore + + def test_steppable_defaults(self) -> None: + layer = self.create_steppable() + self.assertEqual(layer.block_size, 1) + self.assertEqual(layer.output_ratio, fractions.Fraction(1)) + self.assertTrue(layer.supports_step) + self.assertEqual(layer.input_latency, 0) + self.assertEqual(layer.output_latency, 0) + self.assertEqual(layer.get_accumulated_input_latency(0), 0) + self.assertEqual(layer.get_accumulated_output_latency(0), 0) + + def create_sequence(self) -> spec.Sequence: + return self.sl.Sequence( + self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) + ) + + def test_steppable_with_emits_defaults_to_tuple_with_empty_emits( + self, + ) -> None: + layer = self.create_steppable() + seq = self.create_sequence() + state_in = {'a': 'b'} + state_out = {1: 2} + + with unittest.mock.patch.object( + layer, 'layer', return_value=seq + ) as mock_layer: + out, emits = layer.layer_with_emits(seq, training=False, constants=None) + self.assertEqual(out, seq) + self.assertEqual(emits, ()) + mock_layer.assert_called_with(seq, training=False, constants=None) + + with unittest.mock.patch.object( + layer, 'step', return_value=(seq, state_out) + ) as mock_step: + out, state, emits = layer.step_with_emits( + seq, state_in, training=True, constants=None + ) + self.assertEqual(out, seq) + self.assertEqual(state, state_out) + self.assertEqual(emits, ()) + mock_step.assert_called_with(seq, state_in, training=True, constants=None) + + +class SequenceLayerConfigTest(SequenceLayerTest): + + def test_copy(self) -> None: + + @dataclasses.dataclass(frozen=True) + class Config(self.sl.SequenceLayerConfig): # type: ignore[name-defined,misc] + a: int = 1234 + b: str = 'default string' + + def make(self) -> Any: + return 'dummy_layer' + + config = Config() # type: ignore + new_config = config.copy(b='new string') + self.assertEqual(new_config.a, config.a) + self.assertEqual(new_config.b, 'new string') + + def test_copy_raises_on_non_dataclass(self) -> None: + + class NonDataclassConfig(self.sl.SequenceLayerConfig): # type: ignore[name-defined,misc] + + def make(self) -> Any: + return 'dummy_layer' + + config = NonDataclassConfig() # type: ignore + with self.assertRaises(TypeError): + new_config = config.copy() + del new_config + + def test_copy_disallows_new_fields(self) -> None: + + @dataclasses.dataclass(frozen=True) + class Config(self.sl.SequenceLayerConfig): # type: ignore[name-defined,misc] + + def make(self) -> Any: + return 'dummy_layer' + + config = Config() # type: ignore + # dataclasses.replace raises TypeError for unknown arguments + # JAX implementation wraps it in AttributeError + with self.assertRaises((TypeError, AttributeError)): + new_config = config.copy(field_does_not_exist=1234) + del new_config + + +class PreservesTypeTest(SequenceLayerTest): + + def create_layer(self) -> spec.PreservesType: + class DummyLayer(DefaultTestLayer, self.sl.types.PreservesType): # type: ignore[name-defined, misc] + + @override + def get_output_dtype(self, *args, **kwargs): + return super(DefaultTestLayer, self).get_output_dtype(*args, **kwargs) # type: ignore + + return DummyLayer() # type: ignore + + def test_preserves_dtype(self) -> None: + layer = self.create_layer() + self.assertEqual(layer.get_output_dtype('fake_dtype123'), 'fake_dtype123') + + +class PreservesShapeTest(SequenceLayerTest): + + def create_layer(self) -> spec.PreservesShape: + class DummyLayer(DefaultTestLayer, self.sl.types.PreservesShape): # type: ignore[name-defined, misc] + + @override + def get_output_shape(self, *args, **kwargs): + return super(DefaultTestLayer, self).get_output_shape(*args, **kwargs) # type: ignore + + return DummyLayer() # type: ignore + + def test_preserves_shape(self) -> None: + layer = self.create_layer() + self.assertEqual(layer.get_output_shape((1, 2, 3, 5)), (1, 2, 3, 5)) + + +class StatelessTest(SequenceLayerTest): + + def create_layer(self) -> spec.Stateless: + class DummyLayer(DefaultTestLayer, self.sl.types.Stateless): # type: ignore[name-defined, misc] + + @override + def get_initial_state(self, *args, **kwargs): + return super(DefaultTestLayer, self).get_initial_state(*args, **kwargs) # type: ignore + + @override + def step(self, *args, **kwargs): + return super(DefaultTestLayer, self).step(*args, **kwargs) # type: ignore + + return DummyLayer() # type: ignore + + def test_stateless_behaviors(self) -> None: + layer = self.create_layer() + + # Initial state must be empty + self.assertEqual( + layer.get_initial_state(32, 'fake_spec', training=False), () + ) + + # step unconditionally delegates to layer and returns identical empty state + with unittest.mock.patch.object( + layer, 'layer', return_value='layer_out' + ) as mock_layer: + out, state = layer.step('mock_x', 'mock_state', training=True, constants={'c': 1}) # type: ignore + self.assertEqual(out, 'layer_out') + self.assertEqual(state, 'mock_state') + mock_layer.assert_called_once_with( + 'mock_x', training=True, constants={'c': 1} + ) + + +class EmittingTest(SequenceLayerTest): + + def create_layer(self) -> spec.Emitting: + class DummyLayer(DefaultTestLayer, self.sl.types.Emitting): # type: ignore[name-defined, misc] + + @override + def layer(self, *args, **kwargs): + return super(DefaultTestLayer, self).layer(*args, **kwargs) # type: ignore + + @override + def step(self, *args, **kwargs): + return super(DefaultTestLayer, self).step(*args, **kwargs) # type: ignore + + return DummyLayer() # type: ignore + + def test_emitting_drops_emits_on_standard_calls(self) -> None: + layer = self.create_layer() + + with unittest.mock.patch.object( + layer, 'layer_with_emits', return_value=('out', 'emits') + ) as m_layer: + self.assertEqual(layer.layer('mock_x', training=False), 'out') # type: ignore + m_layer.assert_called_once_with('mock_x', training=False, constants=None) + + with unittest.mock.patch.object( + layer, 'step_with_emits', return_value=('out', 'state', 'emits') + ) as m_step: + out, state = layer.step('mock_x', 'state', training=True, constants={'c': 1}) # type: ignore + self.assertEqual(out, 'out') + self.assertEqual(state, 'state') + m_step.assert_called_once_with( + 'mock_x', 'state', training=True, constants={'c': 1} + ) + + +class StatelessEmittingTest(SequenceLayerTest): + + def create_layer(self) -> spec.SequenceLayer: + class DummyLayer(DefaultTestLayer, self.sl.types.StatelessEmitting): # type: ignore[name-defined, misc] + + @override + def get_initial_state(self, *args, **kwargs): + return super(DefaultTestLayer, self).get_initial_state(*args, **kwargs) # type: ignore + + @override + def step_with_emits(self, *args, **kwargs): + return super(DefaultTestLayer, self).step_with_emits(*args, **kwargs) # type: ignore + + return DummyLayer() # type: ignore + + def test_stateless_emitting_behaviors(self) -> None: + layer = self.create_layer() + + self.assertEqual( + layer.get_initial_state(32, 'fake_spec', training=False), () + ) + + with unittest.mock.patch.object( + layer, 'layer_with_emits', return_value=('out', 'emits') + ) as m_layer: + out, state, emits = layer.step_with_emits('mock_x', 'state', training=False) # type: ignore + self.assertEqual(out, 'out') + self.assertEqual(state, 'state') + self.assertEqual(emits, 'emits') + m_layer.assert_called_once_with('mock_x', training=False, constants=None) + + +class StatelessPointwiseFunctorTest(SequenceLayerTest): + + def create_layer(self, is_mask_required: bool) -> spec.SequenceLayer[Any]: + + class DummyLayer(DefaultTestLayer, self.sl.types.StatelessPointwiseFunctor): # type: ignore[name-defined, misc] + + @override + def layer(self, *args, **kwargs): + return super(DefaultTestLayer, self).layer(*args, **kwargs) # type: ignore + + @override + def get_output_shape(self, *args, **kwargs): + return super(DefaultTestLayer, self).get_output_shape(*args, **kwargs) # type: ignore + + @property + @override + def mask_required(self_inner) -> bool: + return is_mask_required + + @override + def fn(self_inner, values: Any, mask: Any) -> tuple[Any, Any]: + return values, mask + + return DummyLayer() # type: ignore + + def create_sequence(self) -> spec.Sequence[spec.Array, spec.Array]: + return self.sl.Sequence( + self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) + ) + + def test_layer_applies_fn_based_on_mask_required(self) -> None: + for mask_required in [True, False]: + with self.subTest(mask_required=mask_required): + layer = self.create_layer(mask_required) + x = self.create_sequence() + # Mock the apply methods on the Sequence class itself so we return a valid Sequence + # that satisfies any @check_layer decorators. + with unittest.mock.patch.object( + type(x), 'apply', return_value=x + ) as mock_apply: + with unittest.mock.patch.object( + type(x), 'apply_masked', return_value=x + ) as mock_apply_masked: + layer.layer(x, training=False) + + if mask_required: + mock_apply.assert_called_once() + mock_apply_masked.assert_not_called() + else: + mock_apply_masked.assert_called_once() + mock_apply.assert_not_called() From ef1dc36744c1ba797d5d966fe3c332582cfa1afe Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Wed, 8 Apr 2026 13:56:34 -0700 Subject: [PATCH 4/9] refactor(test): Extract abstract SequenceLayerTest and enforce abstract methods --- sequence_layers/jax/test_utils.py | 5 +- sequence_layers/jax/types_test.py | 49 ++-- sequence_layers/mlx/test_utils.py | 8 + sequence_layers/mlx/types_test.py | 54 ++--- sequence_layers/specs/test_utils.py | 37 +++ sequence_layers/specs/types.py | 12 + sequence_layers/specs/types_behaviors.py | 273 ++++++++++++++++------- 7 files changed, 306 insertions(+), 132 deletions(-) create mode 100644 sequence_layers/mlx/test_utils.py create mode 100644 sequence_layers/specs/test_utils.py diff --git a/sequence_layers/jax/test_utils.py b/sequence_layers/jax/test_utils.py index 1a75441..567c375 100644 --- a/sequence_layers/jax/test_utils.py +++ b/sequence_layers/jax/test_utils.py @@ -32,8 +32,7 @@ from sequence_layers.jax import types from sequence_layers.jax import typing as jt from sequence_layers.jax import utils - -from sequence_layers.specs import types_behaviors as types_behaviors_spec +from sequence_layers.specs import test_utils as spec _SequenceLayerT = TypeVar('_SequenceLayerT', bound=types.SequenceLayer) _T = TypeVar('_T') @@ -780,7 +779,7 @@ def _mask_and_pad_to_max_length( return a, b -class SequenceLayerTest(types_behaviors_spec.SequenceLayerTest[types.Sequence]): +class SequenceLayerTest(spec.SequenceLayerTest): """Base class for SequenceLayer tests.""" sl = sl diff --git a/sequence_layers/jax/types_test.py b/sequence_layers/jax/types_test.py index 4bf0632..7b7b5cf 100644 --- a/sequence_layers/jax/types_test.py +++ b/sequence_layers/jax/types_test.py @@ -15,9 +15,7 @@ import dataclasses import typing -from typing import Any -from absl.testing import parameterized import chex import flax.linen as nn import jax @@ -25,7 +23,6 @@ import jaxtyping import numpy as np -import sequence_layers.jax as sl from sequence_layers.jax import simple from sequence_layers.jax import test_utils from sequence_layers.jax import types @@ -33,15 +30,15 @@ from sequence_layers.specs import types_behaviors as spec -class ModuleInterfaceTest(spec.ModuleInterfaceTest): - sl = sl +class ModuleInterfaceTest( + test_utils.SequenceLayerTest, spec.ModuleInterfaceTest +): + pass class SequenceTest(test_utils.SequenceLayerTest, spec.SequenceTest): """Tests for the Sequence class.""" - sl = sl - def test_type_checks(self): """Test type checks in Sequence.__post_init__.""" @@ -200,8 +197,10 @@ def fn(x: types.Sequence) -> types.Sequence: self.assertSequencesEqual(y, x) -class SequenceLayerConfigTest(spec.SequenceLayerConfigTest): - sl = sl +class SequenceLayerConfigTest( + test_utils.SequenceLayerTest, spec.SequenceLayerConfigTest +): + pass def test_copy_raises_on_mutable_attribute(self): @@ -242,32 +241,36 @@ def make(self) -> simple.Identity: del new_config -class SteppableTest(spec.SteppableTest): - sl = sl +class SteppableTest(test_utils.SequenceLayerTest, spec.SteppableTest): + pass -class PreservesTypeTest(spec.PreservesTypeTest): - sl = sl +class PreservesTypeTest(test_utils.SequenceLayerTest, spec.PreservesTypeTest): + pass -class PreservesShapeTest(spec.PreservesShapeTest): - sl = sl +class PreservesShapeTest(test_utils.SequenceLayerTest, spec.PreservesShapeTest): + pass -class StatelessTest(spec.StatelessTest): - sl = sl +class StatelessTest(test_utils.SequenceLayerTest, spec.StatelessTest): + pass -class EmittingTest(spec.EmittingTest): - sl = sl +class EmittingTest(test_utils.SequenceLayerTest, spec.EmittingTest): + pass -class StatelessEmittingTest(spec.StatelessEmittingTest): - sl = sl +class StatelessEmittingTest( + test_utils.SequenceLayerTest, spec.StatelessEmittingTest +): + pass -class StatelessPointwiseFunctorTest(spec.StatelessPointwiseFunctorTest): - sl = sl +class StatelessPointwiseFunctorTest( + test_utils.SequenceLayerTest, spec.StatelessPointwiseFunctorTest +): + pass if __name__ == '__main__': diff --git a/sequence_layers/mlx/test_utils.py b/sequence_layers/mlx/test_utils.py new file mode 100644 index 0000000..abdfc61 --- /dev/null +++ b/sequence_layers/mlx/test_utils.py @@ -0,0 +1,8 @@ +import sequence_layers.mlx as sl +from sequence_layers.specs import test_utils as spec + + +class SequenceLayerTest(spec.SequenceLayerTest): + """Base class for MLX SequenceLayer tests.""" + + sl = sl diff --git a/sequence_layers/mlx/types_test.py b/sequence_layers/mlx/types_test.py index 8713558..3981d1b 100644 --- a/sequence_layers/mlx/types_test.py +++ b/sequence_layers/mlx/types_test.py @@ -1,50 +1,54 @@ -import mlx.core as mx -import numpy as np - -import sequence_layers.mlx as sl +from sequence_layers.mlx import test_utils from sequence_layers.specs import types_behaviors as spec -from absl.testing import parameterized from absl.testing import absltest -class ModuleInterfaceTest(spec.ModuleInterfaceTest): - sl = sl +class ModuleInterfaceTest( + test_utils.SequenceLayerTest, spec.ModuleInterfaceTest +): + pass -class SequenceTest(spec.SequenceTest): - sl = sl +class SequenceTest(test_utils.SequenceLayerTest, spec.SequenceTest): + pass -class SequenceLayerConfigTest(spec.SequenceLayerConfigTest): - sl = sl +class SequenceLayerConfigTest( + test_utils.SequenceLayerTest, spec.SequenceLayerConfigTest +): + pass -class SteppableTest(spec.SteppableTest): - sl = sl +class SteppableTest(test_utils.SequenceLayerTest, spec.SteppableTest): + pass -class PreservesTypeTest(spec.PreservesTypeTest): - sl = sl +class PreservesTypeTest(test_utils.SequenceLayerTest, spec.PreservesTypeTest): + pass -class PreservesShapeTest(spec.PreservesShapeTest): - sl = sl +class PreservesShapeTest(test_utils.SequenceLayerTest, spec.PreservesShapeTest): + pass -class StatelessTest(spec.StatelessTest): - sl = sl +class StatelessTest(test_utils.SequenceLayerTest, spec.StatelessTest): + pass -class EmittingTest(spec.EmittingTest): - sl = sl +class EmittingTest(test_utils.SequenceLayerTest, spec.EmittingTest): + pass -class StatelessEmittingTest(spec.StatelessEmittingTest): - sl = sl +class StatelessEmittingTest( + test_utils.SequenceLayerTest, spec.StatelessEmittingTest +): + pass -class StatelessPointwiseFunctorTest(spec.StatelessPointwiseFunctorTest): - sl = sl +class StatelessPointwiseFunctorTest( + test_utils.SequenceLayerTest, spec.StatelessPointwiseFunctorTest +): + pass if __name__ == '__main__': diff --git a/sequence_layers/specs/test_utils.py b/sequence_layers/specs/test_utils.py new file mode 100644 index 0000000..90596d3 --- /dev/null +++ b/sequence_layers/specs/test_utils.py @@ -0,0 +1,37 @@ +"""Test utilities for sequence layers.""" + +import abc +from typing import Any +from absl.testing import parameterized +from sequence_layers import specs +from sequence_layers.specs import backend as backend_spec +from sequence_layers.specs import types as spec + + +class _AbcParameterizedTestCaseMeta(abc.ABCMeta, type(parameterized.TestCase)): + pass + + +class SequenceLayerTest[SequenceT: spec.Sequence = spec.Sequence]( + parameterized.TestCase, metaclass=_AbcParameterizedTestCaseMeta +): + """Base test class providing common sequence testing assertions. + + Binds a backend implementation to tests. + """ + + # sequence_layers. module + sl: specs.ModuleSpec + + @property + def xp(self) -> backend_spec.xp: + """Returns the backend module.""" + return self.sl.backend.xp + + @abc.abstractmethod + def assertSequencesEqual(self, x: SequenceT, y: SequenceT) -> None: # pylint: disable=invalid-name + """After padding, checks sequence values are equal and masks are equal.""" + + @abc.abstractmethod + def assertAllEqual(self, x: Any, y: Any) -> None: # pylint: disable=invalid-name + """Asserts that two arrays are equal.""" diff --git a/sequence_layers/specs/types.py b/sequence_layers/specs/types.py index 07a66a0..2b7a89b 100644 --- a/sequence_layers/specs/types.py +++ b/sequence_layers/specs/types.py @@ -865,6 +865,18 @@ def StatelessPointwise(self) -> type[StatelessPointwise]: def StatelessPointwiseFunctor(self) -> type[StatelessPointwiseFunctor]: ... + @property + def PreservesType(self) -> type[PreservesType]: + ... + + @property + def Emitting(self) -> type[Emitting]: + ... + + @property + def StatelessEmitting(self) -> type[StatelessEmitting]: + ... + __all__ = [ name diff --git a/sequence_layers/specs/types_behaviors.py b/sequence_layers/specs/types_behaviors.py index 427c151..ece3d7a 100644 --- a/sequence_layers/specs/types_behaviors.py +++ b/sequence_layers/specs/types_behaviors.py @@ -1,20 +1,58 @@ -"""Abstract tests for Sequence types.""" +# pylint: disable=abstract-method +"""Generic tests for Sequence types.""" -import abc -from types import ModuleType -from typing import Any, Callable, Sequence as TypingSequence, TYPE_CHECKING, cast, override import dataclasses -from sequence_layers import specs -from sequence_layers.specs import backend as backend_spec -from sequence_layers.specs import types as spec - import fractions +from typing import Any, override +import unittest.mock + from absl.testing import parameterized import numpy as np -import unittest.mock + +from sequence_layers.specs import types as spec +from sequence_layers.specs.test_utils import SequenceLayerTest class DefaultTestLayer(spec.SequenceLayer): + """A default test layer for testing.""" + + @property + @override + def block_size(self) -> int: + return 1 + + @property + @override + def output_ratio(self) -> fractions.Fraction: + return fractions.Fraction(1) + + @property + @override + def supports_step(self) -> bool: + return True + + @property + @override + def input_latency(self) -> int: + return 0 + + @property + @override + def output_latency(self) -> int: + return 0 + + @property + @override + def receptive_field(self) -> Any: + return 1 + + @override + def get_accumulated_input_latency(self, input_latency: int) -> int: + return input_latency + + @override + def get_accumulated_output_latency(self, output_latency: int) -> int: + return output_latency @override def layer( @@ -89,27 +127,6 @@ def get_output_dtype( return np.float64 -class SequenceLayerTest[SequenceT: spec.Sequence = spec.Sequence]( - parameterized.TestCase -): - """Base test class providing common sequence testing assertions and binds a backend implementation to tests.""" - - # sequence_layers. module - sl: specs.ModuleSpec - - @property - def xp(self) -> backend_spec.xp: - return self.sl.backend.xp - - @abc.abstractmethod - def assertSequencesEqual(self, x: SequenceT, y: SequenceT) -> None: - ... - - @abc.abstractmethod - def assertAllEqual(self, x: Any, y: Any) -> None: - ... - - class ModuleInterfaceTest(SequenceLayerTest): def test_backend_specific_module_has_interface(self) -> None: @@ -117,7 +134,7 @@ def test_backend_specific_module_has_interface(self) -> None: class SequenceTest(SequenceLayerTest): - """Abstract tests for the Sequence class.""" + """Generic tests for the Sequence class.""" @parameterized.named_parameters( ('mask_value=None', 0.0, None), @@ -183,6 +200,7 @@ def test_pad_time(self) -> None: def _create_test_sequence( self, shape: spec.Shape ) -> spec.Sequence[spec.Array, spec.Array]: + """Creates a test sequence with specific shape.""" size = 1 for d in shape: size *= d @@ -264,13 +282,7 @@ def test_apply_values_args(self) -> None: def test_from_values(self) -> None: values_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) values = self.xp.array(values_np) - # Get the class from an instance - seq = self.sl.Sequence( - values, self.xp.array(np.ones(values.shape[:2], dtype=bool)) - ) - SeqClass = type(seq) - - x = SeqClass.from_values(values) # type: ignore + x = self.sl.Sequence.from_values(values) self.assertAllEqual(x.values, values) self.assertAllEqual( x.mask, self.xp.array(np.ones(values.shape[:2], dtype=bool)) @@ -316,6 +328,7 @@ def test_mask_invalid_idempotent(self) -> None: self.assertIsInstance(masked2, self.sl.MaskedSequence) def test_from_lengths(self) -> None: + """Tests creating a sequence from lengths.""" values = self.xp.array( np.arange(5 * 17 * 2).reshape((5, 17, 2)).astype(np.float32) ) @@ -344,22 +357,27 @@ def test_from_lengths(self) -> None: class SteppableTest(SequenceLayerTest): - """Abstract tests for Steppable layers.""" def create_steppable(self) -> spec.Steppable: - """Creates a basic Steppable instance that should have default properties.""" + """Creates a basic Steppable instance.""" + backend_sl = self.sl - class DefaultSteppable(DefaultTestLayer, self.sl.types.Steppable): # type: ignore[name-defined, misc] + class DefaultSteppable( + DefaultTestLayer, backend_sl.types.Steppable + ): + """Mock layer for testing.""" @override def layer_with_emits(self, *args, **kwargs): - return super(DefaultTestLayer, self).layer_with_emits(*args, **kwargs) # type: ignore + return backend_sl.types.Steppable.layer_with_emits( + self, *args, **kwargs + ) @override def step_with_emits(self, *args, **kwargs): - return super(DefaultTestLayer, self).step_with_emits(*args, **kwargs) # type: ignore + return backend_sl.types.Steppable.step_with_emits(self, *args, **kwargs) - return DefaultSteppable() # type: ignore + return DefaultSteppable() def test_steppable_defaults(self) -> None: layer = self.create_steppable() @@ -371,7 +389,9 @@ def test_steppable_defaults(self) -> None: self.assertEqual(layer.get_accumulated_input_latency(0), 0) self.assertEqual(layer.get_accumulated_output_latency(0), 0) + @override def create_sequence(self) -> spec.Sequence: + """Creates a test sequence.""" return self.sl.Sequence( self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) ) @@ -407,13 +427,18 @@ def test_steppable_with_emits_defaults_to_tuple_with_empty_emits( class SequenceLayerConfigTest(SequenceLayerTest): def test_copy(self) -> None: + backend_sl = self.sl @dataclasses.dataclass(frozen=True) - class Config(self.sl.SequenceLayerConfig): # type: ignore[name-defined,misc] + class Config(backend_sl.SequenceLayerConfig): + """Mock config.""" + a: int = 1234 b: str = 'default string' + @override def make(self) -> Any: + """Makes a dummy layer.""" return 'dummy_layer' config = Config() # type: ignore @@ -422,10 +447,14 @@ def make(self) -> Any: self.assertEqual(new_config.b, 'new string') def test_copy_raises_on_non_dataclass(self) -> None: + backend_sl = self.sl - class NonDataclassConfig(self.sl.SequenceLayerConfig): # type: ignore[name-defined,misc] + class NonDataclassConfig(backend_sl.SequenceLayerConfig): # pylint: disable=too-few-public-methods + """Non-dataclass mock config.""" + @override def make(self) -> Any: + """Makes a dummy layer.""" return 'dummy_layer' config = NonDataclassConfig() # type: ignore @@ -434,11 +463,15 @@ def make(self) -> Any: del new_config def test_copy_disallows_new_fields(self) -> None: + backend_sl = self.sl @dataclasses.dataclass(frozen=True) - class Config(self.sl.SequenceLayerConfig): # type: ignore[name-defined,misc] + class Config(backend_sl.SequenceLayerConfig): + """Mock config.""" + @override def make(self) -> Any: + """Makes a dummy layer.""" return 'dummy_layer' config = Config() # type: ignore @@ -452,13 +485,21 @@ def make(self) -> Any: class PreservesTypeTest(SequenceLayerTest): def create_layer(self) -> spec.PreservesType: - class DummyLayer(DefaultTestLayer, self.sl.types.PreservesType): # type: ignore[name-defined, misc] + """Creates a preserves type layer.""" + backend_sl = self.sl + + class DummyLayer( + DefaultTestLayer, backend_sl.types.PreservesType + ): + """Mock layer for testing.""" @override def get_output_dtype(self, *args, **kwargs): - return super(DefaultTestLayer, self).get_output_dtype(*args, **kwargs) # type: ignore + return backend_sl.types.PreservesType.get_output_dtype( + self, *args, **kwargs + ) - return DummyLayer() # type: ignore + return DummyLayer() def test_preserves_dtype(self) -> None: layer = self.create_layer() @@ -468,13 +509,21 @@ def test_preserves_dtype(self) -> None: class PreservesShapeTest(SequenceLayerTest): def create_layer(self) -> spec.PreservesShape: - class DummyLayer(DefaultTestLayer, self.sl.types.PreservesShape): # type: ignore[name-defined, misc] + """Creates a preserves shape layer.""" + backend_sl = self.sl + + class DummyLayer( + DefaultTestLayer, backend_sl.types.PreservesShape + ): + """Mock layer for testing.""" @override def get_output_shape(self, *args, **kwargs): - return super(DefaultTestLayer, self).get_output_shape(*args, **kwargs) # type: ignore + return backend_sl.types.PreservesShape.get_output_shape( + self, *args, **kwargs + ) - return DummyLayer() # type: ignore + return DummyLayer() def test_preserves_shape(self) -> None: layer = self.create_layer() @@ -483,18 +532,33 @@ def test_preserves_shape(self) -> None: class StatelessTest(SequenceLayerTest): + @override + def create_sequence(self) -> spec.Sequence: + """Creates a default test sequence.""" + return self.sl.Sequence( + self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) + ) + def create_layer(self) -> spec.Stateless: - class DummyLayer(DefaultTestLayer, self.sl.types.Stateless): # type: ignore[name-defined, misc] + """Creates a stateless layer.""" + backend_sl = self.sl + + class DummyLayer( + DefaultTestLayer, backend_sl.types.Stateless + ): + """Mock layer for testing.""" @override def get_initial_state(self, *args, **kwargs): - return super(DefaultTestLayer, self).get_initial_state(*args, **kwargs) # type: ignore + return backend_sl.types.Stateless.get_initial_state( + self, *args, **kwargs + ) @override def step(self, *args, **kwargs): - return super(DefaultTestLayer, self).step(*args, **kwargs) # type: ignore + return backend_sl.types.Stateless.step(self, *args, **kwargs) - return DummyLayer() # type: ignore + return DummyLayer() def test_stateless_behaviors(self) -> None: layer = self.create_layer() @@ -505,66 +569,98 @@ def test_stateless_behaviors(self) -> None: ) # step unconditionally delegates to layer and returns identical empty state + x = self.create_sequence() with unittest.mock.patch.object( layer, 'layer', return_value='layer_out' ) as mock_layer: - out, state = layer.step('mock_x', 'mock_state', training=True, constants={'c': 1}) # type: ignore + out, state = layer.step( + x, 'mock_state', training=True, constants={'c': 1} + ) self.assertEqual(out, 'layer_out') self.assertEqual(state, 'mock_state') - mock_layer.assert_called_once_with( - 'mock_x', training=True, constants={'c': 1} - ) + mock_layer.assert_called_once_with(x, training=True, constants={'c': 1}) class EmittingTest(SequenceLayerTest): + @override + def create_sequence(self) -> spec.Sequence: + """Creates a default test sequence.""" + return self.sl.Sequence( + self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) + ) + def create_layer(self) -> spec.Emitting: - class DummyLayer(DefaultTestLayer, self.sl.types.Emitting): # type: ignore[name-defined, misc] + """Creates an emitting layer.""" + backend_sl = self.sl + + class DummyLayer( + DefaultTestLayer, backend_sl.types.Emitting + ): + """Mock layer for testing.""" @override def layer(self, *args, **kwargs): - return super(DefaultTestLayer, self).layer(*args, **kwargs) # type: ignore + return backend_sl.types.Emitting.layer(self, *args, **kwargs) @override def step(self, *args, **kwargs): - return super(DefaultTestLayer, self).step(*args, **kwargs) # type: ignore + return backend_sl.types.Emitting.step(self, *args, **kwargs) - return DummyLayer() # type: ignore + return DummyLayer() def test_emitting_drops_emits_on_standard_calls(self) -> None: layer = self.create_layer() + x = self.create_sequence() with unittest.mock.patch.object( layer, 'layer_with_emits', return_value=('out', 'emits') ) as m_layer: - self.assertEqual(layer.layer('mock_x', training=False), 'out') # type: ignore - m_layer.assert_called_once_with('mock_x', training=False, constants=None) + self.assertEqual(layer.layer(x, training=False), 'out') + m_layer.assert_called_once_with(x, training=False, constants=None) with unittest.mock.patch.object( layer, 'step_with_emits', return_value=('out', 'state', 'emits') ) as m_step: - out, state = layer.step('mock_x', 'state', training=True, constants={'c': 1}) # type: ignore + out, state = layer.step(x, 'state', training=True, constants={'c': 1}) self.assertEqual(out, 'out') self.assertEqual(state, 'state') m_step.assert_called_once_with( - 'mock_x', 'state', training=True, constants={'c': 1} + x, 'state', training=True, constants={'c': 1} ) class StatelessEmittingTest(SequenceLayerTest): + @override + def create_sequence(self) -> spec.Sequence: + """Creates a default test sequence.""" + return self.sl.Sequence( + self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) + ) + def create_layer(self) -> spec.SequenceLayer: - class DummyLayer(DefaultTestLayer, self.sl.types.StatelessEmitting): # type: ignore[name-defined, misc] + """Creates a stateless emitting layer.""" + backend_sl = self.sl + + class DummyLayer( + DefaultTestLayer, backend_sl.types.StatelessEmitting + ): + """Mock layer for testing.""" @override def get_initial_state(self, *args, **kwargs): - return super(DefaultTestLayer, self).get_initial_state(*args, **kwargs) # type: ignore + return backend_sl.types.StatelessEmitting.get_initial_state( + self, *args, **kwargs + ) @override def step_with_emits(self, *args, **kwargs): - return super(DefaultTestLayer, self).step_with_emits(*args, **kwargs) # type: ignore + return backend_sl.types.StatelessEmitting.step_with_emits( + self, *args, **kwargs + ) - return DummyLayer() # type: ignore + return DummyLayer() def test_stateless_emitting_behaviors(self) -> None: layer = self.create_layer() @@ -573,42 +669,57 @@ def test_stateless_emitting_behaviors(self) -> None: layer.get_initial_state(32, 'fake_spec', training=False), () ) + x = self.create_sequence() with unittest.mock.patch.object( layer, 'layer_with_emits', return_value=('out', 'emits') ) as m_layer: - out, state, emits = layer.step_with_emits('mock_x', 'state', training=False) # type: ignore + out, state, emits = layer.step_with_emits(x, 'state', training=False) self.assertEqual(out, 'out') self.assertEqual(state, 'state') self.assertEqual(emits, 'emits') - m_layer.assert_called_once_with('mock_x', training=False, constants=None) + m_layer.assert_called_once_with(x, training=False, constants=None) class StatelessPointwiseFunctorTest(SequenceLayerTest): def create_layer(self, is_mask_required: bool) -> spec.SequenceLayer[Any]: + """Creates a stateless pointwise functor layer.""" - class DummyLayer(DefaultTestLayer, self.sl.types.StatelessPointwiseFunctor): # type: ignore[name-defined, misc] + backend_sl = self.sl + + class DummyLayer( + DefaultTestLayer, backend_sl.types.StatelessPointwiseFunctor + ): + """Mock layer for testing.""" @override def layer(self, *args, **kwargs): - return super(DefaultTestLayer, self).layer(*args, **kwargs) # type: ignore + return backend_sl.types.StatelessPointwiseFunctor.layer( + self, *args, **kwargs + ) @override def get_output_shape(self, *args, **kwargs): - return super(DefaultTestLayer, self).get_output_shape(*args, **kwargs) # type: ignore + return backend_sl.types.StatelessPointwiseFunctor.get_output_shape( + self, *args, **kwargs + ) @property @override - def mask_required(self_inner) -> bool: + def mask_required(self) -> bool: + """Whether mask is required.""" return is_mask_required @override - def fn(self_inner, values: Any, mask: Any) -> tuple[Any, Any]: + def fn(self, values: Any, mask: Any) -> tuple[Any, Any]: + """Pointwise function.""" return values, mask - return DummyLayer() # type: ignore + return DummyLayer() + @override def create_sequence(self) -> spec.Sequence[spec.Array, spec.Array]: + """Creates a test sequence.""" return self.sl.Sequence( self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) ) From 7fa83a4166e99d535f4933f2d3931a388731daef Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Wed, 8 Apr 2026 13:56:48 -0700 Subject: [PATCH 5/9] test(mlx): Implement missing assertions to stop vacuous passing --- pyproject.toml | 5 +++++ sequence_layers/mlx/test_utils.py | 35 +++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index aefdb47..03dfd2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,11 @@ unstable = true pyink-indentation = 2 pyink-use-majority-quotes = true +[tool.pylint.format] +indent-string = " " + + + [build-system] # Build system specify which backend is used to build/install the project (flit, # poetry, setuptools,...). All backends are supported by `pip install` diff --git a/sequence_layers/mlx/test_utils.py b/sequence_layers/mlx/test_utils.py index abdfc61..29cf520 100644 --- a/sequence_layers/mlx/test_utils.py +++ b/sequence_layers/mlx/test_utils.py @@ -1,8 +1,43 @@ +"""Test utilities for MLX sequence layers.""" + +from typing import override +import numpy as np +import mlx.core as mx import sequence_layers.mlx as sl +from sequence_layers.mlx import types from sequence_layers.specs import test_utils as spec +def _mask_and_pad_to_max_length( + a: types.Sequence, b: types.Sequence +) -> tuple[types.Sequence, types.Sequence]: + """Masks and pads two sequences to the same max length.""" + # Only compare values in non-masked regions. + a = a.mask_invalid() + b = b.mask_invalid() + a_time = a.values.shape[1] + b_time = b.values.shape[1] + max_time = max(a_time, b_time) + a = a.pad_time(0, max_time - a_time, valid=False) + b = b.pad_time(0, max_time - b_time, valid=False) + return a, b + + class SequenceLayerTest(spec.SequenceLayerTest): """Base class for MLX SequenceLayer tests.""" sl = sl + + @override + def assertAllEqual(self, x, y): + """Asserts that two arrays are equal.""" + x_np = np.array(x) if isinstance(x, mx.array) else x + y_np = np.array(y) if isinstance(y, mx.array) else y + np.testing.assert_array_equal(x_np, y_np) + + @override + def assertSequencesEqual(self, x: types.Sequence, y: types.Sequence): + """After padding, checks sequence values are equal and masks are equal.""" + x, y = _mask_and_pad_to_max_length(x, y) + self.assertAllEqual(x.values, y.values) + self.assertAllEqual(x.mask, y.mask) From 64ba05db6e6b6dcb5ae601a4133644085233e8a3 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Wed, 8 Apr 2026 16:32:44 -0700 Subject: [PATCH 6/9] chore: Update linting configurations for pylint and pyrefly --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 03dfd2f..3ce4706 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,9 @@ pyink-use-majority-quotes = true [tool.pylint.format] indent-string = " " +[tool.pylint.basic] +no-docstring-rgx = "^(_)?test_|^.*Test$" + [build-system] @@ -89,4 +92,5 @@ exclude = [ "testdata/**", ] -[tool.pyrefly] \ No newline at end of file +[tool.pyrefly] +errors = { missing-override-decorator = "error" } \ No newline at end of file From a261193af8ab0e53fda6fa3ac5fce9fbcc4adee3 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Thu, 9 Apr 2026 16:50:31 -0700 Subject: [PATCH 7/9] refactor: move get_output_spec to Steppable and standardize tests --- sequence_layers/mlx/types.py | 90 ++++++++++---- sequence_layers/mlx/types_test.py | 6 +- sequence_layers/specs/types.py | 151 +++++++++++++++++------ sequence_layers/specs/types_behaviors.py | 70 +++++++---- 4 files changed, 228 insertions(+), 89 deletions(-) diff --git a/sequence_layers/mlx/types.py b/sequence_layers/mlx/types.py index 667d18d..e7dda5d 100644 --- a/sequence_layers/mlx/types.py +++ b/sequence_layers/mlx/types.py @@ -2,16 +2,25 @@ import abc import dataclasses -import enum import fractions import functools +import math import types -from typing import Any, Callable, Generic, Iterable, Self, TypeVar, override, cast +from typing import ( + Any, + Callable, + cast, + Generic, + Iterable, + MutableMapping, + override, + Self, + TypeVar, +) import jaxtyping as jt +from mlx import nn import mlx.core as mx -import mlx.nn as nn -import numpy as np from sequence_layers.specs import types as spec @@ -30,8 +39,8 @@ ShapeLike = list[int] | tuple[int, ...] DType = mx.Dtype State = object # Any pytree. -Constants = dict[str, object] -Emits = object +Constants = MutableMapping[str, jt.PyTree[mx.array]] +Emits = jt.PyTree[mx.array] # Receptive field. ReceptiveField = tuple[float | int, float | int] | None @@ -81,9 +90,11 @@ def __init__(self, shape: Shape, dtype: DType): self.shape = shape self.dtype = dtype + @override def __repr__(self) -> str: return f'ShapeDType(shape={self.shape}, dtype={self.dtype})' + @override def __eq__(self, other: object) -> bool: if not isinstance(other, ShapeDType): return NotImplemented @@ -99,7 +110,8 @@ def __hash__(self) -> int: def sequence_mask(lengths: LengthsT, maxlen: int) -> mx.array: - return mx.arange(maxlen)[None, :] < mx.array(lengths)[:, None] + """Generates a boolean mask for sequences based on lengths.""" + return mx.arange(maxlen)[None, :] < mx.array(lengths)[:, None] # pylint: disable=unsubscriptable-object class Sequence[ValuesT: mx.array, MaskT: mx.array]( @@ -258,8 +270,8 @@ def __getitem__( if isinstance(the_slice, slice): the_slice = (the_slice,) return type(self)( - self.values.__getitem__(the_slice), - self.mask.__getitem__(the_slice[:2]), + self.values[the_slice], + self.mask[the_slice[:2]], ) @override @@ -294,8 +306,7 @@ def concatenate(self, other: 'Sequence') -> 'Sequence': mask = mx.concatenate([self.mask, other.mask], axis=1) if type(self) is type(other): return type(self)(values, mask) - else: - return Sequence(values, mask) + return Sequence(values, mask) @override def mask_invalid(self, mask_value: complex | None = None) -> 'Sequence': @@ -375,6 +386,7 @@ def mask_invalid( def _check_output_spec(layer, x, y, constants): + """Checks that the output spec of a layer matches the expected spec.""" expected = layer.get_output_spec(x.channel_spec, constants=constants) if y.channel_shape != expected.shape: raise ValueError( @@ -386,6 +398,7 @@ def _check_output_spec(layer, x, y, constants): def _check_output_ratio(layer, x, y): + """Checks that the output length of a layer matches the expected length.""" expected_length = x.shape[1] * layer.output_ratio if y.shape[1] != expected_length: raise ValueError( @@ -433,7 +446,9 @@ def wrapper(self, x, state, *, training: bool, constants=None): # --------------------------------------------------------------------------- -class Steppable(spec.Steppable[InputT, OutputT], Generic[InputT, OutputT]): +class Steppable[InputT: Sequence, OutputT: Sequence]( + spec.Steppable[InputT, OutputT, ChannelSpec] +): """A sequence processing layer that can be executed layerwise or stepwise. # Step-wise execution: @@ -526,8 +541,6 @@ def output_latency(self) -> int: @override def get_accumulated_input_latency(self, input_latency: int) -> int: - import math - return math.ceil(input_latency / self.output_ratio) + self.input_latency @override @@ -736,12 +749,25 @@ def get_output_dtype( The dtype of the output features. """ + @override def get_output_spec( self, input_spec: ChannelSpec, *, constants: Constants | None = None, ) -> ChannelSpec: + """Returns the output spec this layer produces for the provided input spec. + + Args: + input_spec: A ChannelSpec which represents the channels shape and dtype of + the input sequence (i.e. not including the batch or time dimension). + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. + + Returns: + The ChannelSpec of the output features. + """ shape = self.get_output_shape(input_spec.shape, constants=constants) dtype = self.get_output_dtype(input_spec.dtype, constants=constants) return ChannelSpec(shape, dtype) @@ -753,7 +779,10 @@ def get_output_spec( class SequenceLayer[InputT: Sequence, OutputT: Sequence]( - nn.Module, Steppable[InputT, OutputT], spec.SequenceLayer[InputT, OutputT] + nn.Module, + Steppable[InputT, OutputT], + spec.SequenceLayer[InputT, OutputT, ChannelSpec], + metaclass=abc.ABCMeta, ): """Base Module for Sequence Layers.""" @@ -762,9 +791,11 @@ class SequenceLayerConfig(spec.SequenceLayerConfig): """Base class for SequenceLayer configuration objects.""" @abc.abstractmethod + @override def make(self) -> SequenceLayer: """Builds a SequenceLayer from this config.""" + @override def copy(self, **kwargs) -> Self: """Returns a copy of the config with updated fields.""" return cast(Self, dataclasses.replace(cast(Any, self), **kwargs)) @@ -775,7 +806,11 @@ def copy(self, **kwargs) -> Self: # --------------------------------------------------------------------------- -class PreservesType(SequenceLayer, spec.PreservesType): +class PreservesType[InputT: Sequence, OutputT: Sequence]( + SequenceLayer[InputT, OutputT], + spec.PreservesType[InputT, OutputT, ChannelSpec], + metaclass=abc.ABCMeta, +): """A mix-in for layers that do not change the input dtype.""" @override @@ -790,7 +825,9 @@ def get_output_dtype( class PreservesShape[InputT: Sequence, OutputT: Sequence]( - SequenceLayer[InputT, OutputT], spec.PreservesShape[InputT, OutputT] + SequenceLayer[InputT, OutputT], + spec.PreservesShape[InputT, OutputT, ChannelSpec], + metaclass=abc.ABCMeta, ): """A mix-in for layers that do not change the input shape.""" @@ -811,7 +848,7 @@ def get_output_shape( class Stateless[InputT: Sequence, OutputT: Sequence]( - SequenceLayer[InputT, OutputT], spec.Stateless[InputT, OutputT] + SequenceLayer[InputT, OutputT], spec.Stateless[InputT, OutputT, ChannelSpec] ): """A SequenceLayer with no state over time required for step-wise processing. @@ -882,14 +919,15 @@ def step( class StatelessPointwise[InputT: Sequence, OutputT: Sequence]( PreservesShape[InputT, OutputT], Stateless[InputT, OutputT], - spec.StatelessPointwise[InputT, OutputT], + spec.StatelessPointwise[InputT, OutputT, ChannelSpec], + metaclass=abc.ABCMeta, ): """A SequenceLayer that has no state and operates pointwise on its input.""" class StatelessPointwiseFunctor[InputT: Sequence, OutputT: Sequence]( StatelessPointwise[InputT, OutputT], - spec.StatelessPointwiseFunctor[InputT, OutputT], + spec.StatelessPointwiseFunctor[InputT, OutputT, ChannelSpec], ): """A stateless SequenceLayer for simple pointwise processing fns.""" @@ -907,9 +945,9 @@ def mask_required(self): """ return True - @check_layer @override - def layer( + @check_layer + def layer( # pyrefly: ignore[missing-override-decorator] self, x: InputT, *, @@ -932,10 +970,9 @@ def layer( # --------------------------------------------------------------------------- -class Emitting( +class Emitting[InputT: Sequence, OutputT: Sequence]( SequenceLayer[InputT, OutputT], - spec.Emitting[InputT, OutputT], - Generic[InputT, OutputT], + spec.Emitting[InputT, OutputT, ChannelSpec], ): """A SequenceLayer that emits auxiliary arrays. @@ -1030,7 +1067,8 @@ def layer( class StatelessEmitting[InputT: Sequence, OutputT: Sequence]( - Emitting[InputT, OutputT], spec.StatelessEmitting[InputT, OutputT] + Emitting[InputT, OutputT], + spec.StatelessEmitting[InputT, OutputT, ChannelSpec], ): """A SequenceLayer with no state over time that emits auxiliary arrays. diff --git a/sequence_layers/mlx/types_test.py b/sequence_layers/mlx/types_test.py index 3981d1b..d6a903a 100644 --- a/sequence_layers/mlx/types_test.py +++ b/sequence_layers/mlx/types_test.py @@ -1,6 +1,9 @@ +"""Tests for MLX sequence types.""" + +from absl.testing import absltest + from sequence_layers.mlx import test_utils from sequence_layers.specs import types_behaviors as spec -from absl.testing import absltest class ModuleInterfaceTest( @@ -51,5 +54,6 @@ class StatelessPointwiseFunctorTest( pass + if __name__ == '__main__': absltest.main() diff --git a/sequence_layers/specs/types.py b/sequence_layers/specs/types.py index 2b7a89b..a8b207b 100644 --- a/sequence_layers/specs/types.py +++ b/sequence_layers/specs/types.py @@ -10,12 +10,22 @@ import enum import fractions from types import EllipsisType -from typing import Any, Callable, Concatenate, Generic, Iterable, Literal, Protocol, Self, TypeVar, override, runtime_checkable +from typing import ( + Any, + Callable, + Concatenate, + Iterable, + Literal, + MutableMapping, + override, + Protocol, + runtime_checkable, + Self, + TypeVar, +) -import numpy as np -import numpy.typing as npt import jaxtyping as jt - +import numpy.typing as npt # NEW ArrayLike = npt.ArrayLike @@ -27,14 +37,28 @@ Shape = tuple[int, ...] ShapeLike = list[int] | tuple[int, ...] DType = Any # Can be numpy, jax, or mlx dtype -ChannelSpec = Any # Typically ShapeDType or compatible + + +class ChannelSpec(Protocol): + """Protocol for channel specifications.""" + + @property + def shape(self) -> Shape: + ... + + @property + def dtype(self) -> Any: + ... + + State = Any -Constants = Any -Emits = Any +Constants = MutableMapping[str, jt.PyTree[Array]] +Emits = jt.PyTree[Array] # TODO: Do these defaults do anything? apparently not ValuesT = TypeVar('ValuesT', bound=Array) MaskT = TypeVar('MaskT', bound=Array) +ChannelSpecT = TypeVar('ChannelSpecT', bound=ChannelSpec) LengthsT = TypeVar('LengthsT', bound=Array) # SequenceT = TypeVar('SequenceT', bound='Sequence[Array, Array]', default='Sequence[Array, Array]') @@ -164,9 +188,8 @@ class Sequence[ValuesT = Array, MaskT = Array](metaclass=abc.ABCMeta): values: ValuesT mask: MaskT - @abc.abstractmethod def __init__(self, values: ValuesT, mask: MaskT): - ... + raise NotImplementedError('Subclasses must implement __init__') @property @abc.abstractmethod @@ -292,6 +315,7 @@ class MaskedSequence(Sequence[ValuesT, MaskT]): """A sequence whose invalid timesteps are masked to zero.""" @abc.abstractmethod + @override def apply_values_masked[NewValuesT: Array, **P]( self, values_fn: Callable[Concatenate[ValuesT, P], NewValuesT], @@ -301,6 +325,7 @@ def apply_values_masked[NewValuesT: Array, **P]( ... @abc.abstractmethod + @override def apply_masked[NewValuesT: Array, NewMaskT: Array, **P]( self, apply_fn: Callable[Concatenate[ValuesT, P], tuple[NewValuesT, NewMaskT]], @@ -322,7 +347,11 @@ def copy(self, **kwargs: Any) -> Self: """Returns a copy of the config with updated fields.""" -class Steppable[InputT = Sequence, OutputT = Sequence](metaclass=abc.ABCMeta): +class Steppable[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](metaclass=abc.ABCMeta): """A sequence processing layer that can be executed layerwise or stepwise. The backend must implement: @@ -479,7 +508,7 @@ def step_with_emits( def get_initial_state( self, batch_size: int, - input_spec: ChannelSpec, + input_spec: ChannelSpecT, *, training: bool, constants: Constants | None = None, @@ -541,26 +570,50 @@ def get_output_dtype( The dtype of the output features. """ + @abc.abstractmethod + def get_output_spec( + self, + input_spec: ChannelSpecT, + *, + constants: Constants | None = None, + ) -> ChannelSpec: + """Returns the output spec this layer produces for the provided input spec. + + Args: + input_spec: A ChannelSpec which represents the channels shape and dtype of + the input sequence (i.e. not including the batch or time dimension). + constants: A dictionary of constant name to array or sl.Sequence. + Values or sequences that are "constant" with respect to the + SequenceLayer, but may affect its processing. + + Returns: + The ChannelSpec of the output features. + """ + @property @abc.abstractmethod def receptive_field(self) -> Any: ... -class SequenceLayer[InputT = Sequence, OutputT = Sequence]( - Steppable[InputT, OutputT] -): +class SequenceLayer[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](Steppable[InputT, OutputT, ChannelSpecT], metaclass=abc.ABCMeta): """Base class for Sequence Layers.""" - ... - # --------------------------------------------------------------------------- # Mixins # --------------------------------------------------------------------------- -class PreservesType(SequenceLayer): +class PreservesType[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](SequenceLayer[InputT, OutputT, ChannelSpecT]): """A mix-in for layers that do not change the input dtype.""" @abc.abstractmethod @@ -574,9 +627,11 @@ def get_output_dtype( ... -class PreservesShape[InputT = Sequence, OutputT = Sequence]( - SequenceLayer[InputT, OutputT] -): +class PreservesShape[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](SequenceLayer[InputT, OutputT, ChannelSpecT]): """A mix-in for layers that do not change the input channel shape.""" @abc.abstractmethod @@ -595,9 +650,11 @@ def get_output_shape( # --------------------------------------------------------------------------- -class Stateless[InputT = Sequence, OutputT = Sequence]( - SequenceLayer[InputT, OutputT] -): +class Stateless[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](SequenceLayer[InputT, OutputT, ChannelSpecT]): """A layer with no state over time required for step-wise processing. The backend must implement: @@ -639,7 +696,7 @@ def layer( def get_initial_state( self, batch_size: int, - input_spec: ChannelSpec, + input_spec: ChannelSpecT, *, training: bool, constants: Constants | None = None, @@ -659,15 +716,23 @@ def step( ... -class StatelessPointwise[InputT = Sequence, OutputT = Sequence]( - PreservesShape[InputT, OutputT], Stateless[InputT, OutputT] +class StatelessPointwise[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +]( + PreservesShape[InputT, OutputT, ChannelSpecT], + Stateless[InputT, OutputT, ChannelSpecT], + metaclass=abc.ABCMeta, ): """Stateless layer that operates pointwise (preserves shape).""" -class StatelessPointwiseFunctor[InputT = Sequence, OutputT = Sequence]( - StatelessPointwise[InputT, OutputT] -): +class StatelessPointwiseFunctor[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](StatelessPointwise[InputT, OutputT, ChannelSpecT]): """Stateless pointwise layer defined by a fn(values, mask). The backend must implement: @@ -706,9 +771,11 @@ def layer( # --------------------------------------------------------------------------- -class Emitting[InputT = Sequence, OutputT = Sequence]( - SequenceLayer[InputT, OutputT] -): +class Emitting[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](SequenceLayer[InputT, OutputT, ChannelSpecT]): """A Steppable layer that emits auxiliary arrays. This is a convenience subclass that implements step and layer in terms of @@ -723,6 +790,7 @@ class Emitting[InputT = Sequence, OutputT = Sequence]( """ @abc.abstractmethod + @override def step( self, x: InputT, @@ -734,6 +802,7 @@ def step( ... @abc.abstractmethod + @override def layer( self, x: InputT, @@ -744,6 +813,7 @@ def layer( ... @abc.abstractmethod + @override def step_with_emits( self, x: InputT, @@ -755,6 +825,7 @@ def step_with_emits( ... @abc.abstractmethod + @override def layer_with_emits( self, x: InputT, @@ -765,9 +836,11 @@ def layer_with_emits( ... -class StatelessEmitting[InputT = Sequence, OutputT = Sequence]( - Emitting[InputT, OutputT] -): +class StatelessEmitting[ + InputT = Sequence, + OutputT = Sequence, + ChannelSpecT: ChannelSpec = ChannelSpec, +](Emitting[InputT, OutputT, ChannelSpecT]): """A Steppable layer with no state over time that emits auxiliary arrays. The backend must implement: @@ -784,7 +857,7 @@ class StatelessEmitting[InputT = Sequence, OutputT = Sequence]( def get_initial_state( self, batch_size: int, - input_spec: ChannelSpec, + input_spec: ChannelSpecT, *, training: bool, constants: Constants | None = None, @@ -792,6 +865,7 @@ def get_initial_state( ... @abc.abstractmethod + @override def step_with_emits( self, x: InputT, @@ -803,18 +877,21 @@ def step_with_emits( ... @abc.abstractmethod + @override def get_output_shape( self, input_shape: ShapeLike, *, constants: Constants | None = None ) -> Shape: ... @abc.abstractmethod + @override def get_output_dtype( self, input_dtype: DType, *, constants: Constants | None = None ) -> DType: ... @abc.abstractmethod + @override def layer_with_emits( self, x: InputT, @@ -829,6 +906,8 @@ def layer_with_emits( class ModuleSpec(Protocol): """Specification for sequence_layers..types""" + # pylint: disable=invalid-name + @property def Sequence(self) -> type[Sequence]: ... diff --git a/sequence_layers/specs/types_behaviors.py b/sequence_layers/specs/types_behaviors.py index ece3d7a..844f3e7 100644 --- a/sequence_layers/specs/types_behaviors.py +++ b/sequence_layers/specs/types_behaviors.py @@ -3,7 +3,7 @@ import dataclasses import fractions -from typing import Any, override +from typing import Any, NamedTuple, override import unittest.mock from absl.testing import parameterized @@ -13,6 +13,13 @@ from sequence_layers.specs.test_utils import SequenceLayerTest +class DummyChannelSpec(NamedTuple): + """Dummy channel spec for testing.""" + + shape: spec.Shape + dtype: spec.DType + + class DefaultTestLayer(spec.SequenceLayer): """A default test layer for testing.""" @@ -126,6 +133,17 @@ def get_output_dtype( ) -> spec.DType: return np.float64 + @override + def get_output_spec( + self, + input_spec: Any, + *, + constants: spec.Constants | None = None, + ) -> Any: + shape = self.get_output_shape(input_spec.shape, constants=constants) + dtype = self.get_output_dtype(input_spec.dtype, constants=constants) + return DummyChannelSpec(shape, dtype) + class ModuleInterfaceTest(SequenceLayerTest): @@ -362,9 +380,7 @@ def create_steppable(self) -> spec.Steppable: """Creates a basic Steppable instance.""" backend_sl = self.sl - class DefaultSteppable( - DefaultTestLayer, backend_sl.types.Steppable - ): + class DefaultSteppable(DefaultTestLayer, backend_sl.types.Steppable): """Mock layer for testing.""" @override @@ -389,7 +405,13 @@ def test_steppable_defaults(self) -> None: self.assertEqual(layer.get_accumulated_input_latency(0), 0) self.assertEqual(layer.get_accumulated_output_latency(0), 0) - @override + def test_get_output_spec(self) -> None: + layer = self.create_steppable() + input_spec = DummyChannelSpec(shape=(2, 3), dtype=np.float32) + output_spec = layer.get_output_spec(input_spec) + self.assertEqual(output_spec.shape, (2, 3, 1)) + self.assertEqual(output_spec.dtype, np.float64) + def create_sequence(self) -> spec.Sequence: """Creates a test sequence.""" return self.sl.Sequence( @@ -488,9 +510,7 @@ def create_layer(self) -> spec.PreservesType: """Creates a preserves type layer.""" backend_sl = self.sl - class DummyLayer( - DefaultTestLayer, backend_sl.types.PreservesType - ): + class DummyLayer(DefaultTestLayer, backend_sl.types.PreservesType): """Mock layer for testing.""" @override @@ -512,9 +532,7 @@ def create_layer(self) -> spec.PreservesShape: """Creates a preserves shape layer.""" backend_sl = self.sl - class DummyLayer( - DefaultTestLayer, backend_sl.types.PreservesShape - ): + class DummyLayer(DefaultTestLayer, backend_sl.types.PreservesShape): """Mock layer for testing.""" @override @@ -532,7 +550,6 @@ def test_preserves_shape(self) -> None: class StatelessTest(SequenceLayerTest): - @override def create_sequence(self) -> spec.Sequence: """Creates a default test sequence.""" return self.sl.Sequence( @@ -543,9 +560,7 @@ def create_layer(self) -> spec.Stateless: """Creates a stateless layer.""" backend_sl = self.sl - class DummyLayer( - DefaultTestLayer, backend_sl.types.Stateless - ): + class DummyLayer(DefaultTestLayer, backend_sl.types.Stateless): """Mock layer for testing.""" @override @@ -565,7 +580,12 @@ def test_stateless_behaviors(self) -> None: # Initial state must be empty self.assertEqual( - layer.get_initial_state(32, 'fake_spec', training=False), () + layer.get_initial_state( + 32, + DummyChannelSpec(shape=(2, 3), dtype=np.float32), + training=False, + ), + (), ) # step unconditionally delegates to layer and returns identical empty state @@ -583,7 +603,6 @@ def test_stateless_behaviors(self) -> None: class EmittingTest(SequenceLayerTest): - @override def create_sequence(self) -> spec.Sequence: """Creates a default test sequence.""" return self.sl.Sequence( @@ -594,9 +613,7 @@ def create_layer(self) -> spec.Emitting: """Creates an emitting layer.""" backend_sl = self.sl - class DummyLayer( - DefaultTestLayer, backend_sl.types.Emitting - ): + class DummyLayer(DefaultTestLayer, backend_sl.types.Emitting): """Mock layer for testing.""" @override @@ -632,7 +649,6 @@ def test_emitting_drops_emits_on_standard_calls(self) -> None: class StatelessEmittingTest(SequenceLayerTest): - @override def create_sequence(self) -> spec.Sequence: """Creates a default test sequence.""" return self.sl.Sequence( @@ -643,9 +659,7 @@ def create_layer(self) -> spec.SequenceLayer: """Creates a stateless emitting layer.""" backend_sl = self.sl - class DummyLayer( - DefaultTestLayer, backend_sl.types.StatelessEmitting - ): + class DummyLayer(DefaultTestLayer, backend_sl.types.StatelessEmitting): """Mock layer for testing.""" @override @@ -666,7 +680,12 @@ def test_stateless_emitting_behaviors(self) -> None: layer = self.create_layer() self.assertEqual( - layer.get_initial_state(32, 'fake_spec', training=False), () + layer.get_initial_state( + 32, + DummyChannelSpec(shape=(2, 3), dtype=np.float32), + training=False, + ), + (), ) x = self.create_sequence() @@ -717,7 +736,6 @@ def fn(self, values: Any, mask: Any) -> tuple[Any, Any]: return DummyLayer() - @override def create_sequence(self) -> spec.Sequence[spec.Array, spec.Array]: """Creates a test sequence.""" return self.sl.Sequence( From 42173b362db4d3584812bad2d5dbb2cb13c4b99b Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Thu, 9 Apr 2026 16:50:34 -0700 Subject: [PATCH 8/9] chore: resolve lint warnings, update config, and apply auto-formatting --- pyproject.toml | 13 +++- sequence_layers/jax/__init__.py | 8 +- sequence_layers/jax/backend.py | 6 +- sequence_layers/jax/test_utils.py | 6 +- sequence_layers/jax/types.py | 110 +++++++++++++++++++--------- sequence_layers/jax/typing.py | 34 ++++++--- sequence_layers/jax/utils.py | 5 +- sequence_layers/mlx/__init__.py | 7 +- sequence_layers/mlx/backend.py | 6 +- sequence_layers/mlx/test_utils.py | 12 ++- sequence_layers/mlx/types.py | 110 +++++++++++++--------------- sequence_layers/mlx/types_test.py | 1 - sequence_layers/specs/backend.py | 8 +- sequence_layers/specs/test_utils.py | 2 + sequence_layers/specs/types.py | 101 ++++++++++++------------- 15 files changed, 245 insertions(+), 184 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3ce4706..49a2a0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dev = [ "absl-py>=2.4.0", "chex", "orbax", + "isort", "pyink", "pylint>=2.6.0", "pyrefly>=0.58.0", @@ -68,11 +69,21 @@ unstable = true pyink-indentation = 2 pyink-use-majority-quotes = true +[tool.isort] +profile = "google" +line_length = 80 + +[tool.pylint.master] +extension-pkg-whitelist = ["mlx", "mlx.core"] + [tool.pylint.format] indent-string = " " [tool.pylint.basic] -no-docstring-rgx = "^(_)?test_|^.*Test$" +no-docstring-rgx = "^(_)?test_|^.*Test$|^__.*__$" + +[tool.pylint.messages_control] +disable = ["too-many-lines", "too-many-ancestors", "too-few-public-methods", "duplicate-code"] diff --git a/sequence_layers/jax/__init__.py b/sequence_layers/jax/__init__.py index b922ee7..dcbb3a0 100644 --- a/sequence_layers/jax/__init__.py +++ b/sequence_layers/jax/__init__.py @@ -13,10 +13,6 @@ # limitations under the License. """Sequence layers in JAX.""" -# (re-export the names for typechecking) -from . import backend as backend -from . import types as types - # pylint: disable=wildcard-import from sequence_layers.jax.attention import * from sequence_layers.jax.combinators import * @@ -31,3 +27,7 @@ from sequence_layers.jax.simple import * from sequence_layers.jax.time_varying import * from sequence_layers.jax.types import * + +# (re-export the names for typechecking) +from . import backend as backend +from . import types as types diff --git a/sequence_layers/jax/backend.py b/sequence_layers/jax/backend.py index 4efdc75..54d0b83 100644 --- a/sequence_layers/jax/backend.py +++ b/sequence_layers/jax/backend.py @@ -1,20 +1,24 @@ """Backend-specific helpers (JAX)""" +from typing import override + import jax.numpy as jnp from sequence_layers.specs import backend from sequence_layers.specs import types as types_spec -class BackendWrapper: +class BackendWrapper(backend.xp): """Thin wrapper around JAX to match NumPy interface for tests.""" bool_ = jnp.bool_ int32 = jnp.int32 + @override def array(self, a, dtype=None) -> types_spec.Array: return jnp.array(a, dtype=dtype) + @override def zeros(self, shape, dtype=None) -> types_spec.Array: return jnp.zeros(shape, dtype=dtype) diff --git a/sequence_layers/jax/test_utils.py b/sequence_layers/jax/test_utils.py index 567c375..13add8f 100644 --- a/sequence_layers/jax/test_utils.py +++ b/sequence_layers/jax/test_utils.py @@ -18,7 +18,9 @@ import itertools import logging import random -from typing import Any, Callable, Iterable, Mapping, Sequence as TypingSequence, TypeVar +from typing import Any, Callable, Iterable, Mapping +from typing import Sequence as TypingSequence +from typing import TypeVar from absl.testing import absltest from absl.testing import parameterized @@ -28,10 +30,10 @@ import jax.numpy as jnp import numpy as np -import sequence_layers.jax as sl from sequence_layers.jax import types from sequence_layers.jax import typing as jt from sequence_layers.jax import utils +import sequence_layers.jax as sl from sequence_layers.specs import test_utils as spec _SequenceLayerT = TypeVar('_SequenceLayerT', bound=types.SequenceLayer) diff --git a/sequence_layers/jax/types.py b/sequence_layers/jax/types.py index 86add8f..1c82b6c 100644 --- a/sequence_layers/jax/types.py +++ b/sequence_layers/jax/types.py @@ -20,7 +20,21 @@ import functools import math import typing -from typing import Any, Callable, Generic, Iterable, Literal, MutableMapping, ParamSpec, Protocol, Self, Sequence as TypingSequence, TypeVar, override +from typing import ( + Any, + Callable, + Concatenate, + Generic, + Iterable, + Literal, + MutableMapping, + override, + ParamSpec, + Protocol, + Self, +) +from typing import Sequence as TypingSequence +from typing import TypeVar from absl import logging from flax import linen as nn @@ -29,12 +43,11 @@ from jax import numpy as jnp import jaxtyping import numpy as np +import typeguard -from sequence_layers.specs import types as spec from sequence_layers.jax import sharding as sharding_lib from sequence_layers.jax import typing as jt -import typeguard - +from sequence_layers.specs import types as spec __all__ = ( # go/keep-sorted start @@ -254,8 +267,8 @@ def sequence_mask(lengths: LengthsT, maxlen: int) -> MaskT: ) -class Sequence[ValuesT, MaskT]( - spec.Sequence[ValuesT, MaskT], struct.PyTreeNode +class Sequence( + Generic[ValuesT, MaskT], spec.Sequence[ValuesT, MaskT], struct.PyTreeNode ): """A generic sequence container that preserves masking information.""" @@ -346,6 +359,7 @@ def dtype(self) -> DType: return self.values.dtype @classmethod + @override def from_lengths( cls, values: ValuesT, lengths: LengthsT, is_masked: bool = False ) -> 'Sequence': @@ -407,7 +421,7 @@ def apply_values( return Sequence(values_fn(self.values, *args, **kwargs), self.mask) @override - def apply_values_masked( + def apply_values_masked( # pyrefly: ignore[bad-override] self: SequenceSelf, values_fn: Callable[..., ValuesT], *args: ApplyValuesMaskedParams.args, @@ -428,7 +442,7 @@ def apply( return Sequence(values, mask) @override - def apply_masked( + def apply_masked( # pyrefly: ignore[bad-override] self: SequenceSelf, apply_fn: Callable[..., tuple[ValuesT, MaskT]], *args: ApplyMaskedParams.args, @@ -473,7 +487,7 @@ def __getitem__( ) @override - def pad_time( + def pad_time( # pyrefly: ignore[bad-override] self: SequenceSelf, pad_left: jt.ScalarInt, pad_right: jt.ScalarInt, @@ -539,7 +553,7 @@ def unmask(self) -> 'Sequence': return self -class MaskedSequence( +class MaskedSequence( # pyrefly: ignore[inconsistent-inheritance] Sequence[ValuesT, MaskT], spec.MaskedSequence[ValuesT, MaskT] ): """Sequence whose invalid timesteps are masked to zero.""" @@ -559,21 +573,19 @@ def unmask(self) -> Sequence: def mask_invalid( - sequence: Sequence, + self: Sequence, mask_value: complex | None = None, ) -> 'Sequence': """Returns a sequence whose invalid timesteps are replaced with mask_value.""" - expanded_mask = sequence.expanded_mask() + expanded_mask = self.expanded_mask() if mask_value is None: - masked_values = jnp.zeros_like(sequence.values) + masked_values = jnp.zeros_like(self.values) result_type = MaskedSequence else: - masked_values = jnp.full( - sequence.values.shape, mask_value, sequence.values.dtype - ) + masked_values = jnp.full(self.values.shape, mask_value, self.values.dtype) result_type = Sequence - masked_values = jnp.where(expanded_mask, sequence.values, masked_values) - return result_type(masked_values, sequence.mask) + masked_values = jnp.where(expanded_mask, self.values, masked_values) + return result_type(masked_values, self.mask) # Defined outside of Sequence so that mask_invalid can return a MaskedSequence. @@ -617,22 +629,22 @@ def _sequence_checker_fn(value, origin_type, args, memo): values_dtype, mask_dtype = args try: - typeguard.check_type_internal( + typeguard.check_type_internal( # pyrefly: ignore[missing-attribute] value.values, values_dtype, memo=memo, ) - except typeguard.TypeCheckError as exc: + except typeguard.TypeCheckError as exc: # pyrefly: ignore[missing-attribute] exc.append_path_element('values') raise try: - typeguard.check_type_internal( + typeguard.check_type_internal( # pyrefly: ignore[missing-attribute] value.mask, mask_dtype, memo=memo, ) - except typeguard.TypeCheckError as exc: + except typeguard.TypeCheckError as exc: # pyrefly: ignore[missing-attribute] exc.append_path_element('mask') raise @@ -675,7 +687,7 @@ def _add_custom_checker_lookup_fn(lookup_fn): _add_custom_checker_lookup_fn(_sequence_checker_lookup_fn) -class Steppable(spec.Steppable): +class Steppable(spec.Steppable[Sequence, Sequence, ChannelSpec]): """A sequence processing layer that can be executed layerwise or stepwise. # Step-wise execution: @@ -741,7 +753,7 @@ class Steppable(spec.Steppable): ``` """ - path: str # Provided by nn.Module. + path: tuple[str, ...] # Provided by nn.Module. @property @override @@ -945,6 +957,7 @@ def __call__( return self.layer(x, training=training, constants=constants) @abc.abstractmethod + @override def step( self, x: Sequence, @@ -1010,6 +1023,7 @@ def step_with_emits( return outputs, state, () @abc.abstractmethod + @override def get_initial_state( self, batch_size: int, @@ -1107,6 +1121,7 @@ def get_output_dtype( """Returns the layer's output dtype for the specified input dtype.""" @nn.nowrap + @override def get_output_spec( self, input_spec: ChannelSpec, @@ -1273,11 +1288,15 @@ def check_step_with_emits_fn( return check_step_with_emits_fn -class SequenceLayer(nn.Module, Steppable, spec.SequenceLayer): +class SequenceLayer( + nn.Module, Steppable, spec.SequenceLayer[Sequence, Sequence, ChannelSpec] +): """Base Module for Sequence Layers.""" -class PreservesType(SequenceLayer, spec.PreservesType): +class PreservesType( + SequenceLayer, spec.PreservesType[Sequence, Sequence, ChannelSpec] +): """A mix-in for layers that do not change the input dtype.""" @nn.nowrap @@ -1289,7 +1308,9 @@ def get_output_dtype( return input_dtype -class PreservesShape(SequenceLayer, spec.PreservesShape): +class PreservesShape( + SequenceLayer, spec.PreservesShape[Sequence, Sequence, ChannelSpec] +): """A mix-in for layers that do not change the input shape.""" @nn.nowrap @@ -1301,7 +1322,7 @@ def get_output_shape( return tuple(input_shape) -class Emitting(SequenceLayer, spec.Emitting): +class Emitting(SequenceLayer, spec.Emitting[Sequence, Sequence, ChannelSpec]): """A SequenceLayer that emits auxiliary arrays. This is a convenience subclass that implements step and layer in terms of @@ -1362,7 +1383,7 @@ def layer_with_emits( pass -class Stateless(SequenceLayer, spec.Stateless): +class Stateless(SequenceLayer, spec.Stateless[Sequence, Sequence, ChannelSpec]): """A SequenceLayer with no state over time required for step-wise processing. Sub-classes must only implement: @@ -1372,6 +1393,7 @@ class Stateless(SequenceLayer, spec.Stateless): """ @property + @override def receptive_field_per_step(self) -> dict[int, ReceptiveField]: return {0: (0, 0)} @@ -1427,7 +1449,9 @@ def layer( pass -class StatelessEmitting(Emitting, spec.StatelessEmitting): +class StatelessEmitting( + Emitting, spec.StatelessEmitting[Sequence, Sequence, ChannelSpec] +): """A SequenceLayer with no state over time that emits auxiliary arrays. Sub-classes must only implement: @@ -1437,6 +1461,7 @@ class StatelessEmitting(Emitting, spec.StatelessEmitting): """ @property + @override def receptive_field_per_step(self) -> dict[int, ReceptiveField]: return {0: (0, 0)} @@ -1454,6 +1479,7 @@ def step_with_emits( ) return outputs, state, emits + @override def get_initial_state( self, batch_size: int, @@ -1469,42 +1495,52 @@ def get_initial_state( return () @abc.abstractmethod + @override def get_output_shape( self, input_shape: ShapeLike, *, constants: Constants | None = None ) -> Shape: pass @abc.abstractmethod + @override def get_output_dtype( self, input_dtype: DType, *, constants: Constants | None = None ) -> DType: pass @abc.abstractmethod + @override def layer_with_emits( self, - x: Sequence, + x: Sequence[ValuesT, MaskT], *, training: bool, constants: Constants | None = None, - ) -> tuple[Sequence, Emits]: + ) -> tuple[Sequence[ValuesT, MaskT], Emits]: pass -class StatelessPointwise(PreservesShape, Stateless, spec.StatelessPointwise): +class StatelessPointwise( + PreservesShape, + Stateless, + spec.StatelessPointwise[Sequence, Sequence, ChannelSpec], +): """A SequenceLayer that has no state and operates pointwise on its input.""" class StatelessPointwiseFunctor( - StatelessPointwise, spec.StatelessPointwiseFunctor + StatelessPointwise, + spec.StatelessPointwiseFunctor[Sequence, Sequence, ChannelSpec], ): """A stateless SequenceLayer for simple pointwise processing fns.""" @abc.abstractmethod + @override def fn(self, values: ValuesT, mask: MaskT) -> tuple[ValuesT, MaskT]: """Transforms each scalar in values independently.""" @property + @override def mask_required(self): """Returns true if fn can change the sequence's masked state. @@ -1513,7 +1549,8 @@ def mask_required(self): return True @check_layer - def layer( + @override + def layer( # pyrefly: ignore[missing-override-decorator] self, x: Sequence, *, @@ -1542,9 +1579,11 @@ class SequenceLayerConfig(spec.SequenceLayerConfig): """ @abc.abstractmethod + @override def make(self) -> SequenceLayer: """Builds a SequenceLayer from this config.""" + @override def copy(self, **kwargs) -> Self: """Create a copy of this config. @@ -1576,6 +1615,7 @@ def copy(self, **kwargs) -> Self: ) try: + assert not isinstance(self, type), 'replace must be called on an instance' return typing.cast(Self, dataclasses.replace(self, **kwargs)) except TypeError as type_error: raise AttributeError( diff --git a/sequence_layers/jax/typing.py b/sequence_layers/jax/typing.py index 9b75781..a83a2ba 100644 --- a/sequence_layers/jax/typing.py +++ b/sequence_layers/jax/typing.py @@ -13,27 +13,37 @@ # limitations under the License. """Wrappers for making jaxtyping easier to use and understand.""" +from typing import Callable, TYPE_CHECKING, TypeVar, Union + import jax import jax.numpy as jnp -from jaxtyping import AbstractDtype, Bool, config as jaxtyping_config, Float, Int, PyTree, Shaped, jaxtyped, TypeCheckError +from jaxtyping import AbstractDtype +from jaxtyping import Bool +from jaxtyping import config as jaxtyping_config +from jaxtyping import Float +from jaxtyping import Int +from jaxtyping import jaxtyped +from jaxtyping import PyTree +from jaxtyping import Shaped +from jaxtyping import TypeCheckError import numpy as np import typeguard -from typing import Callable, TypeVar, Union - - -class _MetaArrayT(type): - types = () - def __instancecheck__(cls, obj): - return isinstance(obj, cls.types) +if TYPE_CHECKING: + ArrayT = jax.Array | np.ndarray +else: + class _MetaArrayT(type): + types = () -class JaxArrayT(metaclass=_MetaArrayT): - types = (jax.Array, jax.ShapeDtypeStruct) + def __instancecheck__(cls, obj): + return isinstance(obj, cls.types) + class JaxArrayT(metaclass=_MetaArrayT): + types = (jax.Array, jax.ShapeDtypeStruct) -class ArrayT(metaclass=_MetaArrayT): - types = (JaxArrayT, np.ndarray) + class ArrayT(metaclass=_MetaArrayT): + types = (JaxArrayT, np.ndarray) Scalar = Shaped[ArrayT, ''] | Shaped[np.generic, ''] | Shaped[jnp.generic, ''] diff --git a/sequence_layers/jax/utils.py b/sequence_layers/jax/utils.py index 23b549d..c4e0b40 100644 --- a/sequence_layers/jax/utils.py +++ b/sequence_layers/jax/utils.py @@ -21,13 +21,16 @@ import pprint import re import typing -from typing import Any, Callable, Protocol, Self, Sequence as TypingSequence, TypeVar +from typing import Any, Callable, Protocol, Self +from typing import Sequence as TypingSequence +from typing import TypeVar import flax.core.scope import flax.linen as nn import jax import jax.numpy as jnp import numpy as np + from sequence_layers.jax import meta from sequence_layers.jax import types from sequence_layers.jax import typing as jt diff --git a/sequence_layers/mlx/__init__.py b/sequence_layers/mlx/__init__.py index 58124eb..95f38af 100644 --- a/sequence_layers/mlx/__init__.py +++ b/sequence_layers/mlx/__init__.py @@ -13,8 +13,7 @@ # limitations under the License. """Sequence layers in MLX.""" -# (re-export the names for typechecking) -from . import backend as backend -from . import types as types - from sequence_layers.mlx.types import * + +from . import backend +from . import types diff --git a/sequence_layers/mlx/backend.py b/sequence_layers/mlx/backend.py index 4d96407..76531f6 100644 --- a/sequence_layers/mlx/backend.py +++ b/sequence_layers/mlx/backend.py @@ -1,20 +1,24 @@ """Backend-specific helpers (MLX)""" +from typing import override + import mlx.core as mx from sequence_layers.specs import backend from sequence_layers.specs import types as types_spec -class BackendWrapper: +class BackendWrapper(backend.xp): """Thin wrapper around MLX to match NumPy interface for tests.""" bool_ = mx.bool_ int32 = mx.int32 + @override def array(self, a, dtype=None) -> types_spec.Array: return mx.array(a, dtype=dtype) + @override def zeros(self, shape, dtype=None) -> types_spec.Array: return mx.zeros(shape, dtype=dtype) diff --git a/sequence_layers/mlx/test_utils.py b/sequence_layers/mlx/test_utils.py index 29cf520..9953b74 100644 --- a/sequence_layers/mlx/test_utils.py +++ b/sequence_layers/mlx/test_utils.py @@ -1,10 +1,12 @@ """Test utilities for MLX sequence layers.""" from typing import override -import numpy as np + import mlx.core as mx -import sequence_layers.mlx as sl +import numpy as np + from sequence_layers.mlx import types +import sequence_layers.mlx as sl from sequence_layers.specs import test_utils as spec @@ -26,7 +28,7 @@ def _mask_and_pad_to_max_length( class SequenceLayerTest(spec.SequenceLayerTest): """Base class for MLX SequenceLayer tests.""" - sl = sl + sl = sl # pyrefly: ignore[bad-assignment] # module-as-protocol @override def assertAllEqual(self, x, y): @@ -36,7 +38,9 @@ def assertAllEqual(self, x, y): np.testing.assert_array_equal(x_np, y_np) @override - def assertSequencesEqual(self, x: types.Sequence, y: types.Sequence): + def assertSequencesEqual( # pyrefly: ignore[bad-override] + self, x: types.Sequence, y: types.Sequence + ): """After padding, checks sequence values are equal and masks are equal.""" x, y = _mask_and_pad_to_max_length(x, y) self.assertAllEqual(x.values, y.values) diff --git a/sequence_layers/mlx/types.py b/sequence_layers/mlx/types.py index e7dda5d..80121d9 100644 --- a/sequence_layers/mlx/types.py +++ b/sequence_layers/mlx/types.py @@ -10,7 +10,6 @@ Any, Callable, cast, - Generic, Iterable, MutableMapping, override, @@ -201,7 +200,6 @@ def concatenate_sequences(cls, sequences: Iterable['Sequence']) -> 'Sequence': @override def expanded_mask(self) -> mx.array: """Returns the Sequence mask expanded to match values rank.""" - print(self, type(self), dir(self), self.mask, type(self.mask)) return self.mask.reshape(self.mask.shape + (1,) * (self.values.ndim - 2)) @override @@ -446,9 +444,7 @@ def wrapper(self, x, state, *, training: bool, constants=None): # --------------------------------------------------------------------------- -class Steppable[InputT: Sequence, OutputT: Sequence]( - spec.Steppable[InputT, OutputT, ChannelSpec] -): +class Steppable(spec.Steppable[Sequence, Sequence, ChannelSpec]): """A sequence processing layer that can be executed layerwise or stepwise. # Step-wise execution: @@ -571,8 +567,8 @@ def receptive_field(self) -> ReceptiveField: @abc.abstractmethod @override def layer( - self, x: InputT, *, training: bool, constants: Constants | None = None - ) -> OutputT: + self, x: Sequence, *, training: bool, constants: Constants | None = None + ) -> Sequence: """Process this layer layer-wise. Args: @@ -591,8 +587,8 @@ def layer( @override def layer_with_emits( - self, x: InputT, *, training: bool, constants: Constants | None = None - ) -> tuple[OutputT, Emits]: + self, x: Sequence, *, training: bool, constants: Constants | None = None + ) -> tuple[Sequence, Emits]: """Process this layer layer-wise, producing emitted arrays. This is like `layer`, except it has an additional return value which is the @@ -619,12 +615,12 @@ def layer_with_emits( @override def step( self, - x: InputT, + x: Sequence, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[OutputT, State]: + ) -> tuple[Sequence, State]: """Process this layer step-wise. Args: @@ -648,12 +644,12 @@ def step( @override def step_with_emits( self, - x: InputT, + x: Sequence, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[OutputT, State, Emits]: + ) -> tuple[Sequence, State, Emits]: """Process this layer step-wise, producing emitted arrays. This is like `step`, except it has an additional return value which is the @@ -778,10 +774,10 @@ def get_output_spec( # --------------------------------------------------------------------------- -class SequenceLayer[InputT: Sequence, OutputT: Sequence]( +class SequenceLayer( nn.Module, - Steppable[InputT, OutputT], - spec.SequenceLayer[InputT, OutputT, ChannelSpec], + Steppable, + spec.SequenceLayer[Sequence, Sequence, ChannelSpec], metaclass=abc.ABCMeta, ): """Base Module for Sequence Layers.""" @@ -806,9 +802,9 @@ def copy(self, **kwargs) -> Self: # --------------------------------------------------------------------------- -class PreservesType[InputT: Sequence, OutputT: Sequence]( - SequenceLayer[InputT, OutputT], - spec.PreservesType[InputT, OutputT, ChannelSpec], +class PreservesType( + SequenceLayer, + spec.PreservesType[Sequence, Sequence, ChannelSpec], metaclass=abc.ABCMeta, ): """A mix-in for layers that do not change the input dtype.""" @@ -824,9 +820,9 @@ def get_output_dtype( return input_dtype -class PreservesShape[InputT: Sequence, OutputT: Sequence]( - SequenceLayer[InputT, OutputT], - spec.PreservesShape[InputT, OutputT, ChannelSpec], +class PreservesShape( + SequenceLayer, + spec.PreservesShape[Sequence, Sequence, ChannelSpec], metaclass=abc.ABCMeta, ): """A mix-in for layers that do not change the input shape.""" @@ -847,9 +843,7 @@ def get_output_shape( # --------------------------------------------------------------------------- -class Stateless[InputT: Sequence, OutputT: Sequence]( - SequenceLayer[InputT, OutputT], spec.Stateless[InputT, OutputT, ChannelSpec] -): +class Stateless(SequenceLayer, spec.Stateless[Sequence, Sequence, ChannelSpec]): """A SequenceLayer with no state over time required for step-wise processing. Sub-classes must also implement: @@ -897,37 +891,37 @@ def get_output_dtype( @override def layer( self, - x: InputT, + x: Sequence, *, training: bool, constants: Constants | None = None, - ) -> OutputT: + ) -> Sequence: ... @override def step( self, - x: InputT, + x: Sequence, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[OutputT, State]: + ) -> tuple[Sequence, State]: return self.layer(x, training=training, constants=constants), state -class StatelessPointwise[InputT: Sequence, OutputT: Sequence]( - PreservesShape[InputT, OutputT], - Stateless[InputT, OutputT], - spec.StatelessPointwise[InputT, OutputT, ChannelSpec], +class StatelessPointwise( + PreservesShape, + Stateless, + spec.StatelessPointwise[Sequence, Sequence, ChannelSpec], metaclass=abc.ABCMeta, ): """A SequenceLayer that has no state and operates pointwise on its input.""" -class StatelessPointwiseFunctor[InputT: Sequence, OutputT: Sequence]( - StatelessPointwise[InputT, OutputT], - spec.StatelessPointwiseFunctor[InputT, OutputT, ChannelSpec], +class StatelessPointwiseFunctor( + StatelessPointwise, + spec.StatelessPointwiseFunctor[Sequence, Sequence, ChannelSpec], ): """A stateless SequenceLayer for simple pointwise processing fns.""" @@ -945,15 +939,15 @@ def mask_required(self): """ return True - @override @check_layer + @override def layer( # pyrefly: ignore[missing-override-decorator] self, - x: InputT, + x: Sequence, *, training: bool, constants: Constants | None = None, - ) -> OutputT: + ) -> Sequence: del training if self.mask_required: y = x.apply(self.fn) @@ -962,7 +956,7 @@ def layer( # pyrefly: ignore[missing-override-decorator] # Ensure MaskedSequence -> Sequence conversion for apply. if isinstance(y, MaskedSequence) and self.mask_required: y = Sequence(y.values, y.mask) - return cast(OutputT, y) + return cast(Sequence, y) # --------------------------------------------------------------------------- @@ -970,9 +964,9 @@ def layer( # pyrefly: ignore[missing-override-decorator] # --------------------------------------------------------------------------- -class Emitting[InputT: Sequence, OutputT: Sequence]( - SequenceLayer[InputT, OutputT], - spec.Emitting[InputT, OutputT, ChannelSpec], +class Emitting( + SequenceLayer, + spec.Emitting[Sequence, Sequence, ChannelSpec], ): """A SequenceLayer that emits auxiliary arrays. @@ -1019,34 +1013,34 @@ def get_output_dtype( @override def step_with_emits( self, - x: InputT, + x: Sequence, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[OutputT, State, Emits]: + ) -> tuple[Sequence, State, Emits]: ... @abc.abstractmethod @override def layer_with_emits( self, - x: InputT, + x: Sequence, *, training: bool, constants: Constants | None = None, - ) -> tuple[OutputT, Emits]: + ) -> tuple[Sequence, Emits]: ... @override def step( self, - x: InputT, + x: Sequence, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[OutputT, State]: + ) -> tuple[Sequence, State]: output, state, _ = self.step_with_emits( x, state, training=training, constants=constants ) @@ -1055,20 +1049,20 @@ def step( @override def layer( self, - x: InputT, + x: Sequence, *, training: bool, constants: Constants | None = None, - ) -> OutputT: + ) -> Sequence: outputs, _ = self.layer_with_emits( x, training=training, constants=constants ) return outputs -class StatelessEmitting[InputT: Sequence, OutputT: Sequence]( - Emitting[InputT, OutputT], - spec.StatelessEmitting[InputT, OutputT, ChannelSpec], +class StatelessEmitting( + Emitting, + spec.StatelessEmitting[Sequence, Sequence, ChannelSpec], ): """A SequenceLayer with no state over time that emits auxiliary arrays. @@ -1102,11 +1096,11 @@ def get_output_dtype( @override def layer_with_emits( self, - x: InputT, + x: Sequence, *, training: bool, constants: Constants | None = None, - ) -> tuple[OutputT, Emits]: + ) -> tuple[Sequence, Emits]: ... @override @@ -1127,12 +1121,12 @@ def get_initial_state( @override def step_with_emits( self, - x: InputT, + x: Sequence, state: State, *, training: bool, constants: Constants | None = None, - ) -> tuple[OutputT, State, Emits]: + ) -> tuple[Sequence, State, Emits]: outputs, emits = self.layer_with_emits( x, training=training, constants=constants ) diff --git a/sequence_layers/mlx/types_test.py b/sequence_layers/mlx/types_test.py index d6a903a..da6163c 100644 --- a/sequence_layers/mlx/types_test.py +++ b/sequence_layers/mlx/types_test.py @@ -54,6 +54,5 @@ class StatelessPointwiseFunctorTest( pass - if __name__ == '__main__': absltest.main() diff --git a/sequence_layers/specs/backend.py b/sequence_layers/specs/backend.py index bea6cd9..fb64595 100644 --- a/sequence_layers/specs/backend.py +++ b/sequence_layers/specs/backend.py @@ -4,10 +4,10 @@ from sequence_layers.specs import types as types_spec - Array = types_spec.Array +# pylint: disable=invalid-name class xp(Protocol): """NumPy-compatible interface to enable generic behavior tests. @@ -19,10 +19,10 @@ class xp(Protocol): int32: Any def array(self, a: Any, dtype: Any = None) -> Array: - ... + """Creates an array.""" def zeros(self, shape: tuple[int, ...], dtype: Any = None) -> Array: - ... + """Creates an array of zeros.""" @runtime_checkable @@ -31,7 +31,7 @@ class ModuleSpec(Protocol): @property def xp(self) -> xp: - ... + """Returns the NumPy-compatible interface.""" __all__ = [ diff --git a/sequence_layers/specs/test_utils.py b/sequence_layers/specs/test_utils.py index 90596d3..6bd2ce4 100644 --- a/sequence_layers/specs/test_utils.py +++ b/sequence_layers/specs/test_utils.py @@ -2,7 +2,9 @@ import abc from typing import Any + from absl.testing import parameterized + from sequence_layers import specs from sequence_layers.specs import backend as backend_spec from sequence_layers.specs import types as spec diff --git a/sequence_layers/specs/types.py b/sequence_layers/specs/types.py index a8b207b..5d5e8be 100644 --- a/sequence_layers/specs/types.py +++ b/sequence_layers/specs/types.py @@ -10,24 +10,13 @@ import enum import fractions from types import EllipsisType -from typing import ( - Any, - Callable, - Concatenate, - Iterable, - Literal, - MutableMapping, - override, - Protocol, - runtime_checkable, - Self, - TypeVar, -) +from typing import (Any, Callable, Concatenate, Iterable, Literal, + MutableMapping, override, Protocol, runtime_checkable, Self, + TypeVar) import jaxtyping as jt import numpy.typing as npt -# NEW ArrayLike = npt.ArrayLike Array = jt.Shaped[Any, '...'] @@ -44,24 +33,24 @@ class ChannelSpec(Protocol): @property def shape(self) -> Shape: - ... + """The shape of the channel.""" @property def dtype(self) -> Any: - ... + """The dtype of the channel.""" State = Any Constants = MutableMapping[str, jt.PyTree[Array]] Emits = jt.PyTree[Array] -# TODO: Do these defaults do anything? apparently not + ValuesT = TypeVar('ValuesT', bound=Array) MaskT = TypeVar('MaskT', bound=Array) ChannelSpecT = TypeVar('ChannelSpecT', bound=ChannelSpec) LengthsT = TypeVar('LengthsT', bound=Array) -# SequenceT = TypeVar('SequenceT', bound='Sequence[Array, Array]', default='Sequence[Array, Array]') + InputT = TypeVar('InputT', bound='Sequence') OutputT = TypeVar('OutputT', bound='Sequence') @@ -194,27 +183,27 @@ def __init__(self, values: ValuesT, mask: MaskT): @property @abc.abstractmethod def shape(self) -> Shape: - ... + """The shape of the sequence as (batch, time, ...channels).""" @property @abc.abstractmethod def ndim(self) -> int: - ... + """The number of dimensions of the sequence values.""" @property @abc.abstractmethod def channel_shape(self) -> Shape: - ... + """The shape of the channels in the sequence.""" @property @abc.abstractmethod def dtype(self) -> DType: - ... + """The dtype of the sequence values.""" @classmethod @abc.abstractmethod def from_values(cls, values: ValuesT) -> Self: - ... + """Creates a Sequence from values with a default mask.""" @classmethod @abc.abstractmethod @@ -224,16 +213,16 @@ def from_lengths( lengths: LengthsT, is_masked: bool = False, ) -> Self: - ... + """Creates a Sequence from values and lengths.""" @classmethod @abc.abstractmethod def concatenate_sequences(cls, sequences: Iterable[Self]) -> Self: - ... + """Concatenates multiple sequences into one.""" @abc.abstractmethod def expanded_mask(self) -> Any: - ... + """Returns the mask expanded to the shape of the values.""" @abc.abstractmethod def apply_values[NewValuesT: Array, **P]( @@ -242,7 +231,7 @@ def apply_values[NewValuesT: Array, **P]( *args: P.args, **kwargs: P.kwargs, ) -> 'Sequence[NewValuesT, MaskT]': - ... + """Applies a function to the sequence values.""" @abc.abstractmethod def apply_values_masked[NewValuesT: Array, **P]( @@ -251,7 +240,7 @@ def apply_values_masked[NewValuesT: Array, **P]( *args: P.args, **kwargs: P.kwargs, ) -> 'Sequence[NewValuesT, MaskT]': - ... + """Applies a function to the sequence values, respecting the mask.""" @abc.abstractmethod def apply[NewValuesT: Array, NewMaskT: Array, **P]( @@ -260,7 +249,7 @@ def apply[NewValuesT: Array, NewMaskT: Array, **P]( *args: P.args, **kwargs: P.kwargs, ) -> 'Sequence[NewValuesT, NewMaskT]': - ... + """Applies a function to both values and mask.""" @abc.abstractmethod def apply_masked[NewValuesT: Array, NewMaskT: Array, **P]( @@ -269,15 +258,15 @@ def apply_masked[NewValuesT: Array, NewMaskT: Array, **P]( *args: P.args, **kwargs: P.kwargs, ) -> 'Sequence[NewValuesT, NewMaskT]': - ... + """Applies a function to values and mask, respecting the mask.""" @abc.abstractmethod def astype(self, dtype: DType | None) -> Self: - ... + """Returns a copy of the sequence with a new dtype.""" @abc.abstractmethod def lengths(self) -> Any: - ... + """Returns the lengths of the sequences in the batch.""" @abc.abstractmethod def __getitem__( @@ -294,21 +283,21 @@ def pad_time( valid: bool, pad_value: Any | None = None, ) -> Self: - ... + """Pads the sequence along the time dimension.""" @abc.abstractmethod def concatenate(self, other: Self) -> Self: - ... + """Concatenates another sequence to this one.""" @abc.abstractmethod def mask_invalid( self, mask_value: Any | None = None ) -> 'Sequence[ValuesT, MaskT]': - ... + """Returns a MaskedSequence with invalid timesteps zeroed.""" @abc.abstractmethod def unmask(self) -> 'Sequence[ValuesT, MaskT]': - ... + """Returns a Sequence with no masking applied.""" class MaskedSequence(Sequence[ValuesT, MaskT]): @@ -362,35 +351,35 @@ class Steppable[ @property @abc.abstractmethod def block_size(self) -> int: - ... + """The block size this layer processes at once.""" @property @abc.abstractmethod def output_ratio(self) -> fractions.Fraction: - ... + """The ratio of output timesteps to input timesteps.""" @property @abc.abstractmethod def supports_step(self) -> bool: - ... + """Returns True if the layer supports stepwise processing.""" @property @abc.abstractmethod def input_latency(self) -> int: - ... + """The number of future timesteps required for the current output.""" @property @abc.abstractmethod def output_latency(self) -> int: - ... + """The number of timesteps the output is delayed.""" @abc.abstractmethod def get_accumulated_input_latency(self, input_latency: int) -> int: - ... + """Calculates the total input latency including previous layers.""" @abc.abstractmethod def get_accumulated_output_latency(self, output_latency: int) -> int: - ... + """Calculates the total output latency including previous layers.""" @abc.abstractmethod def layer( @@ -593,7 +582,7 @@ def get_output_spec( @property @abc.abstractmethod def receptive_field(self) -> Any: - ... + """The receptive field of the layer.""" class SequenceLayer[ @@ -910,51 +899,51 @@ class ModuleSpec(Protocol): @property def Sequence(self) -> type[Sequence]: - ... + """The Sequence class for this backend.""" @property def MaskedSequence(self) -> type[MaskedSequence]: - ... + """The MaskedSequence class for this backend.""" @property def SequenceLayer(self) -> type[SequenceLayer]: - ... + """The SequenceLayer class for this backend.""" @property def SequenceLayerConfig(self) -> type[SequenceLayerConfig]: - ... + """The SequenceLayerConfig class for this backend.""" @property def Steppable(self) -> type[Steppable]: - ... + """The Steppable class for this backend.""" @property def PreservesShape(self) -> type[PreservesShape]: - ... + """The PreservesShape class for this backend.""" @property def Stateless(self) -> type[Stateless]: - ... + """The Stateless class for this backend.""" @property def StatelessPointwise(self) -> type[StatelessPointwise]: - ... + """The StatelessPointwise class for this backend.""" @property def StatelessPointwiseFunctor(self) -> type[StatelessPointwiseFunctor]: - ... + """The StatelessPointwiseFunctor class for this backend.""" @property def PreservesType(self) -> type[PreservesType]: - ... + """The PreservesType class for this backend.""" @property def Emitting(self) -> type[Emitting]: - ... + """The Emitting class for this backend.""" @property def StatelessEmitting(self) -> type[StatelessEmitting]: - ... + """The StatelessEmitting class for this backend.""" __all__ = [ From 2269159c1541156a48bfbe23bdaa767eddcaab48 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Fri, 10 Apr 2026 10:27:50 -0700 Subject: [PATCH 9/9] chore: Enforce spec import naming conventions --- sequence_layers/jax/backend.py | 6 +- sequence_layers/mlx/backend.py | 6 +- sequence_layers/specs/test_utils.py | 6 +- sequence_layers/specs/types_behaviors.py | 91 +++++++++++++----------- 4 files changed, 58 insertions(+), 51 deletions(-) diff --git a/sequence_layers/jax/backend.py b/sequence_layers/jax/backend.py index 54d0b83..f29d4e1 100644 --- a/sequence_layers/jax/backend.py +++ b/sequence_layers/jax/backend.py @@ -4,11 +4,11 @@ import jax.numpy as jnp -from sequence_layers.specs import backend +from sequence_layers.specs import backend as spec from sequence_layers.specs import types as types_spec -class BackendWrapper(backend.xp): +class BackendWrapper(spec.xp): """Thin wrapper around JAX to match NumPy interface for tests.""" bool_ = jnp.bool_ @@ -23,4 +23,4 @@ def zeros(self, shape, dtype=None) -> types_spec.Array: return jnp.zeros(shape, dtype=dtype) -xp: backend.xp = BackendWrapper() +xp: spec.xp = BackendWrapper() diff --git a/sequence_layers/mlx/backend.py b/sequence_layers/mlx/backend.py index 76531f6..f73847a 100644 --- a/sequence_layers/mlx/backend.py +++ b/sequence_layers/mlx/backend.py @@ -4,11 +4,11 @@ import mlx.core as mx -from sequence_layers.specs import backend +from sequence_layers.specs import backend as spec from sequence_layers.specs import types as types_spec -class BackendWrapper(backend.xp): +class BackendWrapper(spec.xp): """Thin wrapper around MLX to match NumPy interface for tests.""" bool_ = mx.bool_ @@ -23,4 +23,4 @@ def zeros(self, shape, dtype=None) -> types_spec.Array: return mx.zeros(shape, dtype=dtype) -xp: backend.xp = BackendWrapper() +xp: spec.xp = BackendWrapper() diff --git a/sequence_layers/specs/test_utils.py b/sequence_layers/specs/test_utils.py index 6bd2ce4..26b1d8d 100644 --- a/sequence_layers/specs/test_utils.py +++ b/sequence_layers/specs/test_utils.py @@ -7,14 +7,14 @@ from sequence_layers import specs from sequence_layers.specs import backend as backend_spec -from sequence_layers.specs import types as spec +from sequence_layers.specs import types as types_spec class _AbcParameterizedTestCaseMeta(abc.ABCMeta, type(parameterized.TestCase)): - pass + """Metaclass for abstract parameterized test cases.""" -class SequenceLayerTest[SequenceT: spec.Sequence = spec.Sequence]( +class SequenceLayerTest[SequenceT: types_spec.Sequence = types_spec.Sequence]( parameterized.TestCase, metaclass=_AbcParameterizedTestCaseMeta ): """Base test class providing common sequence testing assertions. diff --git a/sequence_layers/specs/types_behaviors.py b/sequence_layers/specs/types_behaviors.py index 844f3e7..a1bb773 100644 --- a/sequence_layers/specs/types_behaviors.py +++ b/sequence_layers/specs/types_behaviors.py @@ -9,18 +9,18 @@ from absl.testing import parameterized import numpy as np -from sequence_layers.specs import types as spec +from sequence_layers.specs import types as types_spec from sequence_layers.specs.test_utils import SequenceLayerTest class DummyChannelSpec(NamedTuple): """Dummy channel spec for testing.""" - shape: spec.Shape - dtype: spec.DType + shape: types_spec.Shape + dtype: types_spec.DType -class DefaultTestLayer(spec.SequenceLayer): +class DefaultTestLayer(types_spec.SequenceLayer): """A default test layer for testing.""" @property @@ -64,21 +64,21 @@ def get_accumulated_output_latency(self, output_latency: int) -> int: @override def layer( self, - x: spec.Sequence, + x: types_spec.Sequence, *, training: bool, - constants: spec.Constants | None = None, - ) -> spec.Sequence: + constants: types_spec.Constants | None = None, + ) -> types_spec.Sequence: return x @override def layer_with_emits( self, - x: spec.Sequence, + x: types_spec.Sequence, *, training: bool, - constants: spec.Constants | None = None, - ) -> tuple[spec.Sequence, spec.Emits]: + constants: types_spec.Constants | None = None, + ) -> tuple[types_spec.Sequence, types_spec.Emits]: return self.layer(x, training=training, constants=constants), ( 'test_emits', ) @@ -86,23 +86,23 @@ def layer_with_emits( @override def step( self, - x: spec.Sequence, - state: spec.State, + x: types_spec.Sequence, + state: types_spec.State, *, training: bool, - constants: spec.Constants | None = None, - ) -> tuple[spec.Sequence, spec.State]: + constants: types_spec.Constants | None = None, + ) -> tuple[types_spec.Sequence, types_spec.State]: return x, ('new_test_state',) @override def step_with_emits( self, - x: spec.Sequence, - state: spec.State, + x: types_spec.Sequence, + state: types_spec.State, *, training: bool, - constants: spec.Constants | None = None, - ) -> tuple[spec.Sequence, spec.State, spec.Emits]: + constants: types_spec.Constants | None = None, + ) -> tuple[types_spec.Sequence, types_spec.State, types_spec.Emits]: return *self.step(x, state, training=training, constants=constants), ( 'test_emits', ) @@ -111,26 +111,29 @@ def step_with_emits( def get_initial_state( self, batch_size: int, - input_spec: spec.ChannelSpec, + input_spec: types_spec.ChannelSpec, *, training: bool, - constants: spec.Constants | None = None, - ) -> spec.State: + constants: types_spec.Constants | None = None, + ) -> types_spec.State: return ('test_state',) @override def get_output_shape( self, - input_shape: spec.ShapeLike, + input_shape: types_spec.ShapeLike, *, - constants: spec.Constants | None = None, - ) -> spec.Shape: + constants: types_spec.Constants | None = None, + ) -> types_spec.Shape: return tuple(input_shape) + (1,) @override def get_output_dtype( - self, input_dtype: spec.DType, *, constants: spec.Constants | None = None - ) -> spec.DType: + self, + input_dtype: types_spec.DType, + *, + constants: types_spec.Constants | None = None, + ) -> types_spec.DType: return np.float64 @override @@ -138,7 +141,7 @@ def get_output_spec( self, input_spec: Any, *, - constants: spec.Constants | None = None, + constants: types_spec.Constants | None = None, ) -> Any: shape = self.get_output_shape(input_spec.shape, constants=constants) dtype = self.get_output_dtype(input_spec.dtype, constants=constants) @@ -148,7 +151,7 @@ def get_output_spec( class ModuleInterfaceTest(SequenceLayerTest): def test_backend_specific_module_has_interface(self) -> None: - self.assertIsInstance(self.sl.types, spec.ModuleSpec) + self.assertIsInstance(self.sl.types, types_spec.ModuleSpec) class SequenceTest(SequenceLayerTest): @@ -216,8 +219,8 @@ def test_pad_time(self) -> None: self.assertAllEqual(y.mask, x_left1.mask) def _create_test_sequence( - self, shape: spec.Shape - ) -> spec.Sequence[spec.Array, spec.Array]: + self, shape: types_spec.Shape + ) -> types_spec.Sequence[types_spec.Array, types_spec.Array]: """Creates a test sequence with specific shape.""" size = 1 for d in shape: @@ -376,7 +379,7 @@ def test_from_lengths(self) -> None: class SteppableTest(SequenceLayerTest): - def create_steppable(self) -> spec.Steppable: + def create_steppable(self) -> types_spec.Steppable: """Creates a basic Steppable instance.""" backend_sl = self.sl @@ -412,7 +415,7 @@ def test_get_output_spec(self) -> None: self.assertEqual(output_spec.shape, (2, 3, 1)) self.assertEqual(output_spec.dtype, np.float64) - def create_sequence(self) -> spec.Sequence: + def create_sequence(self) -> types_spec.Sequence: """Creates a test sequence.""" return self.sl.Sequence( self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) @@ -506,7 +509,7 @@ def make(self) -> Any: class PreservesTypeTest(SequenceLayerTest): - def create_layer(self) -> spec.PreservesType: + def create_layer(self) -> types_spec.PreservesType: """Creates a preserves type layer.""" backend_sl = self.sl @@ -528,7 +531,7 @@ def test_preserves_dtype(self) -> None: class PreservesShapeTest(SequenceLayerTest): - def create_layer(self) -> spec.PreservesShape: + def create_layer(self) -> types_spec.PreservesShape: """Creates a preserves shape layer.""" backend_sl = self.sl @@ -550,13 +553,13 @@ def test_preserves_shape(self) -> None: class StatelessTest(SequenceLayerTest): - def create_sequence(self) -> spec.Sequence: + def create_sequence(self) -> types_spec.Sequence: """Creates a default test sequence.""" return self.sl.Sequence( self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) ) - def create_layer(self) -> spec.Stateless: + def create_layer(self) -> types_spec.Stateless: """Creates a stateless layer.""" backend_sl = self.sl @@ -603,13 +606,13 @@ def test_stateless_behaviors(self) -> None: class EmittingTest(SequenceLayerTest): - def create_sequence(self) -> spec.Sequence: + def create_sequence(self) -> types_spec.Sequence: """Creates a default test sequence.""" return self.sl.Sequence( self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) ) - def create_layer(self) -> spec.Emitting: + def create_layer(self) -> types_spec.Emitting: """Creates an emitting layer.""" backend_sl = self.sl @@ -649,13 +652,13 @@ def test_emitting_drops_emits_on_standard_calls(self) -> None: class StatelessEmittingTest(SequenceLayerTest): - def create_sequence(self) -> spec.Sequence: + def create_sequence(self) -> types_spec.Sequence: """Creates a default test sequence.""" return self.sl.Sequence( self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_) ) - def create_layer(self) -> spec.SequenceLayer: + def create_layer(self) -> types_spec.SequenceLayer: """Creates a stateless emitting layer.""" backend_sl = self.sl @@ -701,7 +704,9 @@ def test_stateless_emitting_behaviors(self) -> None: class StatelessPointwiseFunctorTest(SequenceLayerTest): - def create_layer(self, is_mask_required: bool) -> spec.SequenceLayer[Any]: + def create_layer( + self, is_mask_required: bool + ) -> types_spec.SequenceLayer[Any]: """Creates a stateless pointwise functor layer.""" backend_sl = self.sl @@ -736,7 +741,9 @@ def fn(self, values: Any, mask: Any) -> tuple[Any, Any]: return DummyLayer() - def create_sequence(self) -> spec.Sequence[spec.Array, spec.Array]: + def create_sequence( + self, + ) -> types_spec.Sequence[types_spec.Array, types_spec.Array]: """Creates a test sequence.""" return self.sl.Sequence( self.xp.zeros((2, 3, 5)), self.xp.zeros((2, 3), dtype=self.xp.bool_)