mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00

Breaks tests. lax.sub requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.subtract is a similar function that does automatic type promotion on inputs). PiperOrigin-RevId: 514897538