Skip to content

Conversation

@alexunderch
Copy link
Contributor

@alexunderch alexunderch commented Dec 14, 2025

Hey! It's not the main PR, so it might be deleted later but to summarise:

  1. Replaced Sonnet's initialisation with native Pytorch's. Implementations with jax and torch had different structures, now they're the same --- i.e. a common MLP parametrisation and LayerNorm after base layers
  2. Rewrote jax implementation using flax.nnx to match the implementations of torch and tf
  3. Deleted all tf and tensorflow_datasets stuff to simplify the code.

However,

  1. torch and jax differ with loss coefficients: in torch/tf it's $\sqrt{t}$, whereas in jax it is $\frac{t}{2T}$, not as the same but still different
  2. jax implementation used to use masking, however, as said in torch implementation, it's not as valuable because in traversal only legal actions are selected
  3. jax implementation is initialised with pytorch parameters to secure reproducibility

The goal is to make sure that both implementations converge to close values at the same time -- it'll allow to delete the tf implementation

@fuyuan-li
Copy link
Contributor

At first glance, like your t/2T in Jax implementation! I think this is what original paper proposed.

Thank you! Will follow up more

@alexunderch
Copy link
Contributor Author

alexunderch commented Dec 15, 2025

@fuyuan-li , I found out that $\frac{t}{2T}$ is for the Linear CFR specifically, which shouldn't work with non-linear MLPs, so I decided to stick with $\sqrt{t}$ as in the original TF impl

Like, the only difference is masking used in the tf2 and jax implementations -- I am not really sure if it's good, but have kept it for consistency

@alexunderch
Copy link
Contributor Author

@fuyuan-li, I added print_nash_convs argument for DeepCFRSolver to print exploitability of each optimisation step

can harm performance, but can help with tracking the progress

@fuyuan-li
Copy link
Contributor

Quick update @alexunderch, the refactored torch impl cannot converge as the original impl, using the same hyper set as in #1287 (#1287)

Will dig in and keep you posted!

@alexunderch
Copy link
Contributor Author

@fuyuan-li, I erroneously had an odd relu in the network. Sorry!
Now, both impls should converge.

@alexunderch
Copy link
Contributor Author

@lanctot both versions now converge for kuhn, look at the #1406

However, without additional code improvements (e.g. jitting the buffer which causes constant recompilation due to changing size thereof), the jax implementations is like 25 times slower than for pytorch. Maybe should tackle it as a separate issue.

When the results for the leduc are in, notify me if you want to merge, I will clean up some stuff.

@fuyuan-li
Copy link
Contributor

@alexunderch probably a late update -- Yes confirmed too, both impl converged in kuhn poker! exploitability drop to 0.05, policy value converge to theoretical value (-0.06 for player0), tested with several random seeds (for both torch version and jax).

The only reason I didn't run a simulation over 30+ seeds is because jax is running super slow -- 70mins for 1 simulation, on kuhn poker.

Another update -- Leduc simulation (in pytorch impl) is running:

  • on 1 result, exploitability drop to 0.6, policy value goes to -0.12 for player0.

Very happy to continue work on jax impl's performance issue too.(otherwise I doube a simulations for multiple seeds are feasible) Do you think it's a good time (i.e., working code is ready?) to start this new thread? Let me know!

@alexunderch
Copy link
Contributor Author

@fuyuan-li , thank you for your testing. Good that the results are reproducible. The newest commit should lower the comp. time like twice (on my mac, at least). The further improvement, I think, should require some additional modifications for the ReservoirBuffer.

I reckon, that it's up to Dr. Lanctot, if he is okay with the slow but consistently looking implementation, or he thinks that we need to have compatible running times

We can have all the discussion in this thread, if you feel comfortable.

@alexunderch
Copy link
Contributor Author

@fuyuan-li check it out now. jax version should be only 4-6 times slower than pytorch

It's partly because of jax-ml/jax#16587 (at append_to_reservior function) and because I allocate the whole buffer right away

@lanctot
Copy link
Collaborator

lanctot commented Dec 18, 2025

Hi guys, great work on this.. I'm super impressed to see the community collaboration here!

@lanctot
Copy link
Collaborator

lanctot commented Dec 18, 2025

I reckon, that it's up to Dr. Lanctot, if he is okay with the slow but consistently looking implementation, or he thinks that we need to have compatible running times

No strong preferences here, I'm mostly happy to see that we can retain these implementations thanks to both of you working on this.

I'm ok with slightly incompatible running times if one of them is just faster / more efficient. How cryptic is it? Will still want it to be readable. So as long as you have enough comments explaining any non-obvious code, I think it's ok to have a more efficient version that is slightly inconsistent with the other one.

@alexunderch
Copy link
Contributor Author

alexunderch commented Dec 18, 2025

Pytorch implementation hasn't really changed in terms of readability.
In jax implementation, I replaced the buffer with a set of functions and made a jittable training loop. Should still stay readable.

if soon @fuyuan-li reports that their testing is fine, I can clean tf and reference implementations and add a couple of comments, and we should be good to go.

P.s. if we continue with refactoring, I will replace the networks and buffers with corresponding utility imports.

@fuyuan-li
Copy link
Contributor

fuyuan-li commented Dec 19, 2025

