Add lowering for lax.sign

This commit is contained in:
Ayaka 2024-07-25 18:04:44 +08:00
parent 4063373b22
commit 6cc09173d5
2 changed files with 46 additions and 0 deletions

View File

@ -1793,6 +1793,27 @@ lowering_rules[lax.neg_p] = _neg_lowering_rule
skip_mlir_conversions.add(lax.neg_p)
def _sign_lowering_helper(x):
if jnp.issubdtype(x.dtype, jnp.unsignedinteger):
return (x != 0).astype(x.dtype)
if jnp.issubdtype(x.dtype, jnp.integer):
return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype)
if jnp.issubdtype(x.dtype, jnp.floating):
out = (x > 0.).astype(x.dtype) - (x < 0.).astype(x.dtype)
return jnp.where(jnp.isnan(x), jnp.nan, out)
raise NotImplementedError
def _sign_lowering_rule(ctx: LoweringRuleContext, x):
return lower_fun(_sign_lowering_helper, multiple_results=False)(ctx, x)
lowering_rules[lax.sign_p] = _sign_lowering_rule
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
return math.RsqrtOp(x).result

View File

@ -54,6 +54,31 @@ class TpuOpsTest(jtu.JaxTestCase):
expected = lax.erf_inv(x)
np.testing.assert_array_equal(out, expected)
SIGN_PARAMS = [
(jnp.int32, (-3, 0, 5)),
(jnp.uint32, (0, 5)),
(jnp.float32, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)),
]
@parameterized.named_parameters(
(f"{dtype.__name__}_{value}", dtype, value)
for dtype, values in SIGN_PARAMS
for value in values
)
def test_sign(self, dtype, value):
@jax.jit
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), dtype),
)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.sign(x_ref[...])
x = jnp.full((4,), value, dtype=dtype)
out = kernel(x)
expected = jnp.sign(x)
np.testing.assert_array_equal(out, expected)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())