mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[pallas:triton] Fallback lowering rules for math functions now use general dtypes
Previously, it was necessary to list all dtypes explicitly, which is why we had separate fallback rules for float16 and bfloat16 for some functions. PiperOrigin-RevId: 722729554
This commit is contained in:
parent
2d7e4ab2dc
commit
bf6489ff5b
@ -593,12 +593,19 @@ class _Extern:
|
||||
|
||||
def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]):
|
||||
[out_aval] = ctx.avals_out
|
||||
bcast_args = []
|
||||
for aval, arg, arg_type in zip(ctx.avals_in, args, self.arg_types):
|
||||
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
|
||||
if aval.weak_type and aval.dtype != jnp.dtype(arg_type):
|
||||
bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type))
|
||||
bcast_args.append(bcast_arg)
|
||||
|
||||
result_type = _dtype_to_ir_type(jnp.dtype(self.result_type))
|
||||
if out_aval.shape:
|
||||
result_type = ir.RankedTensorType.get(out_aval.shape, result_type)
|
||||
return tt_dialect.extern_elementwise(
|
||||
result_type,
|
||||
args,
|
||||
bcast_args,
|
||||
libname="",
|
||||
libpath="",
|
||||
symbol=self.symbol,
|
||||
@ -608,10 +615,23 @@ class _Extern:
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _Fallback:
|
||||
arg_types: Sequence[jax.typing.DTypeLike]
|
||||
lower: Callable[..., ir.Value]
|
||||
arg_classes: Sequence[jax.typing.DTypeLike]
|
||||
op: Callable[..., ir.Value]
|
||||
|
||||
matches = _Extern.matches
|
||||
def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool:
|
||||
if len(avals) != len(self.arg_classes):
|
||||
return False
|
||||
return all(
|
||||
jnp.issubdtype(aval.dtype, arg_class)
|
||||
for aval, arg_class in zip(avals, self.arg_classes)
|
||||
)
|
||||
|
||||
def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]):
|
||||
[out_aval] = ctx.avals_out
|
||||
bcast_args = []
|
||||
for aval, arg in zip(ctx.avals_in, args):
|
||||
bcast_args.append(_bcast_to(_ensure_ir_value(arg, aval), out_aval.shape))
|
||||
return self.op(*args)
|
||||
|
||||
|
||||
def _make_dispatch_table(
|
||||
@ -626,390 +646,452 @@ def _make_dispatch_table(
|
||||
raise NotImplementedError(
|
||||
f"unsupported types for {name}: {arg_aval_dtypes}"
|
||||
)
|
||||
|
||||
[out_aval] = ctx.avals_out
|
||||
bcast_args = []
|
||||
for aval, arg, arg_type in zip(ctx.avals_in, args, h.arg_types):
|
||||
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
|
||||
if aval.weak_type and aval.dtype != jnp.dtype(arg_type):
|
||||
bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type))
|
||||
bcast_args.append(bcast_arg)
|
||||
return h.lower(ctx, *bcast_args)
|
||||
return h.lower(ctx, *args)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
_abs_dispatch_table = _make_dispatch_table(
|
||||
abs_dispatch_table = _make_dispatch_table(
|
||||
"abs",
|
||||
cuda=[
|
||||
_Extern([jnp.int32], "__nv_abs", jnp.int32),
|
||||
_Extern([jnp.int64], "__nv_llabs", jnp.int64),
|
||||
_Extern([jnp.float32], "__nv_fabsf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_fabs", jnp.float64),
|
||||
_Fallback([jnp.integer], math_dialect.absi),
|
||||
_Fallback([jnp.floating], math_dialect.absf),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback([jnp.int32], lambda ctx, x: math_dialect.absi(x)),
|
||||
_Fallback([jnp.int64], lambda ctx, x: math_dialect.absi(x)),
|
||||
_Fallback([jnp.float32], lambda ctx, x: math_dialect.absf(x)),
|
||||
_Fallback([jnp.float64], lambda ctx, x: math_dialect.absf(x)),
|
||||
_Fallback([jnp.integer], math_dialect.absi),
|
||||
_Fallback([jnp.floating], math_dialect.absf),
|
||||
],
|
||||
)
|
||||
|
||||
ceil_dispatch_table = _make_dispatch_table(
|
||||
"ceil",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_ceilf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_ceil", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.ceil),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_ceil_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_ceil_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.ceil),
|
||||
],
|
||||
)
|
||||
|
||||
@register_lowering(lax.abs_p)
|
||||
def _abs_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
try:
|
||||
return _abs_dispatch_table(ctx, x)
|
||||
except NotImplementedError as e:
|
||||
[x_aval] = ctx.avals_in
|
||||
if jnp.issubdtype(x_aval, jnp.integer):
|
||||
return math_dialect.absi(x)
|
||||
elif jnp.issubdtype(x_aval, jnp.floating):
|
||||
return math_dialect.absf(x)
|
||||
else:
|
||||
raise e from None
|
||||
floor_dispatch_table = _make_dispatch_table(
|
||||
"floor",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_floorf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_floor", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.floor),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_floor_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_floor_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.floor),
|
||||
],
|
||||
)
|
||||
|
||||
exp_dispatch_table = _make_dispatch_table(
|
||||
"exp",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_expf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_exp", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.exp),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback([jnp.float32], math_dialect.exp),
|
||||
_Fallback([jnp.float64], math_dialect.exp),
|
||||
_Fallback([jnp.floating], math_dialect.exp),
|
||||
],
|
||||
)
|
||||
|
||||
exp2_dispatch_table = _make_dispatch_table(
|
||||
"exp2",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_exp2f", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_exp2", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.exp2),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_exp2_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_exp2_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.exp2),
|
||||
],
|
||||
)
|
||||
|
||||
expm1_dispatch_table = _make_dispatch_table(
|
||||
"expm1",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_expm1f", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_expm1", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.expm1),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_expm1_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_expm1_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.expm1),
|
||||
],
|
||||
)
|
||||
|
||||
log_dispatch_table = _make_dispatch_table(
|
||||
"log",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_logf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_log", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.log),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_log_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_log_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.log),
|
||||
],
|
||||
)
|
||||
|
||||
log1p_dispatch_table = _make_dispatch_table(
|
||||
"log1p",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_log1pf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_log1p", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.log1p),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_log1p_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_log1p_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.log1p),
|
||||
],
|
||||
)
|
||||
|
||||
sqrt_dispatch_table = _make_dispatch_table(
|
||||
"sqrt",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_sqrtf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_sqrt", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.sqrt),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_sqrt_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_sqrt_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.sqrt),
|
||||
],
|
||||
)
|
||||
|
||||
pow_dispatch_table = _make_dispatch_table(
|
||||
"pow",
|
||||
cuda=[
|
||||
_Extern([jnp.float32, jnp.int32], "__nv_powif", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.int32], "__nv_powi", jnp.float64),
|
||||
_Fallback(
|
||||
[jnp.floating, jnp.integer],
|
||||
lambda ctx, x, y: math_dialect.fpowi(x, y),
|
||||
),
|
||||
_Extern([jnp.float32, jnp.float32], "__nv_powf", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__nv_pow", jnp.float64),
|
||||
_Fallback(
|
||||
[jnp.floating, jnp.floating],
|
||||
lambda ctx, x, y: math_dialect.powf(x, y),
|
||||
),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32, jnp.int32], "__ocml_pown_f32", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.int32], "__ocml_pown_f64", jnp.float64),
|
||||
_Fallback(
|
||||
[jnp.floating, jnp.integer],
|
||||
lambda ctx, x, y: math_dialect.fpowi(x, y),
|
||||
),
|
||||
_Extern([jnp.float32, jnp.float32], "__ocml_pow_f32", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__ocml_pow_f64", jnp.float64),
|
||||
_Fallback(
|
||||
[jnp.floating, jnp.floating],
|
||||
lambda ctx, x, y: math_dialect.powf(x, y),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
cbrt_dispatch_table = _make_dispatch_table(
|
||||
"cbrt",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_cbrtf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_cbrt", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.cbrt),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_cbrt_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_cbrt_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.cbrt),
|
||||
],
|
||||
)
|
||||
|
||||
rsqrt_dispatch_table = _make_dispatch_table(
|
||||
"rsqrt",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_rsqrtf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_rsqrt", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.rsqrt),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_rsqrt_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_rsqrt_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.rsqrt),
|
||||
],
|
||||
)
|
||||
|
||||
sin_dispatch_table = _make_dispatch_table(
|
||||
"sin",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_sinf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_sin", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.sin),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_sin_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_sin_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.sin),
|
||||
],
|
||||
)
|
||||
|
||||
cos_dispatch_table = _make_dispatch_table(
|
||||
"cos",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_cosf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_cos", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.cos),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_cos_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_cos_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.cos),
|
||||
],
|
||||
)
|
||||
|
||||
tan_dispatch_table = _make_dispatch_table(
|
||||
"tan",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_tanf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_tan", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.tan),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_tan_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_tan_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.tan),
|
||||
],
|
||||
)
|
||||
|
||||
asin_dispatch_table = _make_dispatch_table(
|
||||
"asin",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_asinf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_asin", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.asin),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_asin_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_asin_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.asin),
|
||||
],
|
||||
)
|
||||
|
||||
acos_dispatch_table = _make_dispatch_table(
|
||||
"acos",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_acosf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_acos", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.acos),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_acos_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_acos_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.acos),
|
||||
],
|
||||
)
|
||||
|
||||
atan_dispatch_table = _make_dispatch_table(
|
||||
"atan",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_atanf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_atan", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.atan),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_atan_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_atan_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.atan),
|
||||
],
|
||||
)
|
||||
|
||||
atan2_dispatch_table = _make_dispatch_table(
|
||||
"atan2",
|
||||
cuda=[
|
||||
_Extern([jnp.float32, jnp.float32], "__nv_atan2f", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__nv_atan2", jnp.float64),
|
||||
_Fallback([jnp.floating, jnp.floating], math_dialect.atan2),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32, jnp.float32], "__ocml_atan2_f32", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__ocml_atan2_f64", jnp.float64),
|
||||
_Fallback([jnp.floating, jnp.floating], math_dialect.atan2),
|
||||
],
|
||||
)
|
||||
|
||||
sinh_dispatch_table = _make_dispatch_table(
|
||||
"sinh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_sinhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_sinh", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.sinh),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_sinh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_sinh_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.sinh),
|
||||
],
|
||||
)
|
||||
|
||||
cosh_dispatch_table = _make_dispatch_table(
|
||||
"cosh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_coshf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_cosh", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.cosh),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_cosh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_cosh_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.cosh),
|
||||
],
|
||||
)
|
||||
|
||||
tanh_dispatch_table = _make_dispatch_table(
|
||||
"tanh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_tanhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_tanh", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.tanh),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_tanh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_tanh_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.tanh),
|
||||
],
|
||||
)
|
||||
|
||||
asinh_dispatch_table = _make_dispatch_table(
|
||||
"asinh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_asinhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_asinh", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.asinh),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_asinh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_asinh_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.asinh),
|
||||
],
|
||||
)
|
||||
|
||||
acosh_dispatch_table = _make_dispatch_table(
|
||||
"acosh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_acoshf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_acosh", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.acosh),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_acosh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_acosh_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.acosh),
|
||||
],
|
||||
)
|
||||
|
||||
atanh_dispatch_table = _make_dispatch_table(
|
||||
"atanh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_atanhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_atanh", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.atanh),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_atanh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_atanh_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.atanh),
|
||||
],
|
||||
)
|
||||
|
||||
population_count_dispatch_table = _make_dispatch_table(
|
||||
"population_count",
|
||||
cuda=[
|
||||
_Extern([jnp.int32], "__nv_popc", jnp.int32),
|
||||
_Extern([jnp.int64], "__nv_popcll", jnp.int32),
|
||||
_Fallback([jnp.integer], math_dialect.ctpop),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback([jnp.integer], math_dialect.ctpop),
|
||||
],
|
||||
)
|
||||
|
||||
clz_dispatch_table = _make_dispatch_table(
|
||||
"clz",
|
||||
cuda=[
|
||||
_Extern([jnp.int32], "__nv_clz", jnp.int32),
|
||||
_Extern([jnp.int64], "__nv_clzll", jnp.int32),
|
||||
_Fallback([jnp.integer], math_dialect.ctlz),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback([jnp.integer], math_dialect.ctlz),
|
||||
],
|
||||
)
|
||||
|
||||
nextafter_dispatch_table = _make_dispatch_table(
|
||||
"nextafter",
|
||||
cuda=[
|
||||
_Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__nv_nextafter", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern(
|
||||
[jnp.float32, jnp.float32], "__ocml_nextafter_f32", jnp.float32
|
||||
),
|
||||
_Extern(
|
||||
[jnp.float64, jnp.float64], "__ocml_nextafter_f64", jnp.float64
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
triton_lowering_rules.update({
|
||||
lax.abs_p: abs_dispatch_table,
|
||||
lax.neg_p: lambda ctx, x: _minus(x),
|
||||
lax.ceil_p: _make_dispatch_table(
|
||||
"ceil",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_ceilf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_ceil", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_ceil_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_ceil_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.floor_p: _make_dispatch_table(
|
||||
"floor",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_floorf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_floor", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_floor_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_floor_f64", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)),
|
||||
],
|
||||
),
|
||||
lax.exp_p: _make_dispatch_table(
|
||||
"exp",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_expf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_exp", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback([jnp.float32], lambda ctx, x: math_dialect.exp(x)),
|
||||
_Fallback([jnp.float64], lambda ctx, x: math_dialect.exp(x)),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)),
|
||||
],
|
||||
),
|
||||
lax.exp2_p: _make_dispatch_table(
|
||||
"exp2",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_exp2f", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_exp2", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_exp2_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_exp2_f64", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)),
|
||||
],
|
||||
),
|
||||
lax.expm1_p: _make_dispatch_table(
|
||||
"expm1",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_expm1f", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_expm1", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_expm1_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_expm1_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.log_p: _make_dispatch_table(
|
||||
"log",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_logf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_log", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_log_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_log_f64", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)),
|
||||
],
|
||||
),
|
||||
lax.log1p_p: _make_dispatch_table(
|
||||
"log1p",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_log1pf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_log1p", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_log1p_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_log1p_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.sqrt_p: _make_dispatch_table(
|
||||
"sqrt",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_sqrtf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_sqrt", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_sqrt_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_sqrt_f64", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
],
|
||||
),
|
||||
lax.ceil_p: ceil_dispatch_table,
|
||||
lax.floor_p: floor_dispatch_table,
|
||||
lax.exp_p: exp_dispatch_table,
|
||||
lax.exp2_p: exp2_dispatch_table,
|
||||
lax.expm1_p: expm1_dispatch_table,
|
||||
lax.log_p: log_dispatch_table,
|
||||
lax.log1p_p: log1p_dispatch_table,
|
||||
lax.sqrt_p: sqrt_dispatch_table,
|
||||
lax.square_p: lambda ctx, x: _mul(x, x),
|
||||
lax.pow_p: _make_dispatch_table(
|
||||
"pow",
|
||||
cuda=[
|
||||
_Extern([jnp.float32, jnp.int32], "__nv_powif", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.int32], "__nv_powi", jnp.float64),
|
||||
_Extern([jnp.float32, jnp.float32], "__nv_powf", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__nv_pow", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32, jnp.int32], "__ocml_pown_f32", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.int32], "__ocml_pown_f64", jnp.float64),
|
||||
_Extern([jnp.float32, jnp.float32], "__ocml_pow_f32", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__ocml_pow_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.cbrt_p: _make_dispatch_table(
|
||||
"cbrt",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_cbrtf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_cbrt", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_cbrt_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_cbrt_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.rsqrt_p: _make_dispatch_table(
|
||||
"rsqrt",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_rsqrtf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_rsqrt", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_rsqrt_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_rsqrt_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.sin_p: _make_dispatch_table(
|
||||
"sin",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_sinf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_sin", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_sin_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_sin_f64", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)),
|
||||
],
|
||||
),
|
||||
lax.cos_p: _make_dispatch_table(
|
||||
"cos",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_cosf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_cos", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_cos_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_cos_f64", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)),
|
||||
],
|
||||
),
|
||||
lax.tan_p: _make_dispatch_table(
|
||||
"tan",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_tanf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_tan", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_tan_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_tan_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.asin_p: _make_dispatch_table(
|
||||
"asin",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_asinf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_asin", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_asin_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_asin_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.acos_p: _make_dispatch_table(
|
||||
"acos",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_acosf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_acos", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_acos_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_acos_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.atan_p: _make_dispatch_table(
|
||||
"atan",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_atanf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_atan", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_atan_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_atan_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.atan2_p: _make_dispatch_table(
|
||||
"atan2",
|
||||
cuda=[
|
||||
_Extern([jnp.float32, jnp.float32], "__nv_atan2f", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__nv_atan2", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern(
|
||||
[jnp.float32, jnp.float32], "__ocml_atan2_f32", jnp.float32
|
||||
),
|
||||
_Extern(
|
||||
[jnp.float64, jnp.float64], "__ocml_atan2_f64", jnp.float64
|
||||
),
|
||||
],
|
||||
),
|
||||
lax.sinh_p: _make_dispatch_table(
|
||||
"sinh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_sinhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_sinh", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_sinh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_sinh_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.cosh_p: _make_dispatch_table(
|
||||
"cosh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_coshf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_cosh", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_cosh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_cosh_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.tanh_p: _make_dispatch_table(
|
||||
"tanh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_tanhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_tanh", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_tanh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_tanh_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.asinh_p: _make_dispatch_table(
|
||||
"asinh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_asinhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_asinh", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_asinh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_asinh_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.acosh_p: _make_dispatch_table(
|
||||
"acosh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_acoshf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_acosh", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_acosh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_acosh_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.atanh_p: _make_dispatch_table(
|
||||
"atanh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_atanhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_atanh", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_atanh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_atanh_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.population_count_p: _make_dispatch_table(
|
||||
"population_count",
|
||||
cuda=[
|
||||
_Extern([jnp.int32], "__nv_popc", jnp.int32),
|
||||
_Extern([jnp.int64], "__nv_popcll", jnp.int32),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback([jnp.int32], lambda ctx, x: math_dialect.ctpop(x)),
|
||||
_Fallback([jnp.int64], lambda ctx, x: math_dialect.ctpop(x)),
|
||||
],
|
||||
),
|
||||
lax.clz_p: _make_dispatch_table(
|
||||
"clz",
|
||||
cuda=[
|
||||
_Extern([jnp.int32], "__nv_clz", jnp.int32),
|
||||
_Extern([jnp.int64], "__nv_clzll", jnp.int32),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback([jnp.int32], lambda ctx, x: math_dialect.ctlz(x)),
|
||||
_Fallback([jnp.int64], lambda ctx, x: math_dialect.ctlz(x)),
|
||||
],
|
||||
),
|
||||
lax.nextafter_p: _make_dispatch_table(
|
||||
"nextafter",
|
||||
cuda=[
|
||||
_Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__nv_nextafter", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern(
|
||||
[jnp.float32, jnp.float32], "__ocml_nextafter_f32", jnp.float32
|
||||
),
|
||||
_Extern(
|
||||
[jnp.float64, jnp.float64], "__ocml_nextafter_f64", jnp.float64
|
||||
),
|
||||
],
|
||||
),
|
||||
lax.pow_p: pow_dispatch_table,
|
||||
lax.cbrt_p: cbrt_dispatch_table,
|
||||
lax.rsqrt_p: rsqrt_dispatch_table,
|
||||
lax.sin_p: sin_dispatch_table,
|
||||
lax.cos_p: cos_dispatch_table,
|
||||
lax.tan_p: tan_dispatch_table,
|
||||
lax.asin_p: asin_dispatch_table,
|
||||
lax.acos_p: acos_dispatch_table,
|
||||
lax.atan_p: atan_dispatch_table,
|
||||
lax.atan2_p: atan2_dispatch_table,
|
||||
lax.sinh_p: sinh_dispatch_table,
|
||||
lax.cosh_p: cosh_dispatch_table,
|
||||
lax.tanh_p: tanh_dispatch_table,
|
||||
lax.asinh_p: asinh_dispatch_table,
|
||||
lax.acosh_p: acosh_dispatch_table,
|
||||
lax.atanh_p: atanh_dispatch_table,
|
||||
lax.population_count_p: population_count_dispatch_table,
|
||||
lax.clz_p: clz_dispatch_table,
|
||||
lax.nextafter_p: nextafter_dispatch_table,
|
||||
})
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user