[Mosaic TPU] Add support for sqrt and rsqrt in bf16 on TPUv6

PiperOrigin-RevId: 708016513
This commit is contained in:
Adam Paszke 2024-12-19 13:42:01 -08:00 committed by jax authors
parent 307c8d3af8
commit de8fa8fd19

View File

@ -824,7 +824,7 @@ class OpsTest(PallasBaseTest):
jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh, jnp.asinh,
jnp.acosh, jnp.atanh],
# fmt: on
["float32", "float64"],
["bfloat16", "float32", "float64"],
),
([lax.population_count, lax.clz, jnp.invert], ["int32", "int64"]),
([jnp.logical_not], ["bool"]),
@ -843,12 +843,16 @@ class OpsTest(PallasBaseTest):
if dtype in ("int16", "float16"):
self.skipTest("int16 and float16 are not supported on TPU")
if (
fn in (jnp.ceil, jnp.floor, jnp.negative, jnp.exp, jnp.exp2, jnp.log)
fn in (jnp.ceil, jnp.floor, jnp.negative, jnp.exp, jnp.exp2, jnp.log,
jnp.sqrt, lax.rsqrt)
and dtype == "bfloat16"
and not jtu.is_device_tpu_at_least(6)
):
self.skipTest(f"bfloat16 {fn.__name__} is only supported on TPU v6+")
if fn in (jnp.sqrt, jnp.sin, jnp.cos) and dtype == "bfloat16":
if (
fn in (jnp.sin, jnp.cos, jnp.tan, jnp.tanh, jnp.log1p)
and dtype == "bfloat16"
):
self.skipTest(f"bfloat16 {fn.__name__} is not supported on TPU")
# TODO(b/370578663): implement these lowerings on TPU
if fn in (
@ -862,7 +866,10 @@ class OpsTest(PallasBaseTest):
if (
jtu.test_device_matches(["gpu"])
and fn in (jnp.ceil, jnp.floor)
and fn
in (jnp.ceil, jnp.floor, jnp.expm1, jnp.log1p, jnp.cbrt, lax.rsqrt,
jnp.tan, jnp.asin, jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh,
jnp.asinh, jnp.acosh, jnp.atanh)
and dtype == "bfloat16"
):
self.skipTest(f"bfloat16 {fn.__name__} is not supported on GPU")
@ -897,7 +904,10 @@ class OpsTest(PallasBaseTest):
if (
jtu.test_device_matches(["gpu"])
and fn in (jnp.ceil, jnp.floor)
and fn
in (jnp.ceil, jnp.floor, jnp.expm1, jnp.log1p, jnp.cbrt, lax.rsqrt,
jnp.tan, jnp.asin, jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh,
jnp.asinh, jnp.acosh, jnp.atanh)
and dtype == "bfloat16"
):
self.skipTest(f"bfloat16 {fn.__name__} is not supported on GPU")