2023-08-01 16:42:26 -07:00
|
|
|
# Copyright 2023 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
"""Module for lowering JAX to Mosaic-compatible MLIR dialects."""
|
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-06-26 14:44:52 -04:00
|
|
|
from collections.abc import Callable, Sequence
|
2023-08-01 16:42:26 -07:00
|
|
|
import dataclasses
|
|
|
|
import functools
|
2024-05-23 10:01:22 -07:00
|
|
|
import string
|
2024-06-26 14:44:52 -04:00
|
|
|
from typing import Any
|
2024-05-16 15:10:01 +01:00
|
|
|
|
2024-01-05 08:52:32 -08:00
|
|
|
import jax
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax import core as jax_core
|
|
|
|
from jax import lax
|
|
|
|
from jax import tree_util
|
2024-05-23 10:01:22 -07:00
|
|
|
from jax._src import ad_util
|
2023-08-04 15:08:26 -07:00
|
|
|
from jax._src import custom_derivatives
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src import debugging
|
2024-06-12 14:36:31 -07:00
|
|
|
from jax._src import dtypes
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src import linear_util as lu
|
2023-10-02 17:03:40 -07:00
|
|
|
from jax._src import mesh as mesh_lib
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src import pjit
|
2024-06-10 18:07:33 -07:00
|
|
|
from jax._src import prng
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src import source_info_util
|
|
|
|
from jax._src import state
|
|
|
|
from jax._src.interpreters import mlir
|
|
|
|
from jax._src.interpreters import partial_eval as pe
|
2023-09-15 16:00:19 -07:00
|
|
|
from jax._src.lax import lax as lax_internal
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src.lax.control_flow import for_loop
|
|
|
|
from jax._src.lib.mlir import ir
|
|
|
|
from jax._src.lib.mlir.dialects import arith
|
|
|
|
from jax._src.lib.mlir.dialects import func
|
|
|
|
from jax._src.lib.mlir.dialects import math
|
|
|
|
from jax._src.lib.mlir.dialects import memref
|
|
|
|
from jax._src.lib.mlir.dialects import scf
|
|
|
|
from jax._src.lib.mlir.dialects import vector
|
2024-01-18 17:19:38 -08:00
|
|
|
from jax._src.pallas import core as pl_core
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src.pallas import primitives
|
2023-08-04 13:43:04 -07:00
|
|
|
from jax._src.pallas import utils as pallas_utils
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src.pallas.mosaic import core as tpu_core
|
|
|
|
from jax._src.pallas.mosaic import primitives as tpu_primitives
|
|
|
|
from jax._src.state import discharge as state_discharge
|
2024-01-02 15:52:57 -08:00
|
|
|
from jax._src.state import indexing
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src.state import primitives as state_primitives
|
|
|
|
from jax._src.util import safe_map
|
|
|
|
from jax._src.util import safe_zip
|
|
|
|
from jax._src.util import split_list
|
|
|
|
from jax._src.util import unzip2
|
|
|
|
from jax.experimental.mosaic.dialects import tpu
|
|
|
|
import jax.numpy as jnp
|
2024-05-23 10:01:22 -07:00
|
|
|
from jaxlib.mlir.ir import Module
|
2023-08-01 16:42:26 -07:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
# TODO(sharadmv): enable type checking
|
|
|
|
# mypy: ignore-errors
|
|
|
|
|
2023-08-10 21:18:10 -07:00
|
|
|
NDIndexer = indexing.NDIndexer
|
2023-08-01 16:42:26 -07:00
|
|
|
TPUMemorySpace = tpu_core.TPUMemorySpace
|
|
|
|
VMEM = tpu_core.TPUMemorySpace.VMEM
|
|
|
|
SMEM = tpu_core.TPUMemorySpace.SMEM
|
|
|
|
|
2024-02-01 09:14:30 -08:00
|
|
|
# The value interpreter as a dynamic dimension by MLIR.
|
|
|
|
MLIR_DYNAMIC = -9223372036854775808
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
partial = functools.partial
|
|
|
|
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
|
|
|
|
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin
|
|
|
|
|
2024-06-10 18:07:33 -07:00
|
|
|
UNSIGNED_TO_SIGNED = {
|
|
|
|
np.dtype('uint8'): np.dtype('int8'),
|
|
|
|
np.dtype('uint16'): np.dtype('int16'),
|
|
|
|
np.dtype('uint32'): np.dtype('int32'),
|
|
|
|
np.dtype('uint64'): np.dtype('int64'),
|
|
|
|
}
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2023-10-02 17:03:40 -07:00
|
|
|
@dataclasses.dataclass
|
|
|
|
class MeshContext:
|
2024-03-22 07:11:45 -07:00
|
|
|
mesh_shape: tuple[int, ...]
|
2023-10-02 17:03:40 -07:00
|
|
|
axis_names: tuple[str, ...]
|
|
|
|
mesh_strides: tuple[int, ...]
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
@dataclasses.dataclass
|
|
|
|
class LoweringContext:
|
|
|
|
ir_context: ir.Context
|
2024-03-11 08:40:34 -07:00
|
|
|
grid_rank: int # Includes both user and vmap axes.
|
|
|
|
mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions.
|
|
|
|
user_grid_indices: Sequence[ir.Value] | None
|
2024-01-18 17:19:38 -08:00
|
|
|
block_shapes: list[tuple[int | pl_core.Mapped, ...]]
|
2023-08-01 16:42:26 -07:00
|
|
|
name_stack: source_info_util.NameStack
|
2023-10-02 17:03:40 -07:00
|
|
|
mesh_context: MeshContext | None
|
2023-08-01 16:42:26 -07:00
|
|
|
replace = dataclasses.replace
|
2024-01-03 09:00:29 -08:00
|
|
|
traceback_caches: mlir.TracebackCaches
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
class LoweringRuleContext:
|
|
|
|
lowering_context: LoweringContext
|
|
|
|
avals_in: Sequence[jax_core.AbstractValue]
|
|
|
|
avals_out: Sequence[jax_core.AbstractValue]
|
2024-01-18 17:19:38 -08:00
|
|
|
block_shapes: list[tuple[int | pl_core.Mapped, ...]] | None
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
replace = dataclasses.replace
|
|
|
|
|
|
|
|
|
2023-08-23 11:06:27 -07:00
|
|
|
def _memory_space_to_tpu_memspace(memory_space: TPUMemorySpace | None
|
|
|
|
) -> ir.Attribute:
|
|
|
|
if memory_space is None:
|
|
|
|
memory_space = VMEM
|
|
|
|
return ir.Attribute.parse(f"#tpu.memory_space<{memory_space}>")
|
|
|
|
|
2024-01-11 06:32:57 -08:00
|
|
|
def _dtype_to_ir_type(dtype: jnp.dtype) -> ir.Type:
|
|
|
|
if jnp.issubdtype(dtype, tpu_core.semaphore_dtype):
|
|
|
|
if jnp.issubdtype(dtype, tpu_core.dma_semaphore):
|
|
|
|
return ir.Type.parse("!tpu.dma_semaphore")
|
|
|
|
elif jnp.issubdtype(dtype, tpu_core.semaphore):
|
|
|
|
return ir.Type.parse("!tpu.semaphore")
|
|
|
|
elif jnp.issubdtype(dtype, tpu_core.barrier_semaphore):
|
|
|
|
return ir.Type.parse("!tpu.semaphore")
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
2024-06-10 18:07:33 -07:00
|
|
|
# TODO(justinfu): Remove after mosaic supports unsigned types.
|
|
|
|
# This conversion makes mosaic interpret all unsigned types as signed types.
|
|
|
|
type = mlir.dtype_to_ir_type(dtype)
|
|
|
|
if isinstance(type, ir.IntegerType):
|
|
|
|
return ir.IntegerType.get_signless(type.width)
|
|
|
|
else:
|
|
|
|
return type
|
2023-08-23 11:06:27 -07:00
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def aval_to_ir_type(aval, shape=None, memory_space: TPUMemorySpace | None = None):
|
2023-10-24 17:28:05 -07:00
|
|
|
if isinstance(aval, tpu_core.AbstractSemaphore):
|
2024-01-04 17:55:49 -08:00
|
|
|
if aval.sem_type is tpu_core.SemaphoreType.DMA:
|
|
|
|
sem_type = ir.Type.parse("!tpu.dma_semaphore")
|
|
|
|
elif aval.sem_type is tpu_core.SemaphoreType.REGULAR:
|
|
|
|
sem_type = ir.Type.parse("!tpu.semaphore")
|
|
|
|
elif aval.sem_type is tpu_core.SemaphoreType.BARRIER:
|
|
|
|
sem_type = ir.Type.parse("!tpu.semaphore")
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Cannot allocate {aval.sem_type}.")
|
2024-01-11 06:32:57 -08:00
|
|
|
memspace = _memory_space_to_tpu_memspace(TPUMemorySpace.SEMAPHORE)
|
|
|
|
return ir.MemRefType.get((), sem_type, memory_space=memspace)
|
2024-06-12 14:36:31 -07:00
|
|
|
if dtypes.issubdtype(aval.dtype, dtypes.prng_key):
|
|
|
|
shape = aval.dtype._impl.key_shape
|
|
|
|
if memory_space is None:
|
|
|
|
memory_space = TPUMemorySpace.SMEM
|
|
|
|
if memory_space != TPUMemorySpace.SMEM:
|
|
|
|
raise ValueError(f"PRNG keys must be stored in SMEM. Got {memory_space}")
|
|
|
|
memspace = _memory_space_to_tpu_memspace(memory_space)
|
|
|
|
return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)),
|
|
|
|
memory_space=memspace)
|
2023-08-01 16:42:26 -07:00
|
|
|
if isinstance(aval, state.AbstractRef):
|
2023-10-24 17:28:05 -07:00
|
|
|
if shape is None:
|
|
|
|
shape = aval.shape
|
2023-08-23 11:06:27 -07:00
|
|
|
memspace = _memory_space_to_tpu_memspace(memory_space)
|
2024-01-11 06:32:57 -08:00
|
|
|
return ir.MemRefType.get(shape, _dtype_to_ir_type(aval.dtype),
|
2023-08-01 16:42:26 -07:00
|
|
|
memory_space=memspace)
|
2023-10-24 17:28:05 -07:00
|
|
|
if isinstance(aval, jax_core.ShapedArray):
|
|
|
|
if shape is None:
|
|
|
|
shape = aval.shape
|
|
|
|
if not shape:
|
2024-01-11 06:32:57 -08:00
|
|
|
return _dtype_to_ir_type(aval.dtype)
|
|
|
|
return ir.VectorType.get(shape, _dtype_to_ir_type(aval.dtype))
|
2023-08-01 16:42:26 -07:00
|
|
|
raise NotImplementedError(aval)
|
|
|
|
|
|
|
|
|
|
|
|
def ir_constant(x, mlir_type=None):
|
|
|
|
if not hasattr(x, "dtype"):
|
|
|
|
if isinstance(x, int):
|
|
|
|
x = np.array(x, np.int32)
|
|
|
|
elif isinstance(x, float):
|
|
|
|
x = np.array(x, np.float32)
|
|
|
|
if not mlir_type:
|
2024-01-11 06:32:57 -08:00
|
|
|
mlir_type = _dtype_to_ir_type(x.dtype)
|
2024-02-28 19:17:47 -08:00
|
|
|
if isinstance(x, int) or x.dtype in (np.int32, np.uint32, np.int8):
|
2023-08-01 16:42:26 -07:00
|
|
|
return arith.ConstantOp(mlir_type, ir.IntegerAttr.get(mlir_type, int(x))
|
|
|
|
).result
|
|
|
|
elif isinstance(x, float) or x.dtype == np.float32:
|
|
|
|
return arith.ConstantOp(
|
|
|
|
mlir_type, ir.FloatAttr.get(mlir_type, float(x))
|
|
|
|
).result
|
|
|
|
elif x.dtype == jnp.bfloat16:
|
|
|
|
return arith.ConstantOp(
|
|
|
|
mlir_type, ir.FloatAttr.get(mlir_type, float(x))
|
|
|
|
).result
|
|
|
|
elif x.dtype == jnp.bool_:
|
|
|
|
return arith.ConstantOp(
|
|
|
|
mlir_type, ir.BoolAttr.get(bool(x))
|
|
|
|
).result
|
|
|
|
raise NotImplementedError(x.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules = {}
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions = set()
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2023-11-17 18:04:16 -08:00
|
|
|
def _get_arg_type(
|
|
|
|
aval,
|
2024-01-18 17:19:38 -08:00
|
|
|
block_mapping: pl_core.BlockMapping | None,
|
2023-11-17 18:04:16 -08:00
|
|
|
):
|
2024-01-18 17:19:38 -08:00
|
|
|
memory_space = None
|
|
|
|
if isinstance(aval, pl_core.AbstractMemoryRef):
|
2023-11-17 18:04:16 -08:00
|
|
|
memory_space = aval.memory_space
|
2024-01-18 17:19:38 -08:00
|
|
|
# We assume unannotated memory refs are in VMEM
|
|
|
|
if memory_space is None:
|
|
|
|
memory_space = TPUMemorySpace.VMEM
|
2023-11-17 18:04:16 -08:00
|
|
|
if isinstance(aval, tpu_core.AbstractSemaphore):
|
|
|
|
return aval_to_ir_type(aval), None
|
|
|
|
if block_mapping is None:
|
|
|
|
return aval_to_ir_type(aval, memory_space=memory_space), aval.shape
|
2024-01-18 17:19:38 -08:00
|
|
|
shape = tuple(1 if b is pl_core.mapped else b for b in block_mapping.block_shape)
|
2023-11-17 18:04:16 -08:00
|
|
|
return (
|
2024-01-18 17:19:38 -08:00
|
|
|
aval_to_ir_type(aval, shape=shape, memory_space=memory_space),
|
2023-11-17 18:04:16 -08:00
|
|
|
block_mapping.block_shape,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(init=False)
|
|
|
|
class MosaicGridMapping:
|
|
|
|
grid: tuple[int, ...] | None
|
|
|
|
jaxpr: jax_core.Jaxpr
|
2024-01-18 17:19:38 -08:00
|
|
|
block_mappings: tuple[pl_core.BlockMapping | None, ...]
|
2023-11-17 18:04:16 -08:00
|
|
|
mapped_dims: tuple[int, ...]
|
|
|
|
scalar_prefetch_types: tuple[ir.Type, ...]
|
|
|
|
operand_types: tuple[ir.Type, ...]
|
|
|
|
scratch_types: tuple[ir.Type, ...]
|
|
|
|
grid_types: tuple[ir.Type, ...]
|
|
|
|
scalar_prefetch_block_shapes: tuple[tuple[int, ...], ...]
|
|
|
|
operand_block_shapes: tuple[tuple[int, ...], ...]
|
|
|
|
scratch_block_shapes: tuple[tuple[int, ...], ...]
|
|
|
|
mesh_info: MeshInfo | None
|
|
|
|
get_grid_indices: Callable | None
|
|
|
|
|
2024-01-18 17:19:38 -08:00
|
|
|
def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping,
|
2023-11-17 18:04:16 -08:00
|
|
|
dimension_semantics: tuple[str, ...] | None,
|
|
|
|
mesh: mesh_lib.Mesh | None):
|
|
|
|
self.grid = grid_mapping.grid
|
|
|
|
self.jaxpr = jaxpr
|
|
|
|
self.block_mappings = grid_mapping.block_mappings
|
|
|
|
self.mapped_dims = grid_mapping.mapped_dims
|
|
|
|
num_scalar_prefetch = grid_mapping.num_index_operands
|
|
|
|
num_scratch = grid_mapping.num_scratch_operands
|
|
|
|
# jaxpr has signature [*scalar_prefetch, *in_ops *out_ops, *scratch]
|
|
|
|
num_operands = (
|
|
|
|
len(self.jaxpr.invars)
|
|
|
|
- num_scalar_prefetch
|
|
|
|
- num_scratch
|
|
|
|
)
|
|
|
|
user_grid = tuple(
|
|
|
|
g for i, g in enumerate(self.grid) if i not in self.mapped_dims
|
|
|
|
)
|
|
|
|
if dimension_semantics is None:
|
|
|
|
dimension_semantics = ("arbitrary",) * len(user_grid)
|
|
|
|
if len(user_grid) != len(dimension_semantics):
|
|
|
|
raise ValueError(
|
|
|
|
"Must have dimension semantics for each dimension of the grid."
|
|
|
|
)
|
|
|
|
if num_operands != len(self.block_mappings):
|
|
|
|
raise ValueError("Must have block mappings for each operand.")
|
|
|
|
assert len(self.mapped_dims) + len(dimension_semantics) == len(
|
|
|
|
self.grid
|
|
|
|
), (
|
|
|
|
f"Misconfigured grid: {self.mapped_dims=}, {dimension_semantics=},"
|
|
|
|
f" {self.grid=}"
|
|
|
|
)
|
|
|
|
# dimension_semantics is user provided and won't take into account vmap
|
|
|
|
# dimensions. Here we add in parallel dimensions for the vmaps.
|
|
|
|
semantics_iter = iter(dimension_semantics)
|
|
|
|
self._dimension_semantics = tuple(
|
|
|
|
next(semantics_iter) if i not in self.mapped_dims else "parallel"
|
|
|
|
for i in range(len(self.grid))
|
|
|
|
)
|
|
|
|
|
|
|
|
in_avals = [invar.aval for invar in self.jaxpr.invars]
|
|
|
|
scalar_prefetch_avals, operand_avals, scratch_avals = split_list(
|
|
|
|
in_avals, [num_scalar_prefetch, num_operands]
|
|
|
|
)
|
|
|
|
self.scalar_prefetch_types, _ = unzip2([
|
2024-01-18 17:19:38 -08:00
|
|
|
_get_arg_type(aval, None)
|
2023-11-17 18:04:16 -08:00
|
|
|
for aval in scalar_prefetch_avals])
|
|
|
|
self.scalar_prefetch_block_shapes = tuple(
|
|
|
|
aval.shape for aval in scalar_prefetch_avals)
|
|
|
|
self.operand_types, self.operand_block_shapes = unzip2([
|
2024-01-18 17:19:38 -08:00
|
|
|
_get_arg_type(aval, block_mapping)
|
2023-11-17 18:04:16 -08:00
|
|
|
for aval, block_mapping in zip(operand_avals, self.block_mappings)])
|
|
|
|
self.scratch_types, _ = unzip2([
|
2024-01-18 17:19:38 -08:00
|
|
|
_get_arg_type(aval, None) for aval in scratch_avals])
|
2023-11-17 18:04:16 -08:00
|
|
|
self.scratch_block_shapes = tuple(
|
|
|
|
aval.shape if not isinstance(aval, tpu_core.AbstractSemaphore) else None
|
|
|
|
for aval in scratch_avals
|
|
|
|
)
|
|
|
|
self.grid_types, _ = unzip2([
|
2024-01-18 17:19:38 -08:00
|
|
|
_get_arg_type(jax_core.ShapedArray((), jnp.int32), None)
|
2023-11-17 18:04:16 -08:00
|
|
|
for _ in range(len(self.grid))
|
|
|
|
])
|
|
|
|
self._prepare_mesh_info(mesh)
|
|
|
|
def _get_grid_indices(indices):
|
|
|
|
return indices
|
|
|
|
self.get_grid_indices = _get_grid_indices
|
|
|
|
|
|
|
|
def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None):
|
|
|
|
if not self.has_communication:
|
|
|
|
self.mesh_info = None
|
|
|
|
return
|
2023-10-02 17:03:40 -07:00
|
|
|
if mesh is None:
|
2023-11-17 18:04:16 -08:00
|
|
|
raise ValueError(
|
|
|
|
"Cannot use communication in pallas_call without shard_map."
|
|
|
|
)
|
2023-11-07 17:54:43 -08:00
|
|
|
axis_names = mesh.axis_names
|
2023-10-02 17:03:40 -07:00
|
|
|
# We need mesh <-> logical translation tables. Since the logical IDs are
|
|
|
|
# just linearized versions of the mesh IDs, we create those tables.
|
|
|
|
mesh_strides = pallas_utils.strides_from_shape(tuple(
|
|
|
|
mesh.shape[a] for a in axis_names
|
|
|
|
))
|
2024-03-22 07:11:45 -07:00
|
|
|
self.mesh_info = MeshInfo(mesh.device_ids.shape, axis_names, mesh_strides)
|
2023-11-17 18:04:16 -08:00
|
|
|
|
|
|
|
def maybe_compress_grid(self):
|
|
|
|
# If we have many leading parallel dimensions, we should "compress" them
|
|
|
|
# into one so we can load balance across cores as best as we can.
|
|
|
|
# TODO(sharadmv): implement this optimization
|
|
|
|
pass
|
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def has_communication(self) -> bool:
|
|
|
|
return bool(jax_core.used_axis_names_jaxpr(self.jaxpr))
|
|
|
|
|
|
|
|
def get_extra_args(self) -> tuple[Any, ...]:
|
2024-03-22 07:11:45 -07:00
|
|
|
return ()
|
2023-11-17 18:04:16 -08:00
|
|
|
|
|
|
|
def get_dimension_semantics(self) -> ir.ArrayAttr:
|
|
|
|
|
|
|
|
def _get_semantics(s: str | None) -> str:
|
|
|
|
if s is None:
|
|
|
|
return "#tpu.dimension_semantics<arbitrary>"
|
|
|
|
return f"#tpu.dimension_semantics<{s}>"
|
|
|
|
|
|
|
|
return ir.ArrayAttr.get(
|
|
|
|
map(
|
|
|
|
ir.Attribute.parse,
|
|
|
|
map(_get_semantics, self._dimension_semantics),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
class MeshInfo:
|
2024-03-22 07:11:45 -07:00
|
|
|
mesh_shape: tuple[int, ...]
|
2023-11-17 18:04:16 -08:00
|
|
|
axis_names: list[str]
|
|
|
|
mesh_strides: tuple[int, ...]
|
|
|
|
|
|
|
|
def lower_jaxpr_to_module(
|
|
|
|
ctx: ir.Context,
|
2024-01-18 17:19:38 -08:00
|
|
|
grid_mapping: pl_core.GridMapping,
|
2024-01-05 08:52:32 -08:00
|
|
|
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
|
|
|
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
2023-11-17 18:04:16 -08:00
|
|
|
jaxpr: jax_core.Jaxpr,
|
|
|
|
dimension_semantics: tuple[str | None, ...] | None,
|
|
|
|
mesh: mesh_lib.Mesh | None = None
|
2024-05-16 15:10:01 +01:00
|
|
|
) -> tuple[Module, tuple[Any, ...]]:
|
2023-11-17 18:04:16 -08:00
|
|
|
mosaic_grid_mapping = MosaicGridMapping(
|
|
|
|
jaxpr, grid_mapping, dimension_semantics, mesh)
|
|
|
|
mosaic_grid_mapping.maybe_compress_grid()
|
|
|
|
m = ir.Module.create()
|
|
|
|
sym_tab = ir.SymbolTable(m.operation)
|
|
|
|
func_op = lower_jaxpr_to_func(ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping,
|
|
|
|
name="main")
|
|
|
|
m.body.append(func_op)
|
|
|
|
sym_tab.insert(func_op)
|
2023-08-01 16:42:26 -07:00
|
|
|
window_params = []
|
2023-11-17 18:04:16 -08:00
|
|
|
grid = mosaic_grid_mapping.grid
|
2023-10-24 17:28:05 -07:00
|
|
|
if grid:
|
2024-01-18 17:19:38 -08:00
|
|
|
invars = jaxpr.invars
|
|
|
|
if grid_mapping.num_scratch_operands > 0:
|
|
|
|
invars = invars[
|
|
|
|
grid_mapping.num_index_operands:-grid_mapping.num_scratch_operands]
|
|
|
|
else:
|
|
|
|
invars = invars[grid_mapping.num_index_operands:]
|
|
|
|
avals = tuple(v.aval for v in invars)
|
2024-01-05 08:52:32 -08:00
|
|
|
block_operand_shapes = (
|
2024-01-08 08:08:12 -08:00
|
|
|
*in_shapes[grid_mapping.num_index_operands :],
|
2024-01-05 08:52:32 -08:00
|
|
|
*out_shapes,
|
|
|
|
)
|
|
|
|
assert len(block_operand_shapes) == len(grid_mapping.block_mappings)
|
2024-01-18 17:19:38 -08:00
|
|
|
for i, (full_ty, bm, aval) in enumerate(
|
|
|
|
zip(block_operand_shapes, grid_mapping.block_mappings, avals)
|
2024-01-05 08:52:32 -08:00
|
|
|
):
|
2023-10-24 17:28:05 -07:00
|
|
|
func_name = f"transform_{i}"
|
2024-02-01 09:14:30 -08:00
|
|
|
if bm is None:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"BlockSpecs are required on TPU when grid is specified"
|
|
|
|
)
|
2023-10-24 17:28:05 -07:00
|
|
|
if bm.index_map_jaxpr.consts:
|
|
|
|
raise NotImplementedError("Index map jaxpr with consts not supported.")
|
2024-01-05 08:52:32 -08:00
|
|
|
# ANY operands don't support windowing and require empty window_params.
|
2024-01-18 17:19:38 -08:00
|
|
|
if aval.memory_space == tpu_core.TPUMemorySpace.ANY:
|
2024-04-02 16:30:59 -07:00
|
|
|
# We may not require windowing if our block_shape matches the original
|
|
|
|
# shape or the dimensions are mapped.
|
|
|
|
requires_windowing = any(
|
|
|
|
b != s
|
|
|
|
for b, s in zip(bm.block_shape, full_ty.shape)
|
|
|
|
if not (b is pl_core.mapped and s == 1)
|
|
|
|
)
|
|
|
|
if np.prod(grid) != 1:
|
|
|
|
for atom in bm.index_map_jaxpr.jaxpr.outvars:
|
|
|
|
if requires_windowing:
|
|
|
|
break
|
|
|
|
requires_windowing = not (
|
|
|
|
isinstance(atom, jax_core.Literal) and atom.val == 0
|
|
|
|
)
|
2024-01-05 08:52:32 -08:00
|
|
|
if requires_windowing:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Operands in placed in the TPUMemorySpace.ANY memory space don't"
|
|
|
|
" support windowing (i.e. non-trivial block_shape or index_map)."
|
|
|
|
)
|
|
|
|
window_params.append(ir.DictAttr.get())
|
|
|
|
continue
|
2023-10-24 17:28:05 -07:00
|
|
|
mlir_func = lower_jaxpr_to_transform_func(
|
|
|
|
ctx,
|
|
|
|
bm.index_map_jaxpr.jaxpr,
|
2023-11-07 17:54:43 -08:00
|
|
|
name=func_name,
|
2023-11-17 18:04:16 -08:00
|
|
|
mosaic_grid_mapping=mosaic_grid_mapping,
|
2023-11-07 17:54:43 -08:00
|
|
|
)
|
2023-10-24 17:28:05 -07:00
|
|
|
assert mlir_func.verify(), mlir_func
|
|
|
|
block_shape = [
|
2024-01-18 17:19:38 -08:00
|
|
|
1 if b is pl_core.mapped else b for b in bm.block_shape
|
2023-10-24 17:28:05 -07:00
|
|
|
]
|
|
|
|
window_shape = ir.DenseI64ArrayAttr.get(block_shape)
|
2024-02-08 02:47:56 -08:00
|
|
|
block_params = dict(
|
|
|
|
window_bounds=window_shape,
|
|
|
|
transform_indices=ir.FlatSymbolRefAttr.get(func_name),
|
2023-10-24 17:28:05 -07:00
|
|
|
)
|
2024-02-08 02:47:56 -08:00
|
|
|
if isinstance(bm.indexing_mode, pl_core.Unblocked):
|
|
|
|
if bm.indexing_mode.padding is None:
|
|
|
|
pad_low = pad_high = [0] * len(bm.block_shape)
|
|
|
|
else:
|
|
|
|
pad_low, pad_high = map(list, zip(*bm.indexing_mode.padding))
|
|
|
|
block_params["window_kind"] = ir.Attribute.parse(
|
|
|
|
f"#tpu.element_window<{pad_low},{pad_high}>"
|
|
|
|
)
|
|
|
|
window_params.append(ir.DictAttr.get(block_params))
|
2023-10-24 17:28:05 -07:00
|
|
|
m.body.append(mlir_func)
|
|
|
|
sym_tab.insert(mlir_func)
|
|
|
|
func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params)
|
2024-06-17 15:17:52 -07:00
|
|
|
static_grid = [
|
|
|
|
MLIR_DYNAMIC if b is pl_core.dynamic_grid_dim else b for b in grid
|
|
|
|
]
|
2024-02-01 09:14:30 -08:00
|
|
|
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)
|
2023-10-24 17:28:05 -07:00
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
|
2023-11-17 18:04:16 -08:00
|
|
|
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types))
|
2023-10-24 17:28:05 -07:00
|
|
|
func_op.attributes["scratch_operands"] = ir.IntegerAttr.get(
|
2023-11-17 18:04:16 -08:00
|
|
|
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scratch_types))
|
|
|
|
func_op.attributes["dimension_semantics"] = (
|
|
|
|
mosaic_grid_mapping.get_dimension_semantics()
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
2023-11-17 18:04:16 -08:00
|
|
|
return m, mosaic_grid_mapping.get_extra_args()
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def lower_jaxpr_to_transform_func(
|
2023-11-07 17:54:43 -08:00
|
|
|
ctx: ir.Context,
|
|
|
|
jaxpr: jax_core.Jaxpr,
|
|
|
|
*,
|
|
|
|
name: str,
|
2023-11-17 18:04:16 -08:00
|
|
|
mosaic_grid_mapping: MosaicGridMapping,
|
2023-11-07 17:54:43 -08:00
|
|
|
) -> func.FuncOp:
|
2023-11-17 18:04:16 -08:00
|
|
|
num_grid = len(mosaic_grid_mapping.grid_types)
|
|
|
|
arg_types = [
|
|
|
|
*mosaic_grid_mapping.grid_types,
|
|
|
|
*mosaic_grid_mapping.scalar_prefetch_types,
|
|
|
|
]
|
|
|
|
def body_func(*args):
|
|
|
|
grid_indices, scalar_prefetch = split_list(args, [num_grid])
|
|
|
|
jaxpr_indices = mosaic_grid_mapping.get_grid_indices(grid_indices)
|
|
|
|
arg_block_shapes = [
|
|
|
|
*[()] * len(jaxpr_indices),
|
|
|
|
*mosaic_grid_mapping.scalar_prefetch_block_shapes,
|
|
|
|
]
|
2023-11-07 17:54:43 -08:00
|
|
|
|
2023-11-17 18:04:16 -08:00
|
|
|
mesh_info = mosaic_grid_mapping.mesh_info
|
|
|
|
if mesh_info is not None:
|
2024-03-22 07:11:45 -07:00
|
|
|
mesh_context = MeshContext(
|
|
|
|
mesh_info.mesh_shape, mesh_info.axis_names, mesh_info.mesh_strides
|
|
|
|
)
|
2023-11-17 18:04:16 -08:00
|
|
|
else:
|
|
|
|
mesh_context = None
|
|
|
|
lowering_context = LoweringContext(
|
|
|
|
ctx,
|
2024-03-11 08:40:34 -07:00
|
|
|
len(mosaic_grid_mapping.grid),
|
|
|
|
mosaic_grid_mapping.mapped_dims,
|
2023-11-17 18:04:16 -08:00
|
|
|
None,
|
|
|
|
arg_block_shapes,
|
|
|
|
source_info_util.NameStack(),
|
|
|
|
mesh_context=mesh_context,
|
2024-01-03 09:00:29 -08:00
|
|
|
traceback_caches=mlir.TracebackCaches(),
|
2023-11-07 17:54:43 -08:00
|
|
|
)
|
2023-11-17 18:04:16 -08:00
|
|
|
return jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices,
|
|
|
|
*scalar_prefetch)
|
|
|
|
body_func.__name__ = name
|
|
|
|
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
|
2024-01-02 21:53:30 -08:00
|
|
|
try:
|
|
|
|
body.func_op.verify()
|
|
|
|
except Exception as e:
|
|
|
|
raise LoweringException(
|
|
|
|
f"Body failed to verify: {body.func_op}.\nThis is an internal error."
|
|
|
|
" Please report a bug at:"
|
|
|
|
" https://github.com/google/jax/issues/new?assignees=sharadmv."
|
|
|
|
) from e
|
2023-11-17 18:04:16 -08:00
|
|
|
return body.func_op
|
2023-11-07 17:54:43 -08:00
|
|
|
|
|
|
|
|
2023-11-17 18:04:16 -08:00
|
|
|
def lower_jaxpr_to_func(
|
|
|
|
ctx: ir.Context,
|
|
|
|
jaxpr: jax_core.Jaxpr,
|
|
|
|
*,
|
|
|
|
mosaic_grid_mapping: MosaicGridMapping,
|
|
|
|
name: str,
|
|
|
|
) -> func.FuncOp:
|
|
|
|
num_grid = len(mosaic_grid_mapping.grid_types)
|
|
|
|
num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types)
|
|
|
|
arg_types = [
|
|
|
|
*mosaic_grid_mapping.grid_types,
|
|
|
|
*mosaic_grid_mapping.scalar_prefetch_types,
|
|
|
|
*mosaic_grid_mapping.operand_types,
|
|
|
|
*mosaic_grid_mapping.scratch_types,
|
|
|
|
]
|
|
|
|
arg_block_shapes = [
|
|
|
|
*mosaic_grid_mapping.scalar_prefetch_block_shapes,
|
|
|
|
*mosaic_grid_mapping.operand_block_shapes,
|
|
|
|
*mosaic_grid_mapping.scratch_block_shapes,
|
|
|
|
]
|
|
|
|
def body_func(*args):
|
|
|
|
grid_indices, scalar_prefetch, operands_and_scratch = split_list(
|
|
|
|
args, [num_grid, num_scalar_prefetch])
|
|
|
|
grid_indices = mosaic_grid_mapping.get_grid_indices(grid_indices)
|
|
|
|
jaxpr_indices = tuple(idx for i, idx in enumerate(grid_indices)
|
|
|
|
if i not in mosaic_grid_mapping.mapped_dims)
|
|
|
|
mesh_info = mosaic_grid_mapping.mesh_info
|
|
|
|
if mesh_info is not None:
|
2024-03-22 07:11:45 -07:00
|
|
|
mesh_context = MeshContext(
|
|
|
|
mesh_info.mesh_shape, mesh_info.axis_names, mesh_info.mesh_strides
|
|
|
|
)
|
2023-11-17 18:04:16 -08:00
|
|
|
else:
|
|
|
|
mesh_context = None
|
2023-11-07 17:54:43 -08:00
|
|
|
lowering_context = LoweringContext(
|
|
|
|
ctx,
|
2024-03-11 08:40:34 -07:00
|
|
|
len(mosaic_grid_mapping.grid),
|
|
|
|
mosaic_grid_mapping.mapped_dims,
|
2023-11-17 18:04:16 -08:00
|
|
|
jaxpr_indices,
|
|
|
|
arg_block_shapes,
|
2023-11-07 17:54:43 -08:00
|
|
|
source_info_util.NameStack(),
|
2023-11-17 18:04:16 -08:00
|
|
|
mesh_context=mesh_context,
|
2024-01-03 09:00:29 -08:00
|
|
|
traceback_caches=mlir.TracebackCaches(),
|
2023-11-17 18:04:16 -08:00
|
|
|
)
|
|
|
|
return jaxpr_subcomp(
|
|
|
|
lowering_context, jaxpr, *scalar_prefetch, *operands_and_scratch
|
2023-11-07 17:54:43 -08:00
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
body_func.__name__ = name
|
|
|
|
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
|
2024-01-02 21:53:30 -08:00
|
|
|
try:
|
|
|
|
body.func_op.verify()
|
|
|
|
except Exception as e:
|
|
|
|
raise LoweringException(
|
|
|
|
f"Body failed to verify: {body.func_op}.\nThis is an internal error."
|
|
|
|
" Please report a bug at:"
|
|
|
|
" https://github.com/google/jax/issues/new?assignees=sharadmv."
|
|
|
|
) from e
|
2023-08-01 16:42:26 -07:00
|
|
|
return body.func_op
|
|
|
|
|
2023-09-27 13:33:04 -07:00
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable:
|
|
|
|
def f_lowered(ctx: LoweringRuleContext, *args, **params):
|
|
|
|
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
|
|
|
wrapped_fun = lu.wrap_init(f, params)
|
2024-01-25 22:20:36 -08:00
|
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
2023-08-01 16:42:26 -07:00
|
|
|
if consts:
|
|
|
|
raise NotImplementedError
|
|
|
|
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
|
|
|
lowering_context = ctx.lowering_context.replace(
|
|
|
|
block_shapes=ctx.block_shapes)
|
|
|
|
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
|
|
|
|
if not multiple_results:
|
|
|
|
return out[0]
|
|
|
|
return out
|
|
|
|
|
|
|
|
return f_lowered
|
|
|
|
|
|
|
|
|
2023-09-27 13:33:04 -07:00
|
|
|
class LoweringException(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2024-04-25 18:33:52 -07:00
|
|
|
def _compute_name_stack_updates(
|
|
|
|
old_name_stack: list[str],
|
|
|
|
new_name_stack: list[str]
|
|
|
|
) -> tuple[list[str], list[str]]:
|
|
|
|
"""Computes the popped/pushed items to the name stack after an update.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
old_name_stack: The name stack prior to the update.
|
|
|
|
new_name_stack: The name stack after the update.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
popped: A list of names popped from the name stack as part of the update.
|
|
|
|
pushed: A list of names pushed to the name stack as part of the update.
|
|
|
|
"""
|
|
|
|
common_prefix_idx = 0
|
|
|
|
for i, (old, new) in enumerate(unsafe_zip(old_name_stack, new_name_stack)):
|
|
|
|
if old == new:
|
|
|
|
common_prefix_idx = i+1
|
|
|
|
else:
|
|
|
|
break
|
|
|
|
return old_name_stack[common_prefix_idx:], new_name_stack[common_prefix_idx:]
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def jaxpr_subcomp(
|
|
|
|
ctx: LoweringContext, jaxpr: jax_core.Jaxpr, *args: ir.Value
|
|
|
|
) -> Sequence[ir.Value]:
|
|
|
|
assert not jaxpr.constvars
|
|
|
|
env = {}
|
|
|
|
block_shape_env = {}
|
|
|
|
|
|
|
|
def read_block_shape(atom: jax_core.Atom):
|
|
|
|
if isinstance(atom, jax_core.Literal):
|
|
|
|
return None
|
|
|
|
return block_shape_env.get(atom, None)
|
|
|
|
|
|
|
|
def read_env(atom: jax_core.Atom):
|
|
|
|
return atom.val if isinstance(atom, jax_core.Literal) else env[atom]
|
|
|
|
|
|
|
|
def write_env(var: jax_core.Var, val):
|
2024-06-12 14:36:31 -07:00
|
|
|
is_valid_type = isinstance(val, (ir.Value, KeyScalarBundle))
|
|
|
|
assert is_valid_type, type(val)
|
2023-08-01 16:42:26 -07:00
|
|
|
env[var] = val
|
|
|
|
|
|
|
|
for invar, bs in zip(jaxpr.invars, ctx.block_shapes):
|
|
|
|
block_shape_env[invar] = bs
|
|
|
|
map(write_env, jaxpr.invars, args)
|
|
|
|
|
2024-06-12 16:32:30 -07:00
|
|
|
initial_name_stack = [scope.name for scope in ctx.name_stack.stack]
|
2024-04-25 18:33:52 -07:00
|
|
|
current_name_stack: list[str] = []
|
|
|
|
# TODO(justinfu): Handle transform scopes.
|
2024-06-12 16:32:30 -07:00
|
|
|
current_name_stack.extend(initial_name_stack)
|
2023-08-01 16:42:26 -07:00
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
invals = map(read_env, eqn.invars)
|
|
|
|
source_info = eqn.source_info.replace(
|
|
|
|
name_stack=ctx.name_stack + eqn.source_info.name_stack
|
|
|
|
)
|
|
|
|
loc = mlir._source_info_to_location(
|
2024-01-03 09:00:29 -08:00
|
|
|
ctx, eqn.primitive, eqn.params, source_info
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
|
|
|
if eqn.primitive in lowering_rules:
|
2023-09-06 02:14:42 -07:00
|
|
|
if eqn.primitive not in skip_mlir_conversions:
|
|
|
|
invals = [_ensure_mlir_value(x, v.aval)
|
|
|
|
for x, v in zip(invals, eqn.invars)]
|
2023-08-01 16:42:26 -07:00
|
|
|
block_shapes = map(read_block_shape, eqn.invars)
|
|
|
|
rule_context = LoweringRuleContext(
|
|
|
|
ctx,
|
|
|
|
[v.aval for v in eqn.invars],
|
|
|
|
[v.aval for v in eqn.outvars],
|
|
|
|
block_shapes,
|
|
|
|
)
|
2024-04-25 18:33:52 -07:00
|
|
|
|
|
|
|
# Insert trace_start and trace_stop ops on named_scope boundaries.
|
|
|
|
name_stack = [scope.name for scope in source_info.name_stack.stack]
|
|
|
|
popped, pushed = _compute_name_stack_updates(
|
|
|
|
current_name_stack, name_stack)
|
|
|
|
current_name_stack = name_stack
|
|
|
|
for _ in popped:
|
2024-04-26 17:33:26 -07:00
|
|
|
tpu.TraceStopOp()
|
2024-04-25 18:33:52 -07:00
|
|
|
for name in pushed:
|
2024-04-26 17:33:26 -07:00
|
|
|
tpu.TraceStartOp(message=name, level=10)
|
2024-04-25 18:33:52 -07:00
|
|
|
|
2023-09-27 13:33:04 -07:00
|
|
|
try:
|
|
|
|
ans = lowering_rules[eqn.primitive](
|
|
|
|
rule_context, *invals, **eqn.params
|
|
|
|
)
|
|
|
|
except LoweringException:
|
|
|
|
raise # We only add the extra info to the innermost exception.
|
|
|
|
except Exception as e:
|
|
|
|
raise LoweringException(
|
|
|
|
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "
|
|
|
|
f" {rule_context}\nWith inval"
|
|
|
|
f" shapes={map(lambda t: getattr(t, 'shape', None), invals)}\nWith"
|
|
|
|
" inval"
|
|
|
|
f" types={map(lambda t: getattr(t, 'type', None), invals)}\nIn"
|
|
|
|
f" jaxpr:\n{jaxpr}"
|
2024-05-01 10:43:11 -07:00
|
|
|
f"\nException: {e}"
|
2023-09-27 13:33:04 -07:00
|
|
|
) from e
|
2023-08-01 16:42:26 -07:00
|
|
|
else:
|
2023-08-04 16:11:29 -07:00
|
|
|
raise NotImplementedError(
|
|
|
|
"Unimplemented primitive in Pallas TPU lowering: "
|
|
|
|
f"{eqn.primitive.name}. "
|
|
|
|
"Please file an issue on https://github.com/google/jax/issues.")
|
2023-08-01 16:42:26 -07:00
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
map(write_env, eqn.outvars, ans)
|
|
|
|
else:
|
|
|
|
write_env(eqn.outvars[0], ans)
|
2024-06-12 16:32:30 -07:00
|
|
|
|
|
|
|
# Drain the name stack at the end of a jaxpr and insert trace_stop ops.
|
|
|
|
popped, pushed = _compute_name_stack_updates(
|
|
|
|
current_name_stack, initial_name_stack)
|
|
|
|
for _ in popped:
|
|
|
|
tpu.TraceStopOp()
|
|
|
|
assert len(pushed) == 0
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
outvals = map(read_env, jaxpr.outvars)
|
|
|
|
outvals = [
|
|
|
|
ir_constant(x) if isinstance(var, jax_core.Literal) else x
|
|
|
|
for x, var in zip(outvals, jaxpr.outvars)
|
|
|
|
]
|
|
|
|
return outvals
|
|
|
|
|
|
|
|
|
2023-09-06 02:14:42 -07:00
|
|
|
def _ensure_mlir_value(val, aval):
|
|
|
|
if isinstance(val, ir.Value):
|
|
|
|
return val
|
2024-06-12 14:36:31 -07:00
|
|
|
if isinstance(val, KeyScalarBundle):
|
|
|
|
return val
|
2023-09-06 02:14:42 -07:00
|
|
|
elif isinstance(val, (np.generic, np.ndarray, int, float)):
|
2024-01-11 06:32:57 -08:00
|
|
|
return ir_constant(val, _dtype_to_ir_type(aval.dtype))
|
2023-09-06 02:14:42 -07:00
|
|
|
else:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Unsupported argument to a JAX primitive of type: {type(val)}"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-08-10 21:18:10 -07:00
|
|
|
def _convert_flat_indexing_to_indexer(ref_aval, non_slice_idx,
|
|
|
|
non_slice_idx_avals, indexed_dims):
|
|
|
|
non_slice_idx_iter = iter(zip(non_slice_idx, non_slice_idx_avals))
|
|
|
|
splatted_idx_idx_avals = tuple(
|
|
|
|
next(non_slice_idx_iter)
|
|
|
|
if indexed
|
|
|
|
else (primitives.Slice(0, s), primitives.Slice(0, s))
|
|
|
|
for s, indexed in zip(ref_aval.shape,indexed_dims)
|
|
|
|
)
|
|
|
|
splatted_idx, splatted_idx_avals = unzip2(splatted_idx_idx_avals)
|
|
|
|
if non_slice_idx:
|
2023-09-21 22:19:29 +01:00
|
|
|
(int_indexer_shape,) = {idx_aval.shape for idx_aval in splatted_idx_avals
|
|
|
|
if not isinstance(idx_aval, primitives.Slice)}
|
2023-08-10 21:18:10 -07:00
|
|
|
else:
|
|
|
|
int_indexer_shape = ()
|
|
|
|
nd_indexer = NDIndexer(splatted_idx, ref_aval.shape, int_indexer_shape)
|
|
|
|
nd_indexer_avals = NDIndexer(splatted_idx_avals, ref_aval.shape,
|
|
|
|
int_indexer_shape)
|
|
|
|
return nd_indexer, nd_indexer_avals
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _get_lowering_rule(
|
2024-01-02 15:52:57 -08:00
|
|
|
ctx: LoweringRuleContext, ref, *idx, tree,
|
2023-08-01 16:42:26 -07:00
|
|
|
):
|
2024-01-02 15:52:57 -08:00
|
|
|
indexers = tree_util.tree_unflatten(tree, idx)
|
|
|
|
indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:])
|
2023-08-10 21:18:10 -07:00
|
|
|
# Call _load_lowering_rule (since it's more general)
|
2024-01-02 21:53:30 -08:00
|
|
|
ref_aval, *_ = ctx.avals_in
|
|
|
|
args_flat, args_tree = tree_util.tree_flatten((ref, indexers, None, None))
|
|
|
|
avals_flat = tree_util.tree_leaves((ref_aval, indexers_avals, None, None))
|
|
|
|
ctx = ctx.replace(
|
|
|
|
avals_in=avals_flat,
|
|
|
|
block_shapes=[ctx.block_shapes[0], *[None] * (len(avals_flat) - 1)],
|
|
|
|
)
|
2023-10-10 14:38:54 -07:00
|
|
|
return _load_lowering_rule(ctx, *args_flat, args_tree=args_tree)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[state_primitives.get_p] = _get_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(state_primitives.get_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2023-08-10 21:18:10 -07:00
|
|
|
def _swap_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
ref,
|
|
|
|
val,
|
2024-01-02 15:52:57 -08:00
|
|
|
*idx,
|
|
|
|
tree
|
2023-08-10 21:18:10 -07:00
|
|
|
):
|
2024-01-02 15:52:57 -08:00
|
|
|
indexers = tree_util.tree_unflatten(tree, idx)
|
|
|
|
indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[2:])
|
2023-08-10 21:18:10 -07:00
|
|
|
# Call _masked_swap_lowering_rule (since it's more general)
|
2024-01-02 15:52:57 -08:00
|
|
|
ref_aval, val_aval, *_ = ctx.avals_in
|
2024-01-02 21:53:30 -08:00
|
|
|
args_flat, args_tree = tree_util.tree_flatten((ref, indexers, val, None))
|
2023-10-10 14:38:54 -07:00
|
|
|
avals_flat = tree_util.tree_leaves(
|
2024-01-02 21:53:30 -08:00
|
|
|
(ref_aval, indexers_avals, val_aval, None)
|
|
|
|
)
|
|
|
|
ctx = ctx.replace(
|
|
|
|
avals_in=avals_flat,
|
|
|
|
block_shapes=[ctx.block_shapes[0], *[None] * (len(avals_flat) - 1)],
|
2023-10-10 14:38:54 -07:00
|
|
|
)
|
|
|
|
return _masked_swap_lowering_rule(ctx, *args_flat, args_tree=args_tree)
|
2023-08-10 21:18:10 -07:00
|
|
|
|
|
|
|
lowering_rules[state_primitives.swap_p] = _swap_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(state_primitives.swap_p)
|
2023-08-10 21:18:10 -07:00
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _make_index(s):
|
|
|
|
if isinstance(s, (int, np.ndarray)):
|
|
|
|
return ir_constant(s, ir.IndexType.get())
|
|
|
|
if s.type == ir.IndexType.get():
|
|
|
|
return s
|
|
|
|
return arith.IndexCastOp(ir.IndexType.get(), s).result
|
|
|
|
|
|
|
|
|
2024-01-02 21:53:30 -08:00
|
|
|
def _maybe_cast_to_index(cast_to_index, x):
|
|
|
|
if cast_to_index:
|
|
|
|
return _make_index(x)
|
|
|
|
return _ensure_mlir_value(x, aval=jax_core.ShapedArray((), jnp.int32))
|
|
|
|
|
2024-03-14 16:31:23 -07:00
|
|
|
|
|
|
|
def _index_to_start_size_stride(
|
|
|
|
idx: tuple[indexing.Slice | int | ir.Value, ...], cast_to_index: bool
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
) -> tuple[ir.Value, int | ir.Value, int, bool]:
|
2024-01-02 21:53:30 -08:00
|
|
|
assert not isinstance(idx, slice)
|
|
|
|
if isinstance(idx, indexing.Slice):
|
|
|
|
start = _maybe_cast_to_index(cast_to_index, idx.start)
|
|
|
|
size = idx.size
|
2024-03-14 16:31:23 -07:00
|
|
|
stride = idx.stride
|
2024-01-02 21:53:30 -08:00
|
|
|
squeeze = False
|
|
|
|
elif isinstance(idx, int):
|
|
|
|
start = _maybe_cast_to_index(cast_to_index, idx)
|
|
|
|
size = 1
|
2024-03-14 16:31:23 -07:00
|
|
|
stride = 1
|
2024-01-02 21:53:30 -08:00
|
|
|
squeeze = True
|
|
|
|
else:
|
|
|
|
if np.shape(idx):
|
|
|
|
raise ValueError(f"Can only use ()-shaped and slice indexing: {idx}")
|
|
|
|
start = _maybe_cast_to_index(cast_to_index, idx)
|
|
|
|
size = 1
|
2024-03-14 16:31:23 -07:00
|
|
|
stride = 1
|
2024-01-02 21:53:30 -08:00
|
|
|
squeeze = True
|
2024-03-14 16:31:23 -07:00
|
|
|
return start, size, stride, squeeze
|
2024-01-02 21:53:30 -08:00
|
|
|
|
|
|
|
|
2024-03-14 16:31:23 -07:00
|
|
|
def _indexer_to_start_size_stride(
|
|
|
|
indexer: NDIndexer,
|
|
|
|
ref_block_shape: tuple[int | pl_core.Mapped, ...],
|
|
|
|
*,
|
2024-01-02 21:53:30 -08:00
|
|
|
cast_to_index: bool,
|
2024-03-14 16:31:23 -07:00
|
|
|
) -> tuple[
|
|
|
|
tuple[ir.Value, ...],
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
tuple[int | ir.Value, ...],
|
2024-03-14 16:31:23 -07:00
|
|
|
tuple[int, ...],
|
|
|
|
tuple[bool, ...],
|
|
|
|
tuple[int | pl_core.Mapped, ...],
|
|
|
|
]:
|
2024-01-02 21:53:30 -08:00
|
|
|
indices_iter = iter(indexer.indices)
|
2024-03-14 16:31:23 -07:00
|
|
|
starts, sizes, strides, squeeze_dims = [], [], [], []
|
|
|
|
for s in ref_block_shape:
|
|
|
|
start, size, stride, squeeze_dim = (
|
|
|
|
(
|
|
|
|
_maybe_cast_to_index(cast_to_index, 0),
|
|
|
|
1,
|
|
|
|
1,
|
|
|
|
True,
|
|
|
|
)
|
|
|
|
if s is pl_core.mapped
|
|
|
|
else _index_to_start_size_stride(next(indices_iter), cast_to_index)
|
|
|
|
)
|
|
|
|
starts.append(start)
|
|
|
|
sizes.append(size)
|
|
|
|
strides.append(stride)
|
|
|
|
squeeze_dims.append(squeeze_dim)
|
2024-02-12 18:05:31 -08:00
|
|
|
next_index = next(indices_iter, None)
|
|
|
|
assert next_index is None, (indexer.indices, ref_block_shape)
|
2024-01-02 21:53:30 -08:00
|
|
|
new_ref_block_shape = tuple(s for s, squeeze in zip(sizes, squeeze_dims)
|
|
|
|
if not squeeze)
|
2024-03-14 16:31:23 -07:00
|
|
|
return (
|
|
|
|
tuple(starts),
|
|
|
|
tuple(sizes),
|
|
|
|
tuple(strides),
|
|
|
|
tuple(squeeze_dims),
|
|
|
|
new_ref_block_shape,
|
|
|
|
)
|
2024-01-02 21:53:30 -08:00
|
|
|
|
|
|
|
|
|
|
|
def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
|
|
|
|
indexer: NDIndexer,
|
2024-01-18 17:19:38 -08:00
|
|
|
ref_block_shape: tuple[int | pl_core.Mapped, ...]
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
) -> tuple[ir.Value, tuple[int | pl_core.Mapped, ...],
|
2024-01-18 17:19:38 -08:00
|
|
|
tuple[int | pl_core.Mapped, ...]]:
|
2024-01-02 21:53:30 -08:00
|
|
|
assert ref_block_shape is not None
|
|
|
|
target_shape = indexer.get_indexer_shape()
|
2024-03-14 16:31:23 -07:00
|
|
|
starts, sizes, strides, squeeze_dims, ref_block_shape = (
|
|
|
|
_indexer_to_start_size_stride(
|
|
|
|
indexer,
|
|
|
|
ref_block_shape,
|
|
|
|
cast_to_index=False,
|
|
|
|
)
|
2024-01-02 21:53:30 -08:00
|
|
|
)
|
2024-03-14 16:31:23 -07:00
|
|
|
if not all((s is None or s == 1) for s in strides):
|
|
|
|
raise NotImplementedError("Strided slices of references are unsupported.")
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
dynamic_sizes = tuple(s for s in sizes if isinstance(s, ir.Value))
|
|
|
|
ir_dynamic_size = ir.ShapedType.get_dynamic_size()
|
|
|
|
static_sizes = tuple(s if not isinstance(s, ir.Value)
|
|
|
|
else ir_dynamic_size for s in sizes)
|
2024-01-02 21:53:30 -08:00
|
|
|
target_ref_ty = ir.MemRefType.get(
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
static_sizes, _dtype_to_ir_type(ref_aval.dtype),
|
2024-01-02 21:53:30 -08:00
|
|
|
memory_space=ref.type.memory_space)
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
out = tpu.MemRefSliceOp(target_ref_ty, ref, starts, dynamic_sizes).result
|
2024-01-02 21:53:30 -08:00
|
|
|
if any(squeeze_dims):
|
|
|
|
# We need to squeeze out some dimensions
|
2024-04-02 16:30:59 -07:00
|
|
|
static_sizes = tuple(s if not isinstance(s, ir.Value)
|
|
|
|
else ir_dynamic_size for s in target_shape)
|
2024-01-02 21:53:30 -08:00
|
|
|
squeezed_ref_ty = ir.MemRefType.get(
|
2024-04-02 16:30:59 -07:00
|
|
|
static_sizes, _dtype_to_ir_type(ref_aval.dtype),
|
2024-01-02 21:53:30 -08:00
|
|
|
memory_space=ref.type.memory_space)
|
|
|
|
out = tpu.MemRefSqueezeOp(squeezed_ref_ty, out).result
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
return out, ref_block_shape
|
2024-01-02 21:53:30 -08:00
|
|
|
|
|
|
|
|
2024-01-11 06:32:57 -08:00
|
|
|
def _index_ref(ref, ref_aval, ref_block_shape, indexers):
|
|
|
|
for indexer in indexers:
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
ref, ref_block_shape = _slice_memref(ref, ref_aval, indexer,
|
|
|
|
ref_block_shape)
|
|
|
|
return ref, ref_block_shape
|
2024-01-11 06:32:57 -08:00
|
|
|
|
2024-06-12 14:36:31 -07:00
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
class KeyScalarBundle:
|
|
|
|
"""A container class for PRNG key data.
|
|
|
|
|
|
|
|
We pass around keys as a KeyScalarBundle in the lowering pass rather than
|
|
|
|
as a vector, since we want the key data to live in scalar registers rather
|
|
|
|
than vector registers. This special dataclass exists so we can return
|
|
|
|
multiple scalar values from load_op, because the load_op primitive does
|
|
|
|
not allow multiple results.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
scalars: A list of OpResults representing scalar key data during the
|
|
|
|
lowering pass.
|
|
|
|
"""
|
|
|
|
scalars: list[ir.OpResult]
|
2024-01-11 06:32:57 -08:00
|
|
|
|
2023-10-10 14:38:54 -07:00
|
|
|
def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
|
2024-01-02 21:53:30 -08:00
|
|
|
ref, indexers, mask, _ = args_tree.unflatten(args_flat)
|
|
|
|
ref_aval, indexers_avals, _, _ = args_tree.unflatten(ctx.avals_in)
|
|
|
|
(*slice_indexers, idx) = indexers
|
|
|
|
# Select last aval, which is the one that will be used for the load.
|
|
|
|
(*_, idx_aval) = indexers_avals
|
2023-10-10 14:38:54 -07:00
|
|
|
|
|
|
|
if mask is not None:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2024-01-02 21:53:30 -08:00
|
|
|
ref_block_shape, *_ = ctx.block_shapes
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
ref, ref_block_shape = _index_ref(
|
2024-01-11 06:32:57 -08:00
|
|
|
ref, ref_aval, ref_block_shape, slice_indexers)
|
2023-08-01 16:42:26 -07:00
|
|
|
ref_type = ir.MemRefType(ref.type)
|
|
|
|
is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space<smem>"
|
|
|
|
ref_aval, *_ = ctx.avals_in
|
|
|
|
(aval_out,) = ctx.avals_out
|
2024-06-12 14:36:31 -07:00
|
|
|
if isinstance(aval_out.dtype, prng.KeyTy):
|
|
|
|
if not is_smem_load:
|
|
|
|
raise ValueError("PRNG keys must be loaded from SMEM. Did you set "
|
|
|
|
"the memory space to TPUMemorySpace.SMEM in the "
|
|
|
|
"BlockSpec for the PRNG key input?")
|
|
|
|
return _prng_key_load_lowering_rule(ctx, *args_flat, args_tree=args_tree)
|
2023-09-13 15:32:01 -07:00
|
|
|
if not is_smem_load and not ref_block_shape:
|
2023-08-10 21:18:10 -07:00
|
|
|
raise NotImplementedError(
|
|
|
|
"Indexing into a ()-shaped Ref not yet supported on TPU.")
|
2023-08-01 16:42:26 -07:00
|
|
|
if any(
|
2023-10-10 14:38:54 -07:00
|
|
|
(not isinstance(a, primitives.Slice) and a.shape)
|
2023-08-01 16:42:26 -07:00
|
|
|
for a in idx_aval.indices
|
|
|
|
):
|
|
|
|
raise ValueError("Cannot do int indexing on TPU")
|
2024-03-14 16:31:23 -07:00
|
|
|
starts, sizes, strides, _, _ = _indexer_to_start_size_stride(
|
|
|
|
idx,
|
|
|
|
ref_block_shape,
|
|
|
|
cast_to_index=True,
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
2024-03-14 16:31:23 -07:00
|
|
|
need_stride = not all((s is None or s == 1) for s in strides)
|
2024-01-02 21:53:30 -08:00
|
|
|
load_aval = jax_core.ShapedArray(sizes, dtype=ref_aval.dtype)
|
2023-08-01 16:42:26 -07:00
|
|
|
if is_smem_load:
|
|
|
|
if ctx.avals_out[0].shape:
|
2023-09-07 04:25:44 -07:00
|
|
|
raise ValueError("Can only load scalars from SMEM")
|
2024-01-02 21:53:30 -08:00
|
|
|
return memref.LoadOp(ref, starts).result
|
2024-03-14 16:31:23 -07:00
|
|
|
if need_stride:
|
|
|
|
load_val = tpu.StridedLoadOp(
|
|
|
|
aval_to_ir_type(load_aval), ref, starts, strides
|
|
|
|
).result
|
2023-08-01 16:42:26 -07:00
|
|
|
else:
|
2024-01-02 21:53:30 -08:00
|
|
|
load_val = vector.LoadOp(aval_to_ir_type(load_aval), ref, starts).result
|
2023-08-01 16:42:26 -07:00
|
|
|
if load_aval == aval_out:
|
|
|
|
return load_val
|
2023-08-10 21:18:10 -07:00
|
|
|
vec_type = ir.VectorType.get(aval_out.shape,
|
2024-01-11 06:32:57 -08:00
|
|
|
_dtype_to_ir_type(aval_out.dtype))
|
2023-08-10 21:18:10 -07:00
|
|
|
return vector.ShapeCastOp(vec_type, load_val).result
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2024-06-12 14:36:31 -07:00
|
|
|
def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle:
|
|
|
|
"""Lowering rule for loading PRNG keys from SMEM.
|
|
|
|
|
|
|
|
PRNG key loads are currently lowered as a list of scalar loads from SMEM,
|
|
|
|
rather than a single vector load.
|
|
|
|
We store these scalars in a bundle type called KeyScalarBundle, which has
|
|
|
|
special case handling for functions that consume the key such as set_seed.
|
|
|
|
"""
|
|
|
|
ref, _, _, _ = args_tree.unflatten(args_flat)
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
assert isinstance(aval_out.dtype, prng.KeyTy)
|
|
|
|
ref_block_shape = aval_out.dtype._impl.key_shape
|
|
|
|
|
2024-06-24 11:19:59 -07:00
|
|
|
if len(ref_block_shape) != 2:
|
|
|
|
raise NotImplementedError("Seed key_data must be 2D.")
|
|
|
|
if tuple(ref_block_shape) != (1, 1):
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"Seed key_data of shape != (1, 1) not supported. Got: {ref_block_shape}")
|
2024-06-12 14:36:31 -07:00
|
|
|
|
|
|
|
load_ops = []
|
|
|
|
for i in range(ref_block_shape[0]):
|
2024-06-24 11:19:59 -07:00
|
|
|
idx = NDIndexer(indices=(0, i), shape=ref_block_shape,
|
2024-06-12 14:36:31 -07:00
|
|
|
int_indexer_shape=tuple())
|
|
|
|
starts, _, _, _, _ = _indexer_to_start_size_stride(
|
|
|
|
idx,
|
|
|
|
ref_block_shape,
|
|
|
|
cast_to_index=True,
|
|
|
|
)
|
|
|
|
load_ops.append(memref.LoadOp(ref, starts).result)
|
|
|
|
return KeyScalarBundle(scalars=load_ops)
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
lowering_rules[primitives.load_p] = _load_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(primitives.load_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _masked_swap_lowering_rule(
|
2023-10-10 14:38:54 -07:00
|
|
|
ctx: LoweringRuleContext, *args_flat, args_tree, **_
|
2023-08-01 16:42:26 -07:00
|
|
|
):
|
2024-01-02 21:53:30 -08:00
|
|
|
ref, indexers, val, mask = args_tree.unflatten(args_flat)
|
|
|
|
ref_aval, indexers_avals, val_aval, _ = args_tree.unflatten(ctx.avals_in)
|
|
|
|
(*slice_indexers, idx) = indexers
|
|
|
|
(*_, idx_aval) = indexers_avals
|
2023-10-10 14:38:54 -07:00
|
|
|
|
|
|
|
if mask is not None:
|
2023-08-01 16:42:26 -07:00
|
|
|
raise NotImplementedError
|
2023-10-10 14:38:54 -07:00
|
|
|
|
2024-01-02 21:53:30 -08:00
|
|
|
ref_block_shape, *_ = ctx.block_shapes
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
ref, ref_block_shape = _index_ref(
|
2024-01-11 06:32:57 -08:00
|
|
|
ref, ref_aval, ref_block_shape, slice_indexers)
|
2024-01-02 21:53:30 -08:00
|
|
|
|
2023-09-07 04:25:44 -07:00
|
|
|
ref_type = ir.MemRefType(ref.type)
|
|
|
|
is_smem_store = str(ref_type.memory_space) == "#tpu.memory_space<smem>"
|
2023-08-01 16:42:26 -07:00
|
|
|
(aval_out,) = ctx.avals_out
|
2023-08-04 13:43:04 -07:00
|
|
|
if not isinstance(val, ir.Value):
|
2024-01-11 06:32:57 -08:00
|
|
|
val = ir_constant(val, mlir_type=_dtype_to_ir_type(val_aval.dtype))
|
2023-08-01 16:42:26 -07:00
|
|
|
if any(
|
2023-10-10 14:38:54 -07:00
|
|
|
(not isinstance(a, primitives.Slice) and a.shape)
|
2023-08-01 16:42:26 -07:00
|
|
|
for a in idx_aval.indices
|
|
|
|
):
|
|
|
|
raise ValueError("Cannot do int indexing on TPU")
|
2023-10-24 17:28:05 -07:00
|
|
|
if not is_smem_store and not ref_block_shape:
|
2023-08-10 21:18:10 -07:00
|
|
|
raise NotImplementedError(
|
|
|
|
"Indexing into a ()-shaped Ref not yet supported on TPU.")
|
2024-01-02 21:53:30 -08:00
|
|
|
|
2024-03-14 16:31:23 -07:00
|
|
|
starts, _, strides, _, _ = _indexer_to_start_size_stride(
|
|
|
|
idx,
|
|
|
|
ref_block_shape,
|
|
|
|
cast_to_index=True,
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
2024-03-14 16:31:23 -07:00
|
|
|
need_stride = not all((s is None or s == 1) for s in strides)
|
2023-09-07 04:25:44 -07:00
|
|
|
if is_smem_store:
|
|
|
|
if val_aval.shape:
|
|
|
|
raise ValueError("Can only store scalars to SMEM")
|
2024-01-02 21:53:30 -08:00
|
|
|
result = memref.LoadOp(ref, starts).result
|
|
|
|
memref.StoreOp(val, ref, starts)
|
2023-09-07 04:25:44 -07:00
|
|
|
return result
|
2023-08-01 16:42:26 -07:00
|
|
|
mem_slice_shape = list(aval_out.shape)
|
|
|
|
for i, a in enumerate(idx_aval.indices):
|
|
|
|
if not isinstance(a, primitives.Slice):
|
|
|
|
mem_slice_shape.insert(i, 1)
|
|
|
|
mem_slice_shape_iter = iter(mem_slice_shape)
|
|
|
|
mem_slice_shape = [
|
2024-01-18 17:19:38 -08:00
|
|
|
1 if b is pl_core.mapped else next(mem_slice_shape_iter)
|
2023-08-01 16:42:26 -07:00
|
|
|
for b in ref_block_shape
|
|
|
|
]
|
|
|
|
mem_aval = aval_out.update(shape=tuple(mem_slice_shape))
|
2023-08-10 21:18:10 -07:00
|
|
|
mem_aval_vec_type = ir.VectorType.get(mem_aval.shape,
|
2024-01-11 06:32:57 -08:00
|
|
|
_dtype_to_ir_type(mem_aval.dtype))
|
2024-03-14 16:31:23 -07:00
|
|
|
if need_stride:
|
|
|
|
result = tpu.StridedLoadOp(mem_aval_vec_type, ref, starts, strides).result
|
|
|
|
else:
|
|
|
|
result = vector.LoadOp(mem_aval_vec_type, ref, starts).result
|
2023-08-01 16:42:26 -07:00
|
|
|
if mem_aval != aval_out:
|
2023-08-10 21:18:10 -07:00
|
|
|
# We are slicing a scalar so provided dummy 1 indices
|
|
|
|
result_vec_type = ir.VectorType.get(aval_out.shape,
|
2024-01-11 06:32:57 -08:00
|
|
|
_dtype_to_ir_type(aval_out.dtype))
|
2023-08-10 21:18:10 -07:00
|
|
|
result = vector.ShapeCastOp(result_vec_type, result).result
|
|
|
|
val_vec_type = ir.VectorType.get(mem_aval.shape,
|
2024-01-11 06:32:57 -08:00
|
|
|
_dtype_to_ir_type(mem_aval.dtype))
|
2023-08-10 21:18:10 -07:00
|
|
|
val = vector.ShapeCastOp(val_vec_type, val).result
|
2024-03-14 16:31:23 -07:00
|
|
|
if need_stride:
|
|
|
|
tpu.StridedStoreOp(val, ref, starts, strides)
|
|
|
|
else:
|
|
|
|
vector.StoreOp(val, ref, starts)
|
2023-08-01 16:42:26 -07:00
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(primitives.swap_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _multiple_of_lowering_rule(ctx: LoweringRuleContext, val, *, values):
|
2024-01-12 03:42:18 -08:00
|
|
|
del ctx
|
|
|
|
for multiple in values:
|
|
|
|
val = tpu.assume_multiple(val, multiple)
|
2023-08-01 16:42:26 -07:00
|
|
|
return val
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[primitives.multiple_of_p] = _multiple_of_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
|
|
|
(x_aval,) = ctx.avals_in
|
2024-04-12 14:57:22 -07:00
|
|
|
if not ctx.avals_out[0].shape:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Cannot lower reductions to scalar. Reduce to one element vector"
|
|
|
|
" instead, using keepdims=True."
|
|
|
|
)
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
out_type = aval_to_ir_type(ctx.avals_out[0])
|
|
|
|
if jnp.issubdtype(x_aval.dtype, jnp.floating):
|
2024-06-03 03:33:12 -07:00
|
|
|
kind = vector.CombiningKind.MAXIMUMF
|
2023-08-01 16:42:26 -07:00
|
|
|
val = ir.FloatAttr.get(ir.F32Type.get(), float("-inf"))
|
|
|
|
identity = ir.DenseElementsAttr.get_splat(out_type, val)
|
|
|
|
elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
|
|
|
|
kind = ir.Attribute.parse("#vector.kind<maxsi>")
|
|
|
|
raise NotImplementedError
|
|
|
|
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
|
|
|
|
kind = ir.Attribute.parse("#vector.kind<maxui>")
|
|
|
|
raise NotImplementedError
|
|
|
|
acc = arith.ConstantOp(out_type, identity)
|
|
|
|
op = vector.MultiDimReductionOp(
|
|
|
|
kind,
|
|
|
|
x,
|
|
|
|
acc,
|
|
|
|
ir.ArrayAttr.get(
|
|
|
|
[ir.IntegerAttr.get(ir.IntegerType.get_signless(64), a) for a in axes]
|
|
|
|
),
|
|
|
|
)
|
|
|
|
return op.result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.reduce_max_p] = _reduce_max_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
|
|
|
(x_aval,) = ctx.avals_in
|
2024-04-12 14:57:22 -07:00
|
|
|
if not ctx.avals_out[0].shape:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Cannot lower reductions to scalar. Reduce to one element vector"
|
|
|
|
" instead, using keepdims=True."
|
|
|
|
)
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
out_type = aval_to_ir_type(ctx.avals_out[0])
|
|
|
|
if jnp.issubdtype(x_aval.dtype, jnp.floating):
|
|
|
|
kind = ir.Attribute.parse("#vector.kind<add>")
|
|
|
|
val = ir.FloatAttr.get(ir.F32Type.get(), 0.0)
|
|
|
|
identity = ir.DenseElementsAttr.get_splat(out_type, val)
|
|
|
|
elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
|
|
|
|
kind = ir.Attribute.parse("#vector.kind<add>")
|
|
|
|
raise NotImplementedError
|
|
|
|
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
|
|
|
|
kind = ir.Attribute.parse("#vector.kind<add>")
|
|
|
|
raise NotImplementedError
|
|
|
|
acc = arith.ConstantOp(out_type, identity)
|
|
|
|
op = vector.MultiDimReductionOp(
|
|
|
|
kind,
|
|
|
|
x,
|
|
|
|
acc,
|
|
|
|
ir.ArrayAttr.get(
|
|
|
|
[ir.IntegerAttr.get(ir.IntegerType.get_signless(64), a) for a in axes]
|
|
|
|
),
|
|
|
|
)
|
|
|
|
return op.result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.reduce_sum_p] = _reduce_sum_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _broadcast_in_dim_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions
|
|
|
|
):
|
|
|
|
(aval_in,) = ctx.avals_in
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if broadcast_dimensions:
|
|
|
|
out_shape_list = [1] * len(shape)
|
|
|
|
for i, s in zip(broadcast_dimensions, aval_in.shape):
|
|
|
|
out_shape_list[i] = s
|
|
|
|
out_shape = tuple(out_shape_list)
|
|
|
|
out_type = ir.VectorType.get(
|
2024-01-11 06:32:57 -08:00
|
|
|
out_shape, _dtype_to_ir_type(aval_out.dtype)
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
val = vector.ShapeCastOp(out_type, val).result
|
|
|
|
if out_shape == aval_out.shape:
|
|
|
|
return val
|
|
|
|
out_type = ir.VectorType.get(
|
2024-01-11 06:32:57 -08:00
|
|
|
aval_out.shape, _dtype_to_ir_type(aval_out.dtype)
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
return vector.BroadcastOp(out_type, val).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.broadcast_in_dim_p] = _broadcast_in_dim_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _dot_general_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext, x, y, dimension_numbers, precision, **_
|
|
|
|
):
|
|
|
|
(lhs_dims, rhs_dims), _ = dimension_numbers
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
out_type = aval_to_ir_type(aval_out)
|
2023-11-09 11:09:15 -08:00
|
|
|
val_type = out_type.element_type
|
2024-06-07 08:16:09 -07:00
|
|
|
if any(
|
|
|
|
cls.isinstance(val_type)
|
|
|
|
for cls in [
|
|
|
|
ir.BF16Type,
|
|
|
|
ir.F32Type,
|
|
|
|
ir.Float8E5M2Type,
|
|
|
|
ir.Float8E4M3FNType,
|
|
|
|
]
|
|
|
|
):
|
2023-11-09 11:09:15 -08:00
|
|
|
val = ir.FloatAttr.get(val_type, 0.0)
|
|
|
|
elif ir.IntegerType.isinstance(val_type):
|
|
|
|
val = ir.IntegerAttr.get(val_type, 0)
|
2023-08-01 16:42:26 -07:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(ctx.avals_out[0].dtype)
|
|
|
|
if any(len(a.shape) != 2 for a in ctx.avals_in):
|
2024-05-02 10:58:42 -07:00
|
|
|
raise NotImplementedError(
|
|
|
|
f"Only 2D tensors supported in dot; received: {ctx.avals_in}"
|
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
lhs_aval, _ = ctx.avals_in
|
|
|
|
# This is really a matrix-vector product. It only looks like matrix-matrix.
|
|
|
|
if lhs_dims == (1,) and rhs_dims == (1,) and ctx.avals_in[1].shape[0] == 1:
|
|
|
|
if ctx.avals_in[0].shape != ctx.avals_in[1].shape:
|
|
|
|
bcast_shape = jnp.broadcast_shapes(
|
|
|
|
ctx.avals_in[0].shape, ctx.avals_out[0].shape
|
|
|
|
)
|
|
|
|
bcast_shape = ir.VectorType.get(
|
2024-01-11 06:32:57 -08:00
|
|
|
list(bcast_shape), _dtype_to_ir_type(ctx.avals_out[0].dtype)
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
if ctx.avals_in[0].shape != bcast_shape:
|
|
|
|
x = vector.BroadcastOp(bcast_shape, x)
|
|
|
|
if ctx.avals_in[1].shape != bcast_shape:
|
|
|
|
y = vector.BroadcastOp(bcast_shape, y)
|
|
|
|
red_type = aval_to_ir_type(lhs_aval.update(shape=(lhs_aval.shape[0],)))
|
|
|
|
acc = arith.ConstantOp(
|
|
|
|
red_type, ir.DenseElementsAttr.get_splat(red_type, val)
|
|
|
|
)
|
|
|
|
red = vector.MultiDimReductionOp(
|
|
|
|
ir.Attribute.parse("#vector.kind<add>"),
|
|
|
|
arith.MulFOp(x, y),
|
|
|
|
acc,
|
|
|
|
ir.ArrayAttr.get(
|
|
|
|
[ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 1)]
|
|
|
|
),
|
|
|
|
)
|
|
|
|
return vector.ShapeCastOp(out_type, red).result
|
|
|
|
|
|
|
|
if lhs_dims == (1,):
|
2023-10-12 07:37:22 -07:00
|
|
|
transpose_lhs = False
|
2023-08-01 16:42:26 -07:00
|
|
|
elif lhs_dims == (0,):
|
2023-10-12 07:37:22 -07:00
|
|
|
transpose_lhs = True
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
2023-08-01 16:42:26 -07:00
|
|
|
if rhs_dims == (0,):
|
2023-10-12 07:37:22 -07:00
|
|
|
transpose_rhs = False
|
2023-08-01 16:42:26 -07:00
|
|
|
elif rhs_dims == (1,):
|
2023-10-12 07:37:22 -07:00
|
|
|
transpose_rhs = True
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
2023-08-01 16:42:26 -07:00
|
|
|
if precision is not None:
|
|
|
|
if precision[0] != precision[1]:
|
|
|
|
raise NotImplementedError("Per-operand dot precision unsupported")
|
|
|
|
precision = precision[0]
|
|
|
|
if precision is None or precision == lax.Precision.DEFAULT:
|
2023-10-12 07:37:22 -07:00
|
|
|
precision_attr = None # That's the default in Mosaic.
|
2023-08-01 16:42:26 -07:00
|
|
|
elif precision == lax.Precision.HIGHEST:
|
2023-10-12 07:37:22 -07:00
|
|
|
precision_attr = ir.Attribute.parse(
|
2023-08-01 16:42:26 -07:00
|
|
|
"#tpu.contract_precision<fp32>"
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"Unsupported dot precision: {precision}")
|
2023-10-12 07:37:22 -07:00
|
|
|
out_tile = arith.ConstantOp(
|
|
|
|
out_type, ir.DenseElementsAttr.get_splat(out_type, val)
|
|
|
|
)
|
|
|
|
op = tpu.MatmulOp(
|
|
|
|
out_type, x, y, out_tile,
|
|
|
|
transpose_lhs=transpose_lhs, transpose_rhs=transpose_rhs,
|
|
|
|
precision=precision_attr
|
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
return op.result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.dot_general_p] = _dot_general_lowering_rule
|
|
|
|
|
2024-01-18 19:02:02 -08:00
|
|
|
def _convert_helper(x, *, to_dtype):
|
|
|
|
# Helper function for dtype conversion
|
|
|
|
from_dtype = x.dtype
|
|
|
|
if jnp.issubdtype(from_dtype, jnp.dtype("bool")):
|
|
|
|
x = x.astype(jnp.int32)
|
|
|
|
return _convert_helper(x, to_dtype=to_dtype)
|
2024-06-10 18:07:33 -07:00
|
|
|
if jnp.issubdtype(from_dtype, jnp.signedinteger):
|
2024-01-18 19:02:02 -08:00
|
|
|
if from_dtype.itemsize < 4:
|
|
|
|
x = x.astype(jnp.int32)
|
|
|
|
if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4:
|
|
|
|
x = x.astype(jnp.float32)
|
|
|
|
return x.astype(to_dtype)
|
|
|
|
if jnp.issubdtype(from_dtype, jnp.floating):
|
2024-06-10 18:07:33 -07:00
|
|
|
if jnp.issubdtype(to_dtype, jnp.signedinteger):
|
2024-01-18 19:02:02 -08:00
|
|
|
if from_dtype.itemsize < 4:
|
|
|
|
x = x.astype(jnp.float32)
|
|
|
|
if to_dtype.itemsize < 4:
|
|
|
|
# Need to clip values to match XLA
|
|
|
|
minval, maxval = jnp.iinfo(to_dtype).min, jnp.iinfo(to_dtype).max
|
|
|
|
x = jnp.clip(x, minval, maxval)
|
|
|
|
return x.astype(jnp.int32).astype(to_dtype)
|
|
|
|
return x.astype(to_dtype)
|
|
|
|
elif jnp.issubdtype(to_dtype, np.dtype("bool")):
|
|
|
|
x = x.astype(jnp.int32)
|
|
|
|
return x.astype(jnp.float32)
|
|
|
|
raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}")
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
def _convert_element_type_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext, x, *, new_dtype, weak_type
|
|
|
|
):
|
2023-08-22 22:01:44 -07:00
|
|
|
del weak_type
|
|
|
|
out_aval = ctx.avals_out[0]
|
2023-08-01 16:42:26 -07:00
|
|
|
old_dtype = ctx.avals_in[0].dtype
|
2023-08-22 22:01:44 -07:00
|
|
|
out_type = aval_to_ir_type(out_aval)
|
2024-06-10 18:07:33 -07:00
|
|
|
|
|
|
|
# TODO(justinfu): Remove after mosaic supports unsigned types.
|
|
|
|
# This conversion makes mosaic interpret all unsigned types as signed types.
|
|
|
|
if np.issubdtype(new_dtype, jnp.unsignedinteger):
|
|
|
|
new_dtype = UNSIGNED_TO_SIGNED[new_dtype]
|
2023-08-22 22:01:44 -07:00
|
|
|
if old_dtype == new_dtype:
|
|
|
|
return x
|
|
|
|
if jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype(
|
|
|
|
new_dtype, jnp.floating
|
|
|
|
):
|
2024-01-18 19:02:02 -08:00
|
|
|
if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4:
|
2023-08-22 22:01:44 -07:00
|
|
|
return arith.ExtFOp(out_type, x).result
|
2024-01-18 19:02:02 -08:00
|
|
|
elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4:
|
2023-08-22 22:01:44 -07:00
|
|
|
return arith.TruncFOp(out_type, x).result
|
|
|
|
elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype(
|
|
|
|
new_dtype, jnp.signedinteger
|
|
|
|
):
|
2024-01-18 19:02:02 -08:00
|
|
|
if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4:
|
2023-08-22 22:01:44 -07:00
|
|
|
return arith.ExtSIOp(out_type, x).result
|
2024-01-18 19:02:02 -08:00
|
|
|
elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4:
|
2023-08-22 22:01:44 -07:00
|
|
|
return arith.TruncIOp(out_type, x).result
|
|
|
|
elif jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype(
|
|
|
|
new_dtype, jnp.signedinteger
|
2024-01-18 19:02:02 -08:00
|
|
|
) and old_dtype.itemsize == new_dtype.itemsize == 4:
|
2023-08-22 22:01:44 -07:00
|
|
|
return arith.FPToSIOp(out_type, x).result
|
2024-01-18 19:02:02 -08:00
|
|
|
elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype(
|
|
|
|
new_dtype, jnp.floating
|
|
|
|
) and old_dtype.itemsize == new_dtype.itemsize == 4:
|
|
|
|
return arith.SIToFPOp(out_type, x).result
|
|
|
|
elif (
|
|
|
|
old_dtype == jnp.bool_
|
|
|
|
and jnp.issubdtype(new_dtype, jnp.integer)
|
|
|
|
and new_dtype.itemsize == 4
|
|
|
|
):
|
|
|
|
return arith.extui(out_type, x)
|
|
|
|
elif (
|
|
|
|
jnp.issubdtype(old_dtype, jnp.integer)
|
|
|
|
and new_dtype == jnp.bool_
|
|
|
|
and old_dtype.itemsize == 4
|
|
|
|
):
|
|
|
|
return arith.TruncIOp(out_type, x).result
|
|
|
|
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
|
|
|
|
multiple_results=False)(ctx, x)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.convert_element_type_p] = _convert_element_type_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions):
|
|
|
|
if dimensions is not None:
|
|
|
|
raise NotImplementedError
|
|
|
|
if any(d is None for d in new_sizes):
|
|
|
|
raise NotImplementedError
|
2023-09-27 13:33:04 -07:00
|
|
|
if not ctx.avals_in[0].shape:
|
|
|
|
return vector.BroadcastOp(aval_to_ir_type(ctx.avals_out[0]), x).result
|
2023-08-01 16:42:26 -07:00
|
|
|
return vector.ShapeCastOp(aval_to_ir_type(ctx.avals_out[0]), x).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.reshape_p] = _reshape_lowering_rule
|
|
|
|
|
|
|
|
|
2023-09-07 03:44:18 -07:00
|
|
|
def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions):
|
|
|
|
del dimensions # Unused.
|
2023-11-07 07:15:23 -08:00
|
|
|
(aval_in,) = ctx.avals_in
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if not aval_out.shape:
|
|
|
|
return vector.ExtractOp(x, [], [0] * len(aval_in.shape)).result
|
2023-09-07 03:44:18 -07:00
|
|
|
return vector.ShapeCastOp(aval_to_ir_type(ctx.avals_out[0]), x).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.squeeze_p] = _squeeze_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension):
|
|
|
|
return tpu.ConcatenateOp(
|
|
|
|
aval_to_ir_type(ctx.avals_out[0]), xs, dimension=dimension
|
|
|
|
).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension):
|
|
|
|
out_type = aval_to_ir_type(ctx.avals_out[0])
|
|
|
|
return tpu.IotaOp(out_type, dimension=dimension).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.iota_p] = _iota_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation):
|
|
|
|
if permutation != (1, 0):
|
|
|
|
raise NotImplementedError
|
|
|
|
out_type = aval_to_ir_type(ctx.avals_out[0])
|
2023-11-20 12:00:33 -08:00
|
|
|
return vector.TransposeOp(out_type, x, permutation).result
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.transpose_p] = _transpose_lowering_rule
|
|
|
|
|
|
|
|
|
2023-09-06 02:14:42 -07:00
|
|
|
def _bcast(x, y, x_aval, y_aval, out_aval):
|
2024-03-25 08:59:59 -07:00
|
|
|
x_dtype = x_aval.dtype
|
|
|
|
y_dtype = y_aval.dtype
|
|
|
|
if y_aval.weak_type:
|
|
|
|
y_dtype = x_aval.dtype
|
|
|
|
elif x_aval.weak_type:
|
|
|
|
x_dtype = y_aval.dtype
|
2023-09-27 13:33:04 -07:00
|
|
|
if isinstance(x, (np.ndarray, np.number, int, float)):
|
2024-03-25 08:59:59 -07:00
|
|
|
if getattr(y, "type", None) == ir.IndexType.get():
|
2023-09-06 02:14:42 -07:00
|
|
|
mlir_type = y.type
|
|
|
|
else:
|
2024-03-25 08:59:59 -07:00
|
|
|
mlir_type = _dtype_to_ir_type(x_dtype)
|
2023-09-06 02:14:42 -07:00
|
|
|
x = ir_constant(x, mlir_type)
|
2023-09-27 13:33:04 -07:00
|
|
|
if isinstance(y, (np.ndarray, np.number, int, float)):
|
2024-03-25 08:59:59 -07:00
|
|
|
if getattr(x, "type", None) == ir.IndexType.get():
|
2023-09-06 02:14:42 -07:00
|
|
|
mlir_type = x.type
|
|
|
|
else:
|
2024-03-25 08:59:59 -07:00
|
|
|
mlir_type = _dtype_to_ir_type(y_dtype)
|
2023-09-06 02:14:42 -07:00
|
|
|
y = ir_constant(y, mlir_type)
|
|
|
|
out_shape = list(out_aval.shape)
|
|
|
|
if x_aval.shape != out_aval.shape:
|
2024-03-25 08:59:59 -07:00
|
|
|
x_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(x_dtype))
|
2023-09-06 02:14:42 -07:00
|
|
|
x = vector.BroadcastOp(x_ty, x)
|
|
|
|
if y_aval.shape != out_aval.shape:
|
2024-03-25 08:59:59 -07:00
|
|
|
y_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(y_dtype))
|
2023-09-06 02:14:42 -07:00
|
|
|
y = vector.BroadcastOp(y_ty, y)
|
|
|
|
return x, y
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _add_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.integer):
|
|
|
|
return arith.AddIOp(x, y).result
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
|
|
return arith.AddFOp(x, y).result
|
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.add_p] = _add_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.add_p)
|
2024-01-29 21:41:40 -08:00
|
|
|
lowering_rules[ad_util.add_any_p] = _add_lowering_rule
|
|
|
|
skip_mlir_conversions.add(ad_util.add_any_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _max_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
|
|
|
|
return arith.MaxSIOp(x, y).result
|
|
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
|
|
|
|
return arith.MaxUIOp(x, y).result
|
|
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.floating):
|
2023-09-12 14:59:08 -07:00
|
|
|
return arith.MaximumFOp(x, y).result
|
2023-08-01 16:42:26 -07:00
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.max_p] = _max_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.max_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2023-09-27 12:34:31 -07:00
|
|
|
def _min_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
|
|
|
|
return arith.MinSIOp(x, y).result
|
|
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
|
|
|
|
return arith.MinUIOp(x, y).result
|
|
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
|
|
return arith.MinimumFOp(x, y).result
|
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.min_p] = _min_lowering_rule
|
|
|
|
skip_mlir_conversions.add(lax.min_p)
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _sub_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.integer):
|
|
|
|
return arith.SubIOp(x, y).result
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
|
|
return arith.SubFOp(x, y).result
|
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.sub_p] = _sub_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.max_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _mul_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.integer):
|
|
|
|
return arith.MulIOp(x, y).result
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
|
|
return arith.MulFOp(x, y).result
|
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.mul_p] = _mul_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.mul_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.integer):
|
|
|
|
return arith.DivSIOp(x, y).result
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
|
|
|
|
return arith.DivUIOp(x, y).result
|
|
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
|
|
return arith.DivFOp(x, y).result
|
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.div_p] = _div_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.div_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _rem_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.integer):
|
|
|
|
return arith.RemSIOp(x, y).result
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
|
|
|
|
return arith.RemUIOp(x, y).result
|
|
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
|
|
return arith.RemFOp(x, y).result
|
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.rem_p] = _rem_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.rem_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _abs_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.integer):
|
|
|
|
return math.AbsIOp(x).result
|
2023-09-06 02:14:42 -07:00
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.floating):
|
|
|
|
return math.AbsFOp(x).result
|
2023-08-01 16:42:26 -07:00
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.abs_p] = _abs_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _neg_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
(x_aval,) = ctx.avals_in
|
|
|
|
new_ctx = ctx.replace(
|
|
|
|
avals_in=(jax_core.ShapedArray((), x_aval.dtype), x_aval),
|
|
|
|
block_shapes=((), *ctx.block_shapes)
|
|
|
|
)
|
|
|
|
return _sub_lowering_rule(new_ctx, np.array(0, dtype=x_aval.dtype), x)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.neg_p] = _neg_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.neg_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
return math.RsqrtOp(x).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule
|
|
|
|
|
|
|
|
|
2023-09-28 12:39:07 -07:00
|
|
|
def _sqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
return math.SqrtOp(x).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.sqrt_p] = _sqrt_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _exp_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
return math.ExpOp(x).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.exp_p] = _exp_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-22 19:41:30 -07:00
|
|
|
def _pow_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
if not isinstance(x, ir.Value) and x == 2.:
|
|
|
|
return math.Exp2Op(y).result
|
2023-09-27 13:33:04 -07:00
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
return math.PowFOp(x, y).result
|
2023-08-22 19:41:30 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.pow_p] = _pow_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.pow_p)
|
2023-08-22 19:41:30 -07:00
|
|
|
|
|
|
|
|
2023-09-15 16:00:19 -07:00
|
|
|
def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y):
|
|
|
|
return lower_fun(lax_internal._integer_pow, multiple_results=False)(
|
|
|
|
ctx, x, y=y)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-16 17:04:28 -07:00
|
|
|
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
# exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior
|
|
|
|
# here.
|
|
|
|
return lower_fun(lambda x: jnp.exp(np.log(2) * x), multiple_results=False)(
|
|
|
|
ctx, x)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.exp2_p] = _exp2_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.exp2_p)
|
|
|
|
|
2023-08-16 17:04:28 -07:00
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
neg_x = arith.NegFOp(x).result
|
|
|
|
exp_neg_x = math.ExpOp(neg_x).result
|
|
|
|
aval_out = ctx.avals_out[0]
|
2023-09-27 13:33:04 -07:00
|
|
|
out_type = aval_to_ir_type(aval_out)
|
|
|
|
if aval_out.shape == ():
|
|
|
|
one = ir_constant(1.0, mlir_type=out_type)
|
|
|
|
else:
|
|
|
|
one = vector.BroadcastOp(out_type, ir_constant(1.0))
|
2023-08-01 16:42:26 -07:00
|
|
|
denom = arith.AddFOp(one, exp_neg_x).result
|
|
|
|
return arith.DivFOp(one, denom).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.logistic_p] = _logistic_lowering_rule
|
|
|
|
|
|
|
|
|
2023-09-27 13:33:04 -07:00
|
|
|
def _sin_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
return math.SinOp(x).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.sin_p] = _sin_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
return math.TanhOp(x).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.tanh_p] = _tanh_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _log_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
return math.LogOp(x).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.log_p] = _log_lowering_rule
|
|
|
|
|
2023-09-06 02:14:42 -07:00
|
|
|
|
2023-09-19 19:24:46 -07:00
|
|
|
def _log1p_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
return math.Log1pOp(x).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.log1p_p] = _log1p_lowering_rule
|
2024-04-02 02:54:45 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method):
|
|
|
|
if rounding_method == 0:
|
|
|
|
return math.RoundOp(x).result
|
|
|
|
elif rounding_method == 1:
|
|
|
|
return math.RoundEvenOp(x).result
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"Unsupported rounding method: {rounding_method}")
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.round_p] = _round_lowering_rule
|
2023-09-19 19:24:46 -07:00
|
|
|
|
|
|
|
|
2024-04-11 10:57:48 -07:00
|
|
|
# See https://mlir.llvm.org/docs/Dialects/ArithOps/#arithcmpi-arithcmpiop for
|
|
|
|
# the mapping from comparison type to integer predicates for int comparisons.
|
2023-08-01 16:42:26 -07:00
|
|
|
_cmpi_lowering_types = {
|
|
|
|
lax.eq_p: 0,
|
|
|
|
lax.ne_p: 1,
|
|
|
|
lax.lt_p: 2,
|
|
|
|
lax.le_p: 3,
|
|
|
|
lax.gt_p: 4,
|
|
|
|
lax.ge_p: 5,
|
|
|
|
}
|
|
|
|
|
2024-04-11 10:57:48 -07:00
|
|
|
# See https://mlir.llvm.org/docs/Dialects/ArithOps/#arithcmpf-arithcmpfop for
|
|
|
|
# the mapping from comparison type to integer predicate for float comparisons.
|
2023-08-01 16:42:26 -07:00
|
|
|
_cmpf_lowering_types = {
|
|
|
|
lax.eq_p: 1,
|
|
|
|
lax.ne_p: 6,
|
2024-04-11 10:57:48 -07:00
|
|
|
lax.lt_p: 4,
|
|
|
|
lax.le_p: 5,
|
|
|
|
lax.gt_p: 2,
|
|
|
|
lax.ge_p: 3,
|
2023-08-01 16:42:26 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def _cmp_lowering_rule(prim, ctx: LoweringRuleContext, x, y):
|
2023-09-06 02:14:42 -07:00
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
2023-08-01 16:42:26 -07:00
|
|
|
x_aval, y_aval = ctx.avals_in
|
2023-09-06 02:14:42 -07:00
|
|
|
dtypes = x_aval.dtype, y_aval.dtype
|
|
|
|
if all(jnp.issubdtype(dtype, jnp.integer) for dtype in dtypes):
|
2023-08-01 16:42:26 -07:00
|
|
|
pred = _cmpi_lowering_types[prim]
|
|
|
|
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
|
|
|
|
return arith.CmpIOp(predicate, x, y).result
|
2023-09-06 02:14:42 -07:00
|
|
|
elif all(jnp.issubdtype(dtype, jnp.floating) for dtype in dtypes):
|
2023-08-01 16:42:26 -07:00
|
|
|
pred = _cmpf_lowering_types[prim]
|
|
|
|
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
|
|
|
|
return arith.CmpFOp(predicate, x, y).result
|
2023-09-06 02:14:42 -07:00
|
|
|
raise NotImplementedError("Mixed dtype operands in cmp")
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.eq_p] = functools.partial(_cmp_lowering_rule, lax.eq_p)
|
|
|
|
lowering_rules[lax.ne_p] = functools.partial(_cmp_lowering_rule, lax.ne_p)
|
|
|
|
lowering_rules[lax.lt_p] = functools.partial(_cmp_lowering_rule, lax.lt_p)
|
|
|
|
lowering_rules[lax.le_p] = functools.partial(_cmp_lowering_rule, lax.le_p)
|
|
|
|
lowering_rules[lax.gt_p] = functools.partial(_cmp_lowering_rule, lax.gt_p)
|
|
|
|
lowering_rules[lax.ge_p] = functools.partial(_cmp_lowering_rule, lax.ge_p)
|
|
|
|
|
|
|
|
|
2024-03-25 08:59:59 -07:00
|
|
|
def _and_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
|
|
|
|
return arith.AndIOp(x, y).result
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.and_p] = _and_lowering_rule
|
2024-03-25 08:59:59 -07:00
|
|
|
skip_mlir_conversions.add(lax.and_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2024-03-25 08:59:59 -07:00
|
|
|
def _or_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
|
|
|
|
return arith.OrIOp(x, y).result
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.or_p] = _or_lowering_rule
|
2024-03-25 08:59:59 -07:00
|
|
|
skip_mlir_conversions.add(lax.or_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2024-01-09 09:58:04 -08:00
|
|
|
|
|
|
|
def _not_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
# The primitive not_p is lowered to
|
|
|
|
# https://github.com/openxla/stablehlo/blob/main/docs/spec.md#not
|
|
|
|
# which is arithmetic for integers and logical for booleans.
|
|
|
|
# Lowering to:
|
|
|
|
# xor x, -1
|
|
|
|
# covers both cases.
|
|
|
|
out_aval = ctx.avals_out[0]
|
|
|
|
out_scalar_type = mlir.dtype_to_ir_type(out_aval.dtype)
|
|
|
|
if not out_aval.shape:
|
|
|
|
# Create a scalar constant.
|
|
|
|
minus_one = ir_constant(-1, out_scalar_type)
|
|
|
|
else:
|
|
|
|
# Create a vector constant.
|
|
|
|
out_type = aval_to_ir_type(out_aval)
|
|
|
|
scalar_minus_one = ir.IntegerAttr.get(out_scalar_type, -1)
|
|
|
|
minus_one = arith.ConstantOp(
|
|
|
|
out_type, ir.DenseElementsAttr.get_splat(out_type, scalar_minus_one)
|
|
|
|
)
|
|
|
|
return arith.XOrIOp(x, minus_one).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.not_p] = _not_lowering_rule
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args):
|
|
|
|
if len(args) > 1:
|
|
|
|
raise NotImplementedError("select_n only supported with <= 2 arguments")
|
|
|
|
pred_aval, x_aval = ctx.avals_in[:2]
|
|
|
|
if pred_aval.dtype != np.dtype(np.bool_):
|
|
|
|
lower_ctx = LoweringRuleContext(
|
|
|
|
ctx.lowering_context,
|
|
|
|
avals_in=[pred_aval],
|
|
|
|
avals_out=[pred_aval.update(dtype=np.bool_)],
|
|
|
|
block_shapes=[None],
|
|
|
|
)
|
|
|
|
pred = lower_fun(lambda x: x != 0, multiple_results=False)(lower_ctx, pred)
|
|
|
|
if not args:
|
|
|
|
return x
|
2023-09-06 02:14:42 -07:00
|
|
|
# Assume x and y, which we check above.
|
2023-08-01 16:42:26 -07:00
|
|
|
y, = args
|
|
|
|
return arith.SelectOp(pred, y, x).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.select_n_p] = _select_n_lowering_rule
|
|
|
|
|
2023-09-27 13:33:04 -07:00
|
|
|
|
|
|
|
def _clamp(min, operand, max):
|
|
|
|
res = jnp.maximum(operand, min)
|
|
|
|
return jnp.minimum(res, max)
|
|
|
|
|
|
|
|
|
|
|
|
def _clamp_lowering_rule(ctx: LoweringRuleContext, min, operand, max):
|
|
|
|
"""Compute minimum_p(maximum_p(min, operand), max)."""
|
|
|
|
return lower_fun(_clamp, multiple_results=False)(ctx, min, operand, max)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.clamp_p] = _clamp_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _for_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
*args,
|
|
|
|
jaxpr,
|
|
|
|
nsteps,
|
|
|
|
reverse,
|
|
|
|
unroll,
|
|
|
|
which_linear,
|
|
|
|
):
|
|
|
|
should_discharge = [
|
|
|
|
not isinstance(aval, state.AbstractRef) for aval in ctx.avals_in
|
|
|
|
]
|
|
|
|
jaxpr, () = state_discharge.discharge_state(
|
|
|
|
jaxpr, (), should_discharge=[False, *should_discharge]
|
|
|
|
)
|
|
|
|
for i in range(nsteps):
|
|
|
|
if reverse:
|
|
|
|
i = nsteps - i - 1
|
|
|
|
i = ir_constant(i)
|
|
|
|
lowering_context = ctx.lowering_context.replace(
|
|
|
|
block_shapes=[(), *ctx.block_shapes],
|
|
|
|
)
|
|
|
|
non_ref_args = jaxpr_subcomp(lowering_context, jaxpr, i, *args)
|
|
|
|
non_ref_args_iter = iter(non_ref_args)
|
|
|
|
args = [
|
|
|
|
next(non_ref_args_iter) if s else a
|
|
|
|
for a, s in zip(args, should_discharge)
|
|
|
|
]
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[for_loop.for_p] = _for_lowering_rule
|
|
|
|
|
|
|
|
|
2023-12-07 22:54:32 -08:00
|
|
|
def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext,
|
2024-04-17 02:33:36 -07:00
|
|
|
jaxpr: jax_core.Jaxpr, start: int | ir.Value,
|
|
|
|
num_steps: int | ir.Value, consts, *args,
|
2023-12-07 22:54:32 -08:00
|
|
|
has_loop_index: bool,
|
|
|
|
unroll: int):
|
|
|
|
def _run_body(i, args):
|
|
|
|
if has_loop_index:
|
|
|
|
lowering_context = ctx.lowering_context.replace(
|
|
|
|
block_shapes=ctx.block_shapes)
|
|
|
|
args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)
|
|
|
|
else:
|
|
|
|
del i
|
|
|
|
lowering_context = ctx.lowering_context.replace(
|
|
|
|
block_shapes=ctx.block_shapes[:len(consts)]
|
|
|
|
+ ctx.block_shapes[len(consts) + 1:],
|
|
|
|
)
|
|
|
|
args = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
|
|
|
|
return args
|
2024-01-18 21:43:59 -08:00
|
|
|
|
2024-04-17 02:33:36 -07:00
|
|
|
if (
|
|
|
|
not isinstance(start, ir.Value)
|
|
|
|
and not isinstance(num_steps, ir.Value)
|
|
|
|
and num_steps == unroll
|
|
|
|
):
|
2023-12-07 22:54:32 -08:00
|
|
|
# No need for an scf.For. We can just unroll completely
|
|
|
|
for i in range(start, start + num_steps):
|
|
|
|
args = _run_body(
|
2024-01-11 06:32:57 -08:00
|
|
|
ir_constant(i, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))),
|
2023-12-07 22:54:32 -08:00
|
|
|
args,
|
|
|
|
)
|
|
|
|
return args
|
|
|
|
if unroll != 1:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.")
|
2024-04-17 02:33:36 -07:00
|
|
|
i32 = jax_core.ShapedArray((), jnp.int32)
|
|
|
|
lbd = _ensure_mlir_value(start, i32)
|
|
|
|
ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, i32))
|
2024-01-11 06:32:57 -08:00
|
|
|
step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
|
2023-12-07 22:54:32 -08:00
|
|
|
for_op = scf.ForOp(lbd, ubd, step, args)
|
|
|
|
with ir.InsertionPoint(for_op.body):
|
|
|
|
iv = for_op.induction_variable
|
|
|
|
inner_args = for_op.inner_iter_args
|
|
|
|
inner_out = _run_body(iv, inner_args)
|
|
|
|
scf.YieldOp(inner_out)
|
|
|
|
return for_op.results
|
|
|
|
|
|
|
|
|
2023-08-04 13:43:04 -07:00
|
|
|
def _lower_jaxpr_to_unrolled_for_loop(ctx: LoweringRuleContext,
|
|
|
|
jaxpr: jax_core.Jaxpr, start: int,
|
|
|
|
num_steps: int, consts, *args,
|
|
|
|
has_loop_index: bool):
|
|
|
|
for i in range(start, start + num_steps):
|
|
|
|
if has_loop_index:
|
|
|
|
lowering_context = ctx.lowering_context.replace(
|
|
|
|
block_shapes=ctx.block_shapes)
|
|
|
|
args = jaxpr_subcomp(
|
|
|
|
lowering_context, jaxpr, *consts,
|
2024-01-11 06:32:57 -08:00
|
|
|
ir_constant(i, mlir_type=_dtype_to_ir_type(jnp.dtype('int32'))),
|
2023-08-04 13:43:04 -07:00
|
|
|
*args)
|
|
|
|
else:
|
|
|
|
lowering_context = ctx.lowering_context.replace(
|
|
|
|
block_shapes=ctx.block_shapes[:len(consts)]
|
|
|
|
+ ctx.block_shapes[len(consts) + 1:],
|
|
|
|
)
|
|
|
|
args = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
def _scan_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
*args,
|
|
|
|
jaxpr: jax_core.Jaxpr,
|
|
|
|
linear: tuple[bool, ...],
|
|
|
|
length: int,
|
|
|
|
reverse: bool,
|
2024-04-17 02:33:36 -07:00
|
|
|
unroll: bool | int,
|
2023-08-04 13:43:04 -07:00
|
|
|
num_consts: int,
|
|
|
|
num_carry: int,
|
2024-03-28 10:54:02 -07:00
|
|
|
_split_transpose: bool,
|
2023-08-04 13:43:04 -07:00
|
|
|
):
|
2024-03-28 10:54:02 -07:00
|
|
|
del _split_transpose
|
2023-08-04 13:43:04 -07:00
|
|
|
# Can only handle fori_loop-like scans
|
|
|
|
num_extensive = len(args) - num_consts - num_carry
|
|
|
|
if num_extensive: raise NotImplementedError
|
|
|
|
if reverse: raise NotImplementedError
|
2023-12-07 22:54:32 -08:00
|
|
|
del linear, num_extensive, reverse
|
2023-08-04 13:43:04 -07:00
|
|
|
|
|
|
|
jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts
|
|
|
|
if jaxpr_consts: raise NotImplementedError
|
|
|
|
del jaxpr_consts
|
|
|
|
|
2024-04-17 02:33:36 -07:00
|
|
|
jaxpr, has_loop_index = pallas_utils.pattern_match_scan_to_fori_loop(
|
|
|
|
jaxpr, num_consts, num_carry
|
|
|
|
)
|
2023-08-04 13:43:04 -07:00
|
|
|
consts, args = split_list(args, [num_consts])
|
2024-01-18 21:43:59 -08:00
|
|
|
consts_avals, args_avals = split_list(ctx.avals_in, [num_consts])
|
2023-08-04 13:43:04 -07:00
|
|
|
if has_loop_index:
|
|
|
|
loop_index_start, *args = args
|
2024-01-18 21:43:59 -08:00
|
|
|
args_avals = args_avals[1:]
|
2023-08-04 13:43:04 -07:00
|
|
|
else:
|
|
|
|
loop_index_start = 0
|
2024-01-18 21:43:59 -08:00
|
|
|
consts = map(_ensure_mlir_value, consts, consts_avals)
|
|
|
|
args = map(_ensure_mlir_value, args, args_avals)
|
2023-12-07 22:54:32 -08:00
|
|
|
out = _lower_jaxpr_to_for_loop(
|
|
|
|
ctx, jaxpr, loop_index_start, length,
|
|
|
|
consts, *args, has_loop_index=has_loop_index,
|
|
|
|
unroll=unroll)
|
2023-08-04 13:43:04 -07:00
|
|
|
if has_loop_index:
|
|
|
|
out = [ir_constant(length,
|
2024-01-11 06:32:57 -08:00
|
|
|
mlir_type=_dtype_to_ir_type(jnp.dtype('int32'))),
|
2023-08-04 13:43:04 -07:00
|
|
|
*out]
|
|
|
|
return out
|
|
|
|
lowering_rules[lax.scan_p] = _scan_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.scan_p)
|
|
|
|
|
2023-08-04 13:43:04 -07:00
|
|
|
|
2024-04-18 11:03:01 -07:00
|
|
|
def _lower_while_via_fori(
|
2024-02-12 18:05:31 -08:00
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
*args,
|
2024-04-18 11:03:01 -07:00
|
|
|
fori_jaxpr,
|
2024-02-12 18:05:31 -08:00
|
|
|
cond_nconsts,
|
|
|
|
cond_jaxpr,
|
|
|
|
body_nconsts,
|
|
|
|
body_jaxpr,
|
|
|
|
):
|
|
|
|
_, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
|
|
|
|
(lb, ub), args = carry[:2], carry[2:]
|
|
|
|
for_out = _lower_jaxpr_to_for_loop(
|
|
|
|
ctx.replace(
|
|
|
|
block_shapes=ctx.block_shapes[: body_nconsts + 1]
|
|
|
|
+ ctx.block_shapes[body_nconsts + 2 :],
|
|
|
|
),
|
2024-04-18 11:03:01 -07:00
|
|
|
fori_jaxpr,
|
2024-02-12 18:05:31 -08:00
|
|
|
lb,
|
2024-04-17 02:33:36 -07:00
|
|
|
arith.subi(ub, lb),
|
2024-02-12 18:05:31 -08:00
|
|
|
body_consts,
|
|
|
|
*args,
|
|
|
|
has_loop_index=True,
|
|
|
|
unroll=1,
|
|
|
|
)
|
|
|
|
return [ub, ub, *for_out]
|
|
|
|
|
|
|
|
|
2024-04-18 11:03:01 -07:00
|
|
|
def _while_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
*args,
|
|
|
|
cond_nconsts,
|
|
|
|
cond_jaxpr,
|
|
|
|
body_nconsts,
|
|
|
|
body_jaxpr,
|
|
|
|
):
|
|
|
|
# First try to lower via a simpler fori loop, which may optimize better.
|
|
|
|
fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
|
|
|
|
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
|
|
|
|
)
|
|
|
|
if fori_jaxpr is not None:
|
|
|
|
return _lower_while_via_fori(
|
|
|
|
ctx,
|
|
|
|
*args,
|
|
|
|
fori_jaxpr=fori_jaxpr,
|
|
|
|
cond_nconsts=cond_nconsts,
|
|
|
|
cond_jaxpr=cond_jaxpr,
|
|
|
|
body_nconsts=body_nconsts,
|
|
|
|
body_jaxpr=body_jaxpr,
|
|
|
|
)
|
|
|
|
|
|
|
|
# If we fail conversion to fori, fallback to an ordinary while loop.
|
|
|
|
cond_consts, body_consts, carry = split_list(
|
|
|
|
args, [cond_nconsts, body_nconsts]
|
|
|
|
)
|
|
|
|
cond_const_block_shapes, body_const_block_shapes, carry_block_shapes = (
|
|
|
|
split_list(ctx.block_shapes, [cond_nconsts, body_nconsts])
|
|
|
|
)
|
|
|
|
cond_const_types = [a.type for a in cond_consts]
|
|
|
|
body_const_types = [a.type for a in body_consts]
|
|
|
|
carry_types = [a.type for a in carry]
|
|
|
|
all_types = [*cond_const_types, *body_const_types, *carry_types]
|
|
|
|
while_op = scf.WhileOp(all_types, args)
|
|
|
|
|
|
|
|
before_block = while_op.before.blocks.append(*all_types)
|
|
|
|
cond_consts_, _, carry_ = split_list(
|
|
|
|
before_block.arguments,
|
|
|
|
[cond_nconsts, body_nconsts],
|
|
|
|
)
|
|
|
|
cond_args = [*cond_consts_, *carry_]
|
|
|
|
with ir.InsertionPoint.at_block_begin(before_block):
|
|
|
|
[cond] = jaxpr_subcomp(
|
|
|
|
ctx.lowering_context.replace(
|
|
|
|
block_shapes=[*cond_const_block_shapes, *carry_block_shapes]
|
|
|
|
),
|
|
|
|
cond_jaxpr.jaxpr,
|
|
|
|
*cond_args,
|
|
|
|
)
|
|
|
|
scf.condition(cond, before_block.arguments)
|
|
|
|
|
|
|
|
after_block = while_op.after.blocks.append(*all_types)
|
|
|
|
cond_consts_, body_consts_, carry_ = split_list(
|
|
|
|
after_block.arguments,
|
|
|
|
[cond_nconsts, body_nconsts],
|
|
|
|
)
|
|
|
|
all_args = [*cond_consts_, *body_consts_, *carry_]
|
|
|
|
cond_const_args, body_const_args, carry_args = split_list(
|
|
|
|
all_args, [cond_nconsts, body_nconsts]
|
|
|
|
)
|
|
|
|
with ir.InsertionPoint.at_block_begin(after_block):
|
|
|
|
loop_out = jaxpr_subcomp(
|
|
|
|
ctx.lowering_context.replace(
|
|
|
|
block_shapes=[*body_const_block_shapes, *carry_block_shapes],
|
|
|
|
),
|
|
|
|
body_jaxpr.jaxpr,
|
|
|
|
*body_const_args,
|
|
|
|
*carry_args,
|
|
|
|
)
|
|
|
|
all_handles = [*cond_const_args, *body_const_args, *loop_out]
|
|
|
|
if all_handles:
|
|
|
|
scf.yield_(all_handles)
|
|
|
|
|
|
|
|
all_out = list(while_op.results_)
|
|
|
|
return all_out[cond_nconsts + body_nconsts :]
|
|
|
|
|
|
|
|
|
2024-02-12 18:05:31 -08:00
|
|
|
lowering_rules[lax.while_p] = _while_lowering_rule
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, linear):
|
2023-09-27 13:33:04 -07:00
|
|
|
index, *args = args
|
2023-08-01 16:42:26 -07:00
|
|
|
out_types = map(aval_to_ir_type, ctx.avals_out)
|
2023-09-27 13:33:04 -07:00
|
|
|
pred = arith.CmpIOp(
|
|
|
|
arith.CmpIPredicate.ne, index, ir_constant(0, index.type)
|
2023-08-01 16:42:26 -07:00
|
|
|
).result
|
|
|
|
if_op = scf.IfOp(pred, out_types, hasElse=True)
|
|
|
|
lowering_context = ctx.lowering_context.replace(
|
|
|
|
block_shapes=ctx.block_shapes[1:],
|
|
|
|
)
|
|
|
|
with ir.InsertionPoint(if_op.then_block):
|
2023-09-27 13:33:04 -07:00
|
|
|
# TODO(b/300272065): Use `scf.IndexSwitchOp` instead of a cascade of
|
|
|
|
# if/else.
|
|
|
|
if len(branches) > 2:
|
|
|
|
out = _cond_lowering_rule(
|
|
|
|
ctx,
|
|
|
|
arith.SubIOp(index, ir_constant(1, index.type)).result,
|
|
|
|
*args,
|
|
|
|
branches=branches[1:],
|
|
|
|
linear=linear,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args)
|
2023-08-01 16:42:26 -07:00
|
|
|
scf.YieldOp(out)
|
|
|
|
with ir.InsertionPoint(if_op.else_block):
|
|
|
|
out = jaxpr_subcomp(lowering_context, branches[0].jaxpr, *args)
|
|
|
|
scf.YieldOp(out)
|
|
|
|
return if_op.results
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.cond_p] = _cond_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
|
|
|
|
lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes)
|
|
|
|
return jaxpr_subcomp(lowering_context, jaxpr.jaxpr, *args)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[pjit.pjit_p] = _pjit_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-04 15:08:26 -07:00
|
|
|
def _custom_jvp_call_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
*args,
|
|
|
|
call_jaxpr: jax_core.Jaxpr,
|
|
|
|
jvp_jaxpr_thunk: Callable,
|
|
|
|
num_consts: int,
|
|
|
|
symbolic_zeros: bool,
|
|
|
|
):
|
|
|
|
del jvp_jaxpr_thunk
|
|
|
|
if symbolic_zeros: raise NotImplementedError
|
|
|
|
if num_consts: raise NotImplementedError
|
|
|
|
if call_jaxpr.consts: raise NotImplementedError
|
|
|
|
lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes)
|
|
|
|
return jaxpr_subcomp(lowering_context, call_jaxpr.jaxpr, *args)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[custom_derivatives.custom_jvp_call_p] = (
|
|
|
|
_custom_jvp_call_lowering_rule)
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):
|
|
|
|
del ctx, args, kwargs
|
|
|
|
# No-op debug callbacks in Mosaic for now
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[debugging.debug_callback_p] = _debug_callback_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int):
|
2024-03-11 08:40:34 -07:00
|
|
|
if ctx.lowering_context.user_grid_indices is None:
|
2023-08-01 16:42:26 -07:00
|
|
|
raise ValueError(
|
|
|
|
f"program id: {axis} was passed, but user did not provide a grid."
|
|
|
|
)
|
2024-03-11 08:40:34 -07:00
|
|
|
length = len(ctx.lowering_context.user_grid_indices)
|
2023-08-01 16:42:26 -07:00
|
|
|
if not (0 <= axis < length):
|
|
|
|
raise ValueError(
|
|
|
|
f"user passed in program id with axis: {axis}, but grid only has"
|
|
|
|
f" length: {length}"
|
|
|
|
)
|
2024-03-11 08:40:34 -07:00
|
|
|
return ctx.lowering_context.user_grid_indices[axis]
|
2023-08-01 16:42:26 -07:00
|
|
|
lowering_rules[primitives.program_id_p] = _program_id_lowering_rule
|
|
|
|
|
2024-02-26 06:38:17 -08:00
|
|
|
def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int):
|
2024-03-11 08:40:34 -07:00
|
|
|
mapped_axes = set(ctx.lowering_context.mapped_dims)
|
|
|
|
seen_user_axes = 0
|
|
|
|
for i in range(ctx.lowering_context.grid_rank):
|
|
|
|
seen_user_axes += int(i not in mapped_axes)
|
|
|
|
if seen_user_axes == axis + 1:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f"user passed in program id with axis: {axis}, but grid only has"
|
|
|
|
f" length: {len(ctx.lowering_context.grid_rank)}"
|
|
|
|
)
|
|
|
|
return tpu.iteration_bound(i)
|
2024-02-26 06:38:17 -08:00
|
|
|
lowering_rules[primitives.num_programs_p] = _num_programs_lowering_rule
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
def _repeat_lowering_rule(ctx: LoweringRuleContext, x, *, repeats, axis):
|
|
|
|
(out_aval,) = ctx.avals_out
|
|
|
|
return tpu.RepeatOp(aval_to_ir_type(out_aval), x, axis, repeats).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[tpu_primitives.repeat_p] = _repeat_lowering_rule
|
|
|
|
|
|
|
|
|
2024-02-13 18:20:23 -08:00
|
|
|
def _roll_lowering_rule(
|
2024-06-18 15:51:41 -07:00
|
|
|
ctx: LoweringRuleContext, x, shift, *, axis, stride, stride_axis
|
2024-02-13 18:20:23 -08:00
|
|
|
):
|
2024-06-18 15:51:41 -07:00
|
|
|
(out_aval,) = ctx.avals_out
|
|
|
|
return tpu.DynamicRotateOp(
|
|
|
|
aval_to_ir_type(out_aval),
|
2024-02-13 18:20:23 -08:00
|
|
|
x,
|
|
|
|
shift,
|
|
|
|
axis,
|
|
|
|
stride=stride,
|
|
|
|
stride_dimension=stride_axis,
|
|
|
|
).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[tpu_primitives.roll_p] = _roll_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _slice_lowering_rule(
|
2023-09-06 02:14:42 -07:00
|
|
|
ctx: LoweringRuleContext, x, limit_indices, start_indices, strides
|
2023-08-01 16:42:26 -07:00
|
|
|
):
|
|
|
|
"""Lowers a slice to vector dialect."""
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if strides is None:
|
|
|
|
strides = [1] * len(start_indices)
|
|
|
|
sizes = np.array(limit_indices) - np.array(start_indices)
|
|
|
|
op = vector.ExtractStridedSliceOp(
|
2023-09-06 02:14:42 -07:00
|
|
|
aval_to_ir_type(aval_out), x, start_indices, sizes, strides
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
return op.result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.slice_p] = _slice_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _xor_lowering_rule(ctx: LoweringRuleContext, x, y):
|
2024-03-25 08:59:59 -07:00
|
|
|
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
|
2023-08-01 16:42:26 -07:00
|
|
|
return arith.XOrIOp(x, y).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.xor_p] = _xor_lowering_rule
|
2024-03-25 08:59:59 -07:00
|
|
|
skip_mlir_conversions.add(lax.xor_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _shift_left_lowering_rule(ctx: LoweringRuleContext, x, d):
|
2024-03-25 08:59:59 -07:00
|
|
|
x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out)
|
2023-08-01 16:42:26 -07:00
|
|
|
return arith.ShLIOp(x, d).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.shift_left_p] = _shift_left_lowering_rule
|
2024-03-25 08:59:59 -07:00
|
|
|
skip_mlir_conversions.add(lax.shift_left_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _shift_right_logical_lowering_rules(ctx: LoweringRuleContext, x, d):
|
2024-03-25 08:59:59 -07:00
|
|
|
x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out)
|
2023-08-01 16:42:26 -07:00
|
|
|
return arith.ShRUIOp(x, d).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.shift_right_logical_p] = _shift_right_logical_lowering_rules
|
2024-03-25 08:59:59 -07:00
|
|
|
skip_mlir_conversions.add(lax.shift_right_logical_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2024-02-09 10:43:43 -08:00
|
|
|
def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty):
|
|
|
|
del ty
|
|
|
|
(out_aval,) = ctx.avals_out
|
|
|
|
return tpu.BitcastOp(aval_to_ir_type(out_aval), x).result
|
|
|
|
|
|
|
|
lowering_rules[tpu_primitives.bitcast_p] = _bitcast_lowering_rule
|
|
|
|
|
2024-06-10 18:07:33 -07:00
|
|
|
def _bitcast_convert_type_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext, x, *, new_dtype):
|
|
|
|
(in_aval, ) = ctx.avals_in
|
|
|
|
(out_aval,) = ctx.avals_out
|
|
|
|
if in_aval.dtype.itemsize != new_dtype.itemsize:
|
|
|
|
raise NotImplementedError("Changing bitwidths not supported.")
|
|
|
|
return tpu.BitcastOp(aval_to_ir_type(out_aval), x).result
|
|
|
|
lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule
|
2023-09-07 16:40:27 -07:00
|
|
|
|
2023-09-07 17:08:18 -07:00
|
|
|
def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value:
|
2024-01-18 17:19:38 -08:00
|
|
|
if isinstance(aval, pl_core.AbstractMemoryRef):
|
2023-09-07 17:08:18 -07:00
|
|
|
memspace = ir.Attribute.parse(f"#tpu.memory_space<{aval.memory_space}>")
|
2024-01-11 06:32:57 -08:00
|
|
|
if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype):
|
|
|
|
assert aval.memory_space == TPUMemorySpace.SEMAPHORE
|
|
|
|
memref_type = aval_to_ir_type(aval, memory_space=TPUMemorySpace.SEMAPHORE)
|
|
|
|
return tpu.AllocaSemaphoreOp(memref_type).result
|
|
|
|
else:
|
|
|
|
out_type = ir.MemRefType.get(
|
|
|
|
aval.shape, _dtype_to_ir_type(aval.dtype), memory_space=memspace)
|
|
|
|
return memref.AllocaOp(out_type, [], []).result
|
2023-09-07 21:33:12 -07:00
|
|
|
elif isinstance(aval, tpu_core.AbstractSemaphore):
|
2024-01-11 06:32:57 -08:00
|
|
|
memref_type = aval_to_ir_type(aval, memory_space=TPUMemorySpace.SEMAPHORE)
|
2024-01-04 17:55:49 -08:00
|
|
|
return tpu.AllocaSemaphoreOp(memref_type).result
|
2023-09-07 17:08:18 -07:00
|
|
|
raise NotImplementedError(f"Cannot allocate {type(aval)}.")
|
2023-09-07 16:40:27 -07:00
|
|
|
|
|
|
|
|
2023-09-07 17:08:18 -07:00
|
|
|
def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr):
|
2023-09-07 16:40:27 -07:00
|
|
|
region = tpu.RegionOp()
|
2023-09-07 17:08:18 -07:00
|
|
|
in_avals = [v.aval for v in jaxpr.invars]
|
2023-09-07 16:40:27 -07:00
|
|
|
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
|
|
|
with ir.InsertionPoint(region.body):
|
2023-09-07 17:08:18 -07:00
|
|
|
args = map(_alloc_value, in_avals)
|
|
|
|
block_shapes = tuple(a.shape if isinstance(a, state.AbstractRef) else None
|
|
|
|
for a in in_avals)
|
2023-09-07 16:40:27 -07:00
|
|
|
ctx = ctx.lowering_context.replace(
|
2023-09-07 17:08:18 -07:00
|
|
|
block_shapes=(*ctx.block_shapes, *block_shapes)
|
2023-09-07 16:40:27 -07:00
|
|
|
)
|
|
|
|
jaxpr_subcomp(ctx, jaxpr, *consts, *args)
|
|
|
|
tpu.YieldOp([])
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[tpu_primitives.run_scoped_p] = _run_scoped_lowering_rule
|
2023-09-07 21:33:12 -07:00
|
|
|
|
2023-10-02 17:03:40 -07:00
|
|
|
def _device_id_to_logical(
|
|
|
|
ctx: LoweringRuleContext, device_id,
|
|
|
|
device_id_type: tpu_primitives.DeviceIdType):
|
|
|
|
if device_id_type is tpu_primitives.DeviceIdType.MESH:
|
|
|
|
# Mesh means we are passed the mesh coordinates for the device
|
|
|
|
device_ids = tree_util.tree_leaves(device_id)
|
|
|
|
mesh_strides = ctx.lowering_context.mesh_context.mesh_strides
|
|
|
|
def _linearize_mesh_indices(*indices):
|
2024-05-16 15:10:01 +01:00
|
|
|
return sum(a * b for a, b in zip(indices, mesh_strides))
|
2023-10-02 17:03:40 -07:00
|
|
|
lower_ctx = LoweringRuleContext(
|
|
|
|
lowering_context=ctx.lowering_context,
|
|
|
|
avals_in=[jax_core.ShapedArray((), jnp.int32)] * len(device_ids),
|
|
|
|
avals_out=[jax_core.ShapedArray((), jnp.int32)],
|
|
|
|
block_shapes=(None,) * len(device_ids),
|
|
|
|
)
|
|
|
|
return lower_fun(_linearize_mesh_indices, multiple_results=False)(
|
|
|
|
lower_ctx, *device_ids)
|
|
|
|
elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL:
|
|
|
|
return device_id
|
|
|
|
raise NotImplementedError(f"Unsupported device id type: {device_id_type}")
|
|
|
|
|
2024-04-10 13:44:07 -07:00
|
|
|
|
|
|
|
def _semaphore_read_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
*args,
|
|
|
|
args_tree,
|
|
|
|
):
|
|
|
|
sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
|
|
|
|
sem, indexers = tree_util.tree_unflatten(args_tree, args)
|
|
|
|
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
|
|
|
|
return tpu.SemaphoreReadOp(sem).result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[tpu_primitives.semaphore_read_p] = _semaphore_read_lowering_rule
|
|
|
|
|
2023-11-07 17:54:43 -08:00
|
|
|
def _semaphore_signal_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
*args,
|
2024-01-11 06:32:57 -08:00
|
|
|
args_tree,
|
2023-11-07 17:54:43 -08:00
|
|
|
device_id_type: tpu_primitives.DeviceIdType,
|
|
|
|
):
|
2024-05-02 07:42:54 -07:00
|
|
|
sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
|
|
|
|
sem, indexers, value, device_id, core_index = tree_util.tree_unflatten(args_tree, args)
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
|
2024-01-11 06:32:57 -08:00
|
|
|
if device_id is not None:
|
2023-10-02 17:03:40 -07:00
|
|
|
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
|
2024-05-02 07:42:54 -07:00
|
|
|
return tpu.SemaphoreSignalOp(
|
|
|
|
sem, value, device_id=device_id, core_id=core_index
|
|
|
|
).results
|
2023-11-07 17:54:43 -08:00
|
|
|
|
|
|
|
|
2023-09-07 21:33:12 -07:00
|
|
|
lowering_rules[tpu_primitives.semaphore_signal_p] = (
|
|
|
|
_semaphore_signal_lowering_rule)
|
|
|
|
|
|
|
|
|
2024-01-11 06:32:57 -08:00
|
|
|
def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree):
|
|
|
|
sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
|
|
|
|
sem, indexers, value = tree_util.tree_unflatten(args_tree, args)
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
|
2024-01-11 06:32:57 -08:00
|
|
|
return tpu.SemaphoreWaitOp(sem, value).results
|
2023-09-07 21:33:12 -07:00
|
|
|
lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule
|
2023-09-13 15:32:01 -07:00
|
|
|
|
2023-10-02 17:03:40 -07:00
|
|
|
def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree,
|
|
|
|
device_id_type: tpu_primitives.DeviceIdType):
|
2024-01-11 06:32:57 -08:00
|
|
|
(
|
|
|
|
src_ref,
|
|
|
|
src_indexers,
|
|
|
|
dst_ref,
|
|
|
|
dst_indexers,
|
|
|
|
sem,
|
|
|
|
sem_indexers,
|
|
|
|
src_sem,
|
|
|
|
src_sem_indexers,
|
|
|
|
device_id,
|
|
|
|
) = tree_util.tree_unflatten(tree, args)
|
|
|
|
(src_ref_aval, _, dst_ref_aval, _, sem_aval, _, src_sem_aval, _, _) = (
|
2023-10-03 13:58:26 -07:00
|
|
|
tree_util.tree_unflatten(tree, ctx.avals_in)
|
|
|
|
)
|
2024-01-02 21:53:30 -08:00
|
|
|
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
|
|
|
|
src_ref_block_shape, dst_ref_block_shape = block_shapes[0], block_shapes[2]
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
src_ref, _ = _index_ref(
|
2024-01-11 06:32:57 -08:00
|
|
|
src_ref, src_ref_aval, src_ref_block_shape, src_indexers
|
|
|
|
)
|
|
|
|
if src_sem is not None:
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
src_sem, _ = _index_ref(
|
2024-01-11 06:32:57 -08:00
|
|
|
src_sem, src_sem_aval, src_sem_aval.shape, src_sem_indexers)
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
dst_ref, _ = _index_ref(
|
2024-01-11 06:32:57 -08:00
|
|
|
dst_ref, dst_ref_aval, dst_ref_block_shape, dst_indexers
|
|
|
|
)
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers)
|
2023-10-02 17:03:40 -07:00
|
|
|
if device_id is not None:
|
|
|
|
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
|
2024-01-02 21:53:30 -08:00
|
|
|
return tpu.EnqueueDMAOp(src_ref, dst_ref, sem, source_semaphore=src_sem,
|
2023-09-13 16:13:33 -07:00
|
|
|
device_id=device_id).results
|
2023-09-13 15:32:01 -07:00
|
|
|
lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule
|
|
|
|
|
|
|
|
|
2023-10-02 17:03:40 -07:00
|
|
|
def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
|
|
|
|
device_id_type: tpu_primitives.DeviceIdType):
|
|
|
|
del device_id_type
|
2024-01-11 06:32:57 -08:00
|
|
|
sem, sem_indexers, ref, indexers = tree_util.tree_unflatten(tree, args)
|
|
|
|
sem_aval, _, ref_aval, _ = tree_util.tree_unflatten(tree, ctx.avals_in)
|
|
|
|
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
|
|
|
|
ref_block_shape = block_shapes[2]
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
ref, _ = _index_ref(
|
2024-01-11 06:32:57 -08:00
|
|
|
ref, ref_aval, ref_block_shape, indexers
|
|
|
|
)
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers)
|
2024-01-02 21:53:30 -08:00
|
|
|
return tpu.WaitDMAOp(sem, ref).results
|
2023-09-13 15:32:01 -07:00
|
|
|
lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule
|
2023-09-13 16:13:33 -07:00
|
|
|
|
|
|
|
def _device_id_lowering_rule(ctx: LoweringRuleContext):
|
|
|
|
return tpu.DeviceIdOp().result
|
|
|
|
lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule
|
2023-10-02 17:03:40 -07:00
|
|
|
|
|
|
|
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: str):
|
2024-03-22 07:11:45 -07:00
|
|
|
device_id = tpu.DeviceIdOp().result
|
|
|
|
mesh_shape = ctx.lowering_context.mesh_context.mesh_shape
|
2023-10-02 17:03:40 -07:00
|
|
|
axis_names = ctx.lowering_context.mesh_context.axis_names
|
2024-03-22 07:11:45 -07:00
|
|
|
axis_index = axis_names.index(axis_name)
|
|
|
|
axis_size = ir_constant(mesh_shape[axis_index])
|
|
|
|
minor_divisor = ir_constant(
|
|
|
|
np.prod(mesh_shape[axis_index + 1 :], dtype=np.int32)
|
|
|
|
)
|
|
|
|
return arith.remsi(arith.divsi(device_id, minor_divisor), axis_size)
|
2023-10-02 17:03:40 -07:00
|
|
|
lowering_rules[lax.axis_index_p] = _axis_index_rule
|
2023-11-29 04:02:30 -08:00
|
|
|
|
|
|
|
def _get_barrier_semaphore_rule(ctx: LoweringRuleContext):
|
2024-01-04 17:55:49 -08:00
|
|
|
memref_type = aval_to_ir_type(ctx.avals_out[0])
|
|
|
|
return tpu.GetBarrierSemaphoreOp(memref_type).result
|
2023-11-29 04:02:30 -08:00
|
|
|
lowering_rules[tpu_primitives.get_barrier_semaphore_p] = _get_barrier_semaphore_rule
|
2024-05-20 09:07:03 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _delay_rule(ctx: LoweringRuleContext, nanos: int):
|
|
|
|
return tpu.DelayOp(nanos).results
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[tpu_primitives.delay_p] = _delay_rule
|
2024-05-23 10:01:22 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _debug_print_rule(
|
|
|
|
ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool
|
|
|
|
):
|
|
|
|
primitives.check_debug_print_format(fmt, *args)
|
|
|
|
if has_placeholders:
|
|
|
|
if not all(
|
|
|
|
isinstance(arg.type, ir.IntegerType) and arg.type.width == 32
|
|
|
|
for arg in args
|
|
|
|
):
|
|
|
|
raise TypeError(
|
|
|
|
"All arguments must be 32-bit integers when using"
|
|
|
|
" placeholders (`{...}`). If you need to print values of other types,"
|
|
|
|
" remove placeholders from the format string."
|
|
|
|
)
|
|
|
|
|
|
|
|
# TPU expects $0, $1 etc as placeholders.
|
|
|
|
tpu_fmt = "".join(
|
|
|
|
f"{text}${idx}"
|
|
|
|
for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt))
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
tpu_fmt = fmt
|
|
|
|
tpu.log(args, tpu_fmt, formatted=has_placeholders)
|
|
|
|
return ()
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[primitives.debug_print_p] = _debug_print_rule
|
2024-05-24 12:21:31 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds):
|
|
|
|
del ctx
|
2024-06-12 14:36:31 -07:00
|
|
|
# In the KeyScalarBundle case we unpack the bundle and set the seed with
|
|
|
|
# the list of scalars.
|
|
|
|
if len(seeds) == 1 and isinstance(seeds[0], KeyScalarBundle):
|
|
|
|
return tpu.PRNGSeed32Op(seeds[0].scalars).results
|
|
|
|
# For integer seeds, we can set the seed directly as PRNGSeed32Op natively
|
|
|
|
# takes in a list of integers as input.
|
|
|
|
all_integers = all(isinstance(seed.type, ir.IntegerType) for seed in seeds)
|
|
|
|
if not all_integers:
|
2024-06-24 11:19:59 -07:00
|
|
|
seed_types = [seed.type for seed in seeds]
|
|
|
|
raise ValueError(f"All seed data must be scalar integers. Got {seed_types}")
|
2024-05-24 12:21:31 -07:00
|
|
|
return tpu.PRNGSeed32Op(seeds).results
|
|
|
|
lowering_rules[tpu_primitives.prng_seed_p] = _prng_seed_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _prng_random_bits_lowering_rule(ctx: LoweringRuleContext, *, shape):
|
|
|
|
if len(shape) <= 1:
|
|
|
|
# TODO(b/342054464): Support implicit dims for PRNGRandomBitsOp.
|
|
|
|
raise NotImplementedError("random_bits only supports rank>=2 outputs.")
|
|
|
|
out_aval = ctx.avals_out[0]
|
|
|
|
out_type = aval_to_ir_type(out_aval)
|
|
|
|
return tpu.PRNGRandomBitsOp(out_type).result
|
|
|
|
lowering_rules[tpu_primitives.prng_random_bits_p] = _prng_random_bits_lowering_rule
|
2024-06-10 18:07:33 -07:00
|
|
|
|
|
|
|
|
|
|
|
def random_seed_lowering(ctx, seeds, *, impl):
|
|
|
|
seed_lowering = lower_fun(
|
|
|
|
impl.seed, multiple_results=False)
|
|
|
|
return seed_lowering(ctx, seeds)
|
|
|
|
lowering_rules[prng.random_seed_p] = random_seed_lowering
|
|
|
|
|
|
|
|
|
|
|
|
def random_bits_lowering(ctx, keys, *, bit_width, shape):
|
|
|
|
assert bit_width == 32, "Only 32-bit PRNG supported."
|
|
|
|
aval, = ctx.avals_in
|
|
|
|
impl = aval.dtype._impl
|
|
|
|
bits_lowering = lower_fun(
|
|
|
|
impl.random_bits, multiple_results=False)
|
|
|
|
return bits_lowering(ctx, keys, bit_width=bit_width, shape=shape)
|
|
|
|
lowering_rules[prng.random_bits_p] = random_bits_lowering
|
|
|
|
|
|
|
|
|
|
|
|
def random_fold_in_lowering(ctx, keys, msgs):
|
|
|
|
keys_aval, _ = ctx.avals_in
|
|
|
|
impl = keys_aval.dtype._impl
|
|
|
|
fold_in_lowering = lower_fun(
|
|
|
|
impl.fold_in, multiple_results=False)
|
|
|
|
return fold_in_lowering(ctx, keys, msgs)
|
|
|
|
lowering_rules[prng.random_fold_in_p] = random_fold_in_lowering
|
|
|
|
|
|
|
|
|
|
|
|
def random_unwrap_lowering(ctx, key):
|
2024-06-12 14:36:31 -07:00
|
|
|
del ctx, key
|
|
|
|
raise NotImplementedError("key_data not implemented.")
|
2024-06-10 18:07:33 -07:00
|
|
|
lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering
|
|
|
|
|
|
|
|
|
|
|
|
def random_wrap_lowering(ctx, key_data, *, impl):
|
|
|
|
del ctx, impl
|
2024-06-12 14:36:31 -07:00
|
|
|
if isinstance(key_data.type, ir.VectorType):
|
|
|
|
# If the key data lives in vregs, need to unpack it to sregs.
|
|
|
|
key_data_list = []
|
|
|
|
key_data_shape = key_data.type.shape
|
2024-06-24 11:19:59 -07:00
|
|
|
if len(key_data_shape) != 2:
|
|
|
|
raise NotImplementedError("Seed key_data must be 2D.")
|
|
|
|
if tuple(key_data_shape) != (1, 1):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Seed key_data of shape != (1, 1) not supported. "
|
|
|
|
f"Got: {key_data_shape}")
|
|
|
|
for i in range(key_data_shape[1]):
|
|
|
|
key_data_list.append(vector.ExtractOp(key_data, [], [0, i]))
|
2024-06-12 14:36:31 -07:00
|
|
|
return KeyScalarBundle(scalars=key_data_list)
|
|
|
|
if isinstance(key_data, KeyScalarBundle):
|
|
|
|
return key_data
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"key_data wrap {type(key_data)}")
|
|
|
|
|
2024-06-10 18:07:33 -07:00
|
|
|
lowering_rules[prng.random_wrap_p] = random_wrap_lowering
|