-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Deep cfr jax refactor #1408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deep cfr jax refactor #1408
Conversation
|
At first glance, like your t/2T in Jax implementation! I think this is what original paper proposed. Thank you! Will follow up more |
|
@fuyuan-li , I found out that Like, the only difference is masking used in the tf2 and jax implementations -- I am not really sure if it's good, but have kept it for consistency |
|
@fuyuan-li, I added can harm performance, but can help with tracking the progress |
|
Quick update @alexunderch, the refactored torch impl cannot converge as the original impl, using the same hyper set as in #1287 (#1287) Will dig in and keep you posted! |
|
@fuyuan-li, I erroneously had an odd relu in the network. Sorry! |
|
@lanctot both versions now converge for kuhn, look at the #1406 However, without additional code improvements (e.g. jitting the buffer which causes constant recompilation due to changing size thereof), the When the results for the leduc are in, notify me if you want to merge, I will clean up some stuff. |
|
@alexunderch probably a late update -- Yes confirmed too, both impl converged in kuhn poker! exploitability drop to 0.05, policy value converge to theoretical value (-0.06 for player0), tested with several random seeds (for both torch version and jax). The only reason I didn't run a simulation over 30+ seeds is because jax is running super slow -- 70mins for 1 simulation, on kuhn poker. Another update -- Leduc simulation (in pytorch impl) is running:
Very happy to continue work on jax impl's performance issue too.(otherwise I doube a simulations for multiple seeds are feasible) Do you think it's a good time (i.e., working code is ready?) to start this new thread? Let me know! |
|
@fuyuan-li , thank you for your testing. Good that the results are reproducible. The newest commit should lower the comp. time like twice (on my mac, at least). The further improvement, I think, should require some additional modifications for the I reckon, that it's up to Dr. Lanctot, if he is okay with the slow but consistently looking implementation, or he thinks that we need to have compatible running times We can have all the discussion in this thread, if you feel comfortable. |
|
@fuyuan-li check it out now. It's partly because of jax-ml/jax#16587 (at |
|
Hi guys, great work on this.. I'm super impressed to see the community collaboration here! |
No strong preferences here, I'm mostly happy to see that we can retain these implementations thanks to both of you working on this. I'm ok with slightly incompatible running times if one of them is just faster / more efficient. How cryptic is it? Will still want it to be readable. So as long as you have enough comments explaining any non-obvious code, I think it's ok to have a more efficient version that is slightly inconsistent with the other one. |
|
Pytorch implementation hasn't really changed in terms of readability. if soon @fuyuan-li reports that their testing is fine, I can clean tf and reference implementations and add a couple of comments, and we should be good to go. P.s. if we continue with refactoring, I will replace the networks and buffers with corresponding utility imports. |
|
Thank you @alexunderch and @lanctot
|
|
@fuyuan-li thank you for your updates. Yes, I will update it locally. When everything is confirmed, we can merge. Just for the sake of interest, can you compare GPU performances? Because as I mentioned, because of the array copying, |
|
@lanctot I merged as you told me to but there is that scipy installation error again (for arm...) |
3915744 to
cc3a6b1
Compare
Can you try removing |
Nvm.. that won't do anything. It seems to be failing from the installation via I'm not sure what's going on. Let me trigger a custom wheels tests on master.. done: https://github.com/google-deepmind/open_spiel/actions/runs/20957953304 (see if this one fails with the same reason or not..) |
|
Wait, it appears that you and @visheshrwl are running into the same problem (he's working on #1426). So my guess is the arm wheels are broken due to the scipy upgrade a few days ago. I expect my wheels test to fail. Can you try commenting out the Linux arm64 tests in wheels.yml? |
|
I think it's a cool solution but let's try a diff thing beforehand: |
Sorry don't have any more time to work on this. I will disable them for now and open a bug. Feel free to work on it independently. We can't leave master broken because the CI doesn't work. |
|
If it doesn't work, I'll revert the change and just disable the wheels |
|
There was a conflict because of the PR you've just merged, fixed it. Technically, doesn't change anyhting. Updated with the freshest comments of yours. |
|
wait @lanctot you (we) use manylinux_2014 which is based on Centos7 EOL of the distribution is like 2024, no? pypa/cibuildwheel#1772 maybe to upate it? when I ran with |
ace0bb4 to
26161df
Compare
Yeah, extended the disabling: #1443 |
Yes, we should do that, but will likely raise a few other issues that we should deal with separately. I don't think it's the cause of the current issue. |
|
merging #1443 here? |
|
Quoting Google's AI:
|
841f6d8 to
39c250a
Compare
|
Note that there are conflicts that need to be resolved; can't rerun the rests until they get resolved. |
|
I feel really sorry but I really found out that I can't run test just in my fork and not bother you as much... I think I make tests for linux (Ubuntu and ARM) pass. Will send you the jobs, maybe you'd want to revert |
|
It's not a problem... let's just do it separately :) |
|
Replaced by #1445 |
Hey! It's not the main PR, so it might be deleted later but to summarise:
Sonnet'sinitialisation with nativePytorch's. Implementations with jax and torch had different structures, now they're the same --- i.e. a commonMLPparametrisation andLayerNormafter base layersjaximplementation usingflax.nnxto match the implementations of torch and tftfandtensorflow_datasetsstuff to simplify the code.However,
torchandjaxdiffer with loss coefficients: intorch/tfit'sjaximplementation used to use masking, however, as said in torch implementation, it's not as valuable because in traversal only legal actions are selectedjaximplementation is initialised with pytorch parameters to secure reproducibilityThe goal is to make sure that both implementations converge to close values at the same time -- it'll allow to delete the tf implementation