1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

[pallas:mosaic_gpu] add logistic op and some tests for unary operations

PiperOrigin-RevId: 681889064
This commit is contained in:
Christos Perivolaropoulos 2024-10-03 08:25:09 -07:00 committed by jax authors
parent ad78147183
commit 5800070c36
2 changed files with 15 additions and 3 deletions
jax/_src/pallas/mosaic_gpu
tests/pallas

@ -892,6 +892,12 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math)
@register_lowering_rule(lax.logistic_p)
def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
a = _ensure_fa(x, x_aval.dtype)
return 1. / (1. + (-a).exp(approx=ctx.module_ctx.approx_math))
@register_lowering_rule(lax.reduce_sum_p)
def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):

@ -43,16 +43,22 @@ class PallasTest(jtu.JaxTestCase):
class PallasCallTest(PallasTest):
def test_add_one(self):
@parameterized.named_parameters(
("add_one", lambda x: x + 1.),
("logistic", jax.lax.logistic),
("square", lambda x: x ** 2),
("rsqrt", jax.lax.rsqrt),
)
def test_unary_ops(self, unary):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...] + 1.0
o_ref[...] = unary(x_ref[...])
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
np.testing.assert_array_equal(kernel(x), unary(x))
def test_add_xy(self):
@functools.partial(