mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Pallas GPU] Add 32-bit lowering rule for lax.erf_inv
Add 32-bit lowering rule for `lax.erf_inv` for Pallas GPU, and move the original TPU test case into the general test PiperOrigin-RevId: 668681910
This commit is contained in:
parent
48a9159b22
commit
1dff3a2c71
@ -2497,37 +2497,10 @@ lowering_rules[lax.shift_right_logical_p] = _shift_right_logical_lowering_rules
|
||||
skip_mlir_conversions.add(lax.shift_right_logical_p)
|
||||
|
||||
|
||||
# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L644-L802
|
||||
def _erf_inv_32_helper(x):
|
||||
k_degree = 9
|
||||
w_lt_5_constants = [
|
||||
2.81022636e-08, 3.43273939e-07, -3.5233877e-06,
|
||||
-4.39150654e-06, 0.00021858087, -0.00125372503,
|
||||
-0.00417768164, 0.246640727, 1.50140941,
|
||||
]
|
||||
w_gt_5_constants = [
|
||||
-0.000200214257, 0.000100950558, 0.00134934322,
|
||||
-0.00367342844, 0.00573950773, -0.0076224613,
|
||||
0.00943887047, 1.00167406, 2.83297682,
|
||||
]
|
||||
|
||||
w = -jnp.log1p(x * -x)
|
||||
w_lt_5 = w < 5.0
|
||||
|
||||
w = jnp.where(w_lt_5, w - 2.5, jnp.sqrt(w) - 3.0)
|
||||
|
||||
p = jnp.where(w_lt_5, w_lt_5_constants[0], w_gt_5_constants[0])
|
||||
for i in range(1, k_degree):
|
||||
c = jnp.where(w_lt_5, w_lt_5_constants[i], w_gt_5_constants[i])
|
||||
p = c + p * w
|
||||
|
||||
return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x)
|
||||
|
||||
|
||||
def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
(x_aval,) = ctx.avals_in
|
||||
if x_aval.dtype == jnp.float32:
|
||||
return lower_fun(_erf_inv_32_helper, multiple_results=False)(ctx, x)
|
||||
return lower_fun(pallas_utils.erf_inv_32_lowering_helper, multiple_results=False)(ctx, x)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -404,6 +404,21 @@ def lower_jaxpr_to_triton_ir(
|
||||
return map(read_env, jaxpr.outvars)
|
||||
|
||||
|
||||
def lower_fun(
|
||||
fun: Callable[..., Any], *, multiple_results: bool
|
||||
) -> Callable[..., Any]:
|
||||
fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
||||
|
||||
def f_lowered(ctx: LoweringRuleContext, *args, **params):
|
||||
wrapped_fun = lu.wrap_init(fn, params)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
|
||||
out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr)
|
||||
return out if multiple_results else out[0]
|
||||
|
||||
return f_lowered
|
||||
|
||||
|
||||
# # Primitive lowering rules
|
||||
# ## Programming model primitives
|
||||
|
||||
@ -978,6 +993,27 @@ triton_lowering_rules.update({
|
||||
_Extern(["float64", "float64"], "__ocml_nextafter_f64", "float64"),
|
||||
],
|
||||
),
|
||||
lax.erf_inv_p: _make_dispatch_table(
|
||||
"erf_inv",
|
||||
cuda=[
|
||||
_Fallback(
|
||||
["float32"],
|
||||
lower_fun(
|
||||
pallas_utils.erf_inv_32_lowering_helper,
|
||||
multiple_results=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback(
|
||||
["float32"],
|
||||
lower_fun(
|
||||
pallas_utils.erf_inv_32_lowering_helper,
|
||||
multiple_results=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@ -1255,21 +1291,6 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int):
|
||||
return acc
|
||||
|
||||
|
||||
def lower_fun(
|
||||
fun: Callable[..., Any], *, multiple_results: bool
|
||||
) -> Callable[..., Any]:
|
||||
fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
||||
|
||||
def f_lowered(ctx: LoweringRuleContext, *args, **params):
|
||||
wrapped_fun = lu.wrap_init(fn, params)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
|
||||
out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr)
|
||||
return out if multiple_results else out[0]
|
||||
|
||||
return f_lowered
|
||||
|
||||
|
||||
_JAX_FN_MAPPING = {
|
||||
lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max),
|
||||
lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)),
|
||||
|
@ -183,3 +183,30 @@ def pattern_match_while_to_fori_loop(
|
||||
outvars=new_outvars,
|
||||
)
|
||||
return jaxpr, None
|
||||
|
||||
|
||||
# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L644-L802
|
||||
def erf_inv_32_lowering_helper(x):
|
||||
k_degree = 9
|
||||
w_lt_5_constants = [
|
||||
2.81022636e-08, 3.43273939e-07, -3.5233877e-06,
|
||||
-4.39150654e-06, 0.00021858087, -0.00125372503,
|
||||
-0.00417768164, 0.246640727, 1.50140941,
|
||||
]
|
||||
w_gt_5_constants = [
|
||||
-0.000200214257, 0.000100950558, 0.00134934322,
|
||||
-0.00367342844, 0.00573950773, -0.0076224613,
|
||||
0.00943887047, 1.00167406, 2.83297682,
|
||||
]
|
||||
|
||||
w = -jnp.log1p(x * -x)
|
||||
w_lt_5 = w < 5.0
|
||||
|
||||
w = jnp.where(w_lt_5, w - 2.5, jnp.sqrt(w) - 3.0)
|
||||
|
||||
p = jnp.where(w_lt_5, w_lt_5_constants[0], w_gt_5_constants[0])
|
||||
for i in range(1, k_degree):
|
||||
c = jnp.where(w_lt_5, w_lt_5_constants[i], w_gt_5_constants[i])
|
||||
p = c + p * w
|
||||
|
||||
return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x)
|
||||
|
@ -1513,6 +1513,21 @@ class OpsExtraTest(PallasBaseTest):
|
||||
y_ref = jnp.cumsum(x, axis=axis)
|
||||
np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i)
|
||||
|
||||
@parameterized.parameters([-3.2, -1.0, -0.4, 0., 0.72, 1.0, 2.4])
|
||||
def test_erf_inv(self, x):
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
# TODO(ayx): add float64 support for `erf_inv`
|
||||
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
o_ref[...] = lax.erf_inv(x_ref[...])
|
||||
|
||||
x = jnp.full((8, 128), x)
|
||||
out = kernel(x)
|
||||
expected = lax.erf_inv(x)
|
||||
np.testing.assert_array_equal(out, expected)
|
||||
|
||||
|
||||
class OpsExtraInterpretTest(OpsExtraTest):
|
||||
INTERPRET = True
|
||||
@ -1583,22 +1598,6 @@ class TpuOpsTest(PallasBaseTest):
|
||||
|
||||
super().setUp()
|
||||
|
||||
@parameterized.parameters([-3.2, -1.0, -0.4, 0., 0.72, 1.0, 2.4])
|
||||
def test_erf_inv(self, x):
|
||||
@jax.jit
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
# TODO(ayx): add float64 support for `erf_inv`
|
||||
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
o_ref[...] = lax.erf_inv(x_ref[...])
|
||||
|
||||
x = jnp.full((8, 128), x)
|
||||
out = kernel(x)
|
||||
expected = lax.erf_inv(x)
|
||||
np.testing.assert_array_equal(out, expected)
|
||||
|
||||
SIGN_PARAMS = [
|
||||
(jnp.int32, (-3, 0, 5)),
|
||||
(jnp.uint32, (0, 5)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user