Implement initial vmap over pallas_call w/ ragged inputs (via jumbles)

The plan here is to load it up with invariants, and start with a really simple kernel. After that, we can slowly relax the various invariants and implement support for others.

Note - the work saving here is compute only, not memory yet. A fast-followup CL is adding memory savings via index-map rewriting

PiperOrigin-RevId: 663752447
This commit is contained in:
jax authors 2024-08-16 09:20:13 -07:00 committed by jax authors
parent b6306e3953
commit 24394a1b03
7 changed files with 560 additions and 50 deletions

View File

@ -1954,6 +1954,7 @@ class DArray:
assert data.shape == pad_shape
self._aval = aval
self._data = data
shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)
aval = property(lambda self: self._aval)
@ -1964,21 +1965,38 @@ class DArray:
dtypestr = _short_dtype_name(self._aval.dtype)
shapestr = ','.join(map(str, self.shape))
slices = tuple(slice(int(d._data)) if type(d) is DArray and
type(d.dtype) is bint else slice(None) for d in self.shape)
data = self._data[slices]
data = self.data
return f'{dtypestr}[{shapestr}] with value: {data}'
def __hash__(self) -> int:
if not self.shape:
return hash((self._aval, int(self._data)))
raise TypeError("unhashable type: DArray")
def __eq__(self, other):
if isinstance(other, DArray) and self._aval == other._aval:
return self._data == other._data
return False
def __len__(self):
return self.shape[0]
@property
def data(self):
if not self.shape and type(self.dtype) is bint:
# special-case scalar bints
return self._data
slices = tuple(
slice(int(d._data))
if type(d) is DArray and type(d.dtype) is bint
else slice(None)
for d in self.shape
)
data = self._data[slices]
return data
pytype_aval_mappings[DArray] = \
lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type,
x._data)

View File

