diff --git a/jax/_src/core.py b/jax/_src/core.py index ebf29cf0b..61ed81cde 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1954,6 +1954,7 @@ class DArray: assert data.shape == pad_shape self._aval = aval self._data = data + shape = property(lambda self: self._aval.shape) dtype = property(lambda self: self._aval.dtype) aval = property(lambda self: self._aval) @@ -1964,21 +1965,38 @@ class DArray: dtypestr = _short_dtype_name(self._aval.dtype) shapestr = ','.join(map(str, self.shape)) - slices = tuple(slice(int(d._data)) if type(d) is DArray and - type(d.dtype) is bint else slice(None) for d in self.shape) - data = self._data[slices] + data = self.data return f'{dtypestr}[{shapestr}] with value: {data}' + def __hash__(self) -> int: if not self.shape: return hash((self._aval, int(self._data))) raise TypeError("unhashable type: DArray") + def __eq__(self, other): if isinstance(other, DArray) and self._aval == other._aval: return self._data == other._data return False + def __len__(self): return self.shape[0] + @property + def data(self): + if not self.shape and type(self.dtype) is bint: + # special-case scalar bints + return self._data + + slices = tuple( + slice(int(d._data)) + if type(d) is DArray and type(d.dtype) is bint + else slice(None) + for d in self.shape + ) + data = self._data[slices] + return data + + pytype_aval_mappings[DArray] = \ lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type, x._data) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index fbcd2c4a7..27cde6d31 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -88,6 +88,7 @@ def _jumble_flatten(jumble): elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape)) aval = jumble.aval.replace(elt_ty=elt_ty) return (lengths, jumble.data), aval + def _jumble_unflatten(aval, x): lengths, data = x new_shape = [d.replace(lengths=lengths[d.lengths - 1]) @@ -251,7 +252,10 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: return (BatchTracer(trace, x, spec, source_info_util.current()) if spec is not None else x) else: - assert False + # TODO(mvoz): This is a terrible place to fall into if you pass + # a non jumble type in, make it clearer what went wrong. + assert False, f'Unexpected type in ELT? {type(x)}' + to_elt_handlers: dict[type, ToEltHandler] = {} def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int, diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 0ef208f75..1d99680b6 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -112,7 +112,10 @@ class AbstractMemoryRef(state.AbstractRef): def __init__(self, inner_aval: jax_core.AbstractValue, memory_space: Any): - assert isinstance(inner_aval, jax_core.ShapedArray) + + assert isinstance( + inner_aval, jax_core.ShapedArray + ), f"Illegal ref, got {type(inner_aval)}" self.inner_aval = inner_aval self.memory_space = memory_space @@ -167,9 +170,7 @@ class PallasGridContext: mapped_dims: tuple[int, ...] def size(self, axis: int) -> int | DynamicGridDim: - valid_grid = tuple( - s for i, s in enumerate(self.grid) if i not in self.mapped_dims - ) + valid_grid = tuple(self.grid) try: size = valid_grid[axis] except IndexError as e: @@ -338,7 +339,10 @@ class BlockMapping: ) assert not self.index_map_jaxpr.consts - assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals) + assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals), ( + self.block_shape, + self.index_map_jaxpr.out_avals, + ) assert all(ov.shape == () and (ov.dtype == jnp.int32 or ov.dtype == jnp.int64) for ov in self.index_map_jaxpr.out_avals), ( @@ -422,6 +426,8 @@ class GridMapping: num_inputs: int num_outputs: int num_scratch_operands: int + get_grid_indices: Callable | None = None + local_grid_env: Callable | None = None def check_invariants(self) -> None: if not config.enable_checks.value: return @@ -442,8 +448,8 @@ class GridMapping: assert len(index_map_args) >= len(self.grid) for i in range(len(self.grid)): index_map_arg = index_map_args[i] - assert index_map_arg.shape == () - assert index_map_arg.dtype == jnp.int32 + assert index_map_arg.shape == (), f"index_map_arg: {index_map_arg}" + assert index_map_arg.dtype == jnp.int32, f"index_map_arg: {index_map_arg}" assert len(self.vmapped_dims) <= len(self.grid) for i in self.vmapped_dims: @@ -454,8 +460,11 @@ class GridMapping: for bm in self.block_mappings: bm.check_invariants() - assert tuple(self.index_map_avals) == tuple(bm.index_map_jaxpr.in_avals), ( + assert tuple(self.index_map_avals) == tuple( + bm.index_map_jaxpr.in_avals + ), ( self.index_map_avals, + "|", bm.index_map_jaxpr.in_avals, ) @@ -547,6 +556,25 @@ def _is_valid_grid_dim(dim: int | jax.Array) -> bool: return True return jax_core.is_dim(dim) + +def _max_shape_from_aval(array_aval: jax_core.ShapedArray): + array_aval_shape = list(array_aval.shape) + for i, s in enumerate(array_aval.shape): + try: + aval = jax_core.get_aval(s) + if isinstance(aval, jax_core.DShapedArray): + array_aval_shape[i] = aval.dtype.bound + except OverflowError as e: + # Note - there are annoying cases where on 32 bit hardware, + # a flattened index space may overflow - for these cases, + # we just take the shape as is. + # In most places, this is totally sound to do. + # For ragged/jumble inputs, this will fail downstream. + return array_aval.shape + + return tuple(array_aval_shape) + + def _convert_block_spec_to_block_mapping( block_spec: BlockSpec, origin: OriginStr, @@ -575,8 +603,15 @@ def _convert_block_spec_to_block_mapping( f"array shape {array_aval.shape}.") unmapped_block_shape = tuple(s for s in block_shape if s is not None) - block_aval = AbstractMemoryRef(array_aval.update(shape=unmapped_block_shape), - block_spec.memory_space) + block_array_aval = array_aval.update(shape=unmapped_block_shape) + if isinstance(array_aval, jax_core.DShapedArray): + # Get the "max" shape for the ragged array. + block_array_aval = jax_core.ShapedArray( + block_array_aval.shape, + block_array_aval.dtype, + block_array_aval.weak_type, + ) + block_aval = AbstractMemoryRef(block_array_aval, block_spec.memory_space) if not jax_core.is_constant_shape(block_aval.shape): raise ValueError( @@ -609,12 +644,12 @@ def _convert_block_spec_to_block_mapping( f"{origin} must return integer scalars. Output[{i}] has type " f"{ov}.") - if consts: raise ValueError( f"Index map function {index_map_src_info} for " f"{origin} must not capture constants: {consts}") + array_aval_shape = _max_shape_from_aval(array_aval) mapping = BlockMapping( block_shape=mapped_block_shape, @@ -622,7 +657,9 @@ def _convert_block_spec_to_block_mapping( index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts), index_map_src_info=index_map_src_info, indexing_mode=block_spec.indexing_mode, - array_shape_dtype=jax.ShapeDtypeStruct(array_aval.shape, array_aval.dtype), + array_shape_dtype=jax.ShapeDtypeStruct( + array_aval_shape, array_aval.dtype + ), origin=origin, ) mapping.check_invariants() diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 86ce2f0b1..aee894ee1 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -298,6 +298,7 @@ class MosaicGridMapping: self.jaxpr = jaxpr self.block_mappings = grid_mapping.block_mappings self.mapped_dims = grid_mapping.vmapped_dims + # TODO(mvoz): Generalize to not need this user_grid = tuple( g for i, g in enumerate(self.grid) if i not in self.mapped_dims ) @@ -345,9 +346,19 @@ class MosaicGridMapping: for _ in range(len(self.grid)) ]) self._prepare_mesh_info(mesh) - def _get_grid_indices(indices): - return indices - self.get_grid_indices = _get_grid_indices + + if grid_mapping.get_grid_indices is None: + + def _get_grid_indices(indices, maybe_include_mapped_dims: bool): + if maybe_include_mapped_dims: + return indices + return tuple( + idx for i, idx in enumerate(indices) if i not in self.mapped_dims + ) + + self.get_grid_indices = _get_grid_indices + else: + self.get_grid_indices = grid_mapping.get_grid_indices def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None): if not self.has_communication: @@ -595,7 +606,9 @@ def lower_jaxpr_to_transform_func( ] def body_func(*args): grid_indices, scalar_prefetch = split_list(args, [num_grid]) - jaxpr_indices = mosaic_grid_mapping.get_grid_indices(grid_indices) + jaxpr_indices = mosaic_grid_mapping.get_grid_indices( + grid_indices, maybe_include_mapped_dims=True + ) arg_block_shapes = [ *[()] * len(jaxpr_indices), *mosaic_grid_mapping.scalar_prefetch_block_shapes, @@ -663,9 +676,9 @@ def lower_jaxpr_to_func( def body_func(*args): grid_indices, scalar_prefetch, operands_and_scratch = split_list( args, [num_grid, num_scalar_prefetch]) - grid_indices = mosaic_grid_mapping.get_grid_indices(grid_indices) - jaxpr_indices = tuple(idx for i, idx in enumerate(grid_indices) - if i not in mosaic_grid_mapping.mapped_dims) + jaxpr_indices = mosaic_grid_mapping.get_grid_indices( + grid_indices, maybe_include_mapped_dims=False + ) mesh_info = mosaic_grid_mapping.mesh_info if mesh_info is not None: mesh_context = MeshContext( @@ -2365,6 +2378,7 @@ lowering_rules[debugging.debug_callback_p] = _debug_callback_lowering_rule def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int): + if ctx.lowering_context.user_grid_indices is None: raise ValueError( f"program id: {axis} was passed, but user did not provide a grid." diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 3a780cdca..e948ff374 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -228,6 +228,12 @@ def _pallas_call_impl_interpret( # Pad values to evenly divide into block dimensions. This matches the # behavior of the non-interpret mode. We pad with NaN, to make it easier # to catch OOB accesses. + for carry_element in carry: + aval = carry_element.aval + if isinstance(aval, jax_core.DShapedArray): + aval = jax_core.ShapedArray(aval.shape, aval.dtype) + carry_element.aval = aval + carry = map(_pad_values_to_block_dimension, carry, block_shapes) carry.extend(scratch_values) @@ -247,11 +253,16 @@ def _pallas_call_impl_interpret( return i < num_iterations def body(carry): i, loop_idx, *carry_blocks = carry - local_grid_env = tuple( - pallas_core.GridAxis(idx, b) - for dim, (idx, b) in enumerate(zip(loop_idx, grid)) - if dim not in grid_mapping.vmapped_dims - ) + + if grid_mapping.local_grid_env is not None: + local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) + else: + local_grid_env = tuple( + pallas_core.GridAxis(idx, b) + for dim, (idx, b) in enumerate(zip(loop_idx, grid)) + if dim not in grid_mapping.vmapped_dims + ) + carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks]) with pallas_core.grid_env(local_grid_env): start_indices = [ @@ -268,8 +279,14 @@ def _pallas_call_impl_interpret( len(blocks), len(scratch_values), ) - blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars, - *blocks, *scratch) + for s in scalars: + aval = jax_core.get_aval(s) + if isinstance(aval, jax_core.DShapedArray): + s.aval = aval.update(dtype=jnp.int32) + + blocks = jax_core.eval_jaxpr( + discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch + ) _, out_inout, out_scratch = split_list( blocks, [grid_mapping.num_index_operands, num_inout_blocks]) @@ -390,19 +407,55 @@ def _pallas_call_jvp_rule( ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule -def _batch_block_mapping(grid_mapping: GridMapping, - axis_size: int, - aval: jax_core.ShapedArray, - dim: int | batching.NotMapped, - block_mapping: BlockMapping) -> BlockMapping: + +def _batch_block_mapping( + grid_mapping: GridMapping, + axis_size: int, + aval: jax_core.ShapedArray, + dim: int | batching.NotMapped, + block_mapping: BlockMapping, + for_ragged: bool, +) -> BlockMapping: def _block_map_function(new_idx, *args): - indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr, - block_mapping.index_map_jaxpr.consts, - *args) + if for_ragged: + drop_last_args = args[:-1] + else: + drop_last_args = args + + indices = jax_core.eval_jaxpr( + block_mapping.index_map_jaxpr.jaxpr, + block_mapping.index_map_jaxpr.consts, + *drop_last_args, + ) if dim is not batching.not_mapped: - indices.insert(dim, new_idx) + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + stacked_axis = dim.stacked_axis + indices.insert(stacked_axis, new_idx) + else: + indices.insert(dim, new_idx) return tuple(indices) idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals] + + if for_ragged: + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + _, _, ragged_axis_length = _ragged_axis_parts(dim) + aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) + if isinstance(aval, jax_core.DShapedArray): + aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) + lengths_aval = pallas_core.AbstractMemoryRef( + aval, + pallas_core.MemorySpace.INDEX, + ) + idx_avals = [*idx_avals, lengths_aval] + else: + i32_aval_memref = pallas_core.AbstractMemoryRef( + jax_core.ShapedArray(([axis_size]), jnp.int32), + pallas_core.MemorySpace.INDEX, + ) + idx_avals = [*idx_avals, i32_aval_memref] + with grid_mapping.trace_env(): block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_block_map_function), idx_avals) @@ -411,12 +464,27 @@ def _batch_block_mapping(grid_mapping: GridMapping, new_block_shape = shape new_array_shape_dtype = block_mapping.array_shape_dtype else: - new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + new_block_shape = shape + stacked_axis = dim.stacked_axis + new_block_shape = tuple_insert( + new_block_shape, stacked_axis, pallas_core.mapped + ) + else: + new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) + + array_shape = block_mapping.array_shape_dtype.shape + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + stacked_axis = dim.stacked_axis + array_shape = tuple_insert(array_shape, stacked_axis, axis_size) + else: + array_shape = tuple_insert(array_shape, dim, axis_size) + new_array_shape_dtype = jax.ShapeDtypeStruct( - tuple_insert(block_mapping.array_shape_dtype.shape, - dim, - axis_size), - block_mapping.array_shape_dtype.dtype) + array_shape, block_mapping.array_shape_dtype.dtype + ) jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts) return block_mapping.replace(block_shape=new_block_shape, @@ -547,6 +615,16 @@ def _batch_with_explicit_loop( return result, (0,) * len(result) +def _ragged_axis_parts(dim: batching.RaggedAxis) -> tuple[int, int, int]: + stacked_axis = dim.stacked_axis + ragged_axes = dim.ragged_axes + if len(ragged_axes) != 1: + raise ValueError("Multiple ragged axes not yet implemented.") + ragged_axis_dim = ragged_axes[0][0] + ragged_axis_length = ragged_axes[0][1] + return stacked_axis, ragged_axis_dim, ragged_axis_length + + def _pallas_call_batching_rule( args, dims, @@ -567,8 +645,26 @@ def _pallas_call_batching_rule( return x return jnp.squeeze(x, axis=bdim) + all_ragged_axes = [d for d in dims if isinstance(d, batching.RaggedAxis)] + if len(all_ragged_axes) > 1: + raise ValueError("Multiple ragged dimensions not yet implemented.") + + if all_ragged_axes: + stacked_axis, ragged_axis_dim, ragged_axis_length = _ragged_axis_parts( + all_ragged_axes[0] + ) + else: + stacked_axis, ragged_axis_dim, ragged_axis_length = None, None, None + + def get_size(i, x, d): + if not isinstance(d, batching.RaggedAxis): + return x.shape[d] + return x.aval.shape[i] + (axis_size,) = { - x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped + get_size(i=i, x=x, d=d) + for i, (x, d) in enumerate(zip(args, dims)) + if d is not batching.not_mapped } if axis_size == 1: # Why are we even vmapping? @@ -670,12 +766,27 @@ def _pallas_call_batching_rule( num_index_operands = grid_mapping.num_index_operands num_scratch_operands = grid_mapping.num_scratch_operands + lengths_aval = None + if ragged_axis_length is not None: + aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) + if isinstance(aval, jax_core.DShapedArray): + aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) + lengths_aval = pallas_core.AbstractMemoryRef( + aval, + pallas_core.MemorySpace.INDEX, + ) + # Only add a batch dimension for the avals that actually have a grid mapping. # This excludes scalar prefetch inputs (the first in the list) and scratch # operands (the last in the list). avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)] batched_block_mappings = map( - partial(_batch_block_mapping, grid_mapping, axis_size), + partial( + _batch_block_mapping, + grid_mapping, + axis_size, + for_ragged=lengths_aval is not None, + ), avals_to_batch, all_dims[num_index_operands:], block_mappings, @@ -685,15 +796,23 @@ def _pallas_call_batching_rule( grid_mapping.index_map_avals) assert not index_map_tree_kwargs batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args + + if lengths_aval: + batched_index_map_args = batched_index_map_args + (lengths_aval,) + num_index_operands += 1 + batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten( (batched_index_map_args, {})) + batched_grid_mapping = grid_mapping.replace( grid=(axis_size, *grid_mapping.grid), block_mappings=tuple(batched_block_mappings), - index_map_avals=batched_index_map_avals, + index_map_avals=tuple(batched_index_map_avals), index_map_tree=batched_index_map_tree, + num_index_operands=num_index_operands, vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims), ) + if cost_estimate is not None: batched_cost_estimate = CostEstimate( flops=cost_estimate.flops * axis_size, @@ -702,6 +821,103 @@ def _pallas_call_batching_rule( ) else: batched_cost_estimate = None + + if lengths_aval: + batched_grid_mapping = batched_grid_mapping.replace( + get_grid_indices=lambda indices, maybe_include_mapped_dims: indices, + local_grid_env=lambda loop_idx, grid: tuple( + pallas_core.GridAxis(idx, b) for (idx, b) in zip(loop_idx, grid) + ), + ) + + # Note - on zero filling counterfactuals + # A debug util to produce a counterfactual version of the when + # gating, where for all values that don't pass the @when check, + # we write 0s. This is useful for debugging, as certain lowering paths + # like mosaic will write the last data as passthrough, leading to + # potentially confusing results. + debug_zero_fill_counterfactual = debug + + first_block_mapping = batched_grid_mapping.block_mappings[0] + for block_mapping in batched_grid_mapping.block_mappings: + # This invariant may already be checked elsewhere, but lets reaffirm it + assert block_mapping.block_shape == first_block_mapping.block_shape, ( + f"block_mapping.block_shape: {block_mapping.block_shape}, " + f"first_block_mapping.block_shape: {first_block_mapping.block_shape}" + ) + assert ( + block_mapping.array_shape_dtype + == first_block_mapping.array_shape_dtype + ), ( + f"block_mapping.array_shape_dtype: {block_mapping.array_shape_dtype}," + " first_block_mapping.array_shape_dtype:" + f" {first_block_mapping.array_shape_dtype}" + ) + + mapped_dim_idxs = [ + i + for i, d in enumerate(first_block_mapping.block_shape) + if d is pallas_core.mapped + ] + assert len(mapped_dim_idxs) == 1 + mapped_dim_idx = mapped_dim_idxs[0] + if stacked_axis != mapped_dim_idx: + raise ValueError( + f"Expected mapped dim to be {stacked_axis}, but got {mapped_dim_idx}" + ) + + assert ragged_axis_dim is not None, "Invariant violation" + # This is the blockspec size of the dimension + val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim] + + def when_wrapped_kernel(lengths_ref, *args, **kwargs): + b_idx = jax.experimental.pallas.program_id(stacked_axis) + i_idx = ( + jax.experimental.pallas.program_id(ragged_axis_dim) + * val_at_ragged_dim + ) + b_len = lengths_ref[b_idx] + + # TODO(mvoz): Unimplemented primitive in pallas + # b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0) + # checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0") + + @jax.experimental.pallas.when(i_idx < b_len) + def f(): + # Important! This allows us to trace the inner kernel with the correct + # grid to preserve user program_id semantics. Ex: program_id(0) will + # always be analogous to program_id(1) in the outer kernel. + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs) + + if debug_zero_fill_counterfactual: + + @jax.experimental.pallas.when(i_idx >= b_len) + def g(): + for arg_ref in args: + arg_ref[...] = jnp.zeros_like(arg_ref) + + kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars] + flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten( + list(kernel_avals) + ) + # Important! This allows us to trace the outer kernel with the correct grid + # to enable accessing the batch program_id. + with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()): + kernel_src_info: pallas_core.SrcInfoStr = "" + + jaxpr = _trace_kernel_to_jaxpr( + when_wrapped_kernel, + kernel_src_info, + batched_grid_mapping, + tuple(flat_kernel_avals), + kernel_in_tree, + interpret=interpret, + ) + + assert ragged_axis_length is not None + args = (ragged_axis_length, *args) + out = pallas_call_p.bind( *dynamic_grid_args, *args, @@ -1097,12 +1313,14 @@ def pallas_call( out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths) flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) # type: ignore for x in flat_out_shapes] + @jax.jit def wrapped(*args): flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args) in_paths, flat_args = unzip2(flat_args_with_paths) flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a)) for a in flat_args) + flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype) for v in flat_out_shapes) @@ -1172,15 +1390,18 @@ def pallas_call( return wrapped -def in_path_to_input_origin(in_path: tree_util.KeyPath, - arg_names: tuple[str, ...] | None) -> pallas_core.OriginStr: +def in_path_to_input_origin( + in_path: tree_util.KeyPath, arg_names: tuple[str, ...] | None +) -> pallas_core.OriginStr: """Converts `args[k]` into `arg_k_name`.""" if arg_names is None: return f"args{tree_util.keystr(in_path)}" if len(in_path) == 0: return "args" arg_idx, *rest_path = in_path - if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(arg_names): + if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len( + arg_names + ): return arg_names[arg_idx.idx] + tree_util.keystr(tuple(rest_path)) else: return f"args{tree_util.keystr(tuple(in_path))}" diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index c0cf61387..5559a0552 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -62,6 +62,29 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_test( + name = "pallas_jumble_test", + srcs = [ + "pallas_jumble_test.py", + ], + disable_configs = [ + "gpu", + "gpu_x32", + "gpu_a100", + "gpu_p100", + "gpu_p100_x32", + "gpu_h100", + ], + shard_count = { + "tpu": 1, + }, + deps = [ + "//jax:pallas", + "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + jax_test( name = "ops_test", srcs = [ diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py new file mode 100644 index 000000000..5ed15fe96 --- /dev/null +++ b/tests/pallas/pallas_jumble_test.py @@ -0,0 +1,201 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" + +from absl.testing import absltest +import jax +from jax import lax +from jax._src import config +from jax._src import core +from jax._src import dtypes +from jax._src import test_util as jtu +from jax._src.interpreters import batching +from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr +from jax.experimental import pallas as pl +import jax.numpy as jnp +import numpy as np + + +# TODO(mvoz): Update signatures of pallas_call to correct inputs/outputs. +# pylint: disable=no-value-for-parameter + +config.parse_flags_with_absl() + + +intx = dtypes.canonicalize_dtype(jnp.int64) +floatx = dtypes.canonicalize_dtype(jnp.float64) + + +@jtu.with_config(jax_traceback_filtering="off") +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: + self.skipTest("On CPU the test works only in interpret mode") + if jtu.test_device_matches( + ["cuda"] + ) and not jtu.is_cuda_compute_capability_at_least("8.0"): + self.skipTest("Only works on GPU with capability >= sm80") + if sys.platform == "win32" and not self.INTERPRET: + self.skipTest("Only works on non-Windows platforms") + + super().setUp() + _trace_kernel_to_jaxpr.cache_clear() + + def pallas_call(self, *args, **kwargs): + return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) + + +@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_dtype_promotion="standard") +class PallasCallRaggedVmapTest(PallasBaseTest): + + def test_vmap_jumble_over_sin_kernel(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only tested on TPU") + + row_count = 8 + col_grid_size = 5 + ragged_shape = [3, 1, 4] + sizes = lax.convert_element_type( + jnp.array([128 * x for x in ragged_shape]), + core.bint(col_grid_size * 128), + ) + x = jax.vmap( + lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis + )(sizes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sin(x_ref[...]) + + def invoke_kernel(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], + out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct( + (8, col_grid_size * 128), dtype=jnp.float32 + ), + grid=(1, col_grid_size), + interpret=self.INTERPRET, + # See note - on zero filling counterfactuals + debug=True, + )(x) + + res = jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) + + res = res.data + total = len(ragged_shape) * row_count * col_grid_size * 128 + res_total = np.prod(res.shape) + self.assertEqual(res_total, total) + ragged_total = 0 + for dim in ragged_shape: + ragged_total += row_count * dim * 128 + # See note - on zero filling counterfactuals + self.assertEqual(np.count_nonzero(res == jnp.sin(1.0)), ragged_total) + + def test_vmap_jumble_over_sin_kernel_grid_remapping(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only tested on TPU") + + row_count = 8 + col_grid_size = 5 + ragged_shape = [3, 1, 4] + sizes = lax.convert_element_type( + jnp.array([128 * x for x in ragged_shape]), + core.bint(col_grid_size * 128), + ) + x = jax.vmap( + lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis + )(sizes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sin(x_ref[...]) * pl.program_id(2) + + def invoke_kernel(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], + out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), + grid=(1, 5), + interpret=False, + )(x) + + with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"): + jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) + + def test_vmap_jumble_ragged_boundary_unaligned_with_grid(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only tested on TPU") + + self.skipTest("Checkify NYI") + + row_count = 8 + col_grid_size = 5 + ragged_shape = [3, 1, 4] + sizes = lax.convert_element_type( + jnp.array([(128 * x) - 1 for x in ragged_shape]), + core.bint(col_grid_size * 128), + ) + x = jax.vmap( + lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis + )(sizes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sin(x_ref[...]) + + def invoke_kernel(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], + out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), + grid=(1, 5), + interpret=False, + )(x) + + with self.assertRaisesRegex( + ValueError, + "Ragged input shape must be evenly divisble by the grid" # noqa: W605 + " size at the ragged dimension 2", + ): + jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) + + +class PallasCallNamedGridInterpretTest(PallasCallRaggedVmapTest): + INTERPRET = True + + +if __name__ == "__main__": + absltest.main()