DeepHAMX is a JAX-accelerated reimplementation of DeepHAM: a global solution method for heterogeneous-agent models with aggregate shocks (Han, Yang, and E). This repository keeps the same high-level algorithm and configuration style as the original PyTorch code, while focusing on faster execution through JAX, jax.numpy, Haiku, and Optax.
Scope. The training entry point in this fork targets the Krusell–Smith (KS) model. Additional JSON configs under
srcx/configs/may be carried over from upstream experiments; only the KS training script is wired for JAX in the current layout.
Jiequn Han, Yucheng Yang, Weinan E
Upstream repository: github.com/frankhan91/DeepHAM
| Path | Role |
|---|---|
srcx/ |
JAX implementation: training, value networks, policy, datasets, utilities |
srcx/configs/KS/ |
Example Krusell–Smith experiment configs (JSON) |
srcx/train_KS.py |
Main training script (absl flags) |
data/ |
Place model inputs (for example .mat policy matrices) and run outputs under data/simul_results/ |
environment.yml |
Conda environment specification |
scripts/smoke_check.py |
Quick JAX / import / shape sanity check (no full training run) |
The KS configs reference MATLAB-derived assets such as data/KS_policy_N50_v1.mat. If those files are not in your clone, obtain the corresponding data from the original DeepHAM repository or your own preprocessing pipeline, and align mats_path in the JSON config with your local paths.
Tested setup: Python 3.11, JAX built for CUDA 12 (GPU). CPU-only JAX installs are possible but are not the configuration used for the benchmarks below.
conda env create -f environment.yml
conda activate deephamxconda create -n deephamx python=3.11
conda activate deephamx
pip install "jax[cuda12]" dm-haiku optax absl-py tqdm scipyInstall the JAX variant that matches your platform (jax[cuda12] vs CPU wheels) following the official JAX installation guide.
Note. CUDA 12 is strongly recommended if you use NVIDIA GPUs. Matching JAX wheels to an older CUDA stack can be brittle; the authors tested primarily on CUDA 12.
After activating your JAX environment, from the repository root run:
python3 scripts/smoke_check.pyThis script imports srcx modules, runs a short simul_shocks / next_wealth path, and performs one Haiku MLP forward pass. It mirrors train_KS.py by turning on jax_enable_x64 when srcx/param.py has DTYPE = "float64", and in that case it also asserts float64 dtypes on the sampled tensors so silent float32 truncation is caught early.
Run from the srcx directory so relative config and data paths resolve as in the defaults.
cd srcx
python3 train_KS.pytrain_KS.py uses absl.flags:
| Flag | Short | Default | Meaning |
|---|---|---|---|
--config_path |
-c |
./configs/KS/game_nn_n50.json |
Path to the experiment JSON |
--exp_name |
-n |
test |
Suffix for the output run directory name |
Example:
cd srcx
python3 train_KS.py -c ./configs/KS/game_nn_n10.json -n my_runCheckpoints and logs are written under data/simul_results/KS/ according to policy type, sampling mode, number of agents, and exp_name (see train_KS.py for the exact naming pattern).
Global dtype for the KS run is set in srcx/param.py (DTYPE = "float64" or "float32"). For float64, the script enables jax_enable_x64 automatically.
Benchmarks were run on NVIDIA GeForce RTX 3090 nodes (2× AMD EPYC 7H12, 64 cores each). Speedups come from computational improvements (JAX/XLA, batching, etc.); algorithm parameters and the overall solving procedure were not intentionally changed relative to the reference implementation. You can cross-check numerical behavior using validation utilities from the original DeepHAM repository.
| Precision | Implementation | CUDA streams | Wall-clock | Valid U | End K |
|---|---|---|---|---|---|
| FP32 | JAX | 1 | 17 min | 104.035 | 39.237 |
| FP32 | JAX | 2 | 17 min | 104.096 | 38.329 |
| FP64 | JAX | 1 | 41 min | 103.728 | 38.380 |
| FP64 | JAX | 2 | 34 min | 103.707 | 39.361 |
| FP32 | Original (ORG) | 1 | 57 min | 103.702 | 39.014 |
| FP64 | Original (ORG) | 1 | 65 min | 104.126 | 39.326 |
Roughly 3.5× faster than the original PyTorch run in the FP32, single-stream setting reported above.
If you use DeepHAM in research, please cite the original paper:
@article{HanYangE2021deepham,
title={Deep{HAM}: A global solution method for heterogeneous agent models with aggregate shocks},
author={Han, Jiequn and Yang, Yucheng and E, Weinan},
journal={arXiv preprint arXiv:2112.14377},
year={2021}
}If this JAX port is useful in your work, a star on the repository is appreciated.
This project is distributed under the GNU Lesser General Public License v2.1 — see LICENSE. Respect upstream licensing and attribution when reusing or redistributing code and data.
- DeepHAM (method and theory): jiequnhan@gmail.com, yucheng.yang@uzh.ch
- DeepHAMX (JAX port): wang2021@stu.pku.edu.cn