Complemented libdevice ops with the ones from the MLIR Math dialect

This allows some ops, e.g. jnp.exp, to support half-precision inputs (#20239).

PiperOrigin-RevId: 617766573
This commit is contained in:
Sergei Lebedev 2024-03-21 01:55:08 -07:00 committed by jax authors
parent 7d431ad33b
commit c5fa14ba9b
2 changed files with 122 additions and 73 deletions

View File

@ -41,6 +41,7 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.lax.control_flow import for_loop
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith as arith_dialect
from jax._src.lib.mlir.dialects import math as math_dialect
from jax._src.lib.mlir.dialects import scf as scf_dialect
from jax._src.lib.triton import dialect as tt_dialect
from jax._src.pallas import core as pallas_core
@ -531,22 +532,44 @@ class _Extern:
symbol: str
result_type: str
def matches(self, args: Sequence[jax_core.ShapedArray]) -> bool:
if len(args) != len(self.arg_types):
def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool:
if len(avals) != len(self.arg_types):
return False
return all(
aval.weak_type or aval.dtype.name == arg_type
for aval, arg_type in zip(args, self.arg_types)
for aval, arg_type in zip(avals, self.arg_types)
)
def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]):
[out_aval] = ctx.avals_out
result_type = _dtype_to_ir_type(jnp.dtype(self.result_type))
if out_aval.shape:
result_type = ir.RankedTensorType.get(out_aval.shape, result_type)
return tt_dialect.extern_elementwise(
result_type,
args,
libname="",
libpath="",
symbol=self.symbol,
pure=True,
)
def _extern_elementwise(
name: str, table: Sequence[_Extern]
@dataclasses.dataclass(frozen=True)
class _Fallback:
arg_types: Sequence[str]
lower: Callable[..., ir.Value]
matches = _Extern.matches
def _make_dispatch_table(
name: str, table: Sequence[_Extern | _Fallback]
) -> Callable[..., ir.Value]:
def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value:
extern = next((e for e in table if e.matches(ctx.avals_in)), None)
if extern is None:
h = next((e for e in table if e.matches(ctx.avals_in)), None)
if h is None:
arg_aval_dtypes = tuple(aval.dtype.name for aval in ctx.avals_in)
raise NotImplementedError(
f"unsupported types for {name}: {arg_aval_dtypes}"
@ -554,7 +577,7 @@ def _extern_elementwise(
[out_aval] = ctx.avals_out
bcast_args = []
for aval, arg, arg_type in zip(ctx.avals_in, args, extern.arg_types):
for aval, arg, arg_type in zip(ctx.avals_in, args, h.arg_types):
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
if aval.weak_type and aval.dtype.name != arg_type:
bcast_arg = _cast(
@ -563,90 +586,107 @@ def _extern_elementwise(
signed=jnp.issubdtype(aval.dtype, jnp.signedinteger),
)
bcast_args.append(bcast_arg)
result_type = _dtype_to_ir_type(jnp.dtype(extern.result_type))
if out_aval.shape:
result_type = ir.RankedTensorType.get(out_aval.shape, result_type)
return tt_dialect.extern_elementwise(
result_type,
bcast_args,
libname="",
libpath="",
symbol=extern.symbol,
pure=True,
)
return h.lower(ctx, *bcast_args)
return inner
_abs_dispatch_table = _make_dispatch_table(
"abs",
[
_Extern(["int32"], "__nv_abs", "int32"),
_Extern(["int64"], "__nv_llabs", "int64"),
_Extern(["float32"], "__nv_fabsf", "float32"),
_Extern(["float64"], "__nv_fabs", "float64"),
],
)
def _abs_lowering_rule(ctx: LoweringRuleContext, x):
try:
return _abs_dispatch_table(ctx, x)
except NotImplementedError as e:
[x_aval] = ctx.avals_in
if jnp.issubdtype(x_aval, jnp.integer):
return math_dialect.absi(x)
elif jnp.issubdtype(x_aval, jnp.floating):
return math_dialect.absf(x)
else:
raise e from None
triton_lowering_rules[lax.abs_p] = _abs_lowering_rule
triton_lowering_rules.update({
lax.neg_p: lambda ctx, x: _minus(x),
lax.abs_p: _extern_elementwise(
"abs",
[
_Extern(["int32"], "__nv_abs", "int32"),
_Extern(["int64"], "__nv_llabs", "int64"),
_Extern(["float32"], "__nv_fabsf", "float32"),
_Extern(["float64"], "__nv_fabs", "float64"),
],
),
lax.ceil_p: _extern_elementwise(
lax.ceil_p: _make_dispatch_table(
"ceil",
[
_Extern(["float32"], "__nv_ceilf", "float32"),
_Extern(["float64"], "__nv_ceil", "float64"),
],
),
lax.floor_p: _extern_elementwise(
lax.floor_p: _make_dispatch_table(
"floor",
[
_Extern(["float32"], "__nv_floorf", "float32"),
_Extern(["float64"], "__nv_floor", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)),
],
),
lax.exp_p: _extern_elementwise(
lax.exp_p: _make_dispatch_table(
"exp",
[
_Extern(["float32"], "__nv_expf", "float32"),
_Extern(["float64"], "__nv_exp", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)),
],
),
lax.exp2_p: _extern_elementwise(
lax.exp2_p: _make_dispatch_table(
"exp2",
[
_Extern(["float32"], "__nv_exp2f", "float32"),
_Extern(["float64"], "__nv_exp2", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)),
],
),
lax.expm1_p: _extern_elementwise(
lax.expm1_p: _make_dispatch_table(
"expm1",
[
_Extern(["float32"], "__nv_expm1f", "float32"),
_Extern(["float64"], "__nv_expm1", "float64"),
],
),
lax.log_p: _extern_elementwise(
lax.log_p: _make_dispatch_table(
"log",
[
_Extern(["float32"], "__nv_logf", "float32"),
_Extern(["float64"], "__nv_log", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.log(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)),
],
),
lax.log1p_p: _extern_elementwise(
lax.log1p_p: _make_dispatch_table(
"log1p",
[
_Extern(["float32"], "__nv_log1pf", "float32"),
_Extern(["float64"], "__nv_log1p", "float64"),
],
),
lax.sqrt_p: _extern_elementwise(
lax.sqrt_p: _make_dispatch_table(
"sqrt",
[
_Extern(["float32"], "__nv_sqrtf", "float32"),
_Extern(["float64"], "__nv_sqrt", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)),
],
),
lax.pow_p: _extern_elementwise(
lax.pow_p: _make_dispatch_table(
"pow",
[
_Extern(["float32", "int32"], "__nv_powif", "float32"),
@ -655,126 +695,130 @@ triton_lowering_rules.update({
_Extern(["float64", "float64"], "__nv_pow", "float64"),
],
),
lax.cbrt_p: _extern_elementwise(
lax.cbrt_p: _make_dispatch_table(
"cbrt",
[
_Extern(["float32"], "__nv_cbrtf", "float32"),
_Extern(["float64"], "__nv_cbrt", "float64"),
],
),
lax.rsqrt_p: _extern_elementwise(
lax.rsqrt_p: _make_dispatch_table(
"rsqrt",
[
_Extern(["float32"], "__nv_rsqrtf", "float32"),
_Extern(["float64"], "__nv_rsqrt", "float64"),
],
),
lax.sin_p: _extern_elementwise(
lax.sin_p: _make_dispatch_table(
"sin",
[
_Extern(["float32"], "__nv_sinf", "float32"),
_Extern(["float64"], "__nv_sin", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)),
],
),
lax.cos_p: _extern_elementwise(
lax.cos_p: _make_dispatch_table(
"cos",
[
_Extern(["float32"], "__nv_cosf", "float32"),
_Extern(["float64"], "__nv_cos", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)),
],
),
lax.tan_p: _extern_elementwise(
lax.tan_p: _make_dispatch_table(
"tan",
[
_Extern(["float32"], "__nv_tanf", "float32"),
_Extern(["float64"], "__nv_tan", "float64"),
],
),
lax.asin_p: _extern_elementwise(
lax.asin_p: _make_dispatch_table(
"asin",
[
_Extern(["float32"], "__nv_asinf", "float32"),
_Extern(["float64"], "__nv_asin", "float64"),
],
),
lax.acos_p: _extern_elementwise(
lax.acos_p: _make_dispatch_table(
"acos",
[
_Extern(["float32"], "__nv_acosf", "float32"),
_Extern(["float64"], "__nv_acos", "float64"),
],
),
lax.atan_p: _extern_elementwise(
lax.atan_p: _make_dispatch_table(
"atan",
[
_Extern(["float32"], "__nv_atanf", "float32"),
_Extern(["float64"], "__nv_atan", "float64"),
],
),
lax.atan2_p: _extern_elementwise(
lax.atan2_p: _make_dispatch_table(
"atan2",
[
_Extern(["float32", "float32"], "__nv_atan2f", "float32"),
_Extern(["float64", "float64"], "__nv_atan2", "float64"),
],
),
lax.sinh_p: _extern_elementwise(
lax.sinh_p: _make_dispatch_table(
"sinh",
[
_Extern(["float32"], "__nv_sinhf", "float32"),
_Extern(["float64"], "__nv_sinh", "float64"),
],
),
lax.cosh_p: _extern_elementwise(
lax.cosh_p: _make_dispatch_table(
"cosh",
[
_Extern(["float32"], "__nv_coshf", "float32"),
_Extern(["float64"], "__nv_cosh", "float64"),
],
),
lax.tanh_p: _extern_elementwise(
lax.tanh_p: _make_dispatch_table(
"tanh",
[
_Extern(["float32"], "__nv_tanhf", "float32"),
_Extern(["float64"], "__nv_tanh", "float64"),
],
),
lax.asinh_p: _extern_elementwise(
lax.asinh_p: _make_dispatch_table(
"asinh",
[
_Extern(["float32"], "__nv_asinhf", "float32"),
_Extern(["float64"], "__nv_asinh", "float64"),
],
),
lax.acosh_p: _extern_elementwise(
lax.acosh_p: _make_dispatch_table(
"acosh",
[
_Extern(["float32"], "__nv_acoshf", "float32"),
_Extern(["float64"], "__nv_acosh", "float64"),
],
),
lax.atanh_p: _extern_elementwise(
lax.atanh_p: _make_dispatch_table(
"atanh",
[
_Extern(["float32"], "__nv_atanhf", "float32"),
_Extern(["float64"], "__nv_atanh", "float64"),
],
),
lax.population_count_p: _extern_elementwise(
lax.population_count_p: _make_dispatch_table(
"population_count",
[
_Extern(["int32"], "__nv_popc", "int32"),
_Extern(["int64"], "__nv_popcll", "int32"),
],
),
lax.clz_p: _extern_elementwise(
lax.clz_p: _make_dispatch_table(
"clz",
[
_Extern(["int32"], "__nv_clz", "int32"),
_Extern(["int64"], "__nv_clzll", "int32"),
],
),
lax.nextafter_p: _extern_elementwise(
lax.nextafter_p: _make_dispatch_table(
"nextafter",
[
_Extern(["float32", "float32"], "__nv_nextafterf", "float32"),

View File

@ -1524,15 +1524,22 @@ class PallasCallInterpreterVmapTest(PallasCallVmapTest):
class PallasOpsTest(PallasTest):
ELEMENTWISE_OPS = [
([jnp.abs, jnp.negative], ["int32", "int64", "float32", "float64"]),
(
[jnp.abs, jnp.negative],
["int16", "int32", "int64", "float16", "float32", "float64"],
),
([jnp.ceil, jnp.floor], ["float32", "float64"]),
(
[jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt],
["float16", "float32", "float64"],
),
(
# fmt: off
[jnp.ceil, jnp.floor, jnp.exp, jnp.exp2, jnp.expm1, jnp.log1p,
jnp.sqrt, jnp.cbrt, lax.rsqrt, jnp.sin, jnp.cos, jnp.tan, jnp.asin,
jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh, jnp.asinh,
jnp.acosh, jnp.atanh],
[jnp.expm1, jnp.log1p, jnp.cbrt, lax.rsqrt, jnp.tan, jnp.asin,
jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.asinh, jnp.acosh,
jnp.atanh],
# fmt: on
["float32", "float64"]
["float32", "float64"],
),
([lax.population_count, lax.clz, jnp.invert], ["int32", "int64"]),
]
@ -1552,7 +1559,7 @@ class PallasOpsTest(PallasTest):
with contextlib.ExitStack() as stack:
if jnp.dtype(dtype).itemsize == 8:
stack.enter_context(config.enable_x64(True))
x = jnp.array([4.2, 2.4]).astype(dtype)
x = jnp.array([0.42, 2.4]).astype(dtype)
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)
@parameterized.parameters(
@ -1611,7 +1618,7 @@ class PallasOpsTest(PallasTest):
@parameterized.named_parameters(
(f"{fn.__name__}_{dtype}", fn, dtype)
for fn, dtype in itertools.product(
COMPARISON_OPS, ["int32", "uint32", "float32"]
COMPARISON_OPS, ["int32", "uint32", "float16", "float32"]
)
)
def test_comparison(self, fn, dtype):
@ -1674,17 +1681,15 @@ class PallasOpsTest(PallasTest):
BINARY_OPS = [
([jnp.floor_divide], ["int32", "uint32"]),
(
[jnp.add, jnp.subtract, jnp.multiply, jnp.remainder],
["int32", "uint32", "float32"],
[jnp.add, jnp.subtract, jnp.multiply],
["int16", "int32", "uint32", "float16", "float32"],
),
([jnp.remainder], ["int32", "uint32", "float32"]),
(
[
jnp.bitwise_and,
jnp.bitwise_or,
jnp.bitwise_xor,
jnp.bitwise_left_shift,
jnp.bitwise_right_shift,
],
# fmt: off
[jnp.bitwise_and, jnp.bitwise_or, jnp.bitwise_xor,
jnp.bitwise_left_shift, jnp.bitwise_right_shift],
# fmt: on
["int32", "uint32"],
),
]