mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Pallas] Add state discharge rule for pallas_call
This enables us to avoid spurious copies in the cases outlined in [the async operations design note](https://jax.readthedocs.io/en/latest/pallas/async_note.html) but not in general, since JAX and/or XLA could introduce copies because we have value semantics. For a proper solution, we need to introduce some notion of buffer semantics to XLA/HLO and preserve it through the lowering of stateful JAX (maybe by avoiding state discharge altogether). PiperOrigin-RevId: 681206784
This commit is contained in:
parent
e361868132
commit
c34e25d6f4
@ -2762,8 +2762,10 @@ lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule
|
||||
def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
|
||||
device_id_type: tpu_primitives.DeviceIdType):
|
||||
del device_id_type
|
||||
sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, args)
|
||||
sem_aval, _, ref_aval, _ = tree_util.tree_unflatten(tree, ctx.avals_in)
|
||||
(_, _, ref, transforms, sem, sem_transforms, _, _, _) = tree_util.tree_unflatten(
|
||||
tree, args)
|
||||
(_, _, ref_aval, _, sem_aval, _, _, _, _) = tree_util.tree_unflatten(
|
||||
tree, ctx.avals_in)
|
||||
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
|
||||
ref_block_shape = block_shapes[2]
|
||||
ref, _ = _transform_ref(ref, ref_aval.dtype, ref_block_shape, transforms)
|
||||
|
@ -431,18 +431,34 @@ class AsyncCopyDescriptor:
|
||||
def is_remote(self):
|
||||
return self.src_sem is not None
|
||||
|
||||
def _get_args_and_tree(self, swap_src_and_dst: bool = False):
|
||||
if swap_src_and_dst:
|
||||
return tree_util.tree_flatten((
|
||||
self.dst_ref,
|
||||
self.dst_transforms,
|
||||
self.src_ref,
|
||||
self.src_transforms,
|
||||
self.src_sem,
|
||||
self.src_sem_transforms,
|
||||
self.dst_sem,
|
||||
self.dst_sem_transforms,
|
||||
self.device_id,
|
||||
))
|
||||
else:
|
||||
return tree_util.tree_flatten((
|
||||
self.src_ref,
|
||||
self.src_transforms,
|
||||
self.dst_ref,
|
||||
self.dst_transforms,
|
||||
self.dst_sem,
|
||||
self.dst_sem_transforms,
|
||||
self.src_sem,
|
||||
self.src_sem_transforms,
|
||||
self.device_id,
|
||||
))
|
||||
|
||||
def start(self):
|
||||
flat_args, tree = tree_util.tree_flatten((
|
||||
self.src_ref,
|
||||
self.src_transforms,
|
||||
self.dst_ref,
|
||||
self.dst_transforms,
|
||||
self.dst_sem,
|
||||
self.dst_sem_transforms,
|
||||
self.src_sem,
|
||||
self.src_sem_transforms,
|
||||
self.device_id,
|
||||
))
|
||||
flat_args, tree = self._get_args_and_tree()
|
||||
dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type)
|
||||
|
||||
def wait(self):
|
||||
@ -451,27 +467,20 @@ class AsyncCopyDescriptor:
|
||||
self.wait_recv()
|
||||
|
||||
def wait_recv(self):
|
||||
wait_args, tree = tree_util.tree_flatten((
|
||||
self.dst_sem,
|
||||
self.dst_sem_transforms,
|
||||
self.dst_ref,
|
||||
self.dst_transforms,
|
||||
))
|
||||
flat_args, tree = self._get_args_and_tree()
|
||||
dma_wait_p.bind(
|
||||
*wait_args, tree=tree, device_id_type=self.device_id_type
|
||||
*flat_args, tree=tree, device_id_type=self.device_id_type
|
||||
)
|
||||
|
||||
def wait_send(self):
|
||||
if not self.is_remote:
|
||||
raise ValueError("Cannot `wait_send` on a local copy.")
|
||||
wait_args, tree = tree_util.tree_flatten((
|
||||
self.src_sem,
|
||||
self.src_sem_transforms,
|
||||
self.src_ref,
|
||||
self.src_transforms,
|
||||
))
|
||||
# We swap src and dst since by default dma_wait_p waits on the dst_sem
|
||||
# As a clean up, maybe we could modify the primitive to have a
|
||||
# `wait_on_send` bool.
|
||||
flat_args, tree = self._get_args_and_tree(swap_src_and_dst=True)
|
||||
dma_wait_p.bind(
|
||||
*wait_args, tree=tree, device_id_type=self.device_id_type
|
||||
*flat_args, tree=tree, device_id_type=self.device_id_type
|
||||
)
|
||||
|
||||
|
||||
@ -689,7 +698,17 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
|
||||
del settings
|
||||
invars = eqn.invars
|
||||
tree = eqn.params["tree"]
|
||||
sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, invars)
|
||||
(
|
||||
_,
|
||||
_,
|
||||
ref,
|
||||
transforms,
|
||||
sem,
|
||||
sem_transforms,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = tree_util.tree_unflatten(tree, invars)
|
||||
return pp.concat([
|
||||
pp.text("dma_wait"),
|
||||
pp.text(" "),
|
||||
@ -702,29 +721,38 @@ jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn
|
||||
|
||||
def dma_wait_discharge_rule(in_avals, out_avals,
|
||||
*args, tree, device_id_type):
|
||||
# TODO(b/370563115): perform ref update in dma_wait discharge rule instead of dma_start
|
||||
del out_avals, device_id_type
|
||||
(sem, sem_transforms, ref, ref_transforms) = tree_util.tree_unflatten(
|
||||
tree, args
|
||||
)
|
||||
(
|
||||
sem_aval,
|
||||
sem_transforms_avals,
|
||||
_, _, dst_ref, dst_ref_transforms, dst_sem, dst_sem_transforms, _, _, _ = (
|
||||
tree_util.tree_unflatten(tree, args))
|
||||
(_,
|
||||
src_ref_transforms_avals,
|
||||
_,
|
||||
ref_transforms_avals,
|
||||
dst_ref_transforms_avals,
|
||||
dst_sem_aval,
|
||||
dst_sem_transforms_avals,
|
||||
src_sem_aval,
|
||||
src_sem_transforms_avals,
|
||||
device_id_aval,
|
||||
) = tree_util.tree_unflatten(tree, in_avals)
|
||||
num_sem_transforms = len(tree_util.tree_leaves(sem_transforms_avals))
|
||||
num_transforms = len(tree_util.tree_leaves(ref_transforms_avals))
|
||||
updates = state_discharge.transform_array(ref, ref_transforms)
|
||||
num_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals))
|
||||
num_transforms = len(tree_util.tree_leaves(dst_ref_transforms_avals))
|
||||
updates = state_discharge.transform_array(dst_ref, dst_ref_transforms)
|
||||
copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
|
||||
copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
|
||||
sem_value = _transform_semaphore(sem, sem_transforms, sem_aval)
|
||||
sem_value = _transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval)
|
||||
_, new_sem = state_discharge.transform_swap_array(
|
||||
sem, sem_transforms, sem_value - copy_size
|
||||
dst_sem, dst_sem_transforms, sem_value - copy_size
|
||||
)
|
||||
new_vals = (new_sem,) # sem
|
||||
new_vals += (None,) * num_sem_transforms
|
||||
new_vals = (None,) # src_ref
|
||||
new_vals += (None,) * len(tree_util.tree_leaves(src_ref_transforms_avals))
|
||||
new_vals += (None,) # ref
|
||||
new_vals += (None,) * num_transforms
|
||||
new_vals += (None,) * num_transforms # ref_transforms
|
||||
new_vals += (new_sem,) # sem
|
||||
new_vals += (None,) * num_sem_transforms
|
||||
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_aval)) # src_sem
|
||||
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_transforms_avals))
|
||||
new_vals += (None,) * len(tree_util.tree_leaves(device_id_aval)) # device_id
|
||||
return new_vals, []
|
||||
state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule)
|
||||
|
||||
|
@ -30,6 +30,7 @@ from jax._src import config
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -208,6 +209,7 @@ def _pallas_call_impl_interpret(
|
||||
print(discharged_jaxpr)
|
||||
out = _initialize_output_vals(grid_mapping.block_mappings_output,
|
||||
args, input_output_aliases)
|
||||
# TODO(b/370563936): Fix correctness issue w/ io aliasing
|
||||
scalars = args[grid_mapping.slice_index_ops]
|
||||
block_args = args[len(scalars):]
|
||||
# invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch]
|
||||
@ -936,7 +938,7 @@ def _pallas_call_batching_rule(
|
||||
with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()):
|
||||
kernel_src_info: pallas_core.SrcInfoStr = "<Wrapped outer kernel>"
|
||||
|
||||
jaxpr = _trace_kernel_to_jaxpr(
|
||||
jaxpr, consts = _trace_kernel_to_jaxpr(
|
||||
when_wrapped_kernel,
|
||||
kernel_src_info,
|
||||
batched_grid_mapping,
|
||||
@ -945,6 +947,8 @@ def _pallas_call_batching_rule(
|
||||
tuple(() for _ in flat_kernel_avals),
|
||||
interpret=interpret,
|
||||
)
|
||||
if consts:
|
||||
raise NotImplementedError("consts not supported in pallas_call")
|
||||
|
||||
assert ragged_axis_length is not None
|
||||
args = (ragged_axis_length, *args)
|
||||
@ -1160,7 +1164,7 @@ def _trace_kernel_to_jaxpr(
|
||||
kernel_in_tree: tree_util.PyTreeDef,
|
||||
kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...],
|
||||
interpret: bool,
|
||||
) -> jax_core.ClosedJaxpr:
|
||||
) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]:
|
||||
if interpret:
|
||||
kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval,
|
||||
kernel_avals))
|
||||
@ -1174,17 +1178,18 @@ def _trace_kernel_to_jaxpr(
|
||||
if consts:
|
||||
consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c))
|
||||
for c in consts]
|
||||
raise ValueError(
|
||||
f"The kernel function in the pallas_call {name_and_src_info} "
|
||||
f"captures constants {consts_avals}. "
|
||||
"You should pass them as inputs")
|
||||
if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals):
|
||||
raise ValueError(
|
||||
f"The kernel function in the pallas_call {name_and_src_info} "
|
||||
f"captures constants {consts_avals}. "
|
||||
"You should pass them as inputs")
|
||||
|
||||
kernel_out_tree = out_tree_thunk()
|
||||
if kernel_out_tree != tree_util.tree_structure(None):
|
||||
raise ValueError(
|
||||
f"The kernel function in the pallas_call {name_and_src_info} "
|
||||
f"should return None. It returns a PyTree: {kernel_out_tree}")
|
||||
return jaxpr
|
||||
return jaxpr, tuple(consts)
|
||||
|
||||
|
||||
_PALLAS_USE_MOSAIC_GPU = config.bool_flag(
|
||||
@ -1209,6 +1214,8 @@ def _unsupported_lowering_error(platform: str) -> Exception:
|
||||
def _pallas_call_lowering(
|
||||
ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, **params
|
||||
):
|
||||
if params['jaxpr'].constvars:
|
||||
raise ValueError('Cannot lower a pallas_call with constants.')
|
||||
if interpret:
|
||||
# If we are in interpret mode, we don't care what platform we are on.
|
||||
impl = partial(_pallas_call_impl_interpret, **params)
|
||||
@ -1286,6 +1293,133 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue:
|
||||
return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)
|
||||
|
||||
|
||||
def _get_memory_space_from_ref(ref_aval: state.AbstractRef) -> Any:
|
||||
if isinstance(ref_aval, pallas_core.AbstractMemoryRef):
|
||||
return ref_aval.memory_space
|
||||
return pallas_core.MemorySpace.ANY
|
||||
|
||||
|
||||
@state_discharge.register_discharge_rule(pallas_call_p)
|
||||
def _pallas_call_state_discharge_rule(
|
||||
avals_in,
|
||||
avals_out,
|
||||
*args,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
name_and_src_info: pallas_core.NameAndSrcInfo,
|
||||
grid_mapping: GridMapping,
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
):
|
||||
del avals_out
|
||||
assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars)
|
||||
num_refs = len(jaxpr.constvars)
|
||||
ref_avals, rest_in_avals = split_list(avals_in, [num_refs])
|
||||
assert all(isinstance(ref_aval, state.AbstractRef) for ref_aval in ref_avals)
|
||||
ref_avals = [
|
||||
pallas_core.AbstractMemoryRef(
|
||||
ref_aval.inner_aval, pallas_core.MemorySpace.ANY
|
||||
)
|
||||
for ref_aval in ref_avals
|
||||
]
|
||||
ref_block_specs = [
|
||||
pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)
|
||||
] * num_refs
|
||||
ref_block_mappings = [
|
||||
block_spec.to_block_mapping(
|
||||
origin="", # TODO(sharadmv): enable origins for refs
|
||||
array_aval=ref_aval.inner_aval,
|
||||
index_map_avals=grid_mapping.index_map_avals,
|
||||
index_map_tree=grid_mapping.index_map_tree,
|
||||
grid=grid_mapping.grid,
|
||||
mapped_dims=grid_mapping.mapped_dims,
|
||||
) for ref_aval, block_spec in zip(ref_avals, ref_block_specs)
|
||||
]
|
||||
in_block_mappings, out_block_mappings = split_list(
|
||||
grid_mapping.block_mappings, [grid_mapping.num_inputs]
|
||||
)
|
||||
new_block_mappings = (
|
||||
*ref_block_mappings,
|
||||
*in_block_mappings,
|
||||
*ref_block_mappings,
|
||||
*out_block_mappings,
|
||||
)
|
||||
new_grid_mapping = grid_mapping.replace(
|
||||
block_mappings=new_block_mappings,
|
||||
num_inputs=grid_mapping.num_inputs + num_refs,
|
||||
num_outputs=grid_mapping.num_outputs + num_refs)
|
||||
new_input_output_aliases = [
|
||||
(i + grid_mapping.num_index_operands, i) for i in range(num_refs)
|
||||
]
|
||||
for i, o in input_output_aliases:
|
||||
new_input_output_aliases.append((i + num_refs, o + num_refs))
|
||||
ref_out_avals = [ref_aval.inner_aval for ref_aval in ref_avals]
|
||||
new_out_avals = (*ref_out_avals, *out_avals)
|
||||
ref_args, dynamic_grid_bounds, index_operands, rest_args = split_list(
|
||||
args,
|
||||
[
|
||||
num_refs,
|
||||
grid_mapping.num_dynamic_grid_bounds,
|
||||
grid_mapping.num_index_operands,
|
||||
],
|
||||
)
|
||||
def _rewritten_body(*args):
|
||||
index_args, in_args, out_args, rest_args = split_list(
|
||||
args, [new_grid_mapping.num_index_operands, new_grid_mapping.num_inputs,
|
||||
new_grid_mapping.num_outputs])
|
||||
ref_in_args, in_args = split_list(in_args, [num_refs])
|
||||
ref_out_args, out_args = split_list(out_args, [num_refs])
|
||||
# We don't care about ref_out_args because they are aliased to ref_in_args
|
||||
del ref_out_args
|
||||
jax_core.eval_jaxpr(
|
||||
jaxpr, ref_in_args, *index_args, *in_args, *out_args, *rest_args
|
||||
)
|
||||
return []
|
||||
index_map_avals, jaxpr_in_avals, jaxpr_out_avals, jaxpr_rest_avals = (
|
||||
split_list(
|
||||
[v.aval for v in jaxpr.invars],
|
||||
[
|
||||
grid_mapping.num_index_operands,
|
||||
grid_mapping.num_inputs,
|
||||
grid_mapping.num_outputs,
|
||||
],
|
||||
)
|
||||
)
|
||||
new_jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_rewritten_body),
|
||||
[
|
||||
*index_map_avals,
|
||||
*ref_avals,
|
||||
*jaxpr_in_avals,
|
||||
*ref_avals,
|
||||
*jaxpr_out_avals,
|
||||
*jaxpr_rest_avals,
|
||||
],
|
||||
)
|
||||
out_flat = pallas_call_p.bind(
|
||||
*consts,
|
||||
*dynamic_grid_bounds,
|
||||
*index_operands,
|
||||
*ref_args,
|
||||
*rest_args,
|
||||
jaxpr=new_jaxpr,
|
||||
input_output_aliases=new_input_output_aliases,
|
||||
grid_mapping=new_grid_mapping,
|
||||
name_and_src_info=name_and_src_info,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
out_avals=new_out_avals,
|
||||
)
|
||||
refs_out, rest = split_list(out_flat, [num_refs])
|
||||
updated_vals_in = refs_out + [None] * len(rest_in_avals)
|
||||
return updated_vals_in, rest
|
||||
|
||||
|
||||
def pallas_call(
|
||||
kernel: Callable[..., None],
|
||||
out_shape: Any,
|
||||
@ -1440,7 +1574,7 @@ def pallas_call(
|
||||
for x in flat_kernel_args
|
||||
)
|
||||
with pallas_core.interpret_mode_env(interpret):
|
||||
jaxpr = _trace_kernel_to_jaxpr(
|
||||
jaxpr, consts = _trace_kernel_to_jaxpr(
|
||||
kernel, kernel_src_info, grid_mapping, tuple(flat_kernel_avals),
|
||||
kernel_in_tree, kernel_arg_transforms, interpret=interpret)
|
||||
for i_idx, o_idx in input_output_aliases.items():
|
||||
@ -1467,6 +1601,7 @@ def pallas_call(
|
||||
index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
|
||||
with pallas_core.interpret_mode_env(interpret):
|
||||
out_flat = pallas_call_p.bind(
|
||||
*consts,
|
||||
*dynamic_grid_bounds,
|
||||
*index_args,
|
||||
*rest_args,
|
||||
|
@ -343,8 +343,8 @@ jax_multiplatform_test(
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_pallas_mesh_test",
|
||||
srcs = ["tpu_pallas_mesh_test.py"],
|
||||
name = "tpu_pallas_state_test",
|
||||
srcs = ["tpu_pallas_state_test.py"],
|
||||
enable_backends = ["tpu"],
|
||||
tags = [
|
||||
"noasan",
|
||||
|
@ -20,6 +20,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental import shard_map
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
@ -755,5 +756,123 @@ class PallasCallRemoteAsyncCopyTest(parameterized.TestCase):
|
||||
np.testing.assert_array_equal(y, expected)
|
||||
|
||||
|
||||
def make_stateful_async_copy():
|
||||
@jax.named_call
|
||||
def copy_start(x_ref, o_ref) -> Future:
|
||||
|
||||
def copy_start_kernel(sem):
|
||||
pltpu.make_async_copy(x_ref, o_ref, sem).start()
|
||||
sem = pl.pallas_call(
|
||||
copy_start_kernel,
|
||||
out_shape=pltpu.SemaphoreType.DMA(()),
|
||||
out_specs=pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
|
||||
)()
|
||||
return sem
|
||||
|
||||
@jax.named_call
|
||||
def copy_done(x_ref, o_ref, future):
|
||||
sem = future
|
||||
|
||||
def copy_done_kernel(sem):
|
||||
pltpu.make_async_copy(x_ref, o_ref, sem).wait()
|
||||
|
||||
() = pl.pallas_call(
|
||||
copy_done_kernel,
|
||||
out_shape=(),
|
||||
in_specs=[
|
||||
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
|
||||
],
|
||||
)(sem)
|
||||
|
||||
return copy_start, copy_done
|
||||
|
||||
|
||||
def make_stateful_async_slice(i: int):
|
||||
@jax.named_call
|
||||
def copy_start(x_ref, o_ref) -> Future:
|
||||
|
||||
def copy_start_kernel(sem):
|
||||
pltpu.make_async_copy(x_ref.at[i], o_ref, sem).start()
|
||||
sem = pl.pallas_call(
|
||||
copy_start_kernel,
|
||||
out_shape=pltpu.SemaphoreType.DMA(()),
|
||||
out_specs=pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
|
||||
)()
|
||||
return sem
|
||||
|
||||
@jax.named_call
|
||||
def copy_done(x_ref, o_ref, future):
|
||||
sem = future
|
||||
|
||||
def copy_done_kernel(sem):
|
||||
pltpu.make_async_copy(x_ref.at[i], o_ref, sem).wait()
|
||||
|
||||
() = pl.pallas_call(
|
||||
copy_done_kernel,
|
||||
out_shape=(),
|
||||
in_specs=[
|
||||
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
|
||||
],
|
||||
)(sem)
|
||||
|
||||
return copy_start, copy_done
|
||||
|
||||
|
||||
class PallasCallStatefulAsyncTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.is_device_tpu_at_least(4):
|
||||
self.skipTest('DMAs only guaranteed to work ou TPU v4+')
|
||||
|
||||
def test_basic_stateful_async_copy(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = jnp.zeros_like(x)
|
||||
def body(refs):
|
||||
copy_start, copy_done = make_stateful_async_copy()
|
||||
x_ref, y_ref = refs
|
||||
fut = copy_start(x_ref, y_ref)
|
||||
copy_done(x_ref, y_ref, fut)
|
||||
_, y = state_discharge.run_state(body)((x, y))
|
||||
return y
|
||||
x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32)
|
||||
y = f(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_multiple_stateful_async_copy(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = y2 = jnp.zeros_like(x)
|
||||
def body(refs):
|
||||
copy_start, copy_done = make_stateful_async_copy()
|
||||
x_ref, y_ref, y2_ref = refs
|
||||
fut = copy_start(x_ref, y_ref)
|
||||
fut2 = copy_start(x_ref, y2_ref)
|
||||
copy_done(x_ref, y_ref, fut)
|
||||
copy_done(x_ref, y2_ref, fut2)
|
||||
_, y, y2 = state_discharge.run_state(body)((x, y, y2))
|
||||
return y, y2
|
||||
x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32)
|
||||
y, y2 = f(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
np.testing.assert_array_equal(y2, x)
|
||||
|
||||
def test_basic_stateful_async_slice(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = jnp.zeros(x.shape[1:], x.dtype)
|
||||
def body(refs):
|
||||
copy_start, copy_done = make_stateful_async_slice(2)
|
||||
x_ref, y_ref = refs
|
||||
fut = copy_start(x_ref, y_ref)
|
||||
copy_done(x_ref, y_ref, fut)
|
||||
_, y = state_discharge.run_state(body)((x, y))
|
||||
return y
|
||||
x = jax.random.normal(jax.random.key(0), (4, 8, 128), dtype=jnp.float32)
|
||||
y = f(x)
|
||||
np.testing.assert_array_equal(y, x[2])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -1,107 +0,0 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for Pallas mesh API."""
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental import shard_map
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class ShmallasTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.is_device_tpu_at_least(4):
|
||||
self.skipTest("Only supported on TPU v4+")
|
||||
|
||||
def test_can_create_tensorcore_mesh(self):
|
||||
_ = pltpu.create_tensorcore_mesh("x")
|
||||
|
||||
def test_can_trivially_shard_map_with_pallas_mesh(self):
|
||||
mesh = pltpu.create_tensorcore_mesh("x")
|
||||
_ = shard_map.shard_map(lambda: None, mesh, in_specs=(), out_specs=None)()
|
||||
|
||||
def test_can_run_basic_pallas_kernel_with_shard_map(self):
|
||||
mesh = pltpu.create_tensorcore_mesh("x")
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = jnp.zeros_like(x)
|
||||
def inner(refs):
|
||||
x_ref, y_ref = refs
|
||||
def kernel():
|
||||
def alloc(sem):
|
||||
pltpu.async_copy(x_ref, y_ref, sem).wait()
|
||||
pl.run_scoped(alloc, pltpu.SemaphoreType.DMA)
|
||||
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
|
||||
check_rep=False)()
|
||||
_, y = state_discharge.run_state(inner)((x, y))
|
||||
return y
|
||||
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
|
||||
y = f(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_can_query_core_index_pallas_kernel_with_shard_map(self):
|
||||
mesh = pltpu.create_tensorcore_mesh("x")
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = jnp.zeros_like(x)
|
||||
def inner(refs):
|
||||
x_ref, y_ref = refs
|
||||
def kernel():
|
||||
num_cores = jax.lax.psum(1, "x")
|
||||
slc_size = 16 // num_cores
|
||||
def alloc(x_vmem_ref, y_vmem_ref, sem):
|
||||
core_index = jax.lax.axis_index("x")
|
||||
slc = pl.ds(core_index * slc_size, slc_size)
|
||||
pltpu.async_copy(
|
||||
x_ref.at[slc],
|
||||
x_vmem_ref,
|
||||
sem,
|
||||
).wait()
|
||||
y = x_vmem_ref[...] + jax.lax.axis_index("x")
|
||||
y_vmem_ref[...] = y
|
||||
pltpu.async_copy(y_vmem_ref, y_ref.at[slc], sem).wait()
|
||||
pl.run_scoped(
|
||||
alloc,
|
||||
pltpu.VMEM((slc_size, 128), x_ref.dtype),
|
||||
pltpu.VMEM((slc_size, 128), y_ref.dtype),
|
||||
pltpu.SemaphoreType.DMA,
|
||||
)
|
||||
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
|
||||
check_rep=False)()
|
||||
_, y = state_discharge.run_state(inner)((x, y))
|
||||
return y
|
||||
num_cores = jax.devices()[0].num_cores
|
||||
x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128))
|
||||
expected_out = (
|
||||
x.reshape((num_cores, -1, 128)) + jnp.arange(num_cores)[..., None, None]
|
||||
).reshape(x.shape)
|
||||
y = f(x)
|
||||
np.testing.assert_array_equal(y, expected_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
271
tests/pallas/tpu_pallas_state_test.py
Normal file
271
tests/pallas/tpu_pallas_state_test.py
Normal file
@ -0,0 +1,271 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for Pallas mesh API."""
|
||||
import functools
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental import shard_map
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class PallasCallStatefulTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.is_device_tpu_at_least(4):
|
||||
self.skipTest("Only supported on TPU v4+")
|
||||
|
||||
def test_basic_stateful_kernel(self):
|
||||
|
||||
def copy_kernel(x_ref, y_ref):
|
||||
def body(sem):
|
||||
pltpu.make_async_copy(x_ref, y_ref, sem).start()
|
||||
pltpu.make_async_copy(x_ref, y_ref, sem).wait()
|
||||
pl.run_scoped(body, pltpu.SemaphoreType.DMA)
|
||||
|
||||
def f_stateful(refs):
|
||||
x_ref, y_ref = refs
|
||||
|
||||
pl.pallas_call(functools.partial(copy_kernel, x_ref, y_ref),
|
||||
out_shape=[])()
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
_, y = state_discharge.run_state(f_stateful)((x, jnp.zeros_like(x)))
|
||||
return y
|
||||
|
||||
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
|
||||
y = f(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_basic_stateful_kernel_with_scratch_sem(self):
|
||||
|
||||
def copy_kernel(x_ref, y_ref, sem):
|
||||
pltpu.make_async_copy(x_ref, y_ref, sem).start()
|
||||
pltpu.make_async_copy(x_ref, y_ref, sem).wait()
|
||||
|
||||
def f_stateful(refs):
|
||||
x_ref, y_ref = refs
|
||||
|
||||
pl.pallas_call(functools.partial(copy_kernel, x_ref, y_ref),
|
||||
scratch_shapes=[pltpu.SemaphoreType.DMA],
|
||||
out_shape=[])()
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
_, y = state_discharge.run_state(f_stateful)((x, jnp.zeros_like(x)))
|
||||
return y
|
||||
|
||||
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
|
||||
y = f(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_basic_stateful_kernel_with_scalar_prefetch(self):
|
||||
|
||||
def copy_kernel(x_ref, y_ref, index_ref, sem):
|
||||
i = index_ref[0]
|
||||
pltpu.make_async_copy(x_ref.at[i], y_ref, sem).start()
|
||||
pltpu.make_async_copy(x_ref.at[i], y_ref, sem).wait()
|
||||
|
||||
def f_stateful(refs):
|
||||
x_ref, y_ref = refs
|
||||
|
||||
pl.pallas_call(
|
||||
functools.partial(copy_kernel, x_ref, y_ref),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=1,
|
||||
scratch_shapes=[pltpu.SemaphoreType.DMA],
|
||||
),
|
||||
out_shape=[],
|
||||
)(jnp.array([0]))
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
_, y = state_discharge.run_state(f_stateful)((x[None], jnp.zeros_like(x)))
|
||||
return y
|
||||
|
||||
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
|
||||
y = f(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_basic_stateful_kernel_with_io_aliasing(self):
|
||||
|
||||
def copy_kernel(x_ref, y_ref, x_old_ref, x_old_ref2, sem):
|
||||
del x_old_ref, x_old_ref2
|
||||
pltpu.make_async_copy(x_ref, y_ref, sem).start()
|
||||
pltpu.make_async_copy(x_ref, y_ref, sem).wait()
|
||||
|
||||
def f_stateful(refs):
|
||||
x_ref, y_ref, o_ref = refs
|
||||
|
||||
x = pl.pallas_call(
|
||||
functools.partial(copy_kernel, x_ref, y_ref),
|
||||
scratch_shapes=[pltpu.SemaphoreType.DMA],
|
||||
out_shape=jax.ShapeDtypeStruct(x_ref.shape, x_ref.dtype),
|
||||
input_output_aliases={0: 0},
|
||||
)(x_ref[...])
|
||||
o_ref[...] = x
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
_, y, o = state_discharge.run_state(f_stateful)(
|
||||
(x, jnp.zeros_like(x), jnp.zeros_like(x))
|
||||
)
|
||||
return y, o
|
||||
|
||||
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
|
||||
y, o = f(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
np.testing.assert_array_equal(o, x)
|
||||
|
||||
def test_stateful_matmul(self):
|
||||
|
||||
m, k, n = 512, 512, 512
|
||||
bm, bk, bn = 128, 128, 128
|
||||
|
||||
def matmul_kernel(acc_ref, x_ref, y_ref, o_ref):
|
||||
@pl.when(pl.program_id(2) == 0)
|
||||
def _():
|
||||
acc_ref[...] = jnp.zeros_like(acc_ref)
|
||||
|
||||
acc_ref[...] += jnp.dot(
|
||||
x_ref[...], y_ref[...], preferred_element_type=jnp.float32
|
||||
)
|
||||
|
||||
@pl.when(pl.program_id(2) == pl.num_programs(2) - 1)
|
||||
def _():
|
||||
o_ref[...] = acc_ref[...].astype(o_ref.dtype)
|
||||
|
||||
def matmul(x, y):
|
||||
|
||||
def run_matmul(refs):
|
||||
x_ref, y_ref, o_ref = refs
|
||||
|
||||
def matmul_pipeline_kernel(acc_ref):
|
||||
pltpu.emit_pipeline(
|
||||
functools.partial(matmul_kernel, acc_ref),
|
||||
grid=(m // bm, n // bn, k // bk),
|
||||
in_specs=[
|
||||
pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
|
||||
pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)),
|
||||
],
|
||||
out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
|
||||
)(x_ref, y_ref, o_ref)
|
||||
|
||||
pl.pallas_call(
|
||||
matmul_pipeline_kernel,
|
||||
out_shape=[],
|
||||
scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
|
||||
)()
|
||||
|
||||
_, _, o = state_discharge.run_state(run_matmul)(
|
||||
(x, y, jnp.ones((m, n), dtype=x.dtype))
|
||||
)
|
||||
return o
|
||||
|
||||
x = jax.random.normal(jax.random.key(0), (m, k), jnp.float32)
|
||||
y = jax.random.normal(jax.random.key(1), (k, n), jnp.float32)
|
||||
o = matmul(x, y)
|
||||
atol = 0
|
||||
if jtu.is_device_tpu(6):
|
||||
atol = 2e-5
|
||||
np.testing.assert_allclose(o, x @ y, atol=atol)
|
||||
|
||||
|
||||
class ShmallasTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.is_device_tpu_at_least(4):
|
||||
self.skipTest("Only supported on TPU v4+")
|
||||
|
||||
def test_can_create_tensorcore_mesh(self):
|
||||
_ = pltpu.create_tensorcore_mesh("x")
|
||||
|
||||
def test_can_trivially_shard_map_with_pallas_mesh(self):
|
||||
mesh = pltpu.create_tensorcore_mesh("x")
|
||||
_ = shard_map.shard_map(lambda: None, mesh, in_specs=(), out_specs=None)()
|
||||
|
||||
def test_can_run_basic_pallas_kernel_with_shard_map(self):
|
||||
mesh = pltpu.create_tensorcore_mesh("x")
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = jnp.zeros_like(x)
|
||||
def inner(refs):
|
||||
x_ref, y_ref = refs
|
||||
def kernel():
|
||||
def alloc(sem):
|
||||
pltpu.async_copy(x_ref, y_ref, sem).wait()
|
||||
pl.run_scoped(alloc, pltpu.SemaphoreType.DMA)
|
||||
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
|
||||
check_rep=False)()
|
||||
_, y = state_discharge.run_state(inner)((x, y))
|
||||
return y
|
||||
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
|
||||
y = f(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_can_query_core_index_pallas_kernel_with_shard_map(self):
|
||||
mesh = pltpu.create_tensorcore_mesh("x")
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = jnp.zeros_like(x)
|
||||
def inner(refs):
|
||||
x_ref, y_ref = refs
|
||||
def kernel():
|
||||
num_cores = jax.lax.psum(1, "x")
|
||||
slc_size = 16 // num_cores
|
||||
def alloc(x_vmem_ref, y_vmem_ref, sem):
|
||||
core_index = jax.lax.axis_index("x")
|
||||
slc = pl.ds(core_index * slc_size, slc_size)
|
||||
pltpu.async_copy(
|
||||
x_ref.at[slc],
|
||||
x_vmem_ref,
|
||||
sem,
|
||||
).wait()
|
||||
y = x_vmem_ref[...] + jax.lax.axis_index("x")
|
||||
y_vmem_ref[...] = y
|
||||
pltpu.async_copy(y_vmem_ref, y_ref.at[slc], sem).wait()
|
||||
pl.run_scoped(
|
||||
alloc,
|
||||
pltpu.VMEM((slc_size, 128), x_ref.dtype),
|
||||
pltpu.VMEM((slc_size, 128), y_ref.dtype),
|
||||
pltpu.SemaphoreType.DMA,
|
||||
)
|
||||
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
|
||||
check_rep=False)()
|
||||
_, y = state_discharge.run_state(inner)((x, y))
|
||||
return y
|
||||
num_cores = jax.devices()[0].num_cores
|
||||
x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128))
|
||||
expected_out = (
|
||||
x.reshape((num_cores, -1, 128)) + jnp.arange(num_cores)[..., None, None]
|
||||
).reshape(x.shape)
|
||||
y = f(x)
|
||||
np.testing.assert_array_equal(y, expected_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user