2023-02-17 12:45:39 -08:00
|
|
|
# Copyright 2023 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""Utilities for tracing stateful functions."""
|
|
|
|
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
from functools import partial
|
[pallas] Fix the handling of captured consts
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.
I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).
The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.
I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
2024-07-19 20:22:21 +03:00
|
|
|
from typing import Callable
|
|
|
|
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
import jax
|
2023-02-17 12:45:39 -08:00
|
|
|
from jax._src import core
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
from jax._src import dtypes
|
2023-02-17 12:45:39 -08:00
|
|
|
from jax._src import linear_util as lu
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
from jax._src.interpreters import partial_eval as pe
|
2023-02-17 12:45:39 -08:00
|
|
|
from jax._src.state import AbstractRef
|
|
|
|
from jax._src.state.primitives import ref_get
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
from jax._src.typing import DTypeLike
|
|
|
|
from jax._src.util import safe_map, safe_zip, split_list
|
2023-02-17 12:45:39 -08:00
|
|
|
|
|
|
|
map, unsafe_map = safe_map, map
|
|
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
|
2024-07-11 19:24:22 +01:00
|
|
|
|
[pallas] Fix the handling of captured consts
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.
I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).
The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.
I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
2024-07-19 20:22:21 +03:00
|
|
|
def hoist_consts_to_refs(
|
|
|
|
jaxpr: core.Jaxpr,
|
|
|
|
*,
|
|
|
|
index: int = 0,
|
|
|
|
make_abstract_ref: Callable[[core.AbstractValue], AbstractRef] = lambda aval: AbstractRef(aval)
|
|
|
|
) -> core.Jaxpr:
|
2024-07-11 19:24:22 +01:00
|
|
|
"""Hoists the constants in the given jaxpr into invars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
jaxpr: The jaxpr.
|
|
|
|
index: The index where the invars for the constants should be inserted.
|
|
|
|
By default, the new invars are inserted *before* any existing invars.
|
[pallas] Fix the handling of captured consts
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.
I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).
The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.
I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
2024-07-19 20:22:21 +03:00
|
|
|
make_abstract_ref: a callable to construct an AbstractRef, or subtype
|
|
|
|
thereof, from a constant AbstractValue.
|
2024-07-11 19:24:22 +01:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A new jaxpr where the constants were hoisted into invars as ``Ref``s.
|
|
|
|
"""
|
|
|
|
if not jaxpr.constvars:
|
|
|
|
return jaxpr # Nothing to hoist.
|
|
|
|
|
|
|
|
is_const_ref = [
|
|
|
|
isinstance(var.aval, AbstractRef) for var in jaxpr.constvars
|
|
|
|
]
|
|
|
|
const_avals = [
|
[pallas] Fix the handling of captured consts
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.
I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).
The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.
I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
2024-07-19 20:22:21 +03:00
|
|
|
var.aval if is_ref else make_abstract_ref(var.aval)
|
2024-07-11 19:24:22 +01:00
|
|
|
for is_ref, var in zip(is_const_ref, jaxpr.constvars)
|
|
|
|
]
|
|
|
|
in_avals = [var.aval for var in jaxpr.invars]
|
|
|
|
in_avals[index:index] = const_avals
|
2023-02-17 12:45:39 -08:00
|
|
|
|
|
|
|
def _hoist(*consts_args):
|
2024-07-11 19:24:22 +01:00
|
|
|
args0, all_consts, args1 = split_list(
|
|
|
|
consts_args, [index, len(const_avals)]
|
|
|
|
)
|
2023-02-17 12:45:39 -08:00
|
|
|
# We immediately read the const values out of the `Ref`s.
|
2024-07-11 19:24:22 +01:00
|
|
|
all_consts = [
|
|
|
|
c if is_ref else ref_get(c, ())
|
|
|
|
for is_ref, c in zip(is_const_ref, all_consts)
|
|
|
|
]
|
|
|
|
return core.eval_jaxpr(jaxpr, all_consts, *args0, *args1)
|
|
|
|
|
2024-01-25 22:20:36 -08:00
|
|
|
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
2025-02-07 10:15:47 +02:00
|
|
|
lu.wrap_init(_hoist, debug_info=jaxpr.debug_info), in_avals)
|
2023-02-17 12:45:39 -08:00
|
|
|
assert not consts, "All consts should have been converted to refs"
|
|
|
|
return hoisted_jaxpr
|
|
|
|
|
2024-07-11 19:24:22 +01:00
|
|
|
|
2023-02-17 12:45:39 -08:00
|
|
|
def val_to_ref_aval(x) -> AbstractRef:
|
2024-12-12 09:49:06 -08:00
|
|
|
aval = core.get_aval(x)
|
2023-02-17 12:45:39 -08:00
|
|
|
if type(aval) is not core.ShapedArray:
|
2024-07-11 19:24:22 +01:00
|
|
|
raise TypeError(f"can't make ref from {x}")
|
2023-02-17 12:45:39 -08:00
|
|
|
return AbstractRef(aval)
|
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:52:43 -07:00
|
|
|
|
|
|
|
|
|
|
|
def dtype_bitwidth(dtype: DTypeLike) -> int:
|
|
|
|
if dtypes.isdtype(dtype, "integral"):
|
|
|
|
return dtypes.iinfo(dtype).bits
|
|
|
|
return dtypes.dtype(dtype).itemsize * 8
|
|
|
|
|
|
|
|
|
|
|
|
def bitcast(x, dtype: DTypeLike):
|
|
|
|
x_bitwidth = dtype_bitwidth(x.dtype)
|
|
|
|
y_bitwidth = dtype_bitwidth(dtype)
|
|
|
|
shape = list(x.shape)
|
|
|
|
if x_bitwidth != y_bitwidth:
|
|
|
|
if len(shape) < 2:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Bitcast 1D ref with bitwidth change is not supported."
|
|
|
|
)
|
|
|
|
# Note: this is only valid on TPU.
|
|
|
|
if shape[-2] * x_bitwidth % y_bitwidth != 0:
|
|
|
|
raise ValueError(
|
|
|
|
"Expected input and output shapes are the same after multiplying"
|
|
|
|
" the second-minor dimension by the bitwidths."
|
|
|
|
)
|
|
|
|
shape[-2] = shape[-2] * x_bitwidth // y_bitwidth
|
|
|
|
if x_bitwidth < y_bitwidth:
|
|
|
|
ratio = y_bitwidth // x_bitwidth
|
|
|
|
x = x.reshape(*x.shape[:-2], x.shape[-2] // ratio, ratio, -1).swapaxes(
|
|
|
|
-1, -2
|
|
|
|
)
|
|
|
|
y = jax.lax.bitcast_convert_type(x, dtype)
|
|
|
|
if x_bitwidth > y_bitwidth:
|
|
|
|
y = y.swapaxes(-1, -2).reshape(shape)
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
|
|
def eval_bitcast_shape(x, dtype: DTypeLike):
|
|
|
|
f = partial(bitcast, dtype=dtype)
|
|
|
|
return jax.eval_shape(f, jax.ShapeDtypeStruct(x.shape, x.dtype)).shape
|