From c34e25d6f4ea972bf2fce13b3fbe9f68cd5d0224 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 1 Oct 2024 16:29:59 -0700 Subject: [PATCH] [Pallas] Add state discharge rule for pallas_call This enables us to avoid spurious copies in the cases outlined in [the async operations design note](https://jax.readthedocs.io/en/latest/pallas/async_note.html) but not in general, since JAX and/or XLA could introduce copies because we have value semantics. For a proper solution, we need to introduce some notion of buffer semantics to XLA/HLO and preserve it through the lowering of stateful JAX (maybe by avoiding state discharge altogether). PiperOrigin-RevId: 681206784 --- jax/_src/pallas/mosaic/lowering.py | 6 +- jax/_src/pallas/mosaic/primitives.py | 110 +++++++---- jax/_src/pallas/pallas_call.py | 151 +++++++++++++- tests/pallas/BUILD | 4 +- tests/pallas/tpu_pallas_async_test.py | 119 +++++++++++ tests/pallas/tpu_pallas_mesh_test.py | 107 ---------- tests/pallas/tpu_pallas_state_test.py | 271 ++++++++++++++++++++++++++ 7 files changed, 608 insertions(+), 160 deletions(-) delete mode 100644 tests/pallas/tpu_pallas_mesh_test.py create mode 100644 tests/pallas/tpu_pallas_state_test.py diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 56392cf77..f630e6eff 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2762,8 +2762,10 @@ lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, device_id_type: tpu_primitives.DeviceIdType): del device_id_type - sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, args) - sem_aval, _, ref_aval, _ = tree_util.tree_unflatten(tree, ctx.avals_in) + (_, _, ref, transforms, sem, sem_transforms, _, _, _) = tree_util.tree_unflatten( + tree, args) + (_, _, ref_aval, _, sem_aval, _, _, _, _) = tree_util.tree_unflatten( + tree, ctx.avals_in) block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes) ref_block_shape = block_shapes[2] ref, _ = _transform_ref(ref, ref_aval.dtype, ref_block_shape, transforms) diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index aab214a2d..7aab30ffc 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -431,18 +431,34 @@ class AsyncCopyDescriptor: def is_remote(self): return self.src_sem is not None + def _get_args_and_tree(self, swap_src_and_dst: bool = False): + if swap_src_and_dst: + return tree_util.tree_flatten(( + self.dst_ref, + self.dst_transforms, + self.src_ref, + self.src_transforms, + self.src_sem, + self.src_sem_transforms, + self.dst_sem, + self.dst_sem_transforms, + self.device_id, + )) + else: + return tree_util.tree_flatten(( + self.src_ref, + self.src_transforms, + self.dst_ref, + self.dst_transforms, + self.dst_sem, + self.dst_sem_transforms, + self.src_sem, + self.src_sem_transforms, + self.device_id, + )) + def start(self): - flat_args, tree = tree_util.tree_flatten(( - self.src_ref, - self.src_transforms, - self.dst_ref, - self.dst_transforms, - self.dst_sem, - self.dst_sem_transforms, - self.src_sem, - self.src_sem_transforms, - self.device_id, - )) + flat_args, tree = self._get_args_and_tree() dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type) def wait(self): @@ -451,27 +467,20 @@ class AsyncCopyDescriptor: self.wait_recv() def wait_recv(self): - wait_args, tree = tree_util.tree_flatten(( - self.dst_sem, - self.dst_sem_transforms, - self.dst_ref, - self.dst_transforms, - )) + flat_args, tree = self._get_args_and_tree() dma_wait_p.bind( - *wait_args, tree=tree, device_id_type=self.device_id_type + *flat_args, tree=tree, device_id_type=self.device_id_type ) def wait_send(self): if not self.is_remote: raise ValueError("Cannot `wait_send` on a local copy.") - wait_args, tree = tree_util.tree_flatten(( - self.src_sem, - self.src_sem_transforms, - self.src_ref, - self.src_transforms, - )) + # We swap src and dst since by default dma_wait_p waits on the dst_sem + # As a clean up, maybe we could modify the primitive to have a + # `wait_on_send` bool. + flat_args, tree = self._get_args_and_tree(swap_src_and_dst=True) dma_wait_p.bind( - *wait_args, tree=tree, device_id_type=self.device_id_type + *flat_args, tree=tree, device_id_type=self.device_id_type ) @@ -689,7 +698,17 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn, del settings invars = eqn.invars tree = eqn.params["tree"] - sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, invars) + ( + _, + _, + ref, + transforms, + sem, + sem_transforms, + _, + _, + _, + ) = tree_util.tree_unflatten(tree, invars) return pp.concat([ pp.text("dma_wait"), pp.text(" "), @@ -702,29 +721,38 @@ jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn def dma_wait_discharge_rule(in_avals, out_avals, *args, tree, device_id_type): + # TODO(b/370563115): perform ref update in dma_wait discharge rule instead of dma_start del out_avals, device_id_type - (sem, sem_transforms, ref, ref_transforms) = tree_util.tree_unflatten( - tree, args - ) - ( - sem_aval, - sem_transforms_avals, + _, _, dst_ref, dst_ref_transforms, dst_sem, dst_sem_transforms, _, _, _ = ( + tree_util.tree_unflatten(tree, args)) + (_, + src_ref_transforms_avals, _, - ref_transforms_avals, + dst_ref_transforms_avals, + dst_sem_aval, + dst_sem_transforms_avals, + src_sem_aval, + src_sem_transforms_avals, + device_id_aval, ) = tree_util.tree_unflatten(tree, in_avals) - num_sem_transforms = len(tree_util.tree_leaves(sem_transforms_avals)) - num_transforms = len(tree_util.tree_leaves(ref_transforms_avals)) - updates = state_discharge.transform_array(ref, ref_transforms) + num_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals)) + num_transforms = len(tree_util.tree_leaves(dst_ref_transforms_avals)) + updates = state_discharge.transform_array(dst_ref, dst_ref_transforms) copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - sem_value = _transform_semaphore(sem, sem_transforms, sem_aval) + sem_value = _transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval) _, new_sem = state_discharge.transform_swap_array( - sem, sem_transforms, sem_value - copy_size + dst_sem, dst_sem_transforms, sem_value - copy_size ) - new_vals = (new_sem,) # sem - new_vals += (None,) * num_sem_transforms + new_vals = (None,) # src_ref + new_vals += (None,) * len(tree_util.tree_leaves(src_ref_transforms_avals)) new_vals += (None,) # ref - new_vals += (None,) * num_transforms + new_vals += (None,) * num_transforms # ref_transforms + new_vals += (new_sem,) # sem + new_vals += (None,) * num_sem_transforms + new_vals += (None,) * len(tree_util.tree_leaves(src_sem_aval)) # src_sem + new_vals += (None,) * len(tree_util.tree_leaves(src_sem_transforms_avals)) + new_vals += (None,) * len(tree_util.tree_leaves(device_id_aval)) # device_id return new_vals, [] state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 48fe3302e..6114afd02 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -30,6 +30,7 @@ from jax._src import config from jax._src import core as jax_core from jax._src import effects from jax._src import linear_util as lu +from jax._src import state from jax._src import tree_util from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -208,6 +209,7 @@ def _pallas_call_impl_interpret( print(discharged_jaxpr) out = _initialize_output_vals(grid_mapping.block_mappings_output, args, input_output_aliases) + # TODO(b/370563936): Fix correctness issue w/ io aliasing scalars = args[grid_mapping.slice_index_ops] block_args = args[len(scalars):] # invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch] @@ -936,7 +938,7 @@ def _pallas_call_batching_rule( with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()): kernel_src_info: pallas_core.SrcInfoStr = "" - jaxpr = _trace_kernel_to_jaxpr( + jaxpr, consts = _trace_kernel_to_jaxpr( when_wrapped_kernel, kernel_src_info, batched_grid_mapping, @@ -945,6 +947,8 @@ def _pallas_call_batching_rule( tuple(() for _ in flat_kernel_avals), interpret=interpret, ) + if consts: + raise NotImplementedError("consts not supported in pallas_call") assert ragged_axis_length is not None args = (ragged_axis_length, *args) @@ -1160,7 +1164,7 @@ def _trace_kernel_to_jaxpr( kernel_in_tree: tree_util.PyTreeDef, kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...], interpret: bool, -) -> jax_core.ClosedJaxpr: +) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]: if interpret: kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval, kernel_avals)) @@ -1174,17 +1178,18 @@ def _trace_kernel_to_jaxpr( if consts: consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c)) for c in consts] - raise ValueError( - f"The kernel function in the pallas_call {name_and_src_info} " - f"captures constants {consts_avals}. " - "You should pass them as inputs") + if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals): + raise ValueError( + f"The kernel function in the pallas_call {name_and_src_info} " + f"captures constants {consts_avals}. " + "You should pass them as inputs") kernel_out_tree = out_tree_thunk() if kernel_out_tree != tree_util.tree_structure(None): raise ValueError( f"The kernel function in the pallas_call {name_and_src_info} " f"should return None. It returns a PyTree: {kernel_out_tree}") - return jaxpr + return jaxpr, tuple(consts) _PALLAS_USE_MOSAIC_GPU = config.bool_flag( @@ -1209,6 +1214,8 @@ def _unsupported_lowering_error(platform: str) -> Exception: def _pallas_call_lowering( ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, **params ): + if params['jaxpr'].constvars: + raise ValueError('Cannot lower a pallas_call with constants.') if interpret: # If we are in interpret mode, we don't care what platform we are on. impl = partial(_pallas_call_impl_interpret, **params) @@ -1286,6 +1293,133 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue: return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype) +def _get_memory_space_from_ref(ref_aval: state.AbstractRef) -> Any: + if isinstance(ref_aval, pallas_core.AbstractMemoryRef): + return ref_aval.memory_space + return pallas_core.MemorySpace.ANY + + +@state_discharge.register_discharge_rule(pallas_call_p) +def _pallas_call_state_discharge_rule( + avals_in, + avals_out, + *args, + jaxpr: jax_core.Jaxpr, + input_output_aliases: tuple[tuple[int, int], ...], + name_and_src_info: pallas_core.NameAndSrcInfo, + grid_mapping: GridMapping, + debug: bool, + interpret: bool, + compiler_params: Any, + cost_estimate: CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], +): + del avals_out + assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars) + num_refs = len(jaxpr.constvars) + ref_avals, rest_in_avals = split_list(avals_in, [num_refs]) + assert all(isinstance(ref_aval, state.AbstractRef) for ref_aval in ref_avals) + ref_avals = [ + pallas_core.AbstractMemoryRef( + ref_aval.inner_aval, pallas_core.MemorySpace.ANY + ) + for ref_aval in ref_avals + ] + ref_block_specs = [ + pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY) + ] * num_refs + ref_block_mappings = [ + block_spec.to_block_mapping( + origin="", # TODO(sharadmv): enable origins for refs + array_aval=ref_aval.inner_aval, + index_map_avals=grid_mapping.index_map_avals, + index_map_tree=grid_mapping.index_map_tree, + grid=grid_mapping.grid, + mapped_dims=grid_mapping.mapped_dims, + ) for ref_aval, block_spec in zip(ref_avals, ref_block_specs) + ] + in_block_mappings, out_block_mappings = split_list( + grid_mapping.block_mappings, [grid_mapping.num_inputs] + ) + new_block_mappings = ( + *ref_block_mappings, + *in_block_mappings, + *ref_block_mappings, + *out_block_mappings, + ) + new_grid_mapping = grid_mapping.replace( + block_mappings=new_block_mappings, + num_inputs=grid_mapping.num_inputs + num_refs, + num_outputs=grid_mapping.num_outputs + num_refs) + new_input_output_aliases = [ + (i + grid_mapping.num_index_operands, i) for i in range(num_refs) + ] + for i, o in input_output_aliases: + new_input_output_aliases.append((i + num_refs, o + num_refs)) + ref_out_avals = [ref_aval.inner_aval for ref_aval in ref_avals] + new_out_avals = (*ref_out_avals, *out_avals) + ref_args, dynamic_grid_bounds, index_operands, rest_args = split_list( + args, + [ + num_refs, + grid_mapping.num_dynamic_grid_bounds, + grid_mapping.num_index_operands, + ], + ) + def _rewritten_body(*args): + index_args, in_args, out_args, rest_args = split_list( + args, [new_grid_mapping.num_index_operands, new_grid_mapping.num_inputs, + new_grid_mapping.num_outputs]) + ref_in_args, in_args = split_list(in_args, [num_refs]) + ref_out_args, out_args = split_list(out_args, [num_refs]) + # We don't care about ref_out_args because they are aliased to ref_in_args + del ref_out_args + jax_core.eval_jaxpr( + jaxpr, ref_in_args, *index_args, *in_args, *out_args, *rest_args + ) + return [] + index_map_avals, jaxpr_in_avals, jaxpr_out_avals, jaxpr_rest_avals = ( + split_list( + [v.aval for v in jaxpr.invars], + [ + grid_mapping.num_index_operands, + grid_mapping.num_inputs, + grid_mapping.num_outputs, + ], + ) + ) + new_jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(_rewritten_body), + [ + *index_map_avals, + *ref_avals, + *jaxpr_in_avals, + *ref_avals, + *jaxpr_out_avals, + *jaxpr_rest_avals, + ], + ) + out_flat = pallas_call_p.bind( + *consts, + *dynamic_grid_bounds, + *index_operands, + *ref_args, + *rest_args, + jaxpr=new_jaxpr, + input_output_aliases=new_input_output_aliases, + grid_mapping=new_grid_mapping, + name_and_src_info=name_and_src_info, + debug=debug, + interpret=interpret, + compiler_params=compiler_params, + cost_estimate=cost_estimate, + out_avals=new_out_avals, + ) + refs_out, rest = split_list(out_flat, [num_refs]) + updated_vals_in = refs_out + [None] * len(rest_in_avals) + return updated_vals_in, rest + + def pallas_call( kernel: Callable[..., None], out_shape: Any, @@ -1440,7 +1574,7 @@ def pallas_call( for x in flat_kernel_args ) with pallas_core.interpret_mode_env(interpret): - jaxpr = _trace_kernel_to_jaxpr( + jaxpr, consts = _trace_kernel_to_jaxpr( kernel, kernel_src_info, grid_mapping, tuple(flat_kernel_avals), kernel_in_tree, kernel_arg_transforms, interpret=interpret) for i_idx, o_idx in input_output_aliases.items(): @@ -1467,6 +1601,7 @@ def pallas_call( index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands]) with pallas_core.interpret_mode_env(interpret): out_flat = pallas_call_p.bind( + *consts, *dynamic_grid_bounds, *index_args, *rest_args, diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index df81f3cb6..eb05bfa30 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -343,8 +343,8 @@ jax_multiplatform_test( ) jax_multiplatform_test( - name = "tpu_pallas_mesh_test", - srcs = ["tpu_pallas_mesh_test.py"], + name = "tpu_pallas_state_test", + srcs = ["tpu_pallas_state_test.py"], enable_backends = ["tpu"], tags = [ "noasan", diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py index 4f9d591db..ef8d3ea89 100644 --- a/tests/pallas/tpu_pallas_async_test.py +++ b/tests/pallas/tpu_pallas_async_test.py @@ -20,6 +20,7 @@ from absl.testing import absltest from absl.testing import parameterized import jax from jax._src import test_util as jtu +from jax._src.state import discharge as state_discharge from jax.experimental import pallas as pl from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu @@ -755,5 +756,123 @@ class PallasCallRemoteAsyncCopyTest(parameterized.TestCase): np.testing.assert_array_equal(y, expected) +def make_stateful_async_copy(): + @jax.named_call + def copy_start(x_ref, o_ref) -> Future: + + def copy_start_kernel(sem): + pltpu.make_async_copy(x_ref, o_ref, sem).start() + sem = pl.pallas_call( + copy_start_kernel, + out_shape=pltpu.SemaphoreType.DMA(()), + out_specs=pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + )() + return sem + + @jax.named_call + def copy_done(x_ref, o_ref, future): + sem = future + + def copy_done_kernel(sem): + pltpu.make_async_copy(x_ref, o_ref, sem).wait() + + () = pl.pallas_call( + copy_done_kernel, + out_shape=(), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + )(sem) + + return copy_start, copy_done + + +def make_stateful_async_slice(i: int): + @jax.named_call + def copy_start(x_ref, o_ref) -> Future: + + def copy_start_kernel(sem): + pltpu.make_async_copy(x_ref.at[i], o_ref, sem).start() + sem = pl.pallas_call( + copy_start_kernel, + out_shape=pltpu.SemaphoreType.DMA(()), + out_specs=pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + )() + return sem + + @jax.named_call + def copy_done(x_ref, o_ref, future): + sem = future + + def copy_done_kernel(sem): + pltpu.make_async_copy(x_ref.at[i], o_ref, sem).wait() + + () = pl.pallas_call( + copy_done_kernel, + out_shape=(), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + )(sem) + + return copy_start, copy_done + + +class PallasCallStatefulAsyncTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('DMAs only guaranteed to work ou TPU v4+') + + def test_basic_stateful_async_copy(self): + @jax.jit + def f(x): + y = jnp.zeros_like(x) + def body(refs): + copy_start, copy_done = make_stateful_async_copy() + x_ref, y_ref = refs + fut = copy_start(x_ref, y_ref) + copy_done(x_ref, y_ref, fut) + _, y = state_discharge.run_state(body)((x, y)) + return y + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_multiple_stateful_async_copy(self): + @jax.jit + def f(x): + y = y2 = jnp.zeros_like(x) + def body(refs): + copy_start, copy_done = make_stateful_async_copy() + x_ref, y_ref, y2_ref = refs + fut = copy_start(x_ref, y_ref) + fut2 = copy_start(x_ref, y2_ref) + copy_done(x_ref, y_ref, fut) + copy_done(x_ref, y2_ref, fut2) + _, y, y2 = state_discharge.run_state(body)((x, y, y2)) + return y, y2 + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + y, y2 = f(x) + np.testing.assert_array_equal(y, x) + np.testing.assert_array_equal(y2, x) + + def test_basic_stateful_async_slice(self): + @jax.jit + def f(x): + y = jnp.zeros(x.shape[1:], x.dtype) + def body(refs): + copy_start, copy_done = make_stateful_async_slice(2) + x_ref, y_ref = refs + fut = copy_start(x_ref, y_ref) + copy_done(x_ref, y_ref, fut) + _, y = state_discharge.run_state(body)((x, y)) + return y + x = jax.random.normal(jax.random.key(0), (4, 8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x[2]) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_mesh_test.py b/tests/pallas/tpu_pallas_mesh_test.py deleted file mode 100644 index 0df759aec..000000000 --- a/tests/pallas/tpu_pallas_mesh_test.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Pallas mesh API.""" - -from absl.testing import absltest -import jax -from jax._src import test_util as jtu -from jax._src.state import discharge as state_discharge -from jax.experimental import pallas as pl -from jax.experimental import shard_map -from jax.experimental.pallas import tpu as pltpu -import jax.numpy as jnp -import numpy as np - - -jax.config.parse_flags_with_absl() - - -class ShmallasTest(jtu.JaxTestCase): - - def setUp(self): - super().setUp() - if not jtu.is_device_tpu_at_least(4): - self.skipTest("Only supported on TPU v4+") - - def test_can_create_tensorcore_mesh(self): - _ = pltpu.create_tensorcore_mesh("x") - - def test_can_trivially_shard_map_with_pallas_mesh(self): - mesh = pltpu.create_tensorcore_mesh("x") - _ = shard_map.shard_map(lambda: None, mesh, in_specs=(), out_specs=None)() - - def test_can_run_basic_pallas_kernel_with_shard_map(self): - mesh = pltpu.create_tensorcore_mesh("x") - - @jax.jit - def f(x): - y = jnp.zeros_like(x) - def inner(refs): - x_ref, y_ref = refs - def kernel(): - def alloc(sem): - pltpu.async_copy(x_ref, y_ref, sem).wait() - pl.run_scoped(alloc, pltpu.SemaphoreType.DMA) - shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None, - check_rep=False)() - _, y = state_discharge.run_state(inner)((x, y)) - return y - x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) - y = f(x) - np.testing.assert_array_equal(y, x) - - def test_can_query_core_index_pallas_kernel_with_shard_map(self): - mesh = pltpu.create_tensorcore_mesh("x") - - @jax.jit - def f(x): - y = jnp.zeros_like(x) - def inner(refs): - x_ref, y_ref = refs - def kernel(): - num_cores = jax.lax.psum(1, "x") - slc_size = 16 // num_cores - def alloc(x_vmem_ref, y_vmem_ref, sem): - core_index = jax.lax.axis_index("x") - slc = pl.ds(core_index * slc_size, slc_size) - pltpu.async_copy( - x_ref.at[slc], - x_vmem_ref, - sem, - ).wait() - y = x_vmem_ref[...] + jax.lax.axis_index("x") - y_vmem_ref[...] = y - pltpu.async_copy(y_vmem_ref, y_ref.at[slc], sem).wait() - pl.run_scoped( - alloc, - pltpu.VMEM((slc_size, 128), x_ref.dtype), - pltpu.VMEM((slc_size, 128), y_ref.dtype), - pltpu.SemaphoreType.DMA, - ) - shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None, - check_rep=False)() - _, y = state_discharge.run_state(inner)((x, y)) - return y - num_cores = jax.devices()[0].num_cores - x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128)) - expected_out = ( - x.reshape((num_cores, -1, 128)) + jnp.arange(num_cores)[..., None, None] - ).reshape(x.shape) - y = f(x) - np.testing.assert_array_equal(y, expected_out) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_state_test.py b/tests/pallas/tpu_pallas_state_test.py new file mode 100644 index 000000000..b017cac2f --- /dev/null +++ b/tests/pallas/tpu_pallas_state_test.py @@ -0,0 +1,271 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Pallas mesh API.""" +import functools +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax._src.state import discharge as state_discharge +from jax.experimental import pallas as pl +from jax.experimental import shard_map +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() + + +class PallasCallStatefulTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Only supported on TPU v4+") + + def test_basic_stateful_kernel(self): + + def copy_kernel(x_ref, y_ref): + def body(sem): + pltpu.make_async_copy(x_ref, y_ref, sem).start() + pltpu.make_async_copy(x_ref, y_ref, sem).wait() + pl.run_scoped(body, pltpu.SemaphoreType.DMA) + + def f_stateful(refs): + x_ref, y_ref = refs + + pl.pallas_call(functools.partial(copy_kernel, x_ref, y_ref), + out_shape=[])() + + @jax.jit + def f(x): + _, y = state_discharge.run_state(f_stateful)((x, jnp.zeros_like(x))) + return y + + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_basic_stateful_kernel_with_scratch_sem(self): + + def copy_kernel(x_ref, y_ref, sem): + pltpu.make_async_copy(x_ref, y_ref, sem).start() + pltpu.make_async_copy(x_ref, y_ref, sem).wait() + + def f_stateful(refs): + x_ref, y_ref = refs + + pl.pallas_call(functools.partial(copy_kernel, x_ref, y_ref), + scratch_shapes=[pltpu.SemaphoreType.DMA], + out_shape=[])() + + @jax.jit + def f(x): + _, y = state_discharge.run_state(f_stateful)((x, jnp.zeros_like(x))) + return y + + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_basic_stateful_kernel_with_scalar_prefetch(self): + + def copy_kernel(x_ref, y_ref, index_ref, sem): + i = index_ref[0] + pltpu.make_async_copy(x_ref.at[i], y_ref, sem).start() + pltpu.make_async_copy(x_ref.at[i], y_ref, sem).wait() + + def f_stateful(refs): + x_ref, y_ref = refs + + pl.pallas_call( + functools.partial(copy_kernel, x_ref, y_ref), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + scratch_shapes=[pltpu.SemaphoreType.DMA], + ), + out_shape=[], + )(jnp.array([0])) + + @jax.jit + def f(x): + _, y = state_discharge.run_state(f_stateful)((x[None], jnp.zeros_like(x))) + return y + + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_basic_stateful_kernel_with_io_aliasing(self): + + def copy_kernel(x_ref, y_ref, x_old_ref, x_old_ref2, sem): + del x_old_ref, x_old_ref2 + pltpu.make_async_copy(x_ref, y_ref, sem).start() + pltpu.make_async_copy(x_ref, y_ref, sem).wait() + + def f_stateful(refs): + x_ref, y_ref, o_ref = refs + + x = pl.pallas_call( + functools.partial(copy_kernel, x_ref, y_ref), + scratch_shapes=[pltpu.SemaphoreType.DMA], + out_shape=jax.ShapeDtypeStruct(x_ref.shape, x_ref.dtype), + input_output_aliases={0: 0}, + )(x_ref[...]) + o_ref[...] = x + + @jax.jit + def f(x): + _, y, o = state_discharge.run_state(f_stateful)( + (x, jnp.zeros_like(x), jnp.zeros_like(x)) + ) + return y, o + + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + y, o = f(x) + np.testing.assert_array_equal(y, x) + np.testing.assert_array_equal(o, x) + + def test_stateful_matmul(self): + + m, k, n = 512, 512, 512 + bm, bk, bn = 128, 128, 128 + + def matmul_kernel(acc_ref, x_ref, y_ref, o_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + acc_ref[...] = jnp.zeros_like(acc_ref) + + acc_ref[...] += jnp.dot( + x_ref[...], y_ref[...], preferred_element_type=jnp.float32 + ) + + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) + def _(): + o_ref[...] = acc_ref[...].astype(o_ref.dtype) + + def matmul(x, y): + + def run_matmul(refs): + x_ref, y_ref, o_ref = refs + + def matmul_pipeline_kernel(acc_ref): + pltpu.emit_pipeline( + functools.partial(matmul_kernel, acc_ref), + grid=(m // bm, n // bn, k // bk), + in_specs=[ + pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)), + pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)), + ], + out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), + )(x_ref, y_ref, o_ref) + + pl.pallas_call( + matmul_pipeline_kernel, + out_shape=[], + scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)], + )() + + _, _, o = state_discharge.run_state(run_matmul)( + (x, y, jnp.ones((m, n), dtype=x.dtype)) + ) + return o + + x = jax.random.normal(jax.random.key(0), (m, k), jnp.float32) + y = jax.random.normal(jax.random.key(1), (k, n), jnp.float32) + o = matmul(x, y) + atol = 0 + if jtu.is_device_tpu(6): + atol = 2e-5 + np.testing.assert_allclose(o, x @ y, atol=atol) + + +class ShmallasTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Only supported on TPU v4+") + + def test_can_create_tensorcore_mesh(self): + _ = pltpu.create_tensorcore_mesh("x") + + def test_can_trivially_shard_map_with_pallas_mesh(self): + mesh = pltpu.create_tensorcore_mesh("x") + _ = shard_map.shard_map(lambda: None, mesh, in_specs=(), out_specs=None)() + + def test_can_run_basic_pallas_kernel_with_shard_map(self): + mesh = pltpu.create_tensorcore_mesh("x") + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + def inner(refs): + x_ref, y_ref = refs + def kernel(): + def alloc(sem): + pltpu.async_copy(x_ref, y_ref, sem).wait() + pl.run_scoped(alloc, pltpu.SemaphoreType.DMA) + shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None, + check_rep=False)() + _, y = state_discharge.run_state(inner)((x, y)) + return y + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_can_query_core_index_pallas_kernel_with_shard_map(self): + mesh = pltpu.create_tensorcore_mesh("x") + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + def inner(refs): + x_ref, y_ref = refs + def kernel(): + num_cores = jax.lax.psum(1, "x") + slc_size = 16 // num_cores + def alloc(x_vmem_ref, y_vmem_ref, sem): + core_index = jax.lax.axis_index("x") + slc = pl.ds(core_index * slc_size, slc_size) + pltpu.async_copy( + x_ref.at[slc], + x_vmem_ref, + sem, + ).wait() + y = x_vmem_ref[...] + jax.lax.axis_index("x") + y_vmem_ref[...] = y + pltpu.async_copy(y_vmem_ref, y_ref.at[slc], sem).wait() + pl.run_scoped( + alloc, + pltpu.VMEM((slc_size, 128), x_ref.dtype), + pltpu.VMEM((slc_size, 128), y_ref.dtype), + pltpu.SemaphoreType.DMA, + ) + shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None, + check_rep=False)() + _, y = state_discharge.run_state(inner)((x, y)) + return y + num_cores = jax.devices()[0].num_cores + x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128)) + expected_out = ( + x.reshape((num_cores, -1, 128)) + jnp.arange(num_cores)[..., None, None] + ).reshape(x.shape) + y = f(x) + np.testing.assert_array_equal(y, expected_out) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())