Add MLX (Apple Silicon) implementation of BDH with training and tests #3
+1,494
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR contributes a faithful MLX port of the BDH architecture for Apple Silicon. It mirrors the original PyTorch implementation’s math and behavior (byte-level vocab, shared weights, sparse ReLU activations, RoPE attention with Q=K, non-affine LayerNorm), packaged with training and testing scripts. The goal is to make BDH easy to train on Mac M-series devices with unified memory and Metal acceleration.
Motivation
What’s included
Design and equivalence notes
Parameter sharing: same
encoder,decoder, andencoder_vreused across layersSparse activations: ReLU on latent projections to enforce non-negativity
Attention: RoPE applied with Q=K, causal mask with diagonal=-1 behavior
LayerNorm: non-affine, consistent with original
Output head: same shape, same initialization scale (normal std=0.02)
Differences are API-level only:
Verified shape parity and operation order at each step; generation (top-k) is also equivalent
Performance (Apple Silicon, indicative)
How to use
Status and availability
I’m currently training the MLX build on an Internal Knowledge Map dataset; weights should be ready in about a day. I’ll publish checkpoints to Hugging Face once completed.
A standalone repo is also available here for reference: https://github.com/severian42/BDH-MLX
Impact
Questions for maintainers
Checklist
References