Reverts c2a3c0bb80434d89053b43648f88eba22b9bf1fa

PiperOrigin-RevId: 640524004
This commit is contained in:
Malcolm Reynolds 2024-06-05 07:50:12 -07:00 committed by jax authors
parent e09cda8fa9
commit 1669b99505
6 changed files with 36 additions and 313 deletions

View File

@ -113,5 +113,5 @@ py_library(
"//jax:api_util",
"//jax:util",
"//jax/_src/pallas",
] + py_deps("numpy"),
],
)

View File

@ -13,7 +13,6 @@
# limitations under the License.
"""Module for emitting custom TPU pipelines within a Pallas call."""
from __future__ import annotations
import dataclasses
import enum
@ -30,7 +29,6 @@ from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import primitives as tpu_primitives
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
SMEM = tpu_core.TPUMemorySpace.SMEM
@ -46,9 +44,6 @@ PipelineBlockSpecs = Union[Sequence[pallas_core.BlockSpec], Any]
PipelineRefs = Union[Sequence[REF], Any]
# TODO(sharadmv): make this a parameter and make it queryable from the Device.
_TILING = (8, 128)
def _broadcast_pytree_to(from_pytree, to_pytree):
"""Broadcast a prefix pytree to a given full tree."""
proxy = object()
@ -68,73 +63,14 @@ def _broadcast_pytree_to(from_pytree, to_pytree):
return tree_util.tree_unflatten(treedef, broadcast_leaves)
def _get_tpu_generation() -> int:
kind = jax.devices()[0].device_kind
if kind.endswith(' lite'):
kind = kind[:-len(' lite')]
assert kind[:-1] == "TPU v", kind
return int(kind[-1])
def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]:
# For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
# and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
# (2, 3, 128, 128) -> (1, 1, 8, 128).
if len(shape) < 2:
raise ValueError(f"Shape must have at least 2 dimensions: {shape=}")
leading_dims, final_dims = shape[:-2], shape[-2:]
# We want to find the minimum power of 2 that fits the second-minor dimension
# of shape, with maximum value 8.
second_minor, _ = final_dims
packing = 4 // dtype.itemsize
max_tiling = _TILING[0]
second_minor_tiling = (1 + int(_get_tpu_generation() < 4)) * packing
while second_minor_tiling < min(second_minor, max_tiling):
second_minor_tiling *= 2
return (*(1,) * len(leading_dims), second_minor_tiling, _TILING[1])
def _mod(a, n):
""""Calculates a mod n for positive and negative a with |a| <= n."""
return lax.rem(a + n, n)
def _round_up_to_nearest_multiple(s: int, multiple: int) -> int:
if s % multiple == 0:
return s
# Subtract off the remainder, then add multiple
return s - s % multiple + multiple
def _make_ds(
idx: jax.Array | int, size: jax.Array | int
) -> pl.Slice:
def _make_ds(idx, size):
"""Make a DMA slice with mosaic size hints."""
out = pl.ds(idx * size, size)
assert isinstance(out, pl.Slice)
return out
def _make_block_slice(
block_index: jax.Array, block_size: int, size: int, tiling: int
) -> pl.Slice | slice:
# Computes a slice given a block index and block size. In the default case,
# we return slice(block_index * block_size, (block_index + 1) * block_size).
# However, if the total size of the ref does not divide block size and we are
# selecting the last block, we need to pick the lowest tiling size multiple
# that contains the block.
if size % block_size == 0:
return _make_ds(block_index, block_size)
if block_size % tiling != 0:
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
num_blocks = pl.cdiv(size, block_size)
is_last = block_index == num_blocks - 1
rounded_size = jnp.where(
is_last,
_round_up_to_nearest_multiple(size % block_size, tiling),
block_size,
)
rounded_size = pl.multiple_of(rounded_size, tiling)
return pl.ds(block_index * block_size, rounded_size)
return pl.ds(pl.multiple_of(idx * size, size), size)
def _tuples_differ(xs, ys):
@ -323,110 +259,49 @@ class BufferedRef:
if self.memory_space == VMEM: return
self.current_slot[0] = self.next_slot[0]
def get_dma_slice(self, src_shape, src_dtype, grid_indices):
# We need to handle blocks that might go OOB in the src array. An in bounds
# block looks like this (for array shape (600, 600) and block shape
# (256, 256)):
#
# +--------------+------------------|
# | Block (0,0) | |
# | (256, 256) | |
# +--------------+ |
# | A (600, 600) |
# | |
# +---------------------------------+
#
# For in-bounds blocks, we don't need to do anything special.
# An out-of-bounds block looks like this:
#
# +--------------+------------------|
# | |
# | |
# + |
# | A (600, 600) |
# +--------------+ |
# | Block (2,0) | |
# + --------------------------------|
# | XXXXXXXXXX |
# +--------------+
# where the X's indicate where the block is out of bounds.
#
# When we have an out of bounds block like this, we need to truncate it to
# a tile boundary (tiles are (8, 128) along the two minormost dimensions).
# In this case, we'll have a block that is indexing the
# 512:768 elements of A along the first dimension. We need to convert 768
# into 600 (600 % 8 == 0), so our indexing will look like this:
# +--------------+------------------|
# | |
# | |
# + |
# | A (600, 600) |
# +--------------+ |
# | Block (2,0) | |
# + --------------------------------|
# where it is now a (88, 256) sized block.
#
# Suppose A is now (601, 600), instead of picking a (88, 256)-sized block
# for the last iteration on that dimension, we will pick the next highest
# tile multiple, i.e. (96, 256).
if len(src_shape) < 2:
raise NotImplementedError("Must use >1D values.")
tiling = _make_tiling(src_shape, src_dtype)
block_shape = tuple(1 if b is None else b for b in self.block_shape)
block_indices = self.compute_index(*grid_indices)
return jax.tree.map(
_make_block_slice, block_indices, block_shape, src_shape, tiling
)
def copy_in(self, src_ref, grid_indices):
"""Starts copy of HBM dma slice into the current slot."""
assert self.is_input
if self.memory_space == VMEM: return
dma_slice = self.compute_slice(grid_indices)
next_slot = lax.rem(self.current_slot[0] + 1, 2)
self.next_slot[0] = next_slot
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
tpu_primitives.make_async_copy(
src_ref.at[src_slice],
self.vmem_ref.at[next_slot].at[dst_slice],
src_ref.at[dma_slice],
self.vmem_ref.at[next_slot],
self.sem_recv).start()
def copy_out(self, dst_ref, grid_indices):
"""Starts copy of HBM dma slice from the current slot."""
assert self.is_output
if self.memory_space == VMEM: return
dma_slice = self.compute_slice(grid_indices)
slot = self.current_slot[0]
self.next_slot[0] = lax.rem(slot + 1, 2)
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
tpu_primitives.make_async_copy(
self.vmem_ref.at[slot].at[src_slice],
dst_ref.at[dst_slice],
self.vmem_ref.at[slot],
dst_ref.at[dma_slice],
self.sem_send).start()
def wait_in(self, src_ref, grid_indices):
"""Waits for input copy to finish."""
assert self.is_input
if self.memory_space == VMEM: return
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
dma_slice = self.compute_slice(grid_indices)
tpu_primitives.make_async_copy(
src_ref.at[src_slice], # nb: doesn't matter
self.vmem_ref.at[self.current_slot[0]].at[dst_slice], # only dst shape is important
src_ref.at[dma_slice], # nb: doesn't matter
self.vmem_ref.at[self.current_slot[0]], # only dst shape is important
self.sem_recv).wait()
def wait_out(self, dst_ref, grid_indices):
"""Waits for output copy to finish."""
assert self.is_output
if self.memory_space == VMEM: return
dma_slice = self.compute_slice(grid_indices)
prev_slot = lax.rem(self.current_slot[0] + 1, 2)
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
tpu_primitives.make_async_copy(
self.vmem_ref.at[prev_slot].at[src_slice], # nb: doesn't matter
dst_ref.at[dst_slice], # only dst shape is important
self.vmem_ref.at[prev_slot], # nb: doesn't matter
dst_ref.at[dma_slice], # only dst shape is important
self.sem_send).wait()
# Accumulator methods

View File

@ -15,9 +15,7 @@
"""Pallas utility functions."""
from __future__ import annotations
from typing import overload
import jax
from jax import lax
from jax._src import core as jax_core
from jax._src.util import split_list
@ -34,26 +32,9 @@ def when(condition):
lax.cond(condition, f, lambda: None)
return _wrapped
@overload
def cdiv(a: int, b: int) -> int:
...
@overload
def cdiv(a: int, b: jax.Array) -> jax.Array:
...
@overload
def cdiv(a: jax.Array, b: int) -> jax.Array:
...
@overload
def cdiv(a: jax.Array, b: jax.Array) -> jax.Array:
...
def cdiv(a: int | jax.Array, b: int | jax.Array) -> int | jax.Array:
if isinstance(a, int) and isinstance(b, int):
return (a + b - 1) // b
return lax.div(a + b - 1, b)
return (a + b - 1) // b
def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]:

View File

@ -119,7 +119,7 @@ def mha_forward_kernel(
# Ceildiv (`pl.cdiv` and `//` do not work due to type of start_q)
upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k)
else:
upper_bound = pl.cdiv(seq_len, block_k)
upper_bound = pl.cdiv(seq_len, block_k) # type: ignore
o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i))
if residual_refs:

View File

@ -210,12 +210,11 @@ jax_test(
"gpu",
],
main = "pallas_pipeline_tpu_test.py",
shard_count = 2,
deps = [
"//jax:extend",
"//jax:pallas_tpu",
"//jax:pallas_tpu_ops",
] + py_deps("hypothesis"),
],
)
jax_test(

View File

@ -26,24 +26,6 @@ from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
try:
import hypothesis as hp
import hypothesis.strategies as hps
CAN_USE_HYPOTHESIS = True
except (ModuleNotFoundError, ImportError):
CAN_USE_HYPOTHESIS = False
if CAN_USE_HYPOTHESIS:
hp.settings.register_profile(
'deterministic',
database=None,
derandomize=True,
deadline=None,
max_examples=50,
print_blob=True,
)
hp.settings.load_profile('deterministic')
jax.config.parse_flags_with_absl()
@ -73,46 +55,20 @@ def basic_matmul_kernel(
out_ref,
acc_scratch_ref,
*,
k: int,
acc_steps: int,
):
k_index = pl.program_id(2)
num_k = pl.num_programs(2)
bk = lhs_ref.shape[1]
@pl.when(k_index == 0)
@pl.when(pl.program_id(2) == 0)
def _zero_acc():
acc_scratch_ref[...] = jnp.zeros(
acc_scratch_ref.shape, acc_scratch_ref.dtype)
divisible_k = k % bk == 0
if divisible_k:
acc_scratch_ref[...] += jnp.dot(
lhs_ref[...],
rhs_ref[...],
preferred_element_type=acc_scratch_ref.dtype,
)
else:
def _last_block():
accum_dtype = acc_scratch_ref.dtype
lhs_mask = k_index * bk + jax.lax.broadcasted_iota(jnp.int32, lhs_ref.shape, 1) < k
rhs_mask = k_index * bk + jax.lax.broadcasted_iota(jnp.int32, rhs_ref.shape, 0) < k
dtype = lhs_ref.dtype
lhs = lhs_ref[...].astype(accum_dtype)
lhs = jnp.where(lhs_mask, lhs, 0).astype(dtype)
rhs = rhs_ref[...].astype(accum_dtype)
rhs = jnp.where(rhs_mask, rhs, 0).astype(dtype)
acc_scratch_ref[...] += jnp.dot(
lhs, rhs, preferred_element_type=acc_scratch_ref.dtype)
def _not_last_block():
acc_scratch_ref[...] += jnp.dot(
lhs_ref[...],
rhs_ref[...],
preferred_element_type=acc_scratch_ref.dtype,
)
jax.lax.cond(
k_index == num_k - 1, _last_block, _not_last_block
)
acc_scratch_ref[...] += jnp.dot(
lhs_ref[...],
rhs_ref[...],
preferred_element_type=acc_scratch_ref.dtype,
)
@pl.when(k_index == num_k - 1)
@pl.when(pl.program_id(2) == acc_steps - 1)
def _reduce_out():
out_ref[...] = acc_scratch_ref[...].astype(out_ref.dtype)
@ -252,8 +208,9 @@ class PallasCallColectivePipelineTest(parameterized.TestCase):
sharded_k = k // num_devices
inner_grid = (n // tn, m // tm, sharded_k // tk)
acc_steps = (sharded_k // tk)
inner_kernel = partial(basic_matmul_kernel, k=sharded_k)
inner_kernel = partial(basic_matmul_kernel, acc_steps=acc_steps)
inner_allocs = [
pltpu.BufferedRef.input(
@ -549,8 +506,9 @@ class PallasCallColectivePipelineTest(parameterized.TestCase):
sharded_k = k // num_devices
half_m = m // 2
inner_grid = (n // tn, half_m // tm, sharded_k // tk)
acc_steps = (sharded_k // tk)
inner_kernel = partial(basic_matmul_kernel, k=sharded_k)
inner_kernel = partial(basic_matmul_kernel, acc_steps=acc_steps)
inner_allocs = [
pltpu.BufferedRef.input(
@ -789,9 +747,10 @@ class PallasCallColectivePipelineTest(parameterized.TestCase):
sharded_k = k // num_devices
inner_grid = (n // tn, sharded_m // tm, sharded_k // tk)
outer_steps = num_devices // 2
acc_steps = sharded_k // tk
reduce_grid = (sharded_m // tm,)
inner_kernel = partial(basic_matmul_kernel, k=sharded_k)
inner_kernel = partial(basic_matmul_kernel, acc_steps=acc_steps)
def reduce_kernel(
out_ref, # [tm, tn]
@ -1073,7 +1032,9 @@ class PallasCallColectivePipelineTest(parameterized.TestCase):
sharded_k = k // num_devices
inner_grid = (n // tn, half_m // tm, sharded_k // tk)
outer_steps = num_devices
inner_kernel = partial(basic_matmul_kernel, k=sharded_k)
acc_steps = sharded_k // tk
inner_kernel = partial(basic_matmul_kernel, acc_steps=acc_steps)
inner_allocs = [
pltpu.BufferedRef.input(
@ -1302,98 +1263,5 @@ class PallasCallColectivePipelineTest(parameterized.TestCase):
)
if CAN_USE_HYPOTHESIS:
@partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
def matmul(x: jax.Array, y: jax.Array, *, bm: int, bk: int, bn: int):
m, k = x.shape
_, n = y.shape
def kernel(x_hbm_ref, y_hbm_ref, o_hbm_ref):
# Partition iteration space according to number of cores
core_index = pl.program_id(0)
num_cores = pl.num_programs(0)
num_m_iters = pl.cdiv(m, bm)
# Floor divide to get number of iterations per core
iterations_per_core = jax.lax.div(num_m_iters, num_cores)
# Last core gets the remainder of iterations
num_iters_on_this_core = jnp.where(
core_index == num_cores - 1,
iterations_per_core + jax.lax.rem(num_m_iters, num_cores),
iterations_per_core,
)
m_offset = core_index * iterations_per_core
grid = (num_iters_on_this_core, pl.cdiv(n, bn), pl.cdiv(k, bk))
def run(acc_scratch_ref):
pltpu.emit_pipeline(
partial(basic_matmul_kernel, acc_scratch_ref=acc_scratch_ref, k=k),
in_specs=[
pl.BlockSpec(lambda i, j, k: (m_offset + i, k), (bm, bk)),
pl.BlockSpec(lambda i, j, k: (k, j), (bk, bn)),
],
out_specs=pl.BlockSpec(lambda i, j, k: (m_offset + i, j), (bm, bn)),
grid=grid,
)(x_hbm_ref, y_hbm_ref, o_hbm_ref)
accum_dtype = (
jnp.float32 if jnp.issubdtype(x.dtype, jnp.floating) else jnp.int32
)
pltpu.run_scoped(run, pltpu.VMEM((bm, bn), accum_dtype))
num_cores = jax.devices()[0].num_cores
return pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
in_specs=[
pl.BlockSpec(memory_space=pltpu.ANY),
pl.BlockSpec(memory_space=pltpu.ANY),
],
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
grid=(num_cores,),
)(x, y)
class PaddedPipelineEmitterTest(parameterized.TestCase):
def setUp(self):
super().setUp()
if not jtu.is_device_tpu_at_least(4):
self.skipTest('Only TPU v4+ allowed.')
@hp.given(
hps.sampled_from(['float32', 'bfloat16', 'int8']),
hps.integers(1, 1024),
hps.integers(1, 1024),
hps.integers(1, 1024),
hps.sampled_from([8, 16, 32, 128, 256, 512]),
hps.sampled_from([128, 256, 512]),
hps.sampled_from([128, 256, 512]),
hps.integers(0, 4),
)
def test_padded_matmul(self, dtype, m, k, n, bm, bk, bn, seed):
hp.assume(bm <= m)
hp.assume(bn <= n)
hp.assume(bk <= k)
if dtype == 'bfloat16':
hp.assume(bm >= 16)
if dtype == 'int8':
hp.assume(bm >= 32)
hp.assume(jtu.is_device_tpu_at_least(5))
k1, k2 = jax.random.split(jax.random.key(seed))
x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype)
y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype)
out = matmul(x, y, bm=bm, bk=bk, bn=bn)
expected = x @ y
atol = rtol = 1e-5
if dtype == 'bfloat16':
out = out.astype('float32')
expected = expected.astype('float32')
atol = rtol = 1e-2
np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())