[Pallas TPU] Add lowerings for scalar absf and rsqrt

This PR is similar to https://github.com/jax-ml/jax/pull/24284

PiperOrigin-RevId: 689546724
This commit is contained in:
Ayaka 2024-10-24 15:58:57 -07:00 committed by jax authors
parent af28595909
commit 5c614470ad

View File

@ -832,11 +832,21 @@ class OpsTest(PallasBaseTest):
"Scalar population count on TPU is only supported in interpret mode"
)
if (
jtu.test_device_matches(["tpu"])
and fn == jnp.abs
and jnp.issubdtype(dtype, jnp.integer)
and not self.INTERPRET
):
self.skipTest(
"Scalar abs for integers on TPU is only supported in interpret mode"
)
# TODO(b/370578663): implement these lowerings on TPU
if jtu.test_device_matches(["tpu"]) and fn in (
jnp.abs, jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan,
jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan,
jnp.atanh, jnp.cbrt, jnp.cosh, jnp.expm1,
jnp.sinh, lax.rsqrt,
jnp.sinh,
):
self.skipTest(f"{fn.__name__} not implemented on TPU")