[Pallas GPU] Add 32-bit lowering rule for lax.erf_inv

Add 32-bit lowering rule for `lax.erf_inv` for Pallas GPU, and move the original TPU test case into the general test

PiperOrigin-RevId: 668681910
This commit is contained in:
Ayaka 2024-08-28 17:58:42 -07:00 committed by jax authors
parent 48a9159b22
commit 1dff3a2c71
4 changed files with 79 additions and 59 deletions

View File

@ -2497,37 +2497,10 @@ lowering_rules[lax.shift_right_logical_p] = _shift_right_logical_lowering_rules
skip_mlir_conversions.add(lax.shift_right_logical_p)
# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L644-L802
def _erf_inv_32_helper(x):
k_degree = 9
w_lt_5_constants = [
2.81022636e-08, 3.43273939e-07, -3.5233877e-06,
-4.39150654e-06, 0.00021858087, -0.00125372503,
-0.00417768164, 0.246640727, 1.50140941,
]
w_gt_5_constants = [
-0.000200214257, 0.000100950558, 0.00134934322,
-0.00367342844, 0.00573950773, -0.0076224613,
0.00943887047, 1.00167406, 2.83297682,
]
w = -jnp.log1p(x * -x)
w_lt_5 = w < 5.0
w = jnp.where(w_lt_5, w - 2.5, jnp.sqrt(w) - 3.0)
p = jnp.where(w_lt_5, w_lt_5_constants[0], w_gt_5_constants[0])
for i in range(1, k_degree):
c = jnp.where(w_lt_5, w_lt_5_constants[i], w_gt_5_constants[i])
p = c + p * w
return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x)
def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x):
(x_aval,) = ctx.avals_in
if x_aval.dtype == jnp.float32:
return lower_fun(_erf_inv_32_helper, multiple_results=False)(ctx, x)
return lower_fun(pallas_utils.erf_inv_32_lowering_helper, multiple_results=False)(ctx, x)
else:
raise NotImplementedError

View File

@ -404,6 +404,21 @@ def lower_jaxpr_to_triton_ir(
return map(read_env, jaxpr.outvars)
def lower_fun(
fun: Callable[..., Any], *, multiple_results: bool
) -> Callable[..., Any]:
fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
def f_lowered(ctx: LoweringRuleContext, *args, **params):
wrapped_fun = lu.wrap_init(fn, params)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr)
return out if multiple_results else out[0]
return f_lowered
# # Primitive lowering rules
# ## Programming model primitives
@ -978,6 +993,27 @@ triton_lowering_rules.update({
_Extern(["float64", "float64"], "__ocml_nextafter_f64", "float64"),
],
),
lax.erf_inv_p: _make_dispatch_table(
"erf_inv",
cuda=[
_Fallback(
["float32"],
lower_fun(
pallas_utils.erf_inv_32_lowering_helper,
multiple_results=False,
),
),
],
rocm=[
_Fallback(
["float32"],
lower_fun(
pallas_utils.erf_inv_32_lowering_helper,
multiple_results=False,
),
),
],
),
})
@ -1255,21 +1291,6 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int):
return acc
def lower_fun(
fun: Callable[..., Any], *, multiple_results: bool
) -> Callable[..., Any]:
fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
def f_lowered(ctx: LoweringRuleContext, *args, **params):
wrapped_fun = lu.wrap_init(fn, params)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr)
return out if multiple_results else out[0]
return f_lowered
_JAX_FN_MAPPING = {
lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max),
lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)),

View File

@ -183,3 +183,30 @@ def pattern_match_while_to_fori_loop(
outvars=new_outvars,
)
return jaxpr, None
# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L644-L802
def erf_inv_32_lowering_helper(x):
k_degree = 9
w_lt_5_constants = [
2.81022636e-08, 3.43273939e-07, -3.5233877e-06,
-4.39150654e-06, 0.00021858087, -0.00125372503,
-0.00417768164, 0.246640727, 1.50140941,
]
w_gt_5_constants = [
-0.000200214257, 0.000100950558, 0.00134934322,
-0.00367342844, 0.00573950773, -0.0076224613,
0.00943887047, 1.00167406, 2.83297682,
]
w = -jnp.log1p(x * -x)
w_lt_5 = w < 5.0
w = jnp.where(w_lt_5, w - 2.5, jnp.sqrt(w) - 3.0)
p = jnp.where(w_lt_5, w_lt_5_constants[0], w_gt_5_constants[0])
for i in range(1, k_degree):
c = jnp.where(w_lt_5, w_lt_5_constants[i], w_gt_5_constants[i])
p = c + p * w
return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x)

View File

@ -1513,6 +1513,21 @@ class OpsExtraTest(PallasBaseTest):
y_ref = jnp.cumsum(x, axis=axis)
np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i)
@parameterized.parameters([-3.2, -1.0, -0.4, 0., 0.72, 1.0, 2.4])
def test_erf_inv(self, x):
@functools.partial(
self.pallas_call,
# TODO(ayx): add float64 support for `erf_inv`
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)
def kernel(x_ref, o_ref):
o_ref[...] = lax.erf_inv(x_ref[...])
x = jnp.full((8, 128), x)
out = kernel(x)
expected = lax.erf_inv(x)
np.testing.assert_array_equal(out, expected)
class OpsExtraInterpretTest(OpsExtraTest):
INTERPRET = True
@ -1583,22 +1598,6 @@ class TpuOpsTest(PallasBaseTest):
super().setUp()
@parameterized.parameters([-3.2, -1.0, -0.4, 0., 0.72, 1.0, 2.4])
def test_erf_inv(self, x):
@jax.jit
@functools.partial(
pl.pallas_call,
# TODO(ayx): add float64 support for `erf_inv`
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)
def kernel(x_ref, o_ref):
o_ref[...] = lax.erf_inv(x_ref[...])
x = jnp.full((8, 128), x)
out = kernel(x)
expected = lax.erf_inv(x)
np.testing.assert_array_equal(out, expected)
SIGN_PARAMS = [
(jnp.int32, (-3, 0, 5)),
(jnp.uint32, (0, 5)),