mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Implement initial vmap over pallas_call w/ ragged inputs (via jumbles)
The plan here is to load it up with invariants, and start with a really simple kernel. After that, we can slowly relax the various invariants and implement support for others. Note - the work saving here is compute only, not memory yet. A fast-followup CL is adding memory savings via index-map rewriting PiperOrigin-RevId: 663752447
This commit is contained in:
parent
b6306e3953
commit
24394a1b03
@ -1954,6 +1954,7 @@ class DArray:
|
||||
assert data.shape == pad_shape
|
||||
self._aval = aval
|
||||
self._data = data
|
||||
|
||||
shape = property(lambda self: self._aval.shape)
|
||||
dtype = property(lambda self: self._aval.dtype)
|
||||
aval = property(lambda self: self._aval)
|
||||
@ -1964,21 +1965,38 @@ class DArray:
|
||||
|
||||
dtypestr = _short_dtype_name(self._aval.dtype)
|
||||
shapestr = ','.join(map(str, self.shape))
|
||||
slices = tuple(slice(int(d._data)) if type(d) is DArray and
|
||||
type(d.dtype) is bint else slice(None) for d in self.shape)
|
||||
data = self._data[slices]
|
||||
data = self.data
|
||||
return f'{dtypestr}[{shapestr}] with value: {data}'
|
||||
|
||||
def __hash__(self) -> int:
|
||||
if not self.shape:
|
||||
return hash((self._aval, int(self._data)))
|
||||
raise TypeError("unhashable type: DArray")
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, DArray) and self._aval == other._aval:
|
||||
return self._data == other._data
|
||||
return False
|
||||
|
||||
def __len__(self):
|
||||
return self.shape[0]
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
if not self.shape and type(self.dtype) is bint:
|
||||
# special-case scalar bints
|
||||
return self._data
|
||||
|
||||
slices = tuple(
|
||||
slice(int(d._data))
|
||||
if type(d) is DArray and type(d.dtype) is bint
|
||||
else slice(None)
|
||||
for d in self.shape
|
||||
)
|
||||
data = self._data[slices]
|
||||
return data
|
||||
|
||||
|
||||
pytype_aval_mappings[DArray] = \
|
||||
lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type,
|
||||
x._data)
|
||||
|
@ -88,6 +88,7 @@ def _jumble_flatten(jumble):
|
||||
elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape))
|
||||
aval = jumble.aval.replace(elt_ty=elt_ty)
|
||||
return (lengths, jumble.data), aval
|
||||
|
||||
def _jumble_unflatten(aval, x):
|
||||
lengths, data = x
|
||||
new_shape = [d.replace(lengths=lengths[d.lengths - 1])
|
||||
@ -251,7 +252,10 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
|
||||
return (BatchTracer(trace, x, spec, source_info_util.current())
|
||||
if spec is not None else x)
|
||||
else:
|
||||
assert False
|
||||
# TODO(mvoz): This is a terrible place to fall into if you pass
|
||||
# a non jumble type in, make it clearer what went wrong.
|
||||
assert False, f'Unexpected type in ELT? {type(x)}'
|
||||
|
||||
to_elt_handlers: dict[type, ToEltHandler] = {}
|
||||
|
||||
def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int,
|
||||
|
@ -112,7 +112,10 @@ class AbstractMemoryRef(state.AbstractRef):
|
||||
|
||||
def __init__(self, inner_aval: jax_core.AbstractValue,
|
||||
memory_space: Any):
|
||||
assert isinstance(inner_aval, jax_core.ShapedArray)
|
||||
|
||||
assert isinstance(
|
||||
inner_aval, jax_core.ShapedArray
|
||||
), f"Illegal ref, got {type(inner_aval)}"
|
||||
self.inner_aval = inner_aval
|
||||
self.memory_space = memory_space
|
||||
|
||||
@ -167,9 +170,7 @@ class PallasGridContext:
|
||||
mapped_dims: tuple[int, ...]
|
||||
|
||||
def size(self, axis: int) -> int | DynamicGridDim:
|
||||
valid_grid = tuple(
|
||||
s for i, s in enumerate(self.grid) if i not in self.mapped_dims
|
||||
)
|
||||
valid_grid = tuple(self.grid)
|
||||
try:
|
||||
size = valid_grid[axis]
|
||||
except IndexError as e:
|
||||
@ -338,7 +339,10 @@ class BlockMapping:
|
||||
)
|
||||
|
||||
assert not self.index_map_jaxpr.consts
|
||||
assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals)
|
||||
assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals), (
|
||||
self.block_shape,
|
||||
self.index_map_jaxpr.out_avals,
|
||||
)
|
||||
assert all(ov.shape == () and
|
||||
(ov.dtype == jnp.int32 or ov.dtype == jnp.int64)
|
||||
for ov in self.index_map_jaxpr.out_avals), (
|
||||
@ -422,6 +426,8 @@ class GridMapping:
|
||||
num_inputs: int
|
||||
num_outputs: int
|
||||
num_scratch_operands: int
|
||||
get_grid_indices: Callable | None = None
|
||||
local_grid_env: Callable | None = None
|
||||
|
||||
def check_invariants(self) -> None:
|
||||
if not config.enable_checks.value: return
|
||||
@ -442,8 +448,8 @@ class GridMapping:
|
||||
assert len(index_map_args) >= len(self.grid)
|
||||
for i in range(len(self.grid)):
|
||||
index_map_arg = index_map_args[i]
|
||||
assert index_map_arg.shape == ()
|
||||
assert index_map_arg.dtype == jnp.int32
|
||||
assert index_map_arg.shape == (), f"index_map_arg: {index_map_arg}"
|
||||
assert index_map_arg.dtype == jnp.int32, f"index_map_arg: {index_map_arg}"
|
||||
|
||||
assert len(self.vmapped_dims) <= len(self.grid)
|
||||
for i in self.vmapped_dims:
|
||||
@ -454,8 +460,11 @@ class GridMapping:
|
||||
|
||||
for bm in self.block_mappings:
|
||||
bm.check_invariants()
|
||||
assert tuple(self.index_map_avals) == tuple(bm.index_map_jaxpr.in_avals), (
|
||||
assert tuple(self.index_map_avals) == tuple(
|
||||
bm.index_map_jaxpr.in_avals
|
||||
), (
|
||||
self.index_map_avals,
|
||||
"|",
|
||||
bm.index_map_jaxpr.in_avals,
|
||||
)
|
||||
|
||||
@ -547,6 +556,17 @@ def _is_valid_grid_dim(dim: int | jax.Array) -> bool:
|
||||
return True
|
||||
return jax_core.is_dim(dim)
|
||||
|
||||
|
||||
def _max_shape_from_aval(array_aval: jax_core.ShapedArray):
|
||||
array_aval_shape = list(array_aval.shape)
|
||||
for i, s in enumerate(array_aval.shape):
|
||||
aval = jax_core.get_aval(s)
|
||||
if isinstance(aval, jax_core.DShapedArray):
|
||||
array_aval_shape[i] = aval.dtype.bound
|
||||
|
||||
return tuple(array_aval_shape)
|
||||
|
||||
|
||||
def _convert_block_spec_to_block_mapping(
|
||||
block_spec: BlockSpec,
|
||||
origin: OriginStr,
|
||||
@ -575,8 +595,15 @@ def _convert_block_spec_to_block_mapping(
|
||||
f"array shape {array_aval.shape}.")
|
||||
|
||||
unmapped_block_shape = tuple(s for s in block_shape if s is not None)
|
||||
block_aval = AbstractMemoryRef(array_aval.update(shape=unmapped_block_shape),
|
||||
block_spec.memory_space)
|
||||
block_array_aval = array_aval.update(shape=unmapped_block_shape)
|
||||
if isinstance(array_aval, jax_core.DShapedArray):
|
||||
# Get the "max" shape for the ragged array.
|
||||
block_array_aval = jax_core.ShapedArray(
|
||||
block_array_aval.shape,
|
||||
block_array_aval.dtype,
|
||||
block_array_aval.weak_type,
|
||||
)
|
||||
block_aval = AbstractMemoryRef(block_array_aval, block_spec.memory_space)
|
||||
|
||||
if not jax_core.is_constant_shape(block_aval.shape):
|
||||
raise ValueError(
|
||||
@ -609,12 +636,12 @@ def _convert_block_spec_to_block_mapping(
|
||||
f"{origin} must return integer scalars. Output[{i}] has type "
|
||||
f"{ov}.")
|
||||
|
||||
|
||||
if consts:
|
||||
raise ValueError(
|
||||
f"Index map function {index_map_src_info} for "
|
||||
f"{origin} must not capture constants: {consts}")
|
||||
|
||||
array_aval_shape = _max_shape_from_aval(array_aval)
|
||||
|
||||
mapping = BlockMapping(
|
||||
block_shape=mapped_block_shape,
|
||||
@ -622,7 +649,9 @@ def _convert_block_spec_to_block_mapping(
|
||||
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
|
||||
index_map_src_info=index_map_src_info,
|
||||
indexing_mode=block_spec.indexing_mode,
|
||||
array_shape_dtype=jax.ShapeDtypeStruct(array_aval.shape, array_aval.dtype),
|
||||
array_shape_dtype=jax.ShapeDtypeStruct(
|
||||
array_aval_shape, array_aval.dtype
|
||||
),
|
||||
origin=origin,
|
||||
)
|
||||
mapping.check_invariants()
|
||||
|
@ -298,6 +298,7 @@ class MosaicGridMapping:
|
||||
self.jaxpr = jaxpr
|
||||
self.block_mappings = grid_mapping.block_mappings
|
||||
self.mapped_dims = grid_mapping.vmapped_dims
|
||||
# TODO(mvoz): Generalize to not need this
|
||||
user_grid = tuple(
|
||||
g for i, g in enumerate(self.grid) if i not in self.mapped_dims
|
||||
)
|
||||
@ -345,9 +346,19 @@ class MosaicGridMapping:
|
||||
for _ in range(len(self.grid))
|
||||
])
|
||||
self._prepare_mesh_info(mesh)
|
||||
def _get_grid_indices(indices):
|
||||
return indices
|
||||
self.get_grid_indices = _get_grid_indices
|
||||
|
||||
if grid_mapping.get_grid_indices is None:
|
||||
|
||||
def _get_grid_indices(indices, maybe_include_mapped_dims: bool):
|
||||
if maybe_include_mapped_dims:
|
||||
return indices
|
||||
return tuple(
|
||||
idx for i, idx in enumerate(indices) if i not in self.mapped_dims
|
||||
)
|
||||
|
||||
self.get_grid_indices = _get_grid_indices
|
||||
else:
|
||||
self.get_grid_indices = grid_mapping.get_grid_indices
|
||||
|
||||
def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None):
|
||||
if not self.has_communication:
|
||||
@ -595,7 +606,9 @@ def lower_jaxpr_to_transform_func(
|
||||
]
|
||||
def body_func(*args):
|
||||
grid_indices, scalar_prefetch = split_list(args, [num_grid])
|
||||
jaxpr_indices = mosaic_grid_mapping.get_grid_indices(grid_indices)
|
||||
jaxpr_indices = mosaic_grid_mapping.get_grid_indices(
|
||||
grid_indices, maybe_include_mapped_dims=True
|
||||
)
|
||||
arg_block_shapes = [
|
||||
*[()] * len(jaxpr_indices),
|
||||
*mosaic_grid_mapping.scalar_prefetch_block_shapes,
|
||||
@ -663,9 +676,9 @@ def lower_jaxpr_to_func(
|
||||
def body_func(*args):
|
||||
grid_indices, scalar_prefetch, operands_and_scratch = split_list(
|
||||
args, [num_grid, num_scalar_prefetch])
|
||||
grid_indices = mosaic_grid_mapping.get_grid_indices(grid_indices)
|
||||
jaxpr_indices = tuple(idx for i, idx in enumerate(grid_indices)
|
||||
if i not in mosaic_grid_mapping.mapped_dims)
|
||||
jaxpr_indices = mosaic_grid_mapping.get_grid_indices(
|
||||
grid_indices, maybe_include_mapped_dims=False
|
||||
)
|
||||
mesh_info = mosaic_grid_mapping.mesh_info
|
||||
if mesh_info is not None:
|
||||
mesh_context = MeshContext(
|
||||
@ -2365,6 +2378,7 @@ lowering_rules[debugging.debug_callback_p] = _debug_callback_lowering_rule
|
||||
|
||||
|
||||
def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int):
|
||||
|
||||
if ctx.lowering_context.user_grid_indices is None:
|
||||
raise ValueError(
|
||||
f"program id: {axis} was passed, but user did not provide a grid."
|
||||
|
@ -228,6 +228,12 @@ def _pallas_call_impl_interpret(
|
||||
# Pad values to evenly divide into block dimensions. This matches the
|
||||
# behavior of the non-interpret mode. We pad with NaN, to make it easier
|
||||
# to catch OOB accesses.
|
||||
for carry_element in carry:
|
||||
aval = carry_element.aval
|
||||
if isinstance(aval, jax_core.DShapedArray):
|
||||
aval = jax_core.ShapedArray(aval.shape, aval.dtype)
|
||||
carry_element.aval = aval
|
||||
|
||||
carry = map(_pad_values_to_block_dimension, carry, block_shapes)
|
||||
carry.extend(scratch_values)
|
||||
|
||||
@ -247,11 +253,16 @@ def _pallas_call_impl_interpret(
|
||||
return i < num_iterations
|
||||
def body(carry):
|
||||
i, loop_idx, *carry_blocks = carry
|
||||
local_grid_env = tuple(
|
||||
pallas_core.GridAxis(idx, b)
|
||||
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
|
||||
if dim not in grid_mapping.vmapped_dims
|
||||
)
|
||||
|
||||
if grid_mapping.local_grid_env is not None:
|
||||
local_grid_env = grid_mapping.local_grid_env(loop_idx, grid)
|
||||
else:
|
||||
local_grid_env = tuple(
|
||||
pallas_core.GridAxis(idx, b)
|
||||
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
|
||||
if dim not in grid_mapping.vmapped_dims
|
||||
)
|
||||
|
||||
carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks])
|
||||
with pallas_core.grid_env(local_grid_env):
|
||||
start_indices = [
|
||||
@ -268,8 +279,14 @@ def _pallas_call_impl_interpret(
|
||||
len(blocks),
|
||||
len(scratch_values),
|
||||
)
|
||||
blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
|
||||
*blocks, *scratch)
|
||||
for s in scalars:
|
||||
aval = jax_core.get_aval(s)
|
||||
if isinstance(aval, jax_core.DShapedArray):
|
||||
s.aval = aval.update(dtype=jnp.int32)
|
||||
|
||||
blocks = jax_core.eval_jaxpr(
|
||||
discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch
|
||||
)
|
||||
|
||||
_, out_inout, out_scratch = split_list(
|
||||
blocks, [grid_mapping.num_index_operands, num_inout_blocks])
|
||||
@ -390,19 +407,55 @@ def _pallas_call_jvp_rule(
|
||||
|
||||
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
|
||||
|
||||
def _batch_block_mapping(grid_mapping: GridMapping,
|
||||
axis_size: int,
|
||||
aval: jax_core.ShapedArray,
|
||||
dim: int | batching.NotMapped,
|
||||
block_mapping: BlockMapping) -> BlockMapping:
|
||||
|
||||
def _batch_block_mapping(
|
||||
grid_mapping: GridMapping,
|
||||
axis_size: int,
|
||||
aval: jax_core.ShapedArray,
|
||||
dim: int | batching.NotMapped,
|
||||
block_mapping: BlockMapping,
|
||||
for_ragged: bool,
|
||||
) -> BlockMapping:
|
||||
def _block_map_function(new_idx, *args):
|
||||
indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr,
|
||||
block_mapping.index_map_jaxpr.consts,
|
||||
*args)
|
||||
if for_ragged:
|
||||
drop_last_args = args[:-1]
|
||||
else:
|
||||
drop_last_args = args
|
||||
|
||||
indices = jax_core.eval_jaxpr(
|
||||
block_mapping.index_map_jaxpr.jaxpr,
|
||||
block_mapping.index_map_jaxpr.consts,
|
||||
*drop_last_args,
|
||||
)
|
||||
if dim is not batching.not_mapped:
|
||||
indices.insert(dim, new_idx)
|
||||
if isinstance(dim, batching.RaggedAxis):
|
||||
assert for_ragged, "Ragged axis not supported for non-ragged batching."
|
||||
stacked_axis = dim.stacked_axis
|
||||
indices.insert(stacked_axis, new_idx)
|
||||
else:
|
||||
indices.insert(dim, new_idx)
|
||||
return tuple(indices)
|
||||
idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals]
|
||||
|
||||
if for_ragged:
|
||||
if isinstance(dim, batching.RaggedAxis):
|
||||
assert for_ragged, "Ragged axis not supported for non-ragged batching."
|
||||
_, _, ragged_axis_length = _ragged_axis_parts(dim)
|
||||
aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32)
|
||||
if isinstance(aval, jax_core.DShapedArray):
|
||||
aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type)
|
||||
lengths_aval = pallas_core.AbstractMemoryRef(
|
||||
aval,
|
||||
pallas_core.MemorySpace.INDEX,
|
||||
)
|
||||
idx_avals = [*idx_avals, lengths_aval]
|
||||
else:
|
||||
i32_aval_memref = pallas_core.AbstractMemoryRef(
|
||||
jax_core.ShapedArray(([axis_size]), jnp.int32),
|
||||
pallas_core.MemorySpace.INDEX,
|
||||
)
|
||||
idx_avals = [*idx_avals, i32_aval_memref]
|
||||
|
||||
with grid_mapping.trace_env():
|
||||
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_block_map_function), idx_avals)
|
||||
@ -411,12 +464,27 @@ def _batch_block_mapping(grid_mapping: GridMapping,
|
||||
new_block_shape = shape
|
||||
new_array_shape_dtype = block_mapping.array_shape_dtype
|
||||
else:
|
||||
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
|
||||
if isinstance(dim, batching.RaggedAxis):
|
||||
assert for_ragged, "Ragged axis not supported for non-ragged batching."
|
||||
new_block_shape = shape
|
||||
stacked_axis = dim.stacked_axis
|
||||
new_block_shape = tuple_insert(
|
||||
new_block_shape, stacked_axis, pallas_core.mapped
|
||||
)
|
||||
else:
|
||||
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
|
||||
|
||||
array_shape = block_mapping.array_shape_dtype.shape
|
||||
if isinstance(dim, batching.RaggedAxis):
|
||||
assert for_ragged, "Ragged axis not supported for non-ragged batching."
|
||||
stacked_axis = dim.stacked_axis
|
||||
array_shape = tuple_insert(array_shape, stacked_axis, axis_size)
|
||||
else:
|
||||
array_shape = tuple_insert(array_shape, dim, axis_size)
|
||||
|
||||
new_array_shape_dtype = jax.ShapeDtypeStruct(
|
||||
tuple_insert(block_mapping.array_shape_dtype.shape,
|
||||
dim,
|
||||
axis_size),
|
||||
block_mapping.array_shape_dtype.dtype)
|
||||
array_shape, block_mapping.array_shape_dtype.dtype
|
||||
)
|
||||
|
||||
jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)
|
||||
return block_mapping.replace(block_shape=new_block_shape,
|
||||
@ -547,6 +615,16 @@ def _batch_with_explicit_loop(
|
||||
return result, (0,) * len(result)
|
||||
|
||||
|
||||
def _ragged_axis_parts(dim: batching.RaggedAxis) -> tuple[int, int, int]:
|
||||
stacked_axis = dim.stacked_axis
|
||||
ragged_axes = dim.ragged_axes
|
||||
if len(ragged_axes) != 1:
|
||||
raise ValueError("Multiple ragged axes not yet implemented.")
|
||||
ragged_axis_dim = ragged_axes[0][0]
|
||||
ragged_axis_length = ragged_axes[0][1]
|
||||
return stacked_axis, ragged_axis_dim, ragged_axis_length
|
||||
|
||||
|
||||
def _pallas_call_batching_rule(
|
||||
args,
|
||||
dims,
|
||||
@ -567,8 +645,26 @@ def _pallas_call_batching_rule(
|
||||
return x
|
||||
return jnp.squeeze(x, axis=bdim)
|
||||
|
||||
all_ragged_axes = [d for d in dims if isinstance(d, batching.RaggedAxis)]
|
||||
if len(all_ragged_axes) > 1:
|
||||
raise ValueError("Multiple ragged dimensions not yet implemented.")
|
||||
|
||||
if all_ragged_axes:
|
||||
stacked_axis, ragged_axis_dim, ragged_axis_length = _ragged_axis_parts(
|
||||
all_ragged_axes[0]
|
||||
)
|
||||
else:
|
||||
stacked_axis, ragged_axis_dim, ragged_axis_length = None, None, None
|
||||
|
||||
def get_size(i, x, d):
|
||||
if not isinstance(d, batching.RaggedAxis):
|
||||
return x.shape[d]
|
||||
return x.aval.shape[i]
|
||||
|
||||
(axis_size,) = {
|
||||
x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped
|
||||
get_size(i=i, x=x, d=d)
|
||||
for i, (x, d) in enumerate(zip(args, dims))
|
||||
if d is not batching.not_mapped
|
||||
}
|
||||
if axis_size == 1:
|
||||
# Why are we even vmapping?
|
||||
@ -670,12 +766,27 @@ def _pallas_call_batching_rule(
|
||||
num_index_operands = grid_mapping.num_index_operands
|
||||
num_scratch_operands = grid_mapping.num_scratch_operands
|
||||
|
||||
lengths_aval = None
|
||||
if ragged_axis_length is not None:
|
||||
aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32)
|
||||
if isinstance(aval, jax_core.DShapedArray):
|
||||
aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type)
|
||||
lengths_aval = pallas_core.AbstractMemoryRef(
|
||||
aval,
|
||||
pallas_core.MemorySpace.INDEX,
|
||||
)
|
||||
|
||||
# Only add a batch dimension for the avals that actually have a grid mapping.
|
||||
# This excludes scalar prefetch inputs (the first in the list) and scratch
|
||||
# operands (the last in the list).
|
||||
avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)]
|
||||
batched_block_mappings = map(
|
||||
partial(_batch_block_mapping, grid_mapping, axis_size),
|
||||
partial(
|
||||
_batch_block_mapping,
|
||||
grid_mapping,
|
||||
axis_size,
|
||||
for_ragged=lengths_aval is not None,
|
||||
),
|
||||
avals_to_batch,
|
||||
all_dims[num_index_operands:],
|
||||
block_mappings,
|
||||
@ -685,15 +796,23 @@ def _pallas_call_batching_rule(
|
||||
grid_mapping.index_map_avals)
|
||||
assert not index_map_tree_kwargs
|
||||
batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args
|
||||
|
||||
if lengths_aval:
|
||||
batched_index_map_args = batched_index_map_args + (lengths_aval,)
|
||||
num_index_operands += 1
|
||||
|
||||
batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten(
|
||||
(batched_index_map_args, {}))
|
||||
|
||||
batched_grid_mapping = grid_mapping.replace(
|
||||
grid=(axis_size, *grid_mapping.grid),
|
||||
block_mappings=tuple(batched_block_mappings),
|
||||
index_map_avals=batched_index_map_avals,
|
||||
index_map_avals=tuple(batched_index_map_avals),
|
||||
index_map_tree=batched_index_map_tree,
|
||||
num_index_operands=num_index_operands,
|
||||
vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims),
|
||||
)
|
||||
|
||||
if cost_estimate is not None:
|
||||
batched_cost_estimate = CostEstimate(
|
||||
flops=cost_estimate.flops * axis_size,
|
||||
@ -702,6 +821,103 @@ def _pallas_call_batching_rule(
|
||||
)
|
||||
else:
|
||||
batched_cost_estimate = None
|
||||
|
||||
if lengths_aval:
|
||||
batched_grid_mapping = batched_grid_mapping.replace(
|
||||
get_grid_indices=lambda indices, maybe_include_mapped_dims: indices,
|
||||
local_grid_env=lambda loop_idx, grid: tuple(
|
||||
pallas_core.GridAxis(idx, b) for (idx, b) in zip(loop_idx, grid)
|
||||
),
|
||||
)
|
||||
|
||||
# Note - on zero filling counterfactuals
|
||||
# A debug util to produce a counterfactual version of the when
|
||||
# gating, where for all values that don't pass the @when check,
|
||||
# we write 0s. This is useful for debugging, as certain lowering paths
|
||||
# like mosaic will write the last data as passthrough, leading to
|
||||
# potentially confusing results.
|
||||
debug_zero_fill_counterfactual = debug
|
||||
|
||||
first_block_mapping = batched_grid_mapping.block_mappings[0]
|
||||
for block_mapping in batched_grid_mapping.block_mappings:
|
||||
# This invariant may already be checked elsewhere, but lets reaffirm it
|
||||
assert block_mapping.block_shape == first_block_mapping.block_shape, (
|
||||
f"block_mapping.block_shape: {block_mapping.block_shape}, "
|
||||
f"first_block_mapping.block_shape: {first_block_mapping.block_shape}"
|
||||
)
|
||||
assert (
|
||||
block_mapping.array_shape_dtype
|
||||
== first_block_mapping.array_shape_dtype
|
||||
), (
|
||||
f"block_mapping.array_shape_dtype: {block_mapping.array_shape_dtype},"
|
||||
" first_block_mapping.array_shape_dtype:"
|
||||
f" {first_block_mapping.array_shape_dtype}"
|
||||
)
|
||||
|
||||
mapped_dim_idxs = [
|
||||
i
|
||||
for i, d in enumerate(first_block_mapping.block_shape)
|
||||
if d is pallas_core.mapped
|
||||
]
|
||||
assert len(mapped_dim_idxs) == 1
|
||||
mapped_dim_idx = mapped_dim_idxs[0]
|
||||
if stacked_axis != mapped_dim_idx:
|
||||
raise ValueError(
|
||||
f"Expected mapped dim to be {stacked_axis}, but got {mapped_dim_idx}"
|
||||
)
|
||||
|
||||
assert ragged_axis_dim is not None, "Invariant violation"
|
||||
# This is the blockspec size of the dimension
|
||||
val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim]
|
||||
|
||||
def when_wrapped_kernel(lengths_ref, *args, **kwargs):
|
||||
b_idx = jax.experimental.pallas.program_id(stacked_axis)
|
||||
i_idx = (
|
||||
jax.experimental.pallas.program_id(ragged_axis_dim)
|
||||
* val_at_ragged_dim
|
||||
)
|
||||
b_len = lengths_ref[b_idx]
|
||||
|
||||
# TODO(mvoz): Unimplemented primitive in pallas
|
||||
# b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0)
|
||||
# checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0")
|
||||
|
||||
@jax.experimental.pallas.when(i_idx < b_len)
|
||||
def f():
|
||||
# Important! This allows us to trace the inner kernel with the correct
|
||||
# grid to preserve user program_id semantics. Ex: program_id(0) will
|
||||
# always be analogous to program_id(1) in the outer kernel.
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs)
|
||||
|
||||
if debug_zero_fill_counterfactual:
|
||||
|
||||
@jax.experimental.pallas.when(i_idx >= b_len)
|
||||
def g():
|
||||
for arg_ref in args:
|
||||
arg_ref[...] = jnp.zeros_like(arg_ref)
|
||||
|
||||
kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars]
|
||||
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(
|
||||
list(kernel_avals)
|
||||
)
|
||||
# Important! This allows us to trace the outer kernel with the correct grid
|
||||
# to enable accessing the batch program_id.
|
||||
with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()):
|
||||
kernel_src_info: pallas_core.SrcInfoStr = "<Wrapped outer kernel>"
|
||||
|
||||
jaxpr = _trace_kernel_to_jaxpr(
|
||||
when_wrapped_kernel,
|
||||
kernel_src_info,
|
||||
batched_grid_mapping,
|
||||
tuple(flat_kernel_avals),
|
||||
kernel_in_tree,
|
||||
interpret=interpret,
|
||||
)
|
||||
|
||||
assert ragged_axis_length is not None
|
||||
args = (ragged_axis_length, *args)
|
||||
|
||||
out = pallas_call_p.bind(
|
||||
*dynamic_grid_args,
|
||||
*args,
|
||||
@ -1097,12 +1313,14 @@ def pallas_call(
|
||||
out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths)
|
||||
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) # type: ignore
|
||||
for x in flat_out_shapes]
|
||||
|
||||
@jax.jit
|
||||
def wrapped(*args):
|
||||
flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args)
|
||||
in_paths, flat_args = unzip2(flat_args_with_paths)
|
||||
flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a))
|
||||
for a in flat_args)
|
||||
|
||||
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
|
||||
for v in flat_out_shapes)
|
||||
|
||||
@ -1172,15 +1390,18 @@ def pallas_call(
|
||||
return wrapped
|
||||
|
||||
|
||||
def in_path_to_input_origin(in_path: tree_util.KeyPath,
|
||||
arg_names: tuple[str, ...] | None) -> pallas_core.OriginStr:
|
||||
def in_path_to_input_origin(
|
||||
in_path: tree_util.KeyPath, arg_names: tuple[str, ...] | None
|
||||
) -> pallas_core.OriginStr:
|
||||
"""Converts `args[k]<rest>` into `arg_k_name<rest>`."""
|
||||
if arg_names is None:
|
||||
return f"args{tree_util.keystr(in_path)}"
|
||||
if len(in_path) == 0:
|
||||
return "args"
|
||||
arg_idx, *rest_path = in_path
|
||||
if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(arg_names):
|
||||
if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(
|
||||
arg_names
|
||||
):
|
||||
return arg_names[arg_idx.idx] + tree_util.keystr(tuple(rest_path))
|
||||
else:
|
||||
return f"args{tree_util.keystr(tuple(in_path))}"
|
||||
|
@ -62,6 +62,29 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "pallas_jumble_test",
|
||||
srcs = [
|
||||
"pallas_jumble_test.py",
|
||||
],
|
||||
disable_configs = [
|
||||
"gpu",
|
||||
"gpu_x32",
|
||||
"gpu_a100",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
"gpu_h100",
|
||||
],
|
||||
shard_count = {
|
||||
"tpu": 1,
|
||||
},
|
||||
deps = [
|
||||
"//jax:pallas",
|
||||
"//jax:pallas_tpu",
|
||||
"//jax:pallas_tpu_ops",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "ops_test",
|
||||
srcs = [
|
||||
|
201
tests/pallas/pallas_jumble_test.py
Normal file
201
tests/pallas/pallas_jumble_test.py
Normal file
@ -0,0 +1,201 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
|
||||
from jax.experimental import pallas as pl
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
# TODO(mvoz): Update signatures of pallas_call to correct inputs/outputs.
|
||||
# pylint: disable=no-value-for-parameter
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
intx = dtypes.canonicalize_dtype(jnp.int64)
|
||||
floatx = dtypes.canonicalize_dtype(jnp.float64)
|
||||
|
||||
|
||||
@jtu.with_config(jax_traceback_filtering="off")
|
||||
class PallasBaseTest(jtu.JaxTestCase):
|
||||
INTERPRET = False
|
||||
|
||||
def setUp(self):
|
||||
if jtu.test_device_matches(["cpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On CPU the test works only in interpret mode")
|
||||
if jtu.test_device_matches(
|
||||
["cuda"]
|
||||
) and not jtu.is_cuda_compute_capability_at_least("8.0"):
|
||||
self.skipTest("Only works on GPU with capability >= sm80")
|
||||
if sys.platform == "win32" and not self.INTERPRET:
|
||||
self.skipTest("Only works on non-Windows platforms")
|
||||
|
||||
super().setUp()
|
||||
_trace_kernel_to_jaxpr.cache_clear()
|
||||
|
||||
def pallas_call(self, *args, **kwargs):
|
||||
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
|
||||
|
||||
|
||||
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow")
|
||||
class PallasCallRaggedVmapTest(PallasBaseTest):
|
||||
|
||||
def test_vmap_jumble_over_sin_kernel(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Only tested on TPU")
|
||||
|
||||
row_count = 8
|
||||
col_grid_size = 5
|
||||
ragged_shape = [3, 1, 4]
|
||||
sizes = lax.convert_element_type(
|
||||
jnp.array([128 * x for x in ragged_shape]),
|
||||
core.bint(col_grid_size * 128),
|
||||
)
|
||||
x = jax.vmap(
|
||||
lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis
|
||||
)(sizes)
|
||||
|
||||
def kernel(x_ref, o_ref):
|
||||
o_ref[...] = jnp.sin(x_ref[...])
|
||||
|
||||
def invoke_kernel(x):
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))],
|
||||
out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)),
|
||||
out_shape=jax.ShapeDtypeStruct(
|
||||
(8, col_grid_size * 128), dtype=jnp.float32
|
||||
),
|
||||
grid=(1, col_grid_size),
|
||||
interpret=self.INTERPRET,
|
||||
# See note - on zero filling counterfactuals
|
||||
debug=True,
|
||||
)(x)
|
||||
|
||||
res = jax.vmap(
|
||||
invoke_kernel,
|
||||
out_axes=batching.jumble_axis,
|
||||
in_axes=batching.jumble_axis,
|
||||
axis_size=3,
|
||||
)(x)
|
||||
|
||||
res = res.data
|
||||
total = len(ragged_shape) * row_count * col_grid_size * 128
|
||||
res_total = np.prod(res.shape)
|
||||
self.assertEqual(res_total, total)
|
||||
ragged_total = 0
|
||||
for dim in ragged_shape:
|
||||
ragged_total += row_count * dim * 128
|
||||
# See note - on zero filling counterfactuals
|
||||
self.assertEqual(np.count_nonzero(res == jnp.sin(1.0)), ragged_total)
|
||||
|
||||
def test_vmap_jumble_over_sin_kernel_grid_remapping(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Only tested on TPU")
|
||||
|
||||
row_count = 8
|
||||
col_grid_size = 5
|
||||
ragged_shape = [3, 1, 4]
|
||||
sizes = lax.convert_element_type(
|
||||
jnp.array([128 * x for x in ragged_shape]),
|
||||
core.bint(col_grid_size * 128),
|
||||
)
|
||||
x = jax.vmap(
|
||||
lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis
|
||||
)(sizes)
|
||||
|
||||
def kernel(x_ref, o_ref):
|
||||
o_ref[...] = jnp.sin(x_ref[...]) * pl.program_id(2)
|
||||
|
||||
def invoke_kernel(x):
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))],
|
||||
out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)),
|
||||
out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32),
|
||||
grid=(1, 5),
|
||||
interpret=False,
|
||||
)(x)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"):
|
||||
jax.vmap(
|
||||
invoke_kernel,
|
||||
out_axes=batching.jumble_axis,
|
||||
in_axes=batching.jumble_axis,
|
||||
axis_size=3,
|
||||
)(x)
|
||||
|
||||
def test_vmap_jumble_ragged_boundary_unaligned_with_grid(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Only tested on TPU")
|
||||
|
||||
self.skipTest("Checkify NYI")
|
||||
|
||||
row_count = 8
|
||||
col_grid_size = 5
|
||||
ragged_shape = [3, 1, 4]
|
||||
sizes = lax.convert_element_type(
|
||||
jnp.array([(128 * x) - 1 for x in ragged_shape]),
|
||||
core.bint(col_grid_size * 128),
|
||||
)
|
||||
x = jax.vmap(
|
||||
lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis
|
||||
)(sizes)
|
||||
|
||||
def kernel(x_ref, o_ref):
|
||||
o_ref[...] = jnp.sin(x_ref[...])
|
||||
|
||||
def invoke_kernel(x):
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))],
|
||||
out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)),
|
||||
out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32),
|
||||
grid=(1, 5),
|
||||
interpret=False,
|
||||
)(x)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Ragged input shape must be evenly divisble by the grid" # noqa: W605
|
||||
" size at the ragged dimension 2",
|
||||
):
|
||||
jax.vmap(
|
||||
invoke_kernel,
|
||||
out_axes=batching.jumble_axis,
|
||||
in_axes=batching.jumble_axis,
|
||||
axis_size=3,
|
||||
)(x)
|
||||
|
||||
|
||||
class PallasCallNamedGridInterpretTest(PallasCallRaggedVmapTest):
|
||||
INTERPRET = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
Loading…
x
Reference in New Issue
Block a user