diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 5fcbea288..425773faa 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -301,3 +301,149 @@ pip install --force-reinstall https://github.com/bitsandbytes-foundation/bitsand ``` + +#### NVIDIA Jetson (sm_87) — source build required + +NVIDIA Jetson Orin-family devices (Orin Nano, Orin NX, AGX Orin) report CUDA +compute capability **sm_87**, which is not included in the prebuilt aarch64 +wheels (see supported-arch matrix above). Installing the PyPI aarch64 wheel +on Jetson Orin succeeds at import but fails at the first CUDA kernel launch: + +``` +RuntimeError: Error named symbol not found at line 233 in file /src/csrc/ops.cu +``` + +Build from source against the JetPack CUDA toolchain: + +```bash +git clone --depth 1 --branch 0.46.1 \ + https://github.com/bitsandbytes-foundation/bitsandbytes.git +cd bitsandbytes +PATH=/usr/local/cuda-12.6/bin:$PATH \ + cmake -B build . -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=87 +PATH=/usr/local/cuda-12.6/bin:$PATH \ + cmake --build build -j4 +pip install . +``` + +**JetPack version → CUDA version mapping:** use the correct CUDA toolchain path for your JetPack release. + +| JetPack release | CUDA version | Toolchain path | Notes | +|---|---|---|---| +| 6.0 / 6.1 | 12.2 | `/usr/local/cuda-12.2/bin` | Ships with NVIDIA JetPack 6.0/6.1 images. | +| 6.2 | 12.6 | `/usr/local/cuda-12.6/bin` | Current release at time of writing; all numbers in this PR measured here. | + +Substitute the matching `cuda-/bin` path in the `PATH=...` prefix. The `-DCOMPUTE_CAPABILITY=87` flag is required on every Jetson Orin variant regardless of JetPack release (all Orin-family silicon is sm_87). + +Build takes ~6 minutes on an Orin Nano Super. The resulting wheel produces +`bitsandbytes/libbitsandbytes_cuda126.so` linked against JetPack CUDA 12.6, +and 4-bit quantize/dequantize ops run cleanly. + +**Verify the install picked up sm_87:** + +```bash +python -m bitsandbytes +``` + +Expected output on a correctly-built Jetson install: + +``` +=================== bitsandbytes v0.46.1 =================== +Platform: Linux-5.15.148-tegra-aarch64-with-glibc2.35 +PyTorch: 2.5.0a0+872d972e41.nv24.08 + CUDA: 12.6 +PyTorch settings found: CUDA_VERSION=126, Highest Compute Capability: (8, 7). +Checking that the library is importable and CUDA is callable... +SUCCESS! +``` + +The key line is `Highest Compute Capability: (8, 7)` — this confirms the +source-built library has sm_87 in its kernel set. If this reports a different +capability or the `SUCCESS!` line is missing, the build did not target sm_87 +and 4-bit ops will still fail at first launch. + +**Known limitation — paged optimizers do not work on Jetson Orin.** + +Standard 8-bit optimizers (`bnb.optim.AdamW8bit`, `Adam8bit`, `Lion8bit`, +etc.) run correctly after a source build with `-DCOMPUTE_CAPABILITY=87`. +However, the paged variants (`paged_adamw_8bit`, `paged_adamw_32bit`, +`PagedAdamW8bit`, `PagedAdam8bit`, and their counterparts) instantiate +without error but fail during training. This was observed on the NVIDIA +Jetson AGX Orin Developer Kit (64 GB, sm_87) per the Hackster.io tutorial +"Fine-Tuning LLMs using NVIDIA Jetson AGX Orin" (2024), and the same +underlying GPU-memory management path is used on all Jetson Orin-family +silicon. Users on Jetson should stay on the non-paged 8-bit variants +(e.g., `optim="adamw_bnb_8bit"` in `transformers.TrainingArguments` or +`SFTConfig`). The non-paged path already provides the ~75% optimizer-state +memory reduction; paging is a safety net for memory spikes that Jetson's +unified-memory model doesn't benefit from in the same way as discrete GPUs. + +**Verified environments:** + +| Device | JetPack | CUDA | Python | Build time | Status | +|---|---|---|---|---|---| +| Jetson Orin Nano Super | 6.2 | 12.6 | 3.10 | ~6 min | Working (0.46.1) | + +Other Jetson Orin devices (Orin NX, AGX Orin) are also sm_87 and are expected +to work with the same recipe; please add entries above if you verify one. + +**Correctness:** the source-built wheel has been validated beyond just +running-without-crash. On a 16-problem held-out logic-reasoning benchmark +(Carroll-16), three independent lines of evidence support that the sm_87 +kernels are numerically correct, not just stable: + +- **4-bit NF4 base inference at 1B (TinyLlama 1.1B)** scored within 1/16 + problems of the same model's Ollama Q4_K_M reference — quantization + preserves base reasoning at the smallest model scale tested. +- **Same-stack 4-bit QLoRA vs bf16 LoRA training** (identical hyperparameters, + seed, data; only `load_in_4bit` differs) produced training losses within + 0.4% (0.2963 vs 0.2951) and downstream benchmark scores within single- + problem noise — 4-bit training is behaviorally equivalent to bf16 at + matched config. Peak training memory for the 4-bit run was 54% lower. +- **4-bit NF4 base inference at 3B (Qwen2.5-3B-Instruct)** scored 93.75% + keyword (15/16) and 0.418 judge composite on Carroll-16 — highest of any + non-reasoning-model configuration tested. Confirms the kernels produce + correct outputs across model families and scales, not just for a single + architecture. + +Both validations are documented in [elemental/carroll/docs/experiments/ +TASK_UNSLOTH_4BIT_VALIDATION_2026_04_21.md and +TASK_UNSLOTH_V2_QUANT_VS_BF16_ADAPTER_2026_04_21.md]. + +**Memory envelope reference (measured on Jetson Orin Nano Super, 8 GB +unified, batch=1, seq=1024):** + +| Model | Inference peak | QLoRA training peak | +|---|---:|---:| +| TinyLlama 1.1B, 4-bit NF4 | 1.05 GB | 1.22 GB | +| Qwen2.5-3B-Instruct, 4-bit NF4 | 2.43 GB | 3.69 GB | +| Llama-3.1-Nemotron-Nano-4B, 4-bit NF4 | 3.84 GB | 4.43 GB | + +Training peaks assume attention-only LoRA (r=16, α=32) and standard AdamW +in fp32. Using `adamw_bnb_8bit` reduces optimizer state materially at 3B+ +scale where trainable-parameter count becomes a significant fraction of +peak (saves ~360 MB at 3B/r=32 all-modules; ~2.4 GB at 7B/r=64). At r=16 +attention-only on 1B, the 8-bit optimizer's saving is ≤24 MB — below the +noise floor of peak-memory measurement. + +**Known limitation — profiling 3B+ inference can trigger device reboot.** + +`torch.profiler` with `with_stack=True + profile_memory=True + record_shapes=True` +on 3B QLoRA inference at seq=1024 causes memory pressure that can trigger +the Jetson's safety shutdown and reboot the device, wiping `/tmp`. This was +observed once in testing with `active=10` steps of profile-recording on a +Qwen2.5-3B-Instruct Carroll-16 eval. Root cause: the profiler's internal +buffers (stack frames, op shape records, memory events) grow into GB-class +resident state that competes with the model for the device's 7.4 GB unified +memory. + +Safer profiling configurations on Orin Nano Super: +- Drop `with_stack=True` and `profile_memory=True` when profiling 3B+ + inference; keep the lighter-weight op-timing-only config. +- Use `schedule(active=2-3)` instead of `active=10` — a shorter active window + bounds the resident profiler state. +- Profile training (which on Jetson is already memory-bound but not as tight + as concurrent profile + inference) before profile inference. +- On Orin AGX (64 GB), the full-detail profile config works without this + risk; it is specific to the 8 GB Orin Nano Super memory envelope. +```