[Pallas] Make num_programs return an int if the grid is not dynamic

PiperOrigin-RevId: 644149441
This commit is contained in:
Sharad Vikram 2024-06-17 15:17:52 -07:00 committed by jax authors
parent 1de2756c7e
commit 9499de4358
6 changed files with 205 additions and 57 deletions

View File

@ -20,6 +20,7 @@ import copy
import contextlib
import dataclasses
import functools
import threading
from typing import Any, Callable, Union
import jax
@ -33,10 +34,16 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.state import discharge as state_discharge
import jax.numpy as jnp
class DynamicGridDim:
pass
dynamic_grid_dim = DynamicGridDim()
partial = functools.partial
Grid = tuple[Union[int, jax_core.Array, None], ...] # None indicates that the bound is dynamic.
DynamicGrid = tuple[Union[int, jax_core.Array], ...]
Grid = tuple[Union[int, jax_core.Array], ...]
StaticGrid = tuple[int, ...]
GridMappingGrid = tuple[Union[int, DynamicGridDim], ...]
split_list = util.split_list
map, unsafe_map = util.safe_map, map
@ -84,6 +91,39 @@ def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type):
jax_core.raise_to_shaped_mappings[AbstractMemoryRef] = _ref_raise_to_shaped
@dataclasses.dataclass(frozen=True)
class PallasGridContext:
grid: GridMappingGrid
mapped_dims: tuple[int, ...]
def size(self, axis: int) -> int | DynamicGridDim:
valid_grid = tuple(
s for i, s in enumerate(self.grid) if i not in self.mapped_dims
)
try:
size = valid_grid[axis]
except IndexError as e:
raise ValueError(
f"Axis {axis} is out of bounds for grid {self.grid}"
) from e
return size
@dataclasses.dataclass
class PallasTracingEnv(threading.local):
grid_context: PallasGridContext | None = None
_pallas_tracing_env = PallasTracingEnv()
def axis_frame() -> PallasGridContext:
# This is like jax_core.axis_frame, except there should only ever be one
# active PallasGridAxisName for a particular main_trace because we cannot
# nest pallas_calls.
env = _pallas_tracing_env
assert env.grid_context is not None
return env.grid_context
@dataclasses.dataclass(frozen=True)
class GridAxis:
index: jax.Array
@ -176,9 +216,20 @@ class BlockMapping:
replace = dataclasses.replace
@contextlib.contextmanager
def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]):
assert all(i is dynamic_grid_dim or isinstance(i, int) for i in grid)
old_grid_context = _pallas_tracing_env.grid_context
try:
_pallas_tracing_env.grid_context = PallasGridContext(grid, mapped_dims)
yield
finally:
_pallas_tracing_env.grid_context = old_grid_context
@dataclasses.dataclass(frozen=True)
class GridMapping:
grid: Grid
grid: GridMappingGrid
block_mappings: tuple[BlockMapping | None, ...]
mapped_dims: tuple[int, ...] = ()
num_index_operands: int = 0
@ -190,7 +241,7 @@ class GridMapping:
@property
def num_dynamic_grid_bounds(self):
return sum(b is None for b in self.grid)
return sum(b is dynamic_grid_dim for b in self.grid)
@property
def static_grid(self) -> StaticGrid:
@ -198,6 +249,11 @@ class GridMapping:
raise ValueError("Expected a grid with fully static bounds")
return self.grid # type: ignore
@contextlib.contextmanager
def trace_env(self):
with tracing_grid_env(self.grid, self.mapped_dims):
yield
def _preprocess_grid(grid: Grid | int | None) -> Grid:
if grid is None:
@ -208,8 +264,12 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid:
def _convert_block_spec_to_block_mapping(
in_avals: Sequence[jax_core.ShapedArray], block_spec: BlockSpec,
aval: jax_core.ShapedArray, in_tree: Any,
in_avals: Sequence[jax_core.ShapedArray],
block_spec: BlockSpec,
aval: jax_core.ShapedArray,
in_tree: Any,
grid: GridMappingGrid,
mapped_dims: tuple[int, ...],
) -> BlockMapping | None:
if block_spec is no_block_spec:
return None
@ -222,11 +282,13 @@ def _convert_block_spec_to_block_mapping(
block_shape = tuple(
mapped if s is None else s for s in block_shape)
flat_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index), in_tree)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
with tracing_grid_env(grid, mapped_dims):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
return BlockMapping(
block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode
)
def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
) -> state.AbstractRef:
if block_shape is None:
@ -267,6 +329,7 @@ class NoBlockSpec:
pass
no_block_spec = NoBlockSpec()
@dataclasses.dataclass(init=False, unsafe_hash=True)
class GridSpec:
grid: Grid
@ -323,6 +386,10 @@ class GridSpec:
def get_grid_mapping(
self, in_avals, in_tree, out_avals, out_tree
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
assert all(i is None or isinstance(i, int) for i in self.grid)
grid_mapping_grid = tuple(
dynamic_grid_dim if d is None else d for d in self.grid
)
flat_in_specs, flat_out_specs = self._get_in_out_specs(
in_avals, in_tree, out_avals, out_tree)
in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals(
@ -332,13 +399,29 @@ class GridSpec:
# Create args, kwargs pytree def
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
in_block_mappings = map(
partial(_convert_block_spec_to_block_mapping, grid_avals,
in_tree=grid_tree), in_specs, in_ref_avals)
partial(
_convert_block_spec_to_block_mapping,
grid_avals,
in_tree=grid_tree,
grid=grid_mapping_grid,
mapped_dims=(),
),
in_specs,
in_ref_avals,
)
out_block_mappings = map(
partial(_convert_block_spec_to_block_mapping, grid_avals,
in_tree=grid_tree), out_specs, out_ref_avals)
partial(
_convert_block_spec_to_block_mapping,
grid_avals,
in_tree=grid_tree,
grid=grid_mapping_grid,
mapped_dims=(),
),
out_specs,
out_ref_avals,
)
grid_mapping = GridMapping(
self.grid, (*in_block_mappings, *out_block_mappings)
grid_mapping_grid, (*in_block_mappings, *out_block_mappings) # type: ignore
)
jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
@ -346,11 +429,15 @@ class GridSpec:
jaxpr_out_avals = (jaxpr_out_avals,)
return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping
def unzip_dynamic_grid_bounds(self) -> tuple[GridSpec, tuple[Any, ...]]:
static_grid = tuple(d if isinstance(d, int) else None for d in self.grid)
def unzip_dynamic_grid_bounds(
self,
) -> tuple[GridSpec, tuple[Any, ...]]:
static_grid = tuple(
d if isinstance(d, int) else None for d in self.grid
)
dynamic_bounds = tuple(d for d in self.grid if not isinstance(d, int))
# We can't use dataclasses.replace, because our fields are incompatible
# with __init__'s signature.
static_self = copy.copy(self)
static_self.grid = static_grid
static_self.grid = static_grid # type: ignore
return static_self, dynamic_bounds

