mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[ROCm]: Pallas updates for ROCm
This commit is contained in:
parent
a9edaeb38e
commit
5f87fd1ec8
@ -78,6 +78,7 @@ class ModuleContext:
|
||||
grid_mapping: GridMapping
|
||||
program_ids: Sequence[ir.Value]
|
||||
traceback_caches: mlir.TracebackCaches = dataclasses.field(repr=False)
|
||||
platform: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -253,6 +254,7 @@ def lower_jaxpr_to_triton_module(
|
||||
in_shapes,
|
||||
grid_mapping: GridMapping,
|
||||
name: str,
|
||||
platform: str
|
||||
) -> LoweringResult:
|
||||
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True)
|
||||
with _new_ir_context(), ir.Location.unknown():
|
||||
@ -284,7 +286,7 @@ def lower_jaxpr_to_triton_module(
|
||||
if i not in grid_mapping.mapped_dims
|
||||
]
|
||||
ctx = ModuleContext(
|
||||
name, grid_mapping, local_program_ids, mlir.TracebackCaches()
|
||||
name, grid_mapping, local_program_ids, mlir.TracebackCaches(), platform
|
||||
)
|
||||
if grid_mapping.num_index_operands:
|
||||
raise NotImplementedError(
|
||||
@ -395,7 +397,6 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis):
|
||||
raise ValueError(f"axis must be in [0, 3), but got: {axis}")
|
||||
return tt_dialect.get_num_programs(axis)
|
||||
|
||||
|
||||
def _atomic_rmw(
|
||||
op: tt_dialect.RMWOp,
|
||||
ptr: ir.Value,
|
||||
@ -554,6 +555,7 @@ class _Extern:
|
||||
)
|
||||
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _Fallback:
|
||||
arg_types: Sequence[str]
|
||||
@ -567,7 +569,7 @@ def _make_dispatch_table(
|
||||
) -> Callable[..., ir.Value]:
|
||||
|
||||
def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value:
|
||||
h = next((e for e in table if e.matches(ctx.avals_in)), None)
|
||||
h = next((e for e in table[ctx.context.platform] if e.matches(ctx.avals_in)), None)
|
||||
if h is None:
|
||||
arg_aval_dtypes = tuple(aval.dtype.name for aval in ctx.avals_in)
|
||||
raise NotImplementedError(
|
||||
@ -588,12 +590,21 @@ def _make_dispatch_table(
|
||||
|
||||
_abs_dispatch_table = _make_dispatch_table(
|
||||
"abs",
|
||||
[
|
||||
_Extern(["int32"], "__nv_abs", "int32"),
|
||||
_Extern(["int64"], "__nv_llabs", "int64"),
|
||||
_Extern(["float32"], "__nv_fabsf", "float32"),
|
||||
_Extern(["float64"], "__nv_fabs", "float64"),
|
||||
],
|
||||
{ "cuda":
|
||||
[
|
||||
_Extern(["int32"], "__nv_abs", "int32"),
|
||||
_Extern(["int64"], "__nv_llabs", "int64"),
|
||||
_Extern(["float32"], "__nv_fabsf", "float32"),
|
||||
_Extern(["float64"], "__nv_fabs", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Fallback(["int32"], lambda ctx, x: math_dialect.absi(x)),
|
||||
_Fallback(["int64"], lambda ctx, x: math_dialect.absi(x)),
|
||||
_Fallback(["float32"], lambda ctx, x: math_dialect.absf(x)),
|
||||
_Fallback(["float64"], lambda ctx, x: math_dialect.absf(x)),
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -615,208 +626,413 @@ triton_lowering_rules.update({
|
||||
lax.neg_p: lambda ctx, x: _minus(x),
|
||||
lax.ceil_p: _make_dispatch_table(
|
||||
"ceil",
|
||||
[
|
||||
_Extern(["float32"], "__nv_ceilf", "float32"),
|
||||
_Extern(["float64"], "__nv_ceil", "float64"),
|
||||
],
|
||||
{ "cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_ceilf", "float32"),
|
||||
_Extern(["float64"], "__nv_ceil", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_ceil_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_ceil_f64", "float64"),
|
||||
]
|
||||
},
|
||||
),
|
||||
lax.floor_p: _make_dispatch_table(
|
||||
"floor",
|
||||
[
|
||||
_Extern(["float32"], "__nv_floorf", "float32"),
|
||||
_Extern(["float64"], "__nv_floor", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)),
|
||||
],
|
||||
{ "cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_floorf", "float32"),
|
||||
_Extern(["float64"], "__nv_floor", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_floor_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_floor_f64", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.exp_p: _make_dispatch_table(
|
||||
"exp",
|
||||
[
|
||||
{ "cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_expf", "float32"),
|
||||
_Extern(["float64"], "__nv_exp", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)),
|
||||
],
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Fallback(["float32"], lambda ctx, x: math_dialect.exp(x)),
|
||||
_Fallback(["float64"], lambda ctx, x: math_dialect.exp(x)),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.exp2_p: _make_dispatch_table(
|
||||
"exp2",
|
||||
[
|
||||
{ "cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_exp2f", "float32"),
|
||||
_Extern(["float64"], "__nv_exp2", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)),
|
||||
],
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_exp2_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_exp2_f64", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.expm1_p: _make_dispatch_table(
|
||||
"expm1",
|
||||
[
|
||||
{ "cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_expm1f", "float32"),
|
||||
_Extern(["float64"], "__nv_expm1", "float64"),
|
||||
],
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_expm1_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_expm1_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.log_p: _make_dispatch_table(
|
||||
"log",
|
||||
[
|
||||
{ "cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_logf", "float32"),
|
||||
_Extern(["float64"], "__nv_log", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.log(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)),
|
||||
],
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_log_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_log_f64", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.log(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.log1p_p: _make_dispatch_table(
|
||||
"log1p",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_log1pf", "float32"),
|
||||
_Extern(["float64"], "__nv_log1p", "float64"),
|
||||
],
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_log1p_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_log1p_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.sqrt_p: _make_dispatch_table(
|
||||
"sqrt",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_sqrtf", "float32"),
|
||||
_Extern(["float64"], "__nv_sqrt", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_sqrt_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_sqrt_f64", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.pow_p: _make_dispatch_table(
|
||||
"pow",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32", "int32"], "__nv_powif", "float32"),
|
||||
_Extern(["float64", "int32"], "__nv_powi", "float64"),
|
||||
_Extern(["float32", "float32"], "__nv_powf", "float32"),
|
||||
_Extern(["float64", "float64"], "__nv_pow", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32", "int32"], "__ocml_pown_f32", "float32"),
|
||||
_Extern(["float64", "int32"], "__ocml_pown_f64", "float64"),
|
||||
_Extern(["float32", "float32"], "__ocml_pow_f32", "float32"),
|
||||
_Extern(["float64", "float64"], "__ocml_pow_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.cbrt_p: _make_dispatch_table(
|
||||
"cbrt",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_cbrtf", "float32"),
|
||||
_Extern(["float64"], "__nv_cbrt", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_cbrt_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_cbrt_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.rsqrt_p: _make_dispatch_table(
|
||||
"rsqrt",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_rsqrtf", "float32"),
|
||||
_Extern(["float64"], "__nv_rsqrt", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_rsqrt_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_rsqrt_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.sin_p: _make_dispatch_table(
|
||||
"sin",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_sinf", "float32"),
|
||||
_Extern(["float64"], "__nv_sin", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_sin_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_sin_f64", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.cos_p: _make_dispatch_table(
|
||||
"cos",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_cosf", "float32"),
|
||||
_Extern(["float64"], "__nv_cos", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_cos_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_cos_f64", "float64"),
|
||||
_Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)),
|
||||
_Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.tan_p: _make_dispatch_table(
|
||||
"tan",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_tanf", "float32"),
|
||||
_Extern(["float64"], "__nv_tan", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_tan_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_tan_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.asin_p: _make_dispatch_table(
|
||||
"asin",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_asinf", "float32"),
|
||||
_Extern(["float64"], "__nv_asin", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_asin_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_asin_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.acos_p: _make_dispatch_table(
|
||||
"acos",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_acosf", "float32"),
|
||||
_Extern(["float64"], "__nv_acos", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_acos_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_acos_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.atan_p: _make_dispatch_table(
|
||||
"atan",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_atanf", "float32"),
|
||||
_Extern(["float64"], "__nv_atan", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_atan_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_atan_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.atan2_p: _make_dispatch_table(
|
||||
"atan2",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32", "float32"], "__nv_atan2f", "float32"),
|
||||
_Extern(["float64", "float64"], "__nv_atan2", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32", "float32"], "__ocml_atan2_f32", "float32"),
|
||||
_Extern(["float64", "float64"], "__ocml_atan2_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.sinh_p: _make_dispatch_table(
|
||||
"sinh",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_sinhf", "float32"),
|
||||
_Extern(["float64"], "__nv_sinh", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_sinh_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_sinh_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.cosh_p: _make_dispatch_table(
|
||||
"cosh",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_coshf", "float32"),
|
||||
_Extern(["float64"], "__nv_cosh", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_cosh_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_cosh_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.tanh_p: _make_dispatch_table(
|
||||
"tanh",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_tanhf", "float32"),
|
||||
_Extern(["float64"], "__nv_tanh", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_tanh_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_tanh_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.asinh_p: _make_dispatch_table(
|
||||
"asinh",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_asinhf", "float32"),
|
||||
_Extern(["float64"], "__nv_asinh", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_asinh_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_asinh_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.acosh_p: _make_dispatch_table(
|
||||
"acosh",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_acoshf", "float32"),
|
||||
_Extern(["float64"], "__nv_acosh", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_acosh_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_acosh_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.atanh_p: _make_dispatch_table(
|
||||
"atanh",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32"], "__nv_atanhf", "float32"),
|
||||
_Extern(["float64"], "__nv_atanh", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32"], "__ocml_atanh_f32", "float32"),
|
||||
_Extern(["float64"], "__ocml_atanh_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.population_count_p: _make_dispatch_table(
|
||||
"population_count",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["int32"], "__nv_popc", "int32"),
|
||||
_Extern(["int64"], "__nv_popcll", "int32"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Fallback(["int32"], lambda ctx, x: math_dialect.ctpop(x)),
|
||||
_Fallback(["int64"], lambda ctx, x: math_dialect.ctpop(x)),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.clz_p: _make_dispatch_table(
|
||||
"clz",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["int32"], "__nv_clz", "int32"),
|
||||
_Extern(["int64"], "__nv_clzll", "int32"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Fallback(["int32"], lambda ctx, x: math_dialect.ctlz(x)),
|
||||
_Fallback(["int64"], lambda ctx, x: math_dialect.ctlz(x)),
|
||||
],
|
||||
}
|
||||
),
|
||||
lax.nextafter_p: _make_dispatch_table(
|
||||
"nextafter",
|
||||
[
|
||||
{"cuda":
|
||||
[
|
||||
_Extern(["float32", "float32"], "__nv_nextafterf", "float32"),
|
||||
_Extern(["float64", "float64"], "__nv_nextafter", "float64"),
|
||||
],
|
||||
"rocm":
|
||||
[
|
||||
_Extern(["float32", "float32"], "__ocml_nextafter_f32", "float32"),
|
||||
_Extern(["float64", "float64"], "__ocml_nextafter_f64", "float64"),
|
||||
],
|
||||
}
|
||||
),
|
||||
})
|
||||
|
||||
|
@ -87,7 +87,7 @@ def pallas_call_lowering(
|
||||
print(grid_mapping)
|
||||
|
||||
lowering_result = lowering.lower_jaxpr_to_triton_module(
|
||||
jaxpr, (*in_shapes, *out_shapes), grid_mapping, name,
|
||||
jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, lowering_platform
|
||||
)
|
||||
module_op = lowering_result.module.operation
|
||||
if debug:
|
||||
|
Loading…
x
Reference in New Issue
Block a user