mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

I have only added tests and documentation, will improve error reporting separately. For TPU we get a mix of errors from either the Pallas lowering or from Mosaic. I plan to add lowering exception for all unsupported cases, so that we have a better Python stack trace available. For GPU, we get a RET_CHECK instead of a Python exception, so I had to add skipTest. Will fix the error message separately. In order to be able to put the test in pallas_test::PallasCallTest, I moved the skipTest for TPU from the setUp to the individual tests that need this. PiperOrigin-RevId: 653195289