View File

@ -166,6 +166,10 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
def get_grid_mapping(
self, in_avals, in_tree, out_avals, out_tree
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
assert all(i is None or isinstance(i, int) for i in self.grid)
grid_mapping_grid = tuple(
pallas_core.dynamic_grid_dim if d is None else d for d in self.grid
)
all_avals = tree_util.tree_unflatten(in_tree, in_avals)
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
self.scratch_shapes)
@ -191,15 +195,29 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
((*grid_avals, *scalar_avals), {})
)
in_block_mappings = map(
partial(_convert_block_spec_to_block_mapping,
(*grid_avals, *scalar_ref_avals),
in_tree=index_map_in_tree), in_specs, in_ref_avals)
partial(
_convert_block_spec_to_block_mapping,
(*grid_avals, *scalar_ref_avals),
in_tree=index_map_in_tree,
grid=grid_mapping_grid,
mapped_dims=(),
),
in_specs,
in_ref_avals,
)
out_block_mappings = map(
partial(_convert_block_spec_to_block_mapping,
(*grid_avals, *scalar_ref_avals),
in_tree=index_map_in_tree), out_specs, out_ref_avals)
partial(
_convert_block_spec_to_block_mapping,
(*grid_avals, *scalar_ref_avals),
in_tree=index_map_in_tree,
grid=grid_mapping_grid,
mapped_dims=(),
),
out_specs,
out_ref_avals,
)
grid_mapping = GridMapping(
grid=self.grid,
grid=grid_mapping_grid, # type: ignore
block_mappings=(*in_block_mappings, *out_block_mappings),
mapped_dims=(),
num_index_operands=num_flat_scalar_prefetch,

View File

@ -451,7 +451,9 @@ def lower_jaxpr_to_module(
m.body.append(mlir_func)
sym_tab.insert(mlir_func)
func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params)
static_grid = [MLIR_DYNAMIC if b is None else b for b in grid]
static_grid = [
MLIR_DYNAMIC if b is pl_core.dynamic_grid_dim else b for b in grid
]
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)
func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
@ -1021,7 +1023,6 @@ def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree
return KeyScalarBundle(scalars=load_ops)
lowering_rules[primitives.load_p] = _load_lowering_rule
skip_mlir_conversions.add(primitives.load_p)

