mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Pallas:MGPU] Allow allocating transformed refs in run_scoped
PiperOrigin-RevId: 688448592
This commit is contained in:
parent
ebb75db8a5
commit
84a303f32f
@ -200,7 +200,7 @@ class MemoryRef:
|
||||
self.shape, dtype, memory_space=self.memory_space
|
||||
)
|
||||
|
||||
def get_ref_aval(self) -> AbstractMemoryRef:
|
||||
def get_ref_aval(self) -> TransformedRef | AbstractMemoryRef:
|
||||
# TODO(sharadmv): Clean this up. ShapedArrayWithMemorySpace fails when we
|
||||
# try to apply JAX ops to it.
|
||||
return AbstractMemoryRef(
|
||||
|
@ -14,6 +14,8 @@
|
||||
|
||||
"""Contains GPU-specific Pallas abstractions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import collections
|
||||
from collections.abc import Sequence
|
||||
@ -73,9 +75,32 @@ class GPUMemorySpace(enum.Enum):
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
|
||||
def __call__(
|
||||
self,
|
||||
shape: tuple[int, ...],
|
||||
dtype: jnp.dtype,
|
||||
transforms: Sequence[MemoryRefTransform] = (),
|
||||
):
|
||||
# A convenience function for constructing MemoryRef types.
|
||||
return pallas_core.MemoryRef(shape, dtype, memory_space=self)
|
||||
return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GPUMemoryRef(pallas_core.MemoryRef):
|
||||
transforms: Sequence[MemoryRefTransform] = ()
|
||||
|
||||
def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef:
|
||||
aval = jax_core.ShapedArray(self.shape, self.dtype)
|
||||
for t in self.transforms:
|
||||
aval = t(aval)
|
||||
ref = pallas_core.TransformedRef(
|
||||
AbstractMemoryRef(aval, memory_space=self.memory_space), ()
|
||||
)
|
||||
for t in reversed(self.transforms):
|
||||
ref = t.undo(ref)
|
||||
if not ref.transforms:
|
||||
return ref.ref
|
||||
return ref
|
||||
|
||||
|
||||
class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC):
|
||||
|
@ -1382,17 +1382,6 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
return new_error, results
|
||||
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
|
||||
|
||||
# All of those shenanigans are because we can't make TransformedRef a PyTree,
|
||||
# because they should appear as atomic JAX values to the users.
|
||||
@lu.transformation
|
||||
def wrap_with_transforms(transforms, *args):
|
||||
new_args = tuple(
|
||||
state_types.TransformedRef(a, t) if t else a
|
||||
for a, t in zip(args, transforms)
|
||||
)
|
||||
res = yield new_args, {}
|
||||
yield res
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _trace_kernel_to_jaxpr(
|
||||
@ -1410,7 +1399,9 @@ def _trace_kernel_to_jaxpr(
|
||||
kernel_avals))
|
||||
wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun), kernel_in_tree)
|
||||
wrapped_kernel_fun = wrap_with_transforms(wrapped_kernel_fun, kernel_in_transforms)
|
||||
wrapped_kernel_fun = primitives.wrap_with_transforms(
|
||||
wrapped_kernel_fun, kernel_in_transforms
|
||||
)
|
||||
debug = pe.debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
|
||||
with grid_mapping.trace_env():
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
|
||||
|
@ -39,6 +39,7 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import indexing
|
||||
from jax._src.state import types as state_types
|
||||
from jax._src.state import primitives as sp
|
||||
from jax.interpreters import mlir
|
||||
import jax.numpy as jnp
|
||||
@ -816,6 +817,20 @@ def debug_print_lowering_rule(ctx, *args, **params):
|
||||
return result
|
||||
|
||||
|
||||
# All of those shenanigans are because we can't make TransformedRef a PyTree,
|
||||
# because they should appear as atomic JAX values to the users.
|
||||
# TODO(apaszke): This can be deleted once we make transforms in Mosaic GPU
|
||||
# inferred by the compiler.
|
||||
@lu.transformation
|
||||
def wrap_with_transforms(transforms, *args):
|
||||
new_args = tuple(
|
||||
state_types.TransformedRef(a, t) if t else a
|
||||
for a, t in zip(args, transforms)
|
||||
)
|
||||
res = yield new_args, {}
|
||||
yield res
|
||||
|
||||
|
||||
run_scoped_p = jax_core.Primitive("run_scoped")
|
||||
run_scoped_p.multiple_results = True
|
||||
|
||||
@ -829,7 +844,17 @@ def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any:
|
||||
"""
|
||||
flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
|
||||
avals = [t.get_ref_aval() for t in flat_types]
|
||||
# We allow ref avals to be transformed references.
|
||||
ref_avals = [t.get_ref_aval() for t in flat_types]
|
||||
avals = [
|
||||
t.ref if isinstance(t, state_types.TransformedRef) else t
|
||||
for t in ref_avals
|
||||
]
|
||||
ref_transforms = tuple(
|
||||
t.transforms if isinstance(t, state_types.TransformedRef) else ()
|
||||
for t in ref_avals
|
||||
)
|
||||
flat_fun = wrap_with_transforms(flat_fun, ref_transforms)
|
||||
# Turn the function into a jaxpr. The body of run_scoped may have
|
||||
# effects (IO) on constvars (i.e. variables inherited from the
|
||||
# parent scope). Jax can't reason about effects to references that
|
||||
|
@ -287,6 +287,29 @@ class PallasCallTest(PallasTest):
|
||||
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
|
||||
np.testing.assert_array_equal(f(x), x)
|
||||
|
||||
def test_scoped_copy_with_transforms(self):
|
||||
ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128))
|
||||
def kernel(x_ref, o_ref, barrier_ref):
|
||||
def body(tmp_ref):
|
||||
plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier=barrier_ref)
|
||||
plgpu.barrier_wait(barrier_ref)
|
||||
o_ref[...] = tmp_ref[...] * 2
|
||||
pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts))
|
||||
|
||||
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
|
||||
out_spec = plgpu.GPUBlockSpec(
|
||||
(128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM,
|
||||
)
|
||||
f = pl.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32),
|
||||
in_specs=(in_spec,),
|
||||
out_specs=out_spec,
|
||||
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
|
||||
)
|
||||
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
|
||||
np.testing.assert_array_equal(f(x), x * 2)
|
||||
|
||||
def test_copy_with_transforms_and_indexing(self):
|
||||
def kernel(x_ref, o_ref, barrier_ref):
|
||||
for i in range(2):
|
||||
|
Loading…
x
Reference in New Issue
Block a user