Thank you @alexunderch and @lanctot
Quick comments for us:

  1. Convergence and consistency on kuhn in both pytorch and jax are confirmed.
  2. Convergence on Leduc in pytorch, based on 40 simulations: exploitability trained to 0.67, on average. Policy value (for player 0) arrived -0.14 (with std 0.01) across 40 simulations. (Thinking it's a convergence)
  3. Convergence on Leduc in jax is running. re @lanctot (How cryptic is it?): about 65 mins per simulation given the default paramters. (By "default hyper parameters", they are the hyper param sets confirmed convergence for pytorch and kuhn)
  4. @alexunderch : A small glitch: in your jax impl, the jax decorator for init_reservoir(), do you want to change from @jax.jit(static_argnames=("capacity",)) to @partial(jax.jit, static_argnames=("capacity",))? I updated it on my local to have the simulation run, but don't think I have to submit a PR on this -- defer to you to update in the branch HEAD to keep things simple.
  5. Given the runnning time on Leduc in jax, expected to get results this weekends (it's running now), but feel free to go ahead if we are all happy about the kuhn's convergence result. I'll come back to log the simulation results when it's ready (regardless whether this PR is already merged).

@alexunderch
Copy link
Contributor Author

@fuyuan-li thank you for your updates. Yes, I will update it locally. When everything is confirmed, we can merge.

Just for the sake of interest, can you compare GPU performances? Because as I mentioned, because of the array copying, jax performance may be worse than numpy. On paper, should be much faster. No code modifications needed, just install cuda versions...

@alexunderch
Copy link
Contributor Author

alexunderch commented Jan 13, 2026

@lanctot I merged as you told me to but there is that scipy installation error again (for arm...)

@lanctot
Copy link
Collaborator

lanctot commented Jan 13, 2026

@lanctot I merged as you told me to but there is that scipy installation error again (for arm...)

Can you try removing scipy from python_extra_deps.sh ? All instances of it. Since it's in requirements.txt we shouldn't need it in python_extra_deps.sh

@lanctot
Copy link
Collaborator

lanctot commented Jan 13, 2026

@lanctot I merged as you told me to but there is that scipy installation error again (for arm...)

Can you try removing scipy from python_extra_deps.sh ? All instances of it. Since it's in requirements.txt we shouldn't need it in python_extra_deps.sh

Nvm.. that won't do anything. It seems to be failing from the installation via requirements.txt...

I'm not sure what's going on.

Let me trigger a custom wheels tests on master.. done: https://github.com/google-deepmind/open_spiel/actions/runs/20957953304 (see if this one fails with the same reason or not..)

@lanctot
Copy link
Collaborator

lanctot commented Jan 13, 2026

Wait, it appears that you and @visheshrwl are running into the same problem (he's working on #1426). So my guess is the arm wheels are broken due to the scipy upgrade a few days ago. I expect my wheels test to fail. Can you try commenting out the Linux arm64 tests in wheels.yml?

@lanctot lanctot mentioned this pull request Jan 13, 2026
@alexunderch
Copy link
Contributor Author

alexunderch commented Jan 13, 2026

I think it's a cool solution but let's try a diff thing beforehand: openblas
it's marked as a solution for both, windows and linux scipy/scipy#21562

@lanctot
Copy link
Collaborator

lanctot commented Jan 13, 2026

@lanctot
Copy link
Collaborator

lanctot commented Jan 13, 2026

I think it's a cool solution but let's try a diff thing beforehand: openblas it's marked as a solution for both, windows and linux scipy/scipy#21562

Sorry don't have any more time to work on this. I will disable them for now and open a bug. Feel free to work on it independently. We can't leave master broken because the CI doesn't work.

@alexunderch
Copy link
Contributor Author

If it doesn't work, I'll revert the change and just disable the wheels

@alexunderch
Copy link
Contributor Author

alexunderch commented Jan 13, 2026

There was a conflict because of the PR you've just merged, fixed it. Technically, doesn't change anyhting.

Updated with the freshest comments of yours.

@alexunderch
Copy link
Contributor Author

it's not only ARM https://github.com/google-deepmind/open_spiel/actions/runs/20958630257/job/60230374482

@alexunderch
Copy link
Contributor Author

wait @lanctot

you (we) use manylinux_2014 which is based on Centos7

EOL of the distribution is like 2024, no? pypa/cibuildwheel#1772

maybe to upate it? when I ran with manylinux_2_28, that wheel passed (question mark?)

@lanctot
Copy link
Collaborator

lanctot commented Jan 13, 2026

it's not only ARM https://github.com/google-deepmind/open_spiel/actions/runs/20958630257/job/60230374482

Yeah, extended the disabling: #1443

@lanctot
Copy link
Collaborator

lanctot commented Jan 13, 2026

wait @lanctot

you (we) use manylinux_2014 which is based on Centos7

EOL of the distribution is like 2024, no? pypa/cibuildwheel#1772

maybe to upate it? when I ran with manylinux_2_28, that wheel passed (question mark?)

Yes, we should do that, but will likely raise a few other issues that we should deal with separately. I don't think it's the cause of the current issue.

@alexunderch
Copy link
Contributor Author

merging #1443 here?

@alexunderch
Copy link
Contributor Author

Quoting Google's AI:

The Crash: Compiling SciPy 1.17.0 from source requires C++17 headers and a modern version of glibc to handle the new ARPACK/PROPACK C-conversions. manylinux2014 is based on CentOS 7 (glibc 2.17), which is too old. The build fails during the C++ compilation of SciPy, long before it ever gets to your "disabled" tests.

@lanctot
Copy link
Collaborator

lanctot commented Jan 13, 2026

Note that there are conflicts that need to be resolved; can't rerun the rests until they get resolved.

@alexunderch
Copy link
Contributor Author

I feel really sorry but I really found out that I can't run test just in my fork and not bother you as much...

I think I make tests for linux (Ubuntu and ARM) pass. Will send you the jobs, maybe you'd want to revert patch-81 and patch-82 because they won't be needed

@lanctot
Copy link
Collaborator

lanctot commented Jan 13, 2026

It's not a problem... let's just do it separately :)

@lanctot
Copy link
Collaborator

lanctot commented Jan 13, 2026

Replaced by #1445

@lanctot lanctot closed this Jan 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants