[Pallas] Add state discharge rule for pallas_call

This enables us to avoid spurious copies in the cases outlined in [the async operations design note](https://jax.readthedocs.io/en/latest/pallas/async_note.html) but not in general, since JAX and/or XLA could introduce copies because we have value semantics. For a proper solution, we need to introduce some notion of buffer semantics to XLA/HLO and preserve it through the lowering of stateful JAX (maybe by avoiding state discharge altogether).

PiperOrigin-RevId: 681206784
This commit is contained in:
Sharad Vikram 2024-10-01 16:29:59 -07:00 committed by jax authors
parent e361868132
commit c34e25d6f4
7 changed files with 608 additions and 160 deletions

View File

@ -2762,8 +2762,10 @@ lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule
def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
device_id_type: tpu_primitives.DeviceIdType):
del device_id_type
sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, args)
sem_aval, _, ref_aval, _ = tree_util.tree_unflatten(tree, ctx.avals_in)
(_, _, ref, transforms, sem, sem_transforms, _, _, _) = tree_util.tree_unflatten(
tree, args)
(_, _, ref_aval, _, sem_aval, _, _, _, _) = tree_util.tree_unflatten(
tree, ctx.avals_in)
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
ref_block_shape = block_shapes[2]
ref, _ = _transform_ref(ref, ref_aval.dtype, ref_block_shape, transforms)

View File

@ -431,18 +431,34 @@ class AsyncCopyDescriptor:
def is_remote(self):
return self.src_sem is not None
def _get_args_and_tree(self, swap_src_and_dst: bool = False):
if swap_src_and_dst:
return tree_util.tree_flatten((
self.dst_ref,
self.dst_transforms,
self.src_ref,
self.src_transforms,
self.src_sem,
self.src_sem_transforms,
self.dst_sem,
self.dst_sem_transforms,
self.device_id,
))
else:
return tree_util.tree_flatten((
self.src_ref,
self.src_transforms,
self.dst_ref,
self.dst_transforms,
self.dst_sem,
self.dst_sem_transforms,
self.src_sem,
self.src_sem_transforms,
self.device_id,
))
def start(self):
flat_args, tree = tree_util.tree_flatten((
self.src_ref,
self.src_transforms,
self.dst_ref,
self.dst_transforms,
self.dst_sem,
self.dst_sem_transforms,
self.src_sem,
self.src_sem_transforms,
self.device_id,
))
flat_args, tree = self._get_args_and_tree()
dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type)
def wait(self):
@ -451,27 +467,20 @@ class AsyncCopyDescriptor:
self.wait_recv()
def wait_recv(self):
wait_args, tree = tree_util.tree_flatten((
self.dst_sem,
self.dst_sem_transforms,
self.dst_ref,
self.dst_transforms,
))
flat_args, tree = self._get_args_and_tree()
dma_wait_p.bind(
*wait_args, tree=tree, device_id_type=self.device_id_type
*flat_args, tree=tree, device_id_type=self.device_id_type
)
def wait_send(self):
if not self.is_remote:
raise ValueError("Cannot `wait_send` on a local copy.")
wait_args, tree = tree_util.tree_flatten((
self.src_sem,
self.src_sem_transforms,
self.src_ref,
self.src_transforms,
))
# We swap src and dst since by default dma_wait_p waits on the dst_sem
# As a clean up, maybe we could modify the primitive to have a
# `wait_on_send` bool.
flat_args, tree = self._get_args_and_tree(swap_src_and_dst=True)
dma_wait_p.bind(
*wait_args, tree=tree, device_id_type=self.device_id_type
*flat_args, tree=tree, device_id_type=self.device_id_type
)
@ -689,7 +698,17 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
del settings
invars = eqn.invars
tree = eqn.params["tree"]
sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, invars)
(
_,
_,
ref,
transforms,
sem,
sem_transforms,
_,
_,
_,
) = tree_util.tree_unflatten(tree, invars)
return pp.concat([
pp.text("dma_wait"),
pp.text(" "),
@ -702,29 +721,38 @@ jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn
def dma_wait_discharge_rule(in_avals, out_avals,
*args, tree, device_id_type):
# TODO(b/370563115): perform ref update in dma_wait discharge rule instead of dma_start
del out_avals, device_id_type
(sem, sem_transforms, ref, ref_transforms) = tree_util.tree_unflatten(
tree, args
)
(
sem_aval,
sem_transforms_avals,
_, _, dst_ref, dst_ref_transforms, dst_sem, dst_sem_transforms, _, _, _ = (
tree_util.tree_unflatten(tree, args))
(_,
src_ref_transforms_avals,
_,
ref_transforms_avals,
dst_ref_transforms_avals,
dst_sem_aval,
dst_sem_transforms_avals,
src_sem_aval,
src_sem_transforms_avals,
device_id_aval,
) = tree_util.tree_unflatten(tree, in_avals)
num_sem_transforms = len(tree_util.tree_leaves(sem_transforms_avals))
num_transforms = len(tree_util.tree_leaves(ref_transforms_avals))
updates = state_discharge.transform_array(ref, ref_transforms)
num_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals))
num_transforms = len(tree_util.tree_leaves(dst_ref_transforms_avals))
updates = state_discharge.transform_array(dst_ref, dst_ref_transforms)
copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
sem_value = _transform_semaphore(sem, sem_transforms, sem_aval)
sem_value = _transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval)
_, new_sem = state_discharge.transform_swap_array(
sem, sem_transforms, sem_value - copy_size
dst_sem, dst_sem_transforms, sem_value - copy_size
)
new_vals = (new_sem,) # sem
new_vals += (None,) * num_sem_transforms
new_vals = (None,) # src_ref
new_vals += (None,) * len(tree_util.tree_leaves(src_ref_transforms_avals))
new_vals += (None,) # ref
new_vals += (None,) * num_transforms
new_vals += (None,) * num_transforms # ref_transforms
new_vals += (new_sem,) # sem
new_vals += (None,) * num_sem_transforms
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_aval)) # src_sem
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_transforms_avals))
new_vals += (None,) * len(tree_util.tree_leaves(device_id_aval)) # device_id
return new_vals, []
state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule)

View File

@ -30,6 +30,7 @@ from jax._src import config
from jax._src import core as jax_core
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import state
from jax._src import tree_util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -208,6 +209,7 @@ def _pallas_call_impl_interpret(
print(discharged_jaxpr)
out = _initialize_output_vals(grid_mapping.block_mappings_output,
args, input_output_aliases)
# TODO(b/370563936): Fix correctness issue w/ io aliasing
scalars = args[grid_mapping.slice_index_ops]
block_args = args[len(scalars):]
# invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch]
@ -936,7 +938,7 @@ def _pallas_call_batching_rule(
with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()):
kernel_src_info: pallas_core.SrcInfoStr = "<Wrapped outer kernel>"
jaxpr = _trace_kernel_to_jaxpr(
jaxpr, consts = _trace_kernel_to_jaxpr(
when_wrapped_kernel,
kernel_src_info,
batched_grid_mapping,
@ -945,6 +947,8 @@ def _pallas_call_batching_rule(
tuple(() for _ in flat_kernel_avals),
interpret=interpret,
)
if consts:
raise NotImplementedError("consts not supported in pallas_call")
assert ragged_axis_length is not None
args = (ragged_axis_length, *args)
@ -1160,7 +1164,7 @@ def _trace_kernel_to_jaxpr(
kernel_in_tree: tree_util.PyTreeDef,
kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...],
interpret: bool,
) -> jax_core.ClosedJaxpr:
) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]:
if interpret:
kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval,
kernel_avals))
@ -1174,17 +1178,18 @@ def _trace_kernel_to_jaxpr(
if consts:
consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c))
for c in consts]
raise ValueError(
f"The kernel function in the pallas_call {name_and_src_info} "
f"captures constants {consts_avals}. "
"You should pass them as inputs")
if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals):
raise ValueError(
f"The kernel function in the pallas_call {name_and_src_info} "
f"captures constants {consts_avals}. "
"You should pass them as inputs")
kernel_out_tree = out_tree_thunk()
if kernel_out_tree != tree_util.tree_structure(None):
raise ValueError(
f"The kernel function in the pallas_call {name_and_src_info} "
f"should return None. It returns a PyTree: {kernel_out_tree}")
return jaxpr
return jaxpr, tuple(consts)
_PALLAS_USE_MOSAIC_GPU = config.bool_flag(
@ -1209,6 +1214,8 @@ def _unsupported_lowering_error(platform: str) -> Exception:
def _pallas_call_lowering(
ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, **params
):
if params['jaxpr'].constvars:
raise ValueError('Cannot lower a pallas_call with constants.')
if interpret:
# If we are in interpret mode, we don't care what platform we are on.
impl = partial(_pallas_call_impl_interpret, **params)
@ -1286,6 +1293,133 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue:
return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)
def _get_memory_space_from_ref(ref_aval: state.AbstractRef) -> Any:
if isinstance(ref_aval, pallas_core.AbstractMemoryRef):
return ref_aval.memory_space
return pallas_core.MemorySpace.ANY
@state_discharge.register_discharge_rule(pallas_call_p)
def _pallas_call_state_discharge_rule(
avals_in,
avals_out,
*args,
jaxpr: jax_core.Jaxpr,
input_output_aliases: tuple[tuple[int, int], ...],
name_and_src_info: pallas_core.NameAndSrcInfo,
grid_mapping: GridMapping,
debug: bool,
interpret: bool,
compiler_params: Any,
cost_estimate: CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
):
del avals_out
assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars)
num_refs = len(jaxpr.constvars)
ref_avals, rest_in_avals = split_list(avals_in, [num_refs])
assert all(isinstance(ref_aval, state.AbstractRef) for ref_aval in ref_avals)
ref_avals = [
pallas_core.AbstractMemoryRef(
ref_aval.inner_aval, pallas_core.MemorySpace.ANY
)
for ref_aval in ref_avals
]
ref_block_specs = [
pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)
] * num_refs
ref_block_mappings = [
block_spec.to_block_mapping(
origin="", # TODO(sharadmv): enable origins for refs
array_aval=ref_aval.inner_aval,
index_map_avals=grid_mapping.index_map_avals,
index_map_tree=grid_mapping.index_map_tree,
grid=grid_mapping.grid,
mapped_dims=grid_mapping.mapped_dims,
) for ref_aval, block_spec in zip(ref_avals, ref_block_specs)
]
in_block_mappings, out_block_mappings = split_list(
grid_mapping.block_mappings, [grid_mapping.num_inputs]
)
new_block_mappings = (
*ref_block_mappings,
*in_block_mappings,
*ref_block_mappings,
*out_block_mappings,
)
new_grid_mapping = grid_mapping.replace(
block_mappings=new_block_mappings,
num_inputs=grid_mapping.num_inputs + num_refs,
num_outputs=grid_mapping.num_outputs + num_refs)
new_input_output_aliases = [
(i + grid_mapping.num_index_operands, i) for i in range(num_refs)
]
for i, o in input_output_aliases:
new_input_output_aliases.append((i + num_refs, o + num_refs))
ref_out_avals = [ref_aval.inner_aval for ref_aval in ref_avals]
new_out_avals = (*ref_out_avals, *out_avals)
ref_args, dynamic_grid_bounds, index_operands, rest_args = split_list(
args,
[
num_refs,
grid_mapping.num_dynamic_grid_bounds,
grid_mapping.num_index_operands,
],
)
def _rewritten_body(*args):
index_args, in_args, out_args, rest_args = split_list(
args, [new_grid_mapping.num_index_operands, new_grid_mapping.num_inputs,
new_grid_mapping.num_outputs])
ref_in_args, in_args = split_list(in_args, [num_refs])
ref_out_args, out_args = split_list(out_args, [num_refs])
# We don't care about ref_out_args because they are aliased to ref_in_args
del ref_out_args
jax_core.eval_jaxpr(
jaxpr, ref_in_args, *index_args, *in_args, *out_args, *rest_args
)
return []
index_map_avals, jaxpr_in_avals, jaxpr_out_avals, jaxpr_rest_avals = (
split_list(
[v.aval for v in jaxpr.invars],
[
grid_mapping.num_index_operands,
grid_mapping.num_inputs,
grid_mapping.num_outputs,
],
)
)
new_jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_rewritten_body),
[
*index_map_avals,
*ref_avals,
*jaxpr_in_avals,
*ref_avals,
*jaxpr_out_avals,
*jaxpr_rest_avals,
],
)
out_flat = pallas_call_p.bind(
*consts,
*dynamic_grid_bounds,
*index_operands,
*ref_args,
*rest_args,
jaxpr=new_jaxpr,
input_output_aliases=new_input_output_aliases,
grid_mapping=new_grid_mapping,
name_and_src_info=name_and_src_info,
debug=debug,
interpret=interpret,
compiler_params=compiler_params,
cost_estimate=cost_estimate,
out_avals=new_out_avals,
)
refs_out, rest = split_list(out_flat, [num_refs])
updated_vals_in = refs_out + [None] * len(rest_in_avals)
return updated_vals_in, rest
def pallas_call(
kernel: Callable[..., None],
out_shape: Any,
@ -1440,7 +1574,7 @@ def pallas_call(
for x in flat_kernel_args
)
with pallas_core.interpret_mode_env(interpret):
jaxpr = _trace_kernel_to_jaxpr(
jaxpr, consts = _trace_kernel_to_jaxpr(
kernel, kernel_src_info, grid_mapping, tuple(flat_kernel_avals),
kernel_in_tree, kernel_arg_transforms, interpret=interpret)
for i_idx, o_idx in input_output_aliases.items():
@ -1467,6 +1601,7 @@ def pallas_call(
index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
with pallas_core.interpret_mode_env(interpret):
out_flat = pallas_call_p.bind(
*consts,
*dynamic_grid_bounds,
*index_args,
*rest_args,

View File

@ -343,8 +343,8 @@ jax_multiplatform_test(
)
jax_multiplatform_test(
name = "tpu_pallas_mesh_test",
srcs = ["tpu_pallas_mesh_test.py"],
name = "tpu_pallas_state_test",
srcs = ["tpu_pallas_state_test.py"],
enable_backends = ["tpu"],
tags = [
"noasan",

View File

@ -20,6 +20,7 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from jax._src.state import discharge as state_discharge
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu
@ -755,5 +756,123 @@ class PallasCallRemoteAsyncCopyTest(parameterized.TestCase):
np.testing.assert_array_equal(y, expected)
def make_stateful_async_copy():
@jax.named_call
def copy_start(x_ref, o_ref) -> Future:
def copy_start_kernel(sem):
pltpu.make_async_copy(x_ref, o_ref, sem).start()
sem = pl.pallas_call(
copy_start_kernel,
out_shape=pltpu.SemaphoreType.DMA(()),
out_specs=pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
)()
return sem
@jax.named_call
def copy_done(x_ref, o_ref, future):
sem = future
def copy_done_kernel(sem):
pltpu.make_async_copy(x_ref, o_ref, sem).wait()
() = pl.pallas_call(
copy_done_kernel,
out_shape=(),
in_specs=[
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
],
)(sem)
return copy_start, copy_done
def make_stateful_async_slice(i: int):
@jax.named_call
def copy_start(x_ref, o_ref) -> Future:
def copy_start_kernel(sem):
pltpu.make_async_copy(x_ref.at[i], o_ref, sem).start()
sem = pl.pallas_call(
copy_start_kernel,
out_shape=pltpu.SemaphoreType.DMA(()),
out_specs=pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
)()
return sem
@jax.named_call
def copy_done(x_ref, o_ref, future):
sem = future
def copy_done_kernel(sem):
pltpu.make_async_copy(x_ref.at[i], o_ref, sem).wait()
() = pl.pallas_call(
copy_done_kernel,
out_shape=(),
in_specs=[
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
],
)(sem)
return copy_start, copy_done
class PallasCallStatefulAsyncTest(parameterized.TestCase):
def setUp(self):
super().setUp()
if not jtu.is_device_tpu_at_least(4):
self.skipTest('DMAs only guaranteed to work ou TPU v4+')
def test_basic_stateful_async_copy(self):
@jax.jit
def f(x):
y = jnp.zeros_like(x)
def body(refs):
copy_start, copy_done = make_stateful_async_copy()
x_ref, y_ref = refs
fut = copy_start(x_ref, y_ref)
copy_done(x_ref, y_ref, fut)
_, y = state_discharge.run_state(body)((x, y))
return y
x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32)
y = f(x)
np.testing.assert_array_equal(y, x)
def test_multiple_stateful_async_copy(self):
@jax.jit
def f(x):
y = y2 = jnp.zeros_like(x)
def body(refs):
copy_start, copy_done = make_stateful_async_copy()
x_ref, y_ref, y2_ref = refs
fut = copy_start(x_ref, y_ref)
fut2 = copy_start(x_ref, y2_ref)
copy_done(x_ref, y_ref, fut)
copy_done(x_ref, y2_ref, fut2)
_, y, y2 = state_discharge.run_state(body)((x, y, y2))
return y, y2
x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32)
y, y2 = f(x)
np.testing.assert_array_equal(y, x)
np.testing.assert_array_equal(y2, x)
def test_basic_stateful_async_slice(self):
@jax.jit
def f(x):
y = jnp.zeros(x.shape[1:], x.dtype)
def body(refs):
copy_start, copy_done = make_stateful_async_slice(2)
x_ref, y_ref = refs
fut = copy_start(x_ref, y_ref)
copy_done(x_ref, y_ref, fut)
_, y = state_discharge.run_state(body)((x, y))
return y
x = jax.random.normal(jax.random.key(0), (4, 8, 128), dtype=jnp.float32)
y = f(x)
np.testing.assert_array_equal(y, x[2])
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1,107 +0,0 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Pallas mesh API."""
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax._src.state import discharge as state_discharge
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
jax.config.parse_flags_with_absl()
class ShmallasTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.is_device_tpu_at_least(4):
self.skipTest("Only supported on TPU v4+")
def test_can_create_tensorcore_mesh(self):
_ = pltpu.create_tensorcore_mesh("x")
def test_can_trivially_shard_map_with_pallas_mesh(self):
mesh = pltpu.create_tensorcore_mesh("x")
_ = shard_map.shard_map(lambda: None, mesh, in_specs=(), out_specs=None)()
def test_can_run_basic_pallas_kernel_with_shard_map(self):
mesh = pltpu.create_tensorcore_mesh("x")
@jax.jit
def f(x):
y = jnp.zeros_like(x)
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pl.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = state_discharge.run_state(inner)((x, y))
return y
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
y = f(x)
np.testing.assert_array_equal(y, x)
def test_can_query_core_index_pallas_kernel_with_shard_map(self):
mesh = pltpu.create_tensorcore_mesh("x")
@jax.jit
def f(x):
y = jnp.zeros_like(x)
def inner(refs):
x_ref, y_ref = refs
def kernel():
num_cores = jax.lax.psum(1, "x")
slc_size = 16 // num_cores
def alloc(x_vmem_ref, y_vmem_ref, sem):
core_index = jax.lax.axis_index("x")
slc = pl.ds(core_index * slc_size, slc_size)
pltpu.async_copy(
x_ref.at[slc],
x_vmem_ref,
sem,
).wait()
y = x_vmem_ref[...] + jax.lax.axis_index("x")
y_vmem_ref[...] = y
pltpu.async_copy(y_vmem_ref, y_ref.at[slc], sem).wait()
pl.run_scoped(
alloc,
pltpu.VMEM((slc_size, 128), x_ref.dtype),
pltpu.VMEM((slc_size, 128), y_ref.dtype),
pltpu.SemaphoreType.DMA,
)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = state_discharge.run_state(inner)((x, y))
return y
num_cores = jax.devices()[0].num_cores
x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128))
expected_out = (
x.reshape((num_cores, -1, 128)) + jnp.arange(num_cores)[..., None, None]
).reshape(x.shape)
y = f(x)
np.testing.assert_array_equal(y, expected_out)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -0,0 +1,271 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Pallas mesh API."""
import functools
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax._src.state import discharge as state_discharge
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
jax.config.parse_flags_with_absl()
class PallasCallStatefulTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.is_device_tpu_at_least(4):
self.skipTest("Only supported on TPU v4+")
def test_basic_stateful_kernel(self):
def copy_kernel(x_ref, y_ref):
def body(sem):
pltpu.make_async_copy(x_ref, y_ref, sem).start()
pltpu.make_async_copy(x_ref, y_ref, sem).wait()
pl.run_scoped(body, pltpu.SemaphoreType.DMA)
def f_stateful(refs):
x_ref, y_ref = refs
pl.pallas_call(functools.partial(copy_kernel, x_ref, y_ref),
out_shape=[])()
@jax.jit
def f(x):
_, y = state_discharge.run_state(f_stateful)((x, jnp.zeros_like(x)))
return y
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
y = f(x)
np.testing.assert_array_equal(y, x)
def test_basic_stateful_kernel_with_scratch_sem(self):
def copy_kernel(x_ref, y_ref, sem):
pltpu.make_async_copy(x_ref, y_ref, sem).start()
pltpu.make_async_copy(x_ref, y_ref, sem).wait()
def f_stateful(refs):
x_ref, y_ref = refs
pl.pallas_call(functools.partial(copy_kernel, x_ref, y_ref),
scratch_shapes=[pltpu.SemaphoreType.DMA],
out_shape=[])()
@jax.jit
def f(x):
_, y = state_discharge.run_state(f_stateful)((x, jnp.zeros_like(x)))
return y
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
y = f(x)
np.testing.assert_array_equal(y, x)
def test_basic_stateful_kernel_with_scalar_prefetch(self):
def copy_kernel(x_ref, y_ref, index_ref, sem):
i = index_ref[0]
pltpu.make_async_copy(x_ref.at[i], y_ref, sem).start()
pltpu.make_async_copy(x_ref.at[i], y_ref, sem).wait()
def f_stateful(refs):
x_ref, y_ref = refs
pl.pallas_call(
functools.partial(copy_kernel, x_ref, y_ref),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
scratch_shapes=[pltpu.SemaphoreType.DMA],
),
out_shape=[],
)(jnp.array([0]))
@jax.jit
def f(x):
_, y = state_discharge.run_state(f_stateful)((x[None], jnp.zeros_like(x)))
return y
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
y = f(x)
np.testing.assert_array_equal(y, x)
def test_basic_stateful_kernel_with_io_aliasing(self):
def copy_kernel(x_ref, y_ref, x_old_ref, x_old_ref2, sem):
del x_old_ref, x_old_ref2
pltpu.make_async_copy(x_ref, y_ref, sem).start()
pltpu.make_async_copy(x_ref, y_ref, sem).wait()
def f_stateful(refs):
x_ref, y_ref, o_ref = refs
x = pl.pallas_call(
functools.partial(copy_kernel, x_ref, y_ref),
scratch_shapes=[pltpu.SemaphoreType.DMA],
out_shape=jax.ShapeDtypeStruct(x_ref.shape, x_ref.dtype),
input_output_aliases={0: 0},
)(x_ref[...])
o_ref[...] = x
@jax.jit
def f(x):
_, y, o = state_discharge.run_state(f_stateful)(
(x, jnp.zeros_like(x), jnp.zeros_like(x))
)
return y, o
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
y, o = f(x)
np.testing.assert_array_equal(y, x)
np.testing.assert_array_equal(o, x)
def test_stateful_matmul(self):
m, k, n = 512, 512, 512
bm, bk, bn = 128, 128, 128
def matmul_kernel(acc_ref, x_ref, y_ref, o_ref):
@pl.when(pl.program_id(2) == 0)
def _():
acc_ref[...] = jnp.zeros_like(acc_ref)
acc_ref[...] += jnp.dot(
x_ref[...], y_ref[...], preferred_element_type=jnp.float32
)
@pl.when(pl.program_id(2) == pl.num_programs(2) - 1)
def _():
o_ref[...] = acc_ref[...].astype(o_ref.dtype)
def matmul(x, y):
def run_matmul(refs):
x_ref, y_ref, o_ref = refs
def matmul_pipeline_kernel(acc_ref):
pltpu.emit_pipeline(
functools.partial(matmul_kernel, acc_ref),
grid=(m // bm, n // bn, k // bk),
in_specs=[
pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)),
],
out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
)(x_ref, y_ref, o_ref)
pl.pallas_call(
matmul_pipeline_kernel,
out_shape=[],
scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
)()
_, _, o = state_discharge.run_state(run_matmul)(
(x, y, jnp.ones((m, n), dtype=x.dtype))
)
return o
x = jax.random.normal(jax.random.key(0), (m, k), jnp.float32)
y = jax.random.normal(jax.random.key(1), (k, n), jnp.float32)
o = matmul(x, y)
atol = 0
if jtu.is_device_tpu(6):
atol = 2e-5
np.testing.assert_allclose(o, x @ y, atol=atol)
class ShmallasTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.is_device_tpu_at_least(4):
self.skipTest("Only supported on TPU v4+")
def test_can_create_tensorcore_mesh(self):
_ = pltpu.create_tensorcore_mesh("x")
def test_can_trivially_shard_map_with_pallas_mesh(self):
mesh = pltpu.create_tensorcore_mesh("x")
_ = shard_map.shard_map(lambda: None, mesh, in_specs=(), out_specs=None)()
def test_can_run_basic_pallas_kernel_with_shard_map(self):
mesh = pltpu.create_tensorcore_mesh("x")
@jax.jit
def f(x):
y = jnp.zeros_like(x)
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pl.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = state_discharge.run_state(inner)((x, y))
return y
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
y = f(x)
np.testing.assert_array_equal(y, x)
def test_can_query_core_index_pallas_kernel_with_shard_map(self):
mesh = pltpu.create_tensorcore_mesh("x")
@jax.jit
def f(x):
y = jnp.zeros_like(x)
def inner(refs):
x_ref, y_ref = refs
def kernel():
num_cores = jax.lax.psum(1, "x")
slc_size = 16 // num_cores
def alloc(x_vmem_ref, y_vmem_ref, sem):
core_index = jax.lax.axis_index("x")
slc = pl.ds(core_index * slc_size, slc_size)
pltpu.async_copy(
x_ref.at[slc],
x_vmem_ref,
sem,
).wait()
y = x_vmem_ref[...] + jax.lax.axis_index("x")
y_vmem_ref[...] = y
pltpu.async_copy(y_vmem_ref, y_ref.at[slc], sem).wait()
pl.run_scoped(
alloc,
pltpu.VMEM((slc_size, 128), x_ref.dtype),
pltpu.VMEM((slc_size, 128), y_ref.dtype),
pltpu.SemaphoreType.DMA,
)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = state_discharge.run_state(inner)((x, y))
return y
num_cores = jax.devices()[0].num_cores
x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128))
expected_out = (
x.reshape((num_cores, -1, 128)) + jnp.arange(num_cores)[..., None, None]
).reshape(x.shape)
y = f(x)
np.testing.assert_array_equal(y, expected_out)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())