mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 15:36:07 +00:00

Add 32-bit lowering rule for `lax.erf_inv` for Pallas GPU, and move the original TPU test case into the general test PiperOrigin-RevId: 668681910
Add 32-bit lowering rule for `lax.erf_inv` for Pallas GPU, and move the original TPU test case into the general test PiperOrigin-RevId: 668681910