View File

@ -138,11 +138,13 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
# will do.
dynamic_grid_args_iter = iter(dynamic_grid_args)
grid = tuple(
a if a is not None else next(dynamic_grid_args_iter)
a if a is not pallas_core.dynamic_grid_dim
else next(dynamic_grid_args_iter)
for a in grid_mapping.grid
)
assert next(dynamic_grid_args_iter, None) is None
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
with grid_mapping.trace_env():
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
if debug:
print(discharged_jaxpr)
oi_map = {v: k for k, v in input_output_aliases}
@ -330,7 +332,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
return out_primals, out_tangents
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
def _batch_block_mapping(grid: tuple[int, ...], aval: jax_core.ShapedArray,
def _batch_block_mapping(grid_mapping: GridMapping, aval: jax_core.ShapedArray,
dim: int | batching.NotMapped,
block_mapping: BlockMapping | None) -> BlockMapping:
def _block_map_function(new_idx, *args):
@ -345,11 +347,12 @@ def _batch_block_mapping(grid: tuple[int, ...], aval: jax_core.ShapedArray,
return tuple(indices)
i32_aval = jax_core.ShapedArray((), jnp.int32)
if block_mapping is None:
idx_avals = [i32_aval] * (len(grid) + 1)
idx_avals = [i32_aval] * (len(grid_mapping.grid) + 1)
else:
idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals]
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_block_map_function), idx_avals)
with grid_mapping.trace_env():
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_block_map_function), idx_avals)
shape = aval.shape if block_mapping is None else block_mapping.block_shape
if dim is batching.not_mapped:
new_block_shape = shape
@ -628,7 +631,7 @@ def _pallas_call_batching_rule(
# operands (the last in the list).
avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)]
batched_block_mappings = map(
partial(_batch_block_mapping, grid_mapping.grid),
partial(_batch_block_mapping, grid_mapping),
avals_to_batch,
all_dims[num_index_operands:],
block_mappings,
@ -711,14 +714,16 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, flat_in_avals,
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun), jaxpr_in_tree)
debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call")
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, debug)
if consts:
jaxpr = _hoist_consts_to_refs(jaxpr)
# Pad ``block_mappings`` to account for the hoisted constants.
grid_mapping = grid_mapping.replace(
block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)),
num_constant_operands=len(consts),
)
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun,
jaxpr_flat_avals, debug)
if consts:
jaxpr = _hoist_consts_to_refs(jaxpr)
# Pad ``block_mappings`` to account for the hoisted constants.
grid_mapping = grid_mapping.replace(
block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)),
num_constant_operands=len(consts),
)
return grid_mapping, jaxpr, consts, out_tree_thunk()
def _extract_function_name(f: Callable, name: str | None) -> str:

View File

@ -56,38 +56,35 @@ def program_id_bind(*, axis: int):
grid_env = pallas_core.current_grid_env()
if grid_env:
return grid_env[axis].index
frame = pallas_core.axis_frame()
# Query the size of the axis to make sure its a valid axis (and error
# otherwise).
_ = frame.size(axis)
return jax_core.Primitive.bind(program_id_p, axis=axis)
program_id_p.def_custom_bind(program_id_bind)
def _program_id_impl(*, axis: int):
grid_env = pallas_core.current_grid_env()
assert grid_env
return grid_env[axis].index
program_id_p.def_impl(_program_id_impl)
def _program_id_abstract_eval(**_):
return jax_core.ShapedArray((), jnp.int32)
program_id_p.def_abstract_eval(_program_id_abstract_eval)
num_programs_p = jax_core.Primitive("num_programs")
def num_programs(axis: int) -> jax.Array:
def num_programs(axis: int) -> int | jax.Array:
"""Returns the size of the grid along the given axis."""
return num_programs_p.bind(axis=axis)
@num_programs_p.def_custom_bind
def _num_programs_bind(*, axis: int):
# We might be using a local grid env
grid_env = pallas_core.current_grid_env()
if grid_env:
return jnp.asarray(grid_env[axis].size, dtype=jnp.int32)
return jax_core.Primitive.bind(num_programs_p, axis=axis)
@num_programs_p.def_impl
def _num_programs_impl(*, axis: int):
grid_env = pallas_core.current_grid_env()
assert grid_env
return jnp.asarray(grid_env[axis].size, dtype=jnp.int32)
return grid_env[axis].size
# Otherwise, we look up the size of the grid in the axis env
frame = pallas_core.axis_frame()
size = frame.size(axis)
if size is pallas_core.dynamic_grid_dim:
return jax_core.Primitive.bind(num_programs_p, axis=axis)
return size
@num_programs_p.def_abstract_eval
def _num_programs_abstract_eval(**_):

View File

@ -339,6 +339,46 @@ class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest):
class PallasCallDynamicGridTest(PallasTPUTest):
def test_can_query_grid_statically_via_num_programs(self):
def kernel(_):
num_programs = pl.num_programs(0)
self.assertIsInstance(num_programs, int)
self.assertEqual(num_programs, 2)
pl.pallas_call(kernel, out_shape=None, grid=(2,))()
def test_can_query_grid_statically_via_num_programs_in_block_spec(self):
def kernel(*_):
pass
def x_index_map(_):
num_programs = pl.num_programs(0)
self.assertIsInstance(num_programs, int)
self.assertEqual(num_programs, 2)
return 0
pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(x_index_map, (8, 128))],
out_shape=None,
grid=(2,),
)(jnp.ones((8, 128)))
def test_dynamic_grid_has_dynamic_size(self):
def kernel(_):
num_programs = pl.num_programs(0)
self.assertIsInstance(num_programs, int, msg=type(num_programs))
self.assertEqual(num_programs, 2)
num_programs = pl.num_programs(1)
self.assertIsInstance(num_programs, jax.Array)
@jax.jit
def outer(x):
pl.pallas_call(kernel, out_shape=None, grid=(2, x))()
outer(2)
def test_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
@ -496,7 +536,7 @@ class PallasCallDynamicGridTest(PallasTPUTest):
out_shape=jax.ShapeDtypeStruct((1, 1), jnp.int32),
)()
self.assertEqual(dynamic_kernel(4), 8)
self.assertEqual(dynamic_kernel(np.int32(4)), 8)
@parameterized.parameters(range(1, 4))
def test_vmap_num_programs(self, num_vmaps):
@ -540,7 +580,7 @@ class PallasCallDynamicGridTest(PallasTPUTest):
)(x)
x = np.arange(4 * 8 * 128., dtype=np.int32).reshape((4 * 8, 128))
np.testing.assert_array_equal(dynamic_kernel(4, x), x[8:16])
np.testing.assert_array_equal(dynamic_kernel(np.int32(4), x), x[8:16])
class PallasCallInterpretDynamicGridTest(PallasCallDynamicGridTest):