mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas/MGPU] Undo transforms before giving refs back to users
This is a second attempt at this change. The first one was rolled back because of reported failures. Reverts 411928b9668570bbc3795522aba94cece6894881 PiperOrigin-RevId: 680943744
This commit is contained in:
parent
14ef2b6a21
commit
cac2b8d5fc
@ -40,6 +40,7 @@ from jax._src import util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state.types import TransformedRef
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
@ -496,7 +497,7 @@ class BlockSpec:
|
||||
|
||||
mapping = BlockMapping(
|
||||
block_shape=mapped_block_shape,
|
||||
block_aval=block_aval,
|
||||
transformed_block_aval=block_aval, # There are no transforms by default
|
||||
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
|
||||
index_map_src_info=index_map_src_info,
|
||||
indexing_mode=self.indexing_mode,
|
||||
@ -523,7 +524,7 @@ BlockSpecTree = Any
|
||||
class MemoryRefTransform(Protocol):
|
||||
"""Transforms a memory reference on load or store."""
|
||||
|
||||
def __call__(self, block_aval: AbstractMemoryRef) -> AbstractMemoryRef:
|
||||
def undo(self, ref: TransformedRef) -> TransformedRef:
|
||||
raise NotImplementedError("Abstract evaluation not implemented.")
|
||||
|
||||
|
||||
@ -533,8 +534,10 @@ class BlockMapping:
|
||||
|
||||
See the `check_invariants` method for precise specification.
|
||||
"""
|
||||
# TODO(apaszke,sharadmv): Replace mapped dims in block_shape with a transform.
|
||||
# After all, it's just indexing out singleton dimensions.
|
||||
block_shape: tuple[Mapped | int, ...]
|
||||
block_aval: AbstractMemoryRef # The block ref aval
|
||||
transformed_block_aval: AbstractMemoryRef
|
||||
index_map_jaxpr: jax_core.ClosedJaxpr
|
||||
index_map_src_info: NameAndSrcInfo
|
||||
indexing_mode: IndexingMode
|
||||
@ -546,8 +549,8 @@ class BlockMapping:
|
||||
if not config.enable_checks.value: return
|
||||
|
||||
unmapped_block_shape = tuple(s for s in self.block_shape if s is not mapped)
|
||||
assert unmapped_block_shape == self.block_aval.shape, (
|
||||
self.block_shape, self.block_aval)
|
||||
assert unmapped_block_shape == self.ref_aval.shape, (
|
||||
self.block_shape, self.ref_aval.shape)
|
||||
assert len(self.block_shape) == len(self.array_shape_dtype.shape), (
|
||||
self.block_shape, self.array_shape_dtype
|
||||
)
|
||||
@ -568,12 +571,21 @@ class BlockMapping:
|
||||
return new_self
|
||||
|
||||
@property
|
||||
def ref_aval(self) -> AbstractMemoryRef:
|
||||
def block_aval(self) -> AbstractMemoryRef:
|
||||
# If you hit this, make sure you take transforms into account and use either
|
||||
# ref_aval or transformed_block_aval.
|
||||
assert not self.transforms, "Lowering failed to handle transforms"
|
||||
return self.transformed_block_aval
|
||||
|
||||
@property
|
||||
def ref_aval(self) -> AbstractMemoryRef | TransformedRef:
|
||||
"""Returns the abstract value of the Ref after transformations."""
|
||||
block_aval = self.block_aval
|
||||
for transform in self.transforms:
|
||||
block_aval = transform(block_aval)
|
||||
return block_aval
|
||||
if not self.transforms:
|
||||
return self.transformed_block_aval
|
||||
ref = TransformedRef(self.transformed_block_aval, ())
|
||||
for transform in reversed(self.transforms):
|
||||
ref = transform.undo(ref)
|
||||
return ref
|
||||
|
||||
def compute_start_indices_interpret(self, loop_idx, *args):
|
||||
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
|
||||
|
@ -14,14 +14,16 @@
|
||||
|
||||
"""Contains GPU-specific Pallas abstractions."""
|
||||
|
||||
import abc
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import Any, ClassVar, Literal, Protocol
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import dtypes
|
||||
from jax._src import tree_util
|
||||
from jax._src.state.types import Transform
|
||||
from jax._src.pallas import core as pallas_core
|
||||
import jax.experimental.mosaic.gpu as mgpu
|
||||
import jax.numpy as jnp
|
||||
@ -63,9 +65,15 @@ class GPUMemorySpace(enum.Enum):
|
||||
return pallas_core.MemoryRef(shape, dtype, memory_space=self)
|
||||
|
||||
|
||||
class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol):
|
||||
class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def to_gpu_transform(self) -> mgpu.MemRefTransform:
|
||||
...
|
||||
pass
|
||||
|
||||
def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray:
|
||||
return aval.update(
|
||||
shape=self.to_gpu_transform().transform_shape(aval.shape)
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -79,52 +87,67 @@ class TilingTransform(MemoryRefTransform):
|
||||
|
||||
tiling: tuple[int, ...]
|
||||
|
||||
def __call__(
|
||||
self, block_aval: pallas_core.AbstractMemoryRef
|
||||
) -> pallas_core.AbstractMemoryRef:
|
||||
block_shape = block_aval.shape
|
||||
old_tiled_dims = block_shape[-len(self.tiling) :]
|
||||
num_tiles = tuple(
|
||||
block_dim // tiling_dim
|
||||
for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling)
|
||||
)
|
||||
rem = (
|
||||
block_dim % tiling_dim
|
||||
for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling)
|
||||
)
|
||||
if any(rem):
|
||||
raise ValueError(
|
||||
f"Block shape {block_shape} is not divisible by tiling {self.tiling}"
|
||||
)
|
||||
new_block_shape = block_shape[: -len(self.tiling)] + num_tiles + self.tiling
|
||||
return block_aval.update(
|
||||
inner_aval=block_aval.inner_aval.update(shape=new_block_shape)
|
||||
def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
|
||||
return dataclasses.replace(
|
||||
ref, transforms=(*ref.transforms, UntileRef(self.tiling))
|
||||
)
|
||||
|
||||
def to_gpu_transform(self) -> mgpu.MemRefTransform:
|
||||
return mgpu.TileTransform(self.tiling)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UntileRef(Transform):
|
||||
tiling: tuple[int, ...]
|
||||
|
||||
def transform_shape(self, shape):
|
||||
if shape is None:
|
||||
return None
|
||||
assert shape[-len(self.tiling) :] == self.tiling
|
||||
shape = shape[: -len(self.tiling)] # Drop tiling
|
||||
return shape[: -len(self.tiling)] + tuple(
|
||||
block_dim * tiling_dim
|
||||
for block_dim, tiling_dim in zip(shape[-len(self.tiling) :], self.tiling)
|
||||
)
|
||||
|
||||
def transform_dtype(self, dtype):
|
||||
return dtype
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TransposeTransform(MemoryRefTransform):
|
||||
"""Transpose a tiled memref."""
|
||||
|
||||
permutation: tuple[int, ...]
|
||||
|
||||
def __call__(
|
||||
self, block_aval: pallas_core.AbstractMemoryRef
|
||||
) -> pallas_core.AbstractMemoryRef:
|
||||
shape = block_aval.shape # pytype: disable=attribute-error
|
||||
return block_aval.update(
|
||||
inner_aval=block_aval.inner_aval.update(
|
||||
shape=self.to_gpu_transform().transform_shape(shape)
|
||||
)
|
||||
def __post_init__(self):
|
||||
if set(self.permutation) != set(range(len(self.permutation))):
|
||||
raise ValueError(f"Permutation {self.permutation} is not a permutation.")
|
||||
|
||||
def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
|
||||
inverse = [-1] * len(self.permutation)
|
||||
for i, p in enumerate(self.permutation):
|
||||
inverse[p] = i
|
||||
return dataclasses.replace(
|
||||
ref, transforms=(*ref.transforms, TransposeRef(tuple(inverse)))
|
||||
)
|
||||
|
||||
def to_gpu_transform(self) -> mgpu.MemRefTransform:
|
||||
return mgpu.TransposeTransform(self.permutation)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TransposeRef(Transform):
|
||||
permutation: tuple[int, ...]
|
||||
|
||||
def transform_shape(self, shape):
|
||||
if shape is None:
|
||||
return None
|
||||
return tuple(shape[i] for i in self.permutation)
|
||||
|
||||
def transform_dtype(self, dtype):
|
||||
return dtype
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GPUBlockMapping(pallas_core.BlockMapping):
|
||||
swizzle: int | None = None
|
||||
@ -156,9 +179,14 @@ class GPUBlockSpec(pallas_core.BlockSpec):
|
||||
transforms = self.transforms
|
||||
if not isinstance(transforms, tuple):
|
||||
transforms = (transforms,)
|
||||
block_inner_aval = bm.block_aval.inner_aval
|
||||
for t in transforms:
|
||||
block_inner_aval = t(block_inner_aval)
|
||||
return GPUBlockMapping(
|
||||
block_shape=bm.block_shape,
|
||||
block_aval=bm.block_aval,
|
||||
transformed_block_aval=bm.block_aval.update(
|
||||
inner_aval=block_inner_aval
|
||||
),
|
||||
origin=bm.origin,
|
||||
index_map_jaxpr=bm.index_map_jaxpr,
|
||||
index_map_src_info=bm.index_map_src_info,
|
||||
|
@ -280,7 +280,7 @@ def lower_jaxpr_to_module(
|
||||
|
||||
in_in_smem, out_in_smem = util.split_list(
|
||||
[
|
||||
bm.block_aval.memory_space in (None, gpu_core.SMEM)
|
||||
bm.transformed_block_aval.memory_space in (None, gpu_core.SMEM)
|
||||
for bm in block_mappings
|
||||
],
|
||||
[grid_mapping.num_inputs],
|
||||
@ -290,9 +290,13 @@ def lower_jaxpr_to_module(
|
||||
in_block_mappings, out_block_mappings = util.split_list(
|
||||
block_mappings, [grid_mapping.num_inputs]
|
||||
)
|
||||
# TODO(apaszke): We can shrink allocation if max_concurrent_steps is more than the actual number of steps.
|
||||
# We allocate the fully transformed shapes here. All primitives have seen the
|
||||
# inverse transformation stack and will understand how to handle it.
|
||||
in_structs_smem = [
|
||||
jax.ShapeDtypeStruct(
|
||||
[max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype
|
||||
[max_concurrent_steps, *bm.transformed_block_aval.shape],
|
||||
bm.transformed_block_aval.dtype,
|
||||
)
|
||||
if in_smem
|
||||
else None
|
||||
@ -312,6 +316,9 @@ def lower_jaxpr_to_module(
|
||||
)
|
||||
out_structs_gmem = [*grid_mapping.out_shapes]
|
||||
# TODO(justinfu): Implement output Memref transforms
|
||||
for bm in block_mappings[grid_mapping.num_inputs :]:
|
||||
if bm.transforms:
|
||||
raise NotImplementedError("Output transforms are not supported")
|
||||
out_structs_smem = [
|
||||
jax.ShapeDtypeStruct([max_concurrent_steps, *bm.block_shape], s.dtype)
|
||||
if in_smem
|
||||
|
@ -119,7 +119,7 @@ effects.control_flow_allowed_effects.add_type(_WGMMAPipelineEffect)
|
||||
wgmma_ref_p = jax_core.Primitive("wgmma_ref")
|
||||
wgmma_ref_p.multiple_results = True
|
||||
|
||||
def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128):
|
||||
def wgmma(acc, a, b, *, swizzle: int = 128):
|
||||
"""Asynchronous warp group matmul.
|
||||
|
||||
The sm90 wgmma instruction, essentially acc[...] += a @ b. Requires
|
||||
@ -129,24 +129,49 @@ def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128):
|
||||
acc: The accumulator register.
|
||||
a: The left hand side operand.
|
||||
b: The right hand side operand.
|
||||
transpose: Whether to transpose b.
|
||||
n_tile: The number of tiles to use.
|
||||
swizzle: The swizzle pattern.
|
||||
"""
|
||||
if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef):
|
||||
raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}")
|
||||
|
||||
ma, ka, tma, tka = a.shape
|
||||
kb, nb, tkb, tnb = b.shape
|
||||
mc, nc = acc.shape
|
||||
# TODO(apaszke): Make swizzling another transform and read it from the refs.
|
||||
if not isinstance(a, pallas_core.TransformedRef):
|
||||
raise ValueError("WGMMA inputs must be tiled references.")
|
||||
|
||||
if rhs_transpose:
|
||||
kb, nb, tkb, tnb = nb, kb, tnb, tkb
|
||||
m, n = acc.shape
|
||||
m2, k = a.shape
|
||||
k2, n2 = b.shape
|
||||
|
||||
if tma * ma != mc or nb * tnb != nc or ka != kb or tka != tkb:
|
||||
raise ValueError(f"Incompatible shapes: {a.shape=}, {b.shape=}, {acc.shape=}, {rhs_transpose=}")
|
||||
if m != m2 or n != n2 or k != k2:
|
||||
raise ValueError(
|
||||
f"Incompatible shapes for matrix multiplication: lhs={a.shape},"
|
||||
f" rhs={b.shape=}, acc={acc.shape}"
|
||||
)
|
||||
|
||||
return wgmma_ref_p.bind(acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose)
|
||||
if (dtype := a.dtype) != b.dtype:
|
||||
raise ValueError(f"Mixed input dtypes for matrix multiplication unsupported: lhs={a.dtype}, rhs={b.dtype}")
|
||||
if not isinstance(a, pallas_core.TransformedRef):
|
||||
raise ValueError("WGMMA lhs must be a tiled reference.")
|
||||
if not isinstance(b, pallas_core.TransformedRef):
|
||||
raise ValueError("WGMMA rhs must be a tiled reference.")
|
||||
|
||||
elems_128b = swizzle // dtype.itemsize
|
||||
if a.transforms != (gpu_core.UntileRef((64, elems_128b)),):
|
||||
raise ValueError(
|
||||
f"WGMMA lhs must be tiled with 64x{elems_128b} tiles for element type"
|
||||
f" {dtype}."
|
||||
)
|
||||
rhs_transpose_transform = gpu_core.TransposeRef((1, 0, 2, 3))
|
||||
rhs_tiling = gpu_core.UntileRef((elems_128b, elems_128b))
|
||||
if not (
|
||||
rhs_transpose := (b.transforms == (rhs_transpose_transform, rhs_tiling))
|
||||
) and not (b.transforms == (rhs_tiling,)):
|
||||
raise ValueError(
|
||||
f"WGMMA rhs must be tiled with {elems_128b}x{elems_128b} tiles for"
|
||||
f" element type {dtype} (and optionally transposed)."
|
||||
)
|
||||
|
||||
return wgmma_ref_p.bind(acc, a.ref, b.ref, swizzle=swizzle, rhs_transpose=rhs_transpose)
|
||||
|
||||
|
||||
@wgmma_ref_p.def_effectful_abstract_eval
|
||||
|
@ -39,6 +39,7 @@ from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.pallas import utils as pallas_utils
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import types as state_types
|
||||
from jax._src.util import (
|
||||
safe_map,
|
||||
safe_zip,
|
||||
@ -941,6 +942,7 @@ def _pallas_call_batching_rule(
|
||||
batched_grid_mapping,
|
||||
tuple(flat_kernel_avals),
|
||||
kernel_in_tree,
|
||||
tuple(() for _ in flat_kernel_avals),
|
||||
interpret=interpret,
|
||||
)
|
||||
|
||||
@ -1136,12 +1138,27 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
return new_error, results
|
||||
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
|
||||
|
||||
|
||||
# All of those shenanigans are because we can't make TransformedRef a PyTree,
|
||||
# because they should appear as atomic JAX values to the users.
|
||||
@lu.transformation
|
||||
def wrap_with_transforms(transforms, *args):
|
||||
new_args = tuple(
|
||||
state_types.TransformedRef(a, t) if t else a
|
||||
for a, t in zip(args, transforms)
|
||||
)
|
||||
res = yield new_args, {}
|
||||
yield res
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _trace_kernel_to_jaxpr(fun: Callable,
|
||||
def _trace_kernel_to_jaxpr(
|
||||
fun: Callable,
|
||||
name_and_src_info: pallas_core.NameAndSrcInfo,
|
||||
grid_mapping: GridMapping,
|
||||
kernel_avals: tuple[pallas_core.AbstractMemRef, ...],
|
||||
kernel_in_tree: tree_util.PyTreeDef,
|
||||
kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...],
|
||||
interpret: bool,
|
||||
) -> jax_core.ClosedJaxpr:
|
||||
if interpret:
|
||||
@ -1149,6 +1166,7 @@ def _trace_kernel_to_jaxpr(fun: Callable,
|
||||
kernel_avals))
|
||||
wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun), kernel_in_tree)
|
||||
wrapped_kernel_fun = wrap_with_transforms(wrapped_kernel_fun, kernel_in_transforms)
|
||||
debug = pe.debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
|
||||
with grid_mapping.trace_env():
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
|
||||
@ -1406,16 +1424,25 @@ def pallas_call(
|
||||
for p in in_paths)
|
||||
out_origins = tuple(f"outputs{tree_util.keystr(p)}" for p in out_paths)
|
||||
# TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc.
|
||||
kernel_avals, grid_mapping = pallas_core.get_grid_mapping(
|
||||
kernel_args, grid_mapping = pallas_core.get_grid_mapping(
|
||||
grid_spec,
|
||||
flat_in_avals, in_tree, in_origins,
|
||||
flat_out_avals, out_tree, out_origins)
|
||||
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals)
|
||||
flat_kernel_args, kernel_in_tree = tree_util.tree_flatten(kernel_args)
|
||||
flat_kernel_avals = tuple(
|
||||
x.ref if isinstance(x, state_types.TransformedRef) else x
|
||||
for x in flat_kernel_args
|
||||
)
|
||||
# Note that only a subset of all transforms can be found here, and they are
|
||||
# never expected to contains any arrays.
|
||||
kernel_arg_transforms = tuple(
|
||||
x.transforms if isinstance(x, state_types.TransformedRef) else ()
|
||||
for x in flat_kernel_args
|
||||
)
|
||||
with pallas_core.interpret_mode_env(interpret):
|
||||
jaxpr = _trace_kernel_to_jaxpr(
|
||||
kernel, kernel_src_info,
|
||||
grid_mapping, tuple(flat_kernel_avals), kernel_in_tree,
|
||||
interpret=interpret)
|
||||
kernel, kernel_src_info, grid_mapping, tuple(flat_kernel_avals),
|
||||
kernel_in_tree, kernel_arg_transforms, interpret=interpret)
|
||||
for i_idx, o_idx in input_output_aliases.items():
|
||||
if i_idx not in range(len(flat_in_avals)):
|
||||
raise ValueError(
|
||||
|
@ -256,3 +256,10 @@ class NDIndexer:
|
||||
# In NDIndexers, the int_indexer_shape is *always* at the front of the
|
||||
# result.
|
||||
return (*self.int_indexer_shape, *slice_shape)
|
||||
|
||||
def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]:
|
||||
del shape # Unused
|
||||
return self.get_indexer_shape()
|
||||
|
||||
def transform_dtype(self, dtype):
|
||||
return dtype
|
||||
|
@ -18,8 +18,9 @@ from __future__ import annotations
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import math
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, Protocol
|
||||
|
||||
from jax._src.typing import DTypeLike
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
@ -105,8 +106,39 @@ class RefBitcaster:
|
||||
assert not arrays
|
||||
return cls(*metadata)
|
||||
|
||||
def transform_shape(
|
||||
self, shape: tuple[int | Array, ...] | None
|
||||
) -> tuple[int | Array, ...] | None:
|
||||
del shape # Unused
|
||||
return self.shape
|
||||
|
||||
def transform_dtype(self, dtype):
|
||||
del dtype # Unused
|
||||
return self.dtype
|
||||
|
||||
|
||||
class Transform(Protocol):
|
||||
|
||||
def transform_shape(
|
||||
self, shape: tuple[int | Array, ...] | None
|
||||
) -> tuple[int | Array, ...] | None:
|
||||
"""Transform the shape.
|
||||
|
||||
Can return None if the input shape is not known, but must return a concrete
|
||||
result when the input shape is known.
|
||||
"""
|
||||
return shape
|
||||
|
||||
def transform_dtype(
|
||||
self, dtype: DTypeLike | None
|
||||
) -> DTypeLike | None:
|
||||
"""Transform the dtype.
|
||||
|
||||
Can return None if the input dtype is not known, but must return a concrete
|
||||
result when the input dtype is known.
|
||||
"""
|
||||
return dtype
|
||||
|
||||
Transform = indexing.NDIndexer | RefBitcaster
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RefIndexer:
|
||||
@ -122,30 +154,51 @@ class RefIndexer:
|
||||
return TransformedRef(self.ref_or_view, (indexer,))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TransformedRef:
|
||||
ref: Any
|
||||
transforms: tuple[Transform, ...]
|
||||
|
||||
@property
|
||||
def is_dynamic_size(self):
|
||||
return self.transforms[-1].is_dynamic_size
|
||||
return any(not isinstance(i, int) for i in self.shape)
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int | Array, ...]:
|
||||
assert (
|
||||
len(self.transforms) > 0
|
||||
), "Should not be able to create a trivial TransformedRef"
|
||||
if isinstance(self.transforms[-1], indexing.NDIndexer):
|
||||
return self.transforms[-1].get_indexer_shape()
|
||||
return self.transforms[-1].shape
|
||||
unprocessed, shape = 0, None
|
||||
# We first go backwards to find the first transform that knows its output
|
||||
# shape. It's possible none of them do!
|
||||
for unprocessed, t in enumerate(reversed(self.transforms), 1):
|
||||
if (shape := t.transform_shape(None)) is not None:
|
||||
unprocessed -= 1
|
||||
break
|
||||
if shape is None:
|
||||
shape = self.ref.shape
|
||||
if not unprocessed:
|
||||
return shape
|
||||
# If there are any unprocessed transforms left, we apply them to the shape
|
||||
# we've found previuously.
|
||||
for t in self.transforms[-unprocessed:]:
|
||||
shape = t.transform_shape(shape)
|
||||
assert shape is not None
|
||||
return shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
for transform in reversed(self.transforms):
|
||||
if isinstance(transform, RefBitcaster):
|
||||
return transform.dtype
|
||||
return self.ref.dtype
|
||||
# The structure of this method is analogous to `shape`. See comments there.
|
||||
unprocessed, dtype = 0, None
|
||||
for unprocessed, t in enumerate(reversed(self.transforms), 1):
|
||||
if (dtype := t.transform_dtype(None)) is not None:
|
||||
unprocessed -= 1
|
||||
break
|
||||
if dtype is None:
|
||||
dtype = self.ref.dtype
|
||||
if not unprocessed:
|
||||
return dtype
|
||||
for t in self.transforms[-unprocessed:]:
|
||||
dtype = t.transform_dtype(dtype)
|
||||
assert dtype is not None
|
||||
return dtype
|
||||
|
||||
@property
|
||||
def at(self) -> RefIndexer:
|
||||
|
@ -404,17 +404,13 @@ class PallasCallTest(PallasTest):
|
||||
swizzle=128,
|
||||
),
|
||||
],
|
||||
out_specs=pl.BlockSpec((2, 1, 64, 64), lambda i, j: (i, j, 64, 64)),
|
||||
out_shape=jax.ShapeDtypeStruct((4, 2, 64, 64), jnp.float16),
|
||||
out_shape=jax.ShapeDtypeStruct((128, 64), jnp.float16),
|
||||
grid=(2, 2),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
assert x_ref.shape == (2, 1, 64, 64), x_ref.shape
|
||||
o_ref[...] = x_ref[...]
|
||||
assert x_ref.shape == (128, 64), x_ref.shape
|
||||
|
||||
x = jnp.zeros((256, 128), dtype=jnp.float16)
|
||||
result = kernel(x)
|
||||
self.assertEqual(result.shape, (4, 2, 64, 64))
|
||||
kernel.lower(jax.ShapeDtypeStruct((256, 128), jnp.float16))
|
||||
|
||||
def test_fori_loop_array(self):
|
||||
@functools.partial(
|
||||
@ -473,7 +469,7 @@ class PallasCallTest(PallasTest):
|
||||
elems_128b = swizzle // jnp.dtype(dtype).itemsize
|
||||
def kernel(a_ref, b_ref, o_ref):
|
||||
def scope(acc_ref):
|
||||
plgpu.wgmma(acc_ref, a_ref, b_ref, rhs_transpose=rhs_transpose)
|
||||
plgpu.wgmma(acc_ref, a_ref, b_ref)
|
||||
return acc_ref[...]
|
||||
|
||||
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32))
|
||||
@ -534,6 +530,10 @@ class PallasCallTest(PallasTest):
|
||||
tile_k = elems_128b
|
||||
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
|
||||
def kernel(a_ref, b_ref, o_ref, acc_ref):
|
||||
# Make sure tiling does not alter the shape of references
|
||||
assert a_ref.shape == (tile_m, tile_k)
|
||||
assert b_ref.shape == (tile_k, tile_n)
|
||||
assert o_ref.shape == acc_ref.shape == (tile_m, tile_n)
|
||||
plgpu.wgmma(acc_ref, a_ref, b_ref)
|
||||
plgpu.wgmma_wait(0) # TODO(apaszke): Delay the pipeline to avoid memory races
|
||||
# TODO(apaszke): Only store in the last step. It doesn't work because we
|
||||
|
Loading…
x
Reference in New Issue
Block a user