mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic TPU] Add support for sqrt and rsqrt in bf16 on TPUv6
PiperOrigin-RevId: 708016513
This commit is contained in:
parent
307c8d3af8
commit
de8fa8fd19
@ -824,7 +824,7 @@ class OpsTest(PallasBaseTest):
|
||||
jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh, jnp.asinh,
|
||||
jnp.acosh, jnp.atanh],
|
||||
# fmt: on
|
||||
["float32", "float64"],
|
||||
["bfloat16", "float32", "float64"],
|
||||
),
|
||||
([lax.population_count, lax.clz, jnp.invert], ["int32", "int64"]),
|
||||
([jnp.logical_not], ["bool"]),
|
||||
@ -843,12 +843,16 @@ class OpsTest(PallasBaseTest):
|
||||
if dtype in ("int16", "float16"):
|
||||
self.skipTest("int16 and float16 are not supported on TPU")
|
||||
if (
|
||||
fn in (jnp.ceil, jnp.floor, jnp.negative, jnp.exp, jnp.exp2, jnp.log)
|
||||
fn in (jnp.ceil, jnp.floor, jnp.negative, jnp.exp, jnp.exp2, jnp.log,
|
||||
jnp.sqrt, lax.rsqrt)
|
||||
and dtype == "bfloat16"
|
||||
and not jtu.is_device_tpu_at_least(6)
|
||||
):
|
||||
self.skipTest(f"bfloat16 {fn.__name__} is only supported on TPU v6+")
|
||||
if fn in (jnp.sqrt, jnp.sin, jnp.cos) and dtype == "bfloat16":
|
||||
if (
|
||||
fn in (jnp.sin, jnp.cos, jnp.tan, jnp.tanh, jnp.log1p)
|
||||
and dtype == "bfloat16"
|
||||
):
|
||||
self.skipTest(f"bfloat16 {fn.__name__} is not supported on TPU")
|
||||
# TODO(b/370578663): implement these lowerings on TPU
|
||||
if fn in (
|
||||
@ -862,7 +866,10 @@ class OpsTest(PallasBaseTest):
|
||||
|
||||
if (
|
||||
jtu.test_device_matches(["gpu"])
|
||||
and fn in (jnp.ceil, jnp.floor)
|
||||
and fn
|
||||
in (jnp.ceil, jnp.floor, jnp.expm1, jnp.log1p, jnp.cbrt, lax.rsqrt,
|
||||
jnp.tan, jnp.asin, jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh,
|
||||
jnp.asinh, jnp.acosh, jnp.atanh)
|
||||
and dtype == "bfloat16"
|
||||
):
|
||||
self.skipTest(f"bfloat16 {fn.__name__} is not supported on GPU")
|
||||
@ -897,7 +904,10 @@ class OpsTest(PallasBaseTest):
|
||||
|
||||
if (
|
||||
jtu.test_device_matches(["gpu"])
|
||||
and fn in (jnp.ceil, jnp.floor)
|
||||
and fn
|
||||
in (jnp.ceil, jnp.floor, jnp.expm1, jnp.log1p, jnp.cbrt, lax.rsqrt,
|
||||
jnp.tan, jnp.asin, jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh,
|
||||
jnp.asinh, jnp.acosh, jnp.atanh)
|
||||
and dtype == "bfloat16"
|
||||
):
|
||||
self.skipTest(f"bfloat16 {fn.__name__} is not supported on GPU")
|
||||
|
Loading…
x
Reference in New Issue
Block a user