mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas] Improve some error messages and add API tests.
We make the following improvements: * pytree structural disequality messages now attempt to localize the mismatch using tree_util.KeyPath. * we generate a simpler error message for when `in_specs` is not a sequence, instead of the current PyTreeDef mismatch error. * we generate an error message for when the index map function in a BlockSpec returns an unexpected number of results. * added error localization to the existing shape polymorphism check that the block shapes are static. * We check that the kernel function returns None. Without this we used to get `body_fun output and input must have same type structure` in the interpreter, `assert len(jaxpr.outvars) == 0` on GPU, and `INTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0` on TPU. * we check that the rank of the block_shape matches the rank of the overall array. Without this we used to get a `safe_zip` error. We also carry the pytree paths to localize the error. To simplify the generation of the error messages we added a helper function `tree_util.equality_errors_pytreedef`, which is just like `tree_util.equality_errors` but takes `PyTreeDef` inputs rather than PyTrees. We then used this new helper function in `pjit.py` and `stages.py`.
This commit is contained in:
parent
f0e36d5083
commit
a4a9499a40
@ -296,10 +296,12 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid:
|
||||
def _convert_block_spec_to_block_mapping(
|
||||
in_avals: Sequence[jax_core.ShapedArray],
|
||||
block_spec: BlockSpec,
|
||||
path: tree_util.KeyPath,
|
||||
aval: jax_core.ShapedArray,
|
||||
in_tree: Any,
|
||||
grid: GridMappingGrid,
|
||||
mapped_dims: tuple[int, ...],
|
||||
what: str, # Used to localize error messages, e.g., {what}{path}
|
||||
) -> BlockMapping | None:
|
||||
if block_spec is no_block_spec:
|
||||
return None
|
||||
@ -313,7 +315,13 @@ def _convert_block_spec_to_block_mapping(
|
||||
mapped if s is None else s for s in block_shape)
|
||||
flat_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index), in_tree)
|
||||
with tracing_grid_env(grid, mapped_dims):
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
if len(out_avals) != len(block_shape):
|
||||
raise ValueError(
|
||||
f"Index map for {what}{tree_util.keystr(path)} must return "
|
||||
f"{len(aval.shape)} values to match {block_shape=}. "
|
||||
f"Currently returning {len(out_avals)} values."
|
||||
)
|
||||
return BlockMapping(
|
||||
block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode
|
||||
)
|
||||
@ -327,42 +335,51 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
|
||||
return ref.update(inner_aval=ref.inner_aval.update(shape=shape))
|
||||
|
||||
|
||||
def _check_static_ref_shape(ref: state.AbstractRef) -> state.AbstractRef:
|
||||
shape = ref.shape
|
||||
if not jax_core.is_constant_shape(shape):
|
||||
# TODO(necula): thread the tree labels so that we can localize the error
|
||||
raise ValueError("shape polymorphism for Pallas does not support "
|
||||
f"dynamically-shaped blocks. Found block_shape: {shape}")
|
||||
return ref
|
||||
|
||||
|
||||
def _get_ref_avals(grid, in_avals, in_specs, out_avals, out_specs):
|
||||
def _get_memory_space(spec):
|
||||
def _get_ref_avals(grid,
|
||||
in_avals: Sequence[jax_core.ShapedArray],
|
||||
in_specs: Sequence[BlockSpec],
|
||||
in_paths: Sequence[tree_util.KeyPath],
|
||||
out_avals: Sequence[jax_core.ShapedArray],
|
||||
out_specs: Sequence[BlockSpec],
|
||||
out_paths: Sequence[tree_util.KeyPath]):
|
||||
def make_ref_aval(aval: jax_core.ShapedArray,
|
||||
spec: BlockSpec,
|
||||
path: tree_util.KeyPath,
|
||||
what: str) -> state.AbstractRef:
|
||||
if spec is no_block_spec:
|
||||
return None
|
||||
return spec.memory_space
|
||||
memory_space = None
|
||||
block_shape = None
|
||||
else:
|
||||
memory_space = spec.memory_space
|
||||
block_shape = spec.block_shape
|
||||
|
||||
ref_aval = AbstractMemoryRef(aval, memory_space)
|
||||
if block_shape is not None:
|
||||
if len(ref_aval.shape) != len(block_shape):
|
||||
raise ValueError(
|
||||
f"Block shape for {what}{tree_util.keystr(path)} (= {block_shape}) "
|
||||
f"must have the same number of dimensions as the array shape {ref_aval.shape}"
|
||||
)
|
||||
trimmed_block_shape = tuple(s for s in block_shape if s is not None)
|
||||
ref_aval = ref_aval.update(
|
||||
inner_aval=ref_aval.inner_aval.update(shape=trimmed_block_shape))
|
||||
|
||||
if not jax_core.is_constant_shape(ref_aval.shape):
|
||||
raise ValueError(
|
||||
"shape polymorphism for Pallas does not support "
|
||||
"dynamically-shaped blocks. "
|
||||
f"{what}{tree_util.keystr(path)} has block_shape: {ref_aval.shape}")
|
||||
return ref_aval
|
||||
|
||||
in_ref_avals = [
|
||||
AbstractMemoryRef(aval, _get_memory_space(in_spec))
|
||||
for aval, in_spec in zip(in_avals, in_specs)
|
||||
make_ref_aval(aval, in_spec, in_path, "input")
|
||||
for aval, in_spec, in_path in zip(in_avals, in_specs, in_paths)
|
||||
]
|
||||
out_ref_avals = [
|
||||
AbstractMemoryRef(aval, _get_memory_space(out_spec))
|
||||
for aval, out_spec in zip(out_avals, out_specs)
|
||||
make_ref_aval(aval, out_spec, out_path, "output")
|
||||
for aval, out_spec, out_path in zip(out_avals, out_specs, out_paths)
|
||||
]
|
||||
if grid is None:
|
||||
in_specs = [None] * len(in_avals)
|
||||
out_specs = [None] * len(out_avals)
|
||||
tiled_in_ref_avals = [
|
||||
_check_static_ref_shape(aval if in_spec is no_block_spec
|
||||
else _tile_ref(aval, in_spec.block_shape))
|
||||
for aval, in_spec in zip(in_ref_avals, in_specs)
|
||||
]
|
||||
tiled_out_ref_avals = [
|
||||
_check_static_ref_shape(aval if out_spec is no_block_spec
|
||||
else _tile_ref(aval, out_spec.block_shape))
|
||||
for aval, out_spec in zip(out_ref_avals, out_specs)
|
||||
]
|
||||
return in_specs, tiled_in_ref_avals, out_specs, tiled_out_ref_avals
|
||||
return in_specs, in_ref_avals, out_specs, out_ref_avals
|
||||
|
||||
class NoBlockSpec:
|
||||
pass
|
||||
@ -386,6 +403,8 @@ class GridSpec:
|
||||
# Be more lenient for in/out_specs
|
||||
if isinstance(in_specs, list):
|
||||
in_specs = tuple(in_specs)
|
||||
elif in_specs is not no_block_spec and not isinstance(in_specs, Sequence):
|
||||
raise ValueError(f"`in_specs` must be a tuple or a list. Found: {in_specs}")
|
||||
if isinstance(out_specs, list):
|
||||
out_specs = tuple(out_specs)
|
||||
|
||||
@ -410,20 +429,20 @@ class GridSpec:
|
||||
flat_in_specs = self.in_specs
|
||||
if self.in_specs_tree != in_tree:
|
||||
raise ValueError(
|
||||
"Pytree specs for arguments and `in_specs` must match: "
|
||||
f"{in_tree} vs. {self.in_specs_tree}")
|
||||
pytreedef_mismatch_err_msg("`in_specs`", self.in_specs_tree,
|
||||
"inputs", in_tree))
|
||||
if self.out_specs is no_block_spec:
|
||||
flat_out_specs = [no_block_spec] * len(out_avals)
|
||||
else:
|
||||
flat_out_specs = self.out_specs
|
||||
if self.out_specs_tree != out_tree:
|
||||
raise ValueError(
|
||||
"Pytree specs for `out_shape` and `out_specs` must match: "
|
||||
f"{out_tree} vs. {self.out_specs_tree}")
|
||||
pytreedef_mismatch_err_msg("`out_specs`", self.out_specs_tree,
|
||||
"`out_shape`", out_tree))
|
||||
return flat_in_specs, flat_out_specs
|
||||
|
||||
def get_grid_mapping(
|
||||
self, in_avals, in_tree, out_avals, out_tree
|
||||
self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths
|
||||
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
|
||||
assert all(i is None or isinstance(i, int) for i in self.grid)
|
||||
grid_mapping_grid = tuple(
|
||||
@ -432,8 +451,8 @@ class GridSpec:
|
||||
flat_in_specs, flat_out_specs = self._get_in_out_specs(
|
||||
in_avals, in_tree, out_avals, out_tree)
|
||||
in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals(
|
||||
self.grid, in_avals, flat_in_specs, out_avals,
|
||||
flat_out_specs)
|
||||
self.grid, in_avals, flat_in_specs, in_paths,
|
||||
out_avals, flat_out_specs, out_paths)
|
||||
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
|
||||
# Create args, kwargs pytree def
|
||||
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
|
||||
@ -444,8 +463,10 @@ class GridSpec:
|
||||
in_tree=grid_tree,
|
||||
grid=grid_mapping_grid,
|
||||
mapped_dims=(),
|
||||
what="input",
|
||||
),
|
||||
in_specs,
|
||||
in_paths,
|
||||
in_ref_avals,
|
||||
)
|
||||
out_block_mappings = map(
|
||||
@ -455,8 +476,10 @@ class GridSpec:
|
||||
in_tree=grid_tree,
|
||||
grid=grid_mapping_grid,
|
||||
mapped_dims=(),
|
||||
what="output",
|
||||
),
|
||||
out_specs,
|
||||
out_paths,
|
||||
out_ref_avals,
|
||||
)
|
||||
grid_mapping = GridMapping(
|
||||
@ -480,3 +503,18 @@ class GridSpec:
|
||||
static_self = copy.copy(self)
|
||||
static_self.grid = static_grid # type: ignore
|
||||
return static_self, dynamic_bounds
|
||||
|
||||
def pytreedef_mismatch_err_msg(
|
||||
what1: str, tree1: tree_util.PyTreeDef,
|
||||
what2: str, tree2: tree_util.PyTreeDef) -> str:
|
||||
errs = list(tree_util.equality_errors_pytreedef(tree1, tree2))
|
||||
msg = []
|
||||
msg.append(
|
||||
f"Pytree for {what1} and {what2} do not match. "
|
||||
f"There are {len(errs)} mismatches, including:")
|
||||
for path, thing1, thing2, explanation in errs:
|
||||
where = f"at {tree_util.keystr(path)}, " if path else ""
|
||||
msg.append(
|
||||
f" * {where}{what1} is a {thing1} but"
|
||||
f" {what2} is a {thing2}, so {explanation}")
|
||||
return "\n".join(msg)
|
||||
|
@ -170,7 +170,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
self.scratch_shapes = tuple(scratch_shapes)
|
||||
|
||||
def get_grid_mapping(
|
||||
self, in_avals, in_tree, out_avals, out_tree
|
||||
self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths
|
||||
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
|
||||
assert all(i is None or isinstance(i, int) for i in self.grid)
|
||||
grid_mapping_grid = tuple(
|
||||
@ -189,8 +189,8 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
in_avals, in_avals_tree, out_avals, out_tree)
|
||||
in_specs, in_ref_avals, out_specs, out_ref_avals = (
|
||||
pallas_core._get_ref_avals(
|
||||
self.grid, in_avals, flat_in_specs,
|
||||
out_avals, flat_out_specs))
|
||||
self.grid, in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:],
|
||||
out_avals, flat_out_specs, out_paths))
|
||||
scalar_ref_avals = [
|
||||
AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
|
||||
TPUMemorySpace.SMEM)
|
||||
@ -207,8 +207,10 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
in_tree=index_map_in_tree,
|
||||
grid=grid_mapping_grid,
|
||||
mapped_dims=(),
|
||||
what="input",
|
||||
),
|
||||
in_specs,
|
||||
in_paths[num_flat_scalar_prefetch:],
|
||||
in_ref_avals,
|
||||
)
|
||||
out_block_mappings = map(
|
||||
@ -218,8 +220,10 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
in_tree=index_map_in_tree,
|
||||
grid=grid_mapping_grid,
|
||||
mapped_dims=(),
|
||||
what="output",
|
||||
),
|
||||
out_specs,
|
||||
out_paths,
|
||||
out_ref_avals,
|
||||
)
|
||||
grid_mapping = GridMapping(
|
||||
|
@ -22,10 +22,10 @@ import string
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import ad_util
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import debugging
|
||||
from jax._src import dtypes
|
||||
@ -204,6 +204,11 @@ def ir_constant(x, mlir_type=None):
|
||||
lowering_rules = {}
|
||||
skip_mlir_conversions = set()
|
||||
|
||||
def _get_aval_physical_dtype_shape(aval):
|
||||
dtype_physical_shape = jax_core.physical_aval(aval).shape[
|
||||
len(aval.shape) :
|
||||
]
|
||||
return dtype_physical_shape
|
||||
|
||||
def _get_arg_type(
|
||||
aval,
|
||||
@ -427,6 +432,7 @@ def lower_jaxpr_to_module(
|
||||
mlir_func = lower_jaxpr_to_transform_func(
|
||||
ctx,
|
||||
bm.index_map_jaxpr.jaxpr,
|
||||
aval,
|
||||
name=func_name,
|
||||
mosaic_grid_mapping=mosaic_grid_mapping,
|
||||
)
|
||||
@ -434,6 +440,9 @@ def lower_jaxpr_to_module(
|
||||
block_shape = [
|
||||
1 if b is pl_core.mapped else b for b in bm.block_shape
|
||||
]
|
||||
# If we have an extended dtype, we need to add the block shape for the
|
||||
# remaining physical dtype.
|
||||
block_shape += list(_get_aval_physical_dtype_shape(aval.inner_aval))
|
||||
window_shape = ir.DenseI64ArrayAttr.get(block_shape)
|
||||
block_params = dict(
|
||||
window_bounds=window_shape,
|
||||
@ -469,6 +478,7 @@ def lower_jaxpr_to_module(
|
||||
def lower_jaxpr_to_transform_func(
|
||||
ctx: ir.Context,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
aval: jax_core.AbstractValue,
|
||||
*,
|
||||
name: str,
|
||||
mosaic_grid_mapping: MosaicGridMapping,
|
||||
@ -503,8 +513,16 @@ def lower_jaxpr_to_transform_func(
|
||||
mesh_context=mesh_context,
|
||||
traceback_caches=mlir.TracebackCaches(),
|
||||
)
|
||||
return jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices,
|
||||
out = jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices,
|
||||
*scalar_prefetch)
|
||||
assert isinstance(aval, state.AbstractRef), aval
|
||||
# If we have an extended dtype, we need to add 0s for the block indices
|
||||
# for the remaining physical dtype.
|
||||
out += [
|
||||
ir_constant(0, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
|
||||
] * len(_get_aval_physical_dtype_shape(aval.inner_aval))
|
||||
return out
|
||||
|
||||
body_func.__name__ = name
|
||||
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
|
||||
try:
|
||||
|
@ -23,7 +23,6 @@ from typing import Any
|
||||
import jax
|
||||
from jax import api_util
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import ad_util
|
||||
from jax._src import checkify
|
||||
from jax._src import config
|
||||
@ -31,6 +30,7 @@ from jax._src import core as jax_core
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -45,6 +45,7 @@ from jax._src.util import (
|
||||
safe_zip,
|
||||
split_list,
|
||||
tuple_insert,
|
||||
unzip2,
|
||||
weakref_lru_cache,
|
||||
)
|
||||
import jax.numpy as jnp
|
||||
@ -848,6 +849,7 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
# for the new error inputs and outputs.
|
||||
scalar_avals = map(checkify.get_shaped_aval, scalars)
|
||||
error_block_specs = [no_block_spec] * num_err_vals
|
||||
error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0])
|
||||
grid_avals = [
|
||||
jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
|
||||
# TODO(justinfu): Place these in device-specific scalar memory.
|
||||
@ -862,8 +864,9 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
(*grid_avals, *scalar_ref_avals),
|
||||
in_tree=grid_tree,
|
||||
grid=grid_mapping.grid,
|
||||
mapped_dims=grid_mapping.mapped_dims),
|
||||
error_block_specs, error_memref_aval)
|
||||
mapped_dims=grid_mapping.mapped_dims,
|
||||
what="error"),
|
||||
error_block_specs, error_paths, error_memref_aval)
|
||||
input_block_mappings, output_block_mappings = split_list(
|
||||
grid_mapping.block_mappings, [num_kernel_inputs,])
|
||||
grid_mapping_with_error = grid_mapping.replace(
|
||||
@ -893,10 +896,16 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
|
||||
|
||||
@weakref_lru_cache
|
||||
def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, flat_in_avals,
|
||||
flat_out_avals, in_tree, out_tree, interpret: bool):
|
||||
avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree,
|
||||
flat_out_avals, out_tree)
|
||||
def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
|
||||
flat_in_avals: Sequence[jax_core.AbstractValue],
|
||||
flat_out_avals: Sequence[jax_core.AbstractValue],
|
||||
in_tree: tree_util.PyTreeDef,
|
||||
in_paths: Sequence[tree_util.KeyPath],
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
out_paths: Sequence[tree_util.KeyPath],
|
||||
interpret: bool):
|
||||
avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree, in_paths,
|
||||
flat_out_avals, out_tree, out_paths)
|
||||
if interpret:
|
||||
avals = jax.tree_util.tree_map(_logical_aval_to_interpret_mode_aval, avals)
|
||||
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals)
|
||||
@ -1058,19 +1067,25 @@ def pallas_call(
|
||||
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
|
||||
if isinstance(out_shape, list):
|
||||
out_shape = tuple(out_shape)
|
||||
flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape)
|
||||
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype)
|
||||
flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape)
|
||||
out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths)
|
||||
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) # type: ignore
|
||||
for x in flat_out_shapes]
|
||||
@jax.jit
|
||||
def wrapped(*args):
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args)
|
||||
in_paths, flat_args = unzip2(flat_args_with_paths)
|
||||
flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a))
|
||||
for a in flat_args)
|
||||
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
|
||||
for v in flat_out_shapes)
|
||||
grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr(
|
||||
f, grid_spec, flat_in_avals, flat_out_avals, in_tree,
|
||||
out_tree, interpret=interpret)
|
||||
grid_mapping, jaxpr, consts, f_out_tree = _trace_to_jaxpr(
|
||||
f, grid_spec, flat_in_avals, flat_out_avals, in_tree, in_paths,
|
||||
out_tree, out_paths, interpret=interpret)
|
||||
if f_out_tree != tree_util.tree_flatten(None)[1]:
|
||||
raise ValueError(
|
||||
"The kernel function in a pallas_call should return None. "
|
||||
f"Found a PyTree: {f_out_tree}")
|
||||
out_flat = pallas_call_p.bind(
|
||||
*dynamic_grid_bounds, *consts, *flat_args,
|
||||
jaxpr=jaxpr, name=name,
|
||||
|
@ -1206,11 +1206,7 @@ def explain_tracing_cache_miss(
|
||||
p(f" never seen input pytree{in_tree_str}")
|
||||
dont_match = [t for t, *_ in seen_keys if t != in_tree]
|
||||
closest_tree = min(dont_match, key=lambda t: abs(t.num_leaves - in_tree.num_leaves))
|
||||
# TODO(mattjj): make equality_errors not print type name, avoid metaclass
|
||||
leaf = type('LeafMeta', (type,), dict(__repr__=lambda _: 'leaf'))('Leaf', (), {})()
|
||||
this_dummy = tree_unflatten(in_tree, [leaf] * in_tree.num_leaves)
|
||||
close_dummy = tree_unflatten(closest_tree, [leaf] * closest_tree.num_leaves) # type: ignore
|
||||
errs = list(tree_util.equality_errors(this_dummy, close_dummy))
|
||||
errs = list(tree_util.equality_errors_pytreedef(in_tree, closest_tree)) # type: ignore[arg-type]
|
||||
p(f" closest seen input pytree has {len(errs)} mismatches, including:")
|
||||
for path, thing1, thing2, explanation in errs:
|
||||
fst, *path = path # type: ignore
|
||||
|
@ -43,7 +43,6 @@ from jax._src import config
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util
|
||||
from jax._src.tree_util import tree_unflatten, keystr
|
||||
from jax._src import util
|
||||
from jax._src.sharding_impls import is_unspecified_or_auto
|
||||
from jax._src.layout import Layout
|
||||
@ -590,11 +589,7 @@ class Compiled(Stage):
|
||||
f"keyword arguments, but called with keyword arguments: {kws}")
|
||||
args_flat, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
if in_tree != params.in_tree:
|
||||
leaf = PytreeLeaf()
|
||||
this_dummy = tree_unflatten(in_tree, [leaf] * in_tree.num_leaves)
|
||||
other_dummy = tree_unflatten(
|
||||
params.in_tree, [leaf] * params.in_tree.num_leaves)
|
||||
errs = list(tree_util.equality_errors(this_dummy, other_dummy))
|
||||
errs = list(tree_util.equality_errors_pytreedef(in_tree, params.in_tree))
|
||||
msg = []
|
||||
msg.append(
|
||||
"Function compiled with input pytree does not match the input pytree"
|
||||
@ -603,7 +598,7 @@ class Compiled(Stage):
|
||||
fst, *rest = path
|
||||
base = ['args', 'kwargs'][fst.idx]
|
||||
msg.append(
|
||||
f" * at {base}{keystr(tuple(rest))}, seen {thing2} but now"
|
||||
f" * at {base}{tree_util.keystr(tuple(rest))}, seen {thing2} but now"
|
||||
f" given {thing1}, so {explanation}")
|
||||
raise TypeError('\n'.join(msg))
|
||||
try:
|
||||
@ -641,10 +636,6 @@ class Compiled(Stage):
|
||||
return self._call(*args, **kwargs)
|
||||
|
||||
|
||||
class PytreeLeaf:
|
||||
def __repr__(self): return "pytree leaf"
|
||||
|
||||
|
||||
class Lowered(Stage):
|
||||
"""Lowering of a function specialized to argument types and values.
|
||||
|
||||
|
@ -621,7 +621,7 @@ def equality_errors(
|
||||
"""Helper to describe structural differences between two pytrees.
|
||||
|
||||
Args:
|
||||
tree1, tree2: pytrees to compare.
|
||||
tree1, tree2: pytrees known to have different structure.
|
||||
|
||||
Usage:
|
||||
|
||||
@ -636,6 +636,15 @@ def equality_errors(
|
||||
"""
|
||||
yield from _equality_errors((), tree1, tree2, is_leaf)
|
||||
|
||||
def equality_errors_pytreedef(
|
||||
tree1: PyTreeDef,
|
||||
tree2: PyTreeDef) -> Iterable[tuple[KeyPath, str, str, str]]:
|
||||
"""Like `equality_errors` but invoked on PyTreeDef."""
|
||||
# TODO(mattjj): make equality_errors not print type name, avoid metaclass
|
||||
leaf = type("LeafMeta", (type,), dict(__repr__=lambda _: "pytree leaf"))("Leaf", (), {})()
|
||||
return equality_errors(tree_unflatten(tree1, [leaf] * tree1.num_leaves),
|
||||
tree_unflatten(tree2, [leaf] * tree2.num_leaves))
|
||||
|
||||
# TODO(mattjj): maybe share some logic with _prefix_error?
|
||||
def _equality_errors(path, t1, t2, is_leaf):
|
||||
# If both are leaves, this isn't a structure equality error.
|
||||
|
@ -16,6 +16,7 @@ import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
|
||||
@ -228,6 +229,27 @@ class PallasCallTest(PallasTest):
|
||||
for i in range(5):
|
||||
np.testing.assert_allclose(index(x, i), x[i])
|
||||
|
||||
def test_pallas_call_no_outputs(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref: None, ())
|
||||
self.assertAllClose((), f(a))
|
||||
|
||||
def test_pallas_call_out_shape_is_singleton_tuple(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref, o1_ref: None,
|
||||
out_shape=(a,))
|
||||
res = f(a)
|
||||
self.assertIsInstance(res, tuple)
|
||||
self.assertLen(res, 1)
|
||||
|
||||
def test_pallas_call_out_shape_is_list(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref, o1_ref: None,
|
||||
out_shape=[a])
|
||||
res = f(a)
|
||||
# TODO(necula): we normalize out_shape to a tuple, we shouldn't.
|
||||
self.assertIsInstance(res, tuple)
|
||||
|
||||
def test_hoisted_consts(self):
|
||||
# See https://github.com/google/jax/issues/21557.
|
||||
x = jnp.zeros(32)
|
||||
@ -441,6 +463,112 @@ class PallasCallInterpreterTest(PallasCallTest):
|
||||
INTERPRET = True
|
||||
|
||||
|
||||
class ApiErrorTest(PallasTest):
|
||||
|
||||
def test_pallas_kernel_args_mismatch(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref: None, # Missing o_ref
|
||||
out_shape=a)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"takes 1 positional argument but 2 were given"):
|
||||
f(a)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("array", 0),
|
||||
("empty_tuple", ())
|
||||
)
|
||||
def test_pallas_call_error_kernel_returns_something(self, returns):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
# The kernel should not return anything
|
||||
f = self.pallas_call(lambda x_ref, o1_ref, o2_ref: returns,
|
||||
out_shape=(a, a))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The kernel function in a pallas_call should return None"):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_in_specs_not_a_sequence(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"`in_specs` must be a tuple or a list"):
|
||||
_ = self.pallas_call(lambda x_ref, o1_ref: None,
|
||||
out_shape=a,
|
||||
in_specs=pl.BlockSpec((4,), lambda: 0))
|
||||
|
||||
def test_pallas_call_in_specs_mismatch_inputs(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref, o1_ref: None,
|
||||
out_shape=a,
|
||||
in_specs=[pl.BlockSpec((4,), lambda: 0),
|
||||
pl.BlockSpec((4,), lambda: 0)])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.compile("Pytree for `in_specs` and inputs do not match. "
|
||||
"There are 1 mismatches, including:"
|
||||
".* at \\[1\\], `in_specs` is a pytree leaf but "
|
||||
"inputs is a.*", re.DOTALL)):
|
||||
f(a, dict(a=a))
|
||||
|
||||
def test_pallas_call_index_map_wrong_number_of_arguments(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref, o1_ref: None,
|
||||
out_shape=a,
|
||||
in_specs=[pl.BlockSpec((4,), lambda i, j: 0)])
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"missing 2 required positional arguments: 'i' and 'j'"):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_index_map_wrong_number_of_results(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref, o1_ref: None,
|
||||
out_shape=a,
|
||||
in_specs=[pl.BlockSpec((4,), lambda: (0, 0))])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Index map for input\\[0\\] must return 1 values to match .*Currently returning 2 values."):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_out_specs_mismatch_shape(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref, o1_ref: None,
|
||||
out_shape=[a, a],
|
||||
out_specs=[pl.BlockSpec((6,), lambda i: i)])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.compile("Pytree for `out_specs` and `out_shape` do not match. There are 1 mismatches, including:"
|
||||
".* `out_specs` is a tuple of length 1 but `out_shape` is a tuple of length 2.*", re.DOTALL)):
|
||||
f(a)
|
||||
|
||||
|
||||
def test_pallas_call_block_shape_ndim_mismatch(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref, o1_ref: None,
|
||||
out_shape=[a],
|
||||
in_specs=[pl.BlockSpec((1, 1), lambda: (0, 0))])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Block shape for input\\[0\\] .* must have the same number of dimensions as the "
|
||||
"array shape"):
|
||||
|
||||
f(a)
|
||||
|
||||
f = self.pallas_call(lambda x_ref, o1_ref: None,
|
||||
out_shape=[a],
|
||||
out_specs=[pl.BlockSpec((1, 1), lambda: 0)])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Block shape for output\\[0\\] .* must have the same number of dimensions as the "
|
||||
"array shape"):
|
||||
f(a)
|
||||
|
||||
|
||||
class ApiErrorInterpreterTest(ApiErrorTest):
|
||||
INTERPRET = True
|
||||
|
||||
|
||||
class PallasControlFlowTest(PallasTest):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -115,6 +115,67 @@ class PallasCallScalarPrefetchTest(PallasTPUTest):
|
||||
)(s, x)
|
||||
np.testing.assert_array_equal(out, x)
|
||||
|
||||
def test_block_spec_with_wrong_block_shape_errors(self):
|
||||
def body(x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
|
||||
x = jnp.ones((16, 128))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Block shape .* must have the same number of dimensions as the array shape .*'):
|
||||
_ = pl.pallas_call(
|
||||
body,
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[pl.BlockSpec((128,), lambda i: (i, 0))], # WRONG
|
||||
out_specs=pl.BlockSpec((8, 128,), lambda i: (i, 0)),
|
||||
grid=(2,),
|
||||
),
|
||||
out_shape=x,
|
||||
interpret=self.interpret,
|
||||
)(x)
|
||||
|
||||
def test_block_spec_with_index_map_that_accepts_wrong_number_of_args_errors(self):
|
||||
def body(x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
|
||||
x = jnp.ones((16, 128))
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
'missing 1 required positional argument: \'j\''):
|
||||
_ = pl.pallas_call(
|
||||
body,
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[pl.BlockSpec((8, 128,), lambda i, j: (i, 0))], # WRONG
|
||||
out_specs=pl.BlockSpec((8, 128,), lambda i: (i, 0),),
|
||||
grid=(2,),
|
||||
),
|
||||
out_shape=x,
|
||||
interpret=self.interpret
|
||||
)(x)
|
||||
|
||||
def test_block_spec_with_index_map_returns_wrong_number_of_values_errors(self):
|
||||
def body(x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
|
||||
x = jnp.ones((16, 128))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'Index map for input\[0\] must return 2 values to match block_shape=\(8, 128\).'
|
||||
' Currently returning 1 values.'):
|
||||
_ = pl.pallas_call(
|
||||
body,
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[pl.BlockSpec((8, 128,), lambda i: (i,))], # WRONG
|
||||
out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)),
|
||||
grid=(2,),
|
||||
),
|
||||
out_shape=x,
|
||||
interpret=self.interpret,
|
||||
)(x)
|
||||
|
||||
def test_vmap_scalar_prefetch(self):
|
||||
def body(_, x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
@ -138,8 +199,7 @@ class PallasCallScalarPrefetchTest(PallasTPUTest):
|
||||
out_specs=pl.BlockSpec(
|
||||
(x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0)
|
||||
),
|
||||
grid=8,
|
||||
),
|
||||
grid=8),
|
||||
interpret=self.interpret,
|
||||
)(s, x)
|
||||
np.testing.assert_allclose(
|
||||
@ -363,7 +423,7 @@ class PallasCallDynamicGridTest(PallasTPUTest):
|
||||
num_programs = pl.num_programs(0)
|
||||
self.assertIsInstance(num_programs, int)
|
||||
self.assertEqual(num_programs, 2)
|
||||
return 0
|
||||
return 0, 0
|
||||
pl.pallas_call(
|
||||
kernel,
|
||||
in_specs=[pl.BlockSpec((8, 128), x_index_map)],
|
||||
|
@ -184,7 +184,7 @@ class BlockInvarianceTest(parameterized.TestCase):
|
||||
|
||||
def make_kernel_body(index_map):
|
||||
def body(key_ref, o_ref):
|
||||
key = key_ref[0, 0]
|
||||
key = key_ref[...]
|
||||
samples = plrandom.sample_block(
|
||||
jax.random.uniform,
|
||||
key,
|
||||
@ -199,9 +199,7 @@ class BlockInvarianceTest(parameterized.TestCase):
|
||||
|
||||
global_key = jax_random.key(0, impl="pallas_tpu")
|
||||
o_shape = jnp.ones((64, 512), dtype=jnp.float32)
|
||||
key_spec = pl.BlockSpec(
|
||||
(1, 1), lambda i, j: (0, 0), memory_space=pltpu.TPUMemorySpace.SMEM
|
||||
)
|
||||
key_spec = pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)
|
||||
out_spec = pl.BlockSpec((16, 128), lambda i, j: (i, j))
|
||||
result_16x128 = pl.pallas_call(
|
||||
make_kernel_body(index_map=lambda i, j: (i, j)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user