[Pallas:MGPU] Allow allocating transformed refs in run_scoped

PiperOrigin-RevId: 688448592
This commit is contained in:
Adam Paszke 2024-10-22 01:38:04 -07:00 committed by jax authors
parent ebb75db8a5
commit 84a303f32f
5 changed files with 80 additions and 16 deletions

View File

@ -200,7 +200,7 @@ class MemoryRef:
self.shape, dtype, memory_space=self.memory_space
)
def get_ref_aval(self) -> AbstractMemoryRef:
def get_ref_aval(self) -> TransformedRef | AbstractMemoryRef:
# TODO(sharadmv): Clean this up. ShapedArrayWithMemorySpace fails when we
# try to apply JAX ops to it.
return AbstractMemoryRef(

View File

@ -14,6 +14,8 @@
"""Contains GPU-specific Pallas abstractions."""
from __future__ import annotations
import abc
import collections
from collections.abc import Sequence
@ -73,9 +75,32 @@ class GPUMemorySpace(enum.Enum):
def __str__(self) -> str:
return self.value
def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
def __call__(
self,
shape: tuple[int, ...],
dtype: jnp.dtype,
transforms: Sequence[MemoryRefTransform] = (),
):
# A convenience function for constructing MemoryRef types.
return pallas_core.MemoryRef(shape, dtype, memory_space=self)
return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms)
@dataclasses.dataclass(frozen=True)
class GPUMemoryRef(pallas_core.MemoryRef):
transforms: Sequence[MemoryRefTransform] = ()
def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef:
aval = jax_core.ShapedArray(self.shape, self.dtype)
for t in self.transforms:
aval = t(aval)
ref = pallas_core.TransformedRef(
AbstractMemoryRef(aval, memory_space=self.memory_space), ()
)
for t in reversed(self.transforms):
ref = t.undo(ref)
if not ref.transforms:
return ref.ref
return ref
class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC):

View File

@ -1382,17 +1382,6 @@ def pallas_call_checkify_rule(error: checkify.Error,
return new_error, results
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
# All of those shenanigans are because we can't make TransformedRef a PyTree,
# because they should appear as atomic JAX values to the users.
@lu.transformation
def wrap_with_transforms(transforms, *args):
new_args = tuple(
state_types.TransformedRef(a, t) if t else a
for a, t in zip(args, transforms)
)
res = yield new_args, {}
yield res
@weakref_lru_cache
def _trace_kernel_to_jaxpr(
@ -1410,7 +1399,9 @@ def _trace_kernel_to_jaxpr(
kernel_avals))
wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun), kernel_in_tree)
wrapped_kernel_fun = wrap_with_transforms(wrapped_kernel_fun, kernel_in_transforms)
wrapped_kernel_fun = primitives.wrap_with_transforms(
wrapped_kernel_fun, kernel_in_transforms
)
debug = pe.debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
with grid_mapping.trace_env():
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,

View File

@ -39,6 +39,7 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pallas_core
from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import types as state_types
from jax._src.state import primitives as sp
from jax.interpreters import mlir
import jax.numpy as jnp
@ -816,6 +817,20 @@ def debug_print_lowering_rule(ctx, *args, **params):
return result
# All of those shenanigans are because we can't make TransformedRef a PyTree,
# because they should appear as atomic JAX values to the users.
# TODO(apaszke): This can be deleted once we make transforms in Mosaic GPU
# inferred by the compiler.
@lu.transformation
def wrap_with_transforms(transforms, *args):
new_args = tuple(
state_types.TransformedRef(a, t) if t else a
for a, t in zip(args, transforms)
)
res = yield new_args, {}
yield res
run_scoped_p = jax_core.Primitive("run_scoped")
run_scoped_p.multiple_results = True
@ -829,7 +844,17 @@ def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any:
"""
flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
avals = [t.get_ref_aval() for t in flat_types]
# We allow ref avals to be transformed references.
ref_avals = [t.get_ref_aval() for t in flat_types]
avals = [
t.ref if isinstance(t, state_types.TransformedRef) else t
for t in ref_avals
]
ref_transforms = tuple(
t.transforms if isinstance(t, state_types.TransformedRef) else ()
for t in ref_avals
)
flat_fun = wrap_with_transforms(flat_fun, ref_transforms)
# Turn the function into a jaxpr. The body of run_scoped may have
# effects (IO) on constvars (i.e. variables inherited from the
# parent scope). Jax can't reason about effects to references that

View File

@ -287,6 +287,29 @@ class PallasCallTest(PallasTest):
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
np.testing.assert_array_equal(f(x), x)
def test_scoped_copy_with_transforms(self):
ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128))
def kernel(x_ref, o_ref, barrier_ref):
def body(tmp_ref):
plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier=barrier_ref)
plgpu.barrier_wait(barrier_ref)
o_ref[...] = tmp_ref[...] * 2
pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts))
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
out_spec = plgpu.GPUBlockSpec(
(128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM,
)
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32),
in_specs=(in_spec,),
out_specs=out_spec,
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
)
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
np.testing.assert_array_equal(f(x), x * 2)
def test_copy_with_transforms_and_indexing(self):
def kernel(x_ref, o_ref, barrier_ref):
for i in range(2):