mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
1707 lines
53 KiB
Python
1707 lines
53 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Module for lowering JAX primitives to Triton IR."""
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import functools
|
|
import operator
|
|
from typing import Any, Dict, Sequence, Tuple
|
|
import zlib
|
|
|
|
import jax
|
|
from jax import lax
|
|
from jax import tree_util
|
|
from jax._src import ad_checkpoint
|
|
from jax._src import ad_util
|
|
from jax._src import api_util
|
|
from jax._src import core as jax_core
|
|
from jax._src import linear_util as lu
|
|
from jax._src import pjit
|
|
from jax._src import state
|
|
from jax._src import util
|
|
from jax._src.lax.control_flow import for_loop
|
|
from jax._src.lib import gpu_triton as triton_kernel_call_lib
|
|
from jax._src.lib import hlo_helpers
|
|
from jax._src.lib import version
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.pallas import core as pallas_core
|
|
from jax._src.pallas import indexing
|
|
from jax._src.pallas import primitives
|
|
from jax._src.pallas import utils as pallas_utils
|
|
from jax._src.pallas.pallas_call import pallas_call_p
|
|
from jax._src.state import AbstractRef
|
|
from jax._src.state import discharge
|
|
from jax._src.state import primitives as sp
|
|
from jax._src.util import merge_lists
|
|
from jax._src.util import partition_list
|
|
from jax._src.util import split_list
|
|
from jax._src.util import weakref_lru_cache
|
|
from jax.interpreters import mlir
|
|
from jax.interpreters import partial_eval as pe
|
|
from jax.lib import xla_client as xc
|
|
import jax.numpy as jnp
|
|
from jax_triton import triton_lib
|
|
from jax_triton.triton_lib import compile_ttir_to_ptx_inplace
|
|
from jax_triton.triton_lib import get_triton_type
|
|
import numpy as np
|
|
from triton._C.libtriton.triton import ir as tl_ir
|
|
from triton.compiler import code_generator as code_gen
|
|
import triton.language as tl
|
|
|
|
# TODO(sharadmv): enable type checking
|
|
# mypy: ignore-errors
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
zip, unsafe_zip = util.safe_zip, zip
|
|
partial = functools.partial
|
|
Grid = Tuple[int, ...]
|
|
NDIndexer = indexing.NDIndexer
|
|
GridMapping = pallas_core.GridMapping
|
|
BlockMapping = pallas_core.BlockMapping
|
|
|
|
|
|
# # General lowering logic
|
|
@dataclasses.dataclass
|
|
class TritonModuleContext:
|
|
name: str
|
|
ir_context: tl_ir.context
|
|
builder: tl_ir.builder
|
|
module: tl_ir.module
|
|
grid_mapping: GridMapping
|
|
program_ids: Sequence[tl.tensor]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class BlockInfo:
|
|
full_shape_dtype: jax.ShapeDtypeStruct
|
|
start_indices: Sequence[Any]
|
|
block_shape: Tuple[int, ...]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TritonLoweringRuleContext:
|
|
context: TritonModuleContext
|
|
avals_in: Any
|
|
avals_out: Any
|
|
block_infos: Sequence[BlockInfo | None]
|
|
|
|
def __post_init__(self):
|
|
self.builder = self.context.builder
|
|
|
|
replace = dataclasses.replace
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TritonLoweringResult:
|
|
"""Keeps pybind11 objects alive."""
|
|
|
|
ir_context: tl_ir.context
|
|
module: tl_ir.module
|
|
builder: tl_ir.builder
|
|
grid: Tuple[int, ...]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TritonCompilationResult:
|
|
kernel_name: str
|
|
ttir: str
|
|
ptx: str
|
|
shared_mem_bytes: int
|
|
compute_capability: int
|
|
lowering_result: TritonLoweringResult
|
|
|
|
|
|
class TritonLoweringException(Exception):
|
|
pass
|
|
|
|
|
|
def _eval_index_map(
|
|
ctx: TritonModuleContext, idx, block_mapping: BlockMapping | None
|
|
):
|
|
if block_mapping is None:
|
|
return None
|
|
block_indices = tuple(
|
|
lower_jaxpr_to_triton_ir(
|
|
ctx, block_mapping.index_map_jaxpr.jaxpr, None, *idx
|
|
)
|
|
)
|
|
return tuple(
|
|
i if b is pallas_core.mapped else i.__mul__(b, _builder=ctx.builder)
|
|
for i, b in zip(block_indices, block_mapping.block_shape)
|
|
)
|
|
|
|
|
|
triton_lowering_rules = {}
|
|
|
|
|
|
def _process_grid_to_3d_grid(builder, grid_mapping: GridMapping):
|
|
if len(grid_mapping.grid) <= 3:
|
|
program_ids = [
|
|
tl.program_id(axis=i, _builder=builder)
|
|
for i in range(len(grid_mapping.grid))
|
|
]
|
|
return grid_mapping.grid, program_ids
|
|
grid_prefix = grid_mapping.grid[:-2]
|
|
grid_suffix = grid_mapping.grid[-2:]
|
|
total_axis_size = np.prod(grid_prefix)
|
|
new_grid = (total_axis_size, *grid_suffix)
|
|
out_indices = [0] * len(grid_prefix)
|
|
grid0 = tl.program_id(0, _builder=builder)
|
|
for i, s in reversed(list(enumerate(grid_prefix))):
|
|
grid0, out_indices[i] = (
|
|
grid0.__floordiv__(s, _builder=builder),
|
|
grid0.__mod__(s, _builder=builder),
|
|
)
|
|
out_indices = [
|
|
*out_indices,
|
|
tl.program_id(1, _builder=builder),
|
|
tl.program_id(2, _builder=builder),
|
|
]
|
|
assert len(out_indices) == len(grid_mapping.grid)
|
|
return new_grid, out_indices
|
|
|
|
|
|
def lower_jaxpr_to_triton_module(
|
|
jaxpr: jax_core.Jaxpr, in_shapes, grid_mapping: GridMapping, name: str
|
|
) -> tl_ir.module:
|
|
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True)
|
|
ir_context = tl_ir.context()
|
|
ir_context.load_triton()
|
|
builder = tl_ir.builder(ir_context)
|
|
# TODO(sharadmv): handle multiple devices, right now we assume device 0
|
|
# which is fine when we have multiple of the same GPU but this won't work in
|
|
# general.
|
|
device = 0
|
|
builder.arch = triton_kernel_call_lib.get_compute_capability(device)
|
|
module = builder.create_module()
|
|
in_avals = [var.aval for var in jaxpr.invars]
|
|
triton_types = [get_triton_type(x) for x in in_avals]
|
|
arg_types = [code_gen.str_to_ty(arg) for arg in triton_types]
|
|
assert len(jaxpr.outvars) == 0
|
|
prototype = tl.function_type([], arg_types)
|
|
out = prototype.to_ir(builder)
|
|
fn = builder.get_or_insert_function(module, name, out, "public", False)
|
|
module.push_back(fn)
|
|
entry = fn.add_entry_block()
|
|
args = []
|
|
for i in range(len(in_avals)):
|
|
fn.set_arg_attr(i, "tt.divisibility", 16)
|
|
ptr = tl.tensor(fn.args(i), prototype.param_types[i])
|
|
args.append(ptr)
|
|
builder.set_insertion_point_to_start(entry)
|
|
new_grid, program_ids = _process_grid_to_3d_grid(builder, grid_mapping)
|
|
local_program_ids = [
|
|
pid for i, pid in enumerate(program_ids) if i not in grid_mapping.mapped_dims
|
|
]
|
|
ctx = TritonModuleContext(
|
|
name, ir_context, builder, module, grid_mapping, local_program_ids
|
|
)
|
|
if grid_mapping.num_index_operands:
|
|
raise NotImplementedError(
|
|
"Scalar prefetch not supported in Triton lowering.")
|
|
start_indices = map(
|
|
partial(_eval_index_map, ctx, program_ids), grid_mapping.block_mappings
|
|
)
|
|
block_infos = [
|
|
BlockInfo(
|
|
jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype),
|
|
start_idx,
|
|
block_mapping.block_shape,
|
|
)
|
|
if block_mapping is not None
|
|
else None
|
|
for shape_dtype, block_mapping, start_idx in zip(
|
|
in_shapes, grid_mapping.block_mappings, start_indices
|
|
)
|
|
]
|
|
() = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args)
|
|
module.context = ir_context
|
|
ctx.builder.ret([])
|
|
return TritonLoweringResult(ir_context, module, builder, new_grid)
|
|
|
|
|
|
def lower_jaxpr_to_triton_ir(
|
|
ctx: TritonModuleContext,
|
|
jaxpr: jax_core.Jaxpr,
|
|
block_infos: Sequence[BlockInfo | None] | None,
|
|
*args
|
|
) -> Sequence[Any]:
|
|
env = {}
|
|
block_info_env = {}
|
|
|
|
def read_env(var: jax_core.Atom):
|
|
if type(var) is jax_core.Literal:
|
|
t = tl.core._to_tensor(np.array(var.val).tolist(), builder=ctx.builder)
|
|
dst_ty = code_gen.str_to_ty(get_triton_type(var.aval)).element_ty
|
|
if t.type.scalar != dst_ty:
|
|
# _to_tensor(np.array(var.val).tolist()) can be lossy e.g. np.float64
|
|
# comes out of .tolist() as list[float], which then comes out of
|
|
# _to_tensor as a block of f32.
|
|
t = tl.semantic.cast(t, dst_ty, ctx.builder)
|
|
return t
|
|
return env[var]
|
|
|
|
def read_block_info_env(var: jax_core.Atom):
|
|
if type(var) is jax_core.Literal:
|
|
return None
|
|
return block_info_env.get(var, None)
|
|
|
|
def write_env(var: jax_core.Var, val):
|
|
env[var] = val
|
|
|
|
if block_infos is None:
|
|
block_infos = [None] * len(jaxpr.invars)
|
|
for invar, block_info in zip(jaxpr.invars, block_infos):
|
|
block_info_env[invar] = block_info
|
|
map(write_env, jaxpr.invars, args)
|
|
for eqn in jaxpr.eqns:
|
|
invals = map(read_env, eqn.invars)
|
|
if eqn.primitive not in triton_lowering_rules:
|
|
raise NotImplementedError(
|
|
"Unimplemented primitive in Pallas GPU lowering: "
|
|
f"{eqn.primitive.name}. "
|
|
"Please file an issue on https://github.com/google/jax/issues.")
|
|
rule = triton_lowering_rules[eqn.primitive]
|
|
avals_in = [v.aval for v in eqn.invars]
|
|
avals_out = [v.aval for v in eqn.outvars]
|
|
eqn_block_infos = map(read_block_info_env, eqn.invars)
|
|
rule_ctx = TritonLoweringRuleContext(
|
|
ctx, avals_in, avals_out, eqn_block_infos
|
|
)
|
|
try:
|
|
outvals = rule(rule_ctx, *invals, **eqn.params)
|
|
except TritonLoweringException:
|
|
raise # We only add the extra info to the innermost exception.
|
|
except Exception as e:
|
|
raise TritonLoweringException(
|
|
f"Exception while lowering eqn:\n {eqn}\n"
|
|
f"With context:\n {rule_ctx}\n"
|
|
f"With inval shapes={map(lambda t: t.shape, invals)}\n"
|
|
f"With inval types={map(lambda t: t.type, invals)}\n"
|
|
f"In jaxpr:\n{jaxpr}") from e
|
|
if eqn.primitive.multiple_results:
|
|
map(write_env, eqn.outvars, outvals)
|
|
else:
|
|
write_env(eqn.outvars[0], outvals)
|
|
return map(read_env, jaxpr.outvars)
|
|
|
|
|
|
# # Primitive lowering rules
|
|
# ## Programming model primitives
|
|
def _program_id_lowering_rule(ctx: TritonLoweringRuleContext, *, axis):
|
|
return ctx.context.program_ids[axis]
|
|
|
|
|
|
triton_lowering_rules[primitives.program_id_p] = _program_id_lowering_rule
|
|
# ## Atomic op primitives
|
|
_ATOMIC_OP_MAPPING = {
|
|
primitives.AtomicOpType.XCHG: tl.core.atomic_xchg,
|
|
primitives.AtomicOpType.ADD: tl.core.atomic_add,
|
|
primitives.AtomicOpType.MAX: tl.core.atomic_max,
|
|
primitives.AtomicOpType.MIN: tl.core.atomic_min,
|
|
primitives.AtomicOpType.AND: tl.core.atomic_and,
|
|
primitives.AtomicOpType.OR: tl.core.atomic_or,
|
|
primitives.AtomicOpType.XOR: tl.core.atomic_xor,
|
|
}
|
|
|
|
|
|
def _atomic_lowering_rule(
|
|
ctx: TritonLoweringRuleContext,
|
|
ptr,
|
|
value,
|
|
*args,
|
|
args_tree,
|
|
masked: bool,
|
|
atomic_type: primitives.AtomicOpType
|
|
):
|
|
ref_block_info, *_ = ctx.block_infos
|
|
idx, *mask_rest = tree_util.tree_unflatten(args_tree, args)
|
|
avals_in = ctx.avals_in
|
|
ptr = _compute_pointers_from_indices(
|
|
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder
|
|
)
|
|
mask = None
|
|
if masked:
|
|
assert len(mask_rest) == 1
|
|
(mask,) = mask_rest
|
|
if atomic_type not in _ATOMIC_OP_MAPPING:
|
|
raise NotImplementedError(atomic_type)
|
|
op = _ATOMIC_OP_MAPPING[atomic_type]
|
|
return op(ptr, value, mask=mask, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[primitives.atomic_rmw_p] = _atomic_lowering_rule
|
|
|
|
|
|
def _atomic_cas_lowering_rule(ctx: TritonLoweringRuleContext, ptr, cmp, val):
|
|
return tl.atomic_cas(ptr, cmp, val, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[primitives.atomic_cas_p] = _atomic_cas_lowering_rule
|
|
|
|
|
|
def _max_contiguous_lowering_rule(ctx: TritonLoweringRuleContext, x, *, values):
|
|
values = [tl.constexpr(v) for v in values]
|
|
return tl.max_contiguous(x, values, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[primitives.max_contiguous_p] = (
|
|
_max_contiguous_lowering_rule
|
|
)
|
|
|
|
|
|
def _multiple_of_lowering_rule(ctx: TritonLoweringRuleContext, x, *, values):
|
|
values = [tl.constexpr(v) for v in values]
|
|
return tl.multiple_of(x, values, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[primitives.multiple_of_p] = _multiple_of_lowering_rule
|
|
|
|
|
|
def _abs_lowering_rule(ctx: TritonLoweringRuleContext, x):
|
|
return tl.abs(x, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.abs_p] = _abs_lowering_rule
|
|
|
|
|
|
def _clamp_lowering_rule(ctx: TritonLoweringRuleContext, min, operand, max):
|
|
return _min_lowering_rule(ctx, max_lowering_rule(ctx, min, operand), max)
|
|
|
|
|
|
triton_lowering_rules[lax.clamp_p] = _clamp_lowering_rule
|
|
|
|
|
|
def _exp_lowering_rule(ctx: TritonLoweringRuleContext, a):
|
|
return tl.exp(a, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.exp_p] = _exp_lowering_rule
|
|
|
|
|
|
def _log_lowering_rule(ctx: TritonLoweringRuleContext, a):
|
|
return tl.log(a, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.log_p] = _log_lowering_rule
|
|
|
|
|
|
def _log1p_lowering_rule(ctx: TritonLoweringRuleContext, a):
|
|
return tl.math.log1p(a, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.log1p_p] = _log1p_lowering_rule
|
|
|
|
|
|
def _logistic_lowering_rule(ctx: TritonLoweringRuleContext, a):
|
|
one_ = tl.core._to_tensor(1.0, ctx.builder)
|
|
x = tl.exp(a.__neg__(_builder=ctx.builder), _builder=ctx.builder)
|
|
x = x.__add__(one_, _builder=ctx.builder)
|
|
x = one_.__truediv__(x, _builder=ctx.builder)
|
|
return x
|
|
|
|
|
|
triton_lowering_rules[lax.logistic_p] = _logistic_lowering_rule
|
|
|
|
|
|
def _sin_lowering_rule(ctx: TritonLoweringRuleContext, a):
|
|
return tl.sin(a, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.sin_p] = _sin_lowering_rule
|
|
|
|
|
|
def _cos_lowering_rule(ctx: TritonLoweringRuleContext, a):
|
|
return tl.cos(a, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.cos_p] = _cos_lowering_rule
|
|
|
|
|
|
def _mul_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
|
return a.__mul__(b, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.mul_p] = _mul_lowering_rule
|
|
|
|
|
|
def _div_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
|
floating_dtypes = {tl.float16, tl.float32, tl.float64, tl.bfloat16}
|
|
if a.dtype in floating_dtypes and b.dtype in floating_dtypes:
|
|
return a.__truediv__(b, _builder=ctx.builder)
|
|
return a.__floordiv__(b, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.div_p] = _div_lowering_rule
|
|
|
|
|
|
def _iota_lowering_rule(
|
|
ctx: TritonLoweringRuleContext, *, dtype, shape, dimension
|
|
):
|
|
if dimension != 0:
|
|
raise NotImplementedError()
|
|
return tl.arange(0, shape[0], _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.iota_p] = _iota_lowering_rule
|
|
|
|
|
|
def _add_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
|
return a.__add__(b, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.add_p] = _add_lowering_rule
|
|
triton_lowering_rules[ad_util.add_any_p] = _add_lowering_rule
|
|
|
|
|
|
def _integer_pow_lowering_rule(ctx: TritonLoweringRuleContext, a, *, y):
|
|
if y == 2:
|
|
return a.__mul__(a, _builder=ctx.builder)
|
|
if y == 3:
|
|
return a.__mul__(a.__mul__(a, _builder=ctx.builder), _builder=ctx.builder)
|
|
if y == -2:
|
|
one_ = tl.core._to_tensor(1.0, ctx.builder)
|
|
a_sq = a.__mul__(a, _builder=ctx.builder)
|
|
return one_.__truediv__(a_sq, _builder=ctx.builder)
|
|
return tl.math.pow(a, y, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule
|
|
|
|
|
|
def _pow_lowering_rule(ctx: TritonLoweringRuleContext, a, y):
|
|
return tl.math.pow(a, y, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.pow_p] = _pow_lowering_rule
|
|
|
|
|
|
def _tanh_lowering_rule(ctx: TritonLoweringRuleContext, a):
|
|
return tl.math.tanh(a, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.tanh_p] = _tanh_lowering_rule
|
|
|
|
|
|
def _min_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
|
pred = a.__lt__(b, _builder=ctx.builder)
|
|
return tl.semantic.where(pred, a, b, ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.min_p] = _min_lowering_rule
|
|
|
|
|
|
def _convert_element_type_lowering_rule(
|
|
ctx: TritonLoweringRuleContext, a, *, new_dtype, weak_type
|
|
):
|
|
if new_dtype == ctx.avals_in[0].dtype:
|
|
return a
|
|
if new_dtype == jnp.float32:
|
|
new_dtype = tl.float32
|
|
elif new_dtype == jnp.float64:
|
|
new_dtype = tl.float64
|
|
elif new_dtype == jnp.float16:
|
|
new_dtype = tl.float16
|
|
elif new_dtype == jnp.bfloat16:
|
|
new_dtype = tl.bfloat16
|
|
elif new_dtype == jnp.int32:
|
|
new_dtype = tl.int32
|
|
elif new_dtype == jnp.int64:
|
|
new_dtype = tl.int64
|
|
else:
|
|
raise ValueError(f"Unhandled dtype: {new_dtype}")
|
|
return tl.semantic.cast(a, new_dtype, ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.convert_element_type_p] = (
|
|
_convert_element_type_lowering_rule
|
|
)
|
|
|
|
|
|
def max_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
|
pred = a.__gt__(b, _builder=ctx.builder)
|
|
return tl.semantic.where(pred, a, b, ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.max_p] = max_lowering_rule
|
|
|
|
|
|
def ge_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
|
return a.__ge__(b, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.ge_p] = ge_lowering_rule
|
|
|
|
|
|
def eq_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
|
return a.__eq__(b, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[jax.lax.eq_p] = eq_lowering_rule
|
|
|
|
|
|
def ne_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
|
return a.__ne__(b, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[jax.lax.ne_p] = ne_lowering_rule
|
|
|
|
def bitwise_and_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
|
return a.__and__(b, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[jax.lax.and_p] = bitwise_and_lowering_rule
|
|
|
|
|
|
def bitwise_or_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
|
return a.__or__(b, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[jax.lax.or_p] = bitwise_or_lowering_rule
|
|
|
|
|
|
def select_n_lowering_rule(ctx: TritonLoweringRuleContext, pred, a, b):
|
|
return tl.semantic.where(pred, b, a, ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.select_n_p] = select_n_lowering_rule
|
|
|
|
|
|
def _rem_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
|
return a.__mod__(b, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.rem_p] = _rem_lowering_rule
|
|
|
|
|
|
def _sub_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
|
return a.__sub__(b, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.sub_p] = _sub_lowering_rule
|
|
|
|
|
|
def _lt_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
|
return a.__lt__(b, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.lt_p] = _lt_lowering_rule
|
|
|
|
|
|
def _gt_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
|
return a.__gt__(b, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.gt_p] = _gt_lowering_rule
|
|
|
|
|
|
def _sqrt_lowering_rule(ctx: TritonLoweringRuleContext, a):
|
|
return tl.sqrt(a, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.sqrt_p] = _sqrt_lowering_rule
|
|
|
|
|
|
def _rsqrt_lowering_rule(ctx: TritonLoweringRuleContext, a):
|
|
return tl.math.rsqrt(a, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule
|
|
|
|
|
|
def _neg_lowering_rule(ctx: TritonLoweringRuleContext, a):
|
|
return a.__neg__(_builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.neg_p] = _neg_lowering_rule
|
|
|
|
|
|
def _broadcast_in_dim_lowering_rule(
|
|
ctx: TritonLoweringRuleContext, a, *, broadcast_dimensions, shape
|
|
):
|
|
shape = map(tl.constexpr, shape)
|
|
if not a.type.is_block():
|
|
return tl.broadcast_to(a, shape, _builder=ctx.builder)
|
|
expand_dims = [i for i in range(len(shape)) if i not in broadcast_dimensions]
|
|
for dim in expand_dims:
|
|
a = tl.semantic.expand_dims(a, dim, ctx.builder)
|
|
return tl.broadcast_to(a, shape, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[jax.lax.broadcast_in_dim_p] = (
|
|
_broadcast_in_dim_lowering_rule
|
|
)
|
|
|
|
|
|
def _broadcast_to_lowering_rule(
|
|
ctx: TritonLoweringRuleContext, a, *, shape
|
|
):
|
|
shape = map(tl.constexpr, shape)
|
|
return tl.broadcast_to(a, shape, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[indexing.broadcast_to_p] = (
|
|
_broadcast_to_lowering_rule
|
|
)
|
|
|
|
|
|
def _squeeze_lowering_rule(ctx: TritonLoweringRuleContext, a, *, dimensions):
|
|
del dimensions
|
|
return _reshape_lowering_rule(ctx, a, new_sizes=None, dimensions=None)
|
|
|
|
|
|
triton_lowering_rules[lax.squeeze_p] = _squeeze_lowering_rule
|
|
|
|
|
|
def _reshape_lowering_rule(
|
|
ctx: TritonLoweringRuleContext, a, *, new_sizes, dimensions
|
|
):
|
|
del new_sizes, dimensions
|
|
# Short-circuit to avoid unneeded reshape.
|
|
dst_shp = ctx.avals_out[0].shape
|
|
if tuple(s.value for s in a.shape) == dst_shp:
|
|
return a
|
|
if not a.type.is_block():
|
|
if dst_shp:
|
|
return tl.broadcast_to(a, [tl.constexpr(s) for s in dst_shp],
|
|
_builder=ctx.builder)
|
|
return a
|
|
# Expand-dims or reduce-sum to handle singleton dims.
|
|
if ([s.value for s in a.shape if s.value != 1] ==
|
|
[s for s in dst_shp if s != 1]):
|
|
# Eliminate one difference and recurse.
|
|
for i in range(max(len(a.shape), len(dst_shp))):
|
|
if (i < len(a.shape) and i < len(dst_shp) and
|
|
a.shape[i].value == dst_shp[i]):
|
|
continue
|
|
# Use expand_dims to add a singleton dim.
|
|
if i < len(dst_shp) and dst_shp[i] == 1:
|
|
return _reshape_lowering_rule(
|
|
ctx, tl.semantic.expand_dims(a, i, builder=ctx.builder),
|
|
new_sizes=None, dimensions=None)
|
|
# Use a reduction to eliminate singleton dim.
|
|
if a.shape[i].value == 1:
|
|
reduce_ctx = ctx.replace(
|
|
avals_in=[ctx.avals_in[0].update(
|
|
shape=tuple(d.value for d in a.shape))],
|
|
avals_out=[ctx.avals_in[0].update(
|
|
shape=tuple(d.value for di, d in enumerate(a.shape)
|
|
if di != i))])
|
|
return _reshape_lowering_rule(
|
|
ctx,
|
|
_reduce_lowering(jnp.add, reduce_ctx, a, axes=(i,)),
|
|
new_sizes=None, dimensions=None)
|
|
|
|
shape = [tl.constexpr(s) for s in dst_shp]
|
|
return tl.reshape(a, shape, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[jax.lax.reshape_p] = _reshape_lowering_rule
|
|
|
|
|
|
def _compute_pointers_from_indices(
|
|
root_ptr: tl.core.tensor,
|
|
block_info: BlockInfo | None,
|
|
nd_indexer: NDIndexer,
|
|
array_shape: Tuple[int, ...],
|
|
builder: tl_ir.builder,
|
|
) -> tl.core.tensor:
|
|
if block_info is None:
|
|
full_shape = array_shape
|
|
num_mapped_dims = 0
|
|
block_shape = array_shape
|
|
else:
|
|
full_shape = block_info.full_shape_dtype.shape
|
|
num_mapped_dims = sum(
|
|
b is pallas_core.mapped for b in block_info.block_shape
|
|
)
|
|
block_shape = block_info.block_shape
|
|
strides = pallas_utils.strides_from_shape(full_shape)
|
|
indexer_shape = nd_indexer.get_indexer_shape()
|
|
int_indexer_shape = nd_indexer.int_indexer_shape
|
|
indices = nd_indexer.indices
|
|
other_shape = indexer_shape[len(int_indexer_shape) :]
|
|
bcast_indices = []
|
|
other_shape_idx = 0
|
|
if block_info is None:
|
|
start_index_offsets = [None] * len(indices)
|
|
else:
|
|
start_index_offsets = block_info.start_indices
|
|
assert len(indices) + num_mapped_dims == len(full_shape)
|
|
assert len(start_index_offsets) == len(full_shape)
|
|
indexer_iter = iter(indices)
|
|
for dim_stride, dim_block_size, start_offset in zip(
|
|
strides, block_shape, start_index_offsets
|
|
):
|
|
if dim_block_size is pallas_core.mapped:
|
|
index = tl.core._to_tensor(0, builder)
|
|
else:
|
|
index = next(indexer_iter)
|
|
if isinstance(index, primitives.Slice):
|
|
# Handle slices with static and dynamic indices and static sizes
|
|
if isinstance(index.start, int):
|
|
ptr_dim_offset = tl.arange(
|
|
index.start, index.start + index.size, _builder=builder
|
|
)
|
|
else:
|
|
ptr_dim_offset = index.start.__add__(
|
|
tl.arange(0, index.size, _builder=builder), _builder=builder
|
|
)
|
|
# We need to add broadcastable dimensions for the advanced int indexing
|
|
# and for previous slices
|
|
num_left_expand_dims = len(int_indexer_shape) + other_shape_idx
|
|
num_right_expand_dims = len(other_shape) - other_shape_idx - 1
|
|
other_shape_idx += 1
|
|
elif isinstance(index, slice):
|
|
if index != slice(None):
|
|
raise NotImplementedError("Only `slice(None)` allowed.")
|
|
ptr_dim_offset = tl.arange(0, dim_block_size, _builder=builder)
|
|
num_left_expand_dims = len(int_indexer_shape) + other_shape_idx
|
|
num_right_expand_dims = len(other_shape) - other_shape_idx - 1
|
|
other_shape_idx += 1
|
|
else:
|
|
# indexer is either a *scalar* or an array of size `int_indexer_shape`
|
|
ptr_dim_offset = index
|
|
num_left_expand_dims = 0
|
|
num_right_expand_dims = len(other_shape)
|
|
if not ptr_dim_offset.type.is_block():
|
|
num_left_expand_dims = max(len(indexer_shape) - 1, 0)
|
|
else:
|
|
num_right_expand_dims = len(other_shape)
|
|
if not ptr_dim_offset.type.is_block() and indexer_shape:
|
|
ptr_dim_offset = tl.broadcast_to(
|
|
ptr_dim_offset,
|
|
[tl.constexpr(1)] * len(indexer_shape),
|
|
_builder=builder,
|
|
)
|
|
else:
|
|
for _ in range(num_left_expand_dims):
|
|
ptr_dim_offset = tl.semantic.expand_dims(ptr_dim_offset, 0, builder)
|
|
for _ in range(num_right_expand_dims):
|
|
ndim = len(ptr_dim_offset.shape)
|
|
ptr_dim_offset = tl.semantic.expand_dims(ptr_dim_offset, ndim, builder)
|
|
if start_offset is not None:
|
|
ptr_dim_offset = ptr_dim_offset.__add__(start_offset, _builder=builder)
|
|
stride_size = tl.core._to_tensor(int(dim_stride), builder)
|
|
bcast_indices.append(ptr_dim_offset.__mul__(stride_size, _builder=builder))
|
|
block_shapes = [
|
|
() if not index.type.is_block() else tuple(index.type.get_block_shapes())
|
|
for index in bcast_indices
|
|
]
|
|
bcast_indices = [
|
|
tl.core.broadcast_to(
|
|
index, map(tl.constexpr, indexer_shape), _builder=builder
|
|
)
|
|
if indexer_shape != block_shape
|
|
else index
|
|
for index, block_shape in zip(bcast_indices, block_shapes)
|
|
]
|
|
ptr = root_ptr
|
|
for bcast_idx in bcast_indices:
|
|
ptr = ptr.__add__(bcast_idx, _builder=builder)
|
|
return ptr
|
|
|
|
|
|
def _pack_indices(non_slice_idx, indexed_dims):
|
|
non_slice_idx_iter = iter(non_slice_idx)
|
|
return tuple(
|
|
next(non_slice_idx_iter) if indexed else slice(None)
|
|
for indexed in indexed_dims
|
|
)
|
|
|
|
|
|
def _get_lowering_rule(
|
|
ctx: TritonLoweringRuleContext, ptr, *non_slice_idx, indexed_dims
|
|
):
|
|
ref_block_info, *_ = ctx.block_infos
|
|
idx = _pack_indices(non_slice_idx, indexed_dims)
|
|
avals_in = ctx.avals_in
|
|
idx_avals = _pack_indices(avals_in[1:], indexed_dims)
|
|
if not isinstance(ptr.type, tl.pointer_type):
|
|
assert len(avals_in) == 1
|
|
return ptr
|
|
if non_slice_idx:
|
|
(int_indexer_shape,) = {
|
|
i.shape for i in idx_avals if not isinstance(i, slice)
|
|
}
|
|
else:
|
|
int_indexer_shape = ()
|
|
idx = tuple(
|
|
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
|
|
for s, slc in zip(avals_in[0].shape, idx)
|
|
)
|
|
idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape)
|
|
ptr = _compute_pointers_from_indices(
|
|
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder
|
|
)
|
|
return tl.load(ptr, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[sp.get_p] = _get_lowering_rule
|
|
|
|
|
|
def _masked_load_lowering_rule(
|
|
ctx: TritonLoweringRuleContext,
|
|
ptr,
|
|
*args,
|
|
args_tree,
|
|
masked,
|
|
eviction_policy,
|
|
cache_modifier,
|
|
is_volatile
|
|
):
|
|
ref_block_info, *_ = ctx.block_infos
|
|
idx, *mask_other = tree_util.tree_unflatten(args_tree, args)
|
|
avals_in = ctx.avals_in
|
|
if not isinstance(ptr.type, tl.pointer_type):
|
|
assert len(avals_in) == 1
|
|
return ptr
|
|
ptr = _compute_pointers_from_indices(
|
|
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder
|
|
)
|
|
mask, other = None, None
|
|
if masked:
|
|
assert 0 < len(mask_other) <= 2
|
|
if len(mask_other) == 2:
|
|
mask, other = mask_other
|
|
elif len(mask_other) == 1:
|
|
(mask,) = mask_other
|
|
return tl.load(
|
|
ptr,
|
|
mask=mask,
|
|
other=other,
|
|
cache_modifier=cache_modifier,
|
|
volatile=is_volatile,
|
|
eviction_policy=eviction_policy,
|
|
_builder=ctx.builder,
|
|
)
|
|
|
|
|
|
triton_lowering_rules[primitives.load_p] = _masked_load_lowering_rule
|
|
|
|
|
|
def _swap_lowering_rule(
|
|
ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims
|
|
):
|
|
ref_block_info, *_ = ctx.block_infos
|
|
avals_in = ctx.avals_in
|
|
idx = _pack_indices(non_slice_idx, indexed_dims)
|
|
idx_avals = _pack_indices(avals_in[2:], indexed_dims)
|
|
if non_slice_idx:
|
|
(int_indexer_shape,) = {
|
|
i.shape for i in idx_avals if not isinstance(i, slice)
|
|
}
|
|
else:
|
|
int_indexer_shape = ()
|
|
idx = tuple(
|
|
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
|
|
for s, slc in zip(avals_in[0].shape, idx)
|
|
)
|
|
idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape)
|
|
ptr = _compute_pointers_from_indices(
|
|
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder
|
|
)
|
|
mask = None
|
|
old_value = tl.load(ptr, mask=mask, _builder=ctx.builder)
|
|
tl.store(ptr, value, mask=mask, _builder=ctx.builder)
|
|
return old_value
|
|
|
|
|
|
triton_lowering_rules[sp.swap_p] = _swap_lowering_rule
|
|
|
|
|
|
def _masked_swap_lowering_rule(
|
|
ctx: TritonLoweringRuleContext,
|
|
ptr,
|
|
value,
|
|
*args,
|
|
args_tree,
|
|
masked,
|
|
eviction_policy
|
|
):
|
|
ptr_type = (
|
|
ptr.type.element_ty.element_ty
|
|
if ptr.type.is_block()
|
|
else ptr.type.element_ty
|
|
)
|
|
value_type = value.type.element_ty if value.type.is_block() else value.type
|
|
assert ptr_type == value_type, (ptr_type, value_type)
|
|
ref_block_info, *_ = ctx.block_infos
|
|
idx, *mask_other = tree_util.tree_unflatten(args_tree, args)
|
|
avals_in = ctx.avals_in
|
|
idx_avals, *_ = tree_util.tree_unflatten(args_tree, avals_in[2:])
|
|
ptr = _compute_pointers_from_indices(
|
|
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder
|
|
)
|
|
mask = None
|
|
if masked:
|
|
assert len(mask_other) == 1
|
|
(mask,) = mask_other
|
|
return tl.store(ptr, value, mask=mask, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule
|
|
|
|
|
|
def _addupdate_lowering_rule(
|
|
ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims
|
|
):
|
|
ref_block_info, *_ = ctx.block_infos
|
|
avals_in = ctx.avals_in
|
|
idx = _pack_indices(non_slice_idx, indexed_dims)
|
|
if non_slice_idx:
|
|
(int_indexer_shape,) = {
|
|
tuple(map(lambda x: x.value, i.shape)) for i in non_slice_idx
|
|
}
|
|
else:
|
|
int_indexer_shape = ()
|
|
idx = tuple(
|
|
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
|
|
for s, slc in zip(avals_in[0].shape, idx)
|
|
)
|
|
idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape)
|
|
ptr = _compute_pointers_from_indices(
|
|
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder
|
|
)
|
|
tl.atomic_add(ptr, value, _builder=ctx.builder)
|
|
return []
|
|
|
|
|
|
triton_lowering_rules[sp.addupdate_p] = _addupdate_lowering_rule
|
|
|
|
|
|
def _transpose_lowering(ctx: TritonLoweringRuleContext, a, *, permutation):
|
|
if permutation != (1, 0):
|
|
raise NotImplementedError(permutation)
|
|
return tl.trans(a, _builder=ctx.builder)
|
|
|
|
|
|
triton_lowering_rules[lax.transpose_p] = _transpose_lowering
|
|
|
|
|
|
def _dot_general_lowering(
|
|
ctx: TritonLoweringRuleContext,
|
|
a,
|
|
b,
|
|
*,
|
|
dimension_numbers,
|
|
precision,
|
|
preferred_element_type
|
|
):
|
|
contract_dims, batch_dims = dimension_numbers
|
|
assert batch_dims == ((), ())
|
|
(a_contract_dim,) = contract_dims[0]
|
|
(b_contract_dim,) = contract_dims[1]
|
|
trans_a = a_contract_dim == 0
|
|
trans_b = b_contract_dim == 1
|
|
if trans_a:
|
|
a = tl.trans(a, _builder=ctx.builder)
|
|
if trans_b:
|
|
b = tl.trans(b, _builder=ctx.builder)
|
|
allow_tf32 = (
|
|
precision == lax.Precision.HIGH or precision == lax.Precision.DEFAULT
|
|
)
|
|
return tl.dot(a, b, _builder=ctx.builder, allow_tf32=allow_tf32)
|
|
|
|
|
|
triton_lowering_rules[lax.dot_general_p] = _dot_general_lowering
|
|
|
|
|
|
def _reduction_lowering(body, ctx: TritonLoweringRuleContext, args, axes):
|
|
flat_args = tree_util.tree_leaves(args)
|
|
(axis,) = axes
|
|
mapped_avals = [
|
|
jax_core.mapped_aval(aval.shape[axis], axis, aval)
|
|
for aval in ctx.avals_in
|
|
]
|
|
in_tree = tree_util.tree_structure((args, args))
|
|
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
|
lu.wrap_init(body), in_tree
|
|
)
|
|
combine_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
|
flat_fun, [*mapped_avals, *mapped_avals]
|
|
)
|
|
out_tree = out_tree_thunk()
|
|
del out_tree # Not needed
|
|
if consts:
|
|
raise NotImplementedError("Reductions with constants not supported.")
|
|
element_types = [arg.type.scalar for arg in flat_args]
|
|
builder = ctx.builder
|
|
reduce_op = builder.create_reduce([t.handle for t in flat_args], axis)
|
|
region = reduce_op.get_region(0)
|
|
old_ip = builder.get_insertion_point()
|
|
param_types = element_types * 2
|
|
ir_param_types = [ty.to_ir(builder) for ty in param_types]
|
|
block = builder.create_block_with_parent(region, ir_param_types)
|
|
combine_args = [
|
|
tl.core.tensor(block.arg(i), ty) for i, ty in enumerate(param_types)
|
|
]
|
|
results = lower_jaxpr_to_triton_ir(ctx, combine_jaxpr, None, *combine_args)
|
|
handles = [r.handle for r in results]
|
|
builder.create_reduce_ret(*handles)
|
|
builder.restore_insertion_point(old_ip)
|
|
reduce_op.verify()
|
|
|
|
def wrap_tensor(x, scalar_ty):
|
|
if ctx.avals_out[0].shape:
|
|
res_ty = tl.block_type(scalar_ty, ctx.avals_out[0].shape)
|
|
else:
|
|
# 0d-tensor -> scalar
|
|
res_ty = scalar_ty
|
|
return tl.core.tensor(x, res_ty)
|
|
|
|
results = [
|
|
wrap_tensor(reduce_op.get_result(i), ty)
|
|
for i, ty in enumerate(element_types)
|
|
]
|
|
return results
|
|
|
|
|
|
def _reduce_lowering(body, ctx: TritonLoweringRuleContext, a, *, axes):
|
|
assert isinstance(axes, tuple)
|
|
if not axes:
|
|
return a
|
|
while len(axes) > 1:
|
|
axis = max(axes)
|
|
dst_avals = tuple(v.update(shape=v.shape[:axis] + v.shape[axis + 1:])
|
|
for v in ctx.avals_in)
|
|
a = _reduce_lowering(
|
|
body, ctx.replace(avals_out=dst_avals), a, axes=(axis,))
|
|
# Adding an intervening -(-reduce(.)) introduces a convert_layout between
|
|
# reduces, which seems necessary for correctness.
|
|
# TODO(bjp): Get rid of the double negation.
|
|
# https://github.com/openai/triton/issues/1776
|
|
a = a.__neg__(_builder=ctx.builder).__neg__(_builder=ctx.builder)
|
|
ctx = ctx.replace(avals_in=dst_avals)
|
|
axes = tuple(ax for ax in axes if ax != axis)
|
|
return _reduction_lowering(body, ctx, a, axes=axes)[0]
|
|
|
|
|
|
triton_lowering_rules[lax.reduce_max_p] = functools.partial(
|
|
_reduce_lowering, jnp.maximum
|
|
)
|
|
triton_lowering_rules[lax.reduce_min_p] = functools.partial(
|
|
_reduce_lowering, jnp.minimum
|
|
)
|
|
triton_lowering_rules[lax.reduce_sum_p] = functools.partial(
|
|
_reduce_lowering, jnp.add
|
|
)
|
|
|
|
|
|
def _argreduce_lowering(
|
|
body, ctx: TritonLoweringRuleContext, a, *, axes, index_dtype
|
|
):
|
|
if index_dtype != jnp.int32:
|
|
raise ValueError("`index_type` must be f32.")
|
|
if len(axes) != 1:
|
|
raise ValueError("`pallas` reduce operations only support one reduce axis.")
|
|
(axis,) = axes
|
|
n = ctx.avals_in[0].shape[axis]
|
|
index = tl.arange(0, n, _builder=ctx.builder)
|
|
if len(a.shape) > 1:
|
|
# Broadcast index across the non-reduced axes
|
|
expand_dims_index = [tl.constexpr(None)] * len(a.shape)
|
|
expand_dims_index[axis] = slice(None)
|
|
index = index.__getitem__(expand_dims_index, _builder=ctx.builder)
|
|
index = tl.core.broadcast_to(index, a.shape, _builder=ctx.builder)
|
|
ctx = ctx.replace(
|
|
avals_in=[
|
|
ctx.avals_in[0],
|
|
ctx.avals_in[0].update(dtype=jnp.dtype("int32")),
|
|
]
|
|
)
|
|
_, indices = _reduction_lowering(body, ctx, (a, index), axes=axes)
|
|
return indices
|
|
|
|
|
|
def _reduce_argmax_combine(left, right):
|
|
value1, index1 = left
|
|
value2, index2 = right
|
|
gt = value1 > value2
|
|
lt = value1 < value2
|
|
index_min = jnp.minimum(index1, index2)
|
|
index_ret = jnp.where(gt, index1, jnp.where(lt, index2, index_min))
|
|
value_ret = jnp.maximum(value1, value2)
|
|
return value_ret, index_ret
|
|
|
|
|
|
triton_lowering_rules[lax.argmax_p] = functools.partial(
|
|
_argreduce_lowering, _reduce_argmax_combine
|
|
)
|
|
|
|
|
|
def _reduce_argmin_combine(left, right):
|
|
value1, index1 = left
|
|
value2, index2 = right
|
|
gt = value1 > value2
|
|
lt = value1 < value2
|
|
index_min = jnp.minimum(index1, index2)
|
|
index_ret = jnp.where(lt, index1, jnp.where(gt, index2, index_min))
|
|
value_ret = jnp.minimum(value1, value2)
|
|
return value_ret, index_ret
|
|
|
|
|
|
triton_lowering_rules[lax.argmin_p] = functools.partial(
|
|
_argreduce_lowering, _reduce_argmin_combine
|
|
)
|
|
|
|
|
|
def _pjit_lowering_rule(ctx: TritonLoweringRuleContext, *args, jaxpr, **_):
|
|
if jaxpr.consts:
|
|
raise NotImplementedError
|
|
return lower_jaxpr_to_triton_ir(
|
|
ctx.context, jaxpr.jaxpr, ctx.block_infos, *args
|
|
)
|
|
|
|
|
|
triton_lowering_rules[pjit.pjit_p] = _pjit_lowering_rule
|
|
|
|
|
|
def _closed_call_lowering_rule(
|
|
ctx: TritonLoweringRuleContext, *args, call_jaxpr, **_
|
|
):
|
|
jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
|
|
if consts:
|
|
raise NotImplementedError
|
|
return lower_jaxpr_to_triton_ir(ctx.context, jaxpr, ctx.block_infos, *args)
|
|
|
|
|
|
triton_lowering_rules[jax_core.closed_call_p] = _closed_call_lowering_rule
|
|
|
|
|
|
def _remat_lowering_rule(ctx: TritonLoweringRuleContext, *args, jaxpr, **_):
|
|
return lower_jaxpr_to_triton_ir(ctx.context, jaxpr, ctx.block_infos, *args)
|
|
|
|
|
|
triton_lowering_rules[ad_checkpoint.remat_p] = _remat_lowering_rule
|
|
|
|
|
|
def _is_read_only(ref_effects) -> bool:
|
|
if len(ref_effects) == 0:
|
|
return True
|
|
if len(ref_effects) > 1:
|
|
# Means we must have a write or accum effect so not read-only
|
|
return False
|
|
(eff,) = ref_effects
|
|
return isinstance(eff, state.ReadEffect)
|
|
|
|
|
|
def _for_lowering_rule(
|
|
ctx: TritonLoweringRuleContext,
|
|
*args,
|
|
jaxpr,
|
|
which_linear,
|
|
nsteps,
|
|
reverse,
|
|
unroll
|
|
):
|
|
del which_linear
|
|
if reverse or unroll != 1:
|
|
raise NotImplementedError
|
|
builder = ctx.builder
|
|
lower_bound = builder.get_int32(0)
|
|
upper_bound = builder.get_int32(nsteps)
|
|
step = builder.get_int32(1)
|
|
current_block = builder.get_insertion_block()
|
|
init_args = args
|
|
# Partially discharge state from jaxpr for non-pointers
|
|
should_discharge = [not isinstance(a, AbstractRef) for a in ctx.avals_in]
|
|
discharged_jaxpr, () = discharge.discharge_state(
|
|
jaxpr, (), should_discharge=[True, *should_discharge]
|
|
)
|
|
in_avals = [v.aval for v in jaxpr.invars]
|
|
state_effects = state.get_ref_state_effects(in_avals, jaxpr.effects)[1:]
|
|
# Read-only `Ref`s don't need to be passed in explicitly as loop arguments so
|
|
# we can filter them out.
|
|
read_only = map(_is_read_only, state_effects)
|
|
is_loop_arg = map(
|
|
operator.and_, map(operator.not_, read_only), should_discharge
|
|
)
|
|
ptrs, _ = partition_list(should_discharge, init_args)
|
|
non_loop_args, loop_args = partition_list(is_loop_arg, init_args)
|
|
for_op = builder.create_for_op(
|
|
lower_bound, upper_bound, step, [arg.handle for arg in loop_args]
|
|
)
|
|
loop_block = builder.create_block()
|
|
builder.set_insertion_point_to_start(loop_block)
|
|
loop_index = tl.core.tensor(for_op.get_induction_var(), tl.core.int32)
|
|
# Emit loop body
|
|
for_body_args = [
|
|
tl.core.tensor(for_op.get_body(0).arg(i + 1), arg.type)
|
|
for i, arg in enumerate(loop_args)
|
|
]
|
|
loop_body_args = merge_lists(is_loop_arg, non_loop_args, for_body_args)
|
|
out_discharged = lower_jaxpr_to_triton_ir(
|
|
ctx.context,
|
|
discharged_jaxpr,
|
|
[None, *ctx.block_infos],
|
|
loop_index,
|
|
*loop_body_args
|
|
)
|
|
all_out = merge_lists(should_discharge, ptrs, out_discharged)
|
|
_, loop_out = partition_list(is_loop_arg, all_out)
|
|
if loop_out:
|
|
builder.create_yield_op([arg.handle for arg in loop_out])
|
|
loop_block.merge_block_before(for_op.get_body(0))
|
|
for_results = [for_op.get_result(i) for i in range(len(loop_args))]
|
|
builder.set_insertion_point_to_end(current_block)
|
|
for_out = [tl.core.tensor(r, a.type) for r, a in zip(for_results, loop_args)]
|
|
return merge_lists(is_loop_arg, non_loop_args, for_out)
|
|
|
|
|
|
triton_lowering_rules[for_loop.for_p] = _for_lowering_rule
|
|
|
|
def _lower_jaxpr_to_for_loop(ctx: TritonLoweringRuleContext, jaxpr: jax_core.Jaxpr,
|
|
lower_bound, upper_bound, consts, *args,
|
|
has_loop_index: bool,
|
|
step: int = 1,
|
|
bound_type: tl.dtype = tl.int32):
|
|
if step != 1:
|
|
raise NotImplementedError
|
|
builder = ctx.builder
|
|
if bound_type == tl.int64:
|
|
step = builder.get_int64(step)
|
|
else:
|
|
step = builder.get_int32(step)
|
|
current_block = builder.get_insertion_block()
|
|
for_op = builder.create_for_op(
|
|
lower_bound, upper_bound, step, [arg.handle for arg in args]
|
|
)
|
|
loop_block = builder.create_block()
|
|
builder.set_insertion_point_to_start(loop_block)
|
|
loop_index = tl.core.tensor(for_op.get_induction_var(), tl.core.int32)
|
|
# Emit loop body
|
|
for_body_args = [
|
|
tl.core.tensor(for_op.get_body(0).arg(i + 1), arg.type)
|
|
for i, arg in enumerate(args)
|
|
]
|
|
if has_loop_index:
|
|
jaxpr_args = [*consts, loop_index, *for_body_args]
|
|
else:
|
|
jaxpr_args = [*consts, *for_body_args]
|
|
all_out = lower_jaxpr_to_triton_ir(
|
|
ctx.context,
|
|
jaxpr,
|
|
ctx.block_infos,
|
|
*jaxpr_args)
|
|
if all_out:
|
|
builder.create_yield_op([arg.handle for arg in all_out])
|
|
loop_block.merge_block_before(for_op.get_body(0))
|
|
for_results = [for_op.get_result(i) for i in range(len(args))]
|
|
builder.set_insertion_point_to_end(current_block)
|
|
return [tl.core.tensor(r, a.type) for r, a in zip(for_results, args)]
|
|
|
|
|
|
def _scan_lowering_rule(
|
|
ctx: TritonLoweringRuleContext,
|
|
*args,
|
|
jaxpr,
|
|
linear,
|
|
length,
|
|
reverse,
|
|
unroll,
|
|
num_consts,
|
|
num_carry,
|
|
):
|
|
# Only implements fori_loop-like scans
|
|
num_extensive = len(args) - num_consts - num_carry
|
|
if num_extensive: raise NotImplementedError
|
|
if reverse: raise NotImplementedError
|
|
if unroll != 1: raise NotImplementedError
|
|
del linear, num_extensive, unroll, reverse
|
|
|
|
builder = ctx.builder
|
|
jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts
|
|
if jaxpr_consts: raise NotImplementedError
|
|
del jaxpr_consts
|
|
|
|
jaxpr, has_loop_index = (
|
|
pallas_utils.pattern_match_scan_to_fori_loop(jaxpr, num_consts, num_carry)
|
|
)
|
|
consts, args = util.split_list(args, [num_consts])
|
|
if has_loop_index:
|
|
lb, *args = args
|
|
lower_bound = lb.handle
|
|
ub = lb.__add__(tl.constexpr(length), _builder=builder)
|
|
upper_bound = ub.handle
|
|
bound_type = ub.type
|
|
else:
|
|
lower_bound = builder.get_int32(0)
|
|
upper_bound = builder.get_int32(length)
|
|
bound_type = tl.int32
|
|
for_out = _lower_jaxpr_to_for_loop(
|
|
ctx, jaxpr, lower_bound, upper_bound, consts, *args,
|
|
has_loop_index=has_loop_index, step=1, bound_type=bound_type)
|
|
if has_loop_index:
|
|
# Need to return the final loop index value if the outer scan expects
|
|
# it as an output
|
|
return [tl.core.tensor(upper_bound, bound_type), *for_out]
|
|
return for_out
|
|
|
|
triton_lowering_rules[lax.scan_p] = _scan_lowering_rule
|
|
|
|
def _maybe_pattern_match_fori_loop(ctx: TritonLoweringRuleContext, *args,
|
|
cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr
|
|
):
|
|
if cond_nconsts:
|
|
return None
|
|
_, cond_invars = split_list(cond_jaxpr.jaxpr.invars, [cond_nconsts])
|
|
cond_in_avals = [v.aval for v in cond_invars]
|
|
if len(cond_in_avals) < 2:
|
|
return None
|
|
# Check that the first two carry values are scalar ints
|
|
a1, a2 = cond_in_avals[:2]
|
|
if a1.shape != () or a1.dtype not in (jnp.int32, jnp.int64):
|
|
return None
|
|
if a2.shape != () or a2.dtype not in (jnp.int32, jnp.int64):
|
|
return None
|
|
# Check that the only eqn in the cond checks the loop index condition
|
|
v1, v2 = cond_invars[:2]
|
|
outvar = cond_jaxpr.jaxpr.outvars[0]
|
|
assert outvar.aval.dtype == jnp.bool_
|
|
if len(cond_jaxpr.jaxpr.eqns) != 1:
|
|
return None
|
|
eqn = cond_jaxpr.jaxpr.eqns[0]
|
|
if eqn.primitive != lax.lt_p:
|
|
return None
|
|
if eqn.outvars != [outvar]:
|
|
return None
|
|
if eqn.invars != [v1, v2]:
|
|
return None
|
|
# Check that the carry is updated in the body appropriately
|
|
_, body_invars = split_list(body_jaxpr.jaxpr.invars, [body_nconsts])
|
|
v1, v2 = body_invars[:2]
|
|
vo1, vo2 = body_jaxpr.jaxpr.outvars[:2]
|
|
# Upper bound should be constant
|
|
if v2 is not vo2:
|
|
return None
|
|
# Check that we increment the loop index in the body
|
|
for i, eqn in enumerate(body_jaxpr.jaxpr.eqns):
|
|
if eqn.primitive is lax.add_p:
|
|
if eqn.invars[0] is v1:
|
|
if isinstance(eqn.invars[1], jax_core.Literal):
|
|
if eqn.invars[1].val == 1:
|
|
if eqn.outvars[0] == vo1:
|
|
eqn_index = i
|
|
break
|
|
else:
|
|
return None
|
|
jaxpr = body_jaxpr.jaxpr
|
|
new_invars = tuple((*jaxpr.invars[:body_nconsts],
|
|
jaxpr.invars[body_nconsts],
|
|
*jaxpr.invars[body_nconsts + 2:]))
|
|
new_outvars = tuple(jaxpr.outvars[2:])
|
|
jaxpr = jaxpr.replace(
|
|
eqns=jaxpr.eqns[:eqn_index] + jaxpr.eqns[eqn_index + 1:],
|
|
invars=new_invars,
|
|
outvars=new_outvars)
|
|
_, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
|
|
(lb, ub), args = carry[:2], carry[2:]
|
|
const_block_infos, args_block_infos = split_list(ctx.block_infos,
|
|
[body_nconsts])
|
|
ctx = ctx.replace(block_infos=[*const_block_infos, None,
|
|
*args_block_infos[2:]])
|
|
for_out = _lower_jaxpr_to_for_loop(ctx, jaxpr, lb.handle, ub.handle,
|
|
body_consts, *args, has_loop_index=True,
|
|
step=1, bound_type=lb.type)
|
|
return [ub, ub, *for_out]
|
|
|
|
def _while_lowering_rule(
|
|
ctx: TritonLoweringRuleContext,
|
|
*args,
|
|
cond_nconsts,
|
|
cond_jaxpr,
|
|
body_nconsts,
|
|
body_jaxpr
|
|
):
|
|
# First, try to pattern match to fori_loop and lower to scf.for if possible
|
|
result = _maybe_pattern_match_fori_loop(ctx, *args, cond_nconsts=cond_nconsts,
|
|
body_nconsts=body_nconsts, cond_jaxpr=cond_jaxpr,
|
|
body_jaxpr=body_jaxpr)
|
|
if result is not None:
|
|
return result
|
|
# Fall back to default while lowering
|
|
num_args = len(args)
|
|
cond_consts, body_consts, carry = util.split_list(
|
|
args, [cond_nconsts, body_nconsts]
|
|
)
|
|
cond_const_block_infos, body_const_block_infos, carry_block_infos = (
|
|
util.split_list(ctx.block_infos, [cond_nconsts, body_nconsts])
|
|
)
|
|
current_bb = ctx.builder.get_insertion_block()
|
|
cond_const_types = [a.type.to_ir(ctx.builder) for a in cond_consts]
|
|
body_const_types = [a.type.to_ir(ctx.builder) for a in body_consts]
|
|
carry_types = [a.type.to_ir(ctx.builder) for a in carry]
|
|
all_types = [*cond_const_types, *body_const_types, *carry_types]
|
|
while_op = ctx.builder.create_while_op(
|
|
[*cond_const_types, *body_const_types, *carry_types],
|
|
[arg.handle for arg in args],
|
|
)
|
|
before_block = ctx.builder.create_block_with_parent(
|
|
while_op.get_before(), all_types
|
|
)
|
|
ctx.builder.set_insertion_point_to_start(before_block)
|
|
cond_consts_, _, carry_ = util.split_list(
|
|
[before_block.arg(i) for i in range(num_args)],
|
|
[cond_nconsts, body_nconsts],
|
|
)
|
|
cond_args = [
|
|
tl.core.tensor(a, b.type)
|
|
for a, b in zip([*cond_consts_, *carry_], [*cond_consts, *carry])
|
|
]
|
|
(cond,) = lower_jaxpr_to_triton_ir(
|
|
ctx.context,
|
|
cond_jaxpr.jaxpr,
|
|
[*cond_const_block_infos, *carry_block_infos],
|
|
*cond_args
|
|
)
|
|
ctx.builder.create_condition_op(
|
|
cond.handle, [before_block.arg(i) for i in range(num_args)]
|
|
)
|
|
after_block = ctx.builder.create_block_with_parent(
|
|
while_op.get_after(), all_types
|
|
)
|
|
ctx.builder.set_insertion_point_to_start(after_block)
|
|
cond_consts_, body_consts_, carry_ = util.split_list(
|
|
[after_block.arg(i) for i in range(num_args)],
|
|
[cond_nconsts, body_nconsts],
|
|
)
|
|
all_args = [
|
|
tl.core.tensor(a, b.type)
|
|
for a, b in zip(
|
|
[*cond_consts_, *body_consts_, *carry_],
|
|
[*cond_consts, *body_consts, *carry],
|
|
)
|
|
]
|
|
cond_const_args, body_const_args, carry_args = util.split_list(
|
|
all_args, [cond_nconsts, body_nconsts]
|
|
)
|
|
loop_out = lower_jaxpr_to_triton_ir(
|
|
ctx.context,
|
|
body_jaxpr.jaxpr,
|
|
[*body_const_block_infos, *carry_block_infos],
|
|
*body_const_args,
|
|
*carry_args
|
|
)
|
|
cond_consts_handles = [a.handle for a in cond_const_args]
|
|
body_consts_handles = [a.handle for a in body_const_args]
|
|
loop_out_handles = [a.handle for a in loop_out]
|
|
all_handles = [*cond_consts_handles, *body_consts_handles, *loop_out_handles]
|
|
if all_handles:
|
|
ctx.builder.create_yield_op(all_handles)
|
|
ctx.builder.set_insertion_point_to_end(current_bb)
|
|
all_out = [
|
|
tl.core.tensor(while_op.get_result(i), a.type) for i, a in enumerate(args)
|
|
]
|
|
return all_out[cond_nconsts + body_nconsts :]
|
|
|
|
|
|
triton_lowering_rules[lax.while_p] = _while_lowering_rule
|
|
|
|
|
|
def _cond_lowering_rule(
|
|
ctx: TritonLoweringRuleContext,
|
|
index,
|
|
*args, # *consts, *ops
|
|
branches, # tuple(jaxprs)
|
|
linear,
|
|
):
|
|
block_infos = ctx.block_infos
|
|
current_bb = ctx.builder.get_insertion_block()
|
|
|
|
def to_type(out_aval):
|
|
elt_type = code_gen.str_to_ty(get_triton_type(out_aval)).element_ty
|
|
if not out_aval.shape:
|
|
# Scalar types get special handling.
|
|
return elt_type
|
|
return tl.block_type(elt_type, out_aval.shape)
|
|
|
|
out_types = [to_type(out) for out in ctx.avals_out]
|
|
out_ir_types = [t.to_ir(ctx.builder) for t in out_types]
|
|
|
|
use_branch0 = index.__eq__(0, _builder=ctx.builder)
|
|
# TODO(bjp): Switch to scf.index_switch once exposed in triton.cc
|
|
if_op = ctx.builder.create_if_op(
|
|
out_ir_types, # retTypes
|
|
use_branch0.handle, # condition
|
|
True) # withElse
|
|
# Lower then block.
|
|
ctx.builder.set_insertion_point_to_start(if_op.get_then_block())
|
|
outs0 = lower_jaxpr_to_triton_ir(
|
|
ctx.context,
|
|
branches[0].jaxpr,
|
|
block_infos[1:],
|
|
*args)
|
|
if outs0:
|
|
ctx.builder.create_yield_op([out.handle for out in outs0])
|
|
# Lower else block.
|
|
ctx.builder.set_insertion_point_to_start(if_op.get_else_block())
|
|
# TODO(bjp): Instead of linear nest of 'if's, partition into halves.
|
|
if len(branches) > 2:
|
|
outs1 = _cond_lowering_rule(
|
|
ctx,
|
|
index.__sub__(1, _builder=ctx.builder),
|
|
*args,
|
|
branches=branches[1:],
|
|
linear=linear)
|
|
else:
|
|
outs1 = lower_jaxpr_to_triton_ir(
|
|
ctx.context,
|
|
branches[1].jaxpr,
|
|
block_infos[1:],
|
|
*args)
|
|
if outs1:
|
|
ctx.builder.create_yield_op([out.handle for out in outs1])
|
|
ctx.builder.set_insertion_point_to_end(current_bb)
|
|
all_out = [
|
|
tl.core.tensor(if_op.get_result(i), ty)
|
|
for i, ty in enumerate(out_types)
|
|
]
|
|
return all_out
|
|
|
|
triton_lowering_rules[lax.cond_p] = _cond_lowering_rule
|
|
|
|
|
|
@weakref_lru_cache
|
|
def compile_jaxpr(
|
|
jaxpr: jax_core.Jaxpr,
|
|
in_shapes,
|
|
grid_mapping: GridMapping,
|
|
name: str,
|
|
num_warps: int,
|
|
num_stages: int,
|
|
debug: bool,
|
|
) -> TritonCompilationResult:
|
|
lowering_result = lower_jaxpr_to_triton_module(
|
|
jaxpr, in_shapes, grid_mapping, name
|
|
)
|
|
device = 0
|
|
ttir = str(lowering_result.module)
|
|
ptx, name, shared_mem_bytes, compute_capability = compile_ttir_to_ptx_inplace(
|
|
lowering_result.module,
|
|
device=device,
|
|
num_warps=num_warps,
|
|
num_stages=num_stages,
|
|
dump=debug,
|
|
)
|
|
return TritonCompilationResult(
|
|
name, ttir, ptx, shared_mem_bytes, compute_capability, lowering_result
|
|
)
|
|
|
|
|
|
def pallas_call_lowering(
|
|
ctx: mlir.LoweringRuleContext,
|
|
*in_nodes,
|
|
jaxpr: jax_core.Jaxpr,
|
|
name: str,
|
|
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
|
|
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
|
|
which_linear: Tuple[bool, ...],
|
|
interpret: bool,
|
|
debug: bool,
|
|
input_output_aliases: Tuple[Tuple[int, int], ...],
|
|
grid_mapping: GridMapping,
|
|
triton_params: Dict[str, Any] | None = None,
|
|
**compiler_params: Any,
|
|
):
|
|
if interpret:
|
|
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
|
ctx,
|
|
*in_nodes,
|
|
jaxpr=jaxpr,
|
|
name=name,
|
|
out_shapes=out_shapes,
|
|
in_shapes=in_shapes,
|
|
which_linear=which_linear,
|
|
interpret=interpret,
|
|
debug=debug,
|
|
input_output_aliases=input_output_aliases,
|
|
grid_mapping=grid_mapping,
|
|
**compiler_params
|
|
)
|
|
num_warps = compiler_params.get("num_warps", 4)
|
|
num_stages = compiler_params.get("num_stages", 3)
|
|
if debug:
|
|
print(jaxpr)
|
|
print(grid_mapping)
|
|
compilation_result = compile_jaxpr(
|
|
jaxpr,
|
|
tuple((*in_shapes, *out_shapes)),
|
|
grid_mapping,
|
|
name,
|
|
num_warps,
|
|
num_stages,
|
|
debug=debug,
|
|
)
|
|
|
|
if debug:
|
|
compilation_result.lowering_result.module.dump()
|
|
|
|
kernel = triton_kernel_call_lib.TritonKernel(
|
|
compilation_result.kernel_name,
|
|
num_warps,
|
|
compilation_result.shared_mem_bytes,
|
|
compilation_result.ptx,
|
|
compilation_result.ttir,
|
|
compilation_result.compute_capability,
|
|
)
|
|
|
|
grid = triton_lib.normalize_grid(
|
|
compilation_result.lowering_result.grid, metaparams={}
|
|
)
|
|
|
|
kernel_params = []
|
|
for _ in range(len(in_shapes) + len(out_shapes)):
|
|
kernel_params.append(
|
|
triton_kernel_call_lib.create_array_parameter(
|
|
0, # bytes to zero # TODO(cjfj): Expose through user API.
|
|
16, # divisible by 16
|
|
)
|
|
)
|
|
|
|
kernel_call = triton_kernel_call_lib.TritonKernelCall(
|
|
kernel, grid[0], grid[1], grid[2], kernel_params
|
|
)
|
|
|
|
out_types = [
|
|
ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
|
|
for shape in out_shapes
|
|
]
|
|
|
|
xc.register_custom_call_target(
|
|
name, triton_kernel_call_lib.get_custom_call(), platform="CUDA"
|
|
)
|
|
|
|
if triton_params is None:
|
|
triton_params = {}
|
|
serialized_metadata = triton_params.get("serialized_metadata", b"")
|
|
if version >= (0, 4, 15):
|
|
kernel_call_proto = kernel_call.to_proto(name, serialized_metadata)
|
|
else:
|
|
kernel_call_proto = kernel_call.to_proto(serialized_metadata)
|
|
return hlo_helpers.custom_call(
|
|
call_target_name=name,
|
|
out_types=out_types,
|
|
operands=in_nodes,
|
|
backend_config=zlib.compress(kernel_call_proto),
|
|
operand_layouts=triton_lib.avals_to_layouts(ctx.avals_in),
|
|
result_layouts=triton_lib.avals_to_layouts(ctx.avals_out),
|
|
operand_output_aliases=dict(input_output_aliases),
|
|
)
|
|
|
|
|
|
mlir.register_lowering(pallas_call_p, pallas_call_lowering, platform="cuda")
|