mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 15:46:06 +00:00

Shape polymorphism is now usable independently of jax2tf, and it deserves to have its tests independent of jax2tf. I started by branching jax2tf/tests/shape_poly_test.py into tests/shape_poly_test.py, followed by removing from the latter the tests and helper functions that do not make sense outside of jax2tf. For now we leave the existing tests in jax2tf, because some of those tests exercise other code paths. In the process of adding these tests we found two bugs (fixed separately in https://github.com/google/jax/pull/18516 and https://github.com/google/jax/pull/18515). Since we now run these tests in GitHub and Kokoro, this has revealed a couple of bugs in the tests, which we fix here both in the jax2tf/tests/shape_poly_test.py and the copy tests/shape_poly_test.py. PiperOrigin-RevId: 583816243