diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 56f513d05..abbd7154d 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -200,7 +200,7 @@ class MemoryRef: self.shape, dtype, memory_space=self.memory_space ) - def get_ref_aval(self) -> AbstractMemoryRef: + def get_ref_aval(self) -> TransformedRef | AbstractMemoryRef: # TODO(sharadmv): Clean this up. ShapedArrayWithMemorySpace fails when we # try to apply JAX ops to it. return AbstractMemoryRef( diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 88a0f887f..2e2694452 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -14,6 +14,8 @@ """Contains GPU-specific Pallas abstractions.""" +from __future__ import annotations + import abc import collections from collections.abc import Sequence @@ -73,9 +75,32 @@ class GPUMemorySpace(enum.Enum): def __str__(self) -> str: return self.value - def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): + def __call__( + self, + shape: tuple[int, ...], + dtype: jnp.dtype, + transforms: Sequence[MemoryRefTransform] = (), + ): # A convenience function for constructing MemoryRef types. - return pallas_core.MemoryRef(shape, dtype, memory_space=self) + return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms) + + +@dataclasses.dataclass(frozen=True) +class GPUMemoryRef(pallas_core.MemoryRef): + transforms: Sequence[MemoryRefTransform] = () + + def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef: + aval = jax_core.ShapedArray(self.shape, self.dtype) + for t in self.transforms: + aval = t(aval) + ref = pallas_core.TransformedRef( + AbstractMemoryRef(aval, memory_space=self.memory_space), () + ) + for t in reversed(self.transforms): + ref = t.undo(ref) + if not ref.transforms: + return ref.ref + return ref class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC): diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d3368c43f..4db6057d2 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1382,17 +1382,6 @@ def pallas_call_checkify_rule(error: checkify.Error, return new_error, results checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule -# All of those shenanigans are because we can't make TransformedRef a PyTree, -# because they should appear as atomic JAX values to the users. -@lu.transformation -def wrap_with_transforms(transforms, *args): - new_args = tuple( - state_types.TransformedRef(a, t) if t else a - for a, t in zip(args, transforms) - ) - res = yield new_args, {} - yield res - @weakref_lru_cache def _trace_kernel_to_jaxpr( @@ -1410,7 +1399,9 @@ def _trace_kernel_to_jaxpr( kernel_avals)) wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(fun), kernel_in_tree) - wrapped_kernel_fun = wrap_with_transforms(wrapped_kernel_fun, kernel_in_transforms) + wrapped_kernel_fun = primitives.wrap_with_transforms( + wrapped_kernel_fun, kernel_in_transforms + ) debug = pe.debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call") with grid_mapping.trace_env(): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun, diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 97655b8df..b41ce3632 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -39,6 +39,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core from jax._src.state import discharge as state_discharge from jax._src.state import indexing +from jax._src.state import types as state_types from jax._src.state import primitives as sp from jax.interpreters import mlir import jax.numpy as jnp @@ -816,6 +817,20 @@ def debug_print_lowering_rule(ctx, *args, **params): return result +# All of those shenanigans are because we can't make TransformedRef a PyTree, +# because they should appear as atomic JAX values to the users. +# TODO(apaszke): This can be deleted once we make transforms in Mosaic GPU +# inferred by the compiler. +@lu.transformation +def wrap_with_transforms(transforms, *args): + new_args = tuple( + state_types.TransformedRef(a, t) if t else a + for a, t in zip(args, transforms) + ) + res = yield new_args, {} + yield res + + run_scoped_p = jax_core.Primitive("run_scoped") run_scoped_p.multiple_results = True @@ -829,7 +844,17 @@ def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any: """ flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) - avals = [t.get_ref_aval() for t in flat_types] + # We allow ref avals to be transformed references. + ref_avals = [t.get_ref_aval() for t in flat_types] + avals = [ + t.ref if isinstance(t, state_types.TransformedRef) else t + for t in ref_avals + ] + ref_transforms = tuple( + t.transforms if isinstance(t, state_types.TransformedRef) else () + for t in ref_avals + ) + flat_fun = wrap_with_transforms(flat_fun, ref_transforms) # Turn the function into a jaxpr. The body of run_scoped may have # effects (IO) on constvars (i.e. variables inherited from the # parent scope). Jax can't reason about effects to references that diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 276e05031..4b86cc21f 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -287,6 +287,29 @@ class PallasCallTest(PallasTest): x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), x) + def test_scoped_copy_with_transforms(self): + ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)) + def kernel(x_ref, o_ref, barrier_ref): + def body(tmp_ref): + plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier=barrier_ref) + plgpu.barrier_wait(barrier_ref) + o_ref[...] = tmp_ref[...] * 2 + pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts)) + + in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) + out_spec = plgpu.GPUBlockSpec( + (128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM, + ) + f = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), + in_specs=(in_spec,), + out_specs=out_spec, + scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + ) + x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + np.testing.assert_array_equal(f(x), x * 2) + def test_copy_with_transforms_and_indexing(self): def kernel(x_ref, o_ref, barrier_ref): for i in range(2):