rocm_jax/jax/_src/pallas/triton/lowering.py
2024-09-20 07:52:33 -07:00

2652 lines
86 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for lowering JAX primitives to Triton IR."""
from __future__ import annotations
from collections.abc import Callable, Sequence
import dataclasses
import functools
import math
import operator
from typing import Any, Hashable, TypeVar
import jax
from jax import lax
from jax import tree_util
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api_util
from jax._src import config
from jax._src import core as jax_core
from jax._src import custom_derivatives
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import source_info_util
from jax._src import state
from jax._src import util
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax.control_flow import for_loop
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith as arith_dialect
from jax._src.lib.mlir.dialects import math as math_dialect
from jax._src.lib.mlir.dialects import scf as scf_dialect
from jax._src.lib.triton import dialect as tt_dialect
from jax._src.pallas import core as pallas_core
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.state import discharge
from jax._src.state import indexing
from jax._src.state import primitives as sp
from jax._src.util import merge_lists
from jax._src.util import partition_list
from jax._src.util import split_list
import jax.numpy as jnp
import numpy as np
# TODO(sharadmv): Enable type checking.
# mypy: ignore-errors
# pytype: skip-file
_T = TypeVar("_T")
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
NDIndexer = indexing.NDIndexer
GridMapping = pallas_core.GridMapping
BlockMapping = pallas_core.BlockMapping
Blocked = pallas_core.Blocked
# # General lowering logic
@dataclasses.dataclass
class ModuleContext:
name: str
grid_mapping: GridMapping
program_ids: Sequence[ir.Value]
traceback_caches: mlir.TracebackCaches = dataclasses.field(repr=False)
platform: str
@dataclasses.dataclass
class BlockInfo:
full_shape_dtype: jax.ShapeDtypeStruct
start_indices: Sequence[Any]
block_shape: tuple[int, ...] # TODO(necula): can this contain "mapped"?
@dataclasses.dataclass
class LoweringRuleContext:
context: ModuleContext
avals_in: Sequence[jax_core.ShapedArray]
avals_out: Sequence[jax_core.ShapedArray]
block_infos: Sequence[BlockInfo | None] # TODO(necula): can this be None?
replace = dataclasses.replace
@dataclasses.dataclass
class LoweringResult:
"""Keeps pybind11 objects alive."""
module: ir.Module
grid: tuple[int, ...]
class LoweringError(Exception):
pass
def _eval_index_map(
ctx: ModuleContext, idx, block_mapping: BlockMapping
):
block_indices = lower_jaxpr_to_triton_ir(
ctx, block_mapping.index_map_jaxpr.jaxpr, None, *idx
)
block_indices = (
_ensure_ir_value(i, jax_core.ShapedArray((), jnp.int32))
for i in block_indices
)
return tuple(
i if b is pallas_core.mapped else _mul(i, _ir_constant(b, i.type))
for i, b in zip(block_indices, block_mapping.block_shape)
)
def _bcast_to(a: ir.Value, shape: tuple[int, ...]) -> ir.Value:
if not ir.RankedTensorType.isinstance(a.type):
if not shape:
return a
return tt_dialect.splat(ir.RankedTensorType.get(shape, a.type), a)
else:
a_type = ir.RankedTensorType(a.type)
if a_type.shape == [*shape]:
return a
if a_type.rank != len(shape) or not all(
a_type.shape[i] in (dim, 1) for i, dim in enumerate(shape)
):
raise ValueError(f"Cannot broadcast from {a_type.shape} to {[*shape]}")
return tt_dialect.broadcast(
ir.RankedTensorType.get(shape, a_type.element_type, a_type.encoding), a
)
def _bcast(
x: ir.Value,
y: ir.Value,
x_aval: jax_core.ShapedArray,
y_aval: jax_core.ShapedArray,
out_aval: jax_core.ShapedArray,
) -> ir.Value:
if isinstance(x, (np.ndarray, np.number, int, float)):
x_dtype = x_aval.dtype
if x_aval.weak_type:
x_dtype = y_aval.dtype
x = _ir_constant(x, _dtype_to_ir_type(x_dtype))
if isinstance(y, (np.ndarray, np.number, int, float)):
y_dtype = y_aval.dtype
if y_aval.weak_type:
y_dtype = x_aval.dtype
y = _ir_constant(y, _dtype_to_ir_type(y_dtype))
if x_aval.shape != out_aval.shape:
x = _bcast_to(x, out_aval.shape)
if y_aval.shape != out_aval.shape:
y = _bcast_to(y, out_aval.shape)
return x, y
triton_lowering_rules = {}
def register_lowering(primitive: jax_core.Primitive) -> Callable[[_T], _T]:
def wrapper(fn):
triton_lowering_rules[primitive] = fn
return fn
return wrapper
def _process_grid_to_3d_grid(grid_mapping: GridMapping):
launch_grid = []
launch_grid_to_pallas_grid = []
# Preserve grid order provided to pallas_call
for i, s in enumerate(grid_mapping.grid):
if i not in grid_mapping.vmapped_dims:
launch_grid.append(s)
launch_grid_to_pallas_grid.append(i)
# For mapped dims, iterate from inner to outer. This follows the pallas_call
# batching rule that prepends the vmapped dimension.
for dim in reversed(grid_mapping.vmapped_dims):
s = grid_mapping.grid[dim]
launch_grid.append(s)
launch_grid_to_pallas_grid.append(dim)
num_collapse = len(launch_grid[:-2])
cuda_yz_limit = 2**16 - 1
# Check Z and then Y launch dims to make sure they're within CUDA bounds
if (num_collapse + 1 < len(launch_grid) and
launch_grid[num_collapse + 1] > cuda_yz_limit):
num_collapse += 2
elif (num_collapse < len(launch_grid) and
launch_grid[num_collapse] > cuda_yz_limit):
num_collapse += 1
collapse_dims = launch_grid[:num_collapse]
prog_id_dims = launch_grid[num_collapse:]
if len(collapse_dims) == 0:
prog_ids = [None] * len(prog_id_dims)
for i in range(len(prog_id_dims)):
out_idx = launch_grid_to_pallas_grid[i]
prog_ids[out_idx] = _program_id(i)
return prog_id_dims, prog_ids
else:
new_grid = [math.prod(collapse_dims), *prog_id_dims]
assert new_grid[0] < 2**31 - 1, \
"Cannot fix pallas kernel launch grid within CUDA limits"
out_indices = [None] * len(grid_mapping.grid)
grid0 = _program_id(0)
for i, s in enumerate(collapse_dims):
out_idx = launch_grid_to_pallas_grid[i]
s = _i32_constant(s)
out_indices[out_idx] = _mod(grid0, s, signed=False)
grid0 = _floordiv(grid0, s, signed=False)
for i in range(len(prog_id_dims)):
out_idx = launch_grid_to_pallas_grid[num_collapse + i]
out_indices[out_idx] = _program_id(i + 1)
assert len(out_indices) == len(grid_mapping.grid)
return new_grid, out_indices
def _new_ir_context() -> ir.Context:
ctx = ir.Context()
tt_dialect.register_dialect(ctx)
ctx.load_all_available_dialects()
return ctx
# Many Trion operations require that their inputs and outputs have sizes that
# are a power of 2 (they are defined to have TensorSizeTrait that enforces
# this). This check is only needed to obtain a nicer error message; the
# Triton lowering will fail anyway but it will crash with a C++ exception.
# We currently apply this check only to load/store operations.
def _check_tensor_size(shape: tuple[int | pallas_core.Mapped, ...]):
size = math.prod(1 if d is pallas_core.mapped else d for d in shape)
power_of_2 = (size & (size - 1)) == 0
if not power_of_2:
raise ValueError(
"The Pallas Triton lowering currently requires that all "
"operations have array arguments and results whose size "
"is a power of 2. Encountered an array of "
f"shape {shape}")
def lower_jaxpr_to_triton_module(
jaxpr: jax_core.Jaxpr,
grid_mapping: GridMapping,
name_and_src_info: pallas_core.NameAndStrInfo,
platform: str
) -> LoweringResult:
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError(
"dynamic grid bounds not supported in the Triton backend"
)
if grid_mapping.num_index_operands:
raise NotImplementedError(
"scalar prefetch not implemented in the Triton backend"
)
if jaxpr.invars[grid_mapping.slice_scratch_ops]:
raise NotImplementedError(
"scratch memory not implemented in the Triton backend"
)
with grid_mapping.trace_env():
jaxpr, _ = pe.dce_jaxpr(
jaxpr, [True] * len(jaxpr.outvars), instantiate=True
)
with _new_ir_context(), ir.Location.unknown():
module = ir.Module.create()
attrs = module.operation.attributes
module_name = name_and_src_info.name
attrs["sym_name"] = ir.StringAttr.get(module_name)
param_types = [
tt_dialect.PointerType.get(_dtype_to_ir_type(var.aval.dtype), 1)
for var in jaxpr.invars
]
assert len(jaxpr.outvars) == 0
fn_type = ir.FunctionType.get(param_types, [])
fn = tt_dialect.FuncOp(
name_and_src_info.name,
ir.TypeAttr.get(fn_type),
sym_visibility="public",
res_attrs=ir.DictAttr.get(dict(noinline=ir.BoolAttr.get(False))),
ip=ir.InsertionPoint.at_block_begin(module.body),
)
fn.arg_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get({"tt.divisibility": mlir.i32_attr(32)})]
* len(param_types)
)
fn.body.blocks.append(*fn_type.inputs)
[entry] = fn.body.blocks
with ir.InsertionPoint(entry):
new_grid, program_ids = _process_grid_to_3d_grid(grid_mapping)
local_program_ids = [
pid
for i, pid in enumerate(program_ids)
if i not in grid_mapping.vmapped_dims
]
ctx = ModuleContext(
name_and_src_info.name,
grid_mapping, local_program_ids, mlir.TracebackCaches(), platform
)
if grid_mapping.num_index_operands:
raise NotImplementedError(
"Scalar prefetch not supported in Triton lowering."
)
if not all(isinstance(bm.indexing_mode, Blocked)
for bm in grid_mapping.block_mappings):
raise NotImplementedError(
"Only Blocked indexing mode is supported in Triton lowering."
)
start_indices = map(
functools.partial(_eval_index_map, ctx, program_ids),
grid_mapping.block_mappings,
)
block_infos = [
BlockInfo(
block_mapping.array_shape_dtype,
start_idx,
block_mapping.block_shape,
)
for block_mapping, start_idx in zip(
grid_mapping.block_mappings,
start_indices,
)
]
() = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *entry.arguments)
tt_dialect.return_([])
return LoweringResult(module, new_grid)
def lower_jaxpr_to_triton_ir(
ctx: ModuleContext,
jaxpr: jax_core.Jaxpr,
block_infos: Sequence[BlockInfo | None] | None,
*args,
) -> Sequence[Any]:
env = {}
block_info_env = {}
def read_env(atom: jax_core.Atom):
return atom.val if isinstance(atom, jax_core.Literal) else env[atom]
def read_block_info_env(atom: jax_core.Atom):
if isinstance(atom, jax_core.Literal):
return None
return block_info_env.get(atom, None)
def write_env(var: jax_core.Var, val):
env[var] = val
if block_infos is not None:
for invar, block_info in zip(jaxpr.invars, block_infos):
block_info_env[invar] = block_info
map(write_env, jaxpr.invars, args)
for eqn in jaxpr.eqns:
invals = map(read_env, eqn.invars)
if eqn.primitive not in triton_lowering_rules:
raise NotImplementedError(
"Unimplemented primitive in Pallas GPU lowering: "
f"{eqn.primitive.name}. "
"Please file an issue on https://github.com/jax-ml/jax/issues.")
rule = triton_lowering_rules[eqn.primitive]
avals_in = [v.aval for v in eqn.invars]
avals_out = [v.aval for v in eqn.outvars]
eqn_block_infos = map(read_block_info_env, eqn.invars)
loc = mlir._source_info_to_location(ctx, eqn.primitive, eqn.source_info)
rule_ctx = LoweringRuleContext(ctx, avals_in, avals_out, eqn_block_infos)
try:
with source_info_util.user_context(eqn.source_info.traceback), loc:
outvals = rule(rule_ctx, *invals, **eqn.params)
except LoweringError:
raise # We only add the extra info to the innermost exception.
except Exception as e:
inval_types = map(lambda t: getattr(t, "type", None), invals)
raise LoweringError(
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "
f" {rule_ctx}\nWith inval types={inval_types}\nIn jaxpr:\n{jaxpr}\n"
f"msg={e}"
) from e
if eqn.primitive.multiple_results:
map(write_env, eqn.outvars, outvals)
else:
write_env(eqn.outvars[0], outvals)
return map(read_env, jaxpr.outvars)
def lower_fun(
fun: Callable[..., Any], *, multiple_results: bool
) -> Callable[..., Any]:
fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
def f_lowered(ctx: LoweringRuleContext, *args, **params):
wrapped_fun = lu.wrap_init(fn, params)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr)
return out if multiple_results else out[0]
return f_lowered
# # Primitive lowering rules
# ## Programming model primitives
def _program_id(axis: int) -> ir.Value:
if axis not in range(3):
raise ValueError(f"axis must be in [0, 3), but got: {axis}")
return tt_dialect.get_program_id(axis)
@register_lowering(primitives.program_id_p)
def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis):
return ctx.context.program_ids[axis]
@register_lowering(primitives.num_programs_p)
def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis):
if axis not in range(3):
raise ValueError(f"axis must be in [0, 3), but got: {axis}")
return tt_dialect.get_num_programs(axis)
def _atomic_rmw(
op: tt_dialect.RMWOp,
ptr: ir.Value,
val: ir.Value,
mask: ir.Value | None = None,
semantic: tt_dialect.MemSemantic = tt_dialect.MemSemantic.ACQUIRE_RELEASE,
sync_scope: tt_dialect.MemSyncScope = tt_dialect.MemSyncScope.GPU,
) -> ir.Value:
if ir.RankedTensorType.isinstance(ptr.type):
ptr_type = ir.RankedTensorType(ptr.type)
element_type = tt_dialect.PointerType(ptr_type.element_type)
result_type = ir.RankedTensorType.get(
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
)
else:
result_type = tt_dialect.PointerType(ptr.type).pointee_type
return tt_dialect.atomic_rmw(
result_type, op, ptr, val, mask=mask, sem=semantic, scope=sync_scope
)
@register_lowering(primitives.atomic_rmw_p)
def _atomic_lowering_rule(
ctx: LoweringRuleContext,
*args_flat,
args_tree,
atomic_type: primitives.AtomicOpType,
):
ptr, indexers, val, mask = args_tree.unflatten(args_flat)
*_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in)
if len(indexers) != 1:
raise NotImplementedError("Only single indexer is supported.")
idx = indexers[0]
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
)
val = _ensure_ir_value(val, value_aval)
if mask is not None:
mask = _ensure_ir_value(mask, mask_aval)
if atomic_type == primitives.AtomicOpType.XCHG:
op = tt_dialect.RMWOp.XCHG
elif atomic_type == primitives.AtomicOpType.ADD:
if isinstance(val.type, ir.IntegerType):
op = tt_dialect.RMWOp.ADD
else:
op = tt_dialect.RMWOp.FADD
elif atomic_type == primitives.AtomicOpType.MIN:
op = tt_dialect.RMWOp.MIN
elif atomic_type == primitives.AtomicOpType.MAX:
op = tt_dialect.RMWOp.MAX
elif atomic_type == primitives.AtomicOpType.AND:
op = tt_dialect.RMWOp.AND
elif atomic_type == primitives.AtomicOpType.OR:
op = tt_dialect.RMWOp.OR
elif atomic_type == primitives.AtomicOpType.XOR:
op = tt_dialect.RMWOp.XOR
else:
raise NotImplementedError(f"unsupported atomic operation: {atomic_type}")
return _atomic_rmw(op, ptr, val, mask=mask)
@register_lowering(primitives.atomic_cas_p)
def _atomic_cas_lowering_rule(ctx: LoweringRuleContext, ptr, cmp, val):
_, cmp_aval, val_aval = ctx.avals_in
if ir.RankedTensorType.isinstance(ptr.type):
ptr_type = ir.RankedTensorType(ptr.type)
element_type = tt_dialect.PointerType(ptr_type.element_type)
result_type = ir.RankedTensorType.get(
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
)
else:
result_type = tt_dialect.PointerType(ptr.type).pointee_type
return tt_dialect.atomic_cas(
result_type,
ptr,
_ensure_ir_value(cmp, cmp_aval),
_ensure_ir_value(val, val_aval),
sem=tt_dialect.MemSemantic.ACQUIRE_RELEASE,
scope=tt_dialect.MemSyncScope.GPU,
)
def _associative_scan_lowering(body, ctx: LoweringRuleContext, args, axes):
flat_args = tree_util.tree_leaves(args)
(axis,) = axes
dtype = ctx.avals_in[0].dtype
in_avals = [
jax_core.ShapedArray((), dtype=dtype),
jax_core.ShapedArray((), dtype=dtype),
]
in_tree = tree_util.tree_structure((args, args))
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(body), in_tree
)
combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
flat_fun, in_avals
)
out_tree = out_tree_thunk()
del out_tree # Not needed
if consts:
raise NotImplementedError("Associative scan with constants not supported.")
element_types = [_element_type(arg.type) for arg in flat_args]
scan_op = tt_dialect.ScanOp(flat_args, axis)
param_types = element_types * 2
entry = scan_op.regions[0].blocks.append(*param_types)
with ir.InsertionPoint.at_block_begin(entry):
results = lower_jaxpr_to_triton_ir(
ctx.context, combine_jaxpr, None, *entry.arguments
)
tt_dialect.scan_return(results)
scan_op.verify()
return list(scan_op.result)
@register_lowering(lax.cumsum_p)
def _cumsum_lowering_rule(
ctx: LoweringRuleContext, x, *, axis: int, reverse: bool
):
if reverse:
raise NotImplementedError("Reverse cumsum is not supported.")
return _associative_scan_lowering(jnp.add, ctx, x, (axis,))[0]
@register_lowering(lax.not_p)
def _not_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
return arith_dialect.xori(x, _full(x.type, ~x_aval.dtype.type(0)))
@dataclasses.dataclass(frozen=True)
class _Extern:
arg_types: Sequence[jax.typing.DTypeLike]
symbol: str
result_type: str
def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool:
if len(avals) != len(self.arg_types):
return False
return all(
aval.dtype == jnp.dtype(arg_type)
or (aval.weak_type and aval.dtype.kind == jnp.dtype(arg_type).kind)
for aval, arg_type in zip(avals, self.arg_types)
)
def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]):
[out_aval] = ctx.avals_out
result_type = _dtype_to_ir_type(jnp.dtype(self.result_type))
if out_aval.shape:
result_type = ir.RankedTensorType.get(out_aval.shape, result_type)
return tt_dialect.extern_elementwise(
result_type,
args,
libname="",
libpath="",
symbol=self.symbol,
pure=True,
)
@dataclasses.dataclass(frozen=True)
class _Fallback:
arg_types: Sequence[jax.typing.DTypeLike]
lower: Callable[..., ir.Value]
matches = _Extern.matches
def _make_dispatch_table(
name: str, **tables: Sequence[_Extern | _Fallback]
) -> Callable[..., ir.Value]:
def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value:
table = tables[ctx.context.platform]
h = next((e for e in table if e.matches(ctx.avals_in)), None)
if h is None:
arg_aval_dtypes = tuple(aval.dtype for aval in ctx.avals_in)
raise NotImplementedError(
f"unsupported types for {name}: {arg_aval_dtypes}"
)
[out_aval] = ctx.avals_out
bcast_args = []
for aval, arg, arg_type in zip(ctx.avals_in, args, h.arg_types):
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
if aval.weak_type and aval.dtype != jnp.dtype(arg_type):
bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type))
bcast_args.append(bcast_arg)
return h.lower(ctx, *bcast_args)
return inner
_abs_dispatch_table = _make_dispatch_table(
"abs",
cuda=[
_Extern([jnp.int32], "__nv_abs", jnp.int32),
_Extern([jnp.int64], "__nv_llabs", jnp.int64),
_Extern([jnp.float32], "__nv_fabsf", jnp.float32),
_Extern([jnp.float64], "__nv_fabs", jnp.float64),
],
rocm=[
_Fallback([jnp.int32], lambda ctx, x: math_dialect.absi(x)),
_Fallback([jnp.int64], lambda ctx, x: math_dialect.absi(x)),
_Fallback([jnp.float32], lambda ctx, x: math_dialect.absf(x)),
_Fallback([jnp.float64], lambda ctx, x: math_dialect.absf(x)),
],
)
@register_lowering(lax.abs_p)
def _abs_lowering_rule(ctx: LoweringRuleContext, x):
try:
return _abs_dispatch_table(ctx, x)
except NotImplementedError as e:
[x_aval] = ctx.avals_in
if jnp.issubdtype(x_aval, jnp.integer):
return math_dialect.absi(x)
elif jnp.issubdtype(x_aval, jnp.floating):
return math_dialect.absf(x)
else:
raise e from None
triton_lowering_rules.update({
lax.neg_p: lambda ctx, x: _minus(x),
lax.ceil_p: _make_dispatch_table(
"ceil",
cuda=[
_Extern([jnp.float32], "__nv_ceilf", jnp.float32),
_Extern([jnp.float64], "__nv_ceil", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_ceil_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_ceil_f64", jnp.float64),
],
),
lax.floor_p: _make_dispatch_table(
"floor",
cuda=[
_Extern([jnp.float32], "__nv_floorf", jnp.float32),
_Extern([jnp.float64], "__nv_floor", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_floor_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_floor_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)),
],
),
lax.exp_p: _make_dispatch_table(
"exp",
cuda=[
_Extern([jnp.float32], "__nv_expf", jnp.float32),
_Extern([jnp.float64], "__nv_exp", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)),
],
rocm=[
_Fallback([jnp.float32], lambda ctx, x: math_dialect.exp(x)),
_Fallback([jnp.float64], lambda ctx, x: math_dialect.exp(x)),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)),
],
),
lax.exp2_p: _make_dispatch_table(
"exp2",
cuda=[
_Extern([jnp.float32], "__nv_exp2f", jnp.float32),
_Extern([jnp.float64], "__nv_exp2", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_exp2_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_exp2_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)),
],
),
lax.expm1_p: _make_dispatch_table(
"expm1",
cuda=[
_Extern([jnp.float32], "__nv_expm1f", jnp.float32),
_Extern([jnp.float64], "__nv_expm1", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_expm1_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_expm1_f64", jnp.float64),
],
),
lax.log_p: _make_dispatch_table(
"log",
cuda=[
_Extern([jnp.float32], "__nv_logf", jnp.float32),
_Extern([jnp.float64], "__nv_log", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_log_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_log_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)),
],
),
lax.log1p_p: _make_dispatch_table(
"log1p",
cuda=[
_Extern([jnp.float32], "__nv_log1pf", jnp.float32),
_Extern([jnp.float64], "__nv_log1p", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_log1p_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_log1p_f64", jnp.float64),
],
),
lax.sqrt_p: _make_dispatch_table(
"sqrt",
cuda=[
_Extern([jnp.float32], "__nv_sqrtf", jnp.float32),
_Extern([jnp.float64], "__nv_sqrt", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_sqrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sqrt_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)),
],
),
lax.pow_p: _make_dispatch_table(
"pow",
cuda=[
_Extern([jnp.float32, jnp.int32], "__nv_powif", jnp.float32),
_Extern([jnp.float64, jnp.int32], "__nv_powi", jnp.float64),
_Extern([jnp.float32, jnp.float32], "__nv_powf", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_pow", jnp.float64),
],
rocm=[
_Extern([jnp.float32, jnp.int32], "__ocml_pown_f32", jnp.float32),
_Extern([jnp.float64, jnp.int32], "__ocml_pown_f64", jnp.float64),
_Extern([jnp.float32, jnp.float32], "__ocml_pow_f32", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__ocml_pow_f64", jnp.float64),
],
),
lax.cbrt_p: _make_dispatch_table(
"cbrt",
cuda=[
_Extern([jnp.float32], "__nv_cbrtf", jnp.float32),
_Extern([jnp.float64], "__nv_cbrt", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_cbrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cbrt_f64", jnp.float64),
],
),
lax.rsqrt_p: _make_dispatch_table(
"rsqrt",
cuda=[
_Extern([jnp.float32], "__nv_rsqrtf", jnp.float32),
_Extern([jnp.float64], "__nv_rsqrt", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_rsqrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_rsqrt_f64", jnp.float64),
],
),
lax.sin_p: _make_dispatch_table(
"sin",
cuda=[
_Extern([jnp.float32], "__nv_sinf", jnp.float32),
_Extern([jnp.float64], "__nv_sin", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_sin_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sin_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)),
],
),
lax.cos_p: _make_dispatch_table(
"cos",
cuda=[
_Extern([jnp.float32], "__nv_cosf", jnp.float32),
_Extern([jnp.float64], "__nv_cos", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_cos_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cos_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)),
],
),
lax.tan_p: _make_dispatch_table(
"tan",
cuda=[
_Extern([jnp.float32], "__nv_tanf", jnp.float32),
_Extern([jnp.float64], "__nv_tan", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_tan_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_tan_f64", jnp.float64),
],
),
lax.asin_p: _make_dispatch_table(
"asin",
cuda=[
_Extern([jnp.float32], "__nv_asinf", jnp.float32),
_Extern([jnp.float64], "__nv_asin", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_asin_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_asin_f64", jnp.float64),
],
),
lax.acos_p: _make_dispatch_table(
"acos",
cuda=[
_Extern([jnp.float32], "__nv_acosf", jnp.float32),
_Extern([jnp.float64], "__nv_acos", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_acos_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_acos_f64", jnp.float64),
],
),
lax.atan_p: _make_dispatch_table(
"atan",
cuda=[
_Extern([jnp.float32], "__nv_atanf", jnp.float32),
_Extern([jnp.float64], "__nv_atan", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_atan_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_atan_f64", jnp.float64),
],
),
lax.atan2_p: _make_dispatch_table(
"atan2",
cuda=[
_Extern([jnp.float32, jnp.float32], "__nv_atan2f", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_atan2", jnp.float64),
],
rocm=[
_Extern(
[jnp.float32, jnp.float32], "__ocml_atan2_f32", jnp.float32
),
_Extern(
[jnp.float64, jnp.float64], "__ocml_atan2_f64", jnp.float64
),
],
),
lax.sinh_p: _make_dispatch_table(
"sinh",
cuda=[
_Extern([jnp.float32], "__nv_sinhf", jnp.float32),
_Extern([jnp.float64], "__nv_sinh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_sinh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sinh_f64", jnp.float64),
],
),
lax.cosh_p: _make_dispatch_table(
"cosh",
cuda=[
_Extern([jnp.float32], "__nv_coshf", jnp.float32),
_Extern([jnp.float64], "__nv_cosh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_cosh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cosh_f64", jnp.float64),
],
),
lax.tanh_p: _make_dispatch_table(
"tanh",
cuda=[
_Extern([jnp.float32], "__nv_tanhf", jnp.float32),
_Extern([jnp.float64], "__nv_tanh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_tanh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_tanh_f64", jnp.float64),
],
),
lax.asinh_p: _make_dispatch_table(
"asinh",
cuda=[
_Extern([jnp.float32], "__nv_asinhf", jnp.float32),
_Extern([jnp.float64], "__nv_asinh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_asinh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_asinh_f64", jnp.float64),
],
),
lax.acosh_p: _make_dispatch_table(
"acosh",
cuda=[
_Extern([jnp.float32], "__nv_acoshf", jnp.float32),
_Extern([jnp.float64], "__nv_acosh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_acosh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_acosh_f64", jnp.float64),
],
),
lax.atanh_p: _make_dispatch_table(
"atanh",
cuda=[
_Extern([jnp.float32], "__nv_atanhf", jnp.float32),
_Extern([jnp.float64], "__nv_atanh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_atanh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_atanh_f64", jnp.float64),
],
),
lax.population_count_p: _make_dispatch_table(
"population_count",
cuda=[
_Extern([jnp.int32], "__nv_popc", jnp.int32),
_Extern([jnp.int64], "__nv_popcll", jnp.int32),
],
rocm=[
_Fallback([jnp.int32], lambda ctx, x: math_dialect.ctpop(x)),
_Fallback([jnp.int64], lambda ctx, x: math_dialect.ctpop(x)),
],
),
lax.clz_p: _make_dispatch_table(
"clz",
cuda=[
_Extern([jnp.int32], "__nv_clz", jnp.int32),
_Extern([jnp.int64], "__nv_clzll", jnp.int32),
],
rocm=[
_Fallback([jnp.int32], lambda ctx, x: math_dialect.ctlz(x)),
_Fallback([jnp.int64], lambda ctx, x: math_dialect.ctlz(x)),
],
),
lax.nextafter_p: _make_dispatch_table(
"nextafter",
cuda=[
_Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32 ),
_Extern([jnp.float64, jnp.float64], "__nv_nextafter", jnp.float64),
],
rocm=[
_Extern(
[jnp.float32, jnp.float32], "__ocml_nextafter_f32", jnp.float32
),
_Extern(
[jnp.float64, jnp.float64], "__ocml_nextafter_f64", jnp.float64
),
],
),
})
def _minus(x: ir.Value) -> ir.Value:
if tt_dialect.PointerType.isinstance(_element_type(x.type)):
raise NotImplementedError(f"unsupported type: {x.type}")
return _sub(_full(x.type, 0), x)
def _add(x: ir.Value, y: ir.Value):
x_element_type = _element_type(x.type)
y_element_type = _element_type(y.type)
if tt_dialect.PointerType.isinstance(x_element_type):
assert not tt_dialect.PointerType.isinstance(y_element_type)
return tt_dialect.addptr(x.type, x, y)
if tt_dialect.PointerType.isinstance(y_element_type):
return tt_dialect.addptr(y.type, y, x)
assert x.type == y.type, (str(x.type), str(y.type))
if isinstance(x_element_type, ir.IntegerType):
return arith_dialect.addi(x, y)
if isinstance(x_element_type, ir.FloatType):
return arith_dialect.addf(x, y)
raise NotImplementedError(f"unsupported dtypes: {x.type} and {y.type}")
def _sub(x: ir.Value, y: ir.Value) -> ir.Value:
x_element_type = _element_type(x.type)
y_element_type = _element_type(y.type)
if tt_dialect.PointerType.isinstance(x_element_type):
return tt_dialect.addptr(x.type, x, _minus(y))
elif not tt_dialect.PointerType.isinstance(y_element_type):
assert x.type == y.type, (str(x.type), str(y.type))
if isinstance(x_element_type, ir.IntegerType):
return arith_dialect.subi(x, y)
elif isinstance(x_element_type, ir.FloatType):
return arith_dialect.subf(x, y)
raise NotImplementedError(f"unsupported dtype: {y.type}")
def _mul(x: ir.Value, y: ir.Value) -> ir.Value:
assert x.type == y.type, (str(x.type), str(y.type))
x_element_type = _element_type(x.type)
if isinstance(x_element_type, ir.IntegerType):
return arith_dialect.muli(x, y)
elif isinstance(x_element_type, ir.FloatType):
return arith_dialect.mulf(x, y)
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
def _floordiv(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value:
assert x.type == y.type, (str(x.type), str(y.type))
x_element_type = _element_type(x.type)
if isinstance(x_element_type, (ir.F32Type, ir.F64Type)):
return arith_dialect.divf(x, y)
if not isinstance(x_element_type, ir.IntegerType):
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
if signed:
return arith_dialect.divsi(x, y)
else:
return arith_dialect.divui(x, y)
def _truediv(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value:
assert x.type == y.type, (str(x.type), str(y.type))
x_element_type = _element_type(x.type)
if isinstance(x_element_type, ir.IntegerType):
x_element_type = ir.F32Type.get()
x = _int_float_cast(x, x_element_type, signed=signed)
y = _int_float_cast(y, x_element_type, signed=signed)
if isinstance(x_element_type, (ir.F32Type, ir.F64Type)):
return arith_dialect.divf(x, y)
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
def _mod(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value:
assert x.type == y.type, (str(x.type), str(y.type))
x_element_type = _element_type(x.type)
if isinstance(x_element_type, ir.FloatType):
return arith_dialect.remf(x, y)
if not isinstance(x_element_type, ir.IntegerType):
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
if signed:
return arith_dialect.remsi(x, y)
else:
return arith_dialect.remui(x, y)
def _cmp(
x: ir.Value,
y: ir.Value,
si_pred: arith_dialect.CmpIPredicate,
ui_pred: arith_dialect.CmpIPredicate,
f_pred: arith_dialect.CmpFPredicate,
*,
signed: bool,
) -> ir.Value:
assert x.type == y.type, (str(x.type), str(y.type))
x_element_type = _element_type(x.type)
if isinstance(x_element_type, ir.IntegerType):
return arith_dialect.cmpi(si_pred if signed else ui_pred, x, y)
elif isinstance(x_element_type, ir.FloatType):
return arith_dialect.cmpf(f_pred, x, y)
else:
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
_equal = functools.partial(
_cmp,
si_pred=arith_dialect.CmpIPredicate.eq,
ui_pred=arith_dialect.CmpIPredicate.eq,
f_pred=arith_dialect.CmpFPredicate.OEQ,
)
_not_equal = functools.partial(
_cmp,
si_pred=arith_dialect.CmpIPredicate.ne,
ui_pred=arith_dialect.CmpIPredicate.ne,
f_pred=arith_dialect.CmpFPredicate.UNE,
)
_less_than = functools.partial(
_cmp,
si_pred=arith_dialect.CmpIPredicate.slt,
ui_pred=arith_dialect.CmpIPredicate.ult,
f_pred=arith_dialect.CmpFPredicate.OLT,
)
_less_equal = functools.partial(
_cmp,
si_pred=arith_dialect.CmpIPredicate.sle,
ui_pred=arith_dialect.CmpIPredicate.ule,
f_pred=arith_dialect.CmpFPredicate.OLE,
)
_greater_than = functools.partial(
_cmp,
si_pred=arith_dialect.CmpIPredicate.sgt,
ui_pred=arith_dialect.CmpIPredicate.ugt,
f_pred=arith_dialect.CmpFPredicate.OGT,
)
_greater_equal = functools.partial(
_cmp,
si_pred=arith_dialect.CmpIPredicate.sge,
ui_pred=arith_dialect.CmpIPredicate.uge,
f_pred=arith_dialect.CmpFPredicate.OGE,
)
_JAX_TO_TRITON_BINARY = {
lax.add_p: _add,
lax.sub_p: _sub,
lax.mul_p: _mul,
lax.and_p: arith_dialect.andi,
lax.or_p: arith_dialect.ori,
lax.xor_p: arith_dialect.xori,
lax.shift_left_p: arith_dialect.shli,
lax.shift_right_arithmetic_p: arith_dialect.shrsi,
lax.shift_right_logical_p: arith_dialect.shrui,
ad_util.add_any_p: _add,
}
for prim, fn in _JAX_TO_TRITON_BINARY.items():
def signless_rule(ctx: LoweringRuleContext, x, y, fn=fn):
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
return fn(x, y)
triton_lowering_rules[prim] = signless_rule
_JAX_TO_TRITON_SIGNED_BINARY = {
lax.rem_p: _mod,
lax.eq_p: _equal,
lax.ne_p: _not_equal,
lax.gt_p: _greater_than,
lax.ge_p: _greater_equal,
lax.lt_p: _less_than,
lax.le_p: _less_equal,
}
for prim, fn in _JAX_TO_TRITON_SIGNED_BINARY.items():
def signed_rule(ctx: LoweringRuleContext, x, y, fn=fn):
x_aval, _ = ctx.avals_in
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
return fn(x, y, signed=jnp.issubdtype(x_aval.dtype, jnp.signedinteger))
triton_lowering_rules[prim] = signed_rule
@register_lowering(primitives.debug_print_p)
def debug_print_lowering_rule(
ctx: LoweringRuleContext,
*args: ir.Value,
fmt: str,
has_placeholders: bool,
):
if has_placeholders:
raise ValueError(
"pl.debug_print() does not support placeholders when lowering to Triton"
)
tt_dialect.print_(
f" {fmt} ",
hex=False,
args=args,
is_signed=ir.DenseI32ArrayAttr.get([
jnp.issubdtype(aval.dtype, jnp.signedinteger) for aval in ctx.avals_in
]),
)
return ()
def _set_attr(v: ir.Value, name: str, attr: ir.Attribute) -> None:
if not ir.BlockArgument.isinstance(v):
v.owner.attributes[name] = attr
return
arg = ir.BlockArgument(v)
name += f"_arg{arg.arg_number}"
owner = arg.owner
is_entry = owner.region.blocks[0] == owner
if not is_entry:
return
if (op := owner.owner.operation) and not isinstance(op, tt_dialect.FuncOp):
op.attributes[name] = attr
@register_lowering(primitives.multiple_of_p)
def _multiple_of_rule(ctx: LoweringRuleContext, x, values: Sequence[int]):
[x_aval] = ctx.avals_in
assert max(1, len(x_aval.shape)) == len(values)
_set_attr(
x,
"tt.divisibility",
ir.DenseIntElementsAttr.get(np.asarray(values, dtype=np.int32)),
)
return x
@register_lowering(primitives.max_contiguous_p)
def _max_contiguous_rule(ctx: LoweringRuleContext, x, values: Sequence[int]):
[x_aval] = ctx.avals_in
assert len(x_aval.shape) == len(values)
_set_attr(
x,
"tt.contiguity",
ir.DenseIntElementsAttr.get(np.asarray(values, dtype=np.int32)),
)
return x
@register_lowering(sp.broadcast_to_p)
def _broadcast_to_rule(ctx: LoweringRuleContext, x, shape: Sequence[int]):
(x_aval,) = ctx.avals_in
return _bcast_to(_ensure_ir_value(x, x_aval), shape)
@register_lowering(lax.integer_pow_p)
def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int):
if y == 0:
return _full(x.type, 1)
is_reciprocal = y < 0
if is_reciprocal:
y = -y
acc = None
while y > 0:
y, mod = divmod(y, 2)
if mod:
acc = x if acc is None else _mul(acc, x)
if y > 0:
x = _mul(x, x)
assert acc is not None
[x_aval] = ctx.avals_in
[out_aval] = ctx.avals_out
acc = _cast(acc, x_aval.dtype, out_aval.dtype)
if is_reciprocal:
signed = jnp.issubdtype(out_aval.dtype, jnp.signedinteger)
return _truediv(_full(acc.type, 1), acc, signed=signed)
else:
return acc
_JAX_FN_MAPPING = {
lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max),
lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)),
}
for prim, fn in _JAX_FN_MAPPING.items():
triton_lowering_rules[prim] = lower_fun(fn, multiple_results=False)
@register_lowering(lax.min_p)
def _min_lowering_rule(ctx: LoweringRuleContext, x, y):
# TODO(slebedev): Consider allowing customizing nan behavior.
x_aval, y_aval = ctx.avals_in
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
if jnp.issubdtype(x_aval.dtype, jnp.floating):
# TODO(slebedev): Triton promotes bfloat16 to float32 and back here.
return arith_dialect.minnumf(x, y)
if not jnp.issubdtype(x_aval.dtype, jnp.integer):
raise NotImplementedError(
f"unsupported dtypes: {x_aval.dtype} and {y_aval.dtype}"
)
if jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
return arith_dialect.minsi(x, y)
else:
return arith_dialect.minui(x, y)
@register_lowering(lax.max_p)
def _max_lowering_rule(ctx: LoweringRuleContext, x, y):
# TODO(slebedev): Consider allowing customizing nan behavior.
x_aval, y_aval = ctx.avals_in
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
if jnp.issubdtype(x_aval.dtype, jnp.floating):
# TODO(slebedev): Triton promotes bfloat16 to float32 and back here.
return arith_dialect.maxnumf(x, y)
if not jnp.issubdtype(x_aval.dtype, jnp.integer):
raise NotImplementedError(
f"unsupported dtypes: {x_aval.dtype} and {y_aval.dtype}"
)
if jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
return arith_dialect.maxsi(x, y)
else:
return arith_dialect.maxui(x, y)
@register_lowering(lax.div_p)
def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
x_aval, y_aval = ctx.avals_in
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger) or jnp.issubdtype(
y_aval.dtype, jnp.signedinteger
)
if jnp.issubdtype(x_aval.dtype, np.floating) or jnp.issubdtype(
y_aval.dtype, np.floating
):
return _truediv(x, y, signed=signed)
return _floordiv(x, y, signed=signed)
register_lowering(lax.sign_p)(
lower_fun(pallas_utils.sign_lowering_helper, multiple_results=False)
)
register_lowering(lax.erf_inv_p)(
lower_fun(pallas_utils.erf_inv_lowering_helper, multiple_results=False)
)
@register_lowering(lax.iota_p)
def _iota_lowering_rule(ctx: LoweringRuleContext, *, dtype, shape, dimension):
iota = _make_range(0, shape[dimension])
iota = _cast(iota, jnp.int32, dtype)
for i in range(len(shape)):
if i != dimension:
iota = _expand_dims(iota, i)
return _bcast_to(iota, shape)
def _element_type(t: ir.Type) -> ir.Type:
if ir.RankedTensorType.isinstance(t):
return ir.RankedTensorType(t).element_type
else:
return t
def _make_range(start: int, end: int) -> ir.Value:
if end <= start:
raise ValueError(
f"end must be greater than start, but got: {end} <= {start}"
)
if max(start, end) >= 2**32:
raise ValueError("start and end must fit in int32")
return tt_dialect.make_range(
ir.RankedTensorType.get([end - start], ir.IntegerType.get_signless(32)),
start,
end,
)
def _full(t: ir.Type, v: object) -> ir.Type:
element_type = _element_type(t)
if isinstance(element_type, ir.IntegerType):
result = arith_dialect.constant(element_type, int(v))
elif isinstance(element_type, ir.FloatType):
result = arith_dialect.constant(element_type, float(v))
else:
raise NotImplementedError
if ir.RankedTensorType.isinstance(t):
return tt_dialect.splat(t, result)
else:
return result
def _splat(x: ir.value, shape: Sequence[int]) -> ir.Value:
if ir.RankedTensorType.isinstance(x.type):
raise TypeError("cannot splat a tensor")
if not shape:
return x
return tt_dialect.splat(ir.RankedTensorType.get(shape, x.type), x)
def _expand_dims(x: ir.Value, axis: int) -> ir.Value:
if not ir.RankedTensorType.isinstance(x.type):
shape = list(ir.RankedTensorType(x.type).shape)
shape.insert(axis, 1)
return _splat(x, shape)
return tt_dialect.expand_dims(x, axis)
def _float_float_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
src_element_type = ir.FloatType(_element_type(src.type))
dst_element_type = ir.FloatType(_element_type(dst_type))
if src_element_type.width == 8 or dst_element_type.width == 8:
return tt_dialect.fp_to_fp(
dst_type,
src,
rounding=tt_dialect.RoundingMode.RTNE,
)
if src_element_type.width > dst_element_type.width:
return arith_dialect.truncf(dst_type, src)
elif src_element_type.width < dst_element_type.width:
return arith_dialect.extf(dst_type, src)
else:
raise NotImplementedError
def _int_int_cast(src: ir.Value, dst_type: ir.Type, signed: bool) -> ir.Value:
src_element_type = ir.IntegerType(_element_type(src.type))
dst_element_type = ir.IntegerType(_element_type(dst_type))
assert src_element_type != dst_element_type
if dst_element_type.width == 1:
return _not_equal(src, _full(src.type, 0), signed=signed)
if src_element_type.width == dst_element_type.width:
return arith_dialect.bitcast(dst_type, src)
elif src_element_type.width > dst_element_type.width:
return arith_dialect.trunci(dst_type, src)
elif signed and src_element_type.width != 1:
return arith_dialect.extsi(dst_type, src)
else:
return arith_dialect.extui(dst_type, src)
def _float_int_cast(
src: ir.Value, dst_type: ir.Type, *, signed: bool
) -> ir.Value:
src_element_type = _element_type(src.type)
if not isinstance(src_element_type, (ir.BF16Type, ir.F16Type, ir.F32Type, ir.F64Type)):
raise NotImplementedError(f"cannot cast {src} tp {dst_type}")
dst_element_type = ir.IntegerType(_element_type(dst_type))
if dst_element_type.width == 1:
return _not_equal(src, _full(src.type, 0), signed=signed)
elif signed:
return arith_dialect.fptosi(dst_type, src)
else:
return arith_dialect.fptoui(dst_type, src)
def _int_float_cast(
src: ir.Value, dst_type: ir.Type, *, signed: bool
) -> ir.Value:
src_element_type = ir.IntegerType(_element_type(src.type))
dst_element_type = _element_type(dst_type)
if not isinstance(
dst_element_type, (ir.BF16Type, ir.F16Type, ir.F32Type, ir.F64Type)
):
raise NotImplementedError(f"cannot cast {src} tp {dst_type}")
if src_element_type.width == 1 or not signed:
return arith_dialect.uitofp(dst_type, src)
else:
return arith_dialect.sitofp(dst_type, src)
def _cast(
src: ir.Value,
src_type: jax.typing.DTypeLike,
dst_type: jax.typing.DTypeLike,
) -> ir.Value:
return _ir_cast(
src,
_dtype_to_ir_type(dst_type),
signed=jnp.issubdtype(src_type, jnp.signedinteger),
)
def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
if ir.RankedTensorType.isinstance(
src.type
) and not ir.RankedTensorType.isinstance(dst_type):
src_type = ir.RankedTensorType(src.type)
dst_type = ir.RankedTensorType.get(
src_type.shape,
dst_type,
src_type.encoding,
)
if src.type == dst_type:
return src
src_element_type = _element_type(src.type)
dst_element_type = _element_type(dst_type)
if isinstance(src_element_type, ir.Float8E4M3FNUZType) or isinstance(
dst_element_type, ir.Float8E4M3FNUZType
):
# TODO(slebedev): Check the CUDA version and raise conditionally.
raise NotImplementedError("cannot cast from or to float8_e4m3fnuz")
if isinstance(src_element_type, (ir.F16Type, ir.BF16Type)) and not isinstance(
dst_element_type, ir.F32Type
):
return _ir_cast(
_ir_cast(src, ir.F32Type.get(), signed=False), dst_type, signed=False
)
if isinstance(src_element_type, ir.FloatType) and isinstance(
dst_element_type, ir.FloatType
):
return _float_float_cast(src, dst_type)
if isinstance(src_element_type, ir.IntegerType) and isinstance(
dst_element_type, ir.IntegerType
):
return _int_int_cast(src, dst_type, signed=signed)
if isinstance(src_element_type, ir.FloatType) and isinstance(
dst_element_type, ir.IntegerType
):
return _float_int_cast(src, dst_type, signed=signed)
if isinstance(src_element_type, ir.IntegerType) and isinstance(
dst_element_type, ir.FloatType
):
return _int_float_cast(src, dst_type, signed=signed)
if tt_dialect.PointerType.isinstance(src_element_type) and isinstance(
dst_element_type, ir.IntegerType
):
if dst_element_type.width == 64:
return tt_dialect.ptr_to_int(dst_type, src)
elif dst_element_type.width == 1:
x = _ir_cast(src, ir.IntegerType.get_signless(64), signed=signed)
zero = _full(x.type, 0)
return _ir_cast(_not_equal(x, zero, signed=signed), dst_type, signed=signed)
if isinstance(
src_element_type, ir.IntegerType
) and tt_dialect.PointerType.isinstance(dst_element_type):
return tt_dialect.int_to_ptr(dst_type, src)
if tt_dialect.PointerType.isinstance(
src_element_type
) and tt_dialect.PointerType.isinstance(dst_element_type):
return tt_dialect.bitcast(dst_type, src)
raise NotImplementedError(f"cannot cast {src} to {dst_type}")
@register_lowering(lax.convert_element_type_p)
def _convert_element_type_lowering_rule(
ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
):
[x_aval] = ctx.avals_in
x = _ensure_ir_value(x, x_aval)
if new_dtype == x_aval.dtype:
return x
return _cast(x, x_aval.dtype, new_dtype)
@register_lowering(lax.select_n_p)
def select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, y):
pred_aval, a_aval, b_aval = ctx.avals_in
[out_aval] = ctx.avals_out
pred, x = _bcast(pred, x, pred_aval, a_aval, out_aval)
pred, y = _bcast(pred, y, pred_aval, b_aval, out_aval)
return arith_dialect.select(pred, y, x)
@register_lowering(lax.broadcast_in_dim_p)
def _broadcast_in_dim_lowering_rule(
ctx: LoweringRuleContext, x, *, broadcast_dimensions, shape
):
x = _ensure_ir_value(x, *ctx.avals_in)
if not ir.RankedTensorType.isinstance(x.type):
return _bcast_to(x, shape)
expand_dims = [i for i in range(len(shape)) if i not in broadcast_dimensions]
for dim in expand_dims:
x = _expand_dims(x, dim)
return _bcast_to(x, shape)
@register_lowering(lax.squeeze_p)
def _squeeze_lowering_rule(ctx: LoweringRuleContext, a, *, dimensions):
del dimensions
return _reshape_lowering_rule(ctx, a, new_sizes=None, dimensions=None)
def _reshape(x: ir.Value, shape: Sequence[int]) -> ir.Value:
if not shape:
raise ValueError("cannot reshape to an empty shape")
ty = ir.RankedTensorType(x.type)
return tt_dialect.reshape(
ir.RankedTensorType.get(shape, ty.element_type, ty.encoding),
x,
allow_reorder=False,
)
@register_lowering(lax.reshape_p)
def _reshape_lowering_rule(
ctx: LoweringRuleContext, a, *, new_sizes, dimensions
):
del new_sizes # Unused.
if dimensions is not None:
return ValueError("`dimensions` is not supported.")
a = _ensure_ir_value(a, *ctx.avals_in)
[out_aval] = ctx.avals_out
if not ir.RankedTensorType.isinstance(a.type):
assert all(dim_size == 1 for dim_size in out_aval.shape)
return _splat(a, out_aval.shape)
# TODO(slebedev): Check that the following comment still applies.
# Expand-dims or reduce-sum to handle singleton dims as `tl.reshape` is not
# currently implemented.
dst_shape = [*out_aval.shape]
i = 0
while (
ir.RankedTensorType.isinstance(a.type)
and (a_shape := ir.RankedTensorType(a.type).shape) != dst_shape
):
dim_size = a_shape[i] if i < len(a_shape) else None
dst_dim_size = dst_shape[i] if i < len(dst_shape) else None
if dim_size == dst_dim_size:
i += 1
elif dst_dim_size == 1:
a = _expand_dims(a, axis=i)
i += 1
elif dim_size == 1:
in_shape = a_shape
out_shape = tuple(d for di, d in enumerate(a_shape) if di != i)
reduce_ctx = ctx.replace(
avals_in=[ctx.avals_in[0].update(shape=in_shape)],
avals_out=[ctx.avals_in[0].update(shape=out_shape)],
)
a = _reduce_lowering(jnp.add, reduce_ctx, a, axes=(i,))
else: # We expect this to fail.
return _reshape(a, dst_shape)
return a
def _compute_pointers_from_indices(
root_ptr: ir.Value,
block_info: BlockInfo | None,
nd_indexer: NDIndexer,
array_shape_dtype: Any,
) -> ir.Value:
if block_info is None: # TODO(necula): is this branch dead?
full_shape = array_shape_dtype.shape
num_mapped_dims = 0
block_shape = array_shape_dtype.shape
else:
full_shape = block_info.full_shape_dtype.shape
num_mapped_dims = sum(
b is pallas_core.mapped for b in block_info.block_shape
)
block_shape = block_info.block_shape
strides = pallas_utils.strides_from_shape(full_shape)
indexer_shape = nd_indexer.get_indexer_shape()
int_indexer_shape = nd_indexer.int_indexer_shape
_check_tensor_size(indexer_shape)
indices = nd_indexer.indices
other_shape = indexer_shape[len(int_indexer_shape) :]
other_shape_idx = 0
if block_info is None:
start_index_offsets = [None] * len(indices)
else:
start_index_offsets = block_info.start_indices
assert len(indices) + num_mapped_dims == len(full_shape)
assert len(start_index_offsets) == len(full_shape)
array_dtype = jnp.dtype(array_shape_dtype.dtype)
full_size = math.prod(full_shape) * array_dtype.itemsize
# Use 64-bit indexing when offset might be >= 2**32 bytes.
offset_eltype = ir.IntegerType.get_signless(64 if full_size > 2**32 else 32)
if indexer_shape:
offsets = _full(ir.RankedTensorType.get(indexer_shape, offset_eltype), 0)
else:
offsets = _ir_constant(0, offset_eltype)
indexer_iter = iter(indices)
for dim_stride, dim_block_size, start_offset in zip(
strides, block_shape, start_index_offsets
):
if dim_block_size is pallas_core.mapped:
index = _ir_constant(0, offset_eltype)
else:
index = next(indexer_iter)
if isinstance(index, slice):
index = primitives.Slice.from_slice(index, dim_block_size)
if isinstance(index, primitives.Slice):
if index.is_dynamic_start or (index.stride != 1):
start = index.start
if not index.is_dynamic_start:
start = _ir_constant(start, offset_eltype)
start = _ir_cast(start, offset_eltype, signed=False)
iota = _ir_cast(_make_range(0, index.size), offset_eltype, signed=False)
if index.stride != 1:
iota = _mul(iota, _full(iota.type, index.stride))
dim_offsets = _add(_bcast_to(start, [index.size]), iota)
else:
iota = _make_range(index.start, index.start + index.size)
dim_offsets = _ir_cast(iota, offset_eltype, signed=False)
other_shape_idx += 1
for _ in other_shape[other_shape_idx:]:
rank = ir.RankedTensorType(dim_offsets.type).rank
dim_offsets = _expand_dims(dim_offsets, rank)
else:
# indexer is either a *scalar* or an array of size `int_indexer_shape`
dim_offsets = index
if not isinstance(dim_offsets, ir.Value):
dim_offsets = _ir_constant(dim_offsets, offset_eltype)
dim_offsets = _ir_cast(dim_offsets, offset_eltype, signed=False)
if ir.RankedTensorType.isinstance(dim_offsets.type):
for _ in other_shape:
rank = ir.RankedTensorType(dim_offsets.type).rank
dim_offsets = _expand_dims(dim_offsets, rank)
if ir.RankedTensorType.isinstance(dim_offsets.type):
rank = ir.RankedTensorType(dim_offsets.type).rank
for _ in range(len(indexer_shape) - rank):
dim_offsets = _expand_dims(dim_offsets, 0)
dim_offsets = _bcast_to(dim_offsets, indexer_shape)
if start_offset is not None:
start_offset = _ir_cast(start_offset, offset_eltype, signed=False)
dim_offsets = _add(dim_offsets, _bcast_to(start_offset, indexer_shape))
dim_offsets = _mul(dim_offsets, _full(dim_offsets.type, dim_stride))
offsets = _add(offsets, dim_offsets)
return _add(_bcast_to(root_ptr, indexer_shape), offsets)
@register_lowering(sp.get_p)
def _get_lowering_rule(ctx: LoweringRuleContext, ptr, *idx, tree):
indexers = tree_util.tree_unflatten(tree, idx)
if not tt_dialect.PointerType.isinstance(ptr.type):
assert len(indexers) == 0
return ptr
if len(indexers) > 1:
raise NotImplementedError("No support for multiple indexers yet.")
indexer = indexers[0]
args_flat, args_tree = tree_util.tree_flatten((ptr, (indexer,), None, None))
return _masked_load_lowering_rule(
ctx,
*args_flat,
args_tree=args_tree,
eviction_policy=None,
cache_modifier=None,
is_volatile=False,
)
_STR_TO_EVICTION_POLICY = {str(e): e for e in tt_dialect.EvictionPolicy}
_STR_TO_CACHE_MODIFIER = {str(c): c for c in tt_dialect.CacheModifier}
def _load(
ptr: ir.Value,
mask: ir.Value | None = None,
other: ir.Value | None = None,
*,
cache_modifier: str | None = None,
eviction_policy: str | None = None,
is_volatile: bool = False,
) -> ir.Value:
if cache_modifier is None:
cache_modifier = tt_dialect.CacheModifier.NONE
elif cache_modifier == ".ca" or cache_modifier == ".cg":
cache_modifier = _STR_TO_CACHE_MODIFIER[cache_modifier]
else:
raise ValueError(f"unsupported cache modifier: {cache_modifier}")
if eviction_policy is None:
eviction_policy = tt_dialect.EvictionPolicy.NORMAL
else:
try:
eviction_policy = _STR_TO_EVICTION_POLICY[eviction_policy]
except KeyError:
raise ValueError(
f"unsupported eviction policy: {eviction_policy}"
) from None
if tt_dialect.PointerType.isinstance(ptr.type):
ptr_type = tt_dialect.PointerType(ptr.type)
if ir.RankedTensorType.isinstance(ptr_type.pointee_type):
raise NotImplementedError("loading from a block pointer is not supported")
ptr_type = _element_type(ptr.type)
if not tt_dialect.PointerType.isinstance(ptr_type):
raise ValueError(f"unsupported pointer type: {ptr_type}")
ptr_type = tt_dialect.PointerType(ptr_type)
if other is not None and mask is None:
raise ValueError("other requires mask to be provided")
if not ir.RankedTensorType.isinstance(ptr.type):
if other is not None and ir.RankedTensorType.isinstance(other.type):
raise ValueError("other cannot be a block if pointer is not a block")
if mask is not None and ir.RankedTensorType.isinstance(mask.type):
raise ValueError("mask cannot be a block if pointer is not a block")
pointee_type = ptr_type.pointee_type
is_int1 = isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1
if is_int1:
pointee_type = ir.IntegerType.get_signless(8)
ptr = _ir_cast(
ptr,
tt_dialect.PointerType.get(pointee_type, ptr_type.address_space),
signed=False,
)
if other is not None:
other = _ir_cast(other, pointee_type, signed=False)
result = tt_dialect.load(
ptr,
mask=mask,
other=other,
cache=cache_modifier,
evict=eviction_policy,
is_volatile=is_volatile,
)
return (
result
if not is_int1
else _ir_cast(result, ir.IntegerType.get_signless(1), signed=False)
)
@register_lowering(primitives.load_p)
def _masked_load_lowering_rule(
ctx: LoweringRuleContext,
*args_flat,
args_tree,
eviction_policy,
cache_modifier,
is_volatile,
):
ptr, indexers, mask, other = args_tree.unflatten(args_flat)
*_, mask_aval, other_aval = args_tree.unflatten(ctx.avals_in)
if len(indexers) > 1:
raise NotImplementedError("No support for multiple indexers yet.")
idx = indexers[0]
if not tt_dialect.PointerType.isinstance(ptr.type):
assert len(ctx.avals_in) == 1
return ptr
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
)
if mask is not None:
mask = _bcast_to(_ensure_ir_value(mask, mask_aval), idx.get_indexer_shape())
if other is not None:
other = _bcast_to(
_ensure_ir_value(other, other_aval), idx.get_indexer_shape()
)
return _load(
ptr,
mask=mask,
other=other,
cache_modifier=cache_modifier,
is_volatile=is_volatile,
eviction_policy=eviction_policy,
)
@register_lowering(sp.swap_p)
def _swap_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree):
indexers = tree_util.tree_unflatten(tree, idx)
if not tt_dialect.PointerType.isinstance(ptr.type):
assert len(indexers) == 0
return ptr
if len(indexers) > 1:
raise NotImplementedError("No support for multiple indexers yet.")
indexer = indexers[0]
args_flat, args_tree = tree_util.tree_flatten((ptr, (indexer,), value, None))
return _masked_swap_lowering_rule(
ctx, *args_flat, args_tree=args_tree, eviction_policy=None
)
def _store(
ptr: ir.Value,
value: ir.Value,
mask: ir.Value | None = None,
*,
cache_modifier: str | None = None,
eviction_policy: str | None = None,
) -> ir.Value:
if cache_modifier is None:
cache_modifier = tt_dialect.CacheModifier.NONE
elif cache_modifier != ".ca":
cache_modifier = _STR_TO_CACHE_MODIFIER[cache_modifier]
else:
raise ValueError(f"unsupported cache modifier: {cache_modifier}")
if eviction_policy is None:
eviction_policy = tt_dialect.EvictionPolicy.NORMAL
else:
try:
eviction_policy = _STR_TO_EVICTION_POLICY[eviction_policy]
except KeyError:
raise ValueError(
f"unsupported eviction policy: {eviction_policy}"
) from None
if tt_dialect.PointerType.isinstance(ptr.type):
ptr_type = tt_dialect.PointerType(ptr.type)
if ir.RankedTensorType.isinstance(ptr_type.pointee_type):
raise NotImplementedError("loading from a block pointer is not supported")
ptr_type = _element_type(ptr.type)
if not tt_dialect.PointerType.isinstance(ptr_type):
raise ValueError(f"unsupported pointer type: {ptr_type}")
ptr_type = tt_dialect.PointerType(ptr_type)
if not ir.RankedTensorType.isinstance(ptr.type):
if ir.RankedTensorType.isinstance(value.type):
raise ValueError("value cannot be a block if pointer is not a block")
if mask is not None and ir.RankedTensorType.isinstance(mask.type):
raise ValueError("mask cannot be a block if pointer is not a block")
pointee_type = ptr_type.pointee_type
if isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1:
pointee_type = ir.IntegerType.get_signless(8)
ptr = _ir_cast(
ptr,
tt_dialect.PointerType.get(pointee_type, ptr_type.address_space),
signed=False,
)
value = _ir_cast(value, pointee_type, signed=False)
return tt_dialect.store(
ptr, value, mask=mask, cache=cache_modifier, evict=eviction_policy
)
@register_lowering(primitives.swap_p)
def _masked_swap_lowering_rule(
ctx: LoweringRuleContext, *args_flat, args_tree, eviction_policy
):
ptr, indexers, value, mask = args_tree.unflatten(args_flat)
*_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in)
if len(indexers) > 1:
raise NotImplementedError("No support for multiple indexers yet.")
idx = indexers[0]
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
)
other = None
if value is not None:
value = _ensure_ir_value(value, value_aval)
if mask is not None:
mask = _bcast_to(_ensure_ir_value(mask, mask_aval), idx.get_indexer_shape())
if value is not None:
other = _bcast_to(value, idx.get_indexer_shape())
old_value = _load(ptr, mask=mask, other=other)
_store(ptr, value, mask=mask, eviction_policy=eviction_policy)
return old_value
@register_lowering(sp.addupdate_p)
def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree):
indexers = tree_util.tree_unflatten(tree, idx)
if not tt_dialect.PointerType.isinstance(ptr.type):
assert len(indexers) == 0
return ptr
if len(indexers) > 1:
raise NotImplementedError("No support for multiple indexers yet.")
indexer = indexers[0]
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], indexer, ctx.avals_in[0]
)
op = tt_dialect.RMWOp.FADD
if isinstance(_element_type(value.type), ir.IntegerType):
op = tt_dialect.RMWOp.ADD
_atomic_rmw(op, ptr, value)
return []
@register_lowering(lax.transpose_p)
def _transpose_lowering(ctx: LoweringRuleContext, x, *, permutation):
return tt_dialect.trans(x, permutation)
def _check_dot_operands(
x_type: ir.RankedTensorType, y_type: ir.RankedTensorType, options: Any
):
# TODO(slebedev): Ensure that the dtypes are supported by CUDA.
return
def _dot(
x: ir.Value,
y: ir.Value,
acc: ir.Value | None = None,
*,
allow_tf32: bool = True,
max_num_imprecise_acc: int | None = None,
out_type: ir.Type | None = None,
) -> ir.Value:
if out_type is None:
out_type = ir.F32Type.get()
elif isinstance(out_type, ir.BF16Type):
raise NotImplementedError(f"unsupported output type: {out_type}")
x_type = ir.RankedTensorType(x.type)
y_type = ir.RankedTensorType(y.type)
if min(*x_type.shape, *y_type.shape) < 16:
raise ValueError("all dimensions of x and y must be >= 16 ")
if x_type.element_type != y_type.element_type:
raise ValueError(
"x and y must have the same element type, but got:"
f" {x_type.element_type} and {y_type.element_type}"
)
_check_dot_operands(x_type, y_type, object())
element_type = x_type.element_type
if isinstance(element_type, ir.IntegerType):
if element_type.width != 8:
raise TypeError(f"unsupported element type: {element_type}")
element_type = ir.IntegerType.get_signless(32)
elif isinstance(element_type, (ir.F32Type, ir.BF16Type)):
element_type = ir.F32Type.get()
else:
element_type = out_type
if element_type != out_type:
raise TypeError(
f"output type {out_type} does not match element type {element_type}"
)
m, _ = x_type.shape
_, n = y_type.shape
if acc is None:
acc = _full(ir.RankedTensorType.get([m, n], element_type), 0)
if max_num_imprecise_acc is None:
if isinstance(element_type, ir.FloatType) and element_type.width == 8:
# TODO(slebedev): Fill in from options.
raise NotImplementedError
else:
max_num_imprecise_acc = 0
# Ideally, replace all allow_tf32 usages with InputPrecision directly.
input_precision = tt_dialect.InputPrecision.IEEE
if allow_tf32:
input_precision = tt_dialect.InputPrecision.TF32
return tt_dialect.dot(
x,
y,
acc,
max_num_imprecise_acc=max_num_imprecise_acc,
input_precision=input_precision
)
_TF32_PRECISIONS = (lax.Precision.HIGH, lax.Precision.DEFAULT)
@register_lowering(lax.dot_general_p)
def _dot_general_lowering(
ctx: LoweringRuleContext,
a,
b,
*,
dimension_numbers,
precision,
preferred_element_type,
):
del preferred_element_type # Unused.
((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers
assert batch_dims == ((), ())
if a_contract_dim == 0:
a = tt_dialect.trans(a, (1, 0))
if b_contract_dim == 1:
b = tt_dialect.trans(b, (1, 0))
if precision is None:
allow_tf32 = True
else:
prec_a, prec_b = precision
allow_tf32 = prec_a in _TF32_PRECISIONS or prec_b in _TF32_PRECISIONS
[out_aval] = ctx.avals_out
out_dtype = acc_dtype = out_aval.dtype
if acc_dtype != jnp.int32 and acc_dtype != jnp.float16:
acc_dtype = jnp.dtype(jnp.float32)
return _cast(
_dot(
a,
b,
allow_tf32=allow_tf32,
out_type=_dtype_to_ir_type(acc_dtype),
),
acc_dtype,
out_dtype,
)
def _reduction_lowering(body, ctx: LoweringRuleContext, a, axes):
flat_args = tree_util.tree_leaves(a)
(axis,) = axes
mapped_avals = [jax_core.ShapedArray((), aval.dtype) for aval in ctx.avals_in]
in_tree = tree_util.tree_structure((a, a))
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(body), in_tree
)
combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
flat_fun, [*mapped_avals, *mapped_avals]
)
out_tree = out_tree_thunk()
del out_tree # Not needed
if consts:
raise NotImplementedError("Reductions with constants not supported.")
element_types = [_element_type(arg.type) for arg in flat_args]
reduce_op = tt_dialect.ReduceOp(flat_args, axis)
param_types = element_types * 2
entry = reduce_op.regions[0].blocks.append(*param_types)
with ir.InsertionPoint.at_block_begin(entry):
results = lower_jaxpr_to_triton_ir(
ctx.context, combine_jaxpr, None, *entry.arguments
)
tt_dialect.reduce_return(results)
reduce_op.verify()
return list(reduce_op.result)
def _reduce_lowering(body, ctx: LoweringRuleContext, a, *, axes):
assert isinstance(axes, tuple)
if not axes:
return a
while len(axes) > 1:
axis = max(axes)
dst_avals = tuple(v.update(shape=v.shape[:axis] + v.shape[axis + 1:])
for v in ctx.avals_in)
a = _reduce_lowering(
body, ctx.replace(avals_out=dst_avals), a, axes=(axis,))
# Adding an intervening -(-reduce(.)) introduces a convert_layout between
# reduces, which seems necessary for correctness.
# TODO(bjp): Get rid of the double negation.
# https://github.com/openai/triton/issues/1776
a = _minus(_minus(a))
ctx = ctx.replace(avals_in=dst_avals)
axes = tuple(ax for ax in axes if ax != axis)
return _reduction_lowering(body, ctx, a, axes=axes)[0]
triton_lowering_rules[lax.reduce_max_p] = functools.partial(
_reduce_lowering, jnp.maximum
)
triton_lowering_rules[lax.reduce_min_p] = functools.partial(
_reduce_lowering, jnp.minimum
)
triton_lowering_rules[lax.reduce_sum_p] = functools.partial(
_reduce_lowering, jnp.add
)
def _argreduce_lowering(
body, ctx: LoweringRuleContext, a, *, axes, index_dtype
):
if index_dtype != jnp.int32:
raise ValueError("`index_type` must be i32.")
if len(axes) != 1:
raise ValueError("`pallas` reduce operations only support one reduce axis.")
[axis] = axes
[a_aval] = ctx.avals_in
index = _make_range(0, a_aval.shape[axis])
if len(a_aval.shape) > 1:
# Broadcast index across the non-reduced axes
for i in range(len(a_aval.shape)):
if i != axis:
index = _expand_dims(index, i)
index = _bcast_to(index, a_aval.shape)
ctx = ctx.replace(avals_in=[a_aval, a_aval.update(dtype=jnp.dtype(jnp.int32))])
_, indices = _reduction_lowering(body, ctx, (a, index), axes=axes)
return indices
def _reduce_argmax_combine(left, right):
value1, index1 = left
value2, index2 = right
gt = value1 > value2
lt = value1 < value2
index_min = jnp.minimum(index1, index2)
index_ret = jnp.where(gt, index1, jnp.where(lt, index2, index_min))
value_ret = jnp.maximum(value1, value2)
return value_ret, index_ret
triton_lowering_rules[lax.argmax_p] = functools.partial(
_argreduce_lowering, _reduce_argmax_combine
)
def _reduce_argmin_combine(left, right):
value1, index1 = left
value2, index2 = right
gt = value1 > value2
lt = value1 < value2
index_min = jnp.minimum(index1, index2)
index_ret = jnp.where(lt, index1, jnp.where(gt, index2, index_min))
value_ret = jnp.minimum(value1, value2)
return value_ret, index_ret
triton_lowering_rules[lax.argmin_p] = functools.partial(
_argreduce_lowering, _reduce_argmin_combine
)
@register_lowering(pjit.pjit_p)
def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
if jaxpr.consts:
raise NotImplementedError
return lower_jaxpr_to_triton_ir(
ctx.context, jaxpr.jaxpr, ctx.block_infos, *args
)
@register_lowering(jax_core.closed_call_p)
@register_lowering(custom_derivatives.custom_jvp_call_p)
def _closed_call_lowering_rule(
ctx: LoweringRuleContext, *args, call_jaxpr, **_
):
jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
if consts:
raise NotImplementedError
return lower_jaxpr_to_triton_ir(ctx.context, jaxpr, ctx.block_infos, *args)
@register_lowering(ad_checkpoint.remat_p)
def _remat_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
return lower_jaxpr_to_triton_ir(ctx.context, jaxpr, ctx.block_infos, *args)
triton_lowering_rules[ad_util.stop_gradient_p] = lambda _, x: x
@register_lowering(lax.axis_index_p)
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
grid_names = ctx.context.grid_mapping.grid_names
if axis_name in grid_names:
# We are querying a named axis corresponding to a grid dimension.
return _program_id_lowering_rule(ctx, axis=grid_names.index(axis_name))
raise LookupError(f"Axis name {axis_name} not found in grid.")
def _is_read_only(ref_effects) -> bool:
if len(ref_effects) == 0:
return True
if len(ref_effects) > 1:
# Means we must have a write or accum effect so not read-only
return False
(eff,) = ref_effects
return isinstance(eff, state.ReadEffect)
@register_lowering(for_loop.for_p)
def _for_lowering_rule(
ctx: LoweringRuleContext,
*args,
jaxpr,
which_linear,
nsteps,
reverse,
unroll,
):
del which_linear
if reverse or unroll != 1:
raise NotImplementedError
_i_constant = _i64_constant if config.enable_x64.value else _i32_constant
lower_bound = _i_constant(0)
upper_bound = _i_constant(nsteps)
step = _i_constant(1)
init_args = map(_ensure_ir_value, args, ctx.avals_in)
# Partially discharge state from jaxpr for non-pointers
should_discharge = [
not isinstance(a, state.AbstractRef) for a in ctx.avals_in
]
discharged_jaxpr, () = discharge.discharge_state(
jaxpr, (), should_discharge=[True, *should_discharge]
)
in_avals = [v.aval for v in jaxpr.invars]
state_effects = state.get_ref_state_effects(in_avals, jaxpr.effects)[1:]
# Read-only `Ref`s don't need to be passed in explicitly as loop arguments so
# we can filter them out.
read_only = map(_is_read_only, state_effects)
is_loop_arg = map(
operator.and_, map(operator.not_, read_only), should_discharge
)
ptrs, _ = partition_list(should_discharge, init_args)
non_loop_args, loop_args = partition_list(is_loop_arg, init_args)
for_op = scf_dialect.ForOp(lower_bound, upper_bound, step, loop_args)
with ir.InsertionPoint(for_op.body):
loop_index = for_op.induction_variable
for_body_args = [
for_op.body.arguments[i + 1] for i, _ in enumerate(loop_args)
]
loop_body_args = merge_lists(is_loop_arg, non_loop_args, for_body_args)
out_discharged = lower_jaxpr_to_triton_ir(
ctx.context,
discharged_jaxpr,
[None, *ctx.block_infos],
loop_index,
*loop_body_args,
)
all_out = merge_lists(should_discharge, ptrs, out_discharged)
_, loop_out = partition_list(is_loop_arg, all_out)
scf_dialect.yield_(loop_out)
return merge_lists(is_loop_arg, non_loop_args, list(for_op.results_))
def _lower_jaxpr_to_for_loop(
ctx: LoweringRuleContext,
jaxpr: jax_core.Jaxpr,
lower_bound,
upper_bound,
consts,
*args,
has_loop_index: bool,
step: int = 1,
bound_type: ir.IntegerType | None = None,
):
if step != 1:
raise NotImplementedError
if bound_type is None or bound_type.width == 32:
step = _i32_constant(step)
else:
step = _i64_constant(step)
for_op = scf_dialect.ForOp(lower_bound, upper_bound, step, args)
with ir.InsertionPoint.at_block_begin(for_op.body):
loop_index = for_op.induction_variable
for_body_args = [for_op.body.arguments[i + 1] for i, _ in enumerate(args)]
if has_loop_index:
jaxpr_args = [*consts, loop_index, *for_body_args]
else:
jaxpr_args = [*consts, *for_body_args]
all_out = lower_jaxpr_to_triton_ir(
ctx.context, jaxpr, ctx.block_infos, *jaxpr_args
)
scf_dialect.yield_(all_out)
return list(for_op.results_)
@register_lowering(lax.scan_p)
def _scan_lowering_rule(
ctx: LoweringRuleContext,
*args,
jaxpr,
linear,
length,
reverse,
unroll,
num_consts,
num_carry,
_split_transpose,
):
del _split_transpose
# Only implements fori_loop-like scans
num_extensive = len(args) - num_consts - num_carry
if num_extensive: raise NotImplementedError
if reverse: raise NotImplementedError
if unroll != 1: raise NotImplementedError
del linear, num_extensive, unroll, reverse
jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts
if jaxpr_consts: raise NotImplementedError
del jaxpr_consts
jaxpr, has_loop_index = (
pallas_utils.pattern_match_scan_to_fori_loop(jaxpr, num_consts, num_carry)
)
args = map(_ensure_ir_value, args, ctx.avals_in)
consts, args = util.split_list(args, [num_consts])
if has_loop_index:
lower_bound, *args = args
upper_bound = _add(lower_bound, _ir_constant(length, lower_bound.type))
bound_type = lower_bound.type
else:
lower_bound = _i32_constant(0)
upper_bound = _i32_constant(length)
bound_type = ir.IntegerType.get_signless(32)
for_out = _lower_jaxpr_to_for_loop(
ctx, jaxpr, lower_bound, upper_bound, consts, *args,
has_loop_index=has_loop_index, step=1, bound_type=bound_type)
if has_loop_index:
# Need to return the final loop index value if the outer scan expects
# it as an output
return [upper_bound, *for_out]
return for_out
def _maybe_pattern_match_fori_loop(
ctx: LoweringRuleContext,
*args,
cond_nconsts,
cond_jaxpr,
body_nconsts,
body_jaxpr,
):
if cond_nconsts:
return None
_, cond_invars = split_list(cond_jaxpr.jaxpr.invars, [cond_nconsts])
cond_in_avals = [v.aval for v in cond_invars]
if len(cond_in_avals) < 2:
return None
# Check that the first two carry values are scalar ints
a1, a2 = cond_in_avals[:2]
if a1.shape != () or a1.dtype not in (jnp.int32, jnp.int64):
return None
if a2.shape != () or a2.dtype not in (jnp.int32, jnp.int64):
return None
# Check that the only eqn in the cond checks the loop index condition
v1, v2 = cond_invars[:2]
outvar = cond_jaxpr.jaxpr.outvars[0]
assert outvar.aval.dtype == jnp.bool_
if len(cond_jaxpr.jaxpr.eqns) != 1:
return None
eqn = cond_jaxpr.jaxpr.eqns[0]
if eqn.primitive != lax.lt_p:
return None
if eqn.outvars != [outvar]:
return None
if eqn.invars != [v1, v2]:
return None
# Check that the carry is updated in the body appropriately
_, body_invars = split_list(body_jaxpr.jaxpr.invars, [body_nconsts])
v1, v2 = body_invars[:2]
vo1, vo2 = body_jaxpr.jaxpr.outvars[:2]
# Upper bound should be constant
if v2 is not vo2:
return None
# Check that we increment the loop index in the body
for i, eqn in enumerate(body_jaxpr.jaxpr.eqns):
if eqn.primitive is lax.add_p:
if eqn.invars[0] is v1:
if isinstance(eqn.invars[1], jax_core.Literal):
if eqn.invars[1].val == 1:
if eqn.outvars[0] == vo1:
eqn_index = i
break
else:
return None
jaxpr = body_jaxpr.jaxpr
new_invars = (*jaxpr.invars[:body_nconsts],
jaxpr.invars[body_nconsts],
*jaxpr.invars[body_nconsts + 2:])
new_outvars = tuple(jaxpr.outvars[2:])
jaxpr = jaxpr.replace(
eqns=jaxpr.eqns[:eqn_index] + jaxpr.eqns[eqn_index + 1:],
invars=new_invars,
outvars=new_outvars)
_, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
(lb, ub), args = carry[:2], carry[2:]
const_block_infos, args_block_infos = split_list(ctx.block_infos,
[body_nconsts])
ctx = ctx.replace(block_infos=[*const_block_infos, None,
*args_block_infos[2:]])
for_out = _lower_jaxpr_to_for_loop(
ctx,
jaxpr,
lb,
ub,
body_consts,
*args,
has_loop_index=True,
step=1,
bound_type=lb.type,
)
return [ub, ub, *for_out]
@register_lowering(lax.while_p)
def _while_lowering_rule(
ctx: LoweringRuleContext,
*args,
cond_nconsts,
cond_jaxpr,
body_nconsts,
body_jaxpr,
):
args = map(_ensure_ir_value, args, ctx.avals_in)
# First, try to pattern match to fori_loop and lower to scf.for if possible
result = _maybe_pattern_match_fori_loop(ctx, *args, cond_nconsts=cond_nconsts,
body_nconsts=body_nconsts, cond_jaxpr=cond_jaxpr,
body_jaxpr=body_jaxpr)
if result is not None:
return result
# Fall back to default while lowering
cond_consts, body_consts, carry = util.split_list(
args, [cond_nconsts, body_nconsts]
)
cond_const_block_infos, body_const_block_infos, carry_block_infos = (
util.split_list(ctx.block_infos, [cond_nconsts, body_nconsts])
)
cond_const_types = [a.type for a in cond_consts]
body_const_types = [a.type for a in body_consts]
carry_types = [a.type for a in carry]
all_types = [*cond_const_types, *body_const_types, *carry_types]
while_op = scf_dialect.WhileOp(all_types, args)
before_block = while_op.before.blocks.append(*all_types)
cond_consts_, _, carry_ = util.split_list(
before_block.arguments,
[cond_nconsts, body_nconsts],
)
cond_args = [*cond_consts_, *carry_]
with ir.InsertionPoint.at_block_begin(before_block):
[cond] = lower_jaxpr_to_triton_ir(
ctx.context,
cond_jaxpr.jaxpr,
[*cond_const_block_infos, *carry_block_infos],
*cond_args,
)
scf_dialect.condition(cond, before_block.arguments)
after_block = while_op.after.blocks.append(*all_types)
cond_consts_, body_consts_, carry_ = util.split_list(
after_block.arguments,
[cond_nconsts, body_nconsts],
)
all_args = [*cond_consts_, *body_consts_, *carry_]
cond_const_args, body_const_args, carry_args = util.split_list(
all_args, [cond_nconsts, body_nconsts]
)
with ir.InsertionPoint.at_block_begin(after_block):
loop_out = lower_jaxpr_to_triton_ir(
ctx.context,
body_jaxpr.jaxpr,
[*body_const_block_infos, *carry_block_infos],
*body_const_args,
*carry_args
)
all_handles = [*cond_const_args, *body_const_args, *loop_out]
if all_handles:
scf_dialect.yield_(all_handles)
all_out = list(while_op.results_)
return all_out[cond_nconsts + body_nconsts :]
@register_lowering(lax.cond_p)
def _cond_lowering_rule(
ctx: LoweringRuleContext,
index,
*args, # *consts, *ops
branches, # tuple(jaxprs)
):
block_infos = ctx.block_infos
def to_type(out_aval):
element_type = _dtype_to_ir_type(out_aval.dtype)
if not out_aval.shape:
return element_type
return ir.RankedTensorType.get(out_aval.shape, element_type)
out_types = [to_type(out) for out in ctx.avals_out]
use_branch0 = _equal(index, _ir_constant(0, index.type), signed=False)
# TODO(bjp): Switch to scf.index_switch once exposed in triton.cc
if_op = scf_dialect.IfOp(use_branch0, out_types, hasElse=True)
with ir.InsertionPoint.at_block_begin(if_op.then_block):
outs0 = lower_jaxpr_to_triton_ir(
ctx.context,
branches[0].jaxpr,
block_infos[1:],
*args)
scf_dialect.yield_(outs0)
with ir.InsertionPoint.at_block_begin(if_op.else_block):
# TODO(bjp): Instead of linear nest of 'if's, partition into halves.
if len(branches) > 2:
outs1 = _cond_lowering_rule(
ctx,
_sub(index, _ir_constant(1, index.type)),
*args,
branches=branches[1:],
)
else:
outs1 = lower_jaxpr_to_triton_ir(
ctx.context,
branches[1].jaxpr,
block_infos[1:],
*args)
scf_dialect.yield_(outs1)
return list(if_op.results_)
def _ensure_ir_value(x: object, aval: jax_core.ShapedArray) -> ir.Value:
if isinstance(x, ir.Value):
return x
elif isinstance(x, (np.number, np.ndarray, int, float)):
return _ir_constant(x, _dtype_to_ir_type(aval.dtype))
raise NotImplementedError
def _ir_constant(v: object, t: ir.Type) -> ir.Value:
if isinstance(v, (np.number, np.ndarray, int, float)):
if isinstance(t, ir.IntegerType):
v = int(v)
else:
assert isinstance(t, ir.FloatType)
v = float(v)
return arith_dialect.constant(t, v)
raise NotImplementedError
def _i32_constant(v: int) -> ir.Value:
return arith_dialect.constant(ir.IntegerType.get_signless(32), v)
def _i64_constant(v: int) -> ir.Value:
return arith_dialect.constant(ir.IntegerType.get_signless(64), v)
def _dtype_to_ir_type(dtype: jnp.dtype) -> ir.Type:
if jnp.issubdtype(dtype, np.integer):
# All integer types in Triton are signless.
return ir.IntegerType.get_signless(dtype.itemsize * 8)
return mlir.dtype_to_ir_type(dtype)