[pallas:triton] Fallback lowering rules for math functions now use general dtypes

Previously, it was necessary to list all dtypes explicitly, which is why
we had separate fallback rules for float16 and bfloat16 for some functions.

PiperOrigin-RevId: 722729554
This commit is contained in:
Sergei Lebedev 2025-02-03 11:18:41 -08:00 committed by jax authors
parent 2d7e4ab2dc
commit bf6489ff5b

View File

@ -593,12 +593,19 @@ class _Extern:
def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]):
[out_aval] = ctx.avals_out
bcast_args = []
for aval, arg, arg_type in zip(ctx.avals_in, args, self.arg_types):
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
if aval.weak_type and aval.dtype != jnp.dtype(arg_type):
bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type))
bcast_args.append(bcast_arg)
result_type = _dtype_to_ir_type(jnp.dtype(self.result_type))
if out_aval.shape:
result_type = ir.RankedTensorType.get(out_aval.shape, result_type)
return tt_dialect.extern_elementwise(
result_type,
args,
bcast_args,
libname="",
libpath="",
symbol=self.symbol,
@ -608,10 +615,23 @@ class _Extern:
@dataclasses.dataclass(frozen=True)
class _Fallback:
arg_types: Sequence[jax.typing.DTypeLike]
lower: Callable[..., ir.Value]
arg_classes: Sequence[jax.typing.DTypeLike]
op: Callable[..., ir.Value]
matches = _Extern.matches
def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool:
if len(avals) != len(self.arg_classes):
return False
return all(
jnp.issubdtype(aval.dtype, arg_class)
for aval, arg_class in zip(avals, self.arg_classes)
)
def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]):
[out_aval] = ctx.avals_out
bcast_args = []
for aval, arg in zip(ctx.avals_in, args):
bcast_args.append(_bcast_to(_ensure_ir_value(arg, aval), out_aval.shape))
return self.op(*args)
def _make_dispatch_table(
@ -626,390 +646,452 @@ def _make_dispatch_table(
raise NotImplementedError(
f"unsupported types for {name}: {arg_aval_dtypes}"
)
[out_aval] = ctx.avals_out
bcast_args = []
for aval, arg, arg_type in zip(ctx.avals_in, args, h.arg_types):
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
if aval.weak_type and aval.dtype != jnp.dtype(arg_type):
bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type))
bcast_args.append(bcast_arg)
return h.lower(ctx, *bcast_args)
return h.lower(ctx, *args)
return inner
_abs_dispatch_table = _make_dispatch_table(
abs_dispatch_table = _make_dispatch_table(
"abs",
cuda=[
_Extern([jnp.int32], "__nv_abs", jnp.int32),
_Extern([jnp.int64], "__nv_llabs", jnp.int64),
_Extern([jnp.float32], "__nv_fabsf", jnp.float32),
_Extern([jnp.float64], "__nv_fabs", jnp.float64),
_Fallback([jnp.integer], math_dialect.absi),
_Fallback([jnp.floating], math_dialect.absf),
],
rocm=[
_Fallback([jnp.int32], lambda ctx, x: math_dialect.absi(x)),
_Fallback([jnp.int64], lambda ctx, x: math_dialect.absi(x)),
_Fallback([jnp.float32], lambda ctx, x: math_dialect.absf(x)),
_Fallback([jnp.float64], lambda ctx, x: math_dialect.absf(x)),
_Fallback([jnp.integer], math_dialect.absi),
_Fallback([jnp.floating], math_dialect.absf),
],
)
ceil_dispatch_table = _make_dispatch_table(
"ceil",
cuda=[
_Extern([jnp.float32], "__nv_ceilf", jnp.float32),
_Extern([jnp.float64], "__nv_ceil", jnp.float64),
_Fallback([jnp.floating], math_dialect.ceil),
],
rocm=[
_Extern([jnp.float32], "__ocml_ceil_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_ceil_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.ceil),
],
)
@register_lowering(lax.abs_p)
def _abs_lowering_rule(ctx: LoweringRuleContext, x):
try:
return _abs_dispatch_table(ctx, x)
except NotImplementedError as e:
[x_aval] = ctx.avals_in
if jnp.issubdtype(x_aval, jnp.integer):
return math_dialect.absi(x)
elif jnp.issubdtype(x_aval, jnp.floating):
return math_dialect.absf(x)
else:
raise e from None
floor_dispatch_table = _make_dispatch_table(
"floor",
cuda=[
_Extern([jnp.float32], "__nv_floorf", jnp.float32),
_Extern([jnp.float64], "__nv_floor", jnp.float64),
_Fallback([jnp.floating], math_dialect.floor),
],
rocm=[
_Extern([jnp.float32], "__ocml_floor_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_floor_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.floor),
],
)
exp_dispatch_table = _make_dispatch_table(
"exp",
cuda=[
_Extern([jnp.float32], "__nv_expf", jnp.float32),
_Extern([jnp.float64], "__nv_exp", jnp.float64),
_Fallback([jnp.floating], math_dialect.exp),
],
rocm=[
_Fallback([jnp.float32], math_dialect.exp),
_Fallback([jnp.float64], math_dialect.exp),
_Fallback([jnp.floating], math_dialect.exp),
],
)
exp2_dispatch_table = _make_dispatch_table(
"exp2",
cuda=[
_Extern([jnp.float32], "__nv_exp2f", jnp.float32),
_Extern([jnp.float64], "__nv_exp2", jnp.float64),
_Fallback([jnp.floating], math_dialect.exp2),
],
rocm=[
_Extern([jnp.float32], "__ocml_exp2_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_exp2_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.exp2),
],
)
expm1_dispatch_table = _make_dispatch_table(
"expm1",
cuda=[
_Extern([jnp.float32], "__nv_expm1f", jnp.float32),
_Extern([jnp.float64], "__nv_expm1", jnp.float64),
_Fallback([jnp.floating], math_dialect.expm1),
],
rocm=[
_Extern([jnp.float32], "__ocml_expm1_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_expm1_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.expm1),
],
)
log_dispatch_table = _make_dispatch_table(
"log",
cuda=[
_Extern([jnp.float32], "__nv_logf", jnp.float32),
_Extern([jnp.float64], "__nv_log", jnp.float64),
_Fallback([jnp.floating], math_dialect.log),
],
rocm=[
_Extern([jnp.float32], "__ocml_log_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_log_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.log),
],
)
log1p_dispatch_table = _make_dispatch_table(
"log1p",
cuda=[
_Extern([jnp.float32], "__nv_log1pf", jnp.float32),
_Extern([jnp.float64], "__nv_log1p", jnp.float64),
_Fallback([jnp.floating], math_dialect.log1p),
],
rocm=[
_Extern([jnp.float32], "__ocml_log1p_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_log1p_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.log1p),
],
)
sqrt_dispatch_table = _make_dispatch_table(
"sqrt",
cuda=[
_Extern([jnp.float32], "__nv_sqrtf", jnp.float32),
_Extern([jnp.float64], "__nv_sqrt", jnp.float64),
_Fallback([jnp.floating], math_dialect.sqrt),
],
rocm=[
_Extern([jnp.float32], "__ocml_sqrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sqrt_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.sqrt),
],
)
pow_dispatch_table = _make_dispatch_table(
"pow",
cuda=[
_Extern([jnp.float32, jnp.int32], "__nv_powif", jnp.float32),
_Extern([jnp.float64, jnp.int32], "__nv_powi", jnp.float64),
_Fallback(
[jnp.floating, jnp.integer],
lambda ctx, x, y: math_dialect.fpowi(x, y),
),
_Extern([jnp.float32, jnp.float32], "__nv_powf", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_pow", jnp.float64),
_Fallback(
[jnp.floating, jnp.floating],
lambda ctx, x, y: math_dialect.powf(x, y),
),
],
rocm=[
_Extern([jnp.float32, jnp.int32], "__ocml_pown_f32", jnp.float32),
_Extern([jnp.float64, jnp.int32], "__ocml_pown_f64", jnp.float64),
_Fallback(
[jnp.floating, jnp.integer],
lambda ctx, x, y: math_dialect.fpowi(x, y),
),
_Extern([jnp.float32, jnp.float32], "__ocml_pow_f32", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__ocml_pow_f64", jnp.float64),
_Fallback(
[jnp.floating, jnp.floating],
lambda ctx, x, y: math_dialect.powf(x, y),
),
],
)
cbrt_dispatch_table = _make_dispatch_table(
"cbrt",
cuda=[
_Extern([jnp.float32], "__nv_cbrtf", jnp.float32),
_Extern([jnp.float64], "__nv_cbrt", jnp.float64),
_Fallback([jnp.floating], math_dialect.cbrt),
],
rocm=[
_Extern([jnp.float32], "__ocml_cbrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cbrt_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.cbrt),
],
)
rsqrt_dispatch_table = _make_dispatch_table(
"rsqrt",
cuda=[
_Extern([jnp.float32], "__nv_rsqrtf", jnp.float32),
_Extern([jnp.float64], "__nv_rsqrt", jnp.float64),
_Fallback([jnp.floating], math_dialect.rsqrt),
],
rocm=[
_Extern([jnp.float32], "__ocml_rsqrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_rsqrt_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.rsqrt),
],
)
sin_dispatch_table = _make_dispatch_table(
"sin",
cuda=[
_Extern([jnp.float32], "__nv_sinf", jnp.float32),
_Extern([jnp.float64], "__nv_sin", jnp.float64),
_Fallback([jnp.floating], math_dialect.sin),
],
rocm=[
_Extern([jnp.float32], "__ocml_sin_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sin_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.sin),
],
)
cos_dispatch_table = _make_dispatch_table(
"cos",
cuda=[
_Extern([jnp.float32], "__nv_cosf", jnp.float32),
_Extern([jnp.float64], "__nv_cos", jnp.float64),
_Fallback([jnp.floating], math_dialect.cos),
],
rocm=[
_Extern([jnp.float32], "__ocml_cos_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cos_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.cos),
],
)
tan_dispatch_table = _make_dispatch_table(
"tan",
cuda=[
_Extern([jnp.float32], "__nv_tanf", jnp.float32),
_Extern([jnp.float64], "__nv_tan", jnp.float64),
_Fallback([jnp.floating], math_dialect.tan),
],
rocm=[
_Extern([jnp.float32], "__ocml_tan_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_tan_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.tan),
],
)
asin_dispatch_table = _make_dispatch_table(
"asin",
cuda=[
_Extern([jnp.float32], "__nv_asinf", jnp.float32),
_Extern([jnp.float64], "__nv_asin", jnp.float64),
_Fallback([jnp.floating], math_dialect.asin),
],
rocm=[
_Extern([jnp.float32], "__ocml_asin_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_asin_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.asin),
],
)
acos_dispatch_table = _make_dispatch_table(
"acos",
cuda=[
_Extern([jnp.float32], "__nv_acosf", jnp.float32),
_Extern([jnp.float64], "__nv_acos", jnp.float64),
_Fallback([jnp.floating], math_dialect.acos),
],
rocm=[
_Extern([jnp.float32], "__ocml_acos_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_acos_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.acos),
],
)
atan_dispatch_table = _make_dispatch_table(
"atan",
cuda=[
_Extern([jnp.float32], "__nv_atanf", jnp.float32),
_Extern([jnp.float64], "__nv_atan", jnp.float64),
_Fallback([jnp.floating], math_dialect.atan),
],
rocm=[
_Extern([jnp.float32], "__ocml_atan_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_atan_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.atan),
],
)
atan2_dispatch_table = _make_dispatch_table(
"atan2",
cuda=[
_Extern([jnp.float32, jnp.float32], "__nv_atan2f", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_atan2", jnp.float64),
_Fallback([jnp.floating, jnp.floating], math_dialect.atan2),
],
rocm=[
_Extern([jnp.float32, jnp.float32], "__ocml_atan2_f32", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__ocml_atan2_f64", jnp.float64),
_Fallback([jnp.floating, jnp.floating], math_dialect.atan2),
],
)
sinh_dispatch_table = _make_dispatch_table(
"sinh",
cuda=[
_Extern([jnp.float32], "__nv_sinhf", jnp.float32),
_Extern([jnp.float64], "__nv_sinh", jnp.float64),
_Fallback([jnp.floating], math_dialect.sinh),
],
rocm=[
_Extern([jnp.float32], "__ocml_sinh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sinh_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.sinh),
],
)
cosh_dispatch_table = _make_dispatch_table(
"cosh",
cuda=[
_Extern([jnp.float32], "__nv_coshf", jnp.float32),
_Extern([jnp.float64], "__nv_cosh", jnp.float64),
_Fallback([jnp.floating], math_dialect.cosh),
],
rocm=[
_Extern([jnp.float32], "__ocml_cosh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cosh_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.cosh),
],
)
tanh_dispatch_table = _make_dispatch_table(
"tanh",
cuda=[
_Extern([jnp.float32], "__nv_tanhf", jnp.float32),
_Extern([jnp.float64], "__nv_tanh", jnp.float64),
_Fallback([jnp.floating], math_dialect.tanh),
],
rocm=[
_Extern([jnp.float32], "__ocml_tanh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_tanh_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.tanh),
],
)
asinh_dispatch_table = _make_dispatch_table(
"asinh",
cuda=[
_Extern([jnp.float32], "__nv_asinhf", jnp.float32),
_Extern([jnp.float64], "__nv_asinh", jnp.float64),
_Fallback([jnp.floating], math_dialect.asinh),
],
rocm=[
_Extern([jnp.float32], "__ocml_asinh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_asinh_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.asinh),
],
)
acosh_dispatch_table = _make_dispatch_table(
"acosh",
cuda=[
_Extern([jnp.float32], "__nv_acoshf", jnp.float32),
_Extern([jnp.float64], "__nv_acosh", jnp.float64),
_Fallback([jnp.floating], math_dialect.acosh),
],
rocm=[
_Extern([jnp.float32], "__ocml_acosh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_acosh_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.acosh),
],
)
atanh_dispatch_table = _make_dispatch_table(
"atanh",
cuda=[
_Extern([jnp.float32], "__nv_atanhf", jnp.float32),
_Extern([jnp.float64], "__nv_atanh", jnp.float64),
_Fallback([jnp.floating], math_dialect.atanh),
],
rocm=[
_Extern([jnp.float32], "__ocml_atanh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_atanh_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.atanh),
],
)
population_count_dispatch_table = _make_dispatch_table(
"population_count",
cuda=[
_Extern([jnp.int32], "__nv_popc", jnp.int32),
_Extern([jnp.int64], "__nv_popcll", jnp.int32),
_Fallback([jnp.integer], math_dialect.ctpop),
],
rocm=[
_Fallback([jnp.integer], math_dialect.ctpop),
],
)
clz_dispatch_table = _make_dispatch_table(
"clz",
cuda=[
_Extern([jnp.int32], "__nv_clz", jnp.int32),
_Extern([jnp.int64], "__nv_clzll", jnp.int32),
_Fallback([jnp.integer], math_dialect.ctlz),
],
rocm=[
_Fallback([jnp.integer], math_dialect.ctlz),
],
)
nextafter_dispatch_table = _make_dispatch_table(
"nextafter",
cuda=[
_Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_nextafter", jnp.float64),
],
rocm=[
_Extern(
[jnp.float32, jnp.float32], "__ocml_nextafter_f32", jnp.float32
),
_Extern(
[jnp.float64, jnp.float64], "__ocml_nextafter_f64", jnp.float64
),
],
)
triton_lowering_rules.update({
lax.abs_p: abs_dispatch_table,
lax.neg_p: lambda ctx, x: _minus(x),
lax.ceil_p: _make_dispatch_table(
"ceil",
cuda=[
_Extern([jnp.float32], "__nv_ceilf", jnp.float32),
_Extern([jnp.float64], "__nv_ceil", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_ceil_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_ceil_f64", jnp.float64),
],
),
lax.floor_p: _make_dispatch_table(
"floor",
cuda=[
_Extern([jnp.float32], "__nv_floorf", jnp.float32),
_Extern([jnp.float64], "__nv_floor", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_floor_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_floor_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)),
],
),
lax.exp_p: _make_dispatch_table(
"exp",
cuda=[
_Extern([jnp.float32], "__nv_expf", jnp.float32),
_Extern([jnp.float64], "__nv_exp", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)),
],
rocm=[
_Fallback([jnp.float32], lambda ctx, x: math_dialect.exp(x)),
_Fallback([jnp.float64], lambda ctx, x: math_dialect.exp(x)),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)),
],
),
lax.exp2_p: _make_dispatch_table(
"exp2",
cuda=[
_Extern([jnp.float32], "__nv_exp2f", jnp.float32),
_Extern([jnp.float64], "__nv_exp2", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_exp2_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_exp2_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)),
],
),
lax.expm1_p: _make_dispatch_table(
"expm1",
cuda=[
_Extern([jnp.float32], "__nv_expm1f", jnp.float32),
_Extern([jnp.float64], "__nv_expm1", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_expm1_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_expm1_f64", jnp.float64),
],
),
lax.log_p: _make_dispatch_table(
"log",
cuda=[
_Extern([jnp.float32], "__nv_logf", jnp.float32),
_Extern([jnp.float64], "__nv_log", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_log_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_log_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)),
],
),
lax.log1p_p: _make_dispatch_table(
"log1p",
cuda=[
_Extern([jnp.float32], "__nv_log1pf", jnp.float32),
_Extern([jnp.float64], "__nv_log1p", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_log1p_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_log1p_f64", jnp.float64),
],
),
lax.sqrt_p: _make_dispatch_table(
"sqrt",
cuda=[
_Extern([jnp.float32], "__nv_sqrtf", jnp.float32),
_Extern([jnp.float64], "__nv_sqrt", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_sqrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sqrt_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)),
],
),
lax.ceil_p: ceil_dispatch_table,
lax.floor_p: floor_dispatch_table,
lax.exp_p: exp_dispatch_table,
lax.exp2_p: exp2_dispatch_table,
lax.expm1_p: expm1_dispatch_table,
lax.log_p: log_dispatch_table,
lax.log1p_p: log1p_dispatch_table,
lax.sqrt_p: sqrt_dispatch_table,
lax.square_p: lambda ctx, x: _mul(x, x),
lax.pow_p: _make_dispatch_table(
"pow",
cuda=[
_Extern([jnp.float32, jnp.int32], "__nv_powif", jnp.float32),
_Extern([jnp.float64, jnp.int32], "__nv_powi", jnp.float64),
_Extern([jnp.float32, jnp.float32], "__nv_powf", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_pow", jnp.float64),
],
rocm=[
_Extern([jnp.float32, jnp.int32], "__ocml_pown_f32", jnp.float32),
_Extern([jnp.float64, jnp.int32], "__ocml_pown_f64", jnp.float64),
_Extern([jnp.float32, jnp.float32], "__ocml_pow_f32", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__ocml_pow_f64", jnp.float64),
],
),
lax.cbrt_p: _make_dispatch_table(
"cbrt",
cuda=[
_Extern([jnp.float32], "__nv_cbrtf", jnp.float32),
_Extern([jnp.float64], "__nv_cbrt", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_cbrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cbrt_f64", jnp.float64),
],
),
lax.rsqrt_p: _make_dispatch_table(
"rsqrt",
cuda=[
_Extern([jnp.float32], "__nv_rsqrtf", jnp.float32),
_Extern([jnp.float64], "__nv_rsqrt", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_rsqrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_rsqrt_f64", jnp.float64),
],
),
lax.sin_p: _make_dispatch_table(
"sin",
cuda=[
_Extern([jnp.float32], "__nv_sinf", jnp.float32),
_Extern([jnp.float64], "__nv_sin", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_sin_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sin_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)),
],
),
lax.cos_p: _make_dispatch_table(
"cos",
cuda=[
_Extern([jnp.float32], "__nv_cosf", jnp.float32),
_Extern([jnp.float64], "__nv_cos", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_cos_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cos_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)),
],
),
lax.tan_p: _make_dispatch_table(
"tan",
cuda=[
_Extern([jnp.float32], "__nv_tanf", jnp.float32),
_Extern([jnp.float64], "__nv_tan", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_tan_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_tan_f64", jnp.float64),
],
),
lax.asin_p: _make_dispatch_table(
"asin",
cuda=[
_Extern([jnp.float32], "__nv_asinf", jnp.float32),
_Extern([jnp.float64], "__nv_asin", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_asin_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_asin_f64", jnp.float64),
],
),
lax.acos_p: _make_dispatch_table(
"acos",
cuda=[
_Extern([jnp.float32], "__nv_acosf", jnp.float32),
_Extern([jnp.float64], "__nv_acos", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_acos_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_acos_f64", jnp.float64),
],
),
lax.atan_p: _make_dispatch_table(
"atan",
cuda=[
_Extern([jnp.float32], "__nv_atanf", jnp.float32),
_Extern([jnp.float64], "__nv_atan", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_atan_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_atan_f64", jnp.float64),
],
),
lax.atan2_p: _make_dispatch_table(
"atan2",
cuda=[
_Extern([jnp.float32, jnp.float32], "__nv_atan2f", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_atan2", jnp.float64),
],
rocm=[
_Extern(
[jnp.float32, jnp.float32], "__ocml_atan2_f32", jnp.float32
),
_Extern(
[jnp.float64, jnp.float64], "__ocml_atan2_f64", jnp.float64
),
],
),
lax.sinh_p: _make_dispatch_table(
"sinh",
cuda=[
_Extern([jnp.float32], "__nv_sinhf", jnp.float32),
_Extern([jnp.float64], "__nv_sinh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_sinh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sinh_f64", jnp.float64),
],
),
lax.cosh_p: _make_dispatch_table(
"cosh",
cuda=[
_Extern([jnp.float32], "__nv_coshf", jnp.float32),
_Extern([jnp.float64], "__nv_cosh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_cosh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cosh_f64", jnp.float64),
],
),
lax.tanh_p: _make_dispatch_table(
"tanh",
cuda=[
_Extern([jnp.float32], "__nv_tanhf", jnp.float32),
_Extern([jnp.float64], "__nv_tanh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_tanh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_tanh_f64", jnp.float64),
],
),
lax.asinh_p: _make_dispatch_table(
"asinh",
cuda=[
_Extern([jnp.float32], "__nv_asinhf", jnp.float32),
_Extern([jnp.float64], "__nv_asinh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_asinh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_asinh_f64", jnp.float64),
],
),
lax.acosh_p: _make_dispatch_table(
"acosh",
cuda=[
_Extern([jnp.float32], "__nv_acoshf", jnp.float32),
_Extern([jnp.float64], "__nv_acosh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_acosh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_acosh_f64", jnp.float64),
],
),
lax.atanh_p: _make_dispatch_table(
"atanh",
cuda=[
_Extern([jnp.float32], "__nv_atanhf", jnp.float32),
_Extern([jnp.float64], "__nv_atanh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_atanh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_atanh_f64", jnp.float64),
],
),
lax.population_count_p: _make_dispatch_table(
"population_count",
cuda=[
_Extern([jnp.int32], "__nv_popc", jnp.int32),
_Extern([jnp.int64], "__nv_popcll", jnp.int32),
],
rocm=[
_Fallback([jnp.int32], lambda ctx, x: math_dialect.ctpop(x)),
_Fallback([jnp.int64], lambda ctx, x: math_dialect.ctpop(x)),
],
),
lax.clz_p: _make_dispatch_table(
"clz",
cuda=[
_Extern([jnp.int32], "__nv_clz", jnp.int32),
_Extern([jnp.int64], "__nv_clzll", jnp.int32),
],
rocm=[
_Fallback([jnp.int32], lambda ctx, x: math_dialect.ctlz(x)),
_Fallback([jnp.int64], lambda ctx, x: math_dialect.ctlz(x)),
],
),
lax.nextafter_p: _make_dispatch_table(
"nextafter",
cuda=[
_Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_nextafter", jnp.float64),
],
rocm=[
_Extern(
[jnp.float32, jnp.float32], "__ocml_nextafter_f32", jnp.float32
),
_Extern(
[jnp.float64, jnp.float64], "__ocml_nextafter_f64", jnp.float64
),
],
),
lax.pow_p: pow_dispatch_table,
lax.cbrt_p: cbrt_dispatch_table,
lax.rsqrt_p: rsqrt_dispatch_table,
lax.sin_p: sin_dispatch_table,
lax.cos_p: cos_dispatch_table,
lax.tan_p: tan_dispatch_table,
lax.asin_p: asin_dispatch_table,
lax.acos_p: acos_dispatch_table,
lax.atan_p: atan_dispatch_table,
lax.atan2_p: atan2_dispatch_table,
lax.sinh_p: sinh_dispatch_table,
lax.cosh_p: cosh_dispatch_table,
lax.tanh_p: tanh_dispatch_table,
lax.asinh_p: asinh_dispatch_table,
lax.acosh_p: acosh_dispatch_table,
lax.atanh_p: atanh_dispatch_table,
lax.population_count_p: population_count_dispatch_table,
lax.clz_p: clz_dispatch_table,
lax.nextafter_p: nextafter_dispatch_table,
})