mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
[pallas:mosaic_gpu] add logistic op and some tests for unary operations
PiperOrigin-RevId: 681889064
This commit is contained in:
parent
ad78147183
commit
5800070c36
@ -892,6 +892,12 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math)
|
||||
|
||||
@register_lowering_rule(lax.logistic_p)
|
||||
def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
a = _ensure_fa(x, x_aval.dtype)
|
||||
return 1. / (1. + (-a).exp(approx=ctx.module_ctx.approx_math))
|
||||
|
||||
|
||||
@register_lowering_rule(lax.reduce_sum_p)
|
||||
def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
||||
|
@ -43,16 +43,22 @@ class PallasTest(jtu.JaxTestCase):
|
||||
|
||||
class PallasCallTest(PallasTest):
|
||||
|
||||
def test_add_one(self):
|
||||
@parameterized.named_parameters(
|
||||
("add_one", lambda x: x + 1.),
|
||||
("logistic", jax.lax.logistic),
|
||||
("square", lambda x: x ** 2),
|
||||
("rsqrt", jax.lax.rsqrt),
|
||||
)
|
||||
def test_unary_ops(self, unary):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...] + 1.0
|
||||
o_ref[...] = unary(x_ref[...])
|
||||
|
||||
x = jnp.arange(256).astype(jnp.float32)
|
||||
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
||||
np.testing.assert_array_equal(kernel(x), unary(x))
|
||||
|
||||
def test_add_xy(self):
|
||||
@functools.partial(
|
||||
|
Loading…
x
Reference in New Issue
Block a user