mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
3827 lines
127 KiB
Python
3827 lines
127 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Module for lowering JAX to Mosaic-compatible MLIR dialects."""
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Sequence
|
|
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import string
|
|
from typing import Any, Hashable
|
|
|
|
import jax
|
|
from jax import api_util
|
|
from jax import lax
|
|
from jax import tree_util
|
|
from jax._src import ad_util
|
|
from jax._src import checkify
|
|
from jax._src import core as jax_core
|
|
from jax._src import custom_derivatives
|
|
from jax._src import debugging
|
|
from jax._src import dtypes
|
|
from jax._src import linear_util as lu
|
|
from jax._src import mesh as mesh_lib
|
|
from jax._src import pjit
|
|
from jax._src import prng
|
|
from jax._src import source_info_util
|
|
from jax._src import state
|
|
from jax._src import traceback_util
|
|
from jax._src.cloud_tpu_init import is_cloud_tpu_older_than
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.lax import lax as lax_internal
|
|
from jax._src.lax.control_flow import for_loop
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import arith
|
|
from jax._src.lib.mlir.dialects import func
|
|
from jax._src.lib.mlir.dialects import math
|
|
from jax._src.lib.mlir.dialects import memref
|
|
from jax._src.lib.mlir.dialects import scf
|
|
from jax._src.lib.mlir.dialects import vector
|
|
from jax._src.pallas import core as pallas_core
|
|
from jax._src.pallas import pallas_call
|
|
from jax._src.pallas import primitives
|
|
from jax._src.pallas import utils as pallas_utils
|
|
from jax._src.pallas.mosaic import core as tpu_core
|
|
from jax._src.pallas.mosaic import error_handling
|
|
from jax._src.pallas.mosaic import primitives as tpu_primitives
|
|
from jax._src.pallas.mosaic import random as pl_random
|
|
from jax._src.state import discharge as state_discharge
|
|
from jax._src.state import indexing
|
|
from jax._src.state import primitives as state_primitives
|
|
from jax._src.state.types import RefBitcaster, RefReshaper
|
|
from jax._src.state.utils import dtype_bitwidth
|
|
from jax._src.typing import Array, DTypeLike
|
|
from jax._src.util import safe_map
|
|
from jax._src.util import safe_zip
|
|
from jax._src.util import split_list
|
|
from jax._src.util import unzip2
|
|
from jax.experimental.mosaic.dialects import tpu
|
|
import jax.numpy as jnp
|
|
from jaxlib.mlir.ir import Module
|
|
import numpy as np
|
|
|
|
# TODO(sharadmv): enable type checking
|
|
# mypy: ignore-errors
|
|
|
|
NDIndexer = indexing.NDIndexer
|
|
TPUMemorySpace = tpu_core.TPUMemorySpace
|
|
MemorySpace = pallas_core.MemorySpace | TPUMemorySpace
|
|
VMEM = tpu_core.TPUMemorySpace.VMEM
|
|
SMEM = tpu_core.TPUMemorySpace.SMEM
|
|
# Booleans are stored as the following type in memrefs.
|
|
BOOL_MEMREF_TYPE = np.dtype('int32')
|
|
|
|
# The value interpreted as a dynamic dimension by MLIR.
|
|
MLIR_DYNAMIC = -9223372036854775808
|
|
|
|
partial = functools.partial
|
|
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
|
|
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MeshContext:
|
|
mesh_shape: tuple[int, ...]
|
|
axis_names: tuple[str, ...]
|
|
mesh_strides: tuple[int, ...]
|
|
|
|
# Note - On Export Placeholders
|
|
#
|
|
# Mosaic uses vector IR, which does not have a concept of dynamic
|
|
# dimensions. We need to come up with a way to represent dynamic dimensions in
|
|
# vector IR, and so we use placeholders, which are later replaced during
|
|
# specialization.
|
|
class LoweringDynamicShapeEnv:
|
|
dim_expr_to_placeholder: dict[Any, ir.Value] = {}
|
|
|
|
def to_placeholder(self, dim_expr: Any) -> ir.Value:
|
|
if dim_expr not in self.dim_expr_to_placeholder:
|
|
next_val = np.iinfo(np.int32).max - len(self.dim_expr_to_placeholder)
|
|
self.dim_expr_to_placeholder[dim_expr] = next_val
|
|
return self.dim_expr_to_placeholder[dim_expr]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class LoweringContext:
|
|
ir_context: ir.Context
|
|
grid_sizes: tuple[int, ...] # Includes both user and vmap axes.
|
|
grid_names: tuple[Hashable, ...] | None
|
|
mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions.
|
|
user_grid_indices: Sequence[ir.Value] | None
|
|
block_shapes: list[tuple[int | pallas_core.Mapped, ...]]
|
|
name_stack: source_info_util.NameStack
|
|
mesh_context: MeshContext | None
|
|
replace = dataclasses.replace
|
|
traceback_caches: mlir.TracebackCaches
|
|
for_verification: bool
|
|
forward_compatible: bool
|
|
dynamic_shape_replacement_fn: Callable[
|
|
[tuple[jax.DimSize, ...]], tuple[int, ...]
|
|
]
|
|
|
|
@property
|
|
def grid_rank(self):
|
|
return len(self.grid_sizes)
|
|
|
|
@contextlib.contextmanager
|
|
def grid_name_context(self):
|
|
# TODO(b/355036977): generalize this across other platforms
|
|
if not self.grid_names:
|
|
yield
|
|
return
|
|
grid_names = self.grid_names
|
|
valid_grid_sizes = tuple(
|
|
d for i, d in enumerate(self.grid_sizes) if i not in self.mapped_dims
|
|
)
|
|
grid_env = zip(grid_names, valid_grid_sizes)
|
|
with jax_core.extend_axis_env_nd(grid_env):
|
|
yield
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class LoweringRuleContext:
|
|
lowering_context: LoweringContext
|
|
avals_in: Sequence[jax_core.AbstractValue]
|
|
avals_out: Sequence[jax_core.AbstractValue]
|
|
block_shapes: Sequence[tuple[int | pallas_core.Mapped, ...] | None]
|
|
replace = dataclasses.replace
|
|
|
|
@property
|
|
def forward_compatible(self):
|
|
return self.lowering_context.forward_compatible
|
|
|
|
|
|
def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None
|
|
) -> TPUMemorySpace:
|
|
match memory_space:
|
|
case None:
|
|
# We pick VMEM as the default one when no memory space is
|
|
# specified
|
|
return TPUMemorySpace.VMEM
|
|
case pallas_core.MemorySpace.ANY:
|
|
# Map the general ANY memory space to TPU ANY memory space
|
|
return TPUMemorySpace.ANY
|
|
case pallas_core.MemorySpace.ERROR | pallas_core.MemorySpace.INDEX:
|
|
return TPUMemorySpace.SMEM
|
|
case TPUMemorySpace():
|
|
# Leave the memory space unchanged
|
|
return memory_space
|
|
case _:
|
|
raise ValueError(f"Invalid memory space: {memory_space}")
|
|
|
|
|
|
def _memory_space_to_mosaic_attribute(memory_space: MemorySpace | None
|
|
) -> ir.Attribute:
|
|
tpu_memory_space = _memory_space_to_tpu_memory_space(memory_space)
|
|
return ir.Attribute.parse(f"#tpu.memory_space<{tpu_memory_space}>")
|
|
|
|
def _dtype_to_ir_type(dtype: jnp.dtype,
|
|
is_kernel_boundary: bool = False) -> ir.Type:
|
|
if jnp.issubdtype(dtype, tpu_core.semaphore_dtype):
|
|
if jnp.issubdtype(dtype, tpu_core.dma_semaphore):
|
|
return ir.Type.parse("!tpu.dma_semaphore")
|
|
elif jnp.issubdtype(dtype, tpu_core.semaphore):
|
|
return ir.Type.parse("!tpu.semaphore")
|
|
elif jnp.issubdtype(dtype, tpu_core.barrier_semaphore):
|
|
return ir.Type.parse("!tpu.semaphore")
|
|
else:
|
|
raise NotImplementedError
|
|
if is_kernel_boundary and jnp.issubdtype(dtype, jnp.dtype('bool')):
|
|
dtype = BOOL_MEMREF_TYPE
|
|
# TODO(justinfu): Remove after mosaic supports unsigned types.
|
|
# This conversion makes mosaic interpret all unsigned types as signed types.
|
|
type = mlir.dtype_to_ir_type(dtype)
|
|
if isinstance(type, ir.IntegerType):
|
|
return ir.IntegerType.get_signless(type.width)
|
|
else:
|
|
return type
|
|
|
|
|
|
def aval_to_ir_type(
|
|
dynamic_shape_replacement_fn,
|
|
aval,
|
|
shape=None,
|
|
memory_space: MemorySpace | None = None,
|
|
is_kernel_boundary: bool = False,
|
|
):
|
|
if isinstance(aval, tpu_core.AbstractSemaphore):
|
|
if aval.sem_type is tpu_core.SemaphoreType.DMA:
|
|
sem_type = ir.Type.parse("!tpu.dma_semaphore")
|
|
elif aval.sem_type is tpu_core.SemaphoreType.REGULAR:
|
|
sem_type = ir.Type.parse("!tpu.semaphore")
|
|
elif aval.sem_type is tpu_core.SemaphoreType.BARRIER:
|
|
sem_type = ir.Type.parse("!tpu.semaphore")
|
|
else:
|
|
raise ValueError(f"Cannot allocate {aval.sem_type}.")
|
|
memspace = _memory_space_to_mosaic_attribute(TPUMemorySpace.SEMAPHORE)
|
|
return ir.MemRefType.get((), sem_type, memory_space=memspace)
|
|
if dtypes.issubdtype(aval.dtype, dtypes.prng_key):
|
|
shape = aval.dtype._impl.key_shape
|
|
if pl_random.is_pallas_impl(aval.dtype._impl):
|
|
if memory_space is None:
|
|
memory_space = TPUMemorySpace.SMEM
|
|
if memory_space != TPUMemorySpace.SMEM:
|
|
raise ValueError(
|
|
f"PRNG keys must be stored in SMEM. Got {memory_space}"
|
|
)
|
|
memspace = _memory_space_to_mosaic_attribute(memory_space)
|
|
return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)),
|
|
memory_space=memspace)
|
|
if isinstance(aval, state.AbstractRef):
|
|
if shape is None:
|
|
shape = aval.shape
|
|
memspace = _memory_space_to_mosaic_attribute(memory_space)
|
|
shape = dynamic_shape_replacement_fn(shape)
|
|
return ir.MemRefType.get(shape,
|
|
_dtype_to_ir_type(aval.dtype, is_kernel_boundary=True),
|
|
memory_space=memspace)
|
|
if isinstance(aval, jax_core.ShapedArray):
|
|
if shape is None:
|
|
shape = aval.shape
|
|
if not shape:
|
|
return _dtype_to_ir_type(
|
|
aval.dtype, is_kernel_boundary=is_kernel_boundary)
|
|
shape = dynamic_shape_replacement_fn(shape)
|
|
return ir.VectorType.get(
|
|
shape,
|
|
_dtype_to_ir_type(aval.dtype, is_kernel_boundary=is_kernel_boundary))
|
|
raise NotImplementedError(aval)
|
|
|
|
|
|
def ir_constant(x, mlir_type=None):
|
|
if not hasattr(x, "dtype"):
|
|
if isinstance(x, int):
|
|
x = np.array(x, np.int32)
|
|
elif isinstance(x, float):
|
|
x = np.array(x, np.float32)
|
|
if not mlir_type:
|
|
mlir_type = _dtype_to_ir_type(x.dtype)
|
|
if isinstance(x, int) or np.issubdtype(x.dtype, np.integer):
|
|
return arith.constant(mlir_type, ir.IntegerAttr.get(mlir_type, int(x)))
|
|
elif isinstance(x, float) or x.dtype == np.float32:
|
|
return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x)))
|
|
elif x.dtype == jnp.bfloat16:
|
|
return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x)))
|
|
elif x.dtype == jnp.bool_:
|
|
return arith.constant(mlir_type, ir.BoolAttr.get(bool(x)))
|
|
raise NotImplementedError(x.dtype)
|
|
|
|
|
|
lowering_rules = {}
|
|
skip_mlir_conversions = set()
|
|
|
|
def _get_aval_physical_dtype_shape(aval):
|
|
dtype_physical_shape = jax_core.physical_aval(aval).shape[
|
|
len(aval.shape) :
|
|
]
|
|
return dtype_physical_shape
|
|
|
|
|
|
def _get_arg_type(
|
|
dynamic_shape_replacement_fn: Callable[
|
|
[tuple[jax.DimSize, ...]], tuple[jax.DimSize, ...]
|
|
],
|
|
aval,
|
|
block_mapping: pallas_core.BlockMapping | None,
|
|
):
|
|
memory_space = None
|
|
if isinstance(aval, pallas_core.AbstractMemoryRef):
|
|
memory_space = aval.memory_space
|
|
# We assume unannotated memory refs are in VMEM
|
|
if memory_space is None:
|
|
memory_space = TPUMemorySpace.VMEM
|
|
if isinstance(aval, tpu_core.AbstractSemaphore):
|
|
return aval_to_ir_type(dynamic_shape_replacement_fn, aval), None
|
|
# TODO(necula): clean this None block_mapping
|
|
if block_mapping is None:
|
|
return (
|
|
aval_to_ir_type(
|
|
dynamic_shape_replacement_fn, aval, memory_space=memory_space
|
|
),
|
|
aval.shape,
|
|
)
|
|
shape = tuple(1 if b is pallas_core.mapped else b for b in block_mapping.block_shape)
|
|
return (
|
|
aval_to_ir_type(
|
|
dynamic_shape_replacement_fn,
|
|
aval,
|
|
shape=shape,
|
|
memory_space=memory_space,
|
|
),
|
|
block_mapping.block_shape,
|
|
)
|
|
|
|
|
|
def _canonicalize_dimension_semantic(
|
|
dimension_semantic: str | tpu_core.GridDimensionSemantics,
|
|
) -> str:
|
|
if isinstance(dimension_semantic, tpu_core.GridDimensionSemantics):
|
|
return dimension_semantic.value
|
|
return dimension_semantic
|
|
|
|
|
|
@dataclasses.dataclass(init=False)
|
|
class MosaicGridMapping:
|
|
grid: tuple[int, ...] | None
|
|
grid_names: tuple[Hashable, ...] | None
|
|
jaxpr: jax_core.Jaxpr
|
|
block_mappings: tuple[pallas_core.BlockMapping | None, ...]
|
|
mapped_dims: tuple[int, ...]
|
|
scalar_prefetch_types: tuple[ir.Type, ...]
|
|
operand_types: tuple[ir.Type, ...]
|
|
scratch_types: tuple[ir.Type, ...]
|
|
grid_types: tuple[ir.Type, ...]
|
|
scalar_prefetch_block_shapes: tuple[tuple[int, ...], ...]
|
|
operand_block_shapes: tuple[tuple[int, ...], ...]
|
|
scratch_block_shapes: tuple[tuple[int, ...], ...]
|
|
mesh_info: MeshInfo | None
|
|
get_grid_indices: Callable | None
|
|
|
|
def __init__(
|
|
self,
|
|
jaxpr: jax_core.Jaxpr,
|
|
grid_mapping: pallas_core.GridMapping,
|
|
dimension_semantics: tuple[str | tpu_core.GridDimensionSemantics, ...] | None,
|
|
mesh: mesh_lib.Mesh | None,
|
|
dynamic_shape_replacement_fn: Callable[
|
|
[tuple[jax.DimSize, ...]], tuple[int, ...]
|
|
],
|
|
):
|
|
self.grid = grid_mapping.grid
|
|
self.grid_names = grid_mapping.grid_names
|
|
self.jaxpr = jaxpr
|
|
self.block_mappings = grid_mapping.block_mappings
|
|
self.mapped_dims = grid_mapping.vmapped_dims
|
|
# TODO(mvoz): Generalize to not need this
|
|
user_grid = tuple(
|
|
g for i, g in enumerate(self.grid) if i not in self.mapped_dims
|
|
)
|
|
if dimension_semantics is None:
|
|
dimension_semantics = ("arbitrary",) * len(user_grid)
|
|
dimension_semantics = tuple(
|
|
_canonicalize_dimension_semantic(s) for s in dimension_semantics
|
|
)
|
|
if len(user_grid) != len(dimension_semantics):
|
|
raise ValueError(
|
|
"Must have dimension semantics for each dimension of the grid."
|
|
)
|
|
assert len(self.mapped_dims) + len(dimension_semantics) == len(
|
|
self.grid
|
|
), (
|
|
f"Misconfigured grid: {self.mapped_dims=}, {dimension_semantics=},"
|
|
f" {self.grid=}"
|
|
)
|
|
# dimension_semantics is user provided and won't take into account vmap
|
|
# dimensions. Here we add in parallel dimensions for the vmaps.
|
|
semantics_iter = iter(dimension_semantics)
|
|
self._dimension_semantics = tuple(
|
|
next(semantics_iter) if i not in self.mapped_dims else "parallel"
|
|
for i in range(len(self.grid))
|
|
)
|
|
|
|
in_avals = [invar.aval for invar in self.jaxpr.invars]
|
|
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
|
|
scalar_prefetch_avals = in_avals[grid_mapping.slice_index_ops]
|
|
operand_avals = in_avals[grid_mapping.slice_block_ops]
|
|
scratch_avals = in_avals[grid_mapping.slice_scratch_ops]
|
|
self.scalar_prefetch_types, _ = unzip2([
|
|
_get_arg_type(dynamic_shape_replacement_fn, aval, None)
|
|
for aval in scalar_prefetch_avals
|
|
])
|
|
self.scalar_prefetch_block_shapes = tuple(
|
|
aval.shape for aval in scalar_prefetch_avals)
|
|
self.operand_types, self.operand_block_shapes = unzip2([
|
|
_get_arg_type(dynamic_shape_replacement_fn, aval, block_mapping)
|
|
for aval, block_mapping in zip(operand_avals, self.block_mappings)
|
|
])
|
|
self.scratch_types, _ = unzip2([
|
|
_get_arg_type(dynamic_shape_replacement_fn, aval, None)
|
|
for aval in scratch_avals
|
|
])
|
|
self.scratch_block_shapes = tuple(
|
|
aval.shape if not isinstance(aval, tpu_core.AbstractSemaphore) else None
|
|
for aval in scratch_avals
|
|
)
|
|
self.grid_types, _ = unzip2([
|
|
_get_arg_type(
|
|
dynamic_shape_replacement_fn,
|
|
pallas_core.index_map_grid_aval,
|
|
None,
|
|
)
|
|
for _ in range(len(self.grid))
|
|
])
|
|
self._prepare_mesh_info(mesh)
|
|
|
|
if grid_mapping.get_grid_indices is None:
|
|
|
|
# Avoid using self.mapped_dims within the function, since doing so will
|
|
# introduce a self->_get_grid_indices->self reference cycle that means
|
|
# MosaicGridMapping instances can only ever be deleted by GC, rather than
|
|
# by their reference counts going to 0.
|
|
mapped_dims = self.mapped_dims
|
|
def _get_grid_indices(indices, maybe_include_mapped_dims: bool):
|
|
if maybe_include_mapped_dims:
|
|
return indices
|
|
return tuple(
|
|
idx for i, idx in enumerate(indices) if i not in mapped_dims
|
|
)
|
|
|
|
self.get_grid_indices = _get_grid_indices
|
|
else:
|
|
self.get_grid_indices = grid_mapping.get_grid_indices
|
|
|
|
def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None):
|
|
if not self.has_communication:
|
|
self.mesh_info = None
|
|
return
|
|
if mesh is None:
|
|
raise ValueError(
|
|
"Cannot use communication in pallas_call without shard_map."
|
|
)
|
|
axis_names = mesh.axis_names
|
|
if self.grid_names is not None:
|
|
if any(a in self.grid_names for a in axis_names):
|
|
raise ValueError(
|
|
"Cannot shadow axis mesh axis names with grid names. mesh axis"
|
|
f" names: {mesh.axis_names}, grid names: {self.grid_names}"
|
|
)
|
|
# We need mesh <-> logical translation tables. Since the logical IDs are
|
|
# just linearized versions of the mesh IDs, we create those tables.
|
|
mesh_strides = pallas_utils.strides_from_shape(tuple(
|
|
mesh.shape[a] for a in axis_names
|
|
))
|
|
mesh_shape = tuple(mesh.shape.values())
|
|
self.mesh_info = MeshInfo(mesh_shape, axis_names, mesh_strides)
|
|
|
|
def maybe_compress_grid(self):
|
|
# If we have many leading parallel dimensions, we should "compress" them
|
|
# into one so we can load balance across cores as best as we can.
|
|
# TODO(sharadmv): implement this optimization
|
|
pass
|
|
|
|
@functools.cached_property
|
|
def has_communication(self) -> bool:
|
|
nonlocal_axis_names = set()
|
|
def _get_nonlocal_axis_names(jaxpr: jax_core.Jaxpr):
|
|
return {
|
|
e.name
|
|
for e in jaxpr.effects
|
|
if isinstance(e, jax_core.NamedAxisEffect)
|
|
and (not self.grid_names or e.name not in self.grid_names)
|
|
}
|
|
nonlocal_axis_names.update(_get_nonlocal_axis_names(self.jaxpr))
|
|
for bm in self.block_mappings:
|
|
if bm is not None:
|
|
nonlocal_axis_names.update(_get_nonlocal_axis_names(bm.index_map_jaxpr))
|
|
return bool(nonlocal_axis_names)
|
|
|
|
def get_extra_args(self) -> tuple[Any, ...]:
|
|
return ()
|
|
|
|
def get_dimension_semantics(self) -> ir.ArrayAttr:
|
|
|
|
def _get_semantics(s: str | None) -> str:
|
|
if s is None:
|
|
return "#tpu.dimension_semantics<arbitrary>"
|
|
return f"#tpu.dimension_semantics<{s}>"
|
|
|
|
return ir.ArrayAttr.get(
|
|
map(
|
|
ir.Attribute.parse,
|
|
map(_get_semantics, self._dimension_semantics),
|
|
)
|
|
)
|
|
|
|
@dataclasses.dataclass
|
|
class MeshInfo:
|
|
mesh_shape: tuple[int, ...]
|
|
axis_names: list[str]
|
|
mesh_strides: tuple[int, ...]
|
|
|
|
|
|
def _check_block_mappings(
|
|
block_mappings: tuple[pallas_core.BlockMapping, ...],
|
|
lowering_context: mlir.LoweringRuleContext,
|
|
debug_info: jax_core.DebugInfo,
|
|
) -> None:
|
|
del lowering_context # originally needed for forward compat
|
|
for bm in block_mappings:
|
|
rank = len(bm.block_shape)
|
|
# TODO(necula): add tests for SMEM blocks with trivial windowing
|
|
# We support scalars too
|
|
if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SMEM and
|
|
bm.has_trivial_window()):
|
|
continue
|
|
if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SEMAPHORE:
|
|
continue
|
|
|
|
def err_details():
|
|
return (f"Block spec for {bm.origin} in pallas_call {debug_info.func_src_info} "
|
|
"has block shape "
|
|
f"{bm.block_shape}, array shape {bm.array_shape_dtype.shape}, "
|
|
# TODO(necula): add index_map source location info
|
|
f"and index_map {bm.index_map_jaxpr.jaxpr}, in "
|
|
f"memory space {bm.block_aval.memory_space}."
|
|
"\nSee details at https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec")
|
|
if rank < 1:
|
|
raise ValueError(
|
|
"The Pallas TPU lowering currently supports only blocks of "
|
|
"rank >= 1. " + err_details())
|
|
|
|
if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY and
|
|
not bm.has_trivial_window()):
|
|
raise ValueError(
|
|
"The Pallas TPU lowering currently supports in memory space ANY "
|
|
"only blocks having the same block shape as the array shape "
|
|
"and a trivial index_map (returning all 0s)." + err_details())
|
|
|
|
unmapped_bs = [
|
|
1 if bs is pallas_core.mapped else bs for bs in bm.block_shape]
|
|
bs0, as0 = unmapped_bs[-1], bm.array_shape_dtype.shape[-1]
|
|
if rank >= 2:
|
|
bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2]
|
|
else:
|
|
bs1, as1 = 1, 1
|
|
|
|
if rank >= 2:
|
|
evenly_divisible = (
|
|
(bs0 == as0 or bs0 % 128 == 0) and
|
|
(bs1 == as1 or bs1 % 8 == 0)
|
|
)
|
|
if not evenly_divisible:
|
|
extra_msg = ""
|
|
if pallas_core.dynamic_shapes_export_enabled():
|
|
extra_msg = (
|
|
" In dynamic shape export - your kernel symbolic args must be"
|
|
" annotated with constraints where the computation *after*"
|
|
" applying any grid mapping is divisible by 8 and 128"
|
|
" respectively. Ex: (mod(floordiv(m_dim, grid_size), 8) == 0))"
|
|
)
|
|
raise ValueError(
|
|
"The Pallas TPU lowering currently requires that the last two "
|
|
"dimensions of your block shape are divisible by 8 and 128 "
|
|
"respectively, or be equal to the respective dimensions of the "
|
|
"overall array. "
|
|
+ extra_msg
|
|
+ err_details()
|
|
)
|
|
else:
|
|
assert rank == 1
|
|
# bools get a bitwidth of 32 due to how mosaic handles them
|
|
if bm.array_shape_dtype.dtype == jnp.bool_:
|
|
bitwidth = 32
|
|
else:
|
|
bitwidth = lax_internal._bit_width(bm.array_shape_dtype.dtype)
|
|
packing = 32 // bitwidth
|
|
tiling_size = 128 * packing
|
|
evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0)
|
|
if not evenly_divisible:
|
|
raise ValueError(
|
|
"The Pallas TPU lowering currently requires that rank 1 block"
|
|
" shapes, either 1) the first (and only) dimension of the block"
|
|
" shape is equal to the first (and only) dimension of the array"
|
|
" shape, or 2) the first (and only) dimension of the block shape"
|
|
f" is a multiple of the tiling size ({tiling_size} = 128 * (32 //"
|
|
f" {lax_internal._bit_width(bm.array_shape_dtype.dtype)})) of the"
|
|
" array shape. "
|
|
+ err_details()
|
|
)
|
|
|
|
|
|
def lower_jaxpr_to_module(
|
|
lowering_context: mlir.LoweringRuleContext,
|
|
ctx: ir.Context,
|
|
grid_mapping: pallas_core.GridMapping,
|
|
jaxpr: jax_core.Jaxpr,
|
|
*,
|
|
dimension_semantics: (
|
|
tuple[str | tpu_core.GridDimensionSemantics, None, ...] | None
|
|
),
|
|
mesh: mesh_lib.Mesh | None = None,
|
|
for_verification: bool = False,
|
|
dynamic_shape_replacement_enabled: bool = False,
|
|
) -> tuple[Module, tuple[Any, ...]]:
|
|
# NOTE: We should bump this periodically
|
|
if is_cloud_tpu_older_than(2025, 1, 10):
|
|
raise RuntimeError(
|
|
"Pallas TPU requires a libTPU version that's at most a month old"
|
|
)
|
|
debug_info = jaxpr.debug_info
|
|
if dynamic_shape_replacement_enabled:
|
|
_mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv()
|
|
|
|
def dynamic_shape_replacement_fn(
|
|
shape: jax_core.Shape,
|
|
) -> tuple[int, ...]:
|
|
return tuple(
|
|
_mosaic_lowering_dynamic_shape_env.to_placeholder(dim_expr)
|
|
if jax_core.is_dim(dim_expr)
|
|
else dim_expr
|
|
for dim_expr in shape
|
|
)
|
|
|
|
else:
|
|
dynamic_shape_replacement_fn = lambda x: x
|
|
|
|
# Verify that we have legal block mappings to catch errors early.
|
|
_check_block_mappings(grid_mapping.block_mappings, lowering_context, debug_info)
|
|
|
|
mosaic_grid_mapping = MosaicGridMapping(
|
|
jaxpr,
|
|
grid_mapping,
|
|
dimension_semantics,
|
|
mesh,
|
|
dynamic_shape_replacement_fn,
|
|
)
|
|
mosaic_grid_mapping.maybe_compress_grid()
|
|
m = ir.Module.create()
|
|
attrs = m.operation.attributes
|
|
module_name = mlir.sanitize_name(debug_info.func_name)
|
|
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
|
sym_tab = ir.SymbolTable(m.operation)
|
|
|
|
func_op = lower_jaxpr_to_func(
|
|
ctx,
|
|
jaxpr,
|
|
mosaic_grid_mapping=mosaic_grid_mapping,
|
|
name="main",
|
|
for_verification=for_verification,
|
|
forward_compatible=lowering_context.is_forward_compat(),
|
|
dynamic_shape_replacement_fn=dynamic_shape_replacement_fn,
|
|
)
|
|
m.body.append(func_op)
|
|
sym_tab.insert(func_op)
|
|
window_params = []
|
|
grid = mosaic_grid_mapping.grid
|
|
if grid:
|
|
for i, bm in enumerate(grid_mapping.block_mappings):
|
|
func_name = f"transform_{i}"
|
|
# ANY and SEMAPHORE operands don't support windowing and require empty window_params.
|
|
tpu_memory_space = _memory_space_to_tpu_memory_space(
|
|
bm.block_aval.memory_space)
|
|
if (
|
|
tpu_memory_space == tpu_core.TPUMemorySpace.ANY
|
|
or tpu_memory_space == tpu_core.TPUMemorySpace.SEMAPHORE
|
|
):
|
|
# We checked above that the block does not require windowing.
|
|
window_params.append(ir.DictAttr.get())
|
|
continue
|
|
|
|
mlir_func = lower_jaxpr_to_transform_func(
|
|
ctx,
|
|
bm.index_map_jaxpr.jaxpr,
|
|
bm.block_aval,
|
|
name=func_name,
|
|
mosaic_grid_mapping=mosaic_grid_mapping,
|
|
for_verification=for_verification,
|
|
forward_compatible=lowering_context.is_forward_compat(),
|
|
dynamic_shape_replacement_fn=dynamic_shape_replacement_fn,
|
|
)
|
|
assert mlir_func.verify(), mlir_func
|
|
block_shape = [
|
|
1 if b is pallas_core.mapped else b for b in bm.block_shape
|
|
]
|
|
# If we have an extended dtype, we need to add the block shape for the
|
|
# remaining physical dtype.
|
|
block_shape += list(_get_aval_physical_dtype_shape(bm.block_aval.inner_aval))
|
|
block_shape = dynamic_shape_replacement_fn(block_shape)
|
|
window_shape = ir.DenseI64ArrayAttr.get(block_shape)
|
|
block_params = dict(
|
|
window_bounds=window_shape,
|
|
transform_indices=ir.FlatSymbolRefAttr.get(func_name),
|
|
)
|
|
if isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
|
if bm.indexing_mode.padding is None:
|
|
pad_low = pad_high = [0] * len(bm.block_shape)
|
|
else:
|
|
pad_low, pad_high = map(list, zip(*bm.indexing_mode.padding))
|
|
block_params["window_kind"] = ir.Attribute.parse(
|
|
f"#tpu.element_window<{pad_low},{pad_high}>"
|
|
)
|
|
if bm.pipeline_mode is not None:
|
|
if not isinstance(bm.pipeline_mode, pallas_core.Buffered):
|
|
raise LoweringException(
|
|
f"Unsupported pipeline mode: {bm.pipeline_mode}."
|
|
)
|
|
buffer_count = bm.pipeline_mode.buffer_count
|
|
if buffer_count < 1 or buffer_count > 2:
|
|
raise LoweringException(
|
|
"Only single (1) and double (2) buffering are supported. Got"
|
|
f" {buffer_count}."
|
|
)
|
|
pipeline_mode = "synchronous" if buffer_count == 1 else "double_buffered"
|
|
block_params["pipeline_mode"] = ir.Attribute.parse(
|
|
f"#tpu.pipeline_mode<{pipeline_mode}>"
|
|
)
|
|
window_params.append(ir.DictAttr.get(block_params))
|
|
m.body.append(mlir_func)
|
|
sym_tab.insert(mlir_func)
|
|
func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params)
|
|
|
|
static_grid = [
|
|
MLIR_DYNAMIC if b is pallas_core.dynamic_grid_dim else b for b in grid
|
|
]
|
|
static_grid = dynamic_shape_replacement_fn(static_grid)
|
|
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)
|
|
|
|
func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
|
|
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types))
|
|
func_op.attributes["scratch_operands"] = ir.IntegerAttr.get(
|
|
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scratch_types))
|
|
func_op.attributes["dimension_semantics"] = (
|
|
mosaic_grid_mapping.get_dimension_semantics()
|
|
)
|
|
return m, mosaic_grid_mapping.get_extra_args()
|
|
|
|
|
|
def lower_jaxpr_to_transform_func(
|
|
ctx: ir.Context,
|
|
jaxpr: jax_core.Jaxpr,
|
|
aval: jax_core.AbstractValue,
|
|
*,
|
|
name: str,
|
|
mosaic_grid_mapping: MosaicGridMapping,
|
|
for_verification: bool,
|
|
forward_compatible: bool,
|
|
dynamic_shape_replacement_fn: (
|
|
Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None
|
|
) = None,
|
|
) -> func.FuncOp:
|
|
num_grid = len(mosaic_grid_mapping.grid_types)
|
|
arg_types = [
|
|
*mosaic_grid_mapping.grid_types,
|
|
*mosaic_grid_mapping.scalar_prefetch_types,
|
|
]
|
|
def body_func(*args):
|
|
grid_indices, scalar_prefetch = split_list(args, [num_grid])
|
|
jaxpr_indices = mosaic_grid_mapping.get_grid_indices(
|
|
grid_indices, maybe_include_mapped_dims=True
|
|
)
|
|
arg_block_shapes = [
|
|
*[()] * len(jaxpr_indices),
|
|
*mosaic_grid_mapping.scalar_prefetch_block_shapes,
|
|
]
|
|
|
|
mesh_info = mosaic_grid_mapping.mesh_info
|
|
if mesh_info is not None:
|
|
mesh_context = MeshContext(
|
|
mesh_info.mesh_shape, mesh_info.axis_names, mesh_info.mesh_strides
|
|
)
|
|
else:
|
|
mesh_context = None
|
|
lowering_context = LoweringContext(
|
|
ctx,
|
|
mosaic_grid_mapping.grid,
|
|
mosaic_grid_mapping.grid_names,
|
|
mosaic_grid_mapping.mapped_dims,
|
|
None,
|
|
arg_block_shapes,
|
|
source_info_util.NameStack(),
|
|
mesh_context=mesh_context,
|
|
traceback_caches=mlir.TracebackCaches(),
|
|
for_verification=for_verification,
|
|
forward_compatible=forward_compatible,
|
|
dynamic_shape_replacement_fn=dynamic_shape_replacement_fn,
|
|
)
|
|
out = jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices,
|
|
*scalar_prefetch)
|
|
assert isinstance(aval, state.AbstractRef), aval
|
|
# If we have an extended dtype, we need to add 0s for the block indices
|
|
# for the remaining physical dtype.
|
|
out += [
|
|
ir_constant(0, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
|
|
] * len(_get_aval_physical_dtype_shape(aval.inner_aval))
|
|
return out
|
|
|
|
body_func.__name__ = name
|
|
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
|
|
try:
|
|
body.func_op.verify()
|
|
except ir.MLIRError as e:
|
|
raise error_handling.mlir_error_to_verification_error(e) from e
|
|
return body.func_op
|
|
|
|
|
|
def lower_jaxpr_to_func(
|
|
ctx: ir.Context,
|
|
jaxpr: jax_core.Jaxpr,
|
|
*,
|
|
mosaic_grid_mapping: MosaicGridMapping,
|
|
name: str,
|
|
for_verification: bool,
|
|
forward_compatible: bool,
|
|
dynamic_shape_replacement_fn: (
|
|
Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None
|
|
) = None,
|
|
) -> func.FuncOp:
|
|
num_grid = len(mosaic_grid_mapping.grid_types)
|
|
num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types)
|
|
arg_types = [
|
|
*mosaic_grid_mapping.grid_types,
|
|
*mosaic_grid_mapping.scalar_prefetch_types,
|
|
*mosaic_grid_mapping.operand_types,
|
|
*mosaic_grid_mapping.scratch_types,
|
|
]
|
|
arg_block_shapes = [
|
|
*mosaic_grid_mapping.scalar_prefetch_block_shapes,
|
|
*mosaic_grid_mapping.operand_block_shapes,
|
|
*mosaic_grid_mapping.scratch_block_shapes,
|
|
]
|
|
def body_func(*args):
|
|
grid_indices, scalar_prefetch, operands_and_scratch = split_list(
|
|
args, [num_grid, num_scalar_prefetch])
|
|
jaxpr_indices = mosaic_grid_mapping.get_grid_indices(
|
|
grid_indices, maybe_include_mapped_dims=False
|
|
)
|
|
mesh_info = mosaic_grid_mapping.mesh_info
|
|
if mesh_info is not None:
|
|
mesh_context = MeshContext(
|
|
mesh_info.mesh_shape, mesh_info.axis_names, mesh_info.mesh_strides
|
|
)
|
|
else:
|
|
mesh_context = None
|
|
lowering_context = LoweringContext(
|
|
ctx,
|
|
mosaic_grid_mapping.grid,
|
|
mosaic_grid_mapping.grid_names,
|
|
mosaic_grid_mapping.mapped_dims,
|
|
jaxpr_indices,
|
|
arg_block_shapes,
|
|
source_info_util.NameStack(),
|
|
mesh_context=mesh_context,
|
|
traceback_caches=mlir.TracebackCaches(),
|
|
for_verification=for_verification,
|
|
forward_compatible=forward_compatible,
|
|
dynamic_shape_replacement_fn=dynamic_shape_replacement_fn,
|
|
)
|
|
return jaxpr_subcomp(
|
|
lowering_context, jaxpr, *scalar_prefetch, *operands_and_scratch
|
|
)
|
|
body_func.__name__ = name
|
|
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
|
|
try:
|
|
body.func_op.verify()
|
|
except ir.MLIRError as e:
|
|
raise error_handling.mlir_error_to_verification_error(e) from e
|
|
return body.func_op
|
|
|
|
|
|
def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable:
|
|
def f_lowered(ctx: LoweringRuleContext, *args, **params):
|
|
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
|
wrapped_fun = lu.wrap_init(
|
|
f, params,
|
|
debug_info=api_util.debug_info("mosaic lower_fun", f,
|
|
args, params))
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
|
if consts:
|
|
raise NotImplementedError
|
|
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
|
lowering_context = ctx.lowering_context.replace(
|
|
block_shapes=ctx.block_shapes)
|
|
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
|
|
if not multiple_results:
|
|
return out[0]
|
|
return out
|
|
|
|
return f_lowered
|
|
|
|
|
|
class LoweringException(Exception):
|
|
pass
|
|
|
|
|
|
def _compute_name_stack_updates(
|
|
old_name_stack: list[str],
|
|
new_name_stack: list[str]
|
|
) -> tuple[list[str], list[str]]:
|
|
"""Computes the popped/pushed items to the name stack after an update.
|
|
|
|
Args:
|
|
old_name_stack: The name stack prior to the update.
|
|
new_name_stack: The name stack after the update.
|
|
|
|
Returns:
|
|
popped: A list of names popped from the name stack as part of the update.
|
|
pushed: A list of names pushed to the name stack as part of the update.
|
|
"""
|
|
common_prefix_idx = 0
|
|
for i, (old, new) in enumerate(unsafe_zip(old_name_stack, new_name_stack)):
|
|
if old == new:
|
|
common_prefix_idx = i+1
|
|
else:
|
|
break
|
|
return old_name_stack[common_prefix_idx:], new_name_stack[common_prefix_idx:]
|
|
|
|
|
|
def jaxpr_subcomp(
|
|
ctx: LoweringContext, jaxpr: jax_core.Jaxpr, *args: ir.Value
|
|
) -> Sequence[ir.Value]:
|
|
assert not jaxpr.constvars
|
|
env = {}
|
|
block_shape_env = {}
|
|
|
|
def read_block_shape(atom: jax_core.Atom):
|
|
if isinstance(atom, jax_core.Literal):
|
|
return None
|
|
return block_shape_env.get(atom, None)
|
|
|
|
def read_env(atom: jax_core.Atom):
|
|
return atom.val if isinstance(atom, jax_core.Literal) else env[atom]
|
|
|
|
def write_env(var: jax_core.Var, val):
|
|
is_valid_type = isinstance(val, (ir.Value, KeyScalarBundle))
|
|
assert is_valid_type, type(val)
|
|
env[var] = val
|
|
|
|
for invar, bs in zip(jaxpr.invars, ctx.block_shapes):
|
|
block_shape_env[invar] = bs
|
|
map(write_env, jaxpr.invars, args)
|
|
|
|
initial_name_stack = [scope.name for scope in ctx.name_stack.stack]
|
|
current_name_stack: list[str] = []
|
|
# TODO(justinfu): Handle transform scopes.
|
|
current_name_stack.extend(initial_name_stack)
|
|
for eqn in jaxpr.eqns:
|
|
invals = map(read_env, eqn.invars)
|
|
source_info = eqn.source_info.replace(
|
|
name_stack=ctx.name_stack + eqn.source_info.name_stack
|
|
)
|
|
loc = mlir._source_info_to_location(ctx, eqn.primitive, source_info)
|
|
with (source_info_util.user_context(eqn.source_info.traceback), loc,
|
|
eqn.ctx.manager):
|
|
if eqn.primitive in lowering_rules:
|
|
if eqn.primitive not in skip_mlir_conversions:
|
|
invals = [_ensure_mlir_value(x, v.aval)
|
|
for x, v in zip(invals, eqn.invars)]
|
|
block_shapes = map(read_block_shape, eqn.invars)
|
|
rule_context = LoweringRuleContext(
|
|
ctx,
|
|
[v.aval for v in eqn.invars],
|
|
[v.aval for v in eqn.outvars],
|
|
block_shapes,
|
|
)
|
|
|
|
# Insert trace_start and trace_stop ops on named_scope boundaries.
|
|
name_stack = [scope.name for scope in source_info.name_stack.stack]
|
|
popped, pushed = _compute_name_stack_updates(
|
|
current_name_stack, name_stack)
|
|
current_name_stack = name_stack
|
|
for _ in popped:
|
|
tpu.TraceStopOp()
|
|
for name in pushed:
|
|
tpu.TraceStartOp(message=name, level=10)
|
|
|
|
try:
|
|
ans = lowering_rules[eqn.primitive](
|
|
rule_context, *invals, **eqn.params
|
|
)
|
|
except LoweringException:
|
|
raise # We only add the extra info to the innermost exception.
|
|
except Exception as e:
|
|
if not pallas_call._verbose_errors_enabled():
|
|
raise
|
|
msg = (f"{type(e).__name__}: {e}\n" +
|
|
"Additional diagnostics: \n" +
|
|
f"Failing jaxpr equation: {eqn}\n")
|
|
new_error = LoweringException(msg)
|
|
# We insert the traceback here so that the user code shows
|
|
# up in the traceback for the post-transform error.
|
|
if source_info.traceback is not None:
|
|
tb = source_info.traceback.as_python_traceback()
|
|
new_error.__traceback__ = traceback_util.filter_traceback(tb)
|
|
raise new_error from e
|
|
else:
|
|
raise NotImplementedError(
|
|
"Unimplemented primitive in Pallas TPU lowering: "
|
|
f"{eqn.primitive.name}. "
|
|
"Please file an issue on https://github.com/jax-ml/jax/issues.")
|
|
if eqn.primitive.multiple_results:
|
|
map(write_env, eqn.outvars, ans)
|
|
else:
|
|
write_env(eqn.outvars[0], ans)
|
|
|
|
# Drain the name stack at the end of a jaxpr and insert trace_stop ops.
|
|
popped, pushed = _compute_name_stack_updates(
|
|
current_name_stack, initial_name_stack)
|
|
for _ in popped:
|
|
tpu.TraceStopOp()
|
|
assert len(pushed) == 0
|
|
|
|
outvals = map(read_env, jaxpr.outvars)
|
|
outvals = [
|
|
ir_constant(x) if isinstance(var, jax_core.Literal) else x
|
|
for x, var in zip(outvals, jaxpr.outvars)
|
|
]
|
|
return outvals
|
|
|
|
|
|
def _ensure_mlir_value(val, aval):
|
|
if isinstance(val, ir.Value):
|
|
return val
|
|
if isinstance(val, KeyScalarBundle):
|
|
return val
|
|
elif isinstance(val, (np.generic, np.ndarray, int, float)):
|
|
return ir_constant(val, _dtype_to_ir_type(aval.dtype))
|
|
else:
|
|
raise RuntimeError(
|
|
f"Unsupported argument to a JAX primitive of type: {type(val)}"
|
|
)
|
|
|
|
|
|
def _get_lowering_rule(
|
|
ctx: LoweringRuleContext, ref, *idx, tree,
|
|
):
|
|
indexers = tree_util.tree_unflatten(tree, idx)
|
|
indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:])
|
|
# Call _load_lowering_rule (since it's more general)
|
|
ref_aval, *_ = ctx.avals_in
|
|
args_flat, args_tree = tree_util.tree_flatten((ref, indexers, None, None))
|
|
avals_flat = tree_util.tree_leaves((ref_aval, indexers_avals, None, None))
|
|
ctx = ctx.replace(
|
|
avals_in=avals_flat,
|
|
block_shapes=[ctx.block_shapes[0], *[None] * (len(avals_flat) - 1)],
|
|
)
|
|
return _load_lowering_rule(ctx, *args_flat, args_tree=args_tree)
|
|
|
|
|
|
lowering_rules[state_primitives.get_p] = _get_lowering_rule
|
|
skip_mlir_conversions.add(state_primitives.get_p)
|
|
|
|
|
|
def _swap_lowering_rule(
|
|
ctx: LoweringRuleContext,
|
|
ref,
|
|
val,
|
|
*idx,
|
|
tree
|
|
):
|
|
indexers = tree_util.tree_unflatten(tree, idx)
|
|
indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[2:])
|
|
# Call _masked_swap_lowering_rule (since it's more general)
|
|
ref_aval, val_aval, *_ = ctx.avals_in
|
|
args_flat, args_tree = tree_util.tree_flatten((ref, indexers, val, None))
|
|
avals_flat = tree_util.tree_leaves(
|
|
(ref_aval, indexers_avals, val_aval, None)
|
|
)
|
|
ctx = ctx.replace(
|
|
avals_in=avals_flat,
|
|
block_shapes=[ctx.block_shapes[0], *[None] * (len(avals_flat) - 1)],
|
|
)
|
|
return _masked_swap_lowering_rule(ctx, *args_flat, args_tree=args_tree)
|
|
|
|
lowering_rules[state_primitives.swap_p] = _swap_lowering_rule
|
|
skip_mlir_conversions.add(state_primitives.swap_p)
|
|
|
|
|
|
def _make_index(s):
|
|
if isinstance(s, (int, np.ndarray)):
|
|
return ir_constant(s, ir.IndexType.get())
|
|
if s.type == ir.IndexType.get():
|
|
return s
|
|
return arith.index_cast(ir.IndexType.get(), s)
|
|
|
|
|
|
def _maybe_cast_to_index(cast_to_index, x):
|
|
if cast_to_index:
|
|
return _make_index(x)
|
|
return _ensure_mlir_value(x, aval=pallas_core.index_map_grid_aval)
|
|
|
|
|
|
def _index_to_start_size_stride(
|
|
idx: tuple[indexing.Slice | int | ir.Value, ...], cast_to_index: bool
|
|
) -> tuple[ir.Value, int | ir.Value, int, bool]:
|
|
assert not isinstance(idx, slice)
|
|
if isinstance(idx, indexing.Slice):
|
|
start = _maybe_cast_to_index(cast_to_index, idx.start)
|
|
size = idx.size
|
|
stride = idx.stride
|
|
squeeze = False
|
|
elif isinstance(idx, int):
|
|
start = _maybe_cast_to_index(cast_to_index, idx)
|
|
size = 1
|
|
stride = 1
|
|
squeeze = True
|
|
else:
|
|
if np.shape(idx):
|
|
raise ValueError(f"Can only use ()-shaped and slice indexing: {idx}")
|
|
start = _maybe_cast_to_index(cast_to_index, idx)
|
|
size = 1
|
|
stride = 1
|
|
squeeze = True
|
|
return start, size, stride, squeeze
|
|
|
|
|
|
def _indexer_to_start_size_stride(
|
|
indexer: NDIndexer,
|
|
ref_block_shape: tuple[int | pallas_core.Mapped, ...],
|
|
*,
|
|
cast_to_index: bool,
|
|
) -> tuple[
|
|
tuple[ir.Value, ...],
|
|
tuple[int | ir.Value, ...],
|
|
tuple[int, ...],
|
|
tuple[bool, ...],
|
|
tuple[int | pallas_core.Mapped, ...],
|
|
]:
|
|
indices_iter = iter(indexer.indices)
|
|
starts, sizes, strides, squeeze_dims = [], [], [], []
|
|
for s in ref_block_shape:
|
|
start, size, stride, squeeze_dim = (
|
|
(
|
|
_maybe_cast_to_index(cast_to_index, 0),
|
|
1,
|
|
1,
|
|
True,
|
|
)
|
|
if s is pallas_core.mapped
|
|
else _index_to_start_size_stride(next(indices_iter), cast_to_index)
|
|
)
|
|
starts.append(start)
|
|
sizes.append(size)
|
|
strides.append(stride)
|
|
squeeze_dims.append(squeeze_dim)
|
|
next_index = next(indices_iter, None)
|
|
assert next_index is None, (indexer.indices, ref_block_shape)
|
|
new_ref_block_shape = tuple(s for s, squeeze in zip(sizes, squeeze_dims)
|
|
if not squeeze)
|
|
return (
|
|
tuple(starts),
|
|
tuple(sizes),
|
|
tuple(strides),
|
|
tuple(squeeze_dims),
|
|
new_ref_block_shape,
|
|
)
|
|
|
|
|
|
def _slice_memref(
|
|
ref: ir.Value,
|
|
indexer: NDIndexer,
|
|
ref_dtype: DTypeLike,
|
|
ref_block_shape: tuple[int | pallas_core.Mapped, ...],
|
|
) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...]]:
|
|
assert ref_block_shape is not None
|
|
target_shape = indexer.get_indexer_shape()
|
|
starts, sizes, strides, squeeze_dims, ref_block_shape = (
|
|
_indexer_to_start_size_stride(
|
|
indexer,
|
|
ref_block_shape,
|
|
cast_to_index=False,
|
|
)
|
|
)
|
|
if not all((s is None or s == 1) for s in strides):
|
|
raise NotImplementedError("Strided slices of references are unsupported.")
|
|
dynamic_sizes = tuple(s for s in sizes if isinstance(s, ir.Value))
|
|
ir_dynamic_size = ir.ShapedType.get_dynamic_size()
|
|
static_sizes = tuple(s if not isinstance(s, ir.Value)
|
|
else ir_dynamic_size for s in sizes)
|
|
target_ref_ty = ir.MemRefType.get(
|
|
static_sizes,
|
|
_dtype_to_ir_type(ref_dtype),
|
|
memory_space=ref.type.memory_space,
|
|
)
|
|
out = tpu.memref_slice(target_ref_ty, ref, starts, dynamic_sizes)
|
|
if any(squeeze_dims):
|
|
# We need to squeeze out some dimensions
|
|
static_sizes = tuple(s if not isinstance(s, ir.Value)
|
|
else ir_dynamic_size for s in target_shape)
|
|
squeezed_ref_ty = ir.MemRefType.get(
|
|
static_sizes,
|
|
_dtype_to_ir_type(ref_dtype),
|
|
memory_space=ref.type.memory_space,
|
|
)
|
|
out = tpu.memref_squeeze(squeezed_ref_ty, out)
|
|
return out, ref_block_shape
|
|
|
|
|
|
def _bitcast_memref(
|
|
ref: ir.Value,
|
|
bitcaster: RefBitcaster,
|
|
ref_dtype: DTypeLike,
|
|
ref_block_shape: tuple[int | pallas_core.Mapped, ...],
|
|
) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]:
|
|
src_bitwidth = dtype_bitwidth(ref_dtype)
|
|
dst_bitwidth = dtype_bitwidth(bitcaster.dtype)
|
|
if src_bitwidth != dst_bitwidth:
|
|
if len(ref_block_shape) < 2:
|
|
raise NotImplementedError(
|
|
"Bitcast 1D ref with bitwidth change is not supported."
|
|
)
|
|
if ref_block_shape[-2] is pallas_core.mapped:
|
|
raise NotImplementedError(
|
|
"Bitcast a ref whose 2nd minormost dimension is squeezed when"
|
|
" bitwidth changes."
|
|
)
|
|
new_ref_dtype = bitcaster.dtype
|
|
target_ref_ty = ir.MemRefType.get(
|
|
bitcaster.shape,
|
|
_dtype_to_ir_type(new_ref_dtype),
|
|
memory_space=ref.type.memory_space,
|
|
)
|
|
new_ref_block_shape = list(ref_block_shape)
|
|
if (
|
|
len(new_ref_block_shape) >= 2
|
|
and new_ref_block_shape[-2] is not pallas_core.mapped
|
|
):
|
|
new_ref_block_shape[-2] = (
|
|
new_ref_block_shape[-2] * src_bitwidth // dst_bitwidth
|
|
)
|
|
return (
|
|
tpu.memref_bitcast(target_ref_ty, ref),
|
|
new_ref_dtype,
|
|
tuple(new_ref_block_shape),
|
|
)
|
|
|
|
|
|
def _reshape_memref(
|
|
ref: ir.Value,
|
|
reshaper: RefReshaper,
|
|
ref_dtype: DTypeLike,
|
|
ref_block_shape: tuple[int | pallas_core.Mapped, ...],
|
|
) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]:
|
|
if ref_dtype != reshaper.dtype:
|
|
raise ValueError(
|
|
f"Reshape a ref with dtype change: {reshaper.dtype} vs {ref_dtype}"
|
|
)
|
|
if len(ref_block_shape) < 2:
|
|
raise NotImplementedError("Reshape 1D ref is not supported.")
|
|
if (
|
|
ref_block_shape[-2] is pallas_core.mapped
|
|
or ref_block_shape[-1] is pallas_core.mapped
|
|
):
|
|
raise NotImplementedError(
|
|
"Reshape a ref with squeezed dimension on last two dimensions."
|
|
)
|
|
if np.prod(ref_block_shape) != np.prod(reshaper.shape):
|
|
raise ValueError(
|
|
f"Reshape a ref with different number of elements: {ref_block_shape} "
|
|
f"vs {reshaper.shape}"
|
|
)
|
|
target_ref_ty = ir.MemRefType.get(
|
|
reshaper.shape,
|
|
_dtype_to_ir_type(reshaper.dtype),
|
|
memory_space=ref.type.memory_space,
|
|
)
|
|
return (
|
|
tpu.memref_reshape(target_ref_ty, ref),
|
|
reshaper.shape,
|
|
)
|
|
|
|
|
|
def _transform_ref(ref, ref_dtype, ref_block_shape, transforms):
|
|
for transform in transforms:
|
|
match transform:
|
|
case NDIndexer():
|
|
ref, ref_block_shape = _slice_memref(
|
|
ref, transform, ref_dtype, ref_block_shape
|
|
)
|
|
case RefBitcaster():
|
|
ref, ref_dtype, ref_block_shape = _bitcast_memref(
|
|
ref, transform, ref_dtype, ref_block_shape
|
|
)
|
|
case RefReshaper():
|
|
ref, ref_block_shape = _reshape_memref(
|
|
ref, transform, ref_dtype, ref_block_shape
|
|
)
|
|
case _:
|
|
raise NotImplementedError(f"Unsupported transform: {transform}")
|
|
return ref, ref_block_shape
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class KeyScalarBundle:
|
|
"""A container class for PRNG key data.
|
|
|
|
We pass around keys as a KeyScalarBundle in the lowering pass rather than
|
|
as a vector, since we want the key data to live in scalar registers rather
|
|
than vector registers. This special dataclass exists so we can return
|
|
multiple scalar values from load_op, because the load_op primitive does
|
|
not allow multiple results.
|
|
|
|
Attributes:
|
|
scalars: A list of OpResults representing scalar key data during the
|
|
lowering pass.
|
|
"""
|
|
key_shape: tuple[int, ...]
|
|
scalars: list[ir.OpResult]
|
|
|
|
def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
|
|
ref, transforms, mask, _ = args_tree.unflatten(args_flat)
|
|
ref_aval, transforms_avals, _, _ = args_tree.unflatten(ctx.avals_in)
|
|
(*prev_transforms, idx) = transforms
|
|
# Select last aval, which is the one that will be used for the load.
|
|
(*_, idx_aval) = transforms_avals
|
|
|
|
if mask is not None:
|
|
raise NotImplementedError
|
|
|
|
ref_block_shape, *_ = ctx.block_shapes
|
|
ref, ref_block_shape = _transform_ref(
|
|
ref, ref_aval.dtype, ref_block_shape, prev_transforms
|
|
)
|
|
ref_type = ir.MemRefType(ref.type)
|
|
is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space<smem>"
|
|
(aval_out,) = ctx.avals_out
|
|
if isinstance(aval_out.dtype, prng.KeyTy) and pl_random.is_pallas_impl(
|
|
aval_out.dtype._impl
|
|
):
|
|
if not is_smem_load:
|
|
raise ValueError("PRNG keys must be loaded from SMEM. Did you set "
|
|
"the memory space to TPUMemorySpace.SMEM in the "
|
|
"BlockSpec for the PRNG key input?")
|
|
return _prng_key_load_lowering_rule(ctx, *args_flat, args_tree=args_tree)
|
|
if not is_smem_load and not ref_block_shape:
|
|
raise NotImplementedError(
|
|
"Indexing into a ()-shaped Ref not yet supported on TPU.")
|
|
if any(
|
|
(not isinstance(a, primitives.Slice) and a.shape)
|
|
for a in idx_aval.indices
|
|
):
|
|
raise ValueError("Cannot do int indexing on TPU")
|
|
starts, sizes, strides, _, _ = _indexer_to_start_size_stride(
|
|
idx,
|
|
ref_block_shape,
|
|
cast_to_index=True,
|
|
)
|
|
need_stride = not all((s is None or s == 1) for s in strides)
|
|
if is_smem_load:
|
|
if ctx.avals_out[0].shape:
|
|
raise ValueError("Can only load scalars from SMEM")
|
|
return _maybe_cast_load_to_bool(ctx, aval_out, memref.load(ref, starts))
|
|
elif str(ref_type.memory_space) != "#tpu.memory_space<vmem>":
|
|
extra = ""
|
|
if str(ref_type.memory_space) == "#tpu.memory_space<any>":
|
|
extra = " ANY memory space can only be accessed using async_copy."
|
|
raise ValueError(
|
|
"Loads are only allowed on VMEM and SMEM references." + extra
|
|
)
|
|
load_aval = jax_core.ShapedArray(sizes, dtype=aval_out.dtype)
|
|
if need_stride:
|
|
load_val = tpu.strided_load(
|
|
aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
load_aval,
|
|
is_kernel_boundary=True,
|
|
),
|
|
ref,
|
|
starts,
|
|
strides,
|
|
)
|
|
else:
|
|
load_val = vector.load(
|
|
aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
load_aval,
|
|
is_kernel_boundary=True,
|
|
),
|
|
ref,
|
|
starts,
|
|
)
|
|
if load_aval != aval_out:
|
|
vec_type = ir.VectorType.get(aval_out.shape,
|
|
_dtype_to_ir_type(aval_out.dtype,
|
|
is_kernel_boundary=True))
|
|
load_val = vector.shape_cast(vec_type, load_val)
|
|
return _maybe_cast_load_to_bool(ctx, aval_out, load_val)
|
|
|
|
def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle:
|
|
"""Lowering rule for loading PRNG keys from SMEM.
|
|
|
|
PRNG key loads are currently lowered as a list of scalar loads from SMEM,
|
|
rather than a single vector load.
|
|
We store these scalars in a bundle type called KeyScalarBundle, which has
|
|
special case handling for functions that consume the key such as set_seed.
|
|
"""
|
|
ref, _, _, _ = args_tree.unflatten(args_flat)
|
|
(aval_out,) = ctx.avals_out
|
|
assert isinstance(aval_out.dtype, prng.KeyTy)
|
|
ref_block_shape = aval_out.dtype._impl.key_shape
|
|
|
|
if len(ref_block_shape) != 2:
|
|
raise NotImplementedError("Seed key_data must be 2D.")
|
|
if tuple(ref_block_shape) != (1, 1):
|
|
raise NotImplementedError(
|
|
f"Seed key_data of shape != (1, 1) not supported. Got: {ref_block_shape}")
|
|
|
|
load_ops = []
|
|
for i in range(ref_block_shape[0]):
|
|
idx = NDIndexer(indices=(0, i), shape=ref_block_shape,
|
|
int_indexer_shape=tuple())
|
|
starts, _, _, _, _ = _indexer_to_start_size_stride(
|
|
idx,
|
|
ref_block_shape,
|
|
cast_to_index=True,
|
|
)
|
|
load_ops.append(memref.load(ref, starts))
|
|
return KeyScalarBundle(scalars=load_ops, key_shape=tuple(ref_block_shape))
|
|
|
|
|
|
lowering_rules[primitives.load_p] = _load_lowering_rule
|
|
skip_mlir_conversions.add(primitives.load_p)
|
|
|
|
|
|
def _maybe_cast_load_to_bool(
|
|
ctx, out_aval, val: ir.Value
|
|
) -> tuple[ir.Value, jnp.dtype]:
|
|
"""Casts a memref load value to bool if the requested value is a bool.
|
|
|
|
Mosaic does not support boolean-type memrefs, since booleans
|
|
typically live in mask registers. We instead load booleans as integers from
|
|
memrefs and move them to mask registers on load using this function.
|
|
|
|
Args:
|
|
out_aval: The output aval of the load.
|
|
val: The input value.
|
|
|
|
Returns:
|
|
The loaded value, and the JAX dtype of the input value.
|
|
"""
|
|
if out_aval.dtype != jnp.bool_:
|
|
return val
|
|
load_scalar_type = _dtype_to_ir_type(BOOL_MEMREF_TYPE)
|
|
pred = _cmpsi_lowering_types[lax.ne_p]
|
|
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
|
|
const_zero = ir.IntegerAttr.get(load_scalar_type, 0)
|
|
if out_aval.shape: # Vector case.
|
|
load_vector_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
out_aval,
|
|
is_kernel_boundary=True,
|
|
)
|
|
vector_zeros = arith.ConstantOp(
|
|
load_vector_type,
|
|
ir.DenseElementsAttr.get_splat(load_vector_type, const_zero)
|
|
)
|
|
return arith.cmpi(predicate, val, vector_zeros)
|
|
else: # Scalar case.
|
|
const_zero = arith.ConstantOp(load_scalar_type, const_zero)
|
|
return arith.cmpi(predicate, val, const_zero)
|
|
|
|
|
|
def _maybe_cast_store_to_memref_type(
|
|
ctx: LoweringRuleContext, expected_aval, val: ir.Value
|
|
) -> ir.Value:
|
|
"""Casts a boolean value back to an integer for storing in a memref."""
|
|
if expected_aval.dtype != jnp.bool_:
|
|
return val
|
|
int_out_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
expected_aval,
|
|
is_kernel_boundary=True,
|
|
)
|
|
return arith.extui(int_out_type, val)
|
|
|
|
|
|
def _masked_swap_lowering_rule(
|
|
ctx: LoweringRuleContext, *args_flat, args_tree, **_
|
|
):
|
|
ref, transforms, val, mask = args_tree.unflatten(args_flat)
|
|
ref_aval, transforms_avals, val_aval, mask_aval = args_tree.unflatten(
|
|
ctx.avals_in
|
|
)
|
|
(*prev_transforms, idx) = transforms
|
|
(*_, idx_aval) = transforms_avals
|
|
|
|
if mask is not None:
|
|
if val_aval.dtype.itemsize != 4:
|
|
raise NotImplementedError("masked swap with non-32-bit data")
|
|
if val_aval.shape != mask_aval.shape:
|
|
raise ValueError(
|
|
"Expected value and mask to have the same shape, but got"
|
|
f" value shape {val_aval.shape} vs. mask shape {mask_aval.shape}."
|
|
)
|
|
|
|
ref_block_shape, *_ = ctx.block_shapes
|
|
ref, ref_block_shape = _transform_ref(
|
|
ref, ref_aval.dtype, ref_block_shape, prev_transforms
|
|
)
|
|
|
|
ref_type = ir.MemRefType(ref.type)
|
|
memory_space = str(ref_type.memory_space)
|
|
is_smem_store = memory_space == "#tpu.memory_space<smem>"
|
|
is_vmem_store = memory_space == "#tpu.memory_space<vmem>"
|
|
(aval_out,) = ctx.avals_out
|
|
if not isinstance(val, ir.Value):
|
|
val = ir_constant(val, mlir_type=_dtype_to_ir_type(val_aval.dtype))
|
|
if any(
|
|
(not isinstance(a, primitives.Slice) and a.shape)
|
|
for a in idx_aval.indices
|
|
):
|
|
raise ValueError("Cannot do int indexing on TPU")
|
|
if not is_smem_store and not ref_block_shape:
|
|
raise NotImplementedError(
|
|
"Indexing into a ()-shaped Ref not yet supported on TPU.")
|
|
|
|
starts, _, strides, _, _ = _indexer_to_start_size_stride(
|
|
idx,
|
|
ref_block_shape,
|
|
cast_to_index=True,
|
|
)
|
|
need_stride = not all((s is None or s == 1) for s in strides)
|
|
|
|
if is_smem_store:
|
|
if mask is not None:
|
|
raise ValueError("SMEM store does not support masks")
|
|
if val_aval.shape:
|
|
raise ValueError("Can only store scalars to SMEM")
|
|
result = memref.load(ref, starts)
|
|
result = _maybe_cast_load_to_bool(ctx, val_aval, result)
|
|
val = _maybe_cast_store_to_memref_type(ctx, val_aval, val)
|
|
memref.StoreOp(val, ref, starts)
|
|
return result
|
|
|
|
if not is_vmem_store:
|
|
extra = ""
|
|
if memory_space == "#tpu.memory_space<any>":
|
|
extra = " ANY memory space can only be accessed using async_copy."
|
|
raise ValueError(
|
|
"Loads and stores are only allowed on VMEM and SMEM references." + extra
|
|
)
|
|
|
|
# handling VMEM store below
|
|
if not val_aval.shape:
|
|
raise ValueError("Cannot store scalars to VMEM")
|
|
|
|
mem_slice_shape = list(aval_out.shape)
|
|
for i, a in enumerate(idx_aval.indices):
|
|
if not isinstance(a, primitives.Slice):
|
|
mem_slice_shape.insert(i, 1)
|
|
mem_slice_shape_iter = iter(mem_slice_shape)
|
|
mem_slice_shape = [
|
|
1 if b is pallas_core.mapped else next(mem_slice_shape_iter)
|
|
for b in ref_block_shape
|
|
]
|
|
mem_aval = aval_out.update(
|
|
shape=tuple(mem_slice_shape), sharding=jax_core.get_cur_mesh_sharding()
|
|
)
|
|
mem_aval_shape = ctx.lowering_context.dynamic_shape_replacement_fn(
|
|
mem_aval.shape
|
|
)
|
|
mem_aval_vec_type = ir.VectorType.get(
|
|
mem_aval_shape, _dtype_to_ir_type(mem_aval.dtype, is_kernel_boundary=True)
|
|
)
|
|
if need_stride:
|
|
result = tpu.strided_load(mem_aval_vec_type, ref, starts, strides)
|
|
else:
|
|
result = vector.load(mem_aval_vec_type, ref, starts)
|
|
val = _maybe_cast_store_to_memref_type(ctx, val_aval, val)
|
|
if mem_aval != aval_out:
|
|
# We are slicing a scalar so provided dummy 1 indices
|
|
result_vec_type = ir.VectorType.get(aval_out.shape,
|
|
_dtype_to_ir_type(aval_out.dtype, is_kernel_boundary=True))
|
|
result = vector.shape_cast(result_vec_type, result)
|
|
val_vec_type = ir.VectorType.get(mem_aval.shape,
|
|
_dtype_to_ir_type(mem_aval.dtype, is_kernel_boundary=True))
|
|
val = vector.shape_cast(val_vec_type, val)
|
|
result = _maybe_cast_load_to_bool(ctx, val_aval, result)
|
|
|
|
if need_stride:
|
|
if mask is not None:
|
|
raise NotImplementedError("masked swap with strided store")
|
|
tpu.StridedStoreOp(val, ref, starts, strides)
|
|
else:
|
|
tpu.VectorStoreOp(val, ref, starts, [], mask=mask)
|
|
return result
|
|
|
|
|
|
lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule
|
|
skip_mlir_conversions.add(primitives.swap_p)
|
|
|
|
|
|
def _multiple_of_lowering_rule(ctx: LoweringRuleContext, val, *, values):
|
|
del ctx
|
|
for multiple in values:
|
|
val = tpu.assume_multiple(val, multiple)
|
|
return val
|
|
|
|
|
|
lowering_rules[primitives.multiple_of_p] = _multiple_of_lowering_rule
|
|
|
|
|
|
def reduce_lowering_rule(reduce_fn, type_to_kind, type_to_identity):
|
|
def _lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
|
(x_aval,) = ctx.avals_in
|
|
if not ctx.avals_out[0].shape:
|
|
# If reducing to a scalar, we reduce by adding a leading singleton
|
|
# dimension and reducing over all other dimensions. This avoids
|
|
# the materialization of a scalar tensor by the reduction op which
|
|
# is not supported.
|
|
def _proxy_fun(val, *, axes):
|
|
val = val[jnp.newaxis, ...]
|
|
axes = [axis + 1 for axis in axes]
|
|
val = reduce_fn(val, axis=axes, keepdims=True)
|
|
# Squeeze lowers to vector.ExtractOp which will place the final
|
|
# value in a scalar register.
|
|
return jnp.squeeze(val)
|
|
proxy_lowering = lower_fun(
|
|
_proxy_fun, multiple_results=False)
|
|
return proxy_lowering(ctx, x, axes=axes)
|
|
|
|
if jnp.issubdtype(x_aval.dtype, jnp.floating):
|
|
kind = type_to_kind[jnp.floating]
|
|
val = type_to_identity[jnp.floating]
|
|
val = ir.FloatAttr.get(
|
|
aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
x_aval,
|
|
shape=(),
|
|
),
|
|
val,
|
|
)
|
|
elif x_aval.dtype == jnp.int32:
|
|
kind = type_to_kind[jnp.signedinteger]
|
|
val = type_to_identity[jnp.signedinteger]
|
|
val = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), val)
|
|
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
|
|
raise NotImplementedError(
|
|
"Reductions over unsigned integers not implemented."
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Reductions over {x_aval.dtype} not implemented.")
|
|
out_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
)
|
|
identity = ir.DenseElementsAttr.get_splat(out_type, val)
|
|
acc = arith.ConstantOp(out_type, identity)
|
|
return vector.multi_reduction(kind, x, acc, axes)
|
|
return _lowering_rule
|
|
|
|
|
|
REDUCE_MAX_KINDS = {
|
|
jnp.floating: vector.CombiningKind.MAXIMUMF,
|
|
jnp.signedinteger: vector.CombiningKind.MAXSI,
|
|
jnp.unsignedinteger: vector.CombiningKind.MAXUI,
|
|
}
|
|
REDUCE_MAX_IDENTITY = {
|
|
jnp.floating: float("-inf"),
|
|
jnp.signedinteger: np.iinfo(np.int32).min,
|
|
}
|
|
_reduce_max_lowering_rule = reduce_lowering_rule(
|
|
jnp.max, REDUCE_MAX_KINDS, REDUCE_MAX_IDENTITY)
|
|
lowering_rules[lax.reduce_max_p] = _reduce_max_lowering_rule
|
|
|
|
|
|
REDUCE_MIN_KINDS = {
|
|
jnp.floating: vector.CombiningKind.MINIMUMF,
|
|
jnp.signedinteger: vector.CombiningKind.MINSI,
|
|
jnp.unsignedinteger: vector.CombiningKind.MINUI,
|
|
}
|
|
REDUCE_MIN_IDENTITY = {
|
|
jnp.floating: float("inf"),
|
|
jnp.signedinteger: np.iinfo(np.int32).max,
|
|
}
|
|
_reduce_min_lowering_rule = reduce_lowering_rule(
|
|
jnp.min, REDUCE_MIN_KINDS, REDUCE_MIN_IDENTITY)
|
|
lowering_rules[lax.reduce_min_p] = _reduce_min_lowering_rule
|
|
|
|
|
|
REDUCE_SUM_KINDS = {
|
|
jnp.floating: vector.CombiningKind.ADD,
|
|
jnp.signedinteger: vector.CombiningKind.ADD,
|
|
jnp.unsignedinteger: vector.CombiningKind.ADD,
|
|
}
|
|
REDUCE_SUM_IDENTITY = {
|
|
jnp.floating: 0.0,
|
|
jnp.signedinteger: 0,
|
|
}
|
|
_reduce_sum_lowering_rule = reduce_lowering_rule(
|
|
jnp.sum, REDUCE_SUM_KINDS, REDUCE_SUM_IDENTITY)
|
|
lowering_rules[lax.reduce_sum_p] = _reduce_sum_lowering_rule
|
|
|
|
|
|
def _reduce_and_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
|
def _proxy_reduce(arg, *, axes):
|
|
# Mosaic currently only supports float reductions, so we cast the boolean
|
|
# arg to a float and use reduce_min to implement reduce_and.
|
|
# TODO(b/351017807): Implement this logic in Mosaic MultiDimReductionOp
|
|
# instead.
|
|
float_arg = jnp.where(arg, 1.0, 0.0)
|
|
return jnp.min(float_arg, axis=axes) > 0.0
|
|
proxy_lowering = lower_fun(
|
|
_proxy_reduce, multiple_results=False)
|
|
return proxy_lowering(ctx, x, axes=axes)
|
|
|
|
lowering_rules[lax.reduce_and_p] = _reduce_and_lowering_rule
|
|
|
|
|
|
def _reduce_or_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
|
def _proxy_reduce(arg, *, axes):
|
|
# Mosaic currently only supports float reductions, so we cast the boolean
|
|
# arg to a float and use reduce_max to implement reduce_or.
|
|
# TODO(b/351017807): Implement this logic in Mosaic MultiDimReductionOp
|
|
# instead.
|
|
float_arg = jnp.where(arg, 1.0, 0.0)
|
|
return jnp.max(float_arg, axis=axes) > 0.0
|
|
proxy_lowering = lower_fun(
|
|
_proxy_reduce, multiple_results=False)
|
|
return proxy_lowering(ctx, x, axes=axes)
|
|
|
|
lowering_rules[lax.reduce_or_p] = _reduce_or_lowering_rule
|
|
|
|
|
|
def _broadcast_to_lowering_rule(
|
|
ctx: LoweringRuleContext, x, shape: Sequence[int]
|
|
):
|
|
raise RuntimeError(
|
|
"`broadcast_to` is a Triton-specific primitive. Please consider using"
|
|
" `jnp.broadcast_to` instead."
|
|
)
|
|
|
|
|
|
lowering_rules[state_primitives.broadcast_to_p] = _broadcast_to_lowering_rule
|
|
|
|
|
|
def _broadcast_in_dim_lowering_rule(
|
|
ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions, sharding
|
|
):
|
|
del sharding
|
|
(aval_in,) = ctx.avals_in
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_in.dtype, jnp.bool_):
|
|
# Direct broadcasts for bools are not supported in Mosaic due to booleans
|
|
# living in mask registers and broadcast operating on vregs. Broadcast as an
|
|
# integer instead and cast back to a bool.
|
|
# TODO(b/351019164): Implement this logic in Mosaic BroadcastOp instead.
|
|
def _proxy_fun(val, *, shape, broadcast_dimensions):
|
|
int_val = jnp.where(val, 1, 0)
|
|
bcast_val = jax.lax.broadcast_in_dim(int_val, shape, broadcast_dimensions)
|
|
return bcast_val == 1
|
|
proxy_lowering = lower_fun(
|
|
_proxy_fun, multiple_results=False)
|
|
return proxy_lowering(
|
|
ctx, val, shape=shape, broadcast_dimensions=broadcast_dimensions)
|
|
|
|
if broadcast_dimensions:
|
|
out_shape_list = [1] * len(shape)
|
|
for i, s in zip(broadcast_dimensions, aval_in.shape):
|
|
out_shape_list[i] = s
|
|
out_shape = tuple(out_shape_list)
|
|
out_type = ir.VectorType.get(
|
|
out_shape, _dtype_to_ir_type(aval_out.dtype)
|
|
)
|
|
val = vector.shape_cast(out_type, val)
|
|
if out_shape == aval_out.shape:
|
|
return val
|
|
out_type = ir.VectorType.get(
|
|
aval_out.shape, _dtype_to_ir_type(aval_out.dtype)
|
|
)
|
|
return vector.broadcast(out_type, val)
|
|
|
|
|
|
lowering_rules[lax.broadcast_in_dim_p] = _broadcast_in_dim_lowering_rule
|
|
|
|
|
|
def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape):
|
|
"""Converts a jax dot dimension numbers to a tpu dot dimension numbers.
|
|
|
|
Jax dot dimension numbers are given as a tuple of tuples of sequences of ints
|
|
of the form ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
|
|
rhs_batch_dims)).
|
|
|
|
TPU dot dimension numbers are given as an MLIR definition of the form
|
|
#tpu.dot_dimension_numbers - which can be found in the tpu dilect definition
|
|
# file, tpu.td .
|
|
"""
|
|
(contracting_dims, batch_dims) = dimension_numbers
|
|
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
|
|
lhs_batch_dims, rhs_batch_dims = batch_dims
|
|
|
|
lhs_total_dims = set(range(len(lhs_shape)))
|
|
rhs_total_dims = set(range(len(rhs_shape)))
|
|
|
|
lhs_non_contracting_dims = sorted(
|
|
lhs_total_dims - set(lhs_contracting_dims) - set(lhs_batch_dims)
|
|
)
|
|
rhs_non_contracting_dims = sorted(
|
|
rhs_total_dims - set(rhs_contracting_dims) - set(rhs_batch_dims)
|
|
)
|
|
|
|
# Create output_dim_order
|
|
# Note: we assume that the output dimensions are ordered as batch dims, lhs_non_contracting_dims,
|
|
# rhs_non_contracting_dims - this assumption is safe to make, as it is
|
|
# the same one made in jax's dot_general.
|
|
output_dim_order = []
|
|
|
|
lhs_dim_map = {dim: idx for idx, dim in enumerate(range(len(lhs_shape)))}
|
|
rhs_dim_map = {dim: idx for idx, dim in enumerate(range(len(rhs_shape)))}
|
|
|
|
for dim in lhs_batch_dims:
|
|
output_dim_order.append(0)
|
|
output_dim_order.append(lhs_dim_map[dim])
|
|
|
|
for dim in lhs_non_contracting_dims:
|
|
output_dim_order.append(0)
|
|
output_dim_order.append(lhs_dim_map[dim])
|
|
|
|
for dim in rhs_non_contracting_dims:
|
|
output_dim_order.append(1)
|
|
output_dim_order.append(rhs_dim_map[dim])
|
|
|
|
def format_dims(dims):
|
|
return "[" + ", ".join(str(d) for d in dims) + "]"
|
|
|
|
all_dims = (
|
|
lhs_contracting_dims,
|
|
rhs_contracting_dims,
|
|
lhs_non_contracting_dims,
|
|
rhs_non_contracting_dims,
|
|
output_dim_order,
|
|
lhs_batch_dims,
|
|
rhs_batch_dims,
|
|
)
|
|
tpu_dim_numbers_str = (
|
|
f"#tpu.dot_dimension_numbers<{','.join(map(format_dims, all_dims))}>"
|
|
)
|
|
|
|
return ir.Attribute.parse(tpu_dim_numbers_str)
|
|
|
|
|
|
def _dot_general_lowering_rule(
|
|
ctx: LoweringRuleContext, x, y, dimension_numbers, precision, **_
|
|
):
|
|
(lhs_dims, rhs_dims), _ = dimension_numbers
|
|
(aval_out,) = ctx.avals_out
|
|
out_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, aval_out
|
|
)
|
|
val_type = out_type.element_type
|
|
if any(
|
|
cls.isinstance(val_type)
|
|
for cls in [
|
|
ir.BF16Type,
|
|
ir.F32Type,
|
|
ir.Float8E5M2Type,
|
|
ir.Float8E4M3FNType,
|
|
ir.Float8E4M3B11FNUZType,
|
|
]
|
|
):
|
|
val = ir.FloatAttr.get(val_type, 0.0)
|
|
elif ir.IntegerType.isinstance(val_type):
|
|
val = ir.IntegerAttr.get(val_type, 0)
|
|
else:
|
|
raise NotImplementedError(ctx.avals_out[0].dtype)
|
|
if any(len(a.shape) != 2 for a in ctx.avals_in):
|
|
raise NotImplementedError(
|
|
f"Only 2D tensors supported in dot; received: {ctx.avals_in}"
|
|
)
|
|
lhs_aval, rhs_aval = ctx.avals_in
|
|
# This is really a matrix-vector product. It only looks like matrix-matrix.
|
|
if lhs_dims == (1,) and rhs_dims == (1,) and ctx.avals_in[1].shape[0] == 1:
|
|
if ctx.avals_in[0].shape != ctx.avals_in[1].shape:
|
|
bcast_shape = jnp.broadcast_shapes(
|
|
ctx.avals_in[0].shape, ctx.avals_out[0].shape
|
|
)
|
|
bcast_shape = ir.VectorType.get(
|
|
list(bcast_shape), _dtype_to_ir_type(ctx.avals_out[0].dtype)
|
|
)
|
|
if ctx.avals_in[0].shape != bcast_shape:
|
|
x = vector.broadcast(bcast_shape, x)
|
|
if ctx.avals_in[1].shape != bcast_shape:
|
|
y = vector.broadcast(bcast_shape, y)
|
|
red_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
lhs_aval.update(shape=(lhs_aval.shape[0],)),
|
|
)
|
|
acc = arith.ConstantOp(
|
|
red_type, ir.DenseElementsAttr.get_splat(red_type, val)
|
|
)
|
|
red = vector.MultiDimReductionOp(
|
|
ir.Attribute.parse("#vector.kind<add>"),
|
|
arith.MulFOp(x, y),
|
|
acc,
|
|
[1]
|
|
)
|
|
return vector.shape_cast(out_type, red)
|
|
|
|
tpu_dot_dims = jax_dot_dims_to_tpu_dot_dot_dims(
|
|
dimension_numbers, lhs_aval.shape, rhs_aval.shape
|
|
)
|
|
|
|
if precision is not None:
|
|
if precision[0] != precision[1]:
|
|
raise NotImplementedError("Per-operand dot precision unsupported")
|
|
precision = precision[0]
|
|
if precision is None or precision == lax.Precision.DEFAULT:
|
|
precision_attr = None # That's the default in Mosaic.
|
|
elif precision == lax.Precision.HIGHEST:
|
|
precision_attr = ir.Attribute.parse(
|
|
"#tpu.contract_precision<fp32>"
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported dot precision: {precision}")
|
|
out_tile = arith.ConstantOp(
|
|
out_type, ir.DenseElementsAttr.get_splat(out_type, val)
|
|
)
|
|
return tpu.matmul(
|
|
out_type,
|
|
x,
|
|
y,
|
|
out_tile,
|
|
dimension_numbers=tpu_dot_dims,
|
|
precision=precision_attr,
|
|
)
|
|
|
|
|
|
lowering_rules[lax.dot_general_p] = _dot_general_lowering_rule
|
|
|
|
def _convert_helper(x, *, to_dtype):
|
|
# Helper function for dtype conversion
|
|
from_dtype = x.dtype
|
|
if from_dtype == jnp.bool_:
|
|
x = x.astype(jnp.int32)
|
|
return _convert_helper(x, to_dtype=to_dtype)
|
|
if to_dtype == jnp.bool_:
|
|
# Lower float32 or (u)int32 -> bool to cmp neq %in, 0
|
|
# TODO(apaszke,mvoz): Move the upcasts for cmpi to the Mosaic canonicalizer.
|
|
if jnp.issubdtype(from_dtype, jnp.floating):
|
|
if from_dtype.itemsize < 4:
|
|
x = x.astype(jnp.float32)
|
|
elif jnp.issubdtype(from_dtype, jnp.integer):
|
|
if from_dtype.itemsize < 4:
|
|
x = x.astype(jnp.int32)
|
|
return x != jnp.asarray(0, dtype=x.dtype)
|
|
if jnp.issubdtype(from_dtype, jnp.signedinteger):
|
|
if from_dtype.itemsize < 4:
|
|
x = x.astype(jnp.int32)
|
|
if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4:
|
|
x = x.astype(jnp.float32)
|
|
return x.astype(to_dtype)
|
|
if jnp.issubdtype(from_dtype, jnp.unsignedinteger):
|
|
if from_dtype.itemsize < 4:
|
|
x = x.astype(jnp.uint32)
|
|
# unsigned -> float is unsupported. We fall through and raise at the bottom.
|
|
if not jnp.issubdtype(to_dtype, jnp.floating):
|
|
return x.astype(to_dtype)
|
|
if jnp.issubdtype(from_dtype, jnp.floating) and jnp.issubdtype(
|
|
to_dtype, jnp.signedinteger
|
|
):
|
|
if from_dtype.itemsize < 4:
|
|
x = x.astype(jnp.float32)
|
|
if to_dtype.itemsize < 4:
|
|
# Need to clip values to match XLA
|
|
minval, maxval = jnp.iinfo(to_dtype).min, jnp.iinfo(to_dtype).max
|
|
x = jnp.clip(x, minval, maxval)
|
|
return x.astype(jnp.int32).astype(to_dtype)
|
|
raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}")
|
|
|
|
def _convert_element_type_lowering_rule(
|
|
ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
|
|
):
|
|
del weak_type
|
|
del sharding
|
|
out_aval = ctx.avals_out[0]
|
|
in_aval = ctx.avals_in[0]
|
|
old_dtype = in_aval.dtype
|
|
out_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
)
|
|
|
|
if old_dtype == new_dtype:
|
|
return x
|
|
|
|
if new_dtype.itemsize == 8:
|
|
raise NotImplementedError("64-bit types are not supported")
|
|
|
|
_from = lambda dtype: jnp.issubdtype(old_dtype, dtype)
|
|
_to = lambda dtype: jnp.issubdtype(new_dtype, dtype)
|
|
floating = jnp.floating
|
|
integer = jnp.integer
|
|
signed = jnp.signedinteger
|
|
both_32bit = old_dtype.itemsize == 4 and new_dtype.itemsize == 4
|
|
if _from(floating) and _to(floating):
|
|
if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4:
|
|
return arith.extf(out_type, x)
|
|
elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4:
|
|
return arith.truncf(out_type, x)
|
|
elif _from(integer) and _to(integer):
|
|
if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4:
|
|
if not (_from(signed) and _to(signed)):
|
|
raise NotImplementedError(f"Unsupported cast: {old_dtype} -> {new_dtype}")
|
|
return arith.extsi(out_type, x)
|
|
elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4:
|
|
return arith.trunci(out_type, x)
|
|
elif jnp.iinfo(old_dtype).bits == jnp.iinfo(new_dtype).bits:
|
|
# This case triggers when casting signed to unsigned or vice versa.
|
|
return x
|
|
# TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer.
|
|
elif _from(floating) and _to(signed):
|
|
# TODO(apaszke): Remove once a month has passed, along with the
|
|
# _convert_helper float -> signed conversion above.
|
|
if not ctx.forward_compatible or both_32bit:
|
|
return arith.fptosi(out_type, x)
|
|
elif _from(signed) and _to(floating) and both_32bit:
|
|
return arith.sitofp(out_type, x)
|
|
elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4:
|
|
return arith.extui(out_type, x)
|
|
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
|
|
multiple_results=False)(ctx, x)
|
|
|
|
|
|
lowering_rules[lax.convert_element_type_p] = _convert_element_type_lowering_rule
|
|
|
|
|
|
def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions,
|
|
sharding):
|
|
if dimensions is not None:
|
|
raise NotImplementedError
|
|
if any(d is None for d in new_sizes):
|
|
raise NotImplementedError
|
|
if not ctx.avals_in[0].shape:
|
|
return vector.broadcast(
|
|
aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
),
|
|
x,
|
|
)
|
|
return vector.shape_cast(
|
|
aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
),
|
|
x,
|
|
)
|
|
|
|
|
|
lowering_rules[lax.reshape_p] = _reshape_lowering_rule
|
|
|
|
|
|
def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions):
|
|
del dimensions # Unused.
|
|
(aval_in,) = ctx.avals_in
|
|
(aval_out,) = ctx.avals_out
|
|
if not aval_out.shape:
|
|
if aval_out.dtype.itemsize != 4:
|
|
raise ValueError(
|
|
"Only arrays with 32-bit element types can be converted to scalars,"
|
|
f" but got: {aval_out.dtype}. Try casting the input before squeezing"
|
|
" the scalar."
|
|
)
|
|
return vector.extract(x, [], [0] * len(aval_in.shape))
|
|
return vector.shape_cast(
|
|
aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
),
|
|
x,
|
|
)
|
|
|
|
|
|
lowering_rules[lax.squeeze_p] = _squeeze_lowering_rule
|
|
|
|
|
|
def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension):
|
|
out_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
)
|
|
return tpu.concatenate(out_type, xs, dimension=dimension)
|
|
|
|
|
|
lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule
|
|
|
|
|
|
def _split_lowering_rule(
|
|
ctx: LoweringRuleContext, x, *, sizes, axis
|
|
):
|
|
(x_aval,) = ctx.avals_in
|
|
slice_size = np.array(x_aval.shape, dtype=np.int64)
|
|
starts = np.zeros_like(slice_size)
|
|
strides = np.ones_like(slice_size)
|
|
outs = []
|
|
for size, aval_out in zip(sizes, ctx.avals_out):
|
|
slice_size[axis] = size
|
|
outs.append(
|
|
vector.extract_strided_slice(
|
|
aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, aval_out
|
|
),
|
|
x,
|
|
starts,
|
|
slice_size,
|
|
strides,
|
|
)
|
|
)
|
|
starts[axis] += size
|
|
return outs
|
|
|
|
lowering_rules[lax.split_p] = _split_lowering_rule
|
|
|
|
|
|
def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension,
|
|
sharding):
|
|
out_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
)
|
|
return tpu.iota(out_type, dimension=dimension)
|
|
|
|
|
|
lowering_rules[lax.iota_p] = _iota_lowering_rule
|
|
|
|
|
|
def _gather_lowering_rule(
|
|
ctx: LoweringRuleContext,
|
|
x,
|
|
indices,
|
|
*,
|
|
dimension_numbers,
|
|
slice_sizes,
|
|
unique_indices,
|
|
indices_are_sorted,
|
|
mode,
|
|
fill_value,
|
|
):
|
|
in_aval = ctx.avals_in[0]
|
|
indices_aval = ctx.avals_in[1]
|
|
out_aval = ctx.avals_out[0]
|
|
|
|
if len(in_aval.shape) != 2:
|
|
raise NotImplementedError("Only 2D gather is supported")
|
|
if pallas_utils.dtype_bitwidth(in_aval.dtype) != 32:
|
|
raise NotImplementedError("Only 32-bit gather is supported")
|
|
if in_aval.shape != indices_aval.shape[:-1] != out_aval.shape:
|
|
raise ValueError("Shape mismatch in input, indices and output")
|
|
|
|
out_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
)
|
|
# During lowering jnp.take_along_axis to lax.gather, we append extra dimension
|
|
# to the end of the indices array. We should reshape it back to the original
|
|
# shape before lowering to Mosaic and rely on MLIR CSE to remove the reshapes.
|
|
assert indices_aval.shape == in_aval.shape + (1,)
|
|
recovered_indices = vector.shape_cast(
|
|
ir.VectorType.get(in_aval.shape, ir.IntegerType.get_signless(32)),
|
|
indices,
|
|
)
|
|
# Note: current support for lax.gather is still very limited.
|
|
del fill_value
|
|
if (
|
|
slice_sizes == (1, 1)
|
|
and not unique_indices
|
|
and not indices_are_sorted
|
|
and mode
|
|
in (
|
|
lax.GatherScatterMode.FILL_OR_DROP,
|
|
lax.GatherScatterMode.PROMISE_IN_BOUNDS,
|
|
)
|
|
):
|
|
if dimension_numbers == lax.GatherDimensionNumbers(
|
|
offset_dims=(),
|
|
collapsed_slice_dims=(0,),
|
|
start_index_map=(0,),
|
|
operand_batching_dims=(1,),
|
|
start_indices_batching_dims=(1,),
|
|
):
|
|
return tpu.dynamic_gather(out_type, x, recovered_indices, 0)
|
|
if dimension_numbers == lax.GatherDimensionNumbers(
|
|
offset_dims=(),
|
|
collapsed_slice_dims=(1,),
|
|
start_index_map=(1,),
|
|
operand_batching_dims=(0,),
|
|
start_indices_batching_dims=(0,),
|
|
):
|
|
return tpu.dynamic_gather(out_type, x, recovered_indices, 1)
|
|
raise NotImplementedError("Unsupported gather")
|
|
|
|
|
|
lowering_rules[lax.gather_p] = _gather_lowering_rule
|
|
|
|
|
|
def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation):
|
|
if permutation != (1, 0):
|
|
raise NotImplementedError
|
|
out_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
)
|
|
return vector.transpose(out_type, x, permutation)
|
|
|
|
|
|
lowering_rules[lax.transpose_p] = _transpose_lowering_rule
|
|
|
|
|
|
def _bcast(x, y, x_aval, y_aval, out_aval):
|
|
x_dtype = x_aval.dtype
|
|
y_dtype = y_aval.dtype
|
|
if y_aval.weak_type:
|
|
y_dtype = x_aval.dtype
|
|
elif x_aval.weak_type:
|
|
x_dtype = y_aval.dtype
|
|
if isinstance(x, (np.ndarray, np.number, int, float)):
|
|
if getattr(y, "type", None) == ir.IndexType.get():
|
|
mlir_type = y.type
|
|
else:
|
|
mlir_type = _dtype_to_ir_type(x_dtype)
|
|
x = ir_constant(x, mlir_type)
|
|
if isinstance(y, (np.ndarray, np.number, int, float)):
|
|
if getattr(x, "type", None) == ir.IndexType.get():
|
|
mlir_type = x.type
|
|
else:
|
|
mlir_type = _dtype_to_ir_type(y_dtype)
|
|
y = ir_constant(y, mlir_type)
|
|
out_shape = list(out_aval.shape)
|
|
if x_aval.shape != out_aval.shape:
|
|
x_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(x_dtype))
|
|
x = vector.broadcast(x_ty, x)
|
|
if y_aval.shape != out_aval.shape:
|
|
y_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(y_dtype))
|
|
y = vector.broadcast(y_ty, y)
|
|
return x, y
|
|
|
|
|
|
def _add_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
(aval_out,) = ctx.avals_out
|
|
if jnp.issubdtype(aval_out.dtype, jnp.integer):
|
|
return arith.addi(x, y)
|
|
if jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
return arith.addf(x, y)
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
lowering_rules[lax.add_p] = _add_lowering_rule
|
|
skip_mlir_conversions.add(lax.add_p)
|
|
lowering_rules[ad_util.add_any_p] = _add_lowering_rule
|
|
skip_mlir_conversions.add(ad_util.add_any_p)
|
|
|
|
|
|
class FoldingError(Exception):
|
|
pass
|
|
|
|
|
|
def _fold_and_get_constant_value(x):
|
|
def _fold(x, fuel):
|
|
if fuel <= 0:
|
|
raise FoldingError("Folding depth exceeded")
|
|
op_name = getattr(x.owner, "name", None)
|
|
binop_folds = {
|
|
"arith.maxsi": max,
|
|
"arith.minsi": min,
|
|
}
|
|
if op_name == "arith.constant":
|
|
if ir.IntegerType.isinstance(x.type):
|
|
return ir.IntegerAttr(x.owner.attributes["value"]).value
|
|
elif ir.FloatType.isinstance(x.type):
|
|
return ir.FloatAttr(x.owner.attributes["value"]).value
|
|
else:
|
|
raise ValueError(f"Unsupported constant type: {x.type}")
|
|
if op_name in binop_folds:
|
|
return binop_folds[op_name](_fold(v, fuel - 1) for v in x.owner.operands)
|
|
raise FoldingError(f"Folding not supported for {x.owner}")
|
|
|
|
try:
|
|
return _fold(x, 10)
|
|
except FoldingError:
|
|
return None
|
|
|
|
|
|
def _max_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
(aval_out,) = ctx.avals_out
|
|
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
|
|
return arith.maxsi(x, y)
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
|
|
return arith.maxui(x, y)
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
return arith.maximumf(x, y)
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
lowering_rules[lax.max_p] = _max_lowering_rule
|
|
skip_mlir_conversions.add(lax.max_p)
|
|
|
|
|
|
def _min_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
(aval_out,) = ctx.avals_out
|
|
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
|
|
return arith.minsi(x, y)
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
|
|
return arith.minui(x, y)
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
return arith.minimumf(x, y)
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
lowering_rules[lax.min_p] = _min_lowering_rule
|
|
skip_mlir_conversions.add(lax.min_p)
|
|
|
|
|
|
def _sub_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
(aval_out,) = ctx.avals_out
|
|
if jnp.issubdtype(aval_out.dtype, jnp.integer):
|
|
return arith.subi(x, y)
|
|
if jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
return arith.subf(x, y)
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
lowering_rules[lax.sub_p] = _sub_lowering_rule
|
|
skip_mlir_conversions.add(lax.sub_p)
|
|
|
|
|
|
def _mul_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
(aval_out,) = ctx.avals_out
|
|
if jnp.issubdtype(aval_out.dtype, jnp.integer):
|
|
return arith.muli(x, y)
|
|
if jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
return arith.mulf(x, y)
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
lowering_rules[lax.mul_p] = _mul_lowering_rule
|
|
skip_mlir_conversions.add(lax.mul_p)
|
|
|
|
|
|
def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
(aval_out,) = ctx.avals_out
|
|
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
|
|
return arith.divsi(x, y)
|
|
if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
|
|
return arith.divui(x, y)
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
return arith.divf(x, y)
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
lowering_rules[lax.div_p] = _div_lowering_rule
|
|
skip_mlir_conversions.add(lax.div_p)
|
|
|
|
|
|
def _rem_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
(aval_out,) = ctx.avals_out
|
|
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
|
|
return arith.remsi(x, y)
|
|
if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
|
|
return arith.remui(x, y)
|
|
if jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
return arith.remf(x, y)
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
lowering_rules[lax.rem_p] = _rem_lowering_rule
|
|
skip_mlir_conversions.add(lax.rem_p)
|
|
|
|
|
|
def _abs_lowering_rule(ctx: LoweringRuleContext, x):
|
|
(aval_out,) = ctx.avals_out
|
|
if jnp.issubdtype(aval_out.dtype, jnp.integer):
|
|
return math.absi(x)
|
|
if jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
return math.absf(x)
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
lowering_rules[lax.abs_p] = _abs_lowering_rule
|
|
|
|
|
|
def _neg_lowering_rule(ctx: LoweringRuleContext, x):
|
|
(x_aval,) = ctx.avals_in
|
|
new_ctx = ctx.replace(
|
|
avals_in=(jax_core.ShapedArray((), x_aval.dtype), x_aval),
|
|
block_shapes=((), *ctx.block_shapes)
|
|
)
|
|
return _sub_lowering_rule(new_ctx, np.array(0, dtype=x_aval.dtype), x)
|
|
|
|
|
|
lowering_rules[lax.neg_p] = _neg_lowering_rule
|
|
skip_mlir_conversions.add(lax.neg_p)
|
|
|
|
|
|
def _sign_lowering_rule(ctx: LoweringRuleContext, x):
|
|
return lower_fun(
|
|
pallas_utils.sign_lowering_helper, multiple_results=False,
|
|
)(ctx, x)
|
|
|
|
|
|
lowering_rules[lax.sign_p] = _sign_lowering_rule
|
|
|
|
|
|
def _nextafter_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
return lower_fun(
|
|
pallas_utils.nextafter_lowering_helper, multiple_results=False,
|
|
)(ctx, x, y)
|
|
|
|
|
|
lowering_rules[lax.nextafter_p] = _nextafter_lowering_rule
|
|
|
|
|
|
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
|
return math.rsqrt(x)
|
|
|
|
|
|
lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule
|
|
|
|
|
|
def _sqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
|
return math.sqrt(x)
|
|
|
|
|
|
lowering_rules[lax.sqrt_p] = _sqrt_lowering_rule
|
|
|
|
|
|
def _square_lowering_rule(ctx: LoweringRuleContext, x):
|
|
if jnp.issubdtype(ctx.avals_in[0].dtype, jnp.integer):
|
|
return arith.muli(x, x)
|
|
return arith.mulf(x, x)
|
|
|
|
|
|
lowering_rules[lax.square_p] = _square_lowering_rule
|
|
|
|
|
|
def _exp_lowering_rule(ctx: LoweringRuleContext, x):
|
|
return math.exp(x)
|
|
|
|
|
|
lowering_rules[lax.exp_p] = _exp_lowering_rule
|
|
|
|
|
|
def _pow_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
# jax accepts float base (x) and integer/float exponent (y), and integer
|
|
# exponent is casted to float.
|
|
out_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
)
|
|
if jnp.issubdtype(ctx.avals_in[1].dtype, jnp.integer):
|
|
y = arith.sitofp(out_type, y)
|
|
if not isinstance(x, ir.Value) and x == 2.:
|
|
return math.exp2(y)
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
return math.powf(x, y)
|
|
|
|
|
|
lowering_rules[lax.pow_p] = _pow_lowering_rule
|
|
skip_mlir_conversions.add(lax.pow_p)
|
|
|
|
|
|
def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y):
|
|
return lower_fun(lax_internal._integer_pow, multiple_results=False)(
|
|
ctx, x, y=y)
|
|
|
|
|
|
lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule
|
|
|
|
|
|
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
|
|
# exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior
|
|
# here.
|
|
return lower_fun(
|
|
lambda x: jnp.exp(jnp.astype(np.log(2), x.dtype) * x),
|
|
multiple_results=False,
|
|
)(ctx, x)
|
|
|
|
|
|
lowering_rules[lax.exp2_p] = _exp2_lowering_rule
|
|
skip_mlir_conversions.add(lax.exp2_p)
|
|
|
|
|
|
def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
|
|
neg_x = arith.negf(x)
|
|
exp_neg_x = math.exp(neg_x)
|
|
aval_out = ctx.avals_out[0]
|
|
out_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, aval_out
|
|
)
|
|
if aval_out.shape == ():
|
|
one = ir_constant(1.0, mlir_type=out_type)
|
|
else:
|
|
one = vector.broadcast(out_type, ir_constant(1.0))
|
|
denom = arith.addf(one, exp_neg_x)
|
|
return arith.divf(one, denom)
|
|
|
|
|
|
lowering_rules[lax.logistic_p] = _logistic_lowering_rule
|
|
|
|
|
|
def _sin_lowering_rule(ctx: LoweringRuleContext, x):
|
|
return math.sin(x)
|
|
|
|
|
|
lowering_rules[lax.sin_p] = _sin_lowering_rule
|
|
|
|
|
|
def _cos_lowering_rule(ctx: LoweringRuleContext, x):
|
|
return math.cos(x)
|
|
|
|
|
|
lowering_rules[lax.cos_p] = _cos_lowering_rule
|
|
|
|
|
|
def _tan_lowering_rule(ctx: LoweringRuleContext, x):
|
|
return math.tan(x)
|
|
|
|
|
|
lowering_rules[lax.tan_p] = _tan_lowering_rule
|
|
|
|
|
|
def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
|
|
return math.tanh(x)
|
|
|
|
|
|
lowering_rules[lax.tanh_p] = _tanh_lowering_rule
|
|
|
|
|
|
def _log_lowering_rule(ctx: LoweringRuleContext, x):
|
|
return math.log(x)
|
|
|
|
|
|
lowering_rules[lax.log_p] = _log_lowering_rule
|
|
|
|
|
|
def _log1p_lowering_rule(ctx: LoweringRuleContext, x):
|
|
return math.log1p(x)
|
|
|
|
|
|
lowering_rules[lax.log1p_p] = _log1p_lowering_rule
|
|
|
|
|
|
def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method):
|
|
if rounding_method == 0:
|
|
return math.round(x)
|
|
elif rounding_method == 1:
|
|
return math.roundeven(x)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported rounding method: {rounding_method}")
|
|
|
|
|
|
lowering_rules[lax.round_p] = _round_lowering_rule
|
|
|
|
|
|
def _ceil_lowering_rule(ctx: LoweringRuleContext, x):
|
|
return math.ceil(x)
|
|
|
|
|
|
lowering_rules[lax.ceil_p] = _ceil_lowering_rule
|
|
|
|
|
|
def _floor_lowering_rule(ctx: LoweringRuleContext, x):
|
|
return math.floor(x)
|
|
|
|
|
|
lowering_rules[lax.floor_p] = _floor_lowering_rule
|
|
|
|
|
|
def _clz_lowering_rule(ctx: LoweringRuleContext, x):
|
|
return math.ctlz(x)
|
|
|
|
lowering_rules[lax.clz_p] = _clz_lowering_rule
|
|
|
|
|
|
def _population_count_lowering_rule(ctx: LoweringRuleContext, x):
|
|
aval_out = ctx.avals_out[0]
|
|
if aval_out.shape == ():
|
|
raise ValueError("Population count is not supported on scalars")
|
|
return math.ctpop(x)
|
|
|
|
lowering_rules[lax.population_count_p] = _population_count_lowering_rule
|
|
|
|
|
|
# Mapping for signed integer comparisons.
|
|
_cmpsi_lowering_types = {
|
|
lax.eq_p: arith.CmpIPredicate.eq,
|
|
lax.ne_p: arith.CmpIPredicate.ne,
|
|
lax.lt_p: arith.CmpIPredicate.slt,
|
|
lax.le_p: arith.CmpIPredicate.sle,
|
|
lax.gt_p: arith.CmpIPredicate.sgt,
|
|
lax.ge_p: arith.CmpIPredicate.sge,
|
|
}
|
|
|
|
# Mapping for unsigned integer comparisons.
|
|
_cmpui_lowering_types = {
|
|
lax.eq_p: arith.CmpIPredicate.eq,
|
|
lax.ne_p: arith.CmpIPredicate.ne,
|
|
lax.lt_p: arith.CmpIPredicate.ult,
|
|
lax.le_p: arith.CmpIPredicate.ule,
|
|
lax.gt_p: arith.CmpIPredicate.ugt,
|
|
lax.ge_p: arith.CmpIPredicate.uge,
|
|
}
|
|
|
|
# Mapping for floating point comparisons.
|
|
_cmpf_lowering_types = {
|
|
lax.eq_p: arith.CmpFPredicate.OEQ,
|
|
lax.ne_p: arith.CmpFPredicate.ONE,
|
|
lax.lt_p: arith.CmpFPredicate.OLT,
|
|
lax.le_p: arith.CmpFPredicate.OLE,
|
|
lax.gt_p: arith.CmpFPredicate.OGT,
|
|
lax.ge_p: arith.CmpFPredicate.OGE,
|
|
}
|
|
|
|
|
|
# The relationship between comparison operations on booleans and boolean
|
|
# algebra is as follows:
|
|
# eq(x, y) = !(x ^ y)
|
|
# ne(x, y) = x ^ y
|
|
# lt(x, y) = !x && y
|
|
# le(x, y) = !x || y
|
|
# gt(x, y) = x && !y
|
|
# ge(x, y) = x || !y
|
|
def _cmp_boolean_lowering_helper(primitive, x: Array, y: Array):
|
|
"""A helper function for lowering comparison operations for boolean inputs.
|
|
|
|
Args:
|
|
primitive: A JAX primitive representing a comparison operation, which is
|
|
one of the following: `lax.eq_p` (equals), `lax.ne_p` (not equals),
|
|
`lax.lt_p` (less than), `lax.le_p` (less than or equal to),
|
|
`lax.gt_p` (greater than), or `lax.ge_p` (greater than or equal to).
|
|
x: A boolean array representing the first operand in the comparison.
|
|
y: A boolean array representing the second operand in the comparison.
|
|
|
|
Returns:
|
|
A boolean array that is the result of applying the comparison operation
|
|
between `x` and `y` based on the given primitive.
|
|
|
|
Raises:
|
|
ValueError: If an unsupported comparison primitive is provided.
|
|
"""
|
|
if primitive == lax.eq_p:
|
|
return jnp.logical_not(jnp.logical_xor(x, y))
|
|
elif primitive == lax.ne_p:
|
|
return jnp.logical_xor(x, y)
|
|
elif primitive == lax.lt_p:
|
|
return jnp.logical_and(jnp.logical_not(x), y)
|
|
elif primitive == lax.le_p:
|
|
return jnp.logical_or(jnp.logical_not(x), y)
|
|
elif primitive == lax.gt_p:
|
|
return jnp.logical_and(x, jnp.logical_not(y))
|
|
elif primitive == lax.ge_p:
|
|
return jnp.logical_or(x, jnp.logical_not(y))
|
|
else:
|
|
raise ValueError(f"Unsupported comparison primitive: {primitive}")
|
|
|
|
|
|
def _cmp_lowering_rule(primitive, ctx: LoweringRuleContext, x, y):
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
x_aval, y_aval = ctx.avals_in
|
|
if x_aval.dtype != y_aval.dtype:
|
|
raise ValueError(
|
|
f"Mixed dtype operands in cmp: {x_aval.dtype}, {y_aval.dtype}"
|
|
)
|
|
dtype = x_aval.dtype
|
|
|
|
if jnp.issubdtype(dtype, jnp.bool_):
|
|
return lower_fun(
|
|
functools.partial(_cmp_boolean_lowering_helper, primitive),
|
|
multiple_results=False,
|
|
)(ctx, x, y)
|
|
|
|
if jnp.issubdtype(dtype, jnp.integer):
|
|
is_uint = jnp.issubdtype(dtype, jnp.unsignedinteger)
|
|
pred = (
|
|
_cmpui_lowering_types if is_uint else _cmpsi_lowering_types
|
|
)[primitive]
|
|
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
|
|
return arith.cmpi(predicate, x, y)
|
|
|
|
if jnp.issubdtype(dtype, jnp.floating):
|
|
pred = _cmpf_lowering_types[primitive]
|
|
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
|
|
return arith.cmpf(predicate, x, y)
|
|
|
|
raise NotImplementedError(f"Unsupported dtype in cmp: {dtype}")
|
|
|
|
|
|
lowering_rules[lax.eq_p] = functools.partial(_cmp_lowering_rule, lax.eq_p)
|
|
lowering_rules[lax.ne_p] = functools.partial(_cmp_lowering_rule, lax.ne_p)
|
|
lowering_rules[lax.lt_p] = functools.partial(_cmp_lowering_rule, lax.lt_p)
|
|
lowering_rules[lax.le_p] = functools.partial(_cmp_lowering_rule, lax.le_p)
|
|
lowering_rules[lax.gt_p] = functools.partial(_cmp_lowering_rule, lax.gt_p)
|
|
lowering_rules[lax.ge_p] = functools.partial(_cmp_lowering_rule, lax.ge_p)
|
|
|
|
|
|
def _and_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
|
|
return arith.andi(x, y)
|
|
|
|
|
|
lowering_rules[lax.and_p] = _and_lowering_rule
|
|
skip_mlir_conversions.add(lax.and_p)
|
|
|
|
|
|
def _is_finite_lowering_rule(ctx: LoweringRuleContext, x):
|
|
out_aval, = ctx.avals_out
|
|
out_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
)
|
|
return _not_lowering_rule(ctx, tpu.weird(out_type, x))
|
|
|
|
|
|
lowering_rules[lax.is_finite_p] = _is_finite_lowering_rule
|
|
|
|
|
|
def _or_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
|
|
return arith.ori(x, y)
|
|
|
|
|
|
lowering_rules[lax.or_p] = _or_lowering_rule
|
|
skip_mlir_conversions.add(lax.or_p)
|
|
|
|
|
|
def _not_lowering_rule(ctx: LoweringRuleContext, x):
|
|
# The primitive not_p is lowered to
|
|
# https://github.com/openxla/stablehlo/blob/main/docs/spec.md#not
|
|
# which is arithmetic for integers and logical for booleans.
|
|
# Lowering to:
|
|
# xor x, -1
|
|
# covers both cases.
|
|
out_aval = ctx.avals_out[0]
|
|
out_scalar_type = _dtype_to_ir_type(out_aval.dtype)
|
|
if not out_aval.shape:
|
|
# Create a scalar constant.
|
|
minus_one = ir_constant(-1, out_scalar_type)
|
|
else:
|
|
# Create a vector constant.
|
|
out_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
)
|
|
scalar_minus_one = ir.IntegerAttr.get(out_scalar_type, -1)
|
|
minus_one = arith.ConstantOp(
|
|
out_type, ir.DenseElementsAttr.get_splat(out_type, scalar_minus_one)
|
|
)
|
|
return arith.xori(x, minus_one)
|
|
|
|
|
|
lowering_rules[lax.not_p] = _not_lowering_rule
|
|
|
|
def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args):
|
|
if len(args) > 1:
|
|
raise NotImplementedError("select_n only supported with <= 2 arguments")
|
|
pred_aval, x_aval = ctx.avals_in[:2]
|
|
if pred_aval.dtype != np.dtype(np.bool_):
|
|
lower_ctx = LoweringRuleContext(
|
|
ctx.lowering_context,
|
|
avals_in=[pred_aval],
|
|
avals_out=[pred_aval.update(dtype=np.bool_)],
|
|
block_shapes=[None],
|
|
)
|
|
pred = lower_fun(lambda x: x != 0, multiple_results=False)(lower_ctx, pred)
|
|
if not args:
|
|
return x
|
|
# Assume x and y, which we check above.
|
|
y, = args
|
|
return arith.select(pred, y, x)
|
|
|
|
|
|
lowering_rules[lax.select_n_p] = _select_n_lowering_rule
|
|
|
|
|
|
def _clamp(min, operand, max):
|
|
res = jnp.maximum(operand, min)
|
|
return jnp.minimum(res, max)
|
|
|
|
|
|
def _clamp_lowering_rule(ctx: LoweringRuleContext, min, operand, max):
|
|
"""Compute minimum_p(maximum_p(min, operand), max)."""
|
|
return lower_fun(_clamp, multiple_results=False)(ctx, min, operand, max)
|
|
|
|
|
|
lowering_rules[lax.clamp_p] = _clamp_lowering_rule
|
|
|
|
|
|
def _for_lowering_rule(
|
|
ctx: LoweringRuleContext,
|
|
*args,
|
|
jaxpr,
|
|
nsteps,
|
|
reverse,
|
|
unroll,
|
|
which_linear,
|
|
):
|
|
should_discharge = [
|
|
not isinstance(aval, state.AbstractRef) for aval in ctx.avals_in
|
|
]
|
|
jaxpr, () = state_discharge.discharge_state(
|
|
jaxpr, (), should_discharge=[False, *should_discharge]
|
|
)
|
|
for i in range(nsteps):
|
|
if reverse:
|
|
i = nsteps - i - 1
|
|
i = ir_constant(i)
|
|
lowering_context = ctx.lowering_context.replace(
|
|
block_shapes=[(), *ctx.block_shapes],
|
|
)
|
|
non_ref_args = jaxpr_subcomp(lowering_context, jaxpr, i, *args)
|
|
non_ref_args_iter = iter(non_ref_args)
|
|
args = [
|
|
next(non_ref_args_iter) if s else a
|
|
for a, s in zip(args, should_discharge)
|
|
]
|
|
return args
|
|
|
|
|
|
lowering_rules[for_loop.for_p] = _for_lowering_rule
|
|
|
|
|
|
def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext,
|
|
jaxpr: jax_core.Jaxpr, start: int | ir.Value,
|
|
num_steps: int | ir.Value, consts, *args,
|
|
has_loop_index: bool,
|
|
unroll: int):
|
|
def _run_body(i, args):
|
|
if has_loop_index:
|
|
lowering_context = ctx.lowering_context.replace(
|
|
block_shapes=ctx.block_shapes)
|
|
args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)
|
|
else:
|
|
del i
|
|
lowering_context = ctx.lowering_context.replace(
|
|
block_shapes=ctx.block_shapes[:len(consts)]
|
|
+ ctx.block_shapes[len(consts) + 1:],
|
|
)
|
|
args = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
|
|
return args
|
|
|
|
if (
|
|
not isinstance(start, ir.Value)
|
|
and not isinstance(num_steps, ir.Value)
|
|
and num_steps == unroll
|
|
):
|
|
# No need for an scf.For. We can just unroll completely
|
|
for i in range(start, start + num_steps):
|
|
args = _run_body(
|
|
ir_constant(i, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))),
|
|
args,
|
|
)
|
|
return args
|
|
if unroll != 1:
|
|
raise NotImplementedError(
|
|
f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.")
|
|
lbd = _ensure_mlir_value(start, pallas_core.index_map_grid_aval)
|
|
ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pallas_core.index_map_grid_aval))
|
|
step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
|
|
for_op = scf.ForOp(lbd, ubd, step, args)
|
|
with ir.InsertionPoint(for_op.body):
|
|
iv = for_op.induction_variable
|
|
inner_args = for_op.inner_iter_args
|
|
inner_out = _run_body(iv, inner_args)
|
|
scf.YieldOp(inner_out)
|
|
return for_op.results
|
|
|
|
|
|
def _scan_lowering_rule(
|
|
ctx: LoweringRuleContext,
|
|
*args,
|
|
jaxpr: jax_core.ClosedJaxpr,
|
|
linear: tuple[bool, ...],
|
|
length: int,
|
|
reverse: bool,
|
|
unroll: bool | int,
|
|
num_consts: int,
|
|
num_carry: int,
|
|
_split_transpose: bool,
|
|
):
|
|
del _split_transpose
|
|
# Can only handle fori_loop-like scans
|
|
num_extensive = len(args) - num_consts - num_carry
|
|
if num_extensive: raise NotImplementedError
|
|
if reverse: raise NotImplementedError
|
|
del linear, num_extensive, reverse
|
|
|
|
jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts
|
|
if jaxpr_consts: raise NotImplementedError
|
|
del jaxpr_consts
|
|
|
|
jaxpr, has_loop_index = pallas_utils.pattern_match_scan_to_fori_loop(
|
|
jaxpr, num_consts, num_carry
|
|
)
|
|
consts, args = split_list(args, [num_consts])
|
|
consts_avals, args_avals = split_list(ctx.avals_in, [num_consts])
|
|
if has_loop_index:
|
|
loop_index_start, *args = args
|
|
args_avals = args_avals[1:]
|
|
else:
|
|
loop_index_start = 0
|
|
consts = map(_ensure_mlir_value, consts, consts_avals)
|
|
args = map(_ensure_mlir_value, args, args_avals)
|
|
out = _lower_jaxpr_to_for_loop(
|
|
ctx, jaxpr, loop_index_start, length,
|
|
consts, *args, has_loop_index=has_loop_index,
|
|
unroll=unroll)
|
|
if has_loop_index:
|
|
out = [ir_constant(length,
|
|
mlir_type=_dtype_to_ir_type(jnp.dtype('int32'))),
|
|
*out]
|
|
return out
|
|
lowering_rules[lax.scan_p] = _scan_lowering_rule
|
|
skip_mlir_conversions.add(lax.scan_p)
|
|
|
|
|
|
def _lower_while_via_fori(
|
|
ctx: LoweringRuleContext,
|
|
*args,
|
|
fori_jaxpr,
|
|
cond_nconsts,
|
|
cond_jaxpr,
|
|
body_nconsts,
|
|
body_jaxpr,
|
|
):
|
|
_, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
|
|
(lb, ub), args = carry[:2], carry[2:]
|
|
for_out = _lower_jaxpr_to_for_loop(
|
|
ctx.replace(
|
|
block_shapes=ctx.block_shapes[: body_nconsts + 1]
|
|
+ ctx.block_shapes[body_nconsts + 2 :],
|
|
),
|
|
fori_jaxpr,
|
|
lb,
|
|
arith.subi(ub, lb),
|
|
body_consts,
|
|
*args,
|
|
has_loop_index=True,
|
|
unroll=1,
|
|
)
|
|
return [ub, ub, *for_out]
|
|
|
|
|
|
def _while_lowering_rule(
|
|
ctx: LoweringRuleContext,
|
|
*args,
|
|
cond_nconsts,
|
|
cond_jaxpr,
|
|
body_nconsts,
|
|
body_jaxpr,
|
|
):
|
|
# First try to lower via a simpler fori loop, which may optimize better.
|
|
fori_jaxpr, _ = pallas_utils.pattern_match_while_to_fori_loop(
|
|
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
|
|
)
|
|
if fori_jaxpr is not None:
|
|
return _lower_while_via_fori(
|
|
ctx,
|
|
*args,
|
|
fori_jaxpr=fori_jaxpr,
|
|
cond_nconsts=cond_nconsts,
|
|
cond_jaxpr=cond_jaxpr,
|
|
body_nconsts=body_nconsts,
|
|
body_jaxpr=body_jaxpr,
|
|
)
|
|
|
|
# If we fail conversion to fori, fallback to an ordinary while loop.
|
|
cond_consts, body_consts, carry = split_list(
|
|
args, [cond_nconsts, body_nconsts]
|
|
)
|
|
cond_const_block_shapes, body_const_block_shapes, carry_block_shapes = (
|
|
split_list(ctx.block_shapes, [cond_nconsts, body_nconsts])
|
|
)
|
|
carry_types = [a.type for a in carry]
|
|
while_op = scf.WhileOp(carry_types, carry)
|
|
|
|
before_block = while_op.before.blocks.append(*carry_types)
|
|
with ir.InsertionPoint.at_block_begin(before_block):
|
|
cond_args = [*cond_consts, *before_block.arguments]
|
|
[cond] = jaxpr_subcomp(
|
|
ctx.lowering_context.replace(
|
|
block_shapes=[*cond_const_block_shapes, *carry_block_shapes]
|
|
),
|
|
cond_jaxpr.jaxpr,
|
|
*cond_args,
|
|
)
|
|
scf.condition(cond, before_block.arguments)
|
|
|
|
after_block = while_op.after.blocks.append(*carry_types)
|
|
with ir.InsertionPoint.at_block_begin(after_block):
|
|
body_args = [*body_consts, *after_block.arguments]
|
|
loop_out = jaxpr_subcomp(
|
|
ctx.lowering_context.replace(
|
|
block_shapes=[*body_const_block_shapes, *carry_block_shapes],
|
|
),
|
|
body_jaxpr.jaxpr,
|
|
*body_args,
|
|
)
|
|
if loop_out:
|
|
scf.yield_(loop_out)
|
|
return list(while_op.results)
|
|
|
|
|
|
lowering_rules[lax.while_p] = _while_lowering_rule
|
|
|
|
def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches):
|
|
index, *args = args
|
|
constant_index = _fold_and_get_constant_value(index)
|
|
|
|
if constant_index is not None:
|
|
return jaxpr_subcomp(
|
|
ctx.lowering_context.replace(block_shapes=ctx.block_shapes[1:]), branches[constant_index].jaxpr, *args
|
|
)
|
|
aval_to_ir_type_with_fn = functools.partial(
|
|
aval_to_ir_type, ctx.lowering_context.dynamic_shape_replacement_fn
|
|
)
|
|
out_types = map(aval_to_ir_type_with_fn, ctx.avals_out)
|
|
pred = arith.cmpi(
|
|
arith.CmpIPredicate.ne, index, ir_constant(0, index.type)
|
|
)
|
|
if_op = scf.IfOp(pred, out_types, hasElse=True)
|
|
lowering_context = ctx.lowering_context.replace(
|
|
block_shapes=ctx.block_shapes[1:],
|
|
)
|
|
with ir.InsertionPoint(if_op.then_block):
|
|
# TODO(b/300272065): Use `scf.IndexSwitchOp` instead of a cascade of
|
|
# if/else.
|
|
if len(branches) > 2:
|
|
out = _cond_lowering_rule(
|
|
ctx,
|
|
arith.subi(index, ir_constant(1, index.type)),
|
|
*args,
|
|
branches=branches[1:],
|
|
)
|
|
else:
|
|
out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args)
|
|
scf.YieldOp(out)
|
|
with ir.InsertionPoint(if_op.else_block):
|
|
out = jaxpr_subcomp(lowering_context, branches[0].jaxpr, *args)
|
|
scf.YieldOp(out)
|
|
return if_op.results
|
|
|
|
|
|
lowering_rules[lax.cond_p] = _cond_lowering_rule
|
|
|
|
|
|
def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
|
|
lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes)
|
|
return jaxpr_subcomp(lowering_context, jaxpr.jaxpr, *args)
|
|
|
|
|
|
lowering_rules[pjit.pjit_p] = _pjit_lowering_rule
|
|
|
|
|
|
def _mesh_cast_lowering_rule(ctx, x, dst_sharding):
|
|
return x
|
|
lowering_rules[pjit.mesh_cast_p] = _mesh_cast_lowering_rule
|
|
|
|
|
|
def _custom_jvp_call_lowering_rule(
|
|
ctx: LoweringRuleContext,
|
|
*args,
|
|
call_jaxpr: jax_core.Jaxpr,
|
|
jvp_jaxpr_fun: lu.WrappedFun,
|
|
num_consts: int,
|
|
symbolic_zeros: bool,
|
|
):
|
|
del jvp_jaxpr_fun
|
|
if symbolic_zeros: raise NotImplementedError
|
|
if num_consts: raise NotImplementedError
|
|
if call_jaxpr.consts: raise NotImplementedError
|
|
lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes)
|
|
return jaxpr_subcomp(lowering_context, call_jaxpr.jaxpr, *args)
|
|
|
|
|
|
lowering_rules[custom_derivatives.custom_jvp_call_p] = (
|
|
_custom_jvp_call_lowering_rule)
|
|
|
|
|
|
def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):
|
|
del ctx, args, kwargs
|
|
# No-op debug callbacks in Mosaic for now
|
|
return []
|
|
|
|
|
|
lowering_rules[debugging.debug_callback_p] = _debug_callback_lowering_rule
|
|
|
|
|
|
def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int):
|
|
|
|
if ctx.lowering_context.user_grid_indices is None:
|
|
raise ValueError(
|
|
f"program id: {axis} was passed, but user did not provide a grid."
|
|
)
|
|
length = len(ctx.lowering_context.user_grid_indices)
|
|
if not (0 <= axis < length):
|
|
raise ValueError(
|
|
f"user passed in program id with axis: {axis}, but grid only has"
|
|
f" length: {length}"
|
|
)
|
|
return ctx.lowering_context.user_grid_indices[axis]
|
|
lowering_rules[primitives.program_id_p] = _program_id_lowering_rule
|
|
|
|
def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int):
|
|
mapped_axes = set(ctx.lowering_context.mapped_dims)
|
|
seen_user_axes = 0
|
|
for i in range(ctx.lowering_context.grid_rank):
|
|
seen_user_axes += int(i not in mapped_axes)
|
|
if seen_user_axes == axis + 1:
|
|
break
|
|
else:
|
|
raise ValueError(
|
|
f"user passed in program id with axis: {axis}, but grid only has"
|
|
f" length: {len(ctx.lowering_context.grid_rank)}"
|
|
)
|
|
return tpu.iteration_bound(i)
|
|
lowering_rules[primitives.num_programs_p] = _num_programs_lowering_rule
|
|
|
|
|
|
def _repeat_lowering_rule(ctx: LoweringRuleContext, x, *, repeats, axis):
|
|
(out_aval,) = ctx.avals_out
|
|
return tpu.repeat(
|
|
aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
),
|
|
x,
|
|
axis,
|
|
repeats,
|
|
)
|
|
|
|
|
|
lowering_rules[tpu_primitives.repeat_p] = _repeat_lowering_rule
|
|
|
|
|
|
def _roll_lowering_rule(
|
|
ctx: LoweringRuleContext, x, shift, *, axis, stride, stride_axis
|
|
):
|
|
(out_aval,) = ctx.avals_out
|
|
return tpu.dynamic_rotate(
|
|
aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
),
|
|
x,
|
|
shift,
|
|
axis,
|
|
stride=stride,
|
|
stride_dimension=stride_axis,
|
|
)
|
|
|
|
|
|
lowering_rules[tpu_primitives.roll_p] = _roll_lowering_rule
|
|
|
|
|
|
def _slice_lowering_rule(
|
|
ctx: LoweringRuleContext, x, limit_indices, start_indices, strides
|
|
):
|
|
"""Lowers a slice to vector dialect."""
|
|
(aval_out,) = ctx.avals_out
|
|
out_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, aval_out
|
|
)
|
|
if strides is None:
|
|
strides = [1] * len(start_indices)
|
|
sizes = np.array(limit_indices) - np.array(start_indices)
|
|
return vector.extract_strided_slice(
|
|
out_type, x, start_indices, sizes, strides
|
|
)
|
|
|
|
|
|
lowering_rules[lax.slice_p] = _slice_lowering_rule
|
|
|
|
|
|
def _xor_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
|
|
return arith.xori(x, y)
|
|
|
|
|
|
lowering_rules[lax.xor_p] = _xor_lowering_rule
|
|
skip_mlir_conversions.add(lax.xor_p)
|
|
|
|
|
|
def _shift_left_lowering_rule(ctx: LoweringRuleContext, x, d):
|
|
x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out)
|
|
return arith.shli(x, d)
|
|
|
|
|
|
lowering_rules[lax.shift_left_p] = _shift_left_lowering_rule
|
|
skip_mlir_conversions.add(lax.shift_left_p)
|
|
|
|
|
|
def _shift_right_arithmetic_lowering_rule(ctx: LoweringRuleContext, x, d):
|
|
x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out)
|
|
return arith.shrsi(x, d)
|
|
|
|
|
|
lowering_rules[lax.shift_right_arithmetic_p] = _shift_right_arithmetic_lowering_rule
|
|
skip_mlir_conversions.add(lax.shift_right_arithmetic_p)
|
|
|
|
|
|
def _shift_right_logical_lowering_rules(ctx: LoweringRuleContext, x, d):
|
|
x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out)
|
|
return arith.shrui(x, d)
|
|
|
|
|
|
lowering_rules[lax.shift_right_logical_p] = _shift_right_logical_lowering_rules
|
|
skip_mlir_conversions.add(lax.shift_right_logical_p)
|
|
|
|
|
|
def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x):
|
|
return lower_fun(
|
|
pallas_utils.erf_inv_lowering_helper, multiple_results=False,
|
|
)(ctx, x)
|
|
|
|
|
|
lowering_rules[lax.erf_inv_p] = _erf_inv_lowering_rule
|
|
|
|
|
|
def _reciprocal_lowering_rule(ctx: LoweringRuleContext, x, *, approx):
|
|
if not isinstance(x.type.element_type, ir.F32Type):
|
|
raise ValueError("Only float32 is supported.")
|
|
return tpu.reciprocal(x, approx=approx)
|
|
|
|
|
|
lowering_rules[primitives.reciprocal_p] = _reciprocal_lowering_rule
|
|
|
|
def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty):
|
|
del ty
|
|
(out_aval,) = ctx.avals_out
|
|
return tpu.bitcast(
|
|
aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
),
|
|
x,
|
|
)
|
|
|
|
lowering_rules[tpu_primitives.bitcast_p] = _bitcast_lowering_rule
|
|
|
|
def _bitcast_convert_type_lowering_rule(
|
|
ctx: LoweringRuleContext, x, *, new_dtype):
|
|
(in_aval, ) = ctx.avals_in
|
|
(out_aval,) = ctx.avals_out
|
|
old_bitwidth = pallas_utils.dtype_bitwidth(in_aval.dtype)
|
|
new_bitwidth = pallas_utils.dtype_bitwidth(new_dtype)
|
|
if old_bitwidth != new_bitwidth:
|
|
raise NotImplementedError("Changing bitwidths not supported.")
|
|
return tpu.bitcast(
|
|
aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
),
|
|
x,
|
|
)
|
|
lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule
|
|
|
|
|
|
def _alloc_value(
|
|
aval: jax_core.AbstractValue, *, ctx: LoweringRuleContext
|
|
) -> ir.Value:
|
|
if isinstance(aval, pallas_core.AbstractMemoryRef):
|
|
memspace = _memory_space_to_mosaic_attribute(aval.memory_space)
|
|
if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype):
|
|
assert aval.memory_space == TPUMemorySpace.SEMAPHORE
|
|
memref_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
aval,
|
|
memory_space=TPUMemorySpace.SEMAPHORE,
|
|
)
|
|
return tpu.sem_alloc(memref_type)
|
|
else:
|
|
out_type = ir.MemRefType.get(
|
|
aval.shape,
|
|
_dtype_to_ir_type(aval.dtype, is_kernel_boundary=True),
|
|
memory_space=memspace)
|
|
return memref.alloca(out_type, [], [])
|
|
elif isinstance(aval, tpu_core.AbstractSemaphore):
|
|
memref_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
aval,
|
|
memory_space=TPUMemorySpace.SEMAPHORE,
|
|
)
|
|
return tpu.sem_alloc(memref_type)
|
|
raise NotImplementedError(f"Cannot allocate {type(aval)}.")
|
|
|
|
|
|
def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr):
|
|
out_type = [
|
|
aval_to_ir_type(ctx.lowering_context.dynamic_shape_replacement_fn, aval)
|
|
for aval in ctx.avals_out
|
|
]
|
|
region = tpu.RegionOp(out_type)
|
|
in_avals = [v.aval for v in jaxpr.invars]
|
|
with ctx.lowering_context.grid_name_context():
|
|
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
|
with ir.InsertionPoint(region.body):
|
|
alloc_fn = functools.partial(_alloc_value, ctx=ctx)
|
|
args = map(alloc_fn, in_avals)
|
|
block_shapes = tuple(a.shape if isinstance(a, state.AbstractRef) else None
|
|
for a in in_avals)
|
|
ctx = ctx.lowering_context.replace(
|
|
block_shapes=(*ctx.block_shapes, *block_shapes)
|
|
)
|
|
out = jaxpr_subcomp(ctx, jaxpr, *consts, *args)
|
|
tpu.YieldOp(out)
|
|
return region.results
|
|
|
|
|
|
lowering_rules[primitives.run_scoped_p] = _run_scoped_lowering_rule
|
|
|
|
def _device_id_to_logical(
|
|
ctx: LoweringRuleContext, device_id,
|
|
device_id_type: tpu_primitives.DeviceIdType):
|
|
if device_id_type is tpu_primitives.DeviceIdType.MESH:
|
|
# Mesh means we are passed the mesh coordinates for the device
|
|
device_ids = tree_util.tree_leaves(device_id)
|
|
mesh_strides = ctx.lowering_context.mesh_context.mesh_strides
|
|
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
if len(device_ids) == 0:
|
|
return arith.constant(i32, 0)
|
|
return functools.reduce(
|
|
arith.addi,
|
|
(
|
|
arith.muli(a, arith.constant(i32, b))
|
|
for a, b in zip(device_ids, mesh_strides)
|
|
),
|
|
)
|
|
elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL:
|
|
return device_id
|
|
raise NotImplementedError(f"Unsupported device id type: {device_id_type}")
|
|
|
|
|
|
def _semaphore_read_lowering_rule(
|
|
ctx: LoweringRuleContext,
|
|
*args,
|
|
args_tree,
|
|
):
|
|
sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
|
|
sem, transforms = tree_util.tree_unflatten(args_tree, args)
|
|
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms)
|
|
return tpu.sem_read(sem)
|
|
|
|
|
|
lowering_rules[tpu_primitives.semaphore_read_p] = _semaphore_read_lowering_rule
|
|
|
|
def _semaphore_signal_lowering_rule(
|
|
ctx: LoweringRuleContext,
|
|
*args,
|
|
args_tree,
|
|
device_id_type: tpu_primitives.DeviceIdType,
|
|
):
|
|
sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
|
|
sem, transforms, value, device_id, core_index = tree_util.tree_unflatten(
|
|
args_tree, args
|
|
)
|
|
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms)
|
|
if device_id is not None:
|
|
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
|
|
tpu.sem_signal(sem, value, device_id=device_id, core_id=core_index)
|
|
return []
|
|
|
|
|
|
lowering_rules[tpu_primitives.semaphore_signal_p] = (
|
|
_semaphore_signal_lowering_rule)
|
|
|
|
|
|
def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree):
|
|
sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
|
|
sem, transforms, value = tree_util.tree_unflatten(args_tree, args)
|
|
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms)
|
|
tpu.sem_wait(sem, value)
|
|
return []
|
|
lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule
|
|
|
|
def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree,
|
|
device_id_type: tpu_primitives.DeviceIdType):
|
|
(
|
|
src_ref,
|
|
src_transforms,
|
|
dst_ref,
|
|
dst_transforms,
|
|
sem,
|
|
sem_transforms,
|
|
src_sem,
|
|
src_sem_transforms,
|
|
device_id,
|
|
) = tree_util.tree_unflatten(tree, args)
|
|
(src_ref_aval, _, dst_ref_aval, _, sem_aval, _, src_sem_aval, _, _) = (
|
|
tree_util.tree_unflatten(tree, ctx.avals_in)
|
|
)
|
|
if src_ref_aval.dtype == jnp.bool_:
|
|
raise NotImplementedError("DMAs with bool dtypes are not supported.")
|
|
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
|
|
src_ref_block_shape, dst_ref_block_shape = block_shapes[0], block_shapes[2]
|
|
src_ref, _ = _transform_ref(
|
|
src_ref, src_ref_aval.dtype, src_ref_block_shape, src_transforms
|
|
)
|
|
if src_sem is not None:
|
|
src_sem, _ = _transform_ref(
|
|
src_sem, src_sem_aval.dtype, src_sem_aval.shape, src_sem_transforms
|
|
)
|
|
dst_ref, _ = _transform_ref(
|
|
dst_ref, dst_ref_aval.dtype, dst_ref_block_shape, dst_transforms
|
|
)
|
|
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms)
|
|
if device_id is not None:
|
|
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
|
|
tpu.enqueue_dma(src_ref, dst_ref, sem, source_semaphore=src_sem,
|
|
device_id=device_id)
|
|
|
|
return []
|
|
lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule
|
|
|
|
|
|
def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
|
|
device_id_type: tpu_primitives.DeviceIdType):
|
|
del device_id_type
|
|
(src, src_transforms, dst, transforms, sem, sem_transforms, _, _, _) = (
|
|
tree_util.tree_unflatten(tree, args)
|
|
)
|
|
(src_aval, _, dst_aval, _, sem_aval, _, _, _, _) = tree_util.tree_unflatten(
|
|
tree, ctx.avals_in
|
|
)
|
|
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
|
|
ref_block_shape = block_shapes[2]
|
|
src, _ = _transform_ref(src, src_aval.dtype, src_aval.shape, src_transforms)
|
|
dst, _ = _transform_ref(dst, dst_aval.dtype, ref_block_shape, transforms)
|
|
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms)
|
|
if ctx.forward_compatible or is_cloud_tpu_older_than(2025, 2, 12):
|
|
# TODO(mvoz): Remove once six months have passed. b/395630795
|
|
if hasattr(src_aval, "memory_space"):
|
|
src_memory_space = _memory_space_to_mosaic_attribute(src_aval.memory_space)
|
|
smem_space = ir.Attribute.parse("#tpu.memory_space<smem>")
|
|
src_is_smem = src_memory_space == smem_space
|
|
wait_ref = src if src_is_smem else dst
|
|
else:
|
|
wait_ref = dst
|
|
# Legacy instruction backwards compatibility.
|
|
tpu.wait_dma(sem, wait_ref)
|
|
else:
|
|
tpu.wait_dma2(sem, src, dst)
|
|
return []
|
|
|
|
lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule
|
|
|
|
def _device_id_lowering_rule(ctx: LoweringRuleContext):
|
|
return tpu.device_id()
|
|
lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule
|
|
|
|
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
|
|
grid_names = ctx.lowering_context.grid_names
|
|
if grid_names and axis_name in grid_names:
|
|
# We are querying a named axis corresponding to a grid dimension.
|
|
return _program_id_lowering_rule(ctx, axis=grid_names.index(axis_name))
|
|
# We are querying a named axis corresponding to a mesh dimension.
|
|
device_id = tpu.device_id()
|
|
mesh_context = ctx.lowering_context.mesh_context
|
|
if mesh_context is None:
|
|
raise ValueError("Mesh context is not set.")
|
|
mesh_shape = mesh_context.mesh_shape
|
|
axis_names = mesh_context.axis_names
|
|
axis_index = axis_names.index(axis_name)
|
|
axis_size = ir_constant(mesh_shape[axis_index])
|
|
minor_divisor = ir_constant(
|
|
np.prod(mesh_shape[axis_index + 1 :], dtype=np.int32)
|
|
)
|
|
return arith.remsi(arith.divsi(device_id, minor_divisor), axis_size)
|
|
lowering_rules[lax.axis_index_p] = _axis_index_rule
|
|
|
|
def _get_barrier_semaphore_rule(ctx: LoweringRuleContext):
|
|
memref_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
)
|
|
return tpu.sem_barrier(memref_type)
|
|
lowering_rules[tpu_primitives.get_barrier_semaphore_p] = _get_barrier_semaphore_rule
|
|
|
|
|
|
def _delay_rule(ctx: LoweringRuleContext, nanos: int):
|
|
tpu.delay(nanos)
|
|
return []
|
|
|
|
|
|
lowering_rules[tpu_primitives.delay_p] = _delay_rule
|
|
|
|
|
|
def _debug_print_rule(
|
|
ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool
|
|
):
|
|
is_scalar_inputs = [aval.shape == () for aval in ctx.avals_in]
|
|
is_all_scalars = all(is_scalar_inputs)
|
|
is_single_vector = len(is_scalar_inputs) == 1 and not is_scalar_inputs[0]
|
|
if not (is_all_scalars or is_single_vector):
|
|
raise ValueError(
|
|
"All inputs to debug_print must be all scalars or a single vector, but"
|
|
f" got {ctx.avals_in}"
|
|
)
|
|
|
|
# Scalar case.
|
|
if is_all_scalars:
|
|
primitives.check_debug_print_format(fmt, *args)
|
|
if has_placeholders:
|
|
if not all(
|
|
isinstance(arg.type, ir.IntegerType) and arg.type.width == 32
|
|
for arg in args
|
|
):
|
|
raise TypeError(
|
|
"All arguments must be 32-bit integers when using"
|
|
" placeholders (`{...}`). If you need to print values of other types,"
|
|
" remove placeholders from the format string."
|
|
)
|
|
|
|
# TPU expects $0, $1 etc as placeholders.
|
|
fmt = "".join(
|
|
f"{text}${idx}"
|
|
for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt))
|
|
)
|
|
|
|
tpu.log(args, fmt, formatted=has_placeholders)
|
|
return ()
|
|
|
|
# Vector case.
|
|
# Copy the array to vmem for logging.
|
|
# Note that the shape of the array must be explicitly provided here. This is
|
|
# because the underlying implementation aligns shapes to tile boundaries,
|
|
# potentially altering the original shape and making it unrecoverable.
|
|
if len(ctx.avals_in) != 1:
|
|
raise ValueError(
|
|
"Only one vector input to debug_print is supported."
|
|
)
|
|
(aval,) = ctx.avals_in
|
|
(arg,) = args
|
|
|
|
if not has_placeholders or not fmt.endswith("{}"):
|
|
raise ValueError("For vector input, the format string must end with {}.")
|
|
|
|
fmt = fmt[:-2]
|
|
|
|
region = tpu.RegionOp(())
|
|
with ir.InsertionPoint(region.body):
|
|
element_type = _dtype_to_ir_type(aval.dtype)
|
|
ref_type = ir.MemRefType.get(
|
|
aval.shape,
|
|
element_type,
|
|
memory_space=ir.Attribute.parse("#tpu.memory_space<vmem>"),
|
|
)
|
|
ref = memref.alloca(ref_type, [], [])
|
|
|
|
index_type = ir.IndexType.get()
|
|
zero = arith.constant(index_type, 0)
|
|
indices = [zero] * len(aval.shape)
|
|
vector.store(arg, ref, indices)
|
|
tpu.log_buffer(ref, aval.shape, fmt)
|
|
tpu.yield_([])
|
|
return ()
|
|
|
|
|
|
lowering_rules[primitives.debug_print_p] = _debug_print_rule
|
|
|
|
|
|
def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds):
|
|
del ctx
|
|
# In the KeyScalarBundle case we unpack the bundle and set the seed with
|
|
# the list of scalars.
|
|
if len(seeds) == 1 and isinstance(seeds[0], KeyScalarBundle):
|
|
tpu.prng_set_seed_32(seeds[0].scalars)
|
|
return []
|
|
# For integer seeds, we can set the seed directly as PRNGSeed32Op natively
|
|
# takes in a list of integers as input.
|
|
all_integers = all(isinstance(seed.type, ir.IntegerType) for seed in seeds)
|
|
if not all_integers:
|
|
seed_types = [seed.type for seed in seeds]
|
|
raise ValueError(f"All seed data must be scalar integers. Got {seed_types}")
|
|
tpu.prng_set_seed_32(seeds)
|
|
return []
|
|
lowering_rules[tpu_primitives.prng_seed_p] = _prng_seed_lowering_rule
|
|
|
|
|
|
def _prng_random_bits_lowering_rule(ctx: LoweringRuleContext, *, shape):
|
|
if len(shape) <= 1:
|
|
# TODO(b/342054464): Support implicit dims for PRNGRandomBitsOp.
|
|
raise NotImplementedError("random_bits only supports rank>=2 outputs.")
|
|
out_aval = ctx.avals_out[0]
|
|
out_type = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
)
|
|
return tpu.prng_random_bits(out_type)
|
|
lowering_rules[tpu_primitives.prng_random_bits_p] = _prng_random_bits_lowering_rule
|
|
|
|
|
|
def random_seed_lowering(ctx, seeds, *, impl):
|
|
seed_lowering = lower_fun(impl.seed, multiple_results=False)
|
|
return seed_lowering(ctx, seeds)
|
|
lowering_rules[prng.random_seed_p] = random_seed_lowering
|
|
|
|
|
|
def random_bits_lowering(ctx, keys, *, bit_width, shape):
|
|
assert bit_width == 32, "Only 32-bit PRNG supported."
|
|
aval, = ctx.avals_in
|
|
impl = aval.dtype._impl
|
|
_proxy_fn = impl.random_bits
|
|
if not pl_random.is_pallas_impl(impl):
|
|
def new_lowering(key, bit_width, shape):
|
|
key = jax.random.key_data(key).astype(jnp.uint32)
|
|
return impl.random_bits(key, bit_width, shape)
|
|
_proxy_fn = new_lowering
|
|
bits_lowering = lower_fun(_proxy_fn, multiple_results=False)
|
|
return bits_lowering(ctx, keys, bit_width=bit_width, shape=shape)
|
|
lowering_rules[prng.random_bits_p] = random_bits_lowering
|
|
|
|
|
|
def random_fold_in_lowering(ctx, keys, msgs):
|
|
keys_aval, _ = ctx.avals_in
|
|
impl = keys_aval.dtype._impl
|
|
fold_in_lowering = lower_fun(impl.fold_in, multiple_results=False)
|
|
return fold_in_lowering(ctx, keys, msgs)
|
|
lowering_rules[prng.random_fold_in_p] = random_fold_in_lowering
|
|
|
|
|
|
def random_unwrap_lowering(ctx, key):
|
|
keys_aval = ctx.avals_in[0]
|
|
impl = keys_aval.dtype._impl
|
|
if not pl_random.is_pallas_impl(impl):
|
|
return key
|
|
assert isinstance(key, KeyScalarBundle)
|
|
# Convert to a vector.
|
|
if tuple(key.key_shape) != (1, 1):
|
|
raise NotImplementedError(
|
|
"Seed key_data of shape != (1, 1) not supported. "
|
|
f"Got: {key.key_shape}")
|
|
scalar = key.scalars[0]
|
|
out_type = ir.VectorType.get(
|
|
key.key_shape, _dtype_to_ir_type(jnp.dtype('int32'))
|
|
)
|
|
val = vector.broadcast(out_type, scalar)
|
|
return val
|
|
lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering
|
|
|
|
|
|
def random_wrap_lowering(ctx, key_data, *, impl):
|
|
del ctx
|
|
if not pl_random.is_pallas_impl(impl):
|
|
return key_data
|
|
if isinstance(key_data.type, ir.VectorType):
|
|
# If the key data lives in vregs, need to unpack it to sregs.
|
|
key_data_list = []
|
|
key_data_shape = key_data.type.shape
|
|
if len(key_data_shape) != 2:
|
|
raise NotImplementedError("Seed key_data must be 2D.")
|
|
if tuple(key_data_shape) != (1, 1):
|
|
raise NotImplementedError(
|
|
"Seed key_data of shape != (1, 1) not supported. "
|
|
f"Got: {key_data_shape}")
|
|
for i in range(key_data_shape[1]):
|
|
key_data_list.append(vector.ExtractOp(key_data, [], [0, i]))
|
|
return KeyScalarBundle(
|
|
scalars=key_data_list, key_shape=tuple(key_data_shape))
|
|
if isinstance(key_data, KeyScalarBundle):
|
|
return key_data
|
|
else:
|
|
raise NotImplementedError(f"key_data wrap {type(key_data)}")
|
|
|
|
lowering_rules[prng.random_wrap_p] = random_wrap_lowering
|
|
|
|
def _checkify_lowering_rule(
|
|
ctx: LoweringRuleContext, *err_args, err_tree, debug):
|
|
if not tpu_core.runtime_assert_enabled():
|
|
if debug:
|
|
return []
|
|
else:
|
|
raise LoweringException("Non-debug check must be functionalized. "
|
|
"Enable runtime asserts with "
|
|
"--jax_pallas_enable_runtime_assert "
|
|
"or functionalize with checkify.check.")
|
|
|
|
assert ctx.lowering_context.ir_context.allow_unregistered_dialects, (
|
|
"allow_unregistered_dialects must be set to True for "
|
|
"runtime assert check.")
|
|
error = jax.tree.unflatten(err_tree, err_args)
|
|
assert len(error._pred) == 1
|
|
assert len(error._metadata) == 1
|
|
assert len(error._payload) == 1
|
|
pred = list(error._pred.items())[0][1]
|
|
metadata = list(error._metadata.items())[0]
|
|
payload = list(error._payload.items())[0][1]
|
|
exception_tree = metadata[1]
|
|
exception = jax.tree.unflatten(exception_tree, payload)
|
|
assert isinstance(exception, checkify.FailedCheckError)
|
|
|
|
# check_p has an inverted predicate compared to assert,
|
|
# so we need to compute not(pred) here.
|
|
out_scalar_type = _dtype_to_ir_type(jnp.dtype('bool'))
|
|
minus_one = ir_constant(-1, out_scalar_type)
|
|
not_pred = arith.xori(pred, minus_one)
|
|
attrs = {"msg": ir.StringAttr.get(exception.fmt_string)}
|
|
ir.Operation.create("cf.assert",
|
|
operands=(not_pred,),
|
|
attributes=attrs)
|
|
return []
|
|
lowering_rules[checkify.check_p] = _checkify_lowering_rule
|
|
|
|
def _threefry2x32_lowering(ctx, k1, k2, m1, m2):
|
|
def _lower_fun(k1, k2, m1, m2):
|
|
with jax.named_scope("threefry2x32"):
|
|
res = prng._threefry2x32_lowering(k1, k2, m1, m2, use_rolled_loops=False)
|
|
return res
|
|
|
|
threefry_lowering = lower_fun(_lower_fun, multiple_results=True)
|
|
return threefry_lowering(ctx, k1, k2, m1, m2)
|
|
|
|
|
|
lowering_rules[prng.threefry2x32_p] = _threefry2x32_lowering
|
|
|
|
|
|
def _iota_2x32_shape_lowering(ctx, *, shape):
|
|
total_elements = np.prod(shape)
|
|
if total_elements > np.iinfo(jnp.int32).max:
|
|
raise NotImplementedError(f"Iota with >{np.iinfo(jnp.int32).max} items.")
|
|
|
|
def _lower_fun(shape):
|
|
iota_data = jnp.zeros(shape, dtype=jnp.int32)
|
|
multiplier = 1
|
|
for dim in range(len(shape)-1, -1, -1):
|
|
counts_lo = lax.broadcasted_iota(
|
|
dtype=jnp.int32, shape=shape, dimension=dim
|
|
)
|
|
iota_data += counts_lo * multiplier
|
|
multiplier *= shape[dim]
|
|
counts_hi = jnp.zeros(shape, dtype=jnp.int32)
|
|
return counts_hi, iota_data
|
|
|
|
iota_lowering = lower_fun(_lower_fun, multiple_results=True)
|
|
return iota_lowering(ctx, shape=shape)
|
|
|
|
|
|
lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering
|
|
|
|
|
|
def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):
|
|
operand, padding_value = args
|
|
padding_config = kwargs["padding_config"]
|
|
|
|
out_type: ir.VectorType = aval_to_ir_type(
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_in[0]
|
|
)
|
|
if not isinstance(out_type, ir.VectorType):
|
|
raise NotImplementedError("Only vector types are supported.")
|
|
|
|
for axis, (low, high, interior) in enumerate(padding_config):
|
|
if low == 0 and high == 0 and interior == 0:
|
|
continue
|
|
|
|
def _pad(val):
|
|
shape = list(operand.type.shape)
|
|
shape[axis] = val
|
|
pad_vec_type = ir.VectorType.get(
|
|
shape,
|
|
operand.type.element_type,
|
|
)
|
|
|
|
if isinstance(padding_value, ir.OpResult):
|
|
pad = vector.broadcast(pad_vec_type, padding_value)
|
|
else:
|
|
scalar_attr = ir.FloatAttr.get(operand.type.element_type, padding_value)
|
|
pad = arith.ConstantOp(
|
|
pad_vec_type,
|
|
ir.DenseElementsAttr.get_splat(
|
|
pad_vec_type,
|
|
scalar_attr,
|
|
),
|
|
).result
|
|
return pad
|
|
|
|
if low != 0:
|
|
pad_low = _pad(low)
|
|
new_shape = out_type.shape
|
|
new_shape[axis] += low
|
|
out_type = ir.VectorType.get(
|
|
new_shape,
|
|
out_type.element_type,
|
|
)
|
|
operand = tpu.concatenate(out_type, [pad_low, operand], dimension=axis)
|
|
|
|
if high != 0:
|
|
pad_high = _pad(high)
|
|
new_shape = out_type.shape
|
|
new_shape[axis] += high
|
|
out_type = ir.VectorType.get(
|
|
new_shape,
|
|
out_type.element_type,
|
|
)
|
|
operand = tpu.concatenate(out_type, [operand, pad_high], dimension=axis)
|
|
|
|
if interior > 0:
|
|
raise NotImplementedError("Not implemented: interior padding")
|
|
|
|
return operand
|
|
|
|
|
|
lowering_rules[lax.pad_p] = _pad_lowering_rule
|
|
|
|
|
|
def _platform_index_lowering(
|
|
ctx: mlir.LoweringRuleContext,
|
|
*,
|
|
platforms: Sequence[Sequence[str]],
|
|
has_default: bool,
|
|
):
|
|
for i, ps in enumerate(platforms):
|
|
# note - slightly odd structure here, as platforms is a seq[seq[str]]
|
|
if "mosaic" in ps:
|
|
return ir_constant(i)
|
|
|
|
if has_default:
|
|
return ir_constant(len(platforms))
|
|
|
|
raise NotImplementedError(
|
|
"No mosaic or default platform indexing rule found."
|
|
)
|
|
|
|
|
|
lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering
|