diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 6a9fe216d..c8f910e18 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2497,37 +2497,10 @@ lowering_rules[lax.shift_right_logical_p] = _shift_right_logical_lowering_rules skip_mlir_conversions.add(lax.shift_right_logical_p) -# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L644-L802 -def _erf_inv_32_helper(x): - k_degree = 9 - w_lt_5_constants = [ - 2.81022636e-08, 3.43273939e-07, -3.5233877e-06, - -4.39150654e-06, 0.00021858087, -0.00125372503, - -0.00417768164, 0.246640727, 1.50140941, - ] - w_gt_5_constants = [ - -0.000200214257, 0.000100950558, 0.00134934322, - -0.00367342844, 0.00573950773, -0.0076224613, - 0.00943887047, 1.00167406, 2.83297682, - ] - - w = -jnp.log1p(x * -x) - w_lt_5 = w < 5.0 - - w = jnp.where(w_lt_5, w - 2.5, jnp.sqrt(w) - 3.0) - - p = jnp.where(w_lt_5, w_lt_5_constants[0], w_gt_5_constants[0]) - for i in range(1, k_degree): - c = jnp.where(w_lt_5, w_lt_5_constants[i], w_gt_5_constants[i]) - p = c + p * w - - return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x) - - def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x): (x_aval,) = ctx.avals_in if x_aval.dtype == jnp.float32: - return lower_fun(_erf_inv_32_helper, multiple_results=False)(ctx, x) + return lower_fun(pallas_utils.erf_inv_32_lowering_helper, multiple_results=False)(ctx, x) else: raise NotImplementedError diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 44b198c0f..4057b125f 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -404,6 +404,21 @@ def lower_jaxpr_to_triton_ir( return map(read_env, jaxpr.outvars) +def lower_fun( + fun: Callable[..., Any], *, multiple_results: bool +) -> Callable[..., Any]: + fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) + + def f_lowered(ctx: LoweringRuleContext, *args, **params): + wrapped_fun = lu.wrap_init(fn, params) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) + out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr) + return out if multiple_results else out[0] + + return f_lowered + + # # Primitive lowering rules # ## Programming model primitives @@ -978,6 +993,27 @@ triton_lowering_rules.update({ _Extern(["float64", "float64"], "__ocml_nextafter_f64", "float64"), ], ), + lax.erf_inv_p: _make_dispatch_table( + "erf_inv", + cuda=[ + _Fallback( + ["float32"], + lower_fun( + pallas_utils.erf_inv_32_lowering_helper, + multiple_results=False, + ), + ), + ], + rocm=[ + _Fallback( + ["float32"], + lower_fun( + pallas_utils.erf_inv_32_lowering_helper, + multiple_results=False, + ), + ), + ], + ), }) @@ -1255,21 +1291,6 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): return acc -def lower_fun( - fun: Callable[..., Any], *, multiple_results: bool -) -> Callable[..., Any]: - fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) - - def f_lowered(ctx: LoweringRuleContext, *args, **params): - wrapped_fun = lu.wrap_init(fn, params) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) - jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) - out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr) - return out if multiple_results else out[0] - - return f_lowered - - _JAX_FN_MAPPING = { lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max), lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)), diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index 295134bd9..6fc816e27 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -183,3 +183,30 @@ def pattern_match_while_to_fori_loop( outvars=new_outvars, ) return jaxpr, None + + +# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L644-L802 +def erf_inv_32_lowering_helper(x): + k_degree = 9 + w_lt_5_constants = [ + 2.81022636e-08, 3.43273939e-07, -3.5233877e-06, + -4.39150654e-06, 0.00021858087, -0.00125372503, + -0.00417768164, 0.246640727, 1.50140941, + ] + w_gt_5_constants = [ + -0.000200214257, 0.000100950558, 0.00134934322, + -0.00367342844, 0.00573950773, -0.0076224613, + 0.00943887047, 1.00167406, 2.83297682, + ] + + w = -jnp.log1p(x * -x) + w_lt_5 = w < 5.0 + + w = jnp.where(w_lt_5, w - 2.5, jnp.sqrt(w) - 3.0) + + p = jnp.where(w_lt_5, w_lt_5_constants[0], w_gt_5_constants[0]) + for i in range(1, k_degree): + c = jnp.where(w_lt_5, w_lt_5_constants[i], w_gt_5_constants[i]) + p = c + p * w + + return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index b2171c462..ab7ffcb48 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1513,6 +1513,21 @@ class OpsExtraTest(PallasBaseTest): y_ref = jnp.cumsum(x, axis=axis) np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) + @parameterized.parameters([-3.2, -1.0, -0.4, 0., 0.72, 1.0, 2.4]) + def test_erf_inv(self, x): + @functools.partial( + self.pallas_call, + # TODO(ayx): add float64 support for `erf_inv` + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = lax.erf_inv(x_ref[...]) + + x = jnp.full((8, 128), x) + out = kernel(x) + expected = lax.erf_inv(x) + np.testing.assert_array_equal(out, expected) + class OpsExtraInterpretTest(OpsExtraTest): INTERPRET = True @@ -1583,22 +1598,6 @@ class TpuOpsTest(PallasBaseTest): super().setUp() - @parameterized.parameters([-3.2, -1.0, -0.4, 0., 0.72, 1.0, 2.4]) - def test_erf_inv(self, x): - @jax.jit - @functools.partial( - pl.pallas_call, - # TODO(ayx): add float64 support for `erf_inv` - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - ) - def kernel(x_ref, o_ref): - o_ref[...] = lax.erf_inv(x_ref[...]) - - x = jnp.full((8, 128), x) - out = kernel(x) - expected = lax.erf_inv(x) - np.testing.assert_array_equal(out, expected) - SIGN_PARAMS = [ (jnp.int32, (-3, 0, 5)), (jnp.uint32, (0, 5)),