mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Pallas TPU] Add lowerings for scalar absf
and rsqrt
This PR is similar to https://github.com/jax-ml/jax/pull/24284 PiperOrigin-RevId: 689546724
This commit is contained in:
parent
af28595909
commit
5c614470ad
@ -832,11 +832,21 @@ class OpsTest(PallasBaseTest):
|
||||
"Scalar population count on TPU is only supported in interpret mode"
|
||||
)
|
||||
|
||||
if (
|
||||
jtu.test_device_matches(["tpu"])
|
||||
and fn == jnp.abs
|
||||
and jnp.issubdtype(dtype, jnp.integer)
|
||||
and not self.INTERPRET
|
||||
):
|
||||
self.skipTest(
|
||||
"Scalar abs for integers on TPU is only supported in interpret mode"
|
||||
)
|
||||
|
||||
# TODO(b/370578663): implement these lowerings on TPU
|
||||
if jtu.test_device_matches(["tpu"]) and fn in (
|
||||
jnp.abs, jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan,
|
||||
jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan,
|
||||
jnp.atanh, jnp.cbrt, jnp.cosh, jnp.expm1,
|
||||
jnp.sinh, lax.rsqrt,
|
||||
jnp.sinh,
|
||||
):
|
||||
self.skipTest(f"{fn.__name__} not implemented on TPU")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user