@ -88,6 +88,7 @@ def _jumble_flatten(jumble):
elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape))
aval = jumble.aval.replace(elt_ty=elt_ty)
return (lengths, jumble.data), aval
def _jumble_unflatten(aval, x):
lengths, data = x
new_shape = [d.replace(lengths=lengths[d.lengths - 1])
@ -251,7 +252,10 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
return (BatchTracer(trace, x, spec, source_info_util.current())
if spec is not None else x)
else:
assert False
# TODO(mvoz): This is a terrible place to fall into if you pass
# a non jumble type in, make it clearer what went wrong.
assert False, f'Unexpected type in ELT? {type(x)}'
to_elt_handlers: dict[type, ToEltHandler] = {}
def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int,

View File

@ -112,7 +112,10 @@ class AbstractMemoryRef(state.AbstractRef):
def __init__(self, inner_aval: jax_core.AbstractValue,
memory_space: Any):
assert isinstance(inner_aval, jax_core.ShapedArray)
assert isinstance(
inner_aval, jax_core.ShapedArray
), f"Illegal ref, got {type(inner_aval)}"
self.inner_aval = inner_aval
self.memory_space = memory_space
@ -167,9 +170,7 @@ class PallasGridContext:
mapped_dims: tuple[int, ...]
def size(self, axis: int) -> int | DynamicGridDim:
valid_grid = tuple(
s for i, s in enumerate(self.grid) if i not in self.mapped_dims
)
valid_grid = tuple(self.grid)
try:
size = valid_grid[axis]
except IndexError as e:
@ -338,7 +339,10 @@ class BlockMapping:
)
assert not self.index_map_jaxpr.consts
assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals)
assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals), (
self.block_shape,
self.index_map_jaxpr.out_avals,
)
assert all(ov.shape == () and
(ov.dtype == jnp.int32 or ov.dtype == jnp.int64)
for ov in self.index_map_jaxpr.out_avals), (
@ -422,6 +426,8 @@ class GridMapping:
num_inputs: int
num_outputs: int
num_scratch_operands: int
get_grid_indices: Callable | None = None
local_grid_env: Callable | None = None
def check_invariants(self) -> None:
if not config.enable_checks.value: return
@ -442,8 +448,8 @@ class GridMapping:
assert len(index_map_args) >= len(self.grid)
for i in range(len(self.grid)):
index_map_arg = index_map_args[i]
assert index_map_arg.shape == ()
assert index_map_arg.dtype == jnp.int32
assert index_map_arg.shape == (), f"index_map_arg: {index_map_arg}"
assert index_map_arg.dtype == jnp.int32, f"index_map_arg: {index_map_arg}"
assert len(self.vmapped_dims) <= len(self.grid)
for i in self.vmapped_dims:
@ -454,8 +460,11 @@ class GridMapping:
for bm in self.block_mappings:
bm.check_invariants()
assert tuple(self.index_map_avals) == tuple(bm.index_map_jaxpr.in_avals), (
assert tuple(self.index_map_avals) == tuple(
bm.index_map_jaxpr.in_avals
), (
self.index_map_avals,
"|",
bm.index_map_jaxpr.in_avals,
)
@ -547,6 +556,17 @@ def _is_valid_grid_dim(dim: int | jax.Array) -> bool:
return True
return jax_core.is_dim(dim)
def _max_shape_from_aval(array_aval: jax_core.ShapedArray):
array_aval_shape = list(array_aval.shape)
for i, s in enumerate(array_aval.shape):
aval = jax_core.get_aval(s)
if isinstance(aval, jax_core.DShapedArray):
array_aval_shape[i] = aval.dtype.bound
return tuple(array_aval_shape)
def _convert_block_spec_to_block_mapping(
block_spec: BlockSpec,
origin: OriginStr,
@ -575,8 +595,15 @@ def _convert_block_spec_to_block_mapping(
f"array shape {array_aval.shape}.")
unmapped_block_shape = tuple(s for s in block_shape if s is not None)
block_aval = AbstractMemoryRef(array_aval.update(shape=unmapped_block_shape),
block_spec.memory_space)
block_array_aval = array_aval.update(shape=unmapped_block_shape)
if isinstance(array_aval, jax_core.DShapedArray):
# Get the "max" shape for the ragged array.
block_array_aval = jax_core.ShapedArray(
block_array_aval.shape,
block_array_aval.dtype,
block_array_aval.weak_type,
)
block_aval = AbstractMemoryRef(block_array_aval, block_spec.memory_space)
if not jax_core.is_constant_shape(block_aval.shape):
raise ValueError(
@ -609,12 +636,12 @@ def _convert_block_spec_to_block_mapping(
f"{origin} must return integer scalars. Output[{i}] has type "
f"{ov}.")
if consts:
raise ValueError(
f"Index map function {index_map_src_info} for "
f"{origin} must not capture constants: {consts}")
array_aval_shape = _max_shape_from_aval(array_aval)
mapping = BlockMapping(
block_shape=mapped_block_shape,
@ -622,7 +649,9 @@ def _convert_block_spec_to_block_mapping(
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
index_map_src_info=index_map_src_info,
indexing_mode=block_spec.indexing_mode,
array_shape_dtype=jax.ShapeDtypeStruct(array_aval.shape, array_aval.dtype),
array_shape_dtype=jax.ShapeDtypeStruct(
array_aval_shape, array_aval.dtype
),
origin=origin,
)
mapping.check_invariants()

View File

@ -298,6 +298,7 @@ class MosaicGridMapping:
self.jaxpr = jaxpr
self.block_mappings = grid_mapping.block_mappings
self.mapped_dims = grid_mapping.vmapped_dims
# TODO(mvoz): Generalize to not need this
user_grid = tuple(
g for i, g in enumerate(self.grid) if i not in self.mapped_dims
)
@ -345,9 +346,19 @@ class MosaicGridMapping:
for _ in range(len(self.grid))
])
self._prepare_mesh_info(mesh)
def _get_grid_indices(indices):
return indices
self.get_grid_indices = _get_grid_indices
if grid_mapping.get_grid_indices is None:
def _get_grid_indices(indices, maybe_include_mapped_dims: bool):
if maybe_include_mapped_dims:
return indices
return tuple(
idx for i, idx in enumerate(indices) if i not in self.mapped_dims
)
self.get_grid_indices = _get_grid_indices
else:
self.get_grid_indices = grid_mapping.get_grid_indices
def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None):
if not self.has_communication:
@ -595,7 +606,9 @@ def lower_jaxpr_to_transform_func(
]
def body_func(*args):
grid_indices, scalar_prefetch = split_list(args, [num_grid])
jaxpr_indices = mosaic_grid_mapping.get_grid_indices(grid_indices)
jaxpr_indices = mosaic_grid_mapping.get_grid_indices(
grid_indices, maybe_include_mapped_dims=True
)
arg_block_shapes = [
*[()] * len(jaxpr_indices),
*mosaic_grid_mapping.scalar_prefetch_block_shapes,
@ -663,9 +676,9 @@ def lower_jaxpr_to_func(
def body_func(*args):
grid_indices, scalar_prefetch, operands_and_scratch = split_list(
args, [num_grid, num_scalar_prefetch])
grid_indices = mosaic_grid_mapping.get_grid_indices(grid_indices)
jaxpr_indices = tuple(idx for i, idx in enumerate(grid_indices)
if i not in mosaic_grid_mapping.mapped_dims)
jaxpr_indices = mosaic_grid_mapping.get_grid_indices(
grid_indices, maybe_include_mapped_dims=False
)
mesh_info = mosaic_grid_mapping.mesh_info
if mesh_info is not None:
mesh_context = MeshContext(
@ -2365,6 +2378,7 @@ lowering_rules[debugging.debug_callback_p] = _debug_callback_lowering_rule
def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int):
if ctx.lowering_context.user_grid_indices is None:
raise ValueError(
f"program id: {axis} was passed, but user did not provide a grid."

View File

@ -228,6 +228,12 @@ def _pallas_call_impl_interpret(
# Pad values to evenly divide into block dimensions. This matches the
# behavior of the non-interpret mode. We pad with NaN, to make it easier
# to catch OOB accesses.
for carry_element in carry:
aval = carry_element.aval
if isinstance(aval, jax_core.DShapedArray):
aval = jax_core.ShapedArray(aval.shape, aval.dtype)
carry_element.aval = aval
carry = map(_pad_values_to_block_dimension, carry, block_shapes)
carry.extend(scratch_values)
@ -247,11 +253,16 @@ def _pallas_call_impl_interpret(
return i < num_iterations
def body(carry):
i, loop_idx, *carry_blocks = carry
local_grid_env = tuple(
pallas_core.GridAxis(idx, b)
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
if dim not in grid_mapping.vmapped_dims
)
if grid_mapping.local_grid_env is not None:
local_grid_env = grid_mapping.local_grid_env(loop_idx, grid)
else:
local_grid_env = tuple(
pallas_core.GridAxis(idx, b)
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
if dim not in grid_mapping.vmapped_dims
)
carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks])
with pallas_core.grid_env(local_grid_env):
start_indices = [
@ -268,8 +279,14 @@ def _pallas_call_impl_interpret(
len(blocks),
len(scratch_values),
)
blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
*blocks, *scratch)
for s in scalars:
aval = jax_core.get_aval(s)
if isinstance(aval, jax_core.DShapedArray):
s.aval = aval.update(dtype=jnp.int32)
blocks = jax_core.eval_jaxpr(
discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch
)
_, out_inout, out_scratch = split_list(
blocks, [grid_mapping.num_index_operands, num_inout_blocks])
@ -390,19 +407,55 @@ def _pallas_call_jvp_rule(
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
def _batch_block_mapping(grid_mapping: GridMapping,
axis_size: int,
aval: jax_core.ShapedArray,
dim: int | batching.NotMapped,
block_mapping: BlockMapping) -> BlockMapping:
def _batch_block_mapping(
grid_mapping: GridMapping,
axis_size: int,
aval: jax_core.ShapedArray,
dim: int | batching.NotMapped,
block_mapping: BlockMapping,
for_ragged: bool,
) -> BlockMapping:
def _block_map_function(new_idx, *args):
indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr,
block_mapping.index_map_jaxpr.consts,
*args)
if for_ragged:
drop_last_args = args[:-1]
else:
drop_last_args = args
indices = jax_core.eval_jaxpr(
block_mapping.index_map_jaxpr.jaxpr,
block_mapping.index_map_jaxpr.consts,
*drop_last_args,
)
if dim is not batching.not_mapped:
indices.insert(dim, new_idx)
if isinstance(dim, batching.RaggedAxis):
assert for_ragged, "Ragged axis not supported for non-ragged batching."
stacked_axis = dim.stacked_axis
indices.insert(stacked_axis, new_idx)
else:
indices.insert(dim, new_idx)
return tuple(indices)
idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals]
if for_ragged:
if isinstance(dim, batching.RaggedAxis):
assert for_ragged, "Ragged axis not supported for non-ragged batching."
_, _, ragged_axis_length = _ragged_axis_parts(dim)
aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32)
if isinstance(aval, jax_core.DShapedArray):
aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type)
lengths_aval = pallas_core.AbstractMemoryRef(
aval,
pallas_core.MemorySpace.INDEX,
)
idx_avals = [*idx_avals, lengths_aval]
else:
i32_aval_memref = pallas_core.AbstractMemoryRef(
jax_core.ShapedArray(([axis_size]), jnp.int32),
pallas_core.MemorySpace.INDEX,
)
idx_avals = [*idx_avals, i32_aval_memref]
with grid_mapping.trace_env():
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_block_map_function), idx_avals)
@ -411,12 +464,27 @@ def _batch_block_mapping(grid_mapping: GridMapping,
new_block_shape = shape
new_array_shape_dtype = block_mapping.array_shape_dtype
else:
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
if isinstance(dim, batching.RaggedAxis):
assert for_ragged, "Ragged axis not supported for non-ragged batching."
new_block_shape = shape
stacked_axis = dim.stacked_axis
new_block_shape = tuple_insert(
new_block_shape, stacked_axis, pallas_core.mapped
)
else:
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
array_shape = block_mapping.array_shape_dtype.shape
if isinstance(dim, batching.RaggedAxis):
assert for_ragged, "Ragged axis not supported for non-ragged batching."
stacked_axis = dim.stacked_axis
array_shape = tuple_insert(array_shape, stacked_axis, axis_size)
else:
array_shape = tuple_insert(array_shape, dim, axis_size)
new_array_shape_dtype = jax.ShapeDtypeStruct(
tuple_insert(block_mapping.array_shape_dtype.shape,
dim,
axis_size),
block_mapping.array_shape_dtype.dtype)
array_shape, block_mapping.array_shape_dtype.dtype
)
jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)
return block_mapping.replace(block_shape=new_block_shape,
@ -547,6 +615,16 @@ def _batch_with_explicit_loop(
return result, (0,) * len(result)
def _ragged_axis_parts(dim: batching.RaggedAxis) -> tuple[int, int, int]:
stacked_axis = dim.stacked_axis
ragged_axes = dim.ragged_axes
if len(ragged_axes) != 1:
raise ValueError("Multiple ragged axes not yet implemented.")
ragged_axis_dim = ragged_axes[0][0]
ragged_axis_length = ragged_axes[0][1]
return stacked_axis, ragged_axis_dim, ragged_axis_length
def _pallas_call_batching_rule(
args,
dims,
@ -567,8 +645,26 @@ def _pallas_call_batching_rule(
return x
return jnp.squeeze(x, axis=bdim)
all_ragged_axes = [d for d in dims if isinstance(d, batching.RaggedAxis)]
if len(all_ragged_axes) > 1:
raise ValueError("Multiple ragged dimensions not yet implemented.")
if all_ragged_axes:
stacked_axis, ragged_axis_dim, ragged_axis_length = _ragged_axis_parts(
all_ragged_axes[0]
)
else:
stacked_axis, ragged_axis_dim, ragged_axis_length = None, None, None
def get_size(i, x, d):
if not isinstance(d, batching.RaggedAxis):
return x.shape[d]
return x.aval.shape[i]
(axis_size,) = {
x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped
get_size(i=i, x=x, d=d)
for i, (x, d) in enumerate(zip(args, dims))
if d is not batching.not_mapped
}
if axis_size == 1:
# Why are we even vmapping?
@ -670,12 +766,27 @@ def _pallas_call_batching_rule(
num_index_operands = grid_mapping.num_index_operands
num_scratch_operands = grid_mapping.num_scratch_operands
lengths_aval = None
if ragged_axis_length is not None:
aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32)
if isinstance(aval, jax_core.DShapedArray):
aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type)
lengths_aval = pallas_core.AbstractMemoryRef(
aval,
pallas_core.MemorySpace.INDEX,
)
# Only add a batch dimension for the avals that actually have a grid mapping.
# This excludes scalar prefetch inputs (the first in the list) and scratch
# operands (the last in the list).
avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)]
batched_block_mappings = map(
partial(_batch_block_mapping, grid_mapping, axis_size),
partial(
_batch_block_mapping,
grid_mapping,
axis_size,
for_ragged=lengths_aval is not None,
),
avals_to_batch,
all_dims[num_index_operands:],
block_mappings,
@ -685,15 +796,23 @@ def _pallas_call_batching_rule(
grid_mapping.index_map_avals)
assert not index_map_tree_kwargs
batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args
if lengths_aval:
batched_index_map_args = batched_index_map_args + (lengths_aval,)
num_index_operands += 1
batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten(
(batched_index_map_args, {}))
batched_grid_mapping = grid_mapping.replace(
grid=(axis_size, *grid_mapping.grid),
block_mappings=tuple(batched_block_mappings),
index_map_avals=batched_index_map_avals,
index_map_avals=tuple(batched_index_map_avals),
index_map_tree=batched_index_map_tree,
num_index_operands=num_index_operands,
vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims),
)
if cost_estimate is not None:
batched_cost_estimate = CostEstimate(
flops=cost_estimate.flops * axis_size,
@ -702,6 +821,103 @@ def _pallas_call_batching_rule(
)
else:
batched_cost_estimate = None
if lengths_aval:
batched_grid_mapping = batched_grid_mapping.replace(
get_grid_indices=lambda indices, maybe_include_mapped_dims: indices,
local_grid_env=lambda loop_idx, grid: tuple(
pallas_core.GridAxis(idx, b) for (idx, b) in zip(loop_idx, grid)
),
)
# Note - on zero filling counterfactuals
# A debug util to produce a counterfactual version of the when
# gating, where for all values that don't pass the @when check,
# we write 0s. This is useful for debugging, as certain lowering paths
# like mosaic will write the last data as passthrough, leading to
# potentially confusing results.
debug_zero_fill_counterfactual = debug
first_block_mapping = batched_grid_mapping.block_mappings[0]
for block_mapping in batched_grid_mapping.block_mappings:
# This invariant may already be checked elsewhere, but lets reaffirm it
assert block_mapping.block_shape == first_block_mapping.block_shape, (
f"block_mapping.block_shape: {block_mapping.block_shape}, "
f"first_block_mapping.block_shape: {first_block_mapping.block_shape}"
)
assert (
block_mapping.array_shape_dtype
== first_block_mapping.array_shape_dtype
), (
f"block_mapping.array_shape_dtype: {block_mapping.array_shape_dtype},"
" first_block_mapping.array_shape_dtype:"
f" {first_block_mapping.array_shape_dtype}"
)
mapped_dim_idxs = [
i
for i, d in enumerate(first_block_mapping.block_shape)
if d is pallas_core.mapped
]
assert len(mapped_dim_idxs) == 1
mapped_dim_idx = mapped_dim_idxs[0]
if stacked_axis != mapped_dim_idx:
raise ValueError(
f"Expected mapped dim to be {stacked_axis}, but got {mapped_dim_idx}"
)
assert ragged_axis_dim is not None, "Invariant violation"
# This is the blockspec size of the dimension
val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim]
def when_wrapped_kernel(lengths_ref, *args, **kwargs):
b_idx = jax.experimental.pallas.program_id(stacked_axis)
i_idx = (
jax.experimental.pallas.program_id(ragged_axis_dim)
* val_at_ragged_dim
)
b_len = lengths_ref[b_idx]
# TODO(mvoz): Unimplemented primitive in pallas
# b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0)
# checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0")
@jax.experimental.pallas.when(i_idx < b_len)
def f():
# Important! This allows us to trace the inner kernel with the correct
# grid to preserve user program_id semantics. Ex: program_id(0) will
# always be analogous to program_id(1) in the outer kernel.
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs)
if debug_zero_fill_counterfactual:
@jax.experimental.pallas.when(i_idx >= b_len)
def g():
for arg_ref in args:
arg_ref[...] = jnp.zeros_like(arg_ref)
kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars]
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(
list(kernel_avals)
)
# Important! This allows us to trace the outer kernel with the correct grid
# to enable accessing the batch program_id.
with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()):
kernel_src_info: pallas_core.SrcInfoStr = "<Wrapped outer kernel>"
jaxpr = _trace_kernel_to_jaxpr(
when_wrapped_kernel,
kernel_src_info,
batched_grid_mapping,
tuple(flat_kernel_avals),
kernel_in_tree,
interpret=interpret,
)
assert ragged_axis_length is not None
args = (ragged_axis_length, *args)
out = pallas_call_p.bind(
*dynamic_grid_args,
*args,
@ -1097,12 +1313,14 @@ def pallas_call(
out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths)
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) # type: ignore
for x in flat_out_shapes]
@jax.jit
def wrapped(*args):
flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args)
in_paths, flat_args = unzip2(flat_args_with_paths)
flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a))
for a in flat_args)
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
for v in flat_out_shapes)
@ -1172,15 +1390,18 @@ def pallas_call(
return wrapped
def in_path_to_input_origin(in_path: tree_util.KeyPath,
arg_names: tuple[str, ...] | None) -> pallas_core.OriginStr:
def in_path_to_input_origin(
in_path: tree_util.KeyPath, arg_names: tuple[str, ...] | None
) -> pallas_core.OriginStr:
"""Converts `args[k]<rest>` into `arg_k_name<rest>`."""
if arg_names is None:
return f"args{tree_util.keystr(in_path)}"
if len(in_path) == 0:
return "args"
arg_idx, *rest_path = in_path
if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(arg_names):
if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(
arg_names
):
return arg_names[arg_idx.idx] + tree_util.keystr(tuple(rest_path))
else:
return f"args{tree_util.keystr(tuple(in_path))}"

View File

@ -62,6 +62,29 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_test(
name = "pallas_jumble_test",
srcs = [
"pallas_jumble_test.py",
],
disable_configs = [
"gpu",
"gpu_x32",
"gpu_a100",
"gpu_p100",
"gpu_p100_x32",
"gpu_h100",
],
shard_count = {
"tpu": 1,
},
deps = [
"//jax:pallas",
"//jax:pallas_tpu",
"//jax:pallas_tpu_ops",
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_test(
name = "ops_test",
srcs = [

View File

@ -0,0 +1,201 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
from absl.testing import absltest
import jax
from jax import lax
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.interpreters import batching
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
# TODO(mvoz): Update signatures of pallas_call to correct inputs/outputs.
# pylint: disable=no-value-for-parameter
config.parse_flags_with_absl()
intx = dtypes.canonicalize_dtype(jnp.int64)
floatx = dtypes.canonicalize_dtype(jnp.float64)
@jtu.with_config(jax_traceback_filtering="off")
class PallasBaseTest(jtu.JaxTestCase):
INTERPRET = False
def setUp(self):
if jtu.test_device_matches(["cpu"]) and not self.INTERPRET:
self.skipTest("On CPU the test works only in interpret mode")
if jtu.test_device_matches(
["cuda"]
) and not jtu.is_cuda_compute_capability_at_least("8.0"):
self.skipTest("Only works on GPU with capability >= sm80")
if sys.platform == "win32" and not self.INTERPRET:
self.skipTest("Only works on non-Windows platforms")
super().setUp()
_trace_kernel_to_jaxpr.cache_clear()
def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow")
class PallasCallRaggedVmapTest(PallasBaseTest):
def test_vmap_jumble_over_sin_kernel(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Only tested on TPU")
row_count = 8
col_grid_size = 5
ragged_shape = [3, 1, 4]
sizes = lax.convert_element_type(
jnp.array([128 * x for x in ragged_shape]),
core.bint(col_grid_size * 128),
)
x = jax.vmap(
lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis
)(sizes)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.sin(x_ref[...])
def invoke_kernel(x):
return pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))],
out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)),
out_shape=jax.ShapeDtypeStruct(
(8, col_grid_size * 128), dtype=jnp.float32
),
grid=(1, col_grid_size),
interpret=self.INTERPRET,
# See note - on zero filling counterfactuals
debug=True,
)(x)
res = jax.vmap(
invoke_kernel,
out_axes=batching.jumble_axis,
in_axes=batching.jumble_axis,
axis_size=3,
)(x)
res = res.data
total = len(ragged_shape) * row_count * col_grid_size * 128
res_total = np.prod(res.shape)
self.assertEqual(res_total, total)
ragged_total = 0
for dim in ragged_shape:
ragged_total += row_count * dim * 128
# See note - on zero filling counterfactuals
self.assertEqual(np.count_nonzero(res == jnp.sin(1.0)), ragged_total)
def test_vmap_jumble_over_sin_kernel_grid_remapping(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Only tested on TPU")
row_count = 8
col_grid_size = 5
ragged_shape = [3, 1, 4]
sizes = lax.convert_element_type(
jnp.array([128 * x for x in ragged_shape]),
core.bint(col_grid_size * 128),
)
x = jax.vmap(
lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis
)(sizes)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.sin(x_ref[...]) * pl.program_id(2)
def invoke_kernel(x):
return pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))],
out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)),
out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32),
grid=(1, 5),
interpret=False,
)(x)
with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"):
jax.vmap(
invoke_kernel,
out_axes=batching.jumble_axis,
in_axes=batching.jumble_axis,
axis_size=3,
)(x)
def test_vmap_jumble_ragged_boundary_unaligned_with_grid(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Only tested on TPU")
self.skipTest("Checkify NYI")
row_count = 8
col_grid_size = 5
ragged_shape = [3, 1, 4]
sizes = lax.convert_element_type(
jnp.array([(128 * x) - 1 for x in ragged_shape]),
core.bint(col_grid_size * 128),
)
x = jax.vmap(
lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis
)(sizes)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.sin(x_ref[...])
def invoke_kernel(x):
return pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))],
out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)),
out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32),
grid=(1, 5),
interpret=False,
)(x)
with self.assertRaisesRegex(
ValueError,
"Ragged input shape must be evenly divisble by the grid" # noqa: W605
" size at the ragged dimension 2",
):
jax.vmap(
invoke_kernel,
out_axes=batching.jumble_axis,
in_axes=batching.jumble_axis,
axis_size=3,
)(x)
class PallasCallNamedGridInterpretTest(PallasCallRaggedVmapTest):
INTERPRET = True
if __name__ == "__main__":
absltest.main()