[ROCm]: Pallas updates for ROCm

This commit is contained in:
Rahul Batra 2024-06-10 23:01:53 +00:00
parent a9edaeb38e
commit 5f87fd1ec8
2 changed files with 266 additions and 50 deletions

View File

@ -78,6 +78,7 @@ class ModuleContext:
grid_mapping: GridMapping
program_ids: Sequence[ir.Value]
traceback_caches: mlir.TracebackCaches = dataclasses.field(repr=False)
platform: str
@dataclasses.dataclass
@ -253,6 +254,7 @@ def lower_jaxpr_to_triton_module(
in_shapes,
grid_mapping: GridMapping,
name: str,
platform: str
) -> LoweringResult:
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True)
with _new_ir_context(), ir.Location.unknown():
@ -284,7 +286,7 @@ def lower_jaxpr_to_triton_module(
if i not in grid_mapping.mapped_dims
]
ctx = ModuleContext(
name, grid_mapping, local_program_ids, mlir.TracebackCaches()
name, grid_mapping, local_program_ids, mlir.TracebackCaches(), platform
)
if grid_mapping.num_index_operands:
raise NotImplementedError(
@ -395,7 +397,6 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis):
raise ValueError(f"axis must be in [0, 3), but got: {axis}")
return tt_dialect.get_num_programs(axis)
def _atomic_rmw(
op: tt_dialect.RMWOp,
ptr: ir.Value,
@ -554,6 +555,7 @@ class _Extern:
)
@dataclasses.dataclass(frozen=True)
class _Fallback:
arg_types: Sequence[str]
@ -567,7 +569,7 @@ def _make_dispatch_table(
) -> Callable[..., ir.Value]:
def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value:
h = next((e for e in table if e.matches(ctx.avals_in)), None)
h = next((e for e in table[ctx.context.platform] if e.matches(ctx.avals_in)), None)
if h is None:
arg_aval_dtypes = tuple(aval.dtype.name for aval in ctx.avals_in)
raise NotImplementedError(
@ -588,12 +590,21 @@ def _make_dispatch_table(
_abs_dispatch_table = _make_dispatch_table(
"abs",
[
_Extern(["int32"], "__nv_abs", "int32"),
_Extern(["int64"], "__nv_llabs", "int64"),
_Extern(["float32"], "__nv_fabsf", "float32"),
_Extern(["float64"], "__nv_fabs", "float64"),
],
{ "cuda":
[
_Extern(["int32"], "__nv_abs", "int32"),
_Extern(["int64"], "__nv_llabs", "int64"),
_Extern(["float32"], "__nv_fabsf", "float32"),
_Extern(["float64"], "__nv_fabs", "float64"),
],
"rocm":
[
_Fallback(["int32"], lambda ctx, x: math_dialect.absi(x)),
_Fallback(["int64"], lambda ctx, x: math_dialect.absi(x)),
_Fallback(["float32"], lambda ctx, x: math_dialect.absf(x)),
_Fallback(["float64"], lambda ctx, x: math_dialect.absf(x)),
]
},
)
@ -615,208 +626,413 @@ triton_lowering_rules.update({
lax.neg_p: lambda ctx, x: _minus(x),
lax.ceil_p: _make_dispatch_table(
"ceil",
[
_Extern(["float32"], "__nv_ceilf", "float32"),
_Extern(["float64"], "__nv_ceil", "float64"),
],
{ "cuda":
[
_Extern(["float32"], "__nv_ceilf", "float32"),
_Extern(["float64"], "__nv_ceil", "float64"),
],
"rocm":
[
_Extern(["float32"], "__ocml_ceil_f32", "float32"),
_Extern(["float64"], "__ocml_ceil_f64", "float64"),
]
},
),
lax.floor_p: _make_dispatch_table(
"floor",
[
_Extern(["float32"], "__nv_floorf", "float32"),
_Extern(["float64"], "__nv_floor", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)),
],
{ "cuda":
[
_Extern(["float32"], "__nv_floorf", "float32"),
_Extern(["float64"], "__nv_floor", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)),
],
"rocm":
[
_Extern(["float32"], "__ocml_floor_f32", "float32"),
_Extern(["float64"], "__ocml_floor_f64", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)),
],
}
),
lax.exp_p: _make_dispatch_table(
"exp",
[
{ "cuda":
[
_Extern(["float32"], "__nv_expf", "float32"),
_Extern(["float64"], "__nv_exp", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)),
],
],
"rocm":
[
_Fallback(["float32"], lambda ctx, x: math_dialect.exp(x)),
_Fallback(["float64"], lambda ctx, x: math_dialect.exp(x)),
_Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)),
],
}
),
lax.exp2_p: _make_dispatch_table(
"exp2",
[
{ "cuda":
[
_Extern(["float32"], "__nv_exp2f", "float32"),
_Extern(["float64"], "__nv_exp2", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)),
],
],
"rocm":
[
_Extern(["float32"], "__ocml_exp2_f32", "float32"),
_Extern(["float64"], "__ocml_exp2_f64", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)),
],
}
),
lax.expm1_p: _make_dispatch_table(
"expm1",
[
{ "cuda":
[
_Extern(["float32"], "__nv_expm1f", "float32"),
_Extern(["float64"], "__nv_expm1", "float64"),
],
],
"rocm":
[
_Extern(["float32"], "__ocml_expm1_f32", "float32"),
_Extern(["float64"], "__ocml_expm1_f64", "float64"),
],
}
),
lax.log_p: _make_dispatch_table(
"log",
[
{ "cuda":
[
_Extern(["float32"], "__nv_logf", "float32"),
_Extern(["float64"], "__nv_log", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.log(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)),
],
],
"rocm":
[
_Extern(["float32"], "__ocml_log_f32", "float32"),
_Extern(["float64"], "__ocml_log_f64", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.log(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)),
],
}
),
lax.log1p_p: _make_dispatch_table(
"log1p",
[
{"cuda":
[
_Extern(["float32"], "__nv_log1pf", "float32"),
_Extern(["float64"], "__nv_log1p", "float64"),
],
],
"rocm":
[
_Extern(["float32"], "__ocml_log1p_f32", "float32"),
_Extern(["float64"], "__ocml_log1p_f64", "float64"),
],
}
),
lax.sqrt_p: _make_dispatch_table(
"sqrt",
[
{"cuda":
[
_Extern(["float32"], "__nv_sqrtf", "float32"),
_Extern(["float64"], "__nv_sqrt", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)),
],
"rocm":
[
_Extern(["float32"], "__ocml_sqrt_f32", "float32"),
_Extern(["float64"], "__ocml_sqrt_f64", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)),
],
}
),
lax.pow_p: _make_dispatch_table(
"pow",
[
{"cuda":
[
_Extern(["float32", "int32"], "__nv_powif", "float32"),
_Extern(["float64", "int32"], "__nv_powi", "float64"),
_Extern(["float32", "float32"], "__nv_powf", "float32"),
_Extern(["float64", "float64"], "__nv_pow", "float64"),
],
"rocm":
[
_Extern(["float32", "int32"], "__ocml_pown_f32", "float32"),
_Extern(["float64", "int32"], "__ocml_pown_f64", "float64"),
_Extern(["float32", "float32"], "__ocml_pow_f32", "float32"),
_Extern(["float64", "float64"], "__ocml_pow_f64", "float64"),
],
}
),
lax.cbrt_p: _make_dispatch_table(
"cbrt",
[
{"cuda":
[
_Extern(["float32"], "__nv_cbrtf", "float32"),
_Extern(["float64"], "__nv_cbrt", "float64"),
],
"rocm":
[
_Extern(["float32"], "__ocml_cbrt_f32", "float32"),
_Extern(["float64"], "__ocml_cbrt_f64", "float64"),
],
}
),
lax.rsqrt_p: _make_dispatch_table(
"rsqrt",
[
{"cuda":
[
_Extern(["float32"], "__nv_rsqrtf", "float32"),
_Extern(["float64"], "__nv_rsqrt", "float64"),
],
"rocm":
[
_Extern(["float32"], "__ocml_rsqrt_f32", "float32"),
_Extern(["float64"], "__ocml_rsqrt_f64", "float64"),
],
}
),
lax.sin_p: _make_dispatch_table(
"sin",
[
{"cuda":
[
_Extern(["float32"], "__nv_sinf", "float32"),
_Extern(["float64"], "__nv_sin", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)),
],
"rocm":
[
_Extern(["float32"], "__ocml_sin_f32", "float32"),
_Extern(["float64"], "__ocml_sin_f64", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)),
],
}
),
lax.cos_p: _make_dispatch_table(
"cos",
[
{"cuda":
[
_Extern(["float32"], "__nv_cosf", "float32"),
_Extern(["float64"], "__nv_cos", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)),
],
"rocm":
[
_Extern(["float32"], "__ocml_cos_f32", "float32"),
_Extern(["float64"], "__ocml_cos_f64", "float64"),
_Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)),
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)),
],
}
),
lax.tan_p: _make_dispatch_table(
"tan",
[
{"cuda":
[
_Extern(["float32"], "__nv_tanf", "float32"),
_Extern(["float64"], "__nv_tan", "float64"),
],
"rocm":
[
_Extern(["float32"], "__ocml_tan_f32", "float32"),
_Extern(["float64"], "__ocml_tan_f64", "float64"),
],
}
),
lax.asin_p: _make_dispatch_table(
"asin",
[
{"cuda":
[
_Extern(["float32"], "__nv_asinf", "float32"),
_Extern(["float64"], "__nv_asin", "float64"),
],
"rocm":
[
_Extern(["float32"], "__ocml_asin_f32", "float32"),
_Extern(["float64"], "__ocml_asin_f64", "float64"),
],
}
),
lax.acos_p: _make_dispatch_table(
"acos",
[
{"cuda":
[
_Extern(["float32"], "__nv_acosf", "float32"),
_Extern(["float64"], "__nv_acos", "float64"),
],
"rocm":
[
_Extern(["float32"], "__ocml_acos_f32", "float32"),
_Extern(["float64"], "__ocml_acos_f64", "float64"),
],
}
),
lax.atan_p: _make_dispatch_table(
"atan",
[
{"cuda":
[
_Extern(["float32"], "__nv_atanf", "float32"),
_Extern(["float64"], "__nv_atan", "float64"),
],
"rocm":
[
_Extern(["float32"], "__ocml_atan_f32", "float32"),
_Extern(["float64"], "__ocml_atan_f64", "float64"),
],
}
),
lax.atan2_p: _make_dispatch_table(
"atan2",
[
{"cuda":
[
_Extern(["float32", "float32"], "__nv_atan2f", "float32"),
_Extern(["float64", "float64"], "__nv_atan2", "float64"),
],
"rocm":
[
_Extern(["float32", "float32"], "__ocml_atan2_f32", "float32"),
_Extern(["float64", "float64"], "__ocml_atan2_f64", "float64"),
],
}
),
lax.sinh_p: _make_dispatch_table(
"sinh",
[
{"cuda":
[
_Extern(["float32"], "__nv_sinhf", "float32"),
_Extern(["float64"], "__nv_sinh", "float64"),
],
"rocm":
[
_Extern(["float32"], "__ocml_sinh_f32", "float32"),
_Extern(["float64"], "__ocml_sinh_f64", "float64"),
],
}
),
lax.cosh_p: _make_dispatch_table(
"cosh",
[
{"cuda":
[
_Extern(["float32"], "__nv_coshf", "float32"),
_Extern(["float64"], "__nv_cosh", "float64"),
],
"rocm":
[
_Extern(["float32"], "__ocml_cosh_f32", "float32"),
_Extern(["float64"], "__ocml_cosh_f64", "float64"),
],
}
),
lax.tanh_p: _make_dispatch_table(
"tanh",
[
{"cuda":
[
_Extern(["float32"], "__nv_tanhf", "float32"),
_Extern(["float64"], "__nv_tanh", "float64"),
],
"rocm":
[
_Extern(["float32"], "__ocml_tanh_f32", "float32"),
_Extern(["float64"], "__ocml_tanh_f64", "float64"),
],
}
),
lax.asinh_p: _make_dispatch_table(
"asinh",
[
{"cuda":
[
_Extern(["float32"], "__nv_asinhf", "float32"),
_Extern(["float64"], "__nv_asinh", "float64"),
],
"rocm":
[
_Extern(["float32"], "__ocml_asinh_f32", "float32"),
_Extern(["float64"], "__ocml_asinh_f64", "float64"),
],
}
),
lax.acosh_p: _make_dispatch_table(
"acosh",
[
{"cuda":
[
_Extern(["float32"], "__nv_acoshf", "float32"),
_Extern(["float64"], "__nv_acosh", "float64"),
],
"rocm":
[
_Extern(["float32"], "__ocml_acosh_f32", "float32"),
_Extern(["float64"], "__ocml_acosh_f64", "float64"),
],
}
),
lax.atanh_p: _make_dispatch_table(
"atanh",
[
{"cuda":
[
_Extern(["float32"], "__nv_atanhf", "float32"),
_Extern(["float64"], "__nv_atanh", "float64"),
],
"rocm":
[
_Extern(["float32"], "__ocml_atanh_f32", "float32"),
_Extern(["float64"], "__ocml_atanh_f64", "float64"),
],
}
),
lax.population_count_p: _make_dispatch_table(
"population_count",
[
{"cuda":
[
_Extern(["int32"], "__nv_popc", "int32"),
_Extern(["int64"], "__nv_popcll", "int32"),
],
"rocm":
[
_Fallback(["int32"], lambda ctx, x: math_dialect.ctpop(x)),
_Fallback(["int64"], lambda ctx, x: math_dialect.ctpop(x)),
],
}
),
lax.clz_p: _make_dispatch_table(
"clz",
[
{"cuda":
[
_Extern(["int32"], "__nv_clz", "int32"),
_Extern(["int64"], "__nv_clzll", "int32"),
],
"rocm":
[
_Fallback(["int32"], lambda ctx, x: math_dialect.ctlz(x)),
_Fallback(["int64"], lambda ctx, x: math_dialect.ctlz(x)),
],
}
),
lax.nextafter_p: _make_dispatch_table(
"nextafter",
[
{"cuda":
[
_Extern(["float32", "float32"], "__nv_nextafterf", "float32"),
_Extern(["float64", "float64"], "__nv_nextafter", "float64"),
],
"rocm":
[
_Extern(["float32", "float32"], "__ocml_nextafter_f32", "float32"),
_Extern(["float64", "float64"], "__ocml_nextafter_f64", "float64"),
],
}
),
})

View File

@ -87,7 +87,7 @@ def pallas_call_lowering(
print(grid_mapping)
lowering_result = lowering.lower_jaxpr_to_triton_module(
jaxpr, (*in_shapes, *out_shapes), grid_mapping, name,
jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, lowering_platform
)
module_op = lowering_result.module.operation
if debug: