2023-08-01 16:42:26 -07:00
|
|
|
# Copyright 2023 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
"""Module for lowering JAX to Mosaic-compatible MLIR dialects."""
|
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-06-26 14:44:52 -04:00
|
|
|
from collections.abc import Callable, Sequence
|
Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
2024-07-23 15:15:11 -07:00
|
|
|
import contextlib
|
2023-08-01 16:42:26 -07:00
|
|
|
import dataclasses
|
|
|
|
import functools
|
2024-05-23 10:01:22 -07:00
|
|
|
import string
|
2024-07-22 23:24:31 -07:00
|
|
|
from typing import Any, Hashable
|
2024-05-16 15:10:01 +01:00
|
|
|
|
2024-01-05 08:52:32 -08:00
|
|
|
import jax
|
2025-02-08 15:19:46 +02:00
|
|
|
from jax import api_util
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax import lax
|
|
|
|
from jax import tree_util
|
2024-05-23 10:01:22 -07:00
|
|
|
from jax._src import ad_util
|
2024-10-11 13:33:20 -07:00
|
|
|
from jax._src import checkify
|
2024-07-02 00:40:13 -07:00
|
|
|
from jax._src import core as jax_core
|
2023-08-04 15:08:26 -07:00
|
|
|
from jax._src import custom_derivatives
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src import debugging
|
2024-06-12 14:36:31 -07:00
|
|
|
from jax._src import dtypes
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src import linear_util as lu
|
2023-10-02 17:03:40 -07:00
|
|
|
from jax._src import mesh as mesh_lib
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src import pjit
|
2024-06-10 18:07:33 -07:00
|
|
|
from jax._src import prng
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src import source_info_util
|
|
|
|
from jax._src import state
|
2024-10-01 10:25:53 -07:00
|
|
|
from jax._src import traceback_util
|
2025-02-12 08:15:15 -08:00
|
|
|
from jax._src.cloud_tpu_init import is_cloud_tpu_older_than
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src.interpreters import mlir
|
|
|
|
from jax._src.interpreters import partial_eval as pe
|
2023-09-15 16:00:19 -07:00
|
|
|
from jax._src.lax import lax as lax_internal
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src.lax.control_flow import for_loop
|
|
|
|
from jax._src.lib.mlir import ir
|
|
|
|
from jax._src.lib.mlir.dialects import arith
|
|
|
|
from jax._src.lib.mlir.dialects import func
|
|
|
|
from jax._src.lib.mlir.dialects import math
|
|
|
|
from jax._src.lib.mlir.dialects import memref
|
|
|
|
from jax._src.lib.mlir.dialects import scf
|
|
|
|
from jax._src.lib.mlir.dialects import vector
|
2024-08-05 04:23:15 -07:00
|
|
|
from jax._src.pallas import core as pallas_core
|
2024-10-30 10:12:47 -07:00
|
|
|
from jax._src.pallas import pallas_call
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src.pallas import primitives
|
2023-08-04 13:43:04 -07:00
|
|
|
from jax._src.pallas import utils as pallas_utils
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src.pallas.mosaic import core as tpu_core
|
2024-08-12 14:41:58 -07:00
|
|
|
from jax._src.pallas.mosaic import error_handling
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src.pallas.mosaic import primitives as tpu_primitives
|
2024-10-09 14:47:45 -07:00
|
|
|
from jax._src.pallas.mosaic import random as pl_random
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src.state import discharge as state_discharge
|
2024-01-02 15:52:57 -08:00
|
|
|
from jax._src.state import indexing
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src.state import primitives as state_primitives
|
[Pallas TPU] Support ref reshape.
Jaxpr example:
```
{ lambda ; a:MemRef<None>{int32[32,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:16,:][bitcast(int16[32,256])][reshape(int16[2,16,256])][bitcast(float16[2,16,256])][1:,:,:][reshape(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
- DMA with reshaped ref
- Load from reshaped ref
- Store to reshaped ref
- Multiple transforms
- Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 686186426
2024-10-15 11:51:37 -07:00
|
|
|
from jax._src.state.types import RefBitcaster, RefReshaper
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
from jax._src.state.utils import dtype_bitwidth
|
2024-12-09 08:22:56 -08:00
|
|
|
from jax._src.typing import Array, DTypeLike
|
2023-08-01 16:42:26 -07:00
|
|
|
from jax._src.util import safe_map
|
|
|
|
from jax._src.util import safe_zip
|
|
|
|
from jax._src.util import split_list
|
|
|
|
from jax._src.util import unzip2
|
|
|
|
from jax.experimental.mosaic.dialects import tpu
|
|
|
|
import jax.numpy as jnp
|
2024-05-23 10:01:22 -07:00
|
|
|
from jaxlib.mlir.ir import Module
|
2023-08-01 16:42:26 -07:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
# TODO(sharadmv): enable type checking
|
|
|
|
# mypy: ignore-errors
|
|
|
|
|
2023-08-10 21:18:10 -07:00
|
|
|
NDIndexer = indexing.NDIndexer
|
2023-08-01 16:42:26 -07:00
|
|
|
TPUMemorySpace = tpu_core.TPUMemorySpace
|
2024-08-05 04:23:15 -07:00
|
|
|
MemorySpace = pallas_core.MemorySpace | TPUMemorySpace
|
2023-08-01 16:42:26 -07:00
|
|
|
VMEM = tpu_core.TPUMemorySpace.VMEM
|
|
|
|
SMEM = tpu_core.TPUMemorySpace.SMEM
|
2024-07-15 17:58:27 -07:00
|
|
|
# Booleans are stored as the following type in memrefs.
|
|
|
|
BOOL_MEMREF_TYPE = np.dtype('int32')
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2024-07-01 10:17:21 -07:00
|
|
|
# The value interpreted as a dynamic dimension by MLIR.
|
2024-02-01 09:14:30 -08:00
|
|
|
MLIR_DYNAMIC = -9223372036854775808
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
partial = functools.partial
|
|
|
|
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
|
|
|
|
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin
|
|
|
|
|
|
|
|
|
2023-10-02 17:03:40 -07:00
|
|
|
@dataclasses.dataclass
|
|
|
|
class MeshContext:
|
2024-03-22 07:11:45 -07:00
|
|
|
mesh_shape: tuple[int, ...]
|
2023-10-02 17:03:40 -07:00
|
|
|
axis_names: tuple[str, ...]
|
|
|
|
mesh_strides: tuple[int, ...]
|
|
|
|
|
2025-03-01 00:30:20 -08:00
|
|
|
# Note - On Export Placeholders
|
|
|
|
#
|
|
|
|
# Mosaic uses vector IR, which does not have a concept of dynamic
|
|
|
|
# dimensions. We need to come up with a way to represent dynamic dimensions in
|
|
|
|
# vector IR, and so we use placeholders, which are later replaced during
|
|
|
|
# specialization.
|
2025-01-14 20:33:34 -08:00
|
|
|
class LoweringDynamicShapeEnv:
|
|
|
|
dim_expr_to_placeholder: dict[Any, ir.Value] = {}
|
|
|
|
|
|
|
|
def to_placeholder(self, dim_expr: Any) -> ir.Value:
|
|
|
|
if dim_expr not in self.dim_expr_to_placeholder:
|
|
|
|
next_val = np.iinfo(np.int32).max - len(self.dim_expr_to_placeholder)
|
|
|
|
self.dim_expr_to_placeholder[dim_expr] = next_val
|
|
|
|
return self.dim_expr_to_placeholder[dim_expr]
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
@dataclasses.dataclass
|
|
|
|
class LoweringContext:
|
|
|
|
ir_context: ir.Context
|
Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
2024-07-23 15:15:11 -07:00
|
|
|
grid_sizes: tuple[int, ...] # Includes both user and vmap axes.
|
2024-07-22 23:24:31 -07:00
|
|
|
grid_names: tuple[Hashable, ...] | None
|
2024-03-11 08:40:34 -07:00
|
|
|
mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions.
|
|
|
|
user_grid_indices: Sequence[ir.Value] | None
|
2024-08-05 04:23:15 -07:00
|
|
|
block_shapes: list[tuple[int | pallas_core.Mapped, ...]]
|
2023-08-01 16:42:26 -07:00
|
|
|
name_stack: source_info_util.NameStack
|
2023-10-02 17:03:40 -07:00
|
|
|
mesh_context: MeshContext | None
|
2023-08-01 16:42:26 -07:00
|
|
|
replace = dataclasses.replace
|
2024-01-03 09:00:29 -08:00
|
|
|
traceback_caches: mlir.TracebackCaches
|
2024-07-17 14:08:07 -07:00
|
|
|
for_verification: bool
|
2025-01-14 20:33:34 -08:00
|
|
|
forward_compatible: bool
|
|
|
|
dynamic_shape_replacement_fn: Callable[
|
|
|
|
[tuple[jax.DimSize, ...]], tuple[int, ...]
|
|
|
|
]
|
2024-01-03 09:00:29 -08:00
|
|
|
|
Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
2024-07-23 15:15:11 -07:00
|
|
|
@property
|
|
|
|
def grid_rank(self):
|
|
|
|
return len(self.grid_sizes)
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def grid_name_context(self):
|
|
|
|
# TODO(b/355036977): generalize this across other platforms
|
|
|
|
if not self.grid_names:
|
|
|
|
yield
|
|
|
|
return
|
|
|
|
grid_names = self.grid_names
|
|
|
|
valid_grid_sizes = tuple(
|
|
|
|
d for i, d in enumerate(self.grid_sizes) if i not in self.mapped_dims
|
|
|
|
)
|
|
|
|
grid_env = zip(grid_names, valid_grid_sizes)
|
|
|
|
with jax_core.extend_axis_env_nd(grid_env):
|
|
|
|
yield
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
class LoweringRuleContext:
|
|
|
|
lowering_context: LoweringContext
|
|
|
|
avals_in: Sequence[jax_core.AbstractValue]
|
|
|
|
avals_out: Sequence[jax_core.AbstractValue]
|
2024-08-05 04:23:15 -07:00
|
|
|
block_shapes: Sequence[tuple[int | pallas_core.Mapped, ...] | None]
|
2023-08-01 16:42:26 -07:00
|
|
|
replace = dataclasses.replace
|
|
|
|
|
2025-01-17 06:59:55 -08:00
|
|
|
@property
|
|
|
|
def forward_compatible(self):
|
|
|
|
return self.lowering_context.forward_compatible
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2024-09-19 19:07:35 -07:00
|
|
|
def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None
|
|
|
|
) -> TPUMemorySpace:
|
|
|
|
match memory_space:
|
|
|
|
case None:
|
|
|
|
# We pick VMEM as the default one when no memory space is
|
|
|
|
# specified
|
|
|
|
return TPUMemorySpace.VMEM
|
|
|
|
case pallas_core.MemorySpace.ANY:
|
|
|
|
# Map the general ANY memory space to TPU ANY memory space
|
|
|
|
return TPUMemorySpace.ANY
|
|
|
|
case pallas_core.MemorySpace.ERROR | pallas_core.MemorySpace.INDEX:
|
|
|
|
return TPUMemorySpace.SMEM
|
|
|
|
case TPUMemorySpace():
|
|
|
|
# Leave the memory space unchanged
|
|
|
|
return memory_space
|
|
|
|
case _:
|
2025-01-17 20:16:36 -05:00
|
|
|
raise ValueError(f"Invalid memory space: {memory_space}")
|
2024-09-19 19:07:35 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _memory_space_to_mosaic_attribute(memory_space: MemorySpace | None
|
|
|
|
) -> ir.Attribute:
|
|
|
|
tpu_memory_space = _memory_space_to_tpu_memory_space(memory_space)
|
|
|
|
return ir.Attribute.parse(f"#tpu.memory_space<{tpu_memory_space}>")
|
2023-08-23 11:06:27 -07:00
|
|
|
|
2024-07-15 17:58:27 -07:00
|
|
|
def _dtype_to_ir_type(dtype: jnp.dtype,
|
|
|
|
is_kernel_boundary: bool = False) -> ir.Type:
|
2024-01-11 06:32:57 -08:00
|
|
|
if jnp.issubdtype(dtype, tpu_core.semaphore_dtype):
|
|
|
|
if jnp.issubdtype(dtype, tpu_core.dma_semaphore):
|
|
|
|
return ir.Type.parse("!tpu.dma_semaphore")
|
|
|
|
elif jnp.issubdtype(dtype, tpu_core.semaphore):
|
|
|
|
return ir.Type.parse("!tpu.semaphore")
|
|
|
|
elif jnp.issubdtype(dtype, tpu_core.barrier_semaphore):
|
|
|
|
return ir.Type.parse("!tpu.semaphore")
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
2024-07-15 17:58:27 -07:00
|
|
|
if is_kernel_boundary and jnp.issubdtype(dtype, jnp.dtype('bool')):
|
|
|
|
dtype = BOOL_MEMREF_TYPE
|
2024-06-10 18:07:33 -07:00
|
|
|
# TODO(justinfu): Remove after mosaic supports unsigned types.
|
|
|
|
# This conversion makes mosaic interpret all unsigned types as signed types.
|
|
|
|
type = mlir.dtype_to_ir_type(dtype)
|
|
|
|
if isinstance(type, ir.IntegerType):
|
|
|
|
return ir.IntegerType.get_signless(type.width)
|
|
|
|
else:
|
|
|
|
return type
|
2023-08-23 11:06:27 -07:00
|
|
|
|
2025-01-14 20:33:34 -08:00
|
|
|
|
|
|
|
def aval_to_ir_type(
|
|
|
|
dynamic_shape_replacement_fn,
|
|
|
|
aval,
|
|
|
|
shape=None,
|
|
|
|
memory_space: MemorySpace | None = None,
|
|
|
|
is_kernel_boundary: bool = False,
|
|
|
|
):
|
2023-10-24 17:28:05 -07:00
|
|
|
if isinstance(aval, tpu_core.AbstractSemaphore):
|
2024-01-04 17:55:49 -08:00
|
|
|
if aval.sem_type is tpu_core.SemaphoreType.DMA:
|
|
|
|
sem_type = ir.Type.parse("!tpu.dma_semaphore")
|
|
|
|
elif aval.sem_type is tpu_core.SemaphoreType.REGULAR:
|
|
|
|
sem_type = ir.Type.parse("!tpu.semaphore")
|
|
|
|
elif aval.sem_type is tpu_core.SemaphoreType.BARRIER:
|
|
|
|
sem_type = ir.Type.parse("!tpu.semaphore")
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Cannot allocate {aval.sem_type}.")
|
2024-09-19 19:07:35 -07:00
|
|
|
memspace = _memory_space_to_mosaic_attribute(TPUMemorySpace.SEMAPHORE)
|
2024-01-11 06:32:57 -08:00
|
|
|
return ir.MemRefType.get((), sem_type, memory_space=memspace)
|
2024-06-12 14:36:31 -07:00
|
|
|
if dtypes.issubdtype(aval.dtype, dtypes.prng_key):
|
|
|
|
shape = aval.dtype._impl.key_shape
|
2024-10-09 14:47:45 -07:00
|
|
|
if pl_random.is_pallas_impl(aval.dtype._impl):
|
|
|
|
if memory_space is None:
|
|
|
|
memory_space = TPUMemorySpace.SMEM
|
|
|
|
if memory_space != TPUMemorySpace.SMEM:
|
|
|
|
raise ValueError(
|
|
|
|
f"PRNG keys must be stored in SMEM. Got {memory_space}"
|
|
|
|
)
|
2024-09-19 19:07:35 -07:00
|
|
|
memspace = _memory_space_to_mosaic_attribute(memory_space)
|
2024-06-12 14:36:31 -07:00
|
|
|
return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)),
|
|
|
|
memory_space=memspace)
|
2023-08-01 16:42:26 -07:00
|
|
|
if isinstance(aval, state.AbstractRef):
|
2023-10-24 17:28:05 -07:00
|
|
|
if shape is None:
|
|
|
|
shape = aval.shape
|
2024-09-19 19:07:35 -07:00
|
|
|
memspace = _memory_space_to_mosaic_attribute(memory_space)
|
2025-01-14 20:33:34 -08:00
|
|
|
shape = dynamic_shape_replacement_fn(shape)
|
2024-07-15 17:58:27 -07:00
|
|
|
return ir.MemRefType.get(shape,
|
|
|
|
_dtype_to_ir_type(aval.dtype, is_kernel_boundary=True),
|
|
|
|
memory_space=memspace)
|
2023-10-24 17:28:05 -07:00
|
|
|
if isinstance(aval, jax_core.ShapedArray):
|
|
|
|
if shape is None:
|
|
|
|
shape = aval.shape
|
|
|
|
if not shape:
|
2024-07-15 17:58:27 -07:00
|
|
|
return _dtype_to_ir_type(
|
|
|
|
aval.dtype, is_kernel_boundary=is_kernel_boundary)
|
2025-01-14 20:33:34 -08:00
|
|
|
shape = dynamic_shape_replacement_fn(shape)
|
2024-07-15 17:58:27 -07:00
|
|
|
return ir.VectorType.get(
|
|
|
|
shape,
|
|
|
|
_dtype_to_ir_type(aval.dtype, is_kernel_boundary=is_kernel_boundary))
|
2023-08-01 16:42:26 -07:00
|
|
|
raise NotImplementedError(aval)
|
|
|
|
|
|
|
|
|
|
|
|
def ir_constant(x, mlir_type=None):
|
|
|
|
if not hasattr(x, "dtype"):
|
|
|
|
if isinstance(x, int):
|
|
|
|
x = np.array(x, np.int32)
|
|
|
|
elif isinstance(x, float):
|
|
|
|
x = np.array(x, np.float32)
|
|
|
|
if not mlir_type:
|
2024-01-11 06:32:57 -08:00
|
|
|
mlir_type = _dtype_to_ir_type(x.dtype)
|
2024-12-18 08:17:03 -08:00
|
|
|
if isinstance(x, int) or np.issubdtype(x.dtype, np.integer):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.constant(mlir_type, ir.IntegerAttr.get(mlir_type, int(x)))
|
2023-08-01 16:42:26 -07:00
|
|
|
elif isinstance(x, float) or x.dtype == np.float32:
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x)))
|
2023-08-01 16:42:26 -07:00
|
|
|
elif x.dtype == jnp.bfloat16:
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x)))
|
2023-08-01 16:42:26 -07:00
|
|
|
elif x.dtype == jnp.bool_:
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.constant(mlir_type, ir.BoolAttr.get(bool(x)))
|
2023-08-01 16:42:26 -07:00
|
|
|
raise NotImplementedError(x.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules = {}
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions = set()
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2024-07-02 00:40:13 -07:00
|
|
|
def _get_aval_physical_dtype_shape(aval):
|
|
|
|
dtype_physical_shape = jax_core.physical_aval(aval).shape[
|
|
|
|
len(aval.shape) :
|
|
|
|
]
|
|
|
|
return dtype_physical_shape
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2025-01-14 20:33:34 -08:00
|
|
|
|
2023-11-17 18:04:16 -08:00
|
|
|
def _get_arg_type(
|
2025-01-14 20:33:34 -08:00
|
|
|
dynamic_shape_replacement_fn: Callable[
|
|
|
|
[tuple[jax.DimSize, ...]], tuple[jax.DimSize, ...]
|
|
|
|
],
|
2023-11-17 18:04:16 -08:00
|
|
|
aval,
|
2024-08-05 04:23:15 -07:00
|
|
|
block_mapping: pallas_core.BlockMapping | None,
|
2023-11-17 18:04:16 -08:00
|
|
|
):
|
2024-01-18 17:19:38 -08:00
|
|
|
memory_space = None
|
2024-08-05 04:23:15 -07:00
|
|
|
if isinstance(aval, pallas_core.AbstractMemoryRef):
|
2023-11-17 18:04:16 -08:00
|
|
|
memory_space = aval.memory_space
|
2024-01-18 17:19:38 -08:00
|
|
|
# We assume unannotated memory refs are in VMEM
|
|
|
|
if memory_space is None:
|
|
|
|
memory_space = TPUMemorySpace.VMEM
|
2023-11-17 18:04:16 -08:00
|
|
|
if isinstance(aval, tpu_core.AbstractSemaphore):
|
2025-01-14 20:33:34 -08:00
|
|
|
return aval_to_ir_type(dynamic_shape_replacement_fn, aval), None
|
2024-07-23 15:25:14 +03:00
|
|
|
# TODO(necula): clean this None block_mapping
|
2023-11-17 18:04:16 -08:00
|
|
|
if block_mapping is None:
|
2025-01-14 20:33:34 -08:00
|
|
|
return (
|
|
|
|
aval_to_ir_type(
|
|
|
|
dynamic_shape_replacement_fn, aval, memory_space=memory_space
|
|
|
|
),
|
|
|
|
aval.shape,
|
|
|
|
)
|
2024-08-05 04:23:15 -07:00
|
|
|
shape = tuple(1 if b is pallas_core.mapped else b for b in block_mapping.block_shape)
|
2023-11-17 18:04:16 -08:00
|
|
|
return (
|
2025-01-14 20:33:34 -08:00
|
|
|
aval_to_ir_type(
|
|
|
|
dynamic_shape_replacement_fn,
|
|
|
|
aval,
|
|
|
|
shape=shape,
|
|
|
|
memory_space=memory_space,
|
|
|
|
),
|
2023-11-17 18:04:16 -08:00
|
|
|
block_mapping.block_shape,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2025-02-26 19:33:25 -08:00
|
|
|
def _canonicalize_dimension_semantic(
|
|
|
|
dimension_semantic: str | tpu_core.GridDimensionSemantics,
|
|
|
|
) -> str:
|
|
|
|
if isinstance(dimension_semantic, tpu_core.GridDimensionSemantics):
|
|
|
|
return dimension_semantic.value
|
|
|
|
return dimension_semantic
|
|
|
|
|
|
|
|
|
2023-11-17 18:04:16 -08:00
|
|
|
@dataclasses.dataclass(init=False)
|
|
|
|
class MosaicGridMapping:
|
|
|
|
grid: tuple[int, ...] | None
|
2024-07-22 23:24:31 -07:00
|
|
|
grid_names: tuple[Hashable, ...] | None
|
2023-11-17 18:04:16 -08:00
|
|
|
jaxpr: jax_core.Jaxpr
|
2024-08-05 04:23:15 -07:00
|
|
|
block_mappings: tuple[pallas_core.BlockMapping | None, ...]
|
2023-11-17 18:04:16 -08:00
|
|
|
mapped_dims: tuple[int, ...]
|
|
|
|
scalar_prefetch_types: tuple[ir.Type, ...]
|
|
|
|
operand_types: tuple[ir.Type, ...]
|
|
|
|
scratch_types: tuple[ir.Type, ...]
|
|
|
|
grid_types: tuple[ir.Type, ...]
|
|
|
|
scalar_prefetch_block_shapes: tuple[tuple[int, ...], ...]
|
|
|
|
operand_block_shapes: tuple[tuple[int, ...], ...]
|
|
|
|
scratch_block_shapes: tuple[tuple[int, ...], ...]
|
|
|
|
mesh_info: MeshInfo | None
|
|
|
|
get_grid_indices: Callable | None
|
|
|
|
|
2025-01-14 20:33:34 -08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
jaxpr: jax_core.Jaxpr,
|
|
|
|
grid_mapping: pallas_core.GridMapping,
|
2025-02-26 19:33:25 -08:00
|
|
|
dimension_semantics: tuple[str | tpu_core.GridDimensionSemantics, ...] | None,
|
2025-01-14 20:33:34 -08:00
|
|
|
mesh: mesh_lib.Mesh | None,
|
|
|
|
dynamic_shape_replacement_fn: Callable[
|
|
|
|
[tuple[jax.DimSize, ...]], tuple[int, ...]
|
|
|
|
],
|
|
|
|
):
|
2023-11-17 18:04:16 -08:00
|
|
|
self.grid = grid_mapping.grid
|
2024-07-22 23:24:31 -07:00
|
|
|
self.grid_names = grid_mapping.grid_names
|
2023-11-17 18:04:16 -08:00
|
|
|
self.jaxpr = jaxpr
|
|
|
|
self.block_mappings = grid_mapping.block_mappings
|
2024-07-25 01:49:59 -07:00
|
|
|
self.mapped_dims = grid_mapping.vmapped_dims
|
2024-08-20 15:06:27 -07:00
|
|
|
# TODO(mvoz): Generalize to not need this
|
2023-11-17 18:04:16 -08:00
|
|
|
user_grid = tuple(
|
|
|
|
g for i, g in enumerate(self.grid) if i not in self.mapped_dims
|
|
|
|
)
|
|
|
|
if dimension_semantics is None:
|
|
|
|
dimension_semantics = ("arbitrary",) * len(user_grid)
|
2025-02-26 19:33:25 -08:00
|
|
|
dimension_semantics = tuple(
|
|
|
|
_canonicalize_dimension_semantic(s) for s in dimension_semantics
|
|
|
|
)
|
2023-11-17 18:04:16 -08:00
|
|
|
if len(user_grid) != len(dimension_semantics):
|
|
|
|
raise ValueError(
|
|
|
|
"Must have dimension semantics for each dimension of the grid."
|
|
|
|
)
|
|
|
|
assert len(self.mapped_dims) + len(dimension_semantics) == len(
|
|
|
|
self.grid
|
|
|
|
), (
|
|
|
|
f"Misconfigured grid: {self.mapped_dims=}, {dimension_semantics=},"
|
|
|
|
f" {self.grid=}"
|
|
|
|
)
|
|
|
|
# dimension_semantics is user provided and won't take into account vmap
|
|
|
|
# dimensions. Here we add in parallel dimensions for the vmaps.
|
|
|
|
semantics_iter = iter(dimension_semantics)
|
|
|
|
self._dimension_semantics = tuple(
|
|
|
|
next(semantics_iter) if i not in self.mapped_dims else "parallel"
|
|
|
|
for i in range(len(self.grid))
|
|
|
|
)
|
|
|
|
|
|
|
|
in_avals = [invar.aval for invar in self.jaxpr.invars]
|
2024-08-02 00:15:48 -07:00
|
|
|
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
|
|
|
|
scalar_prefetch_avals = in_avals[grid_mapping.slice_index_ops]
|
|
|
|
operand_avals = in_avals[grid_mapping.slice_block_ops]
|
|
|
|
scratch_avals = in_avals[grid_mapping.slice_scratch_ops]
|
2023-11-17 18:04:16 -08:00
|
|
|
self.scalar_prefetch_types, _ = unzip2([
|
2025-01-14 20:33:34 -08:00
|
|
|
_get_arg_type(dynamic_shape_replacement_fn, aval, None)
|
|
|
|
for aval in scalar_prefetch_avals
|
|
|
|
])
|
2023-11-17 18:04:16 -08:00
|
|
|
self.scalar_prefetch_block_shapes = tuple(
|
|
|
|
aval.shape for aval in scalar_prefetch_avals)
|
|
|
|
self.operand_types, self.operand_block_shapes = unzip2([
|
2025-01-14 20:33:34 -08:00
|
|
|
_get_arg_type(dynamic_shape_replacement_fn, aval, block_mapping)
|
|
|
|
for aval, block_mapping in zip(operand_avals, self.block_mappings)
|
|
|
|
])
|
2023-11-17 18:04:16 -08:00
|
|
|
self.scratch_types, _ = unzip2([
|
2025-01-14 20:33:34 -08:00
|
|
|
_get_arg_type(dynamic_shape_replacement_fn, aval, None)
|
|
|
|
for aval in scratch_avals
|
|
|
|
])
|
2023-11-17 18:04:16 -08:00
|
|
|
self.scratch_block_shapes = tuple(
|
|
|
|
aval.shape if not isinstance(aval, tpu_core.AbstractSemaphore) else None
|
|
|
|
for aval in scratch_avals
|
|
|
|
)
|
|
|
|
self.grid_types, _ = unzip2([
|
2025-01-14 20:33:34 -08:00
|
|
|
_get_arg_type(
|
|
|
|
dynamic_shape_replacement_fn,
|
|
|
|
pallas_core.index_map_grid_aval,
|
|
|
|
None,
|
|
|
|
)
|
2023-11-17 18:04:16 -08:00
|
|
|
for _ in range(len(self.grid))
|
|
|
|
])
|
|
|
|
self._prepare_mesh_info(mesh)
|
2024-08-20 15:06:27 -07:00
|
|
|
|
|
|
|
if grid_mapping.get_grid_indices is None:
|
|
|
|
|
2024-10-06 18:32:44 -07:00
|
|
|
# Avoid using self.mapped_dims within the function, since doing so will
|
|
|
|
# introduce a self->_get_grid_indices->self reference cycle that means
|
|
|
|
# MosaicGridMapping instances can only ever be deleted by GC, rather than
|
|
|
|
# by their reference counts going to 0.
|
|
|
|
mapped_dims = self.mapped_dims
|
2024-08-20 15:06:27 -07:00
|
|
|
def _get_grid_indices(indices, maybe_include_mapped_dims: bool):
|
|
|
|
if maybe_include_mapped_dims:
|
|
|
|
return indices
|
|
|
|
return tuple(
|
2024-10-06 18:32:44 -07:00
|
|
|
idx for i, idx in enumerate(indices) if i not in mapped_dims
|
2024-08-20 15:06:27 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
self.get_grid_indices = _get_grid_indices
|
|
|
|
else:
|
|
|
|
self.get_grid_indices = grid_mapping.get_grid_indices
|
2023-11-17 18:04:16 -08:00
|
|
|
|
|
|
|
def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None):
|
|
|
|
if not self.has_communication:
|
|
|
|
self.mesh_info = None
|
|
|
|
return
|
2023-10-02 17:03:40 -07:00
|
|
|
if mesh is None:
|
2023-11-17 18:04:16 -08:00
|
|
|
raise ValueError(
|
|
|
|
"Cannot use communication in pallas_call without shard_map."
|
|
|
|
)
|
2023-11-07 17:54:43 -08:00
|
|
|
axis_names = mesh.axis_names
|
2024-07-22 23:24:31 -07:00
|
|
|
if self.grid_names is not None:
|
|
|
|
if any(a in self.grid_names for a in axis_names):
|
|
|
|
raise ValueError(
|
|
|
|
"Cannot shadow axis mesh axis names with grid names. mesh axis"
|
|
|
|
f" names: {mesh.axis_names}, grid names: {self.grid_names}"
|
|
|
|
)
|
2023-10-02 17:03:40 -07:00
|
|
|
# We need mesh <-> logical translation tables. Since the logical IDs are
|
|
|
|
# just linearized versions of the mesh IDs, we create those tables.
|
|
|
|
mesh_strides = pallas_utils.strides_from_shape(tuple(
|
|
|
|
mesh.shape[a] for a in axis_names
|
|
|
|
))
|
2024-08-13 08:51:32 -07:00
|
|
|
mesh_shape = tuple(mesh.shape.values())
|
|
|
|
self.mesh_info = MeshInfo(mesh_shape, axis_names, mesh_strides)
|
2023-11-17 18:04:16 -08:00
|
|
|
|
|
|
|
def maybe_compress_grid(self):
|
|
|
|
# If we have many leading parallel dimensions, we should "compress" them
|
|
|
|
# into one so we can load balance across cores as best as we can.
|
|
|
|
# TODO(sharadmv): implement this optimization
|
|
|
|
pass
|
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def has_communication(self) -> bool:
|
2024-07-22 23:24:31 -07:00
|
|
|
nonlocal_axis_names = set()
|
|
|
|
def _get_nonlocal_axis_names(jaxpr: jax_core.Jaxpr):
|
|
|
|
return {
|
|
|
|
e.name
|
|
|
|
for e in jaxpr.effects
|
|
|
|
if isinstance(e, jax_core.NamedAxisEffect)
|
|
|
|
and (not self.grid_names or e.name not in self.grid_names)
|
|
|
|
}
|
|
|
|
nonlocal_axis_names.update(_get_nonlocal_axis_names(self.jaxpr))
|
|
|
|
for bm in self.block_mappings:
|
|
|
|
if bm is not None:
|
|
|
|
nonlocal_axis_names.update(_get_nonlocal_axis_names(bm.index_map_jaxpr))
|
|
|
|
return bool(nonlocal_axis_names)
|
2023-11-17 18:04:16 -08:00
|
|
|
|
|
|
|
def get_extra_args(self) -> tuple[Any, ...]:
|
2024-03-22 07:11:45 -07:00
|
|
|
return ()
|
2023-11-17 18:04:16 -08:00
|
|
|
|
|
|
|
def get_dimension_semantics(self) -> ir.ArrayAttr:
|
|
|
|
|
|
|
|
def _get_semantics(s: str | None) -> str:
|
|
|
|
if s is None:
|
|
|
|
return "#tpu.dimension_semantics<arbitrary>"
|
|
|
|
return f"#tpu.dimension_semantics<{s}>"
|
|
|
|
|
|
|
|
return ir.ArrayAttr.get(
|
|
|
|
map(
|
|
|
|
ir.Attribute.parse,
|
|
|
|
map(_get_semantics, self._dimension_semantics),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
class MeshInfo:
|
2024-03-22 07:11:45 -07:00
|
|
|
mesh_shape: tuple[int, ...]
|
2023-11-17 18:04:16 -08:00
|
|
|
axis_names: list[str]
|
|
|
|
mesh_strides: tuple[int, ...]
|
|
|
|
|
2024-09-18 20:38:54 -07:00
|
|
|
|
|
|
|
def _check_block_mappings(
|
|
|
|
block_mappings: tuple[pallas_core.BlockMapping, ...],
|
2024-07-29 06:48:51 -07:00
|
|
|
lowering_context: mlir.LoweringRuleContext,
|
2025-01-21 10:24:10 +01:00
|
|
|
debug_info: jax_core.DebugInfo,
|
2024-09-18 20:38:54 -07:00
|
|
|
) -> None:
|
|
|
|
del lowering_context # originally needed for forward compat
|
|
|
|
for bm in block_mappings:
|
2024-08-01 09:14:27 -07:00
|
|
|
rank = len(bm.block_shape)
|
|
|
|
# TODO(necula): add tests for SMEM blocks with trivial windowing
|
|
|
|
# We support scalars too
|
|
|
|
if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SMEM and
|
|
|
|
bm.has_trivial_window()):
|
|
|
|
continue
|
2024-09-18 20:38:54 -07:00
|
|
|
if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SEMAPHORE:
|
|
|
|
continue
|
|
|
|
|
2024-07-29 06:48:51 -07:00
|
|
|
def err_details():
|
2025-01-21 10:24:10 +01:00
|
|
|
return (f"Block spec for {bm.origin} in pallas_call {debug_info.func_src_info} "
|
2024-08-05 04:23:15 -07:00
|
|
|
"has block shape "
|
2024-07-29 06:48:51 -07:00
|
|
|
f"{bm.block_shape}, array shape {bm.array_shape_dtype.shape}, "
|
|
|
|
# TODO(necula): add index_map source location info
|
2024-10-11 08:07:29 -07:00
|
|
|
f"and index_map {bm.index_map_jaxpr.jaxpr}, in "
|
2024-08-01 09:14:27 -07:00
|
|
|
f"memory space {bm.block_aval.memory_space}."
|
|
|
|
"\nSee details at https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec")
|
2024-09-14 02:31:31 -07:00
|
|
|
if rank < 1:
|
|
|
|
raise ValueError(
|
|
|
|
"The Pallas TPU lowering currently supports only blocks of "
|
|
|
|
"rank >= 1. " + err_details())
|
2024-07-29 06:48:51 -07:00
|
|
|
|
|
|
|
if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY and
|
|
|
|
not bm.has_trivial_window()):
|
|
|
|
raise ValueError(
|
|
|
|
"The Pallas TPU lowering currently supports in memory space ANY "
|
|
|
|
"only blocks having the same block shape as the array shape "
|
|
|
|
"and a trivial index_map (returning all 0s)." + err_details())
|
|
|
|
|
2024-10-09 14:47:45 -07:00
|
|
|
unmapped_bs = [
|
|
|
|
1 if bs is pallas_core.mapped else bs for bs in bm.block_shape]
|
2024-08-01 09:14:27 -07:00
|
|
|
bs0, as0 = unmapped_bs[-1], bm.array_shape_dtype.shape[-1]
|
|
|
|
if rank >= 2:
|
|
|
|
bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2]
|
|
|
|
else:
|
|
|
|
bs1, as1 = 1, 1
|
2024-09-14 02:31:31 -07:00
|
|
|
|
|
|
|
if rank >= 2:
|
2024-07-29 06:48:51 -07:00
|
|
|
evenly_divisible = (
|
2024-09-14 02:31:31 -07:00
|
|
|
(bs0 == as0 or bs0 % 128 == 0) and
|
|
|
|
(bs1 == as1 or bs1 % 8 == 0)
|
2024-07-29 06:48:51 -07:00
|
|
|
)
|
2024-10-16 08:40:42 -07:00
|
|
|
if not evenly_divisible:
|
2025-01-14 20:33:34 -08:00
|
|
|
extra_msg = ""
|
|
|
|
if pallas_core.dynamic_shapes_export_enabled():
|
|
|
|
extra_msg = (
|
|
|
|
" In dynamic shape export - your kernel symbolic args must be"
|
|
|
|
" annotated with constraints where the computation *after*"
|
|
|
|
" applying any grid mapping is divisible by 8 and 128"
|
|
|
|
" respectively. Ex: (mod(floordiv(m_dim, grid_size), 8) == 0))"
|
|
|
|
)
|
2024-10-16 08:40:42 -07:00
|
|
|
raise ValueError(
|
|
|
|
"The Pallas TPU lowering currently requires that the last two "
|
|
|
|
"dimensions of your block shape are divisible by 8 and 128 "
|
|
|
|
"respectively, or be equal to the respective dimensions of the "
|
|
|
|
"overall array. "
|
2025-01-14 20:33:34 -08:00
|
|
|
+ extra_msg
|
2024-10-16 08:40:42 -07:00
|
|
|
+ err_details()
|
|
|
|
)
|
2024-07-29 06:48:51 -07:00
|
|
|
else:
|
2024-09-14 02:31:31 -07:00
|
|
|
assert rank == 1
|
2024-11-12 13:44:08 -08:00
|
|
|
# bools get a bitwidth of 32 due to how mosaic handles them
|
|
|
|
if bm.array_shape_dtype.dtype == jnp.bool_:
|
|
|
|
bitwidth = 32
|
|
|
|
else:
|
|
|
|
bitwidth = lax_internal._bit_width(bm.array_shape_dtype.dtype)
|
|
|
|
packing = 32 // bitwidth
|
|
|
|
tiling_size = 128 * packing
|
2024-09-14 02:31:31 -07:00
|
|
|
evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0)
|
2024-10-16 08:40:42 -07:00
|
|
|
if not evenly_divisible:
|
|
|
|
raise ValueError(
|
|
|
|
"The Pallas TPU lowering currently requires that rank 1 block"
|
|
|
|
" shapes, either 1) the first (and only) dimension of the block"
|
|
|
|
" shape is equal to the first (and only) dimension of the array"
|
|
|
|
" shape, or 2) the first (and only) dimension of the block shape"
|
|
|
|
f" is a multiple of the tiling size ({tiling_size} = 128 * (32 //"
|
|
|
|
f" {lax_internal._bit_width(bm.array_shape_dtype.dtype)})) of the"
|
|
|
|
" array shape. "
|
|
|
|
+ err_details()
|
|
|
|
)
|
2024-09-18 20:38:54 -07:00
|
|
|
|
|
|
|
|
|
|
|
def lower_jaxpr_to_module(
|
|
|
|
lowering_context: mlir.LoweringRuleContext,
|
|
|
|
ctx: ir.Context,
|
|
|
|
grid_mapping: pallas_core.GridMapping,
|
|
|
|
jaxpr: jax_core.Jaxpr,
|
|
|
|
*,
|
2025-02-26 19:33:25 -08:00
|
|
|
dimension_semantics: (
|
|
|
|
tuple[str | tpu_core.GridDimensionSemantics, None, ...] | None
|
|
|
|
),
|
2024-09-18 20:38:54 -07:00
|
|
|
mesh: mesh_lib.Mesh | None = None,
|
|
|
|
for_verification: bool = False,
|
2025-01-14 20:33:34 -08:00
|
|
|
dynamic_shape_replacement_enabled: bool = False,
|
2024-09-18 20:38:54 -07:00
|
|
|
) -> tuple[Module, tuple[Any, ...]]:
|
2025-02-12 08:15:15 -08:00
|
|
|
# NOTE: We should bump this periodically
|
|
|
|
if is_cloud_tpu_older_than(2025, 1, 10):
|
|
|
|
raise RuntimeError(
|
|
|
|
"Pallas TPU requires a libTPU version that's at most a month old"
|
|
|
|
)
|
2025-01-21 10:24:10 +01:00
|
|
|
debug_info = jaxpr.debug_info
|
2025-01-14 20:33:34 -08:00
|
|
|
if dynamic_shape_replacement_enabled:
|
|
|
|
_mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv()
|
|
|
|
|
|
|
|
def dynamic_shape_replacement_fn(
|
|
|
|
shape: jax_core.Shape,
|
|
|
|
) -> tuple[int, ...]:
|
|
|
|
return tuple(
|
|
|
|
_mosaic_lowering_dynamic_shape_env.to_placeholder(dim_expr)
|
|
|
|
if jax_core.is_dim(dim_expr)
|
|
|
|
else dim_expr
|
|
|
|
for dim_expr in shape
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
dynamic_shape_replacement_fn = lambda x: x
|
|
|
|
|
2024-09-18 20:38:54 -07:00
|
|
|
# Verify that we have legal block mappings to catch errors early.
|
2025-01-21 10:24:10 +01:00
|
|
|
_check_block_mappings(grid_mapping.block_mappings, lowering_context, debug_info)
|
2024-07-25 01:49:59 -07:00
|
|
|
|
2023-11-17 18:04:16 -08:00
|
|
|
mosaic_grid_mapping = MosaicGridMapping(
|
2025-01-14 20:33:34 -08:00
|
|
|
jaxpr,
|
|
|
|
grid_mapping,
|
|
|
|
dimension_semantics,
|
|
|
|
mesh,
|
|
|
|
dynamic_shape_replacement_fn,
|
|
|
|
)
|
2023-11-17 18:04:16 -08:00
|
|
|
mosaic_grid_mapping.maybe_compress_grid()
|
|
|
|
m = ir.Module.create()
|
2024-08-05 04:23:15 -07:00
|
|
|
attrs = m.operation.attributes
|
2025-01-21 10:24:10 +01:00
|
|
|
module_name = mlir.sanitize_name(debug_info.func_name)
|
2024-08-05 04:23:15 -07:00
|
|
|
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
2023-11-17 18:04:16 -08:00
|
|
|
sym_tab = ir.SymbolTable(m.operation)
|
2024-12-22 00:50:12 -08:00
|
|
|
|
2024-07-17 14:08:07 -07:00
|
|
|
func_op = lower_jaxpr_to_func(
|
2024-12-22 00:50:12 -08:00
|
|
|
ctx,
|
|
|
|
jaxpr,
|
|
|
|
mosaic_grid_mapping=mosaic_grid_mapping,
|
|
|
|
name="main",
|
|
|
|
for_verification=for_verification,
|
2025-01-14 20:33:34 -08:00
|
|
|
forward_compatible=lowering_context.is_forward_compat(),
|
|
|
|
dynamic_shape_replacement_fn=dynamic_shape_replacement_fn,
|
2024-07-17 14:08:07 -07:00
|
|
|
)
|
2023-11-17 18:04:16 -08:00
|
|
|
m.body.append(func_op)
|
|
|
|
sym_tab.insert(func_op)
|
2023-08-01 16:42:26 -07:00
|
|
|
window_params = []
|
2023-11-17 18:04:16 -08:00
|
|
|
grid = mosaic_grid_mapping.grid
|
2023-10-24 17:28:05 -07:00
|
|
|
if grid:
|
2024-07-23 15:25:14 +03:00
|
|
|
for i, bm in enumerate(grid_mapping.block_mappings):
|
2023-10-24 17:28:05 -07:00
|
|
|
func_name = f"transform_{i}"
|
2024-10-08 10:43:59 -07:00
|
|
|
# ANY and SEMAPHORE operands don't support windowing and require empty window_params.
|
2024-09-19 19:07:35 -07:00
|
|
|
tpu_memory_space = _memory_space_to_tpu_memory_space(
|
|
|
|
bm.block_aval.memory_space)
|
2024-10-08 10:43:59 -07:00
|
|
|
if (
|
|
|
|
tpu_memory_space == tpu_core.TPUMemorySpace.ANY
|
|
|
|
or tpu_memory_space == tpu_core.TPUMemorySpace.SEMAPHORE
|
|
|
|
):
|
2024-07-29 06:48:51 -07:00
|
|
|
# We checked above that the block does not require windowing.
|
2024-01-05 08:52:32 -08:00
|
|
|
window_params.append(ir.DictAttr.get())
|
|
|
|
continue
|
2024-12-22 00:50:12 -08:00
|
|
|
|
2023-10-24 17:28:05 -07:00
|
|
|
mlir_func = lower_jaxpr_to_transform_func(
|
|
|
|
ctx,
|
|
|
|
bm.index_map_jaxpr.jaxpr,
|
2024-07-23 15:25:14 +03:00
|
|
|
bm.block_aval,
|
2023-11-07 17:54:43 -08:00
|
|
|
name=func_name,
|
2023-11-17 18:04:16 -08:00
|
|
|
mosaic_grid_mapping=mosaic_grid_mapping,
|
2024-07-17 14:08:07 -07:00
|
|
|
for_verification=for_verification,
|
2025-01-14 20:33:34 -08:00
|
|
|
forward_compatible=lowering_context.is_forward_compat(),
|
|
|
|
dynamic_shape_replacement_fn=dynamic_shape_replacement_fn,
|
2023-11-07 17:54:43 -08:00
|
|
|
)
|
2023-10-24 17:28:05 -07:00
|
|
|
assert mlir_func.verify(), mlir_func
|
|
|
|
block_shape = [
|
2024-08-05 04:23:15 -07:00
|
|
|
1 if b is pallas_core.mapped else b for b in bm.block_shape
|
2023-10-24 17:28:05 -07:00
|
|
|
]
|
2024-07-02 00:40:13 -07:00
|
|
|
# If we have an extended dtype, we need to add the block shape for the
|
|
|
|
# remaining physical dtype.
|
2024-07-23 15:25:14 +03:00
|
|
|
block_shape += list(_get_aval_physical_dtype_shape(bm.block_aval.inner_aval))
|
2025-01-14 20:33:34 -08:00
|
|
|
block_shape = dynamic_shape_replacement_fn(block_shape)
|
2023-10-24 17:28:05 -07:00
|
|
|
window_shape = ir.DenseI64ArrayAttr.get(block_shape)
|
2024-02-08 02:47:56 -08:00
|
|
|
block_params = dict(
|
|
|
|
window_bounds=window_shape,
|
|
|
|
transform_indices=ir.FlatSymbolRefAttr.get(func_name),
|
2023-10-24 17:28:05 -07:00
|
|
|
)
|
2024-08-05 04:23:15 -07:00
|
|
|
if isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
2024-02-08 02:47:56 -08:00
|
|
|
if bm.indexing_mode.padding is None:
|
|
|
|
pad_low = pad_high = [0] * len(bm.block_shape)
|
|
|
|
else:
|
|
|
|
pad_low, pad_high = map(list, zip(*bm.indexing_mode.padding))
|
|
|
|
block_params["window_kind"] = ir.Attribute.parse(
|
|
|
|
f"#tpu.element_window<{pad_low},{pad_high}>"
|
|
|
|
)
|
2025-02-10 02:36:10 -08:00
|
|
|
if bm.pipeline_mode is not None:
|
|
|
|
if not isinstance(bm.pipeline_mode, pallas_core.Buffered):
|
|
|
|
raise LoweringException(
|
|
|
|
f"Unsupported pipeline mode: {bm.pipeline_mode}."
|
|
|
|
)
|
|
|
|
buffer_count = bm.pipeline_mode.buffer_count
|
|
|
|
if buffer_count < 1 or buffer_count > 2:
|
|
|
|
raise LoweringException(
|
|
|
|
"Only single (1) and double (2) buffering are supported. Got"
|
|
|
|
f" {buffer_count}."
|
|
|
|
)
|
|
|
|
pipeline_mode = "synchronous" if buffer_count == 1 else "double_buffered"
|
|
|
|
block_params["pipeline_mode"] = ir.Attribute.parse(
|
|
|
|
f"#tpu.pipeline_mode<{pipeline_mode}>"
|
|
|
|
)
|
2024-02-08 02:47:56 -08:00
|
|
|
window_params.append(ir.DictAttr.get(block_params))
|
2023-10-24 17:28:05 -07:00
|
|
|
m.body.append(mlir_func)
|
|
|
|
sym_tab.insert(mlir_func)
|
|
|
|
func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params)
|
2025-01-14 20:33:34 -08:00
|
|
|
|
2024-06-17 15:17:52 -07:00
|
|
|
static_grid = [
|
2024-08-05 04:23:15 -07:00
|
|
|
MLIR_DYNAMIC if b is pallas_core.dynamic_grid_dim else b for b in grid
|
2024-06-17 15:17:52 -07:00
|
|
|
]
|
2025-01-14 20:33:34 -08:00
|
|
|
static_grid = dynamic_shape_replacement_fn(static_grid)
|
2024-02-01 09:14:30 -08:00
|
|
|
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)
|
2023-10-24 17:28:05 -07:00
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
|
2023-11-17 18:04:16 -08:00
|
|
|
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types))
|
2023-10-24 17:28:05 -07:00
|
|
|
func_op.attributes["scratch_operands"] = ir.IntegerAttr.get(
|
2023-11-17 18:04:16 -08:00
|
|
|
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scratch_types))
|
|
|
|
func_op.attributes["dimension_semantics"] = (
|
|
|
|
mosaic_grid_mapping.get_dimension_semantics()
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
2023-11-17 18:04:16 -08:00
|
|
|
return m, mosaic_grid_mapping.get_extra_args()
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def lower_jaxpr_to_transform_func(
|
2023-11-07 17:54:43 -08:00
|
|
|
ctx: ir.Context,
|
|
|
|
jaxpr: jax_core.Jaxpr,
|
2024-07-02 00:40:13 -07:00
|
|
|
aval: jax_core.AbstractValue,
|
2023-11-07 17:54:43 -08:00
|
|
|
*,
|
|
|
|
name: str,
|
2023-11-17 18:04:16 -08:00
|
|
|
mosaic_grid_mapping: MosaicGridMapping,
|
2024-07-17 14:08:07 -07:00
|
|
|
for_verification: bool,
|
2025-01-14 20:33:34 -08:00
|
|
|
forward_compatible: bool,
|
|
|
|
dynamic_shape_replacement_fn: (
|
|
|
|
Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None
|
|
|
|
) = None,
|
2023-11-07 17:54:43 -08:00
|
|
|
) -> func.FuncOp:
|
2023-11-17 18:04:16 -08:00
|
|
|
num_grid = len(mosaic_grid_mapping.grid_types)
|
|
|
|
arg_types = [
|
|
|
|
*mosaic_grid_mapping.grid_types,
|
|
|
|
*mosaic_grid_mapping.scalar_prefetch_types,
|
|
|
|
]
|
|
|
|
def body_func(*args):
|
|
|
|
grid_indices, scalar_prefetch = split_list(args, [num_grid])
|
2024-08-20 15:06:27 -07:00
|
|
|
jaxpr_indices = mosaic_grid_mapping.get_grid_indices(
|
|
|
|
grid_indices, maybe_include_mapped_dims=True
|
|
|
|
)
|
2023-11-17 18:04:16 -08:00
|
|
|
arg_block_shapes = [
|
|
|
|
*[()] * len(jaxpr_indices),
|
|
|
|
*mosaic_grid_mapping.scalar_prefetch_block_shapes,
|
|
|
|
]
|
2023-11-07 17:54:43 -08:00
|
|
|
|
2023-11-17 18:04:16 -08:00
|
|
|
mesh_info = mosaic_grid_mapping.mesh_info
|
|
|
|
if mesh_info is not None:
|
2024-03-22 07:11:45 -07:00
|
|
|
mesh_context = MeshContext(
|
|
|
|
mesh_info.mesh_shape, mesh_info.axis_names, mesh_info.mesh_strides
|
|
|
|
)
|
2023-11-17 18:04:16 -08:00
|
|
|
else:
|
|
|
|
mesh_context = None
|
|
|
|
lowering_context = LoweringContext(
|
|
|
|
ctx,
|
Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
2024-07-23 15:15:11 -07:00
|
|
|
mosaic_grid_mapping.grid,
|
2024-07-22 23:24:31 -07:00
|
|
|
mosaic_grid_mapping.grid_names,
|
2024-03-11 08:40:34 -07:00
|
|
|
mosaic_grid_mapping.mapped_dims,
|
2023-11-17 18:04:16 -08:00
|
|
|
None,
|
|
|
|
arg_block_shapes,
|
|
|
|
source_info_util.NameStack(),
|
|
|
|
mesh_context=mesh_context,
|
2024-01-03 09:00:29 -08:00
|
|
|
traceback_caches=mlir.TracebackCaches(),
|
2024-07-17 14:08:07 -07:00
|
|
|
for_verification=for_verification,
|
2025-01-14 20:33:34 -08:00
|
|
|
forward_compatible=forward_compatible,
|
|
|
|
dynamic_shape_replacement_fn=dynamic_shape_replacement_fn,
|
2023-11-07 17:54:43 -08:00
|
|
|
)
|
2024-07-02 00:40:13 -07:00
|
|
|
out = jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices,
|
|
|
|
*scalar_prefetch)
|
|
|
|
assert isinstance(aval, state.AbstractRef), aval
|
|
|
|
# If we have an extended dtype, we need to add 0s for the block indices
|
|
|
|
# for the remaining physical dtype.
|
|
|
|
out += [
|
|
|
|
ir_constant(0, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
|
|
|
|
] * len(_get_aval_physical_dtype_shape(aval.inner_aval))
|
|
|
|
return out
|
|
|
|
|
2023-11-17 18:04:16 -08:00
|
|
|
body_func.__name__ = name
|
|
|
|
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
|
2024-01-02 21:53:30 -08:00
|
|
|
try:
|
|
|
|
body.func_op.verify()
|
2024-08-12 14:41:58 -07:00
|
|
|
except ir.MLIRError as e:
|
|
|
|
raise error_handling.mlir_error_to_verification_error(e) from e
|
2023-11-17 18:04:16 -08:00
|
|
|
return body.func_op
|
2023-11-07 17:54:43 -08:00
|
|
|
|
|
|
|
|
2023-11-17 18:04:16 -08:00
|
|
|
def lower_jaxpr_to_func(
|
|
|
|
ctx: ir.Context,
|
|
|
|
jaxpr: jax_core.Jaxpr,
|
|
|
|
*,
|
|
|
|
mosaic_grid_mapping: MosaicGridMapping,
|
|
|
|
name: str,
|
2024-07-17 14:08:07 -07:00
|
|
|
for_verification: bool,
|
2025-01-14 20:33:34 -08:00
|
|
|
forward_compatible: bool,
|
|
|
|
dynamic_shape_replacement_fn: (
|
|
|
|
Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None
|
|
|
|
) = None,
|
2023-11-17 18:04:16 -08:00
|
|
|
) -> func.FuncOp:
|
|
|
|
num_grid = len(mosaic_grid_mapping.grid_types)
|
|
|
|
num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types)
|
|
|
|
arg_types = [
|
|
|
|
*mosaic_grid_mapping.grid_types,
|
|
|
|
*mosaic_grid_mapping.scalar_prefetch_types,
|
|
|
|
*mosaic_grid_mapping.operand_types,
|
|
|
|
*mosaic_grid_mapping.scratch_types,
|
|
|
|
]
|
|
|
|
arg_block_shapes = [
|
|
|
|
*mosaic_grid_mapping.scalar_prefetch_block_shapes,
|
|
|
|
*mosaic_grid_mapping.operand_block_shapes,
|
|
|
|
*mosaic_grid_mapping.scratch_block_shapes,
|
|
|
|
]
|
|
|
|
def body_func(*args):
|
|
|
|
grid_indices, scalar_prefetch, operands_and_scratch = split_list(
|
|
|
|
args, [num_grid, num_scalar_prefetch])
|
2024-08-20 15:06:27 -07:00
|
|
|
jaxpr_indices = mosaic_grid_mapping.get_grid_indices(
|
|
|
|
grid_indices, maybe_include_mapped_dims=False
|
|
|
|
)
|
2023-11-17 18:04:16 -08:00
|
|
|
mesh_info = mosaic_grid_mapping.mesh_info
|
|
|
|
if mesh_info is not None:
|
2024-03-22 07:11:45 -07:00
|
|
|
mesh_context = MeshContext(
|
|
|
|
mesh_info.mesh_shape, mesh_info.axis_names, mesh_info.mesh_strides
|
|
|
|
)
|
2023-11-17 18:04:16 -08:00
|
|
|
else:
|
|
|
|
mesh_context = None
|
2023-11-07 17:54:43 -08:00
|
|
|
lowering_context = LoweringContext(
|
|
|
|
ctx,
|
Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
2024-07-23 15:15:11 -07:00
|
|
|
mosaic_grid_mapping.grid,
|
2024-07-22 23:24:31 -07:00
|
|
|
mosaic_grid_mapping.grid_names,
|
2024-03-11 08:40:34 -07:00
|
|
|
mosaic_grid_mapping.mapped_dims,
|
2023-11-17 18:04:16 -08:00
|
|
|
jaxpr_indices,
|
|
|
|
arg_block_shapes,
|
2023-11-07 17:54:43 -08:00
|
|
|
source_info_util.NameStack(),
|
2023-11-17 18:04:16 -08:00
|
|
|
mesh_context=mesh_context,
|
2024-01-03 09:00:29 -08:00
|
|
|
traceback_caches=mlir.TracebackCaches(),
|
2024-07-17 14:08:07 -07:00
|
|
|
for_verification=for_verification,
|
2025-01-14 20:33:34 -08:00
|
|
|
forward_compatible=forward_compatible,
|
|
|
|
dynamic_shape_replacement_fn=dynamic_shape_replacement_fn,
|
2023-11-17 18:04:16 -08:00
|
|
|
)
|
|
|
|
return jaxpr_subcomp(
|
|
|
|
lowering_context, jaxpr, *scalar_prefetch, *operands_and_scratch
|
2023-11-07 17:54:43 -08:00
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
body_func.__name__ = name
|
|
|
|
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
|
2024-01-02 21:53:30 -08:00
|
|
|
try:
|
|
|
|
body.func_op.verify()
|
2024-08-12 14:41:58 -07:00
|
|
|
except ir.MLIRError as e:
|
|
|
|
raise error_handling.mlir_error_to_verification_error(e) from e
|
2023-08-01 16:42:26 -07:00
|
|
|
return body.func_op
|
|
|
|
|
2023-09-27 13:33:04 -07:00
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable:
|
|
|
|
def f_lowered(ctx: LoweringRuleContext, *args, **params):
|
|
|
|
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
2025-02-08 15:19:46 +02:00
|
|
|
wrapped_fun = lu.wrap_init(
|
|
|
|
f, params,
|
|
|
|
debug_info=api_util.debug_info("mosaic lower_fun", f,
|
|
|
|
args, params))
|
2024-01-25 22:20:36 -08:00
|
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
2023-08-01 16:42:26 -07:00
|
|
|
if consts:
|
|
|
|
raise NotImplementedError
|
|
|
|
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
|
|
|
lowering_context = ctx.lowering_context.replace(
|
|
|
|
block_shapes=ctx.block_shapes)
|
|
|
|
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
|
|
|
|
if not multiple_results:
|
|
|
|
return out[0]
|
|
|
|
return out
|
|
|
|
|
|
|
|
return f_lowered
|
|
|
|
|
|
|
|
|
2023-09-27 13:33:04 -07:00
|
|
|
class LoweringException(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2024-04-25 18:33:52 -07:00
|
|
|
def _compute_name_stack_updates(
|
|
|
|
old_name_stack: list[str],
|
|
|
|
new_name_stack: list[str]
|
|
|
|
) -> tuple[list[str], list[str]]:
|
|
|
|
"""Computes the popped/pushed items to the name stack after an update.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
old_name_stack: The name stack prior to the update.
|
|
|
|
new_name_stack: The name stack after the update.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
popped: A list of names popped from the name stack as part of the update.
|
|
|
|
pushed: A list of names pushed to the name stack as part of the update.
|
|
|
|
"""
|
|
|
|
common_prefix_idx = 0
|
|
|
|
for i, (old, new) in enumerate(unsafe_zip(old_name_stack, new_name_stack)):
|
|
|
|
if old == new:
|
|
|
|
common_prefix_idx = i+1
|
|
|
|
else:
|
|
|
|
break
|
|
|
|
return old_name_stack[common_prefix_idx:], new_name_stack[common_prefix_idx:]
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def jaxpr_subcomp(
|
|
|
|
ctx: LoweringContext, jaxpr: jax_core.Jaxpr, *args: ir.Value
|
|
|
|
) -> Sequence[ir.Value]:
|
|
|
|
assert not jaxpr.constvars
|
|
|
|
env = {}
|
|
|
|
block_shape_env = {}
|
|
|
|
|
|
|
|
def read_block_shape(atom: jax_core.Atom):
|
|
|
|
if isinstance(atom, jax_core.Literal):
|
|
|
|
return None
|
|
|
|
return block_shape_env.get(atom, None)
|
|
|
|
|
|
|
|
def read_env(atom: jax_core.Atom):
|
|
|
|
return atom.val if isinstance(atom, jax_core.Literal) else env[atom]
|
|
|
|
|
|
|
|
def write_env(var: jax_core.Var, val):
|
2024-06-12 14:36:31 -07:00
|
|
|
is_valid_type = isinstance(val, (ir.Value, KeyScalarBundle))
|
|
|
|
assert is_valid_type, type(val)
|
2023-08-01 16:42:26 -07:00
|
|
|
env[var] = val
|
|
|
|
|
|
|
|
for invar, bs in zip(jaxpr.invars, ctx.block_shapes):
|
|
|
|
block_shape_env[invar] = bs
|
|
|
|
map(write_env, jaxpr.invars, args)
|
|
|
|
|
2024-06-12 16:32:30 -07:00
|
|
|
initial_name_stack = [scope.name for scope in ctx.name_stack.stack]
|
2024-04-25 18:33:52 -07:00
|
|
|
current_name_stack: list[str] = []
|
|
|
|
# TODO(justinfu): Handle transform scopes.
|
2024-06-12 16:32:30 -07:00
|
|
|
current_name_stack.extend(initial_name_stack)
|
2023-08-01 16:42:26 -07:00
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
invals = map(read_env, eqn.invars)
|
|
|
|
source_info = eqn.source_info.replace(
|
|
|
|
name_stack=ctx.name_stack + eqn.source_info.name_stack
|
|
|
|
)
|
2024-08-21 14:10:48 -07:00
|
|
|
loc = mlir._source_info_to_location(ctx, eqn.primitive, source_info)
|
2025-01-22 16:47:58 -08:00
|
|
|
with (source_info_util.user_context(eqn.source_info.traceback), loc,
|
|
|
|
eqn.ctx.manager):
|
2023-08-01 16:42:26 -07:00
|
|
|
if eqn.primitive in lowering_rules:
|
2023-09-06 02:14:42 -07:00
|
|
|
if eqn.primitive not in skip_mlir_conversions:
|
|
|
|
invals = [_ensure_mlir_value(x, v.aval)
|
|
|
|
for x, v in zip(invals, eqn.invars)]
|
2023-08-01 16:42:26 -07:00
|
|
|
block_shapes = map(read_block_shape, eqn.invars)
|
|
|
|
rule_context = LoweringRuleContext(
|
|
|
|
ctx,
|
|
|
|
[v.aval for v in eqn.invars],
|
|
|
|
[v.aval for v in eqn.outvars],
|
|
|
|
block_shapes,
|
|
|
|
)
|
2024-04-25 18:33:52 -07:00
|
|
|
|
|
|
|
# Insert trace_start and trace_stop ops on named_scope boundaries.
|
|
|
|
name_stack = [scope.name for scope in source_info.name_stack.stack]
|
|
|
|
popped, pushed = _compute_name_stack_updates(
|
|
|
|
current_name_stack, name_stack)
|
|
|
|
current_name_stack = name_stack
|
|
|
|
for _ in popped:
|
2024-04-26 17:33:26 -07:00
|
|
|
tpu.TraceStopOp()
|
2024-04-25 18:33:52 -07:00
|
|
|
for name in pushed:
|
2024-04-26 17:33:26 -07:00
|
|
|
tpu.TraceStartOp(message=name, level=10)
|
2024-04-25 18:33:52 -07:00
|
|
|
|
2023-09-27 13:33:04 -07:00
|
|
|
try:
|
|
|
|
ans = lowering_rules[eqn.primitive](
|
|
|
|
rule_context, *invals, **eqn.params
|
|
|
|
)
|
|
|
|
except LoweringException:
|
|
|
|
raise # We only add the extra info to the innermost exception.
|
|
|
|
except Exception as e:
|
2024-10-30 10:12:47 -07:00
|
|
|
if not pallas_call._verbose_errors_enabled():
|
|
|
|
raise
|
2024-10-01 10:25:53 -07:00
|
|
|
msg = (f"{type(e).__name__}: {e}\n" +
|
|
|
|
"Additional diagnostics: \n" +
|
|
|
|
f"Failing jaxpr equation: {eqn}\n")
|
|
|
|
new_error = LoweringException(msg)
|
|
|
|
# We insert the traceback here so that the user code shows
|
|
|
|
# up in the traceback for the post-transform error.
|
|
|
|
if source_info.traceback is not None:
|
|
|
|
tb = source_info.traceback.as_python_traceback()
|
|
|
|
new_error.__traceback__ = traceback_util.filter_traceback(tb)
|
|
|
|
raise new_error from e
|
2023-08-01 16:42:26 -07:00
|
|
|
else:
|
2023-08-04 16:11:29 -07:00
|
|
|
raise NotImplementedError(
|
|
|
|
"Unimplemented primitive in Pallas TPU lowering: "
|
|
|
|
f"{eqn.primitive.name}. "
|
2024-09-20 07:51:48 -07:00
|
|
|
"Please file an issue on https://github.com/jax-ml/jax/issues.")
|
2023-08-01 16:42:26 -07:00
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
map(write_env, eqn.outvars, ans)
|
|
|
|
else:
|
|
|
|
write_env(eqn.outvars[0], ans)
|
2024-06-12 16:32:30 -07:00
|
|
|
|
|
|
|
# Drain the name stack at the end of a jaxpr and insert trace_stop ops.
|
|
|
|
popped, pushed = _compute_name_stack_updates(
|
|
|
|
current_name_stack, initial_name_stack)
|
|
|
|
for _ in popped:
|
|
|
|
tpu.TraceStopOp()
|
|
|
|
assert len(pushed) == 0
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
outvals = map(read_env, jaxpr.outvars)
|
|
|
|
outvals = [
|
|
|
|
ir_constant(x) if isinstance(var, jax_core.Literal) else x
|
|
|
|
for x, var in zip(outvals, jaxpr.outvars)
|
|
|
|
]
|
|
|
|
return outvals
|
|
|
|
|
|
|
|
|
2023-09-06 02:14:42 -07:00
|
|
|
def _ensure_mlir_value(val, aval):
|
|
|
|
if isinstance(val, ir.Value):
|
|
|
|
return val
|
2024-06-12 14:36:31 -07:00
|
|
|
if isinstance(val, KeyScalarBundle):
|
|
|
|
return val
|
2023-09-06 02:14:42 -07:00
|
|
|
elif isinstance(val, (np.generic, np.ndarray, int, float)):
|
2024-01-11 06:32:57 -08:00
|
|
|
return ir_constant(val, _dtype_to_ir_type(aval.dtype))
|
2023-09-06 02:14:42 -07:00
|
|
|
else:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Unsupported argument to a JAX primitive of type: {type(val)}"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _get_lowering_rule(
|
2024-01-02 15:52:57 -08:00
|
|
|
ctx: LoweringRuleContext, ref, *idx, tree,
|
2023-08-01 16:42:26 -07:00
|
|
|
):
|
2024-01-02 15:52:57 -08:00
|
|
|
indexers = tree_util.tree_unflatten(tree, idx)
|
|
|
|
indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:])
|
2023-08-10 21:18:10 -07:00
|
|
|
# Call _load_lowering_rule (since it's more general)
|
2024-01-02 21:53:30 -08:00
|
|
|
ref_aval, *_ = ctx.avals_in
|
|
|
|
args_flat, args_tree = tree_util.tree_flatten((ref, indexers, None, None))
|
|
|
|
avals_flat = tree_util.tree_leaves((ref_aval, indexers_avals, None, None))
|
|
|
|
ctx = ctx.replace(
|
|
|
|
avals_in=avals_flat,
|
|
|
|
block_shapes=[ctx.block_shapes[0], *[None] * (len(avals_flat) - 1)],
|
|
|
|
)
|
2023-10-10 14:38:54 -07:00
|
|
|
return _load_lowering_rule(ctx, *args_flat, args_tree=args_tree)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[state_primitives.get_p] = _get_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(state_primitives.get_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2023-08-10 21:18:10 -07:00
|
|
|
def _swap_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
ref,
|
|
|
|
val,
|
2024-01-02 15:52:57 -08:00
|
|
|
*idx,
|
|
|
|
tree
|
2023-08-10 21:18:10 -07:00
|
|
|
):
|
2024-01-02 15:52:57 -08:00
|
|
|
indexers = tree_util.tree_unflatten(tree, idx)
|
|
|
|
indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[2:])
|
2023-08-10 21:18:10 -07:00
|
|
|
# Call _masked_swap_lowering_rule (since it's more general)
|
2024-01-02 15:52:57 -08:00
|
|
|
ref_aval, val_aval, *_ = ctx.avals_in
|
2024-01-02 21:53:30 -08:00
|
|
|
args_flat, args_tree = tree_util.tree_flatten((ref, indexers, val, None))
|
2023-10-10 14:38:54 -07:00
|
|
|
avals_flat = tree_util.tree_leaves(
|
2024-01-02 21:53:30 -08:00
|
|
|
(ref_aval, indexers_avals, val_aval, None)
|
|
|
|
)
|
|
|
|
ctx = ctx.replace(
|
|
|
|
avals_in=avals_flat,
|
|
|
|
block_shapes=[ctx.block_shapes[0], *[None] * (len(avals_flat) - 1)],
|
2023-10-10 14:38:54 -07:00
|
|
|
)
|
|
|
|
return _masked_swap_lowering_rule(ctx, *args_flat, args_tree=args_tree)
|
2023-08-10 21:18:10 -07:00
|
|
|
|
|
|
|
lowering_rules[state_primitives.swap_p] = _swap_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(state_primitives.swap_p)
|
2023-08-10 21:18:10 -07:00
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _make_index(s):
|
|
|
|
if isinstance(s, (int, np.ndarray)):
|
|
|
|
return ir_constant(s, ir.IndexType.get())
|
|
|
|
if s.type == ir.IndexType.get():
|
|
|
|
return s
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.index_cast(ir.IndexType.get(), s)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2024-01-02 21:53:30 -08:00
|
|
|
def _maybe_cast_to_index(cast_to_index, x):
|
|
|
|
if cast_to_index:
|
|
|
|
return _make_index(x)
|
2024-08-05 04:23:15 -07:00
|
|
|
return _ensure_mlir_value(x, aval=pallas_core.index_map_grid_aval)
|
2024-01-02 21:53:30 -08:00
|
|
|
|
2024-03-14 16:31:23 -07:00
|
|
|
|
|
|
|
def _index_to_start_size_stride(
|
|
|
|
idx: tuple[indexing.Slice | int | ir.Value, ...], cast_to_index: bool
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
) -> tuple[ir.Value, int | ir.Value, int, bool]:
|
2024-01-02 21:53:30 -08:00
|
|
|
assert not isinstance(idx, slice)
|
|
|
|
if isinstance(idx, indexing.Slice):
|
|
|
|
start = _maybe_cast_to_index(cast_to_index, idx.start)
|
|
|
|
size = idx.size
|
2024-03-14 16:31:23 -07:00
|
|
|
stride = idx.stride
|
2024-01-02 21:53:30 -08:00
|
|
|
squeeze = False
|
|
|
|
elif isinstance(idx, int):
|
|
|
|
start = _maybe_cast_to_index(cast_to_index, idx)
|
|
|
|
size = 1
|
2024-03-14 16:31:23 -07:00
|
|
|
stride = 1
|
2024-01-02 21:53:30 -08:00
|
|
|
squeeze = True
|
|
|
|
else:
|
|
|
|
if np.shape(idx):
|
|
|
|
raise ValueError(f"Can only use ()-shaped and slice indexing: {idx}")
|
|
|
|
start = _maybe_cast_to_index(cast_to_index, idx)
|
|
|
|
size = 1
|
2024-03-14 16:31:23 -07:00
|
|
|
stride = 1
|
2024-01-02 21:53:30 -08:00
|
|
|
squeeze = True
|
2024-03-14 16:31:23 -07:00
|
|
|
return start, size, stride, squeeze
|
2024-01-02 21:53:30 -08:00
|
|
|
|
|
|
|
|
2024-03-14 16:31:23 -07:00
|
|
|
def _indexer_to_start_size_stride(
|
|
|
|
indexer: NDIndexer,
|
2024-08-05 04:23:15 -07:00
|
|
|
ref_block_shape: tuple[int | pallas_core.Mapped, ...],
|
2024-03-14 16:31:23 -07:00
|
|
|
*,
|
2024-01-02 21:53:30 -08:00
|
|
|
cast_to_index: bool,
|
2024-03-14 16:31:23 -07:00
|
|
|
) -> tuple[
|
|
|
|
tuple[ir.Value, ...],
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
tuple[int | ir.Value, ...],
|
2024-03-14 16:31:23 -07:00
|
|
|
tuple[int, ...],
|
|
|
|
tuple[bool, ...],
|
2024-08-05 04:23:15 -07:00
|
|
|
tuple[int | pallas_core.Mapped, ...],
|
2024-03-14 16:31:23 -07:00
|
|
|
]:
|
2024-01-02 21:53:30 -08:00
|
|
|
indices_iter = iter(indexer.indices)
|
2024-03-14 16:31:23 -07:00
|
|
|
starts, sizes, strides, squeeze_dims = [], [], [], []
|
|
|
|
for s in ref_block_shape:
|
|
|
|
start, size, stride, squeeze_dim = (
|
|
|
|
(
|
|
|
|
_maybe_cast_to_index(cast_to_index, 0),
|
|
|
|
1,
|
|
|
|
1,
|
|
|
|
True,
|
|
|
|
)
|
2024-08-05 04:23:15 -07:00
|
|
|
if s is pallas_core.mapped
|
2024-03-14 16:31:23 -07:00
|
|
|
else _index_to_start_size_stride(next(indices_iter), cast_to_index)
|
|
|
|
)
|
|
|
|
starts.append(start)
|
|
|
|
sizes.append(size)
|
|
|
|
strides.append(stride)
|
|
|
|
squeeze_dims.append(squeeze_dim)
|
2024-02-12 18:05:31 -08:00
|
|
|
next_index = next(indices_iter, None)
|
|
|
|
assert next_index is None, (indexer.indices, ref_block_shape)
|
2024-01-02 21:53:30 -08:00
|
|
|
new_ref_block_shape = tuple(s for s, squeeze in zip(sizes, squeeze_dims)
|
|
|
|
if not squeeze)
|
2024-03-14 16:31:23 -07:00
|
|
|
return (
|
|
|
|
tuple(starts),
|
|
|
|
tuple(sizes),
|
|
|
|
tuple(strides),
|
|
|
|
tuple(squeeze_dims),
|
|
|
|
new_ref_block_shape,
|
|
|
|
)
|
2024-01-02 21:53:30 -08:00
|
|
|
|
|
|
|
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
def _slice_memref(
|
|
|
|
ref: ir.Value,
|
|
|
|
indexer: NDIndexer,
|
|
|
|
ref_dtype: DTypeLike,
|
|
|
|
ref_block_shape: tuple[int | pallas_core.Mapped, ...],
|
|
|
|
) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...]]:
|
2024-01-02 21:53:30 -08:00
|
|
|
assert ref_block_shape is not None
|
|
|
|
target_shape = indexer.get_indexer_shape()
|
2024-03-14 16:31:23 -07:00
|
|
|
starts, sizes, strides, squeeze_dims, ref_block_shape = (
|
|
|
|
_indexer_to_start_size_stride(
|
|
|
|
indexer,
|
|
|
|
ref_block_shape,
|
|
|
|
cast_to_index=False,
|
|
|
|
)
|
2024-01-02 21:53:30 -08:00
|
|
|
)
|
2024-03-14 16:31:23 -07:00
|
|
|
if not all((s is None or s == 1) for s in strides):
|
|
|
|
raise NotImplementedError("Strided slices of references are unsupported.")
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
dynamic_sizes = tuple(s for s in sizes if isinstance(s, ir.Value))
|
|
|
|
ir_dynamic_size = ir.ShapedType.get_dynamic_size()
|
|
|
|
static_sizes = tuple(s if not isinstance(s, ir.Value)
|
|
|
|
else ir_dynamic_size for s in sizes)
|
2024-01-02 21:53:30 -08:00
|
|
|
target_ref_ty = ir.MemRefType.get(
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
static_sizes,
|
|
|
|
_dtype_to_ir_type(ref_dtype),
|
|
|
|
memory_space=ref.type.memory_space,
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
out = tpu.memref_slice(target_ref_ty, ref, starts, dynamic_sizes)
|
2024-01-02 21:53:30 -08:00
|
|
|
if any(squeeze_dims):
|
|
|
|
# We need to squeeze out some dimensions
|
2024-04-02 16:30:59 -07:00
|
|
|
static_sizes = tuple(s if not isinstance(s, ir.Value)
|
|
|
|
else ir_dynamic_size for s in target_shape)
|
2024-01-02 21:53:30 -08:00
|
|
|
squeezed_ref_ty = ir.MemRefType.get(
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
static_sizes,
|
|
|
|
_dtype_to_ir_type(ref_dtype),
|
|
|
|
memory_space=ref.type.memory_space,
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
out = tpu.memref_squeeze(squeezed_ref_ty, out)
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
return out, ref_block_shape
|
2024-01-02 21:53:30 -08:00
|
|
|
|
|
|
|
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
def _bitcast_memref(
|
|
|
|
ref: ir.Value,
|
|
|
|
bitcaster: RefBitcaster,
|
|
|
|
ref_dtype: DTypeLike,
|
|
|
|
ref_block_shape: tuple[int | pallas_core.Mapped, ...],
|
|
|
|
) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]:
|
|
|
|
src_bitwidth = dtype_bitwidth(ref_dtype)
|
|
|
|
dst_bitwidth = dtype_bitwidth(bitcaster.dtype)
|
|
|
|
if src_bitwidth != dst_bitwidth:
|
|
|
|
if len(ref_block_shape) < 2:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Bitcast 1D ref with bitwidth change is not supported."
|
|
|
|
)
|
|
|
|
if ref_block_shape[-2] is pallas_core.mapped:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Bitcast a ref whose 2nd minormost dimension is squeezed when"
|
|
|
|
" bitwidth changes."
|
|
|
|
)
|
|
|
|
new_ref_dtype = bitcaster.dtype
|
|
|
|
target_ref_ty = ir.MemRefType.get(
|
|
|
|
bitcaster.shape,
|
|
|
|
_dtype_to_ir_type(new_ref_dtype),
|
|
|
|
memory_space=ref.type.memory_space,
|
|
|
|
)
|
|
|
|
new_ref_block_shape = list(ref_block_shape)
|
|
|
|
if (
|
|
|
|
len(new_ref_block_shape) >= 2
|
|
|
|
and new_ref_block_shape[-2] is not pallas_core.mapped
|
|
|
|
):
|
|
|
|
new_ref_block_shape[-2] = (
|
|
|
|
new_ref_block_shape[-2] * src_bitwidth // dst_bitwidth
|
|
|
|
)
|
|
|
|
return (
|
|
|
|
tpu.memref_bitcast(target_ref_ty, ref),
|
|
|
|
new_ref_dtype,
|
|
|
|
tuple(new_ref_block_shape),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
[Pallas TPU] Support ref reshape.
Jaxpr example:
```
{ lambda ; a:MemRef<None>{int32[32,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:16,:][bitcast(int16[32,256])][reshape(int16[2,16,256])][bitcast(float16[2,16,256])][1:,:,:][reshape(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
- DMA with reshaped ref
- Load from reshaped ref
- Store to reshaped ref
- Multiple transforms
- Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 686186426
2024-10-15 11:51:37 -07:00
|
|
|
def _reshape_memref(
|
|
|
|
ref: ir.Value,
|
|
|
|
reshaper: RefReshaper,
|
|
|
|
ref_dtype: DTypeLike,
|
|
|
|
ref_block_shape: tuple[int | pallas_core.Mapped, ...],
|
|
|
|
) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]:
|
|
|
|
if ref_dtype != reshaper.dtype:
|
|
|
|
raise ValueError(
|
|
|
|
f"Reshape a ref with dtype change: {reshaper.dtype} vs {ref_dtype}"
|
|
|
|
)
|
|
|
|
if len(ref_block_shape) < 2:
|
|
|
|
raise NotImplementedError("Reshape 1D ref is not supported.")
|
|
|
|
if (
|
|
|
|
ref_block_shape[-2] is pallas_core.mapped
|
|
|
|
or ref_block_shape[-1] is pallas_core.mapped
|
|
|
|
):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Reshape a ref with squeezed dimension on last two dimensions."
|
|
|
|
)
|
|
|
|
if np.prod(ref_block_shape) != np.prod(reshaper.shape):
|
|
|
|
raise ValueError(
|
|
|
|
f"Reshape a ref with different number of elements: {ref_block_shape} "
|
|
|
|
f"vs {reshaper.shape}"
|
|
|
|
)
|
|
|
|
target_ref_ty = ir.MemRefType.get(
|
|
|
|
reshaper.shape,
|
|
|
|
_dtype_to_ir_type(reshaper.dtype),
|
|
|
|
memory_space=ref.type.memory_space,
|
|
|
|
)
|
|
|
|
return (
|
|
|
|
tpu.memref_reshape(target_ref_ty, ref),
|
|
|
|
reshaper.shape,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
def _transform_ref(ref, ref_dtype, ref_block_shape, transforms):
|
|
|
|
for transform in transforms:
|
|
|
|
match transform:
|
|
|
|
case NDIndexer():
|
|
|
|
ref, ref_block_shape = _slice_memref(
|
|
|
|
ref, transform, ref_dtype, ref_block_shape
|
|
|
|
)
|
|
|
|
case RefBitcaster():
|
|
|
|
ref, ref_dtype, ref_block_shape = _bitcast_memref(
|
|
|
|
ref, transform, ref_dtype, ref_block_shape
|
|
|
|
)
|
[Pallas TPU] Support ref reshape.
Jaxpr example:
```
{ lambda ; a:MemRef<None>{int32[32,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:16,:][bitcast(int16[32,256])][reshape(int16[2,16,256])][bitcast(float16[2,16,256])][1:,:,:][reshape(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
- DMA with reshaped ref
- Load from reshaped ref
- Store to reshaped ref
- Multiple transforms
- Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 686186426
2024-10-15 11:51:37 -07:00
|
|
|
case RefReshaper():
|
|
|
|
ref, ref_block_shape = _reshape_memref(
|
|
|
|
ref, transform, ref_dtype, ref_block_shape
|
|
|
|
)
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
case _:
|
|
|
|
raise NotImplementedError(f"Unsupported transform: {transform}")
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
return ref, ref_block_shape
|
2024-01-11 06:32:57 -08:00
|
|
|
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
|
2024-06-12 14:36:31 -07:00
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
class KeyScalarBundle:
|
|
|
|
"""A container class for PRNG key data.
|
|
|
|
|
|
|
|
We pass around keys as a KeyScalarBundle in the lowering pass rather than
|
|
|
|
as a vector, since we want the key data to live in scalar registers rather
|
|
|
|
than vector registers. This special dataclass exists so we can return
|
|
|
|
multiple scalar values from load_op, because the load_op primitive does
|
|
|
|
not allow multiple results.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
scalars: A list of OpResults representing scalar key data during the
|
|
|
|
lowering pass.
|
|
|
|
"""
|
2024-07-03 13:07:39 -07:00
|
|
|
key_shape: tuple[int, ...]
|
2024-06-12 14:36:31 -07:00
|
|
|
scalars: list[ir.OpResult]
|
2024-01-11 06:32:57 -08:00
|
|
|
|
2023-10-10 14:38:54 -07:00
|
|
|
def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
ref, transforms, mask, _ = args_tree.unflatten(args_flat)
|
|
|
|
ref_aval, transforms_avals, _, _ = args_tree.unflatten(ctx.avals_in)
|
|
|
|
(*prev_transforms, idx) = transforms
|
2024-01-02 21:53:30 -08:00
|
|
|
# Select last aval, which is the one that will be used for the load.
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
(*_, idx_aval) = transforms_avals
|
2023-10-10 14:38:54 -07:00
|
|
|
|
|
|
|
if mask is not None:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2024-01-02 21:53:30 -08:00
|
|
|
ref_block_shape, *_ = ctx.block_shapes
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
ref, ref_block_shape = _transform_ref(
|
|
|
|
ref, ref_aval.dtype, ref_block_shape, prev_transforms
|
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
ref_type = ir.MemRefType(ref.type)
|
|
|
|
is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space<smem>"
|
|
|
|
(aval_out,) = ctx.avals_out
|
2024-10-09 14:47:45 -07:00
|
|
|
if isinstance(aval_out.dtype, prng.KeyTy) and pl_random.is_pallas_impl(
|
|
|
|
aval_out.dtype._impl
|
|
|
|
):
|
2024-06-12 14:36:31 -07:00
|
|
|
if not is_smem_load:
|
|
|
|
raise ValueError("PRNG keys must be loaded from SMEM. Did you set "
|
|
|
|
"the memory space to TPUMemorySpace.SMEM in the "
|
|
|
|
"BlockSpec for the PRNG key input?")
|
|
|
|
return _prng_key_load_lowering_rule(ctx, *args_flat, args_tree=args_tree)
|
2023-09-13 15:32:01 -07:00
|
|
|
if not is_smem_load and not ref_block_shape:
|
2023-08-10 21:18:10 -07:00
|
|
|
raise NotImplementedError(
|
|
|
|
"Indexing into a ()-shaped Ref not yet supported on TPU.")
|
2023-08-01 16:42:26 -07:00
|
|
|
if any(
|
2023-10-10 14:38:54 -07:00
|
|
|
(not isinstance(a, primitives.Slice) and a.shape)
|
2023-08-01 16:42:26 -07:00
|
|
|
for a in idx_aval.indices
|
|
|
|
):
|
|
|
|
raise ValueError("Cannot do int indexing on TPU")
|
2024-03-14 16:31:23 -07:00
|
|
|
starts, sizes, strides, _, _ = _indexer_to_start_size_stride(
|
|
|
|
idx,
|
|
|
|
ref_block_shape,
|
|
|
|
cast_to_index=True,
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
2024-03-14 16:31:23 -07:00
|
|
|
need_stride = not all((s is None or s == 1) for s in strides)
|
2023-08-01 16:42:26 -07:00
|
|
|
if is_smem_load:
|
|
|
|
if ctx.avals_out[0].shape:
|
2023-09-07 04:25:44 -07:00
|
|
|
raise ValueError("Can only load scalars from SMEM")
|
2025-01-14 20:33:34 -08:00
|
|
|
return _maybe_cast_load_to_bool(ctx, aval_out, memref.load(ref, starts))
|
2024-08-22 15:12:35 -07:00
|
|
|
elif str(ref_type.memory_space) != "#tpu.memory_space<vmem>":
|
|
|
|
extra = ""
|
|
|
|
if str(ref_type.memory_space) == "#tpu.memory_space<any>":
|
|
|
|
extra = " ANY memory space can only be accessed using async_copy."
|
|
|
|
raise ValueError(
|
|
|
|
"Loads are only allowed on VMEM and SMEM references." + extra
|
|
|
|
)
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
load_aval = jax_core.ShapedArray(sizes, dtype=aval_out.dtype)
|
2024-03-14 16:31:23 -07:00
|
|
|
if need_stride:
|
2024-10-18 16:13:46 -07:00
|
|
|
load_val = tpu.strided_load(
|
2025-01-14 20:33:34 -08:00
|
|
|
aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
|
|
load_aval,
|
|
|
|
is_kernel_boundary=True,
|
|
|
|
),
|
|
|
|
ref,
|
|
|
|
starts,
|
|
|
|
strides,
|
2024-10-18 16:13:46 -07:00
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
else:
|
2024-10-18 16:13:46 -07:00
|
|
|
load_val = vector.load(
|
2025-01-14 20:33:34 -08:00
|
|
|
aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
|
|
load_aval,
|
|
|
|
is_kernel_boundary=True,
|
|
|
|
),
|
|
|
|
ref,
|
|
|
|
starts,
|
|
|
|
)
|
2024-08-14 18:22:13 -07:00
|
|
|
if load_aval != aval_out:
|
|
|
|
vec_type = ir.VectorType.get(aval_out.shape,
|
|
|
|
_dtype_to_ir_type(aval_out.dtype,
|
|
|
|
is_kernel_boundary=True))
|
2024-10-18 16:13:46 -07:00
|
|
|
load_val = vector.shape_cast(vec_type, load_val)
|
2025-01-14 20:33:34 -08:00
|
|
|
return _maybe_cast_load_to_bool(ctx, aval_out, load_val)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2024-06-12 14:36:31 -07:00
|
|
|
def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle:
|
|
|
|
"""Lowering rule for loading PRNG keys from SMEM.
|
|
|
|
|
|
|
|
PRNG key loads are currently lowered as a list of scalar loads from SMEM,
|
|
|
|
rather than a single vector load.
|
|
|
|
We store these scalars in a bundle type called KeyScalarBundle, which has
|
|
|
|
special case handling for functions that consume the key such as set_seed.
|
|
|
|
"""
|
|
|
|
ref, _, _, _ = args_tree.unflatten(args_flat)
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
assert isinstance(aval_out.dtype, prng.KeyTy)
|
|
|
|
ref_block_shape = aval_out.dtype._impl.key_shape
|
|
|
|
|
2024-06-24 11:19:59 -07:00
|
|
|
if len(ref_block_shape) != 2:
|
|
|
|
raise NotImplementedError("Seed key_data must be 2D.")
|
|
|
|
if tuple(ref_block_shape) != (1, 1):
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"Seed key_data of shape != (1, 1) not supported. Got: {ref_block_shape}")
|
2024-06-12 14:36:31 -07:00
|
|
|
|
|
|
|
load_ops = []
|
|
|
|
for i in range(ref_block_shape[0]):
|
2024-06-24 11:19:59 -07:00
|
|
|
idx = NDIndexer(indices=(0, i), shape=ref_block_shape,
|
2024-06-12 14:36:31 -07:00
|
|
|
int_indexer_shape=tuple())
|
|
|
|
starts, _, _, _, _ = _indexer_to_start_size_stride(
|
|
|
|
idx,
|
|
|
|
ref_block_shape,
|
|
|
|
cast_to_index=True,
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
load_ops.append(memref.load(ref, starts))
|
2024-07-03 13:07:39 -07:00
|
|
|
return KeyScalarBundle(scalars=load_ops, key_shape=tuple(ref_block_shape))
|
2024-06-12 14:36:31 -07:00
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
lowering_rules[primitives.load_p] = _load_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(primitives.load_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2025-01-14 20:33:34 -08:00
|
|
|
|
2024-07-15 17:58:27 -07:00
|
|
|
def _maybe_cast_load_to_bool(
|
2025-01-14 20:33:34 -08:00
|
|
|
ctx, out_aval, val: ir.Value
|
|
|
|
) -> tuple[ir.Value, jnp.dtype]:
|
2024-07-15 17:58:27 -07:00
|
|
|
"""Casts a memref load value to bool if the requested value is a bool.
|
|
|
|
|
|
|
|
Mosaic does not support boolean-type memrefs, since booleans
|
|
|
|
typically live in mask registers. We instead load booleans as integers from
|
|
|
|
memrefs and move them to mask registers on load using this function.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
out_aval: The output aval of the load.
|
|
|
|
val: The input value.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The loaded value, and the JAX dtype of the input value.
|
|
|
|
"""
|
|
|
|
if out_aval.dtype != jnp.bool_:
|
|
|
|
return val
|
|
|
|
load_scalar_type = _dtype_to_ir_type(BOOL_MEMREF_TYPE)
|
2024-10-09 09:28:07 -07:00
|
|
|
pred = _cmpsi_lowering_types[lax.ne_p]
|
2024-08-14 11:07:45 -07:00
|
|
|
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
|
|
|
|
const_zero = ir.IntegerAttr.get(load_scalar_type, 0)
|
|
|
|
if out_aval.shape: # Vector case.
|
2025-01-14 20:33:34 -08:00
|
|
|
load_vector_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
|
|
out_aval,
|
|
|
|
is_kernel_boundary=True,
|
|
|
|
)
|
2024-08-14 11:07:45 -07:00
|
|
|
vector_zeros = arith.ConstantOp(
|
|
|
|
load_vector_type,
|
|
|
|
ir.DenseElementsAttr.get_splat(load_vector_type, const_zero)
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.cmpi(predicate, val, vector_zeros)
|
2024-08-14 11:07:45 -07:00
|
|
|
else: # Scalar case.
|
2024-07-15 17:58:27 -07:00
|
|
|
const_zero = arith.ConstantOp(load_scalar_type, const_zero)
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.cmpi(predicate, val, const_zero)
|
2024-08-14 11:07:45 -07:00
|
|
|
|
2024-07-15 17:58:27 -07:00
|
|
|
|
|
|
|
def _maybe_cast_store_to_memref_type(
|
2025-01-14 20:33:34 -08:00
|
|
|
ctx: LoweringRuleContext, expected_aval, val: ir.Value
|
|
|
|
) -> ir.Value:
|
2024-07-15 17:58:27 -07:00
|
|
|
"""Casts a boolean value back to an integer for storing in a memref."""
|
|
|
|
if expected_aval.dtype != jnp.bool_:
|
|
|
|
return val
|
2025-01-14 20:33:34 -08:00
|
|
|
int_out_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
|
|
expected_aval,
|
|
|
|
is_kernel_boundary=True,
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.extui(int_out_type, val)
|
2024-07-15 17:58:27 -07:00
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
def _masked_swap_lowering_rule(
|
2023-10-10 14:38:54 -07:00
|
|
|
ctx: LoweringRuleContext, *args_flat, args_tree, **_
|
2023-08-01 16:42:26 -07:00
|
|
|
):
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
ref, transforms, val, mask = args_tree.unflatten(args_flat)
|
2024-11-20 14:03:12 -08:00
|
|
|
ref_aval, transforms_avals, val_aval, mask_aval = args_tree.unflatten(
|
|
|
|
ctx.avals_in
|
|
|
|
)
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
(*prev_transforms, idx) = transforms
|
|
|
|
(*_, idx_aval) = transforms_avals
|
2023-10-10 14:38:54 -07:00
|
|
|
|
|
|
|
if mask is not None:
|
2024-11-20 14:03:12 -08:00
|
|
|
if val_aval.dtype.itemsize != 4:
|
|
|
|
raise NotImplementedError("masked swap with non-32-bit data")
|
|
|
|
if val_aval.shape != mask_aval.shape:
|
|
|
|
raise ValueError(
|
|
|
|
"Expected value and mask to have the same shape, but got"
|
|
|
|
f" value shape {val_aval.shape} vs. mask shape {mask_aval.shape}."
|
|
|
|
)
|
2023-10-10 14:38:54 -07:00
|
|
|
|
2024-01-02 21:53:30 -08:00
|
|
|
ref_block_shape, *_ = ctx.block_shapes
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
ref, ref_block_shape = _transform_ref(
|
|
|
|
ref, ref_aval.dtype, ref_block_shape, prev_transforms
|
|
|
|
)
|
2024-01-02 21:53:30 -08:00
|
|
|
|
2023-09-07 04:25:44 -07:00
|
|
|
ref_type = ir.MemRefType(ref.type)
|
2024-09-24 17:01:01 -07:00
|
|
|
memory_space = str(ref_type.memory_space)
|
|
|
|
is_smem_store = memory_space == "#tpu.memory_space<smem>"
|
|
|
|
is_vmem_store = memory_space == "#tpu.memory_space<vmem>"
|
2023-08-01 16:42:26 -07:00
|
|
|
(aval_out,) = ctx.avals_out
|
2023-08-04 13:43:04 -07:00
|
|
|
if not isinstance(val, ir.Value):
|
2024-01-11 06:32:57 -08:00
|
|
|
val = ir_constant(val, mlir_type=_dtype_to_ir_type(val_aval.dtype))
|
2023-08-01 16:42:26 -07:00
|
|
|
if any(
|
2023-10-10 14:38:54 -07:00
|
|
|
(not isinstance(a, primitives.Slice) and a.shape)
|
2023-08-01 16:42:26 -07:00
|
|
|
for a in idx_aval.indices
|
|
|
|
):
|
|
|
|
raise ValueError("Cannot do int indexing on TPU")
|
2023-10-24 17:28:05 -07:00
|
|
|
if not is_smem_store and not ref_block_shape:
|
2023-08-10 21:18:10 -07:00
|
|
|
raise NotImplementedError(
|
|
|
|
"Indexing into a ()-shaped Ref not yet supported on TPU.")
|
2024-01-02 21:53:30 -08:00
|
|
|
|
2024-03-14 16:31:23 -07:00
|
|
|
starts, _, strides, _, _ = _indexer_to_start_size_stride(
|
|
|
|
idx,
|
|
|
|
ref_block_shape,
|
|
|
|
cast_to_index=True,
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
2024-03-14 16:31:23 -07:00
|
|
|
need_stride = not all((s is None or s == 1) for s in strides)
|
2024-09-24 17:01:01 -07:00
|
|
|
|
2023-09-07 04:25:44 -07:00
|
|
|
if is_smem_store:
|
2024-11-20 14:03:12 -08:00
|
|
|
if mask is not None:
|
|
|
|
raise ValueError("SMEM store does not support masks")
|
2023-09-07 04:25:44 -07:00
|
|
|
if val_aval.shape:
|
|
|
|
raise ValueError("Can only store scalars to SMEM")
|
2024-10-18 16:13:46 -07:00
|
|
|
result = memref.load(ref, starts)
|
2025-01-14 20:33:34 -08:00
|
|
|
result = _maybe_cast_load_to_bool(ctx, val_aval, result)
|
|
|
|
val = _maybe_cast_store_to_memref_type(ctx, val_aval, val)
|
2024-01-02 21:53:30 -08:00
|
|
|
memref.StoreOp(val, ref, starts)
|
2023-09-07 04:25:44 -07:00
|
|
|
return result
|
2024-09-24 17:01:01 -07:00
|
|
|
|
|
|
|
if not is_vmem_store:
|
2024-08-22 15:12:35 -07:00
|
|
|
extra = ""
|
2024-09-24 17:01:01 -07:00
|
|
|
if memory_space == "#tpu.memory_space<any>":
|
2024-08-22 15:12:35 -07:00
|
|
|
extra = " ANY memory space can only be accessed using async_copy."
|
|
|
|
raise ValueError(
|
|
|
|
"Loads and stores are only allowed on VMEM and SMEM references." + extra
|
|
|
|
)
|
2024-09-24 17:01:01 -07:00
|
|
|
|
|
|
|
# handling VMEM store below
|
|
|
|
if not val_aval.shape:
|
|
|
|
raise ValueError("Cannot store scalars to VMEM")
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
mem_slice_shape = list(aval_out.shape)
|
|
|
|
for i, a in enumerate(idx_aval.indices):
|
|
|
|
if not isinstance(a, primitives.Slice):
|
|
|
|
mem_slice_shape.insert(i, 1)
|
|
|
|
mem_slice_shape_iter = iter(mem_slice_shape)
|
|
|
|
mem_slice_shape = [
|
2024-08-05 04:23:15 -07:00
|
|
|
1 if b is pallas_core.mapped else next(mem_slice_shape_iter)
|
2023-08-01 16:42:26 -07:00
|
|
|
for b in ref_block_shape
|
|
|
|
]
|
2025-02-03 17:59:44 -08:00
|
|
|
mem_aval = aval_out.update(
|
|
|
|
shape=tuple(mem_slice_shape), sharding=jax_core.get_cur_mesh_sharding()
|
|
|
|
)
|
2025-01-14 20:33:34 -08:00
|
|
|
mem_aval_shape = ctx.lowering_context.dynamic_shape_replacement_fn(
|
|
|
|
mem_aval.shape
|
|
|
|
)
|
|
|
|
mem_aval_vec_type = ir.VectorType.get(
|
|
|
|
mem_aval_shape, _dtype_to_ir_type(mem_aval.dtype, is_kernel_boundary=True)
|
|
|
|
)
|
2024-03-14 16:31:23 -07:00
|
|
|
if need_stride:
|
2024-10-18 16:13:46 -07:00
|
|
|
result = tpu.strided_load(mem_aval_vec_type, ref, starts, strides)
|
2024-03-14 16:31:23 -07:00
|
|
|
else:
|
2024-10-18 16:13:46 -07:00
|
|
|
result = vector.load(mem_aval_vec_type, ref, starts)
|
2025-01-14 20:33:34 -08:00
|
|
|
val = _maybe_cast_store_to_memref_type(ctx, val_aval, val)
|
2023-08-01 16:42:26 -07:00
|
|
|
if mem_aval != aval_out:
|
2023-08-10 21:18:10 -07:00
|
|
|
# We are slicing a scalar so provided dummy 1 indices
|
|
|
|
result_vec_type = ir.VectorType.get(aval_out.shape,
|
2024-07-15 17:58:27 -07:00
|
|
|
_dtype_to_ir_type(aval_out.dtype, is_kernel_boundary=True))
|
2024-10-18 16:13:46 -07:00
|
|
|
result = vector.shape_cast(result_vec_type, result)
|
2023-08-10 21:18:10 -07:00
|
|
|
val_vec_type = ir.VectorType.get(mem_aval.shape,
|
2024-07-15 17:58:27 -07:00
|
|
|
_dtype_to_ir_type(mem_aval.dtype, is_kernel_boundary=True))
|
2024-10-18 16:13:46 -07:00
|
|
|
val = vector.shape_cast(val_vec_type, val)
|
2025-01-14 20:33:34 -08:00
|
|
|
result = _maybe_cast_load_to_bool(ctx, val_aval, result)
|
2024-07-15 17:58:27 -07:00
|
|
|
|
2024-03-14 16:31:23 -07:00
|
|
|
if need_stride:
|
2024-11-20 14:03:12 -08:00
|
|
|
if mask is not None:
|
|
|
|
raise NotImplementedError("masked swap with strided store")
|
2024-03-14 16:31:23 -07:00
|
|
|
tpu.StridedStoreOp(val, ref, starts, strides)
|
2024-11-20 14:03:12 -08:00
|
|
|
else:
|
|
|
|
tpu.VectorStoreOp(val, ref, starts, [], mask=mask)
|
2023-08-01 16:42:26 -07:00
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(primitives.swap_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _multiple_of_lowering_rule(ctx: LoweringRuleContext, val, *, values):
|
2024-01-12 03:42:18 -08:00
|
|
|
del ctx
|
|
|
|
for multiple in values:
|
|
|
|
val = tpu.assume_multiple(val, multiple)
|
2023-08-01 16:42:26 -07:00
|
|
|
return val
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[primitives.multiple_of_p] = _multiple_of_lowering_rule
|
|
|
|
|
|
|
|
|
2024-07-09 11:02:34 -07:00
|
|
|
def reduce_lowering_rule(reduce_fn, type_to_kind, type_to_identity):
|
|
|
|
def _lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
|
|
|
(x_aval,) = ctx.avals_in
|
|
|
|
if not ctx.avals_out[0].shape:
|
|
|
|
# If reducing to a scalar, we reduce by adding a leading singleton
|
|
|
|
# dimension and reducing over all other dimensions. This avoids
|
|
|
|
# the materialization of a scalar tensor by the reduction op which
|
|
|
|
# is not supported.
|
|
|
|
def _proxy_fun(val, *, axes):
|
|
|
|
val = val[jnp.newaxis, ...]
|
|
|
|
axes = [axis + 1 for axis in axes]
|
|
|
|
val = reduce_fn(val, axis=axes, keepdims=True)
|
|
|
|
# Squeeze lowers to vector.ExtractOp which will place the final
|
|
|
|
# value in a scalar register.
|
|
|
|
return jnp.squeeze(val)
|
|
|
|
proxy_lowering = lower_fun(
|
|
|
|
_proxy_fun, multiple_results=False)
|
|
|
|
return proxy_lowering(ctx, x, axes=axes)
|
|
|
|
|
|
|
|
if jnp.issubdtype(x_aval.dtype, jnp.floating):
|
|
|
|
kind = type_to_kind[jnp.floating]
|
|
|
|
val = type_to_identity[jnp.floating]
|
2025-01-14 20:33:34 -08:00
|
|
|
val = ir.FloatAttr.get(
|
|
|
|
aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
|
|
x_aval,
|
|
|
|
shape=(),
|
|
|
|
),
|
|
|
|
val,
|
|
|
|
)
|
2024-12-18 16:48:28 -08:00
|
|
|
elif x_aval.dtype == jnp.int32:
|
|
|
|
kind = type_to_kind[jnp.signedinteger]
|
|
|
|
val = type_to_identity[jnp.signedinteger]
|
|
|
|
val = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), val)
|
2024-07-09 11:02:34 -07:00
|
|
|
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
|
2024-12-18 16:48:28 -08:00
|
|
|
raise NotImplementedError(
|
|
|
|
"Reductions over unsigned integers not implemented."
|
|
|
|
)
|
2024-07-09 11:02:34 -07:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"Reductions over {x_aval.dtype} not implemented.")
|
2025-01-14 20:33:34 -08:00
|
|
|
out_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
identity = ir.DenseElementsAttr.get_splat(out_type, val)
|
2024-07-09 11:02:34 -07:00
|
|
|
acc = arith.ConstantOp(out_type, identity)
|
2024-10-18 16:13:46 -07:00
|
|
|
return vector.multi_reduction(kind, x, acc, axes)
|
2024-07-09 11:02:34 -07:00
|
|
|
return _lowering_rule
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2024-07-09 11:02:34 -07:00
|
|
|
REDUCE_MAX_KINDS = {
|
|
|
|
jnp.floating: vector.CombiningKind.MAXIMUMF,
|
|
|
|
jnp.signedinteger: vector.CombiningKind.MAXSI,
|
|
|
|
jnp.unsignedinteger: vector.CombiningKind.MAXUI,
|
|
|
|
}
|
|
|
|
REDUCE_MAX_IDENTITY = {
|
|
|
|
jnp.floating: float("-inf"),
|
|
|
|
jnp.signedinteger: np.iinfo(np.int32).min,
|
|
|
|
}
|
|
|
|
_reduce_max_lowering_rule = reduce_lowering_rule(
|
|
|
|
jnp.max, REDUCE_MAX_KINDS, REDUCE_MAX_IDENTITY)
|
2023-08-01 16:42:26 -07:00
|
|
|
lowering_rules[lax.reduce_max_p] = _reduce_max_lowering_rule
|
|
|
|
|
|
|
|
|
2024-07-09 11:02:34 -07:00
|
|
|
REDUCE_MIN_KINDS = {
|
|
|
|
jnp.floating: vector.CombiningKind.MINIMUMF,
|
|
|
|
jnp.signedinteger: vector.CombiningKind.MINSI,
|
|
|
|
jnp.unsignedinteger: vector.CombiningKind.MINUI,
|
|
|
|
}
|
|
|
|
REDUCE_MIN_IDENTITY = {
|
|
|
|
jnp.floating: float("inf"),
|
|
|
|
jnp.signedinteger: np.iinfo(np.int32).max,
|
|
|
|
}
|
|
|
|
_reduce_min_lowering_rule = reduce_lowering_rule(
|
|
|
|
jnp.min, REDUCE_MIN_KINDS, REDUCE_MIN_IDENTITY)
|
|
|
|
lowering_rules[lax.reduce_min_p] = _reduce_min_lowering_rule
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2024-07-09 11:02:34 -07:00
|
|
|
REDUCE_SUM_KINDS = {
|
|
|
|
jnp.floating: vector.CombiningKind.ADD,
|
|
|
|
jnp.signedinteger: vector.CombiningKind.ADD,
|
|
|
|
jnp.unsignedinteger: vector.CombiningKind.ADD,
|
|
|
|
}
|
|
|
|
REDUCE_SUM_IDENTITY = {
|
|
|
|
jnp.floating: 0.0,
|
|
|
|
jnp.signedinteger: 0,
|
|
|
|
}
|
|
|
|
_reduce_sum_lowering_rule = reduce_lowering_rule(
|
|
|
|
jnp.sum, REDUCE_SUM_KINDS, REDUCE_SUM_IDENTITY)
|
2023-08-01 16:42:26 -07:00
|
|
|
lowering_rules[lax.reduce_sum_p] = _reduce_sum_lowering_rule
|
|
|
|
|
|
|
|
|
2024-07-09 11:02:34 -07:00
|
|
|
def _reduce_and_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
|
|
|
def _proxy_reduce(arg, *, axes):
|
|
|
|
# Mosaic currently only supports float reductions, so we cast the boolean
|
|
|
|
# arg to a float and use reduce_min to implement reduce_and.
|
|
|
|
# TODO(b/351017807): Implement this logic in Mosaic MultiDimReductionOp
|
|
|
|
# instead.
|
|
|
|
float_arg = jnp.where(arg, 1.0, 0.0)
|
|
|
|
return jnp.min(float_arg, axis=axes) > 0.0
|
|
|
|
proxy_lowering = lower_fun(
|
|
|
|
_proxy_reduce, multiple_results=False)
|
|
|
|
return proxy_lowering(ctx, x, axes=axes)
|
|
|
|
|
|
|
|
lowering_rules[lax.reduce_and_p] = _reduce_and_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _reduce_or_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
|
|
|
def _proxy_reduce(arg, *, axes):
|
|
|
|
# Mosaic currently only supports float reductions, so we cast the boolean
|
|
|
|
# arg to a float and use reduce_max to implement reduce_or.
|
|
|
|
# TODO(b/351017807): Implement this logic in Mosaic MultiDimReductionOp
|
|
|
|
# instead.
|
|
|
|
float_arg = jnp.where(arg, 1.0, 0.0)
|
|
|
|
return jnp.max(float_arg, axis=axes) > 0.0
|
|
|
|
proxy_lowering = lower_fun(
|
|
|
|
_proxy_reduce, multiple_results=False)
|
|
|
|
return proxy_lowering(ctx, x, axes=axes)
|
|
|
|
|
|
|
|
lowering_rules[lax.reduce_or_p] = _reduce_or_lowering_rule
|
|
|
|
|
|
|
|
|
[Pallas TPU] Better error message for lowering `sp.broadcast_to_p`
`sp.broadcast_to_p` is a GPU-specific primitive, but it mistakenly appears in TPU lowerings. This PR improves the error message to reflect this.
As an example, currently, users will hit this error when doing:
```
def kernel(x_ref, o_ref):
m, n = 32, 8
x = pl.load(x_ref, (jnp.arange(m, dtype=jnp.int32)[:, None], jnp.arange(n, dtype=jnp.int32)[None]))
o_ref[...] = x
```
PiperOrigin-RevId: 700290975
2024-11-26 04:08:59 -08:00
|
|
|
def _broadcast_to_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext, x, shape: Sequence[int]
|
|
|
|
):
|
|
|
|
raise RuntimeError(
|
|
|
|
"`broadcast_to` is a Triton-specific primitive. Please consider using"
|
|
|
|
" `jnp.broadcast_to` instead."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[state_primitives.broadcast_to_p] = _broadcast_to_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _broadcast_in_dim_lowering_rule(
|
2024-10-25 10:34:33 -07:00
|
|
|
ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions, sharding
|
2023-08-01 16:42:26 -07:00
|
|
|
):
|
2024-10-25 10:34:33 -07:00
|
|
|
del sharding
|
2023-08-01 16:42:26 -07:00
|
|
|
(aval_in,) = ctx.avals_in
|
|
|
|
(aval_out,) = ctx.avals_out
|
2024-07-09 11:02:34 -07:00
|
|
|
|
|
|
|
if jnp.issubdtype(aval_in.dtype, jnp.bool_):
|
|
|
|
# Direct broadcasts for bools are not supported in Mosaic due to booleans
|
|
|
|
# living in mask registers and broadcast operating on vregs. Broadcast as an
|
|
|
|
# integer instead and cast back to a bool.
|
|
|
|
# TODO(b/351019164): Implement this logic in Mosaic BroadcastOp instead.
|
|
|
|
def _proxy_fun(val, *, shape, broadcast_dimensions):
|
|
|
|
int_val = jnp.where(val, 1, 0)
|
|
|
|
bcast_val = jax.lax.broadcast_in_dim(int_val, shape, broadcast_dimensions)
|
|
|
|
return bcast_val == 1
|
|
|
|
proxy_lowering = lower_fun(
|
|
|
|
_proxy_fun, multiple_results=False)
|
|
|
|
return proxy_lowering(
|
|
|
|
ctx, val, shape=shape, broadcast_dimensions=broadcast_dimensions)
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
if broadcast_dimensions:
|
|
|
|
out_shape_list = [1] * len(shape)
|
|
|
|
for i, s in zip(broadcast_dimensions, aval_in.shape):
|
|
|
|
out_shape_list[i] = s
|
|
|
|
out_shape = tuple(out_shape_list)
|
|
|
|
out_type = ir.VectorType.get(
|
2024-01-11 06:32:57 -08:00
|
|
|
out_shape, _dtype_to_ir_type(aval_out.dtype)
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
val = vector.shape_cast(out_type, val)
|
2023-08-01 16:42:26 -07:00
|
|
|
if out_shape == aval_out.shape:
|
|
|
|
return val
|
|
|
|
out_type = ir.VectorType.get(
|
2024-01-11 06:32:57 -08:00
|
|
|
aval_out.shape, _dtype_to_ir_type(aval_out.dtype)
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
return vector.broadcast(out_type, val)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.broadcast_in_dim_p] = _broadcast_in_dim_lowering_rule
|
|
|
|
|
|
|
|
|
2024-11-07 15:30:36 -08:00
|
|
|
def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape):
|
|
|
|
"""Converts a jax dot dimension numbers to a tpu dot dimension numbers.
|
|
|
|
|
|
|
|
Jax dot dimension numbers are given as a tuple of tuples of sequences of ints
|
|
|
|
of the form ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
|
|
|
|
rhs_batch_dims)).
|
|
|
|
|
|
|
|
TPU dot dimension numbers are given as an MLIR definition of the form
|
|
|
|
#tpu.dot_dimension_numbers - which can be found in the tpu dilect definition
|
|
|
|
# file, tpu.td .
|
|
|
|
"""
|
|
|
|
(contracting_dims, batch_dims) = dimension_numbers
|
|
|
|
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
|
|
|
|
lhs_batch_dims, rhs_batch_dims = batch_dims
|
|
|
|
|
|
|
|
lhs_total_dims = set(range(len(lhs_shape)))
|
|
|
|
rhs_total_dims = set(range(len(rhs_shape)))
|
|
|
|
|
|
|
|
lhs_non_contracting_dims = sorted(
|
|
|
|
lhs_total_dims - set(lhs_contracting_dims) - set(lhs_batch_dims)
|
|
|
|
)
|
|
|
|
rhs_non_contracting_dims = sorted(
|
|
|
|
rhs_total_dims - set(rhs_contracting_dims) - set(rhs_batch_dims)
|
|
|
|
)
|
|
|
|
|
|
|
|
# Create output_dim_order
|
|
|
|
# Note: we assume that the output dimensions are ordered as batch dims, lhs_non_contracting_dims,
|
|
|
|
# rhs_non_contracting_dims - this assumption is safe to make, as it is
|
|
|
|
# the same one made in jax's dot_general.
|
|
|
|
output_dim_order = []
|
|
|
|
|
|
|
|
lhs_dim_map = {dim: idx for idx, dim in enumerate(range(len(lhs_shape)))}
|
|
|
|
rhs_dim_map = {dim: idx for idx, dim in enumerate(range(len(rhs_shape)))}
|
|
|
|
|
|
|
|
for dim in lhs_batch_dims:
|
|
|
|
output_dim_order.append(0)
|
|
|
|
output_dim_order.append(lhs_dim_map[dim])
|
|
|
|
|
|
|
|
for dim in lhs_non_contracting_dims:
|
|
|
|
output_dim_order.append(0)
|
|
|
|
output_dim_order.append(lhs_dim_map[dim])
|
|
|
|
|
|
|
|
for dim in rhs_non_contracting_dims:
|
|
|
|
output_dim_order.append(1)
|
|
|
|
output_dim_order.append(rhs_dim_map[dim])
|
|
|
|
|
|
|
|
def format_dims(dims):
|
|
|
|
return "[" + ", ".join(str(d) for d in dims) + "]"
|
|
|
|
|
|
|
|
all_dims = (
|
|
|
|
lhs_contracting_dims,
|
|
|
|
rhs_contracting_dims,
|
|
|
|
lhs_non_contracting_dims,
|
|
|
|
rhs_non_contracting_dims,
|
|
|
|
output_dim_order,
|
|
|
|
lhs_batch_dims,
|
|
|
|
rhs_batch_dims,
|
|
|
|
)
|
|
|
|
tpu_dim_numbers_str = (
|
|
|
|
f"#tpu.dot_dimension_numbers<{','.join(map(format_dims, all_dims))}>"
|
|
|
|
)
|
|
|
|
|
|
|
|
return ir.Attribute.parse(tpu_dim_numbers_str)
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _dot_general_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext, x, y, dimension_numbers, precision, **_
|
|
|
|
):
|
|
|
|
(lhs_dims, rhs_dims), _ = dimension_numbers
|
|
|
|
(aval_out,) = ctx.avals_out
|
2025-01-14 20:33:34 -08:00
|
|
|
out_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, aval_out
|
|
|
|
)
|
2023-11-09 11:09:15 -08:00
|
|
|
val_type = out_type.element_type
|
2024-06-07 08:16:09 -07:00
|
|
|
if any(
|
|
|
|
cls.isinstance(val_type)
|
|
|
|
for cls in [
|
|
|
|
ir.BF16Type,
|
|
|
|
ir.F32Type,
|
|
|
|
ir.Float8E5M2Type,
|
|
|
|
ir.Float8E4M3FNType,
|
2025-02-20 10:43:46 -08:00
|
|
|
ir.Float8E4M3B11FNUZType,
|
2024-06-07 08:16:09 -07:00
|
|
|
]
|
|
|
|
):
|
2023-11-09 11:09:15 -08:00
|
|
|
val = ir.FloatAttr.get(val_type, 0.0)
|
|
|
|
elif ir.IntegerType.isinstance(val_type):
|
|
|
|
val = ir.IntegerAttr.get(val_type, 0)
|
2023-08-01 16:42:26 -07:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(ctx.avals_out[0].dtype)
|
|
|
|
if any(len(a.shape) != 2 for a in ctx.avals_in):
|
2024-05-02 10:58:42 -07:00
|
|
|
raise NotImplementedError(
|
|
|
|
f"Only 2D tensors supported in dot; received: {ctx.avals_in}"
|
|
|
|
)
|
2024-11-07 15:30:36 -08:00
|
|
|
lhs_aval, rhs_aval = ctx.avals_in
|
2023-08-01 16:42:26 -07:00
|
|
|
# This is really a matrix-vector product. It only looks like matrix-matrix.
|
|
|
|
if lhs_dims == (1,) and rhs_dims == (1,) and ctx.avals_in[1].shape[0] == 1:
|
|
|
|
if ctx.avals_in[0].shape != ctx.avals_in[1].shape:
|
|
|
|
bcast_shape = jnp.broadcast_shapes(
|
|
|
|
ctx.avals_in[0].shape, ctx.avals_out[0].shape
|
|
|
|
)
|
|
|
|
bcast_shape = ir.VectorType.get(
|
2024-01-11 06:32:57 -08:00
|
|
|
list(bcast_shape), _dtype_to_ir_type(ctx.avals_out[0].dtype)
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
if ctx.avals_in[0].shape != bcast_shape:
|
2024-12-17 02:16:27 -08:00
|
|
|
x = vector.broadcast(bcast_shape, x)
|
2023-08-01 16:42:26 -07:00
|
|
|
if ctx.avals_in[1].shape != bcast_shape:
|
2024-12-17 02:16:27 -08:00
|
|
|
y = vector.broadcast(bcast_shape, y)
|
2025-01-14 20:33:34 -08:00
|
|
|
red_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
|
|
lhs_aval.update(shape=(lhs_aval.shape[0],)),
|
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
acc = arith.ConstantOp(
|
|
|
|
red_type, ir.DenseElementsAttr.get_splat(red_type, val)
|
|
|
|
)
|
|
|
|
red = vector.MultiDimReductionOp(
|
|
|
|
ir.Attribute.parse("#vector.kind<add>"),
|
|
|
|
arith.MulFOp(x, y),
|
|
|
|
acc,
|
2024-08-14 08:42:00 -07:00
|
|
|
[1]
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
return vector.shape_cast(out_type, red)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2024-11-07 15:30:36 -08:00
|
|
|
tpu_dot_dims = jax_dot_dims_to_tpu_dot_dot_dims(
|
|
|
|
dimension_numbers, lhs_aval.shape, rhs_aval.shape
|
|
|
|
)
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
if precision is not None:
|
|
|
|
if precision[0] != precision[1]:
|
|
|
|
raise NotImplementedError("Per-operand dot precision unsupported")
|
|
|
|
precision = precision[0]
|
|
|
|
if precision is None or precision == lax.Precision.DEFAULT:
|
2023-10-12 07:37:22 -07:00
|
|
|
precision_attr = None # That's the default in Mosaic.
|
2023-08-01 16:42:26 -07:00
|
|
|
elif precision == lax.Precision.HIGHEST:
|
2023-10-12 07:37:22 -07:00
|
|
|
precision_attr = ir.Attribute.parse(
|
2023-08-01 16:42:26 -07:00
|
|
|
"#tpu.contract_precision<fp32>"
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"Unsupported dot precision: {precision}")
|
2023-10-12 07:37:22 -07:00
|
|
|
out_tile = arith.ConstantOp(
|
|
|
|
out_type, ir.DenseElementsAttr.get_splat(out_type, val)
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
return tpu.matmul(
|
2024-11-07 15:30:36 -08:00
|
|
|
out_type,
|
|
|
|
x,
|
|
|
|
y,
|
|
|
|
out_tile,
|
|
|
|
dimension_numbers=tpu_dot_dims,
|
|
|
|
precision=precision_attr,
|
2023-10-12 07:37:22 -07:00
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.dot_general_p] = _dot_general_lowering_rule
|
|
|
|
|
2024-01-18 19:02:02 -08:00
|
|
|
def _convert_helper(x, *, to_dtype):
|
|
|
|
# Helper function for dtype conversion
|
|
|
|
from_dtype = x.dtype
|
2025-01-09 07:59:27 -08:00
|
|
|
if from_dtype == jnp.bool_:
|
2024-01-18 19:02:02 -08:00
|
|
|
x = x.astype(jnp.int32)
|
|
|
|
return _convert_helper(x, to_dtype=to_dtype)
|
2025-01-09 07:59:27 -08:00
|
|
|
if to_dtype == jnp.bool_:
|
|
|
|
# Lower float32 or (u)int32 -> bool to cmp neq %in, 0
|
|
|
|
# TODO(apaszke,mvoz): Move the upcasts for cmpi to the Mosaic canonicalizer.
|
|
|
|
if jnp.issubdtype(from_dtype, jnp.floating):
|
|
|
|
if from_dtype.itemsize < 4:
|
|
|
|
x = x.astype(jnp.float32)
|
|
|
|
elif jnp.issubdtype(from_dtype, jnp.integer):
|
|
|
|
if from_dtype.itemsize < 4:
|
|
|
|
x = x.astype(jnp.int32)
|
|
|
|
return x != jnp.asarray(0, dtype=x.dtype)
|
2024-06-10 18:07:33 -07:00
|
|
|
if jnp.issubdtype(from_dtype, jnp.signedinteger):
|
2024-01-18 19:02:02 -08:00
|
|
|
if from_dtype.itemsize < 4:
|
|
|
|
x = x.astype(jnp.int32)
|
|
|
|
if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4:
|
|
|
|
x = x.astype(jnp.float32)
|
|
|
|
return x.astype(to_dtype)
|
2024-08-22 23:56:32 -07:00
|
|
|
if jnp.issubdtype(from_dtype, jnp.unsignedinteger):
|
|
|
|
if from_dtype.itemsize < 4:
|
|
|
|
x = x.astype(jnp.uint32)
|
2025-01-10 04:34:32 -08:00
|
|
|
# unsigned -> float is unsupported. We fall through and raise at the bottom.
|
|
|
|
if not jnp.issubdtype(to_dtype, jnp.floating):
|
|
|
|
return x.astype(to_dtype)
|
|
|
|
if jnp.issubdtype(from_dtype, jnp.floating) and jnp.issubdtype(
|
|
|
|
to_dtype, jnp.signedinteger
|
|
|
|
):
|
|
|
|
if from_dtype.itemsize < 4:
|
2024-08-22 23:56:32 -07:00
|
|
|
x = x.astype(jnp.float32)
|
2025-01-10 04:34:32 -08:00
|
|
|
if to_dtype.itemsize < 4:
|
|
|
|
# Need to clip values to match XLA
|
|
|
|
minval, maxval = jnp.iinfo(to_dtype).min, jnp.iinfo(to_dtype).max
|
|
|
|
x = jnp.clip(x, minval, maxval)
|
|
|
|
return x.astype(jnp.int32).astype(to_dtype)
|
2024-01-18 19:02:02 -08:00
|
|
|
raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}")
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
def _convert_element_type_lowering_rule(
|
2024-07-09 07:32:38 -07:00
|
|
|
ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
|
2023-08-01 16:42:26 -07:00
|
|
|
):
|
2023-08-22 22:01:44 -07:00
|
|
|
del weak_type
|
2024-07-09 07:32:38 -07:00
|
|
|
del sharding
|
2023-08-22 22:01:44 -07:00
|
|
|
out_aval = ctx.avals_out[0]
|
2024-09-27 02:14:55 -07:00
|
|
|
in_aval = ctx.avals_in[0]
|
|
|
|
old_dtype = in_aval.dtype
|
2025-01-14 20:33:34 -08:00
|
|
|
out_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
|
|
)
|
2024-06-10 18:07:33 -07:00
|
|
|
|
2023-08-22 22:01:44 -07:00
|
|
|
if old_dtype == new_dtype:
|
|
|
|
return x
|
2024-10-09 12:00:14 -07:00
|
|
|
|
|
|
|
if new_dtype.itemsize == 8:
|
|
|
|
raise NotImplementedError("64-bit types are not supported")
|
|
|
|
|
2025-01-09 07:59:27 -08:00
|
|
|
_from = lambda dtype: jnp.issubdtype(old_dtype, dtype)
|
|
|
|
_to = lambda dtype: jnp.issubdtype(new_dtype, dtype)
|
|
|
|
floating = jnp.floating
|
|
|
|
integer = jnp.integer
|
|
|
|
signed = jnp.signedinteger
|
|
|
|
both_32bit = old_dtype.itemsize == 4 and new_dtype.itemsize == 4
|
|
|
|
if _from(floating) and _to(floating):
|
2024-01-18 19:02:02 -08:00
|
|
|
if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4:
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.extf(out_type, x)
|
2024-01-18 19:02:02 -08:00
|
|
|
elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4:
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.truncf(out_type, x)
|
2025-01-09 07:59:27 -08:00
|
|
|
elif _from(integer) and _to(integer):
|
2024-01-18 19:02:02 -08:00
|
|
|
if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4:
|
2025-01-10 04:34:32 -08:00
|
|
|
if not (_from(signed) and _to(signed)):
|
|
|
|
raise NotImplementedError(f"Unsupported cast: {old_dtype} -> {new_dtype}")
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.extsi(out_type, x)
|
2024-01-18 19:02:02 -08:00
|
|
|
elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4:
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.trunci(out_type, x)
|
2024-08-23 15:18:14 -07:00
|
|
|
elif jnp.iinfo(old_dtype).bits == jnp.iinfo(new_dtype).bits:
|
|
|
|
# This case triggers when casting signed to unsigned or vice versa.
|
2024-08-22 23:56:32 -07:00
|
|
|
return x
|
2025-01-09 07:59:27 -08:00
|
|
|
# TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer.
|
2025-01-17 06:59:55 -08:00
|
|
|
elif _from(floating) and _to(signed):
|
|
|
|
# TODO(apaszke): Remove once a month has passed, along with the
|
|
|
|
# _convert_helper float -> signed conversion above.
|
|
|
|
if not ctx.forward_compatible or both_32bit:
|
|
|
|
return arith.fptosi(out_type, x)
|
2025-01-10 04:34:32 -08:00
|
|
|
elif _from(signed) and _to(floating) and both_32bit:
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.sitofp(out_type, x)
|
2025-01-09 07:59:27 -08:00
|
|
|
elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4:
|
2024-01-18 19:02:02 -08:00
|
|
|
return arith.extui(out_type, x)
|
|
|
|
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
|
|
|
|
multiple_results=False)(ctx, x)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.convert_element_type_p] = _convert_element_type_lowering_rule
|
|
|
|
|
|
|
|
|
2024-11-25 18:14:30 -08:00
|
|
|
def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions,
|
|
|
|
sharding):
|
2023-08-01 16:42:26 -07:00
|
|
|
if dimensions is not None:
|
|
|
|
raise NotImplementedError
|
|
|
|
if any(d is None for d in new_sizes):
|
|
|
|
raise NotImplementedError
|
2023-09-27 13:33:04 -07:00
|
|
|
if not ctx.avals_in[0].shape:
|
2025-01-14 20:33:34 -08:00
|
|
|
return vector.broadcast(
|
|
|
|
aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
|
|
),
|
|
|
|
x,
|
|
|
|
)
|
|
|
|
return vector.shape_cast(
|
|
|
|
aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
|
|
),
|
|
|
|
x,
|
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.reshape_p] = _reshape_lowering_rule
|
|
|
|
|
|
|
|
|
2023-09-07 03:44:18 -07:00
|
|
|
def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions):
|
|
|
|
del dimensions # Unused.
|
2023-11-07 07:15:23 -08:00
|
|
|
(aval_in,) = ctx.avals_in
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if not aval_out.shape:
|
2024-08-08 07:20:57 -07:00
|
|
|
if aval_out.dtype.itemsize != 4:
|
|
|
|
raise ValueError(
|
|
|
|
"Only arrays with 32-bit element types can be converted to scalars,"
|
|
|
|
f" but got: {aval_out.dtype}. Try casting the input before squeezing"
|
|
|
|
" the scalar."
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
return vector.extract(x, [], [0] * len(aval_in.shape))
|
2025-01-14 20:33:34 -08:00
|
|
|
return vector.shape_cast(
|
|
|
|
aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
|
|
),
|
|
|
|
x,
|
|
|
|
)
|
2023-09-07 03:44:18 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.squeeze_p] = _squeeze_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension):
|
2025-01-14 20:33:34 -08:00
|
|
|
out_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
return tpu.concatenate(out_type, xs, dimension=dimension)
|
2023-09-07 03:44:18 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule
|
|
|
|
|
|
|
|
|
2024-12-17 10:05:58 -08:00
|
|
|
def _split_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext, x, *, sizes, axis
|
|
|
|
):
|
|
|
|
(x_aval,) = ctx.avals_in
|
|
|
|
slice_size = np.array(x_aval.shape, dtype=np.int64)
|
|
|
|
starts = np.zeros_like(slice_size)
|
|
|
|
strides = np.ones_like(slice_size)
|
|
|
|
outs = []
|
|
|
|
for size, aval_out in zip(sizes, ctx.avals_out):
|
|
|
|
slice_size[axis] = size
|
|
|
|
outs.append(
|
|
|
|
vector.extract_strided_slice(
|
2025-01-14 20:33:34 -08:00
|
|
|
aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, aval_out
|
|
|
|
),
|
|
|
|
x,
|
|
|
|
starts,
|
|
|
|
slice_size,
|
|
|
|
strides,
|
2024-12-17 10:05:58 -08:00
|
|
|
)
|
|
|
|
)
|
|
|
|
starts[axis] += size
|
|
|
|
return outs
|
|
|
|
|
|
|
|
lowering_rules[lax.split_p] = _split_lowering_rule
|
|
|
|
|
|
|
|
|
2024-10-17 21:16:18 -07:00
|
|
|
def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension,
|
|
|
|
sharding):
|
2025-01-14 20:33:34 -08:00
|
|
|
out_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
return tpu.iota(out_type, dimension=dimension)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.iota_p] = _iota_lowering_rule
|
|
|
|
|
|
|
|
|
2025-01-31 21:26:28 -08:00
|
|
|
def _gather_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
x,
|
|
|
|
indices,
|
|
|
|
*,
|
|
|
|
dimension_numbers,
|
|
|
|
slice_sizes,
|
|
|
|
unique_indices,
|
|
|
|
indices_are_sorted,
|
|
|
|
mode,
|
|
|
|
fill_value,
|
|
|
|
):
|
|
|
|
in_aval = ctx.avals_in[0]
|
|
|
|
indices_aval = ctx.avals_in[1]
|
|
|
|
out_aval = ctx.avals_out[0]
|
|
|
|
|
|
|
|
if len(in_aval.shape) != 2:
|
|
|
|
raise NotImplementedError("Only 2D gather is supported")
|
|
|
|
if pallas_utils.dtype_bitwidth(in_aval.dtype) != 32:
|
|
|
|
raise NotImplementedError("Only 32-bit gather is supported")
|
|
|
|
if in_aval.shape != indices_aval.shape[:-1] != out_aval.shape:
|
|
|
|
raise ValueError("Shape mismatch in input, indices and output")
|
|
|
|
|
|
|
|
out_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
|
|
)
|
|
|
|
# During lowering jnp.take_along_axis to lax.gather, we append extra dimension
|
|
|
|
# to the end of the indices array. We should reshape it back to the original
|
|
|
|
# shape before lowering to Mosaic and rely on MLIR CSE to remove the reshapes.
|
|
|
|
assert indices_aval.shape == in_aval.shape + (1,)
|
|
|
|
recovered_indices = vector.shape_cast(
|
|
|
|
ir.VectorType.get(in_aval.shape, ir.IntegerType.get_signless(32)),
|
|
|
|
indices,
|
|
|
|
)
|
|
|
|
# Note: current support for lax.gather is still very limited.
|
|
|
|
del fill_value
|
|
|
|
if (
|
|
|
|
slice_sizes == (1, 1)
|
|
|
|
and not unique_indices
|
|
|
|
and not indices_are_sorted
|
2025-02-03 22:05:08 -08:00
|
|
|
and mode
|
|
|
|
in (
|
|
|
|
lax.GatherScatterMode.FILL_OR_DROP,
|
|
|
|
lax.GatherScatterMode.PROMISE_IN_BOUNDS,
|
|
|
|
)
|
2025-01-31 21:26:28 -08:00
|
|
|
):
|
|
|
|
if dimension_numbers == lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(),
|
|
|
|
collapsed_slice_dims=(0,),
|
|
|
|
start_index_map=(0,),
|
|
|
|
operand_batching_dims=(1,),
|
|
|
|
start_indices_batching_dims=(1,),
|
|
|
|
):
|
|
|
|
return tpu.dynamic_gather(out_type, x, recovered_indices, 0)
|
|
|
|
if dimension_numbers == lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(),
|
|
|
|
collapsed_slice_dims=(1,),
|
|
|
|
start_index_map=(1,),
|
|
|
|
operand_batching_dims=(0,),
|
|
|
|
start_indices_batching_dims=(0,),
|
|
|
|
):
|
|
|
|
return tpu.dynamic_gather(out_type, x, recovered_indices, 1)
|
|
|
|
raise NotImplementedError("Unsupported gather")
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.gather_p] = _gather_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation):
|
|
|
|
if permutation != (1, 0):
|
|
|
|
raise NotImplementedError
|
2025-01-14 20:33:34 -08:00
|
|
|
out_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
return vector.transpose(out_type, x, permutation)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.transpose_p] = _transpose_lowering_rule
|
|
|
|
|
|
|
|
|
2023-09-06 02:14:42 -07:00
|
|
|
def _bcast(x, y, x_aval, y_aval, out_aval):
|
2024-03-25 08:59:59 -07:00
|
|
|
x_dtype = x_aval.dtype
|
|
|
|
y_dtype = y_aval.dtype
|
|
|
|
if y_aval.weak_type:
|
|
|
|
y_dtype = x_aval.dtype
|
|
|
|
elif x_aval.weak_type:
|
|
|
|
x_dtype = y_aval.dtype
|
2023-09-27 13:33:04 -07:00
|
|
|
if isinstance(x, (np.ndarray, np.number, int, float)):
|
2024-03-25 08:59:59 -07:00
|
|
|
if getattr(y, "type", None) == ir.IndexType.get():
|
2023-09-06 02:14:42 -07:00
|
|
|
mlir_type = y.type
|
|
|
|
else:
|
2024-03-25 08:59:59 -07:00
|
|
|
mlir_type = _dtype_to_ir_type(x_dtype)
|
2023-09-06 02:14:42 -07:00
|
|
|
x = ir_constant(x, mlir_type)
|
2023-09-27 13:33:04 -07:00
|
|
|
if isinstance(y, (np.ndarray, np.number, int, float)):
|
2024-03-25 08:59:59 -07:00
|
|
|
if getattr(x, "type", None) == ir.IndexType.get():
|
2023-09-06 02:14:42 -07:00
|
|
|
mlir_type = x.type
|
|
|
|
else:
|
2024-03-25 08:59:59 -07:00
|
|
|
mlir_type = _dtype_to_ir_type(y_dtype)
|
2023-09-06 02:14:42 -07:00
|
|
|
y = ir_constant(y, mlir_type)
|
|
|
|
out_shape = list(out_aval.shape)
|
|
|
|
if x_aval.shape != out_aval.shape:
|
2024-03-25 08:59:59 -07:00
|
|
|
x_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(x_dtype))
|
2024-12-17 02:16:27 -08:00
|
|
|
x = vector.broadcast(x_ty, x)
|
2023-09-06 02:14:42 -07:00
|
|
|
if y_aval.shape != out_aval.shape:
|
2024-03-25 08:59:59 -07:00
|
|
|
y_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(y_dtype))
|
2024-12-17 02:16:27 -08:00
|
|
|
y = vector.broadcast(y_ty, y)
|
2023-09-06 02:14:42 -07:00
|
|
|
return x, y
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _add_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.integer):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.addi(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.floating):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.addf(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.add_p] = _add_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.add_p)
|
2024-01-29 21:41:40 -08:00
|
|
|
lowering_rules[ad_util.add_any_p] = _add_lowering_rule
|
|
|
|
skip_mlir_conversions.add(ad_util.add_any_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2024-12-22 00:50:12 -08:00
|
|
|
class FoldingError(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def _fold_and_get_constant_value(x):
|
|
|
|
def _fold(x, fuel):
|
|
|
|
if fuel <= 0:
|
|
|
|
raise FoldingError("Folding depth exceeded")
|
|
|
|
op_name = getattr(x.owner, "name", None)
|
|
|
|
binop_folds = {
|
|
|
|
"arith.maxsi": max,
|
|
|
|
"arith.minsi": min,
|
|
|
|
}
|
|
|
|
if op_name == "arith.constant":
|
|
|
|
if ir.IntegerType.isinstance(x.type):
|
|
|
|
return ir.IntegerAttr(x.owner.attributes["value"]).value
|
|
|
|
elif ir.FloatType.isinstance(x.type):
|
|
|
|
return ir.FloatAttr(x.owner.attributes["value"]).value
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unsupported constant type: {x.type}")
|
|
|
|
if op_name in binop_folds:
|
|
|
|
return binop_folds[op_name](_fold(v, fuel - 1) for v in x.owner.operands)
|
|
|
|
raise FoldingError(f"Folding not supported for {x.owner}")
|
|
|
|
|
|
|
|
try:
|
|
|
|
return _fold(x, 10)
|
|
|
|
except FoldingError:
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _max_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.maxsi(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.maxui(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.floating):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.maximumf(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.max_p] = _max_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.max_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2023-09-27 12:34:31 -07:00
|
|
|
def _min_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.minsi(x, y)
|
2023-09-27 12:34:31 -07:00
|
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.minui(x, y)
|
2023-09-27 12:34:31 -07:00
|
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.floating):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.minimumf(x, y)
|
2023-09-27 12:34:31 -07:00
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.min_p] = _min_lowering_rule
|
|
|
|
skip_mlir_conversions.add(lax.min_p)
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _sub_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.integer):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.subi(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.floating):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.subf(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.sub_p] = _sub_lowering_rule
|
2024-11-07 15:01:12 -08:00
|
|
|
skip_mlir_conversions.add(lax.sub_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _mul_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.integer):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.muli(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.floating):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.mulf(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.mul_p] = _mul_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.mul_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
(aval_out,) = ctx.avals_out
|
2024-10-31 14:37:02 -07:00
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.divsi(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.divui(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
elif jnp.issubdtype(aval_out.dtype, jnp.floating):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.divf(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.div_p] = _div_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.div_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _rem_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
|
|
|
(aval_out,) = ctx.avals_out
|
2024-10-22 11:01:07 -07:00
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.remsi(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.remui(x, y)
|
2024-10-22 11:01:07 -07:00
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.floating):
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.remf(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.rem_p] = _rem_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.rem_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _abs_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.integer):
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.absi(x)
|
2023-09-06 02:14:42 -07:00
|
|
|
if jnp.issubdtype(aval_out.dtype, jnp.floating):
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.absf(x)
|
2023-08-01 16:42:26 -07:00
|
|
|
raise NotImplementedError(aval_out.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.abs_p] = _abs_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _neg_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
(x_aval,) = ctx.avals_in
|
|
|
|
new_ctx = ctx.replace(
|
|
|
|
avals_in=(jax_core.ShapedArray((), x_aval.dtype), x_aval),
|
|
|
|
block_shapes=((), *ctx.block_shapes)
|
|
|
|
)
|
|
|
|
return _sub_lowering_rule(new_ctx, np.array(0, dtype=x_aval.dtype), x)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.neg_p] = _neg_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.neg_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2024-07-25 18:04:44 +08:00
|
|
|
def _sign_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-09-09 14:48:50 -07:00
|
|
|
return lower_fun(
|
|
|
|
pallas_utils.sign_lowering_helper, multiple_results=False,
|
|
|
|
)(ctx, x)
|
2024-07-25 18:04:44 +08:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.sign_p] = _sign_lowering_rule
|
|
|
|
|
|
|
|
|
2024-10-31 17:33:50 -07:00
|
|
|
def _nextafter_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
return lower_fun(
|
|
|
|
pallas_utils.nextafter_lowering_helper, multiple_results=False,
|
|
|
|
)(ctx, x, y)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.nextafter_p] = _nextafter_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.rsqrt(x)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule
|
|
|
|
|
|
|
|
|
2023-09-28 12:39:07 -07:00
|
|
|
def _sqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.sqrt(x)
|
2023-09-28 12:39:07 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.sqrt_p] = _sqrt_lowering_rule
|
|
|
|
|
|
|
|
|
2024-11-13 11:14:16 +02:00
|
|
|
def _square_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
if jnp.issubdtype(ctx.avals_in[0].dtype, jnp.integer):
|
|
|
|
return arith.muli(x, x)
|
|
|
|
return arith.mulf(x, x)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.square_p] = _square_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _exp_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.exp(x)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.exp_p] = _exp_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-22 19:41:30 -07:00
|
|
|
def _pow_lowering_rule(ctx: LoweringRuleContext, x, y):
|
2024-10-23 09:37:23 -07:00
|
|
|
# jax accepts float base (x) and integer/float exponent (y), and integer
|
|
|
|
# exponent is casted to float.
|
2025-01-14 20:33:34 -08:00
|
|
|
out_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
|
|
)
|
2024-10-23 09:37:23 -07:00
|
|
|
if jnp.issubdtype(ctx.avals_in[1].dtype, jnp.integer):
|
|
|
|
y = arith.sitofp(out_type, y)
|
2023-08-22 19:41:30 -07:00
|
|
|
if not isinstance(x, ir.Value) and x == 2.:
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.exp2(y)
|
2023-09-27 13:33:04 -07:00
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.powf(x, y)
|
2023-08-22 19:41:30 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.pow_p] = _pow_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.pow_p)
|
2023-08-22 19:41:30 -07:00
|
|
|
|
|
|
|
|
2023-09-15 16:00:19 -07:00
|
|
|
def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y):
|
|
|
|
return lower_fun(lax_internal._integer_pow, multiple_results=False)(
|
|
|
|
ctx, x, y=y)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-16 17:04:28 -07:00
|
|
|
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
# exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior
|
|
|
|
# here.
|
2024-12-18 05:57:53 -08:00
|
|
|
return lower_fun(
|
|
|
|
lambda x: jnp.exp(jnp.astype(np.log(2), x.dtype) * x),
|
|
|
|
multiple_results=False,
|
|
|
|
)(ctx, x)
|
2023-08-16 17:04:28 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.exp2_p] = _exp2_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.exp2_p)
|
|
|
|
|
2023-08-16 17:04:28 -07:00
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-10-18 16:13:46 -07:00
|
|
|
neg_x = arith.negf(x)
|
|
|
|
exp_neg_x = math.exp(neg_x)
|
2023-08-01 16:42:26 -07:00
|
|
|
aval_out = ctx.avals_out[0]
|
2025-01-14 20:33:34 -08:00
|
|
|
out_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, aval_out
|
|
|
|
)
|
2023-09-27 13:33:04 -07:00
|
|
|
if aval_out.shape == ():
|
|
|
|
one = ir_constant(1.0, mlir_type=out_type)
|
|
|
|
else:
|
2024-12-17 02:16:27 -08:00
|
|
|
one = vector.broadcast(out_type, ir_constant(1.0))
|
2024-10-18 16:13:46 -07:00
|
|
|
denom = arith.addf(one, exp_neg_x)
|
|
|
|
return arith.divf(one, denom)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.logistic_p] = _logistic_lowering_rule
|
|
|
|
|
|
|
|
|
2023-09-27 13:33:04 -07:00
|
|
|
def _sin_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.sin(x)
|
2023-09-27 13:33:04 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.sin_p] = _sin_lowering_rule
|
|
|
|
|
|
|
|
|
2024-09-30 16:11:36 -07:00
|
|
|
def _cos_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.cos(x)
|
2024-09-30 16:11:36 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.cos_p] = _cos_lowering_rule
|
|
|
|
|
|
|
|
|
2024-10-01 15:08:51 -07:00
|
|
|
def _tan_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.tan(x)
|
2024-10-01 15:08:51 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.tan_p] = _tan_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.tanh(x)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.tanh_p] = _tanh_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _log_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.log(x)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.log_p] = _log_lowering_rule
|
|
|
|
|
2023-09-06 02:14:42 -07:00
|
|
|
|
2023-09-19 19:24:46 -07:00
|
|
|
def _log1p_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.log1p(x)
|
2023-09-19 19:24:46 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.log1p_p] = _log1p_lowering_rule
|
2024-04-02 02:54:45 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method):
|
|
|
|
if rounding_method == 0:
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.round(x)
|
2024-04-02 02:54:45 -07:00
|
|
|
elif rounding_method == 1:
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.roundeven(x)
|
2024-04-02 02:54:45 -07:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"Unsupported rounding method: {rounding_method}")
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.round_p] = _round_lowering_rule
|
2023-09-19 19:24:46 -07:00
|
|
|
|
|
|
|
|
2024-10-03 16:22:33 -07:00
|
|
|
def _ceil_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.ceil(x)
|
2024-10-03 16:22:33 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.ceil_p] = _ceil_lowering_rule
|
|
|
|
|
|
|
|
|
2024-10-03 02:51:46 -07:00
|
|
|
def _floor_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.floor(x)
|
2024-10-03 02:51:46 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.floor_p] = _floor_lowering_rule
|
|
|
|
|
|
|
|
|
2024-10-09 13:45:42 -07:00
|
|
|
def _clz_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.ctlz(x)
|
2024-10-09 13:45:42 -07:00
|
|
|
|
|
|
|
lowering_rules[lax.clz_p] = _clz_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _population_count_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
aval_out = ctx.avals_out[0]
|
|
|
|
if aval_out.shape == ():
|
|
|
|
raise ValueError("Population count is not supported on scalars")
|
2024-10-18 16:13:46 -07:00
|
|
|
return math.ctpop(x)
|
2024-10-09 13:45:42 -07:00
|
|
|
|
|
|
|
lowering_rules[lax.population_count_p] = _population_count_lowering_rule
|
|
|
|
|
|
|
|
|
2024-10-09 09:28:07 -07:00
|
|
|
# Mapping for signed integer comparisons.
|
|
|
|
_cmpsi_lowering_types = {
|
|
|
|
lax.eq_p: arith.CmpIPredicate.eq,
|
|
|
|
lax.ne_p: arith.CmpIPredicate.ne,
|
|
|
|
lax.lt_p: arith.CmpIPredicate.slt,
|
|
|
|
lax.le_p: arith.CmpIPredicate.sle,
|
|
|
|
lax.gt_p: arith.CmpIPredicate.sgt,
|
|
|
|
lax.ge_p: arith.CmpIPredicate.sge,
|
2023-08-01 16:42:26 -07:00
|
|
|
}
|
|
|
|
|
2024-10-09 09:28:07 -07:00
|
|
|
# Mapping for unsigned integer comparisons.
|
|
|
|
_cmpui_lowering_types = {
|
|
|
|
lax.eq_p: arith.CmpIPredicate.eq,
|
|
|
|
lax.ne_p: arith.CmpIPredicate.ne,
|
|
|
|
lax.lt_p: arith.CmpIPredicate.ult,
|
|
|
|
lax.le_p: arith.CmpIPredicate.ule,
|
|
|
|
lax.gt_p: arith.CmpIPredicate.ugt,
|
|
|
|
lax.ge_p: arith.CmpIPredicate.uge,
|
|
|
|
}
|
|
|
|
|
|
|
|
# Mapping for floating point comparisons.
|
2023-08-01 16:42:26 -07:00
|
|
|
_cmpf_lowering_types = {
|
2024-10-09 09:28:07 -07:00
|
|
|
lax.eq_p: arith.CmpFPredicate.OEQ,
|
|
|
|
lax.ne_p: arith.CmpFPredicate.ONE,
|
|
|
|
lax.lt_p: arith.CmpFPredicate.OLT,
|
|
|
|
lax.le_p: arith.CmpFPredicate.OLE,
|
|
|
|
lax.gt_p: arith.CmpFPredicate.OGT,
|
|
|
|
lax.ge_p: arith.CmpFPredicate.OGE,
|
2023-08-01 16:42:26 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2024-12-09 08:22:56 -08:00
|
|
|
# The relationship between comparison operations on booleans and boolean
|
|
|
|
# algebra is as follows:
|
|
|
|
# eq(x, y) = !(x ^ y)
|
|
|
|
# ne(x, y) = x ^ y
|
|
|
|
# lt(x, y) = !x && y
|
|
|
|
# le(x, y) = !x || y
|
|
|
|
# gt(x, y) = x && !y
|
|
|
|
# ge(x, y) = x || !y
|
|
|
|
def _cmp_boolean_lowering_helper(primitive, x: Array, y: Array):
|
|
|
|
"""A helper function for lowering comparison operations for boolean inputs.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
primitive: A JAX primitive representing a comparison operation, which is
|
|
|
|
one of the following: `lax.eq_p` (equals), `lax.ne_p` (not equals),
|
|
|
|
`lax.lt_p` (less than), `lax.le_p` (less than or equal to),
|
|
|
|
`lax.gt_p` (greater than), or `lax.ge_p` (greater than or equal to).
|
|
|
|
x: A boolean array representing the first operand in the comparison.
|
|
|
|
y: A boolean array representing the second operand in the comparison.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A boolean array that is the result of applying the comparison operation
|
|
|
|
between `x` and `y` based on the given primitive.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
ValueError: If an unsupported comparison primitive is provided.
|
|
|
|
"""
|
|
|
|
if primitive == lax.eq_p:
|
|
|
|
return jnp.logical_not(jnp.logical_xor(x, y))
|
|
|
|
elif primitive == lax.ne_p:
|
|
|
|
return jnp.logical_xor(x, y)
|
|
|
|
elif primitive == lax.lt_p:
|
|
|
|
return jnp.logical_and(jnp.logical_not(x), y)
|
|
|
|
elif primitive == lax.le_p:
|
|
|
|
return jnp.logical_or(jnp.logical_not(x), y)
|
|
|
|
elif primitive == lax.gt_p:
|
|
|
|
return jnp.logical_and(x, jnp.logical_not(y))
|
|
|
|
elif primitive == lax.ge_p:
|
|
|
|
return jnp.logical_or(x, jnp.logical_not(y))
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unsupported comparison primitive: {primitive}")
|
|
|
|
|
|
|
|
|
|
|
|
def _cmp_lowering_rule(primitive, ctx: LoweringRuleContext, x, y):
|
2023-09-06 02:14:42 -07:00
|
|
|
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
2023-08-01 16:42:26 -07:00
|
|
|
x_aval, y_aval = ctx.avals_in
|
2024-10-09 09:28:07 -07:00
|
|
|
if x_aval.dtype != y_aval.dtype:
|
|
|
|
raise ValueError(
|
|
|
|
f"Mixed dtype operands in cmp: {x_aval.dtype}, {y_aval.dtype}"
|
|
|
|
)
|
|
|
|
dtype = x_aval.dtype
|
2024-08-06 05:36:55 -07:00
|
|
|
|
2024-10-09 09:28:07 -07:00
|
|
|
if jnp.issubdtype(dtype, jnp.bool_):
|
2024-12-09 08:22:56 -08:00
|
|
|
return lower_fun(
|
|
|
|
functools.partial(_cmp_boolean_lowering_helper, primitive),
|
|
|
|
multiple_results=False,
|
|
|
|
)(ctx, x, y)
|
2024-10-09 09:28:07 -07:00
|
|
|
|
|
|
|
if jnp.issubdtype(dtype, jnp.integer):
|
|
|
|
is_uint = jnp.issubdtype(dtype, jnp.unsignedinteger)
|
2024-12-09 08:22:56 -08:00
|
|
|
pred = (
|
|
|
|
_cmpui_lowering_types if is_uint else _cmpsi_lowering_types
|
|
|
|
)[primitive]
|
2023-08-01 16:42:26 -07:00
|
|
|
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
|
2024-10-15 12:23:55 -07:00
|
|
|
return arith.cmpi(predicate, x, y)
|
2024-10-09 09:28:07 -07:00
|
|
|
|
|
|
|
if jnp.issubdtype(dtype, jnp.floating):
|
2024-12-09 08:22:56 -08:00
|
|
|
pred = _cmpf_lowering_types[primitive]
|
2023-08-01 16:42:26 -07:00
|
|
|
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
|
2024-10-15 12:23:55 -07:00
|
|
|
return arith.cmpf(predicate, x, y)
|
2024-10-09 09:28:07 -07:00
|
|
|
|
|
|
|
raise NotImplementedError(f"Unsupported dtype in cmp: {dtype}")
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.eq_p] = functools.partial(_cmp_lowering_rule, lax.eq_p)
|
|
|
|
lowering_rules[lax.ne_p] = functools.partial(_cmp_lowering_rule, lax.ne_p)
|
|
|
|
lowering_rules[lax.lt_p] = functools.partial(_cmp_lowering_rule, lax.lt_p)
|
|
|
|
lowering_rules[lax.le_p] = functools.partial(_cmp_lowering_rule, lax.le_p)
|
|
|
|
lowering_rules[lax.gt_p] = functools.partial(_cmp_lowering_rule, lax.gt_p)
|
|
|
|
lowering_rules[lax.ge_p] = functools.partial(_cmp_lowering_rule, lax.ge_p)
|
|
|
|
|
|
|
|
|
2024-03-25 08:59:59 -07:00
|
|
|
def _and_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.andi(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.and_p] = _and_lowering_rule
|
2024-03-25 08:59:59 -07:00
|
|
|
skip_mlir_conversions.add(lax.and_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2024-12-10 16:37:25 -08:00
|
|
|
def _is_finite_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
out_aval, = ctx.avals_out
|
2025-01-14 20:33:34 -08:00
|
|
|
out_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
|
|
)
|
2024-12-10 16:37:25 -08:00
|
|
|
return _not_lowering_rule(ctx, tpu.weird(out_type, x))
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.is_finite_p] = _is_finite_lowering_rule
|
|
|
|
|
|
|
|
|
2024-03-25 08:59:59 -07:00
|
|
|
def _or_lowering_rule(ctx: LoweringRuleContext, x, y):
|
|
|
|
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.ori(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.or_p] = _or_lowering_rule
|
2024-03-25 08:59:59 -07:00
|
|
|
skip_mlir_conversions.add(lax.or_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
2024-01-09 09:58:04 -08:00
|
|
|
|
|
|
|
def _not_lowering_rule(ctx: LoweringRuleContext, x):
|
|
|
|
# The primitive not_p is lowered to
|
|
|
|
# https://github.com/openxla/stablehlo/blob/main/docs/spec.md#not
|
|
|
|
# which is arithmetic for integers and logical for booleans.
|
|
|
|
# Lowering to:
|
|
|
|
# xor x, -1
|
|
|
|
# covers both cases.
|
|
|
|
out_aval = ctx.avals_out[0]
|
2024-07-15 17:58:27 -07:00
|
|
|
out_scalar_type = _dtype_to_ir_type(out_aval.dtype)
|
2024-01-09 09:58:04 -08:00
|
|
|
if not out_aval.shape:
|
|
|
|
# Create a scalar constant.
|
|
|
|
minus_one = ir_constant(-1, out_scalar_type)
|
|
|
|
else:
|
|
|
|
# Create a vector constant.
|
2025-01-14 20:33:34 -08:00
|
|
|
out_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
|
|
)
|
2024-01-09 09:58:04 -08:00
|
|
|
scalar_minus_one = ir.IntegerAttr.get(out_scalar_type, -1)
|
|
|
|
minus_one = arith.ConstantOp(
|
|
|
|
out_type, ir.DenseElementsAttr.get_splat(out_type, scalar_minus_one)
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.xori(x, minus_one)
|
2024-01-09 09:58:04 -08:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.not_p] = _not_lowering_rule
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args):
|
|
|
|
if len(args) > 1:
|
|
|
|
raise NotImplementedError("select_n only supported with <= 2 arguments")
|
|
|
|
pred_aval, x_aval = ctx.avals_in[:2]
|
|
|
|
if pred_aval.dtype != np.dtype(np.bool_):
|
|
|
|
lower_ctx = LoweringRuleContext(
|
|
|
|
ctx.lowering_context,
|
|
|
|
avals_in=[pred_aval],
|
|
|
|
avals_out=[pred_aval.update(dtype=np.bool_)],
|
|
|
|
block_shapes=[None],
|
|
|
|
)
|
|
|
|
pred = lower_fun(lambda x: x != 0, multiple_results=False)(lower_ctx, pred)
|
|
|
|
if not args:
|
|
|
|
return x
|
2023-09-06 02:14:42 -07:00
|
|
|
# Assume x and y, which we check above.
|
2023-08-01 16:42:26 -07:00
|
|
|
y, = args
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.select(pred, y, x)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.select_n_p] = _select_n_lowering_rule
|
|
|
|
|
2023-09-27 13:33:04 -07:00
|
|
|
|
|
|
|
def _clamp(min, operand, max):
|
|
|
|
res = jnp.maximum(operand, min)
|
|
|
|
return jnp.minimum(res, max)
|
|
|
|
|
|
|
|
|
|
|
|
def _clamp_lowering_rule(ctx: LoweringRuleContext, min, operand, max):
|
|
|
|
"""Compute minimum_p(maximum_p(min, operand), max)."""
|
|
|
|
return lower_fun(_clamp, multiple_results=False)(ctx, min, operand, max)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.clamp_p] = _clamp_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _for_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
*args,
|
|
|
|
jaxpr,
|
|
|
|
nsteps,
|
|
|
|
reverse,
|
|
|
|
unroll,
|
|
|
|
which_linear,
|
|
|
|
):
|
|
|
|
should_discharge = [
|
|
|
|
not isinstance(aval, state.AbstractRef) for aval in ctx.avals_in
|
|
|
|
]
|
|
|
|
jaxpr, () = state_discharge.discharge_state(
|
|
|
|
jaxpr, (), should_discharge=[False, *should_discharge]
|
|
|
|
)
|
|
|
|
for i in range(nsteps):
|
|
|
|
if reverse:
|
|
|
|
i = nsteps - i - 1
|
|
|
|
i = ir_constant(i)
|
|
|
|
lowering_context = ctx.lowering_context.replace(
|
|
|
|
block_shapes=[(), *ctx.block_shapes],
|
|
|
|
)
|
|
|
|
non_ref_args = jaxpr_subcomp(lowering_context, jaxpr, i, *args)
|
|
|
|
non_ref_args_iter = iter(non_ref_args)
|
|
|
|
args = [
|
|
|
|
next(non_ref_args_iter) if s else a
|
|
|
|
for a, s in zip(args, should_discharge)
|
|
|
|
]
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[for_loop.for_p] = _for_lowering_rule
|
|
|
|
|
|
|
|
|
2023-12-07 22:54:32 -08:00
|
|
|
def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext,
|
2024-04-17 02:33:36 -07:00
|
|
|
jaxpr: jax_core.Jaxpr, start: int | ir.Value,
|
|
|
|
num_steps: int | ir.Value, consts, *args,
|
2023-12-07 22:54:32 -08:00
|
|
|
has_loop_index: bool,
|
|
|
|
unroll: int):
|
|
|
|
def _run_body(i, args):
|
|
|
|
if has_loop_index:
|
|
|
|
lowering_context = ctx.lowering_context.replace(
|
|
|
|
block_shapes=ctx.block_shapes)
|
|
|
|
args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)
|
|
|
|
else:
|
|
|
|
del i
|
|
|
|
lowering_context = ctx.lowering_context.replace(
|
|
|
|
block_shapes=ctx.block_shapes[:len(consts)]
|
|
|
|
+ ctx.block_shapes[len(consts) + 1:],
|
|
|
|
)
|
|
|
|
args = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
|
|
|
|
return args
|
2024-01-18 21:43:59 -08:00
|
|
|
|
2024-04-17 02:33:36 -07:00
|
|
|
if (
|
|
|
|
not isinstance(start, ir.Value)
|
|
|
|
and not isinstance(num_steps, ir.Value)
|
|
|
|
and num_steps == unroll
|
|
|
|
):
|
2023-12-07 22:54:32 -08:00
|
|
|
# No need for an scf.For. We can just unroll completely
|
|
|
|
for i in range(start, start + num_steps):
|
|
|
|
args = _run_body(
|
2024-01-11 06:32:57 -08:00
|
|
|
ir_constant(i, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))),
|
2023-12-07 22:54:32 -08:00
|
|
|
args,
|
|
|
|
)
|
|
|
|
return args
|
|
|
|
if unroll != 1:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.")
|
2024-08-05 04:23:15 -07:00
|
|
|
lbd = _ensure_mlir_value(start, pallas_core.index_map_grid_aval)
|
|
|
|
ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pallas_core.index_map_grid_aval))
|
2024-01-11 06:32:57 -08:00
|
|
|
step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
|
2023-12-07 22:54:32 -08:00
|
|
|
for_op = scf.ForOp(lbd, ubd, step, args)
|
|
|
|
with ir.InsertionPoint(for_op.body):
|
|
|
|
iv = for_op.induction_variable
|
|
|
|
inner_args = for_op.inner_iter_args
|
|
|
|
inner_out = _run_body(iv, inner_args)
|
|
|
|
scf.YieldOp(inner_out)
|
|
|
|
return for_op.results
|
|
|
|
|
|
|
|
|
2023-08-04 13:43:04 -07:00
|
|
|
def _scan_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
*args,
|
2024-08-06 22:32:46 -07:00
|
|
|
jaxpr: jax_core.ClosedJaxpr,
|
2023-08-04 13:43:04 -07:00
|
|
|
linear: tuple[bool, ...],
|
|
|
|
length: int,
|
|
|
|
reverse: bool,
|
2024-04-17 02:33:36 -07:00
|
|
|
unroll: bool | int,
|
2023-08-04 13:43:04 -07:00
|
|
|
num_consts: int,
|
|
|
|
num_carry: int,
|
2024-03-28 10:54:02 -07:00
|
|
|
_split_transpose: bool,
|
2023-08-04 13:43:04 -07:00
|
|
|
):
|
2024-03-28 10:54:02 -07:00
|
|
|
del _split_transpose
|
2023-08-04 13:43:04 -07:00
|
|
|
# Can only handle fori_loop-like scans
|
|
|
|
num_extensive = len(args) - num_consts - num_carry
|
|
|
|
if num_extensive: raise NotImplementedError
|
|
|
|
if reverse: raise NotImplementedError
|
2023-12-07 22:54:32 -08:00
|
|
|
del linear, num_extensive, reverse
|
2023-08-04 13:43:04 -07:00
|
|
|
|
|
|
|
jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts
|
|
|
|
if jaxpr_consts: raise NotImplementedError
|
|
|
|
del jaxpr_consts
|
|
|
|
|
2024-04-17 02:33:36 -07:00
|
|
|
jaxpr, has_loop_index = pallas_utils.pattern_match_scan_to_fori_loop(
|
|
|
|
jaxpr, num_consts, num_carry
|
|
|
|
)
|
2023-08-04 13:43:04 -07:00
|
|
|
consts, args = split_list(args, [num_consts])
|
2024-01-18 21:43:59 -08:00
|
|
|
consts_avals, args_avals = split_list(ctx.avals_in, [num_consts])
|
2023-08-04 13:43:04 -07:00
|
|
|
if has_loop_index:
|
|
|
|
loop_index_start, *args = args
|
2024-01-18 21:43:59 -08:00
|
|
|
args_avals = args_avals[1:]
|
2023-08-04 13:43:04 -07:00
|
|
|
else:
|
|
|
|
loop_index_start = 0
|
2024-01-18 21:43:59 -08:00
|
|
|
consts = map(_ensure_mlir_value, consts, consts_avals)
|
|
|
|
args = map(_ensure_mlir_value, args, args_avals)
|
2023-12-07 22:54:32 -08:00
|
|
|
out = _lower_jaxpr_to_for_loop(
|
|
|
|
ctx, jaxpr, loop_index_start, length,
|
|
|
|
consts, *args, has_loop_index=has_loop_index,
|
|
|
|
unroll=unroll)
|
2023-08-04 13:43:04 -07:00
|
|
|
if has_loop_index:
|
|
|
|
out = [ir_constant(length,
|
2024-01-11 06:32:57 -08:00
|
|
|
mlir_type=_dtype_to_ir_type(jnp.dtype('int32'))),
|
2023-08-04 13:43:04 -07:00
|
|
|
*out]
|
|
|
|
return out
|
|
|
|
lowering_rules[lax.scan_p] = _scan_lowering_rule
|
2023-09-06 02:14:42 -07:00
|
|
|
skip_mlir_conversions.add(lax.scan_p)
|
|
|
|
|
2023-08-04 13:43:04 -07:00
|
|
|
|
2024-04-18 11:03:01 -07:00
|
|
|
def _lower_while_via_fori(
|
2024-02-12 18:05:31 -08:00
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
*args,
|
2024-04-18 11:03:01 -07:00
|
|
|
fori_jaxpr,
|
2024-02-12 18:05:31 -08:00
|
|
|
cond_nconsts,
|
|
|
|
cond_jaxpr,
|
|
|
|
body_nconsts,
|
|
|
|
body_jaxpr,
|
|
|
|
):
|
|
|
|
_, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
|
|
|
|
(lb, ub), args = carry[:2], carry[2:]
|
|
|
|
for_out = _lower_jaxpr_to_for_loop(
|
|
|
|
ctx.replace(
|
|
|
|
block_shapes=ctx.block_shapes[: body_nconsts + 1]
|
|
|
|
+ ctx.block_shapes[body_nconsts + 2 :],
|
|
|
|
),
|
2024-04-18 11:03:01 -07:00
|
|
|
fori_jaxpr,
|
2024-02-12 18:05:31 -08:00
|
|
|
lb,
|
2024-04-17 02:33:36 -07:00
|
|
|
arith.subi(ub, lb),
|
2024-02-12 18:05:31 -08:00
|
|
|
body_consts,
|
|
|
|
*args,
|
|
|
|
has_loop_index=True,
|
|
|
|
unroll=1,
|
|
|
|
)
|
|
|
|
return [ub, ub, *for_out]
|
|
|
|
|
|
|
|
|
2024-04-18 11:03:01 -07:00
|
|
|
def _while_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
*args,
|
|
|
|
cond_nconsts,
|
|
|
|
cond_jaxpr,
|
|
|
|
body_nconsts,
|
|
|
|
body_jaxpr,
|
|
|
|
):
|
|
|
|
# First try to lower via a simpler fori loop, which may optimize better.
|
2024-08-06 22:32:46 -07:00
|
|
|
fori_jaxpr, _ = pallas_utils.pattern_match_while_to_fori_loop(
|
2024-04-18 11:03:01 -07:00
|
|
|
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
|
|
|
|
)
|
|
|
|
if fori_jaxpr is not None:
|
|
|
|
return _lower_while_via_fori(
|
|
|
|
ctx,
|
|
|
|
*args,
|
|
|
|
fori_jaxpr=fori_jaxpr,
|
|
|
|
cond_nconsts=cond_nconsts,
|
|
|
|
cond_jaxpr=cond_jaxpr,
|
|
|
|
body_nconsts=body_nconsts,
|
|
|
|
body_jaxpr=body_jaxpr,
|
|
|
|
)
|
|
|
|
|
|
|
|
# If we fail conversion to fori, fallback to an ordinary while loop.
|
|
|
|
cond_consts, body_consts, carry = split_list(
|
|
|
|
args, [cond_nconsts, body_nconsts]
|
|
|
|
)
|
|
|
|
cond_const_block_shapes, body_const_block_shapes, carry_block_shapes = (
|
|
|
|
split_list(ctx.block_shapes, [cond_nconsts, body_nconsts])
|
|
|
|
)
|
|
|
|
carry_types = [a.type for a in carry]
|
2024-08-06 22:32:46 -07:00
|
|
|
while_op = scf.WhileOp(carry_types, carry)
|
2024-04-18 11:03:01 -07:00
|
|
|
|
2024-08-06 22:32:46 -07:00
|
|
|
before_block = while_op.before.blocks.append(*carry_types)
|
2024-04-18 11:03:01 -07:00
|
|
|
with ir.InsertionPoint.at_block_begin(before_block):
|
2024-08-06 22:32:46 -07:00
|
|
|
cond_args = [*cond_consts, *before_block.arguments]
|
2024-04-18 11:03:01 -07:00
|
|
|
[cond] = jaxpr_subcomp(
|
|
|
|
ctx.lowering_context.replace(
|
|
|
|
block_shapes=[*cond_const_block_shapes, *carry_block_shapes]
|
|
|
|
),
|
|
|
|
cond_jaxpr.jaxpr,
|
|
|
|
*cond_args,
|
|
|
|
)
|
|
|
|
scf.condition(cond, before_block.arguments)
|
|
|
|
|
2024-08-06 22:32:46 -07:00
|
|
|
after_block = while_op.after.blocks.append(*carry_types)
|
2024-04-18 11:03:01 -07:00
|
|
|
with ir.InsertionPoint.at_block_begin(after_block):
|
2024-08-06 22:32:46 -07:00
|
|
|
body_args = [*body_consts, *after_block.arguments]
|
2024-04-18 11:03:01 -07:00
|
|
|
loop_out = jaxpr_subcomp(
|
|
|
|
ctx.lowering_context.replace(
|
|
|
|
block_shapes=[*body_const_block_shapes, *carry_block_shapes],
|
|
|
|
),
|
|
|
|
body_jaxpr.jaxpr,
|
2024-08-06 22:32:46 -07:00
|
|
|
*body_args,
|
2024-04-18 11:03:01 -07:00
|
|
|
)
|
2024-08-06 22:32:46 -07:00
|
|
|
if loop_out:
|
|
|
|
scf.yield_(loop_out)
|
|
|
|
return list(while_op.results)
|
2024-04-18 11:03:01 -07:00
|
|
|
|
|
|
|
|
2024-02-12 18:05:31 -08:00
|
|
|
lowering_rules[lax.while_p] = _while_lowering_rule
|
|
|
|
|
2024-06-26 14:25:20 -04:00
|
|
|
def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches):
|
2023-09-27 13:33:04 -07:00
|
|
|
index, *args = args
|
2024-12-22 00:50:12 -08:00
|
|
|
constant_index = _fold_and_get_constant_value(index)
|
|
|
|
|
|
|
|
if constant_index is not None:
|
|
|
|
return jaxpr_subcomp(
|
|
|
|
ctx.lowering_context.replace(block_shapes=ctx.block_shapes[1:]), branches[constant_index].jaxpr, *args
|
|
|
|
)
|
2025-01-14 20:33:34 -08:00
|
|
|
aval_to_ir_type_with_fn = functools.partial(
|
|
|
|
aval_to_ir_type, ctx.lowering_context.dynamic_shape_replacement_fn
|
|
|
|
)
|
|
|
|
out_types = map(aval_to_ir_type_with_fn, ctx.avals_out)
|
2024-10-18 16:13:46 -07:00
|
|
|
pred = arith.cmpi(
|
2023-09-27 13:33:04 -07:00
|
|
|
arith.CmpIPredicate.ne, index, ir_constant(0, index.type)
|
2024-10-18 16:13:46 -07:00
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
if_op = scf.IfOp(pred, out_types, hasElse=True)
|
|
|
|
lowering_context = ctx.lowering_context.replace(
|
|
|
|
block_shapes=ctx.block_shapes[1:],
|
|
|
|
)
|
|
|
|
with ir.InsertionPoint(if_op.then_block):
|
2023-09-27 13:33:04 -07:00
|
|
|
# TODO(b/300272065): Use `scf.IndexSwitchOp` instead of a cascade of
|
|
|
|
# if/else.
|
|
|
|
if len(branches) > 2:
|
|
|
|
out = _cond_lowering_rule(
|
|
|
|
ctx,
|
2024-10-18 16:13:46 -07:00
|
|
|
arith.subi(index, ir_constant(1, index.type)),
|
2023-09-27 13:33:04 -07:00
|
|
|
*args,
|
|
|
|
branches=branches[1:],
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args)
|
2023-08-01 16:42:26 -07:00
|
|
|
scf.YieldOp(out)
|
|
|
|
with ir.InsertionPoint(if_op.else_block):
|
|
|
|
out = jaxpr_subcomp(lowering_context, branches[0].jaxpr, *args)
|
|
|
|
scf.YieldOp(out)
|
|
|
|
return if_op.results
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.cond_p] = _cond_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
|
|
|
|
lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes)
|
|
|
|
return jaxpr_subcomp(lowering_context, jaxpr.jaxpr, *args)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[pjit.pjit_p] = _pjit_lowering_rule
|
|
|
|
|
|
|
|
|
2025-02-03 17:59:44 -08:00
|
|
|
def _mesh_cast_lowering_rule(ctx, x, dst_sharding):
|
|
|
|
return x
|
|
|
|
lowering_rules[pjit.mesh_cast_p] = _mesh_cast_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-04 15:08:26 -07:00
|
|
|
def _custom_jvp_call_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
*args,
|
|
|
|
call_jaxpr: jax_core.Jaxpr,
|
2025-02-07 10:15:47 +02:00
|
|
|
jvp_jaxpr_fun: lu.WrappedFun,
|
2023-08-04 15:08:26 -07:00
|
|
|
num_consts: int,
|
|
|
|
symbolic_zeros: bool,
|
|
|
|
):
|
2025-02-07 10:15:47 +02:00
|
|
|
del jvp_jaxpr_fun
|
2023-08-04 15:08:26 -07:00
|
|
|
if symbolic_zeros: raise NotImplementedError
|
|
|
|
if num_consts: raise NotImplementedError
|
|
|
|
if call_jaxpr.consts: raise NotImplementedError
|
|
|
|
lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes)
|
|
|
|
return jaxpr_subcomp(lowering_context, call_jaxpr.jaxpr, *args)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[custom_derivatives.custom_jvp_call_p] = (
|
|
|
|
_custom_jvp_call_lowering_rule)
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):
|
|
|
|
del ctx, args, kwargs
|
|
|
|
# No-op debug callbacks in Mosaic for now
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[debugging.debug_callback_p] = _debug_callback_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int):
|
2024-08-20 15:06:27 -07:00
|
|
|
|
2024-03-11 08:40:34 -07:00
|
|
|
if ctx.lowering_context.user_grid_indices is None:
|
2023-08-01 16:42:26 -07:00
|
|
|
raise ValueError(
|
|
|
|
f"program id: {axis} was passed, but user did not provide a grid."
|
|
|
|
)
|
2024-03-11 08:40:34 -07:00
|
|
|
length = len(ctx.lowering_context.user_grid_indices)
|
2023-08-01 16:42:26 -07:00
|
|
|
if not (0 <= axis < length):
|
|
|
|
raise ValueError(
|
|
|
|
f"user passed in program id with axis: {axis}, but grid only has"
|
|
|
|
f" length: {length}"
|
|
|
|
)
|
2024-03-11 08:40:34 -07:00
|
|
|
return ctx.lowering_context.user_grid_indices[axis]
|
2023-08-01 16:42:26 -07:00
|
|
|
lowering_rules[primitives.program_id_p] = _program_id_lowering_rule
|
|
|
|
|
2024-02-26 06:38:17 -08:00
|
|
|
def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int):
|
2024-03-11 08:40:34 -07:00
|
|
|
mapped_axes = set(ctx.lowering_context.mapped_dims)
|
|
|
|
seen_user_axes = 0
|
|
|
|
for i in range(ctx.lowering_context.grid_rank):
|
|
|
|
seen_user_axes += int(i not in mapped_axes)
|
|
|
|
if seen_user_axes == axis + 1:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f"user passed in program id with axis: {axis}, but grid only has"
|
|
|
|
f" length: {len(ctx.lowering_context.grid_rank)}"
|
|
|
|
)
|
|
|
|
return tpu.iteration_bound(i)
|
2024-02-26 06:38:17 -08:00
|
|
|
lowering_rules[primitives.num_programs_p] = _num_programs_lowering_rule
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
def _repeat_lowering_rule(ctx: LoweringRuleContext, x, *, repeats, axis):
|
|
|
|
(out_aval,) = ctx.avals_out
|
2025-01-14 20:33:34 -08:00
|
|
|
return tpu.repeat(
|
|
|
|
aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
|
|
),
|
|
|
|
x,
|
|
|
|
axis,
|
|
|
|
repeats,
|
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[tpu_primitives.repeat_p] = _repeat_lowering_rule
|
|
|
|
|
|
|
|
|
2024-02-13 18:20:23 -08:00
|
|
|
def _roll_lowering_rule(
|
2024-06-18 15:51:41 -07:00
|
|
|
ctx: LoweringRuleContext, x, shift, *, axis, stride, stride_axis
|
2024-02-13 18:20:23 -08:00
|
|
|
):
|
2024-06-18 15:51:41 -07:00
|
|
|
(out_aval,) = ctx.avals_out
|
2024-10-18 16:13:46 -07:00
|
|
|
return tpu.dynamic_rotate(
|
2025-01-14 20:33:34 -08:00
|
|
|
aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
|
|
),
|
2024-02-13 18:20:23 -08:00
|
|
|
x,
|
|
|
|
shift,
|
|
|
|
axis,
|
|
|
|
stride=stride,
|
|
|
|
stride_dimension=stride_axis,
|
2024-10-18 16:13:46 -07:00
|
|
|
)
|
2024-02-13 18:20:23 -08:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[tpu_primitives.roll_p] = _roll_lowering_rule
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _slice_lowering_rule(
|
2023-09-06 02:14:42 -07:00
|
|
|
ctx: LoweringRuleContext, x, limit_indices, start_indices, strides
|
2023-08-01 16:42:26 -07:00
|
|
|
):
|
|
|
|
"""Lowers a slice to vector dialect."""
|
|
|
|
(aval_out,) = ctx.avals_out
|
2025-01-14 20:33:34 -08:00
|
|
|
out_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, aval_out
|
|
|
|
)
|
2023-08-01 16:42:26 -07:00
|
|
|
if strides is None:
|
|
|
|
strides = [1] * len(start_indices)
|
|
|
|
sizes = np.array(limit_indices) - np.array(start_indices)
|
2024-10-18 16:13:46 -07:00
|
|
|
return vector.extract_strided_slice(
|
|
|
|
out_type, x, start_indices, sizes, strides
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.slice_p] = _slice_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _xor_lowering_rule(ctx: LoweringRuleContext, x, y):
|
2024-03-25 08:59:59 -07:00
|
|
|
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.xori(x, y)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.xor_p] = _xor_lowering_rule
|
2024-03-25 08:59:59 -07:00
|
|
|
skip_mlir_conversions.add(lax.xor_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _shift_left_lowering_rule(ctx: LoweringRuleContext, x, d):
|
2024-03-25 08:59:59 -07:00
|
|
|
x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out)
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.shli(x, d)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.shift_left_p] = _shift_left_lowering_rule
|
2024-03-25 08:59:59 -07:00
|
|
|
skip_mlir_conversions.add(lax.shift_left_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2024-07-04 18:06:16 +00:00
|
|
|
def _shift_right_arithmetic_lowering_rule(ctx: LoweringRuleContext, x, d):
|
|
|
|
x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out)
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.shrsi(x, d)
|
2024-07-04 18:06:16 +00:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.shift_right_arithmetic_p] = _shift_right_arithmetic_lowering_rule
|
|
|
|
skip_mlir_conversions.add(lax.shift_right_arithmetic_p)
|
|
|
|
|
|
|
|
|
2023-08-01 16:42:26 -07:00
|
|
|
def _shift_right_logical_lowering_rules(ctx: LoweringRuleContext, x, d):
|
2024-03-25 08:59:59 -07:00
|
|
|
x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out)
|
2024-10-18 16:13:46 -07:00
|
|
|
return arith.shrui(x, d)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.shift_right_logical_p] = _shift_right_logical_lowering_rules
|
2024-03-25 08:59:59 -07:00
|
|
|
skip_mlir_conversions.add(lax.shift_right_logical_p)
|
2023-08-01 16:42:26 -07:00
|
|
|
|
|
|
|
|
2024-07-05 08:25:50 +08:00
|
|
|
def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x):
|
2024-07-09 00:27:10 +08:00
|
|
|
return lower_fun(
|
|
|
|
pallas_utils.erf_inv_lowering_helper, multiple_results=False,
|
|
|
|
)(ctx, x)
|
2024-07-05 08:25:50 +08:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.erf_inv_p] = _erf_inv_lowering_rule
|
|
|
|
|
|
|
|
|
2025-03-07 18:28:48 -08:00
|
|
|
def _reciprocal_lowering_rule(ctx: LoweringRuleContext, x, *, approx):
|
|
|
|
if not isinstance(x.type.element_type, ir.F32Type):
|
|
|
|
raise ValueError("Only float32 is supported.")
|
|
|
|
return tpu.reciprocal(x, approx=approx)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[primitives.reciprocal_p] = _reciprocal_lowering_rule
|
|
|
|
|
2024-02-09 10:43:43 -08:00
|
|
|
def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty):
|
|
|
|
del ty
|
|
|
|
(out_aval,) = ctx.avals_out
|
2025-01-14 20:33:34 -08:00
|
|
|
return tpu.bitcast(
|
|
|
|
aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
|
|
),
|
|
|
|
x,
|
|
|
|
)
|
2024-02-09 10:43:43 -08:00
|
|
|
|
|
|
|
lowering_rules[tpu_primitives.bitcast_p] = _bitcast_lowering_rule
|
|
|
|
|
2024-06-10 18:07:33 -07:00
|
|
|
def _bitcast_convert_type_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext, x, *, new_dtype):
|
|
|
|
(in_aval, ) = ctx.avals_in
|
|
|
|
(out_aval,) = ctx.avals_out
|
2024-08-28 15:46:53 -07:00
|
|
|
old_bitwidth = pallas_utils.dtype_bitwidth(in_aval.dtype)
|
|
|
|
new_bitwidth = pallas_utils.dtype_bitwidth(new_dtype)
|
|
|
|
if old_bitwidth != new_bitwidth:
|
2024-06-10 18:07:33 -07:00
|
|
|
raise NotImplementedError("Changing bitwidths not supported.")
|
2025-01-14 20:33:34 -08:00
|
|
|
return tpu.bitcast(
|
|
|
|
aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
|
|
),
|
|
|
|
x,
|
|
|
|
)
|
2024-06-10 18:07:33 -07:00
|
|
|
lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule
|
2023-09-07 16:40:27 -07:00
|
|
|
|
2025-01-14 20:33:34 -08:00
|
|
|
|
|
|
|
def _alloc_value(
|
|
|
|
aval: jax_core.AbstractValue, *, ctx: LoweringRuleContext
|
|
|
|
) -> ir.Value:
|
2024-08-05 04:23:15 -07:00
|
|
|
if isinstance(aval, pallas_core.AbstractMemoryRef):
|
2024-09-19 19:07:35 -07:00
|
|
|
memspace = _memory_space_to_mosaic_attribute(aval.memory_space)
|
2024-01-11 06:32:57 -08:00
|
|
|
if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype):
|
|
|
|
assert aval.memory_space == TPUMemorySpace.SEMAPHORE
|
2025-01-14 20:33:34 -08:00
|
|
|
memref_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
|
|
aval,
|
|
|
|
memory_space=TPUMemorySpace.SEMAPHORE,
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
return tpu.sem_alloc(memref_type)
|
2024-01-11 06:32:57 -08:00
|
|
|
else:
|
|
|
|
out_type = ir.MemRefType.get(
|
2024-07-15 17:58:27 -07:00
|
|
|
aval.shape,
|
|
|
|
_dtype_to_ir_type(aval.dtype, is_kernel_boundary=True),
|
|
|
|
memory_space=memspace)
|
2024-10-18 16:13:46 -07:00
|
|
|
return memref.alloca(out_type, [], [])
|
2023-09-07 21:33:12 -07:00
|
|
|
elif isinstance(aval, tpu_core.AbstractSemaphore):
|
2025-01-14 20:33:34 -08:00
|
|
|
memref_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn,
|
|
|
|
aval,
|
|
|
|
memory_space=TPUMemorySpace.SEMAPHORE,
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
return tpu.sem_alloc(memref_type)
|
2023-09-07 17:08:18 -07:00
|
|
|
raise NotImplementedError(f"Cannot allocate {type(aval)}.")
|
2023-09-07 16:40:27 -07:00
|
|
|
|
|
|
|
|
2023-09-07 17:08:18 -07:00
|
|
|
def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr):
|
2025-01-14 20:33:34 -08:00
|
|
|
out_type = [
|
|
|
|
aval_to_ir_type(ctx.lowering_context.dynamic_shape_replacement_fn, aval)
|
|
|
|
for aval in ctx.avals_out
|
|
|
|
]
|
2024-07-11 18:32:45 -07:00
|
|
|
region = tpu.RegionOp(out_type)
|
2023-09-07 17:08:18 -07:00
|
|
|
in_avals = [v.aval for v in jaxpr.invars]
|
Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
2024-07-23 15:15:11 -07:00
|
|
|
with ctx.lowering_context.grid_name_context():
|
|
|
|
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
2023-09-07 16:40:27 -07:00
|
|
|
with ir.InsertionPoint(region.body):
|
2025-01-14 20:33:34 -08:00
|
|
|
alloc_fn = functools.partial(_alloc_value, ctx=ctx)
|
|
|
|
args = map(alloc_fn, in_avals)
|
2023-09-07 17:08:18 -07:00
|
|
|
block_shapes = tuple(a.shape if isinstance(a, state.AbstractRef) else None
|
|
|
|
for a in in_avals)
|
2023-09-07 16:40:27 -07:00
|
|
|
ctx = ctx.lowering_context.replace(
|
2023-09-07 17:08:18 -07:00
|
|
|
block_shapes=(*ctx.block_shapes, *block_shapes)
|
2023-09-07 16:40:27 -07:00
|
|
|
)
|
2024-07-11 18:32:45 -07:00
|
|
|
out = jaxpr_subcomp(ctx, jaxpr, *consts, *args)
|
|
|
|
tpu.YieldOp(out)
|
|
|
|
return region.results
|
2023-09-07 16:40:27 -07:00
|
|
|
|
|
|
|
|
2024-07-24 17:13:49 -07:00
|
|
|
lowering_rules[primitives.run_scoped_p] = _run_scoped_lowering_rule
|
2023-09-07 21:33:12 -07:00
|
|
|
|
2023-10-02 17:03:40 -07:00
|
|
|
def _device_id_to_logical(
|
|
|
|
ctx: LoweringRuleContext, device_id,
|
|
|
|
device_id_type: tpu_primitives.DeviceIdType):
|
|
|
|
if device_id_type is tpu_primitives.DeviceIdType.MESH:
|
|
|
|
# Mesh means we are passed the mesh coordinates for the device
|
|
|
|
device_ids = tree_util.tree_leaves(device_id)
|
|
|
|
mesh_strides = ctx.lowering_context.mesh_context.mesh_strides
|
2024-10-26 16:58:07 -07:00
|
|
|
|
|
|
|
i32 = ir.IntegerType.get_signless(32)
|
2024-10-28 07:57:52 -07:00
|
|
|
if len(device_ids) == 0:
|
|
|
|
return arith.constant(i32, 0)
|
2024-10-26 16:58:07 -07:00
|
|
|
return functools.reduce(
|
|
|
|
arith.addi,
|
|
|
|
(
|
|
|
|
arith.muli(a, arith.constant(i32, b))
|
|
|
|
for a, b in zip(device_ids, mesh_strides)
|
|
|
|
),
|
2023-10-02 17:03:40 -07:00
|
|
|
)
|
|
|
|
elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL:
|
|
|
|
return device_id
|
|
|
|
raise NotImplementedError(f"Unsupported device id type: {device_id_type}")
|
|
|
|
|
2024-04-10 13:44:07 -07:00
|
|
|
|
|
|
|
def _semaphore_read_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
*args,
|
|
|
|
args_tree,
|
|
|
|
):
|
|
|
|
sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
sem, transforms = tree_util.tree_unflatten(args_tree, args)
|
|
|
|
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms)
|
2024-10-18 16:13:46 -07:00
|
|
|
return tpu.sem_read(sem)
|
2024-04-10 13:44:07 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[tpu_primitives.semaphore_read_p] = _semaphore_read_lowering_rule
|
|
|
|
|
2023-11-07 17:54:43 -08:00
|
|
|
def _semaphore_signal_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
*args,
|
2024-01-11 06:32:57 -08:00
|
|
|
args_tree,
|
2023-11-07 17:54:43 -08:00
|
|
|
device_id_type: tpu_primitives.DeviceIdType,
|
|
|
|
):
|
2024-05-02 07:42:54 -07:00
|
|
|
sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
sem, transforms, value, device_id, core_index = tree_util.tree_unflatten(
|
|
|
|
args_tree, args
|
|
|
|
)
|
|
|
|
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms)
|
2024-01-11 06:32:57 -08:00
|
|
|
if device_id is not None:
|
2023-10-02 17:03:40 -07:00
|
|
|
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
|
2024-10-18 16:13:46 -07:00
|
|
|
tpu.sem_signal(sem, value, device_id=device_id, core_id=core_index)
|
|
|
|
return []
|
2023-11-07 17:54:43 -08:00
|
|
|
|
|
|
|
|
2023-09-07 21:33:12 -07:00
|
|
|
lowering_rules[tpu_primitives.semaphore_signal_p] = (
|
|
|
|
_semaphore_signal_lowering_rule)
|
|
|
|
|
|
|
|
|
2024-01-11 06:32:57 -08:00
|
|
|
def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree):
|
|
|
|
sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
sem, transforms, value = tree_util.tree_unflatten(args_tree, args)
|
|
|
|
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms)
|
2024-10-18 16:13:46 -07:00
|
|
|
tpu.sem_wait(sem, value)
|
|
|
|
return []
|
2023-09-07 21:33:12 -07:00
|
|
|
lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule
|
2023-09-13 15:32:01 -07:00
|
|
|
|
2023-10-02 17:03:40 -07:00
|
|
|
def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree,
|
|
|
|
device_id_type: tpu_primitives.DeviceIdType):
|
2024-01-11 06:32:57 -08:00
|
|
|
(
|
|
|
|
src_ref,
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
src_transforms,
|
2024-01-11 06:32:57 -08:00
|
|
|
dst_ref,
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
dst_transforms,
|
2024-01-11 06:32:57 -08:00
|
|
|
sem,
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
sem_transforms,
|
2024-01-11 06:32:57 -08:00
|
|
|
src_sem,
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
src_sem_transforms,
|
2024-01-11 06:32:57 -08:00
|
|
|
device_id,
|
|
|
|
) = tree_util.tree_unflatten(tree, args)
|
|
|
|
(src_ref_aval, _, dst_ref_aval, _, sem_aval, _, src_sem_aval, _, _) = (
|
2023-10-03 13:58:26 -07:00
|
|
|
tree_util.tree_unflatten(tree, ctx.avals_in)
|
|
|
|
)
|
2024-07-15 17:58:27 -07:00
|
|
|
if src_ref_aval.dtype == jnp.bool_:
|
|
|
|
raise NotImplementedError("DMAs with bool dtypes are not supported.")
|
2024-01-02 21:53:30 -08:00
|
|
|
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
|
|
|
|
src_ref_block_shape, dst_ref_block_shape = block_shapes[0], block_shapes[2]
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
src_ref, _ = _transform_ref(
|
|
|
|
src_ref, src_ref_aval.dtype, src_ref_block_shape, src_transforms
|
2024-01-11 06:32:57 -08:00
|
|
|
)
|
|
|
|
if src_sem is not None:
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
src_sem, _ = _transform_ref(
|
|
|
|
src_sem, src_sem_aval.dtype, src_sem_aval.shape, src_sem_transforms
|
|
|
|
)
|
|
|
|
dst_ref, _ = _transform_ref(
|
|
|
|
dst_ref, dst_ref_aval.dtype, dst_ref_block_shape, dst_transforms
|
2024-01-11 06:32:57 -08:00
|
|
|
)
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms)
|
2023-10-02 17:03:40 -07:00
|
|
|
if device_id is not None:
|
|
|
|
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
|
2024-10-18 16:13:46 -07:00
|
|
|
tpu.enqueue_dma(src_ref, dst_ref, sem, source_semaphore=src_sem,
|
|
|
|
device_id=device_id)
|
2025-02-10 15:50:59 -08:00
|
|
|
|
2024-10-18 16:13:46 -07:00
|
|
|
return []
|
2023-09-13 15:32:01 -07:00
|
|
|
lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule
|
|
|
|
|
|
|
|
|
2023-10-02 17:03:40 -07:00
|
|
|
def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
|
|
|
|
device_id_type: tpu_primitives.DeviceIdType):
|
|
|
|
del device_id_type
|
2025-02-10 15:50:59 -08:00
|
|
|
(src, src_transforms, dst, transforms, sem, sem_transforms, _, _, _) = (
|
|
|
|
tree_util.tree_unflatten(tree, args)
|
|
|
|
)
|
|
|
|
(src_aval, _, dst_aval, _, sem_aval, _, _, _, _) = tree_util.tree_unflatten(
|
|
|
|
tree, ctx.avals_in
|
|
|
|
)
|
2024-01-11 06:32:57 -08:00
|
|
|
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
|
|
|
|
ref_block_shape = block_shapes[2]
|
2025-02-10 15:50:59 -08:00
|
|
|
src, _ = _transform_ref(src, src_aval.dtype, src_aval.shape, src_transforms)
|
|
|
|
dst, _ = _transform_ref(dst, dst_aval.dtype, ref_block_shape, transforms)
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms)
|
2025-02-12 08:15:15 -08:00
|
|
|
if ctx.forward_compatible or is_cloud_tpu_older_than(2025, 2, 12):
|
|
|
|
# TODO(mvoz): Remove once six months have passed. b/395630795
|
|
|
|
if hasattr(src_aval, "memory_space"):
|
|
|
|
src_memory_space = _memory_space_to_mosaic_attribute(src_aval.memory_space)
|
|
|
|
smem_space = ir.Attribute.parse("#tpu.memory_space<smem>")
|
|
|
|
src_is_smem = src_memory_space == smem_space
|
|
|
|
wait_ref = src if src_is_smem else dst
|
|
|
|
else:
|
|
|
|
wait_ref = dst
|
2025-02-27 11:07:15 -08:00
|
|
|
# Legacy instruction backwards compatibility.
|
2025-02-10 15:50:59 -08:00
|
|
|
tpu.wait_dma(sem, wait_ref)
|
|
|
|
else:
|
|
|
|
tpu.wait_dma2(sem, src, dst)
|
2024-10-18 16:13:46 -07:00
|
|
|
return []
|
2025-02-10 15:50:59 -08:00
|
|
|
|
2023-09-13 15:32:01 -07:00
|
|
|
lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule
|
2023-09-13 16:13:33 -07:00
|
|
|
|
|
|
|
def _device_id_lowering_rule(ctx: LoweringRuleContext):
|
2024-10-18 16:13:46 -07:00
|
|
|
return tpu.device_id()
|
2023-09-13 16:13:33 -07:00
|
|
|
lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule
|
2023-10-02 17:03:40 -07:00
|
|
|
|
2024-07-22 23:24:31 -07:00
|
|
|
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
|
|
|
|
grid_names = ctx.lowering_context.grid_names
|
|
|
|
if grid_names and axis_name in grid_names:
|
|
|
|
# We are querying a named axis corresponding to a grid dimension.
|
|
|
|
return _program_id_lowering_rule(ctx, axis=grid_names.index(axis_name))
|
|
|
|
# We are querying a named axis corresponding to a mesh dimension.
|
2024-10-18 16:13:46 -07:00
|
|
|
device_id = tpu.device_id()
|
2024-07-22 23:24:31 -07:00
|
|
|
mesh_context = ctx.lowering_context.mesh_context
|
|
|
|
if mesh_context is None:
|
|
|
|
raise ValueError("Mesh context is not set.")
|
|
|
|
mesh_shape = mesh_context.mesh_shape
|
|
|
|
axis_names = mesh_context.axis_names
|
2024-03-22 07:11:45 -07:00
|
|
|
axis_index = axis_names.index(axis_name)
|
|
|
|
axis_size = ir_constant(mesh_shape[axis_index])
|
|
|
|
minor_divisor = ir_constant(
|
|
|
|
np.prod(mesh_shape[axis_index + 1 :], dtype=np.int32)
|
|
|
|
)
|
|
|
|
return arith.remsi(arith.divsi(device_id, minor_divisor), axis_size)
|
2023-10-02 17:03:40 -07:00
|
|
|
lowering_rules[lax.axis_index_p] = _axis_index_rule
|
2023-11-29 04:02:30 -08:00
|
|
|
|
|
|
|
def _get_barrier_semaphore_rule(ctx: LoweringRuleContext):
|
2025-01-14 20:33:34 -08:00
|
|
|
memref_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
return tpu.sem_barrier(memref_type)
|
2023-11-29 04:02:30 -08:00
|
|
|
lowering_rules[tpu_primitives.get_barrier_semaphore_p] = _get_barrier_semaphore_rule
|
2024-05-20 09:07:03 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _delay_rule(ctx: LoweringRuleContext, nanos: int):
|
2024-10-18 16:13:46 -07:00
|
|
|
tpu.delay(nanos)
|
|
|
|
return []
|
2024-05-20 09:07:03 -07:00
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[tpu_primitives.delay_p] = _delay_rule
|
2024-05-23 10:01:22 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _debug_print_rule(
|
|
|
|
ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool
|
|
|
|
):
|
2025-01-13 13:21:45 -08:00
|
|
|
is_scalar_inputs = [aval.shape == () for aval in ctx.avals_in]
|
|
|
|
is_all_scalars = all(is_scalar_inputs)
|
|
|
|
is_single_vector = len(is_scalar_inputs) == 1 and not is_scalar_inputs[0]
|
|
|
|
if not (is_all_scalars or is_single_vector):
|
|
|
|
raise ValueError(
|
|
|
|
"All inputs to debug_print must be all scalars or a single vector, but"
|
|
|
|
f" got {ctx.avals_in}"
|
|
|
|
)
|
|
|
|
|
|
|
|
# Scalar case.
|
|
|
|
if is_all_scalars:
|
|
|
|
primitives.check_debug_print_format(fmt, *args)
|
|
|
|
if has_placeholders:
|
|
|
|
if not all(
|
|
|
|
isinstance(arg.type, ir.IntegerType) and arg.type.width == 32
|
|
|
|
for arg in args
|
|
|
|
):
|
|
|
|
raise TypeError(
|
|
|
|
"All arguments must be 32-bit integers when using"
|
|
|
|
" placeholders (`{...}`). If you need to print values of other types,"
|
|
|
|
" remove placeholders from the format string."
|
|
|
|
)
|
|
|
|
|
2025-01-14 20:33:34 -08:00
|
|
|
# TPU expects $0, $1 etc as placeholders.
|
2025-01-15 01:17:53 -08:00
|
|
|
fmt = "".join(
|
|
|
|
f"{text}${idx}"
|
|
|
|
for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt))
|
|
|
|
)
|
|
|
|
|
|
|
|
tpu.log(args, fmt, formatted=has_placeholders)
|
|
|
|
return ()
|
|
|
|
|
|
|
|
# Vector case.
|
|
|
|
# Copy the array to vmem for logging.
|
|
|
|
# Note that the shape of the array must be explicitly provided here. This is
|
|
|
|
# because the underlying implementation aligns shapes to tile boundaries,
|
|
|
|
# potentially altering the original shape and making it unrecoverable.
|
|
|
|
if len(ctx.avals_in) != 1:
|
|
|
|
raise ValueError(
|
|
|
|
"Only one vector input to debug_print is supported."
|
2025-01-13 13:21:45 -08:00
|
|
|
)
|
2025-01-15 01:17:53 -08:00
|
|
|
(aval,) = ctx.avals_in
|
|
|
|
(arg,) = args
|
|
|
|
|
|
|
|
if not has_placeholders or not fmt.endswith("{}"):
|
|
|
|
raise ValueError("For vector input, the format string must end with {}.")
|
|
|
|
|
|
|
|
fmt = fmt[:-2]
|
|
|
|
|
|
|
|
region = tpu.RegionOp(())
|
|
|
|
with ir.InsertionPoint(region.body):
|
|
|
|
element_type = _dtype_to_ir_type(aval.dtype)
|
|
|
|
ref_type = ir.MemRefType.get(
|
|
|
|
aval.shape,
|
|
|
|
element_type,
|
|
|
|
memory_space=ir.Attribute.parse("#tpu.memory_space<vmem>"),
|
|
|
|
)
|
|
|
|
ref = memref.alloca(ref_type, [], [])
|
|
|
|
|
|
|
|
index_type = ir.IndexType.get()
|
|
|
|
zero = arith.constant(index_type, 0)
|
|
|
|
indices = [zero] * len(aval.shape)
|
|
|
|
vector.store(arg, ref, indices)
|
|
|
|
tpu.log_buffer(ref, aval.shape, fmt)
|
|
|
|
tpu.yield_([])
|
2024-05-23 10:01:22 -07:00
|
|
|
return ()
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[primitives.debug_print_p] = _debug_print_rule
|
2024-05-24 12:21:31 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds):
|
|
|
|
del ctx
|
2024-06-12 14:36:31 -07:00
|
|
|
# In the KeyScalarBundle case we unpack the bundle and set the seed with
|
|
|
|
# the list of scalars.
|
|
|
|
if len(seeds) == 1 and isinstance(seeds[0], KeyScalarBundle):
|
2024-10-18 16:13:46 -07:00
|
|
|
tpu.prng_set_seed_32(seeds[0].scalars)
|
|
|
|
return []
|
2024-06-12 14:36:31 -07:00
|
|
|
# For integer seeds, we can set the seed directly as PRNGSeed32Op natively
|
|
|
|
# takes in a list of integers as input.
|
|
|
|
all_integers = all(isinstance(seed.type, ir.IntegerType) for seed in seeds)
|
|
|
|
if not all_integers:
|
2024-06-24 11:19:59 -07:00
|
|
|
seed_types = [seed.type for seed in seeds]
|
|
|
|
raise ValueError(f"All seed data must be scalar integers. Got {seed_types}")
|
2024-10-18 16:13:46 -07:00
|
|
|
tpu.prng_set_seed_32(seeds)
|
|
|
|
return []
|
2024-05-24 12:21:31 -07:00
|
|
|
lowering_rules[tpu_primitives.prng_seed_p] = _prng_seed_lowering_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _prng_random_bits_lowering_rule(ctx: LoweringRuleContext, *, shape):
|
|
|
|
if len(shape) <= 1:
|
|
|
|
# TODO(b/342054464): Support implicit dims for PRNGRandomBitsOp.
|
|
|
|
raise NotImplementedError("random_bits only supports rank>=2 outputs.")
|
|
|
|
out_aval = ctx.avals_out[0]
|
2025-01-14 20:33:34 -08:00
|
|
|
out_type = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, out_aval
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
return tpu.prng_random_bits(out_type)
|
2024-05-24 12:21:31 -07:00
|
|
|
lowering_rules[tpu_primitives.prng_random_bits_p] = _prng_random_bits_lowering_rule
|
2024-06-10 18:07:33 -07:00
|
|
|
|
|
|
|
|
|
|
|
def random_seed_lowering(ctx, seeds, *, impl):
|
2024-10-18 16:13:46 -07:00
|
|
|
seed_lowering = lower_fun(impl.seed, multiple_results=False)
|
2024-06-10 18:07:33 -07:00
|
|
|
return seed_lowering(ctx, seeds)
|
|
|
|
lowering_rules[prng.random_seed_p] = random_seed_lowering
|
|
|
|
|
|
|
|
|
|
|
|
def random_bits_lowering(ctx, keys, *, bit_width, shape):
|
|
|
|
assert bit_width == 32, "Only 32-bit PRNG supported."
|
|
|
|
aval, = ctx.avals_in
|
|
|
|
impl = aval.dtype._impl
|
2024-10-09 14:47:45 -07:00
|
|
|
_proxy_fn = impl.random_bits
|
|
|
|
if not pl_random.is_pallas_impl(impl):
|
|
|
|
def new_lowering(key, bit_width, shape):
|
|
|
|
key = jax.random.key_data(key).astype(jnp.uint32)
|
|
|
|
return impl.random_bits(key, bit_width, shape)
|
|
|
|
_proxy_fn = new_lowering
|
|
|
|
bits_lowering = lower_fun(_proxy_fn, multiple_results=False)
|
2024-06-10 18:07:33 -07:00
|
|
|
return bits_lowering(ctx, keys, bit_width=bit_width, shape=shape)
|
|
|
|
lowering_rules[prng.random_bits_p] = random_bits_lowering
|
|
|
|
|
|
|
|
|
|
|
|
def random_fold_in_lowering(ctx, keys, msgs):
|
|
|
|
keys_aval, _ = ctx.avals_in
|
|
|
|
impl = keys_aval.dtype._impl
|
2024-10-18 16:13:46 -07:00
|
|
|
fold_in_lowering = lower_fun(impl.fold_in, multiple_results=False)
|
2024-06-10 18:07:33 -07:00
|
|
|
return fold_in_lowering(ctx, keys, msgs)
|
|
|
|
lowering_rules[prng.random_fold_in_p] = random_fold_in_lowering
|
|
|
|
|
|
|
|
|
|
|
|
def random_unwrap_lowering(ctx, key):
|
2024-10-09 14:47:45 -07:00
|
|
|
keys_aval = ctx.avals_in[0]
|
|
|
|
impl = keys_aval.dtype._impl
|
|
|
|
if not pl_random.is_pallas_impl(impl):
|
|
|
|
return key
|
2024-07-03 13:07:39 -07:00
|
|
|
assert isinstance(key, KeyScalarBundle)
|
|
|
|
# Convert to a vector.
|
|
|
|
if tuple(key.key_shape) != (1, 1):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Seed key_data of shape != (1, 1) not supported. "
|
|
|
|
f"Got: {key.key_shape}")
|
|
|
|
scalar = key.scalars[0]
|
|
|
|
out_type = ir.VectorType.get(
|
|
|
|
key.key_shape, _dtype_to_ir_type(jnp.dtype('int32'))
|
|
|
|
)
|
2024-10-18 16:13:46 -07:00
|
|
|
val = vector.broadcast(out_type, scalar)
|
2024-07-03 13:07:39 -07:00
|
|
|
return val
|
2024-06-10 18:07:33 -07:00
|
|
|
lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering
|
|
|
|
|
|
|
|
|
|
|
|
def random_wrap_lowering(ctx, key_data, *, impl):
|
2024-10-09 14:47:45 -07:00
|
|
|
del ctx
|
|
|
|
if not pl_random.is_pallas_impl(impl):
|
|
|
|
return key_data
|
2024-06-12 14:36:31 -07:00
|
|
|
if isinstance(key_data.type, ir.VectorType):
|
|
|
|
# If the key data lives in vregs, need to unpack it to sregs.
|
|
|
|
key_data_list = []
|
|
|
|
key_data_shape = key_data.type.shape
|
2024-06-24 11:19:59 -07:00
|
|
|
if len(key_data_shape) != 2:
|
|
|
|
raise NotImplementedError("Seed key_data must be 2D.")
|
|
|
|
if tuple(key_data_shape) != (1, 1):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Seed key_data of shape != (1, 1) not supported. "
|
|
|
|
f"Got: {key_data_shape}")
|
|
|
|
for i in range(key_data_shape[1]):
|
|
|
|
key_data_list.append(vector.ExtractOp(key_data, [], [0, i]))
|
2024-07-03 13:07:39 -07:00
|
|
|
return KeyScalarBundle(
|
|
|
|
scalars=key_data_list, key_shape=tuple(key_data_shape))
|
2024-06-12 14:36:31 -07:00
|
|
|
if isinstance(key_data, KeyScalarBundle):
|
|
|
|
return key_data
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"key_data wrap {type(key_data)}")
|
|
|
|
|
2024-06-10 18:07:33 -07:00
|
|
|
lowering_rules[prng.random_wrap_p] = random_wrap_lowering
|
Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
2024-07-23 15:15:11 -07:00
|
|
|
|
2024-10-11 13:33:20 -07:00
|
|
|
def _checkify_lowering_rule(
|
|
|
|
ctx: LoweringRuleContext, *err_args, err_tree, debug):
|
|
|
|
if not tpu_core.runtime_assert_enabled():
|
|
|
|
if debug:
|
|
|
|
return []
|
|
|
|
else:
|
|
|
|
raise LoweringException("Non-debug check must be functionalized. "
|
|
|
|
"Enable runtime asserts with "
|
|
|
|
"--jax_pallas_enable_runtime_assert "
|
|
|
|
"or functionalize with checkify.check.")
|
|
|
|
|
|
|
|
assert ctx.lowering_context.ir_context.allow_unregistered_dialects, (
|
|
|
|
"allow_unregistered_dialects must be set to True for "
|
|
|
|
"runtime assert check.")
|
|
|
|
error = jax.tree.unflatten(err_tree, err_args)
|
|
|
|
assert len(error._pred) == 1
|
|
|
|
assert len(error._metadata) == 1
|
|
|
|
assert len(error._payload) == 1
|
|
|
|
pred = list(error._pred.items())[0][1]
|
|
|
|
metadata = list(error._metadata.items())[0]
|
|
|
|
payload = list(error._payload.items())[0][1]
|
|
|
|
exception_tree = metadata[1]
|
|
|
|
exception = jax.tree.unflatten(exception_tree, payload)
|
|
|
|
assert isinstance(exception, checkify.FailedCheckError)
|
|
|
|
|
|
|
|
# check_p has an inverted predicate compared to assert,
|
|
|
|
# so we need to compute not(pred) here.
|
|
|
|
out_scalar_type = _dtype_to_ir_type(jnp.dtype('bool'))
|
|
|
|
minus_one = ir_constant(-1, out_scalar_type)
|
2024-10-18 16:13:46 -07:00
|
|
|
not_pred = arith.xori(pred, minus_one)
|
2024-10-11 13:33:20 -07:00
|
|
|
attrs = {"msg": ir.StringAttr.get(exception.fmt_string)}
|
|
|
|
ir.Operation.create("cf.assert",
|
|
|
|
operands=(not_pred,),
|
|
|
|
attributes=attrs)
|
|
|
|
return []
|
|
|
|
lowering_rules[checkify.check_p] = _checkify_lowering_rule
|
Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
2024-07-23 15:15:11 -07:00
|
|
|
|
2024-10-09 14:47:45 -07:00
|
|
|
def _threefry2x32_lowering(ctx, k1, k2, m1, m2):
|
|
|
|
def _lower_fun(k1, k2, m1, m2):
|
|
|
|
with jax.named_scope("threefry2x32"):
|
|
|
|
res = prng._threefry2x32_lowering(k1, k2, m1, m2, use_rolled_loops=False)
|
|
|
|
return res
|
|
|
|
|
|
|
|
threefry_lowering = lower_fun(_lower_fun, multiple_results=True)
|
|
|
|
return threefry_lowering(ctx, k1, k2, m1, m2)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[prng.threefry2x32_p] = _threefry2x32_lowering
|
|
|
|
|
|
|
|
|
|
|
|
def _iota_2x32_shape_lowering(ctx, *, shape):
|
|
|
|
total_elements = np.prod(shape)
|
|
|
|
if total_elements > np.iinfo(jnp.int32).max:
|
|
|
|
raise NotImplementedError(f"Iota with >{np.iinfo(jnp.int32).max} items.")
|
|
|
|
|
|
|
|
def _lower_fun(shape):
|
|
|
|
iota_data = jnp.zeros(shape, dtype=jnp.int32)
|
|
|
|
multiplier = 1
|
|
|
|
for dim in range(len(shape)-1, -1, -1):
|
|
|
|
counts_lo = lax.broadcasted_iota(
|
|
|
|
dtype=jnp.int32, shape=shape, dimension=dim
|
|
|
|
)
|
|
|
|
iota_data += counts_lo * multiplier
|
|
|
|
multiplier *= shape[dim]
|
|
|
|
counts_hi = jnp.zeros(shape, dtype=jnp.int32)
|
|
|
|
return counts_hi, iota_data
|
|
|
|
|
|
|
|
iota_lowering = lower_fun(_lower_fun, multiple_results=True)
|
|
|
|
return iota_lowering(ctx, shape=shape)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering
|
2024-11-18 23:58:40 -08:00
|
|
|
|
|
|
|
|
|
|
|
def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):
|
|
|
|
operand, padding_value = args
|
|
|
|
padding_config = kwargs["padding_config"]
|
|
|
|
|
2025-01-14 20:33:34 -08:00
|
|
|
out_type: ir.VectorType = aval_to_ir_type(
|
|
|
|
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_in[0]
|
|
|
|
)
|
2024-11-18 23:58:40 -08:00
|
|
|
if not isinstance(out_type, ir.VectorType):
|
|
|
|
raise NotImplementedError("Only vector types are supported.")
|
|
|
|
|
|
|
|
for axis, (low, high, interior) in enumerate(padding_config):
|
|
|
|
if low == 0 and high == 0 and interior == 0:
|
|
|
|
continue
|
|
|
|
|
|
|
|
def _pad(val):
|
|
|
|
shape = list(operand.type.shape)
|
|
|
|
shape[axis] = val
|
|
|
|
pad_vec_type = ir.VectorType.get(
|
|
|
|
shape,
|
|
|
|
operand.type.element_type,
|
|
|
|
)
|
|
|
|
|
|
|
|
if isinstance(padding_value, ir.OpResult):
|
2024-12-17 02:16:27 -08:00
|
|
|
pad = vector.broadcast(pad_vec_type, padding_value)
|
2024-11-18 23:58:40 -08:00
|
|
|
else:
|
|
|
|
scalar_attr = ir.FloatAttr.get(operand.type.element_type, padding_value)
|
|
|
|
pad = arith.ConstantOp(
|
|
|
|
pad_vec_type,
|
|
|
|
ir.DenseElementsAttr.get_splat(
|
|
|
|
pad_vec_type,
|
|
|
|
scalar_attr,
|
|
|
|
),
|
|
|
|
).result
|
|
|
|
return pad
|
|
|
|
|
|
|
|
if low != 0:
|
|
|
|
pad_low = _pad(low)
|
|
|
|
new_shape = out_type.shape
|
|
|
|
new_shape[axis] += low
|
|
|
|
out_type = ir.VectorType.get(
|
|
|
|
new_shape,
|
|
|
|
out_type.element_type,
|
|
|
|
)
|
|
|
|
operand = tpu.concatenate(out_type, [pad_low, operand], dimension=axis)
|
|
|
|
|
|
|
|
if high != 0:
|
|
|
|
pad_high = _pad(high)
|
|
|
|
new_shape = out_type.shape
|
|
|
|
new_shape[axis] += high
|
|
|
|
out_type = ir.VectorType.get(
|
|
|
|
new_shape,
|
|
|
|
out_type.element_type,
|
|
|
|
)
|
|
|
|
operand = tpu.concatenate(out_type, [operand, pad_high], dimension=axis)
|
|
|
|
|
|
|
|
if interior > 0:
|
|
|
|
raise NotImplementedError("Not implemented: interior padding")
|
|
|
|
|
|
|
|
return operand
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[lax.pad_p] = _pad_lowering_rule
|
2024-12-22 00:50:12 -08:00
|
|
|
|
|
|
|
|
|
|
|
def _platform_index_lowering(
|
|
|
|
ctx: mlir.LoweringRuleContext,
|
|
|
|
*,
|
|
|
|
platforms: Sequence[Sequence[str]],
|
|
|
|
has_default: bool,
|
|
|
|
):
|
|
|
|
for i, ps in enumerate(platforms):
|
|
|
|
# note - slightly odd structure here, as platforms is a seq[seq[str]]
|
|
|
|
if "mosaic" in ps:
|
|
|
|
return ir_constant(i)
|
|
|
|
|
|
|
|
if has_default:
|
|
|
|
return ir_constant(len(platforms))
|
|
|
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
"No mosaic or default platform indexing rule found."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering
|