mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add lowering for lax.sign
This commit is contained in:
parent
4063373b22
commit
6cc09173d5
@ -1793,6 +1793,27 @@ lowering_rules[lax.neg_p] = _neg_lowering_rule
|
||||
skip_mlir_conversions.add(lax.neg_p)
|
||||
|
||||
|
||||
def _sign_lowering_helper(x):
|
||||
if jnp.issubdtype(x.dtype, jnp.unsignedinteger):
|
||||
return (x != 0).astype(x.dtype)
|
||||
|
||||
if jnp.issubdtype(x.dtype, jnp.integer):
|
||||
return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype)
|
||||
|
||||
if jnp.issubdtype(x.dtype, jnp.floating):
|
||||
out = (x > 0.).astype(x.dtype) - (x < 0.).astype(x.dtype)
|
||||
return jnp.where(jnp.isnan(x), jnp.nan, out)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _sign_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
return lower_fun(_sign_lowering_helper, multiple_results=False)(ctx, x)
|
||||
|
||||
|
||||
lowering_rules[lax.sign_p] = _sign_lowering_rule
|
||||
|
||||
|
||||
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
return math.RsqrtOp(x).result
|
||||
|
||||
|
@ -54,6 +54,31 @@ class TpuOpsTest(jtu.JaxTestCase):
|
||||
expected = lax.erf_inv(x)
|
||||
np.testing.assert_array_equal(out, expected)
|
||||
|
||||
SIGN_PARAMS = [
|
||||
(jnp.int32, (-3, 0, 5)),
|
||||
(jnp.uint32, (0, 5)),
|
||||
(jnp.float32, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)),
|
||||
]
|
||||
|
||||
@parameterized.named_parameters(
|
||||
(f"{dtype.__name__}_{value}", dtype, value)
|
||||
for dtype, values in SIGN_PARAMS
|
||||
for value in values
|
||||
)
|
||||
def test_sign(self, dtype, value):
|
||||
@jax.jit
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((4,), dtype),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
o_ref[...] = jnp.sign(x_ref[...])
|
||||
|
||||
x = jnp.full((4,), value, dtype=dtype)
|
||||
out = kernel(x)
|
||||
expected = jnp.sign(x)
|
||||
np.testing.assert_array_equal(out, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user