mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Complemented libdevice ops with the ones from the MLIR Math dialect
This allows some ops, e.g. jnp.exp, to support half-precision inputs (#20239). PiperOrigin-RevId: 617766573
This commit is contained in:
parent
7d431ad33b
commit
c5fa14ba9b
@ -41,6 +41,7 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax.control_flow import for_loop
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith as arith_dialect
|
||||
from jax._src.lib.mlir.dialects import math as math_dialect
|
||||
from jax._src.lib.mlir.dialects import scf as scf_dialect
|
||||
from jax._src.lib.triton import dialect as tt_dialect
|
||||
from jax._src.pallas import core as pallas_core
|
||||
@ -531,22 +532,44 @@ class _Extern:
|
||||
symbol: str
|
||||
result_type: str
|
||||
|
||||
def matches(self, args: Sequence[jax_core.ShapedArray]) -> bool:
|
||||
if len(args) != len(self.arg_types):
|
||||
def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool:
|
||||
if len(avals) != len(self.arg_types):
|
||||
return False
|
||||
return all(
|
||||
aval.weak_type or aval.dtype.name == arg_type
|
||||
for aval, arg_type in zip(args, self.arg_types)
|
||||
for aval, arg_type in zip(avals, self.arg_types)
|
||||
)
|
||||
|
||||
def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]):
|
||||
[out_aval] = ctx.avals_out
|
||||
result_type = _dtype_to_ir_type(jnp.dtype(self.result_type))
|
||||
if out_aval.shape:
|
||||
result_type = ir.RankedTensorType.get(out_aval.shape, result_type)
|
||||
return tt_dialect.extern_elementwise(
|
||||
result_type,
|
||||
args,
|
||||
libname="",
|
||||
libpath="",
|
||||
symbol=self.symbol,
|
||||
pure=True,
|
||||
)
|
||||
|
||||
|
||||
def _extern_elementwise(
|
||||
name: str, table: Sequence[_Extern]
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _Fallback:
|
||||
arg_types: Sequence[str]
|
||||
lower: Callable[..., ir.Value]
|
||||
|
||||
matches = _Extern.matches
|
||||
|
||||
|
||||
def _make_dispatch_table(
|
||||
name: str, table: Sequence[_Extern | _Fallback]
|
||||
) -> Callable[..., ir.Value]:
|
||||
|
||||
def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value:
|
||||
extern = next((e for e in table if e.matches(ctx.avals_in)), None)
|
||||
if extern is None:
|
||||
h = next((e for e in table if e.matches(ctx.avals_in)), None)
|
||||
if h is None:
|
||||
arg_aval_dtypes = tuple(aval.dtype.name for aval in ctx.avals_in)
|
||||
raise NotImplementedError(
|
||||
f"unsupported types for {name}: {arg_aval_dtypes}"
|
||||
@ -554,7 +577,7 @@ def _extern_elementwise(
|
||||
|
||||
[out_aval] = ctx.avals_out
|
||||
bcast_args = []
|
||||
for aval, arg, arg_type in zip(ctx.avals_in, args, extern.arg_types):
|
||||
for aval, arg, arg_type in zip(ctx.avals_in, args, h.arg_types):
|
||||
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
|
||||
if aval.weak_type and aval.dtype.name != arg_type:
|
||||
bcast_arg = _cast(
|
||||
@ -563,90 +586,107 @@ def _extern_elementwise(
|
||||
signed=jnp.issubdtype(aval.dtype, jnp.signedinteger),
|
||||
)
|
||||
bcast_args.append(bcast_arg)
|
||||
|
||||
result_type = _dtype_to_ir_type(jnp.dtype(extern.result_type))
|
||||
if out_aval.shape:
|
||||
result_type = ir.RankedTensorType.get(out_aval.shape, result_type)
|
||||
return tt_dialect.extern_elementwise(
|
||||
result_type,
|
||||
bcast_args,
|
||||
libname="",
|
||||
libpath="",
|
||||
symbol=extern.symbol,
|
||||
pure=True,
|
||||
)
|
||||
return h.lower(ctx, *bcast_args)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
_abs_dispatch_table = _make_dispatch_table(
|
||||
"abs",
|
||||
[
|
||||
_Extern(["int32"], "__nv_abs", "int32"),
|
||||
_Extern(["int64"], "__nv_llabs", "int64"),
|
||||
_Extern(["float32"], "__nv_fabsf", "float32"),
|
||||
_Extern(["float64"], "__nv_fabs", "float64"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _abs_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
try:
|
||||
return _abs_dispatch_table(ctx, x)
|
||||
except NotImplementedError as e:
|
||||
[x_aval] = ctx.avals_in
|
||||
if jnp.issubdtype(x_aval, jnp.integer):
|
||||
return math_dialect.absi(x)
|
||||
elif jnp.issubdtype(x_aval, jnp.floating):
|
||||
return math_dialect.absf(x)
|
||||
else:
|
||||
raise e from None
|
||||
|
||||
|
||||
triton_lowering_rules[lax.abs_p] = _abs_lowering_rule
|
||||
|
||||
|
||||
triton_lowering_rules.update({
|
||||
lax.neg_p: lambda ctx, x: _minus(x),
|
||||
lax.abs_p: _extern_elementwise(
|
||||
"abs",
|
||||
[
|
||||
_Extern(["int32"], "__nv_abs", "int32"),
|
||||
_Extern(["int64"], "__nv_llabs", "int64"),
|
||||
_Extern(["float32"], "__nv_fabsf", "float32"),
|
||||
_Extern(["float64"], "__nv_fabs", "float64"),
|
||||
],
|
||||
),
|
||||
lax.ceil_p: _extern_elementwise(
|
||||
lax.ceil_p: _make_dispatch_table(
|
||||
"ceil",
|
||||
[
|
||||
_Extern(["float32"], "__nv_ceilf", "float32"),
|
||||
_Extern(["float64"], "__nv_ceil", "float64"),
|
||||
],
|
||||
),
|
||||
lax.floor_p: _extern_elementwise(
|
||||
lax.floor_p: _make_dispatch_table(
|
||||
"floor",
|
||||
[
|
||||
_Extern(["float32"], "__nv_floorf", "float32"),
|
||||
_Extern(["float64"], "__nv_floor", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)),
|
||||
],
|
||||
),
|
||||
lax.exp_p: _extern_elementwise(
|
||||
lax.exp_p: _make_dispatch_table(
|
||||
"exp",
|
||||
[
|
||||
_Extern(["float32"], "__nv_expf", "float32"),
|
||||
_Extern(["float64"], "__nv_exp", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)),
|
||||
],
|
||||
),
|
||||
lax.exp2_p: _extern_elementwise(
|
||||
lax.exp2_p: _make_dispatch_table(
|
||||
"exp2",
|
||||
[
|
||||
_Extern(["float32"], "__nv_exp2f", "float32"),
|
||||
_Extern(["float64"], "__nv_exp2", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)),
|
||||
],
|
||||
),
|
||||
lax.expm1_p: _extern_elementwise(
|
||||
lax.expm1_p: _make_dispatch_table(
|
||||
"expm1",
|
||||
[
|
||||
_Extern(["float32"], "__nv_expm1f", "float32"),
|
||||
_Extern(["float64"], "__nv_expm1", "float64"),
|
||||
],
|
||||
),
|
||||
lax.log_p: _extern_elementwise(
|
||||
lax.log_p: _make_dispatch_table(
|
||||
"log",
|
||||
[
|
||||
_Extern(["float32"], "__nv_logf", "float32"),
|
||||
_Extern(["float64"], "__nv_log", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.log(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)),
|
||||
],
|
||||
),
|
||||
lax.log1p_p: _extern_elementwise(
|
||||
lax.log1p_p: _make_dispatch_table(
|
||||
"log1p",
|
||||
[
|
||||
_Extern(["float32"], "__nv_log1pf", "float32"),
|
||||
_Extern(["float64"], "__nv_log1p", "float64"),
|
||||
],
|
||||
),
|
||||
lax.sqrt_p: _extern_elementwise(
|
||||
lax.sqrt_p: _make_dispatch_table(
|
||||
"sqrt",
|
||||
[
|
||||
_Extern(["float32"], "__nv_sqrtf", "float32"),
|
||||
_Extern(["float64"], "__nv_sqrt", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
],
|
||||
),
|
||||
lax.pow_p: _extern_elementwise(
|
||||
lax.pow_p: _make_dispatch_table(
|
||||
"pow",
|
||||
[
|
||||
_Extern(["float32", "int32"], "__nv_powif", "float32"),
|
||||
@ -655,126 +695,130 @@ triton_lowering_rules.update({
|
||||
_Extern(["float64", "float64"], "__nv_pow", "float64"),
|
||||
],
|
||||
),
|
||||
lax.cbrt_p: _extern_elementwise(
|
||||
lax.cbrt_p: _make_dispatch_table(
|
||||
"cbrt",
|
||||
[
|
||||
_Extern(["float32"], "__nv_cbrtf", "float32"),
|
||||
_Extern(["float64"], "__nv_cbrt", "float64"),
|
||||
],
|
||||
),
|
||||
lax.rsqrt_p: _extern_elementwise(
|
||||
lax.rsqrt_p: _make_dispatch_table(
|
||||
"rsqrt",
|
||||
[
|
||||
_Extern(["float32"], "__nv_rsqrtf", "float32"),
|
||||
_Extern(["float64"], "__nv_rsqrt", "float64"),
|
||||
],
|
||||
),
|
||||
lax.sin_p: _extern_elementwise(
|
||||
lax.sin_p: _make_dispatch_table(
|
||||
"sin",
|
||||
[
|
||||
_Extern(["float32"], "__nv_sinf", "float32"),
|
||||
_Extern(["float64"], "__nv_sin", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)),
|
||||
],
|
||||
),
|
||||
lax.cos_p: _extern_elementwise(
|
||||
lax.cos_p: _make_dispatch_table(
|
||||
"cos",
|
||||
[
|
||||
_Extern(["float32"], "__nv_cosf", "float32"),
|
||||
_Extern(["float64"], "__nv_cos", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)),
|
||||
],
|
||||
),
|
||||
lax.tan_p: _extern_elementwise(
|
||||
lax.tan_p: _make_dispatch_table(
|
||||
"tan",
|
||||
[
|
||||
_Extern(["float32"], "__nv_tanf", "float32"),
|
||||
_Extern(["float64"], "__nv_tan", "float64"),
|
||||
],
|
||||
),
|
||||
lax.asin_p: _extern_elementwise(
|
||||
lax.asin_p: _make_dispatch_table(
|
||||
"asin",
|
||||
[
|
||||
_Extern(["float32"], "__nv_asinf", "float32"),
|
||||
_Extern(["float64"], "__nv_asin", "float64"),
|
||||
],
|
||||
),
|
||||
lax.acos_p: _extern_elementwise(
|
||||
lax.acos_p: _make_dispatch_table(
|
||||
"acos",
|
||||
[
|
||||
_Extern(["float32"], "__nv_acosf", "float32"),
|
||||
_Extern(["float64"], "__nv_acos", "float64"),
|
||||
],
|
||||
),
|
||||
lax.atan_p: _extern_elementwise(
|
||||
lax.atan_p: _make_dispatch_table(
|
||||
"atan",
|
||||
[
|
||||
_Extern(["float32"], "__nv_atanf", "float32"),
|
||||
_Extern(["float64"], "__nv_atan", "float64"),
|
||||
],
|
||||
),
|
||||
lax.atan2_p: _extern_elementwise(
|
||||
lax.atan2_p: _make_dispatch_table(
|
||||
"atan2",
|
||||
[
|
||||
_Extern(["float32", "float32"], "__nv_atan2f", "float32"),
|
||||
_Extern(["float64", "float64"], "__nv_atan2", "float64"),
|
||||
],
|
||||
),
|
||||
lax.sinh_p: _extern_elementwise(
|
||||
lax.sinh_p: _make_dispatch_table(
|
||||
"sinh",
|
||||
[
|
||||
_Extern(["float32"], "__nv_sinhf", "float32"),
|
||||
_Extern(["float64"], "__nv_sinh", "float64"),
|
||||
],
|
||||
),
|
||||
lax.cosh_p: _extern_elementwise(
|
||||
lax.cosh_p: _make_dispatch_table(
|
||||
"cosh",
|
||||
[
|
||||
_Extern(["float32"], "__nv_coshf", "float32"),
|
||||
_Extern(["float64"], "__nv_cosh", "float64"),
|
||||
],
|
||||
),
|
||||
lax.tanh_p: _extern_elementwise(
|
||||
lax.tanh_p: _make_dispatch_table(
|
||||
"tanh",
|
||||
[
|
||||
_Extern(["float32"], "__nv_tanhf", "float32"),
|
||||
_Extern(["float64"], "__nv_tanh", "float64"),
|
||||
],
|
||||
),
|
||||
lax.asinh_p: _extern_elementwise(
|
||||
lax.asinh_p: _make_dispatch_table(
|
||||
"asinh",
|
||||
[
|
||||
_Extern(["float32"], "__nv_asinhf", "float32"),
|
||||
_Extern(["float64"], "__nv_asinh", "float64"),
|
||||
],
|
||||
),
|
||||
lax.acosh_p: _extern_elementwise(
|
||||
lax.acosh_p: _make_dispatch_table(
|
||||
"acosh",
|
||||
[
|
||||
_Extern(["float32"], "__nv_acoshf", "float32"),
|
||||
_Extern(["float64"], "__nv_acosh", "float64"),
|
||||
],
|
||||
),
|
||||
lax.atanh_p: _extern_elementwise(
|
||||
lax.atanh_p: _make_dispatch_table(
|
||||
"atanh",
|
||||
[
|
||||
_Extern(["float32"], "__nv_atanhf", "float32"),
|
||||
_Extern(["float64"], "__nv_atanh", "float64"),
|
||||
],
|
||||
),
|
||||
lax.population_count_p: _extern_elementwise(
|
||||
lax.population_count_p: _make_dispatch_table(
|
||||
"population_count",
|
||||
[
|
||||
_Extern(["int32"], "__nv_popc", "int32"),
|
||||
_Extern(["int64"], "__nv_popcll", "int32"),
|
||||
],
|
||||
),
|
||||
lax.clz_p: _extern_elementwise(
|
||||
lax.clz_p: _make_dispatch_table(
|
||||
"clz",
|
||||
[
|
||||
_Extern(["int32"], "__nv_clz", "int32"),
|
||||
_Extern(["int64"], "__nv_clzll", "int32"),
|
||||
],
|
||||
),
|
||||
lax.nextafter_p: _extern_elementwise(
|
||||
lax.nextafter_p: _make_dispatch_table(
|
||||
"nextafter",
|
||||
[
|
||||
_Extern(["float32", "float32"], "__nv_nextafterf", "float32"),
|
||||
|
@ -1524,15 +1524,22 @@ class PallasCallInterpreterVmapTest(PallasCallVmapTest):
|
||||
class PallasOpsTest(PallasTest):
|
||||
|
||||
ELEMENTWISE_OPS = [
|
||||
([jnp.abs, jnp.negative], ["int32", "int64", "float32", "float64"]),
|
||||
(
|
||||
[jnp.abs, jnp.negative],
|
||||
["int16", "int32", "int64", "float16", "float32", "float64"],
|
||||
),
|
||||
([jnp.ceil, jnp.floor], ["float32", "float64"]),
|
||||
(
|
||||
[jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt],
|
||||
["float16", "float32", "float64"],
|
||||
),
|
||||
(
|
||||
# fmt: off
|
||||
[jnp.ceil, jnp.floor, jnp.exp, jnp.exp2, jnp.expm1, jnp.log1p,
|
||||
jnp.sqrt, jnp.cbrt, lax.rsqrt, jnp.sin, jnp.cos, jnp.tan, jnp.asin,
|
||||
jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh, jnp.asinh,
|
||||
jnp.acosh, jnp.atanh],
|
||||
[jnp.expm1, jnp.log1p, jnp.cbrt, lax.rsqrt, jnp.tan, jnp.asin,
|
||||
jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.asinh, jnp.acosh,
|
||||
jnp.atanh],
|
||||
# fmt: on
|
||||
["float32", "float64"]
|
||||
["float32", "float64"],
|
||||
),
|
||||
([lax.population_count, lax.clz, jnp.invert], ["int32", "int64"]),
|
||||
]
|
||||
@ -1552,7 +1559,7 @@ class PallasOpsTest(PallasTest):
|
||||
with contextlib.ExitStack() as stack:
|
||||
if jnp.dtype(dtype).itemsize == 8:
|
||||
stack.enter_context(config.enable_x64(True))
|
||||
x = jnp.array([4.2, 2.4]).astype(dtype)
|
||||
x = jnp.array([0.42, 2.4]).astype(dtype)
|
||||
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)
|
||||
|
||||
@parameterized.parameters(
|
||||
@ -1611,7 +1618,7 @@ class PallasOpsTest(PallasTest):
|
||||
@parameterized.named_parameters(
|
||||
(f"{fn.__name__}_{dtype}", fn, dtype)
|
||||
for fn, dtype in itertools.product(
|
||||
COMPARISON_OPS, ["int32", "uint32", "float32"]
|
||||
COMPARISON_OPS, ["int32", "uint32", "float16", "float32"]
|
||||
)
|
||||
)
|
||||
def test_comparison(self, fn, dtype):
|
||||
@ -1674,17 +1681,15 @@ class PallasOpsTest(PallasTest):
|
||||
BINARY_OPS = [
|
||||
([jnp.floor_divide], ["int32", "uint32"]),
|
||||
(
|
||||
[jnp.add, jnp.subtract, jnp.multiply, jnp.remainder],
|
||||
["int32", "uint32", "float32"],
|
||||
[jnp.add, jnp.subtract, jnp.multiply],
|
||||
["int16", "int32", "uint32", "float16", "float32"],
|
||||
),
|
||||
([jnp.remainder], ["int32", "uint32", "float32"]),
|
||||
(
|
||||
[
|
||||
jnp.bitwise_and,
|
||||
jnp.bitwise_or,
|
||||
jnp.bitwise_xor,
|
||||
jnp.bitwise_left_shift,
|
||||
jnp.bitwise_right_shift,
|
||||
],
|
||||
# fmt: off
|
||||
[jnp.bitwise_and, jnp.bitwise_or, jnp.bitwise_xor,
|
||||
jnp.bitwise_left_shift, jnp.bitwise_right_shift],
|
||||
# fmt: on
|
||||
["int32", "uint32"],
|
||||
),
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user