2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2021-09-08 09:00:23 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-06-26 14:44:52 -04:00
|
|
|
from collections.abc import Callable, Iterable, Sequence
|
2022-05-06 16:28:24 +01:00
|
|
|
import inspect
|
2021-09-08 09:00:23 -07:00
|
|
|
import operator
|
2024-04-23 14:49:11 -07:00
|
|
|
from functools import partial, lru_cache
|
[better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-20 17:17:44 +01:00
|
|
|
import re
|
2024-06-26 14:44:52 -04:00
|
|
|
from typing import Any
|
2021-09-08 09:00:23 -07:00
|
|
|
|
2022-12-16 20:59:41 -08:00
|
|
|
from jax._src import core
|
2024-12-18 22:11:25 +00:00
|
|
|
from jax._src import config
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax._src import dtypes
|
2024-12-18 22:11:25 +00:00
|
|
|
from jax._src.state.types import AbstractRef
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax._src.tree_util import (
|
2024-04-23 17:37:52 -07:00
|
|
|
PyTreeDef, tree_flatten, tree_unflatten, tree_map,
|
2025-02-07 10:15:47 +02:00
|
|
|
treedef_children, generate_key_paths, broadcast_prefix,
|
2023-03-29 14:54:24 -07:00
|
|
|
prefix_errors)
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax._src.tree_util import _replace_nones
|
2022-12-20 14:49:27 -08:00
|
|
|
from jax._src import linear_util as lu
|
make mlir arg and result names work with pmap
This is a follow-up on #15080 to restore (and indeed fix!) how pmap builds a
jaxpr with debug info (i.e. parameter names and result paths). The difference
with the machinery in #15080 is just to deal with pmap being final-style (i.e.
build the jaxpr at the last second, well after pytrees have been flattened away
and transformations have been applied), whereas the machinery for pjit in
imagine, plumbing for the former is a bit more long-range and subtle.
The main idea here is that we need to annotate and maintain debug info on the
lu.WrappedFun instance, which we first form at the api.py level, then pass
through transformations (which can either update or drop debug info), then
finally hand off to the impl rule to be traced to a jaxpr. It makes sense as an
annotation, parallel with the in_type annotation used for dynamic shapes,
because the debug info has to be updated as transformations are applied, since
they might e.g. add tangent inputs and outputs.
In more detail: with an initial-style higher-orer primitive (like pjit), a
jaxpr is formed immediately. Transformations, like autodiff, are
jaxpr-to-jaxpr, and so those transformations (like ad.jvp_jaxpr) need to return
a new jaxpr either with updated debug info or no debug info at all. (The initial
implementation in #15080 doesn't provide updated debug info in any of those
jaxpr-to-jaxpr transformation functions, so the debug info is only applied to
the jaxpr and then lowered to MLIR when the pjit as at the top level.)
For final-style, like pmap here, instead of transformations being
jaxpr-to-jaxpr, they're WrappedFun-to-WrappedFun. And so, analogously,
transformations, like ad.JVPTrace.process_map, would need to produce a
WrappedFun with updated debug info or no debug info at all. (ALso analogously
to #15080, this PR only implements enough for the debug info to be preserved
for top-level pmaps.)
This PR doens't yet delete the trace-time debug info in partial_eval.py. But
that'll happen too!
2023-03-17 17:45:41 -07:00
|
|
|
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
|
2024-04-23 14:49:11 -07:00
|
|
|
Unhashable, safe_zip)
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax._src import traceback_util
|
2021-09-08 09:00:23 -07:00
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
|
|
|
|
map = safe_map
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def _ensure_index(x: Any) -> int | tuple[int, ...]:
|
2021-09-08 09:00:23 -07:00
|
|
|
"""Ensure x is either an index or a tuple of indices."""
|
2022-09-12 12:28:14 -07:00
|
|
|
x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
|
2021-09-08 09:00:23 -07:00
|
|
|
try:
|
|
|
|
return operator.index(x)
|
|
|
|
except TypeError:
|
|
|
|
return tuple(map(operator.index, x))
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
|
2021-09-08 09:00:23 -07:00
|
|
|
"""Convert x to a tuple of indices."""
|
2022-09-12 12:28:14 -07:00
|
|
|
x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
|
2021-09-08 09:00:23 -07:00
|
|
|
try:
|
|
|
|
return (operator.index(x),)
|
|
|
|
except TypeError:
|
|
|
|
return tuple(map(operator.index, x))
|
|
|
|
|
|
|
|
def _ensure_str(x: str) -> str:
|
|
|
|
if not isinstance(x, str):
|
|
|
|
raise TypeError(f"argument is not a string: {x}")
|
|
|
|
return x
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
|
2021-09-08 09:00:23 -07:00
|
|
|
"""Convert x to a tuple of strings."""
|
|
|
|
if isinstance(x, str):
|
|
|
|
return (x,)
|
|
|
|
else:
|
|
|
|
return tuple(map(_ensure_str, x))
|
|
|
|
|
2024-11-12 22:39:26 -08:00
|
|
|
@lu.transformation_with_aux2
|
2025-01-24 12:53:51 +02:00
|
|
|
def flatten_fun(f: Callable, store: lu.Store,
|
|
|
|
in_tree: PyTreeDef, *args_flat):
|
2021-09-08 09:00:23 -07:00
|
|
|
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
|
2024-11-12 22:39:26 -08:00
|
|
|
ans = f(*py_args, **py_kwargs)
|
|
|
|
ans, out_tree = tree_flatten(ans)
|
|
|
|
store.store(out_tree)
|
|
|
|
return ans
|
2021-09-08 09:00:23 -07:00
|
|
|
|
|
|
|
def apply_flat_fun(fun, io_tree, *py_args):
|
|
|
|
in_tree_expected, out_tree = io_tree
|
|
|
|
args, in_tree = tree_flatten((py_args, {}))
|
|
|
|
if in_tree != in_tree_expected:
|
2022-05-12 19:13:00 +01:00
|
|
|
raise TypeError(f"Expected {in_tree_expected}, got {in_tree}")
|
2021-09-08 09:00:23 -07:00
|
|
|
ans = fun(*args)
|
|
|
|
return tree_unflatten(out_tree, ans)
|
|
|
|
|
2024-11-12 22:39:26 -08:00
|
|
|
@lu.transformation_with_aux2
|
[better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 1)
We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.
This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
2025-01-25 07:16:25 +02:00
|
|
|
def flatten_fun_nokwargs(f: Callable, store: lu.Store,
|
|
|
|
in_tree: PyTreeDef, *args_flat):
|
2021-09-08 09:00:23 -07:00
|
|
|
py_args = tree_unflatten(in_tree, args_flat)
|
2024-11-12 22:39:26 -08:00
|
|
|
ans = f(*py_args)
|
|
|
|
ans, out_tree = tree_flatten(ans)
|
|
|
|
store.store(out_tree)
|
|
|
|
return ans
|
2021-09-08 09:00:23 -07:00
|
|
|
|
|
|
|
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
|
|
|
|
in_tree_expected, out_tree = io_tree
|
|
|
|
args, in_tree = tree_flatten(py_args)
|
|
|
|
if in_tree != in_tree_expected:
|
2022-05-12 19:13:00 +01:00
|
|
|
raise TypeError(f"Expected {in_tree_expected}, got {in_tree}")
|
2021-09-08 09:00:23 -07:00
|
|
|
ans = fun(*args)
|
|
|
|
return tree_unflatten(out_tree, ans)
|
|
|
|
|
2024-11-12 22:39:26 -08:00
|
|
|
@lu.transformation_with_aux2
|
|
|
|
def flatten_fun_nokwargs2(f, store, in_tree, *args_flat):
|
2021-09-08 09:00:23 -07:00
|
|
|
py_args = tree_unflatten(in_tree, args_flat)
|
2024-11-12 22:39:26 -08:00
|
|
|
pair = f(*py_args)
|
2021-09-08 09:00:23 -07:00
|
|
|
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
|
|
|
|
raise TypeError("expected function with aux output to return a two-element "
|
2023-10-23 15:11:15 +01:00
|
|
|
f"tuple, but got type {type(pair)} with value {pair!r}")
|
2021-09-08 09:00:23 -07:00
|
|
|
ans, aux = pair
|
|
|
|
ans_flat, ans_tree = tree_flatten(ans)
|
|
|
|
aux_flat, aux_tree = tree_flatten(aux)
|
2024-11-12 22:39:26 -08:00
|
|
|
store.store((ans_tree, aux_tree))
|
|
|
|
return ans_flat, aux_flat
|
2021-09-08 09:00:23 -07:00
|
|
|
|
2022-01-25 14:35:23 -08:00
|
|
|
class _HashableWithStrictTypeEquality:
|
|
|
|
"""Box object used when comparing static arguments as a jit key.
|
|
|
|
|
|
|
|
Requires exact type equality using `is` and value equality."""
|
|
|
|
__slots__ = ["val"]
|
|
|
|
|
|
|
|
def __init__(self, val):
|
|
|
|
self.val = val
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash(self.val)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return type(self.val) is type(other.val) and self.val == other.val
|
|
|
|
|
2022-05-06 16:28:24 +01:00
|
|
|
_POSITIONAL_ARGUMENTS = (
|
|
|
|
inspect.Parameter.POSITIONAL_ONLY,
|
|
|
|
inspect.Parameter.POSITIONAL_OR_KEYWORD
|
|
|
|
)
|
|
|
|
|
2024-03-20 14:32:25 -07:00
|
|
|
def _validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_name: str) -> None:
|
2022-05-06 16:28:24 +01:00
|
|
|
"""
|
|
|
|
Validate that the argnums are sensible for a given function.
|
|
|
|
|
|
|
|
For functions that accept a variable number of positions arguments
|
|
|
|
(`f(..., *args)`) all positive argnums are considered valid.
|
|
|
|
"""
|
|
|
|
n_pos_args = 0
|
|
|
|
for param in sig.parameters.values():
|
|
|
|
if param.kind in _POSITIONAL_ARGUMENTS:
|
|
|
|
n_pos_args += 1
|
|
|
|
|
|
|
|
elif param.kind is inspect.Parameter.VAR_POSITIONAL:
|
|
|
|
# We can have any number of positional arguments
|
|
|
|
return
|
|
|
|
|
|
|
|
if argnums and (-min(argnums) > n_pos_args or max(argnums) >= n_pos_args):
|
2024-04-12 13:08:31 -07:00
|
|
|
raise ValueError(f"Jitted function has {argnums_name}={argnums}, "
|
|
|
|
f"but only accepts {n_pos_args} positional arguments.")
|
2022-05-06 16:28:24 +01:00
|
|
|
|
|
|
|
_INVALID_KEYWORD_ARGUMENTS = (
|
|
|
|
inspect.Parameter.POSITIONAL_ONLY,
|
|
|
|
inspect.Parameter.VAR_POSITIONAL
|
|
|
|
)
|
|
|
|
|
2024-03-20 14:32:25 -07:00
|
|
|
|
2022-05-06 16:28:24 +01:00
|
|
|
_KEYWORD_ARGUMENTS = (
|
|
|
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
|
|
inspect.Parameter.KEYWORD_ONLY,
|
|
|
|
)
|
2024-03-20 14:32:25 -07:00
|
|
|
def _validate_argnames(
|
2024-04-08 21:34:26 -07:00
|
|
|
sig: inspect.Signature, argnames: tuple[str, ...], argnames_name: str
|
2024-03-20 14:32:25 -07:00
|
|
|
) -> None:
|
2022-05-06 16:28:24 +01:00
|
|
|
"""
|
|
|
|
Validate that the argnames are sensible for a given function.
|
|
|
|
|
|
|
|
For functions that accept a variable keyword arguments
|
|
|
|
(`f(..., **kwargs)`) all argnames are considered valid except those
|
|
|
|
marked as position-only (`f(pos_only, /, ...)`).
|
|
|
|
"""
|
|
|
|
var_kwargs = False
|
2023-06-23 15:11:37 -07:00
|
|
|
valid_kwargs: set[str] = set()
|
|
|
|
invalid_kwargs: set[str] = set()
|
2022-05-06 16:28:24 +01:00
|
|
|
for param_name, param in sig.parameters.items():
|
|
|
|
if param.kind in _KEYWORD_ARGUMENTS:
|
|
|
|
valid_kwargs.add(param_name)
|
|
|
|
|
|
|
|
elif param.kind is inspect.Parameter.VAR_KEYWORD:
|
|
|
|
var_kwargs = True
|
|
|
|
|
|
|
|
elif param.kind in _INVALID_KEYWORD_ARGUMENTS:
|
|
|
|
invalid_kwargs.add(param_name)
|
|
|
|
|
|
|
|
# Check whether any kwargs are invalid due to position only
|
2024-04-12 13:08:31 -07:00
|
|
|
if invalid_argnames := (invalid_kwargs & set(argnames)):
|
|
|
|
raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} "
|
|
|
|
f"in {argnames_name}. These are positional-only")
|
2022-05-06 16:28:24 +01:00
|
|
|
|
|
|
|
# Takes any kwargs
|
|
|
|
if var_kwargs:
|
|
|
|
return
|
|
|
|
|
|
|
|
# Check that all argnames exist on function
|
2024-04-12 13:08:31 -07:00
|
|
|
if invalid_argnames := (set(argnames) - valid_kwargs):
|
|
|
|
raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} "
|
|
|
|
f"in {argnames_name}. Function does not take these args.")
|
2022-05-06 16:28:24 +01:00
|
|
|
|
|
|
|
|
2021-07-19 13:11:38 -04:00
|
|
|
def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True):
|
2021-09-08 09:00:23 -07:00
|
|
|
dyn_argnums = _ensure_index_tuple(dyn_argnums)
|
2022-05-11 11:18:25 -07:00
|
|
|
dyn_argnums = _ensure_inbounds(False, len(args), dyn_argnums)
|
2022-04-26 12:31:08 -07:00
|
|
|
if require_static_args_hashable:
|
|
|
|
fixed_args = []
|
|
|
|
for i, arg in enumerate(args):
|
|
|
|
if i in dyn_argnums: continue
|
2021-07-19 13:11:38 -04:00
|
|
|
if not is_hashable(arg):
|
|
|
|
raise ValueError(
|
|
|
|
"Non-hashable static arguments are not supported, as this can lead "
|
|
|
|
f"to unexpected cache-misses. Static argument (index {i}) of type "
|
|
|
|
f"{type(arg)} for function {f.__name__} is non-hashable.")
|
2022-04-26 12:31:08 -07:00
|
|
|
fixed_args.append(_HashableWithStrictTypeEquality(arg))
|
|
|
|
else:
|
|
|
|
fixed_args = [Unhashable(arg) for i, arg in enumerate(args)
|
|
|
|
if i not in dyn_argnums]
|
2021-07-19 13:11:38 -04:00
|
|
|
dyn_args = tuple(args[i] for i in dyn_argnums)
|
|
|
|
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
|
2021-09-08 09:00:23 -07:00
|
|
|
|
2025-01-23 11:38:06 -08:00
|
|
|
|
|
|
|
def prepend_static_args(f, static_args):
|
|
|
|
return _prepend_static_args(f, tuple(Unhashable(arg) for arg in static_args))
|
|
|
|
|
|
|
|
|
|
|
|
@lu.transformation2
|
|
|
|
def _prepend_static_args(f, static_args, *args, **kwargs):
|
|
|
|
static_args = tuple(arg.val for arg in static_args)
|
|
|
|
all_args = static_args + args
|
|
|
|
return f(*all_args, **kwargs)
|
|
|
|
|
|
|
|
|
2022-05-11 11:18:25 -07:00
|
|
|
def _ensure_inbounds(allow_invalid: bool, num_args: int, argnums: Sequence[int]
|
2023-06-23 15:11:37 -07:00
|
|
|
) -> tuple[int, ...]:
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
"""Ensure argnum is within bounds. Also resolves negative argnums."""
|
2022-05-11 11:18:25 -07:00
|
|
|
result = []
|
|
|
|
for i in argnums:
|
|
|
|
if i >= num_args and allow_invalid: continue
|
|
|
|
if not -num_args <= i < num_args:
|
|
|
|
raise ValueError(
|
|
|
|
"Positional argument indices, e.g. for `static_argnums`, must have "
|
|
|
|
"value greater than or equal to -len(args) and less than len(args), "
|
|
|
|
f"but got value {i} for len(args) == {num_args}.")
|
2022-05-06 16:28:24 +01:00
|
|
|
result.append(i % num_args) # Resolve negative
|
2022-05-11 11:18:25 -07:00
|
|
|
return tuple(result)
|
|
|
|
|
2022-05-06 16:28:24 +01:00
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
|
|
|
|
args: tuple[Any, ...], *, allow_invalid: bool):
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
"Version of ``argnums_partial`` that checks hashability of static_argnums."
|
2021-09-08 09:00:23 -07:00
|
|
|
if not static_argnums:
|
|
|
|
return f, args
|
2022-05-11 11:18:25 -07:00
|
|
|
static_argnums = _ensure_inbounds(allow_invalid, len(args), static_argnums)
|
2021-09-08 09:00:23 -07:00
|
|
|
dyn_argnums = tuple(i for i in range(len(args)) if i not in static_argnums)
|
|
|
|
dyn_args = tuple(args[i] for i in dyn_argnums)
|
|
|
|
|
2022-04-26 12:31:08 -07:00
|
|
|
fixed_args = []
|
2021-09-08 09:00:23 -07:00
|
|
|
for i in static_argnums:
|
2022-04-26 12:31:08 -07:00
|
|
|
# TODO(shoyer): set allow_invalid=True permanently after static_argnames.
|
2021-09-08 09:00:23 -07:00
|
|
|
if allow_invalid and i >= len(args):
|
|
|
|
continue
|
|
|
|
static_arg = args[i]
|
2022-05-06 16:28:24 +01:00
|
|
|
if not is_hashable(static_arg):
|
2021-09-08 09:00:23 -07:00
|
|
|
raise ValueError(
|
|
|
|
"Non-hashable static arguments are not supported, as this can lead "
|
|
|
|
f"to unexpected cache-misses. Static argument (index {i}) of type "
|
|
|
|
f"{type(static_arg)} for function {f.__name__} is non-hashable.")
|
|
|
|
else:
|
2024-05-17 09:46:36 +01:00
|
|
|
fixed_args.append(_HashableWithStrictTypeEquality(static_arg))
|
2021-09-08 09:00:23 -07:00
|
|
|
|
|
|
|
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
|
|
|
|
|
2024-11-12 22:39:26 -08:00
|
|
|
@lu.transformation2
|
2024-12-09 11:24:11 -05:00
|
|
|
def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs):
|
2022-04-26 12:31:08 -07:00
|
|
|
sentinel = object()
|
2024-12-09 11:24:11 -05:00
|
|
|
args = [sentinel] * (len(_fixed_args) + len(dyn_args))
|
|
|
|
for i, arg in zip(_dyn_argnums, dyn_args):
|
2021-09-08 09:00:23 -07:00
|
|
|
args[i] = arg
|
2024-12-09 11:24:11 -05:00
|
|
|
fixed_args_ = iter(_fixed_args)
|
2022-04-26 12:31:08 -07:00
|
|
|
args = [next(fixed_args_).val if x is sentinel else x for x in args]
|
|
|
|
assert next(fixed_args_, sentinel) is sentinel
|
2024-12-09 11:24:11 -05:00
|
|
|
return _fun(*args, **kwargs)
|
2021-09-08 09:00:23 -07:00
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
|
|
|
|
kwargs: dict[str, Any]):
|
2021-09-08 09:00:23 -07:00
|
|
|
if not static_argnames:
|
|
|
|
return f, kwargs
|
|
|
|
dyn_kwargs = {k: v for k, v in kwargs.items() if k not in static_argnames}
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
fixed_kwargs: dict[str, Any] = {}
|
2021-09-08 09:00:23 -07:00
|
|
|
for k, arg in kwargs.items():
|
2022-04-26 12:31:08 -07:00
|
|
|
if k not in dyn_kwargs:
|
2021-09-08 09:00:23 -07:00
|
|
|
try:
|
|
|
|
hash(arg)
|
|
|
|
except TypeError:
|
|
|
|
raise ValueError(
|
|
|
|
"Non-hashable static arguments are not supported, as this can lead "
|
|
|
|
f"to unexpected cache-misses. Static argument (name {k}) of type "
|
|
|
|
f"{type(arg)} for function {f.__name__} is non-hashable.")
|
|
|
|
else:
|
2024-05-17 09:46:36 +01:00
|
|
|
fixed_kwargs[k] = Hashable(arg)
|
2021-09-08 09:00:23 -07:00
|
|
|
|
|
|
|
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
|
|
|
|
|
2024-11-12 22:39:26 -08:00
|
|
|
@lu.transformation2
|
2024-12-09 11:24:11 -05:00
|
|
|
def _argnames_partial(_fun, _fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
|
|
|
|
kwargs = dict({k: v.val for k, v in _fixed_kwargs.val.items()}, **dyn_kwargs)
|
|
|
|
return _fun(*args, **kwargs)
|
2021-09-08 09:00:23 -07:00
|
|
|
|
|
|
|
|
2024-04-23 14:49:11 -07:00
|
|
|
@lru_cache(maxsize=4096)
|
2024-04-23 17:37:52 -07:00
|
|
|
def donation_vector(donate_argnums, donate_argnames, in_tree,
|
|
|
|
kws: bool = True) -> tuple[bool, ...]:
|
2023-07-14 14:27:29 -07:00
|
|
|
"""Returns a tuple with a boolean value for each leaf in args and kwargs.
|
|
|
|
|
|
|
|
What if a user specifies donate_argnums but calls the function with kwargs
|
|
|
|
or vice-versa? In that case, in `resolve_argnums` using the signature of the
|
|
|
|
function, the counterpart (donate_argnames or donate_argnums respectively) is
|
|
|
|
calculated so when this function is called both donate_argnums and
|
|
|
|
donate_argnames are available. This allows JAX to donate kwargs when only
|
|
|
|
donate_argnums is specified and vice-versa.
|
|
|
|
|
|
|
|
When both donate_argnums and donate_argnames are specified, only the args and
|
|
|
|
kwargs specified are donated.
|
|
|
|
"""
|
2023-06-23 15:11:37 -07:00
|
|
|
res: list[bool] = []
|
2024-04-23 17:37:52 -07:00
|
|
|
if kws:
|
|
|
|
args_tree, kwargs_tree = treedef_children(in_tree)
|
|
|
|
else:
|
|
|
|
args_tree, kwargs_tree = in_tree, None
|
|
|
|
for i, arg in enumerate(args_tree.children()):
|
2021-09-08 09:00:23 -07:00
|
|
|
donate = bool(i in donate_argnums)
|
2024-04-23 17:37:52 -07:00
|
|
|
res.extend((donate,) * arg.num_leaves)
|
|
|
|
if kwargs_tree is not None:
|
|
|
|
for key, val in safe_zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore
|
|
|
|
donate = key in donate_argnames
|
|
|
|
res.extend((donate,) * val.num_leaves)
|
2021-09-08 09:00:23 -07:00
|
|
|
return tuple(res)
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def rebase_donate_argnums(donate_argnums, static_argnums) -> tuple[int, ...]:
|
2021-09-08 09:00:23 -07:00
|
|
|
"""Shifts donate to account for static.
|
|
|
|
|
|
|
|
>>> rebase_donate_argnums((3, 4), (0, 1))
|
|
|
|
(1, 2)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
donate_argnums: An iterable of ints.
|
|
|
|
static_argnums: An iterable of ints.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A tuple of unique, sorted integer values based on donate_argnums with each
|
|
|
|
element offset to account for static_argnums.
|
|
|
|
"""
|
|
|
|
if not (static_argnums or donate_argnums):
|
|
|
|
return tuple(sorted(donate_argnums))
|
|
|
|
|
|
|
|
static_argnums = sorted(set(static_argnums))
|
|
|
|
donate_argnums = sorted(set(donate_argnums))
|
|
|
|
i = j = o = 0
|
|
|
|
out = []
|
|
|
|
while j < len(donate_argnums):
|
|
|
|
if i < len(static_argnums) and static_argnums[i] == donate_argnums[j]:
|
|
|
|
raise ValueError(f"`static_argnums` {static_argnums} and "
|
|
|
|
f"`donate_argnums` {donate_argnums} cannot intersect.")
|
|
|
|
|
|
|
|
if i < len(static_argnums) and static_argnums[i] < donate_argnums[j]:
|
|
|
|
o += 1
|
|
|
|
i += 1
|
|
|
|
else:
|
|
|
|
out.append(donate_argnums[j] - o)
|
|
|
|
j += 1
|
|
|
|
return tuple(out)
|
|
|
|
|
2021-07-19 13:11:38 -04:00
|
|
|
|
|
|
|
def is_hashable(arg):
|
2021-09-08 09:00:23 -07:00
|
|
|
try:
|
|
|
|
hash(arg)
|
2021-07-19 13:11:38 -04:00
|
|
|
return True
|
2021-09-08 09:00:23 -07:00
|
|
|
except TypeError:
|
2021-07-19 13:11:38 -04:00
|
|
|
return False
|
|
|
|
|
2021-09-08 09:00:23 -07:00
|
|
|
|
2025-03-10 11:37:50 -07:00
|
|
|
SENTINEL = object()
|
|
|
|
|
|
|
|
|
2021-09-08 09:00:23 -07:00
|
|
|
def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
|
|
|
|
# given an axis spec tree axis_tree (a pytree with integers and Nones at the
|
|
|
|
# leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
|
|
|
|
# the given treedef, build a complete axis spec tree with the same structure
|
|
|
|
# and return the flattened result
|
|
|
|
# TODO(mattjj,phawkins): improve this implementation
|
|
|
|
proxy = object()
|
2025-03-10 11:37:50 -07:00
|
|
|
dummy = tree_unflatten(treedef, [SENTINEL] * treedef.num_leaves)
|
2021-09-08 09:00:23 -07:00
|
|
|
axes = []
|
|
|
|
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
|
|
|
|
try:
|
2022-04-01 14:51:54 -07:00
|
|
|
tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy)
|
2021-09-08 09:00:23 -07:00
|
|
|
except ValueError:
|
|
|
|
if kws:
|
|
|
|
# if keyword arguments are included in the tree, we make adapt the error
|
|
|
|
# message only to be about the positional arguments
|
2023-01-03 15:34:59 -08:00
|
|
|
treedef, _ = treedef_children(treedef)
|
2021-09-08 09:00:23 -07:00
|
|
|
axis_tree, _ = axis_tree
|
|
|
|
hint = ""
|
|
|
|
if tupled_args:
|
|
|
|
hint += (f" Note that {name} that are non-trivial pytrees should always be "
|
|
|
|
f"wrapped in a tuple representing the argument list.")
|
|
|
|
if len(treedef.children()) == 1:
|
|
|
|
try:
|
|
|
|
flatten_axes(name, treedef, (axis_tree,))
|
|
|
|
except ValueError:
|
|
|
|
pass # That's not the issue.
|
|
|
|
else:
|
|
|
|
hint += (f" In particular, you're passing in a single argument which "
|
|
|
|
f"means that {name} might need to be wrapped in "
|
|
|
|
f"a singleton tuple.")
|
|
|
|
raise ValueError(f"{name} specification must be a tree prefix of the "
|
|
|
|
f"corresponding value, got specification {axis_tree} "
|
|
|
|
f"for value tree {treedef}.{hint}") from None
|
|
|
|
axes = [None if a is proxy else a for a in axes]
|
|
|
|
assert len(axes) == treedef.num_leaves
|
|
|
|
return axes
|
|
|
|
|
2023-03-29 14:54:24 -07:00
|
|
|
def flat_out_axes(
|
|
|
|
f: lu.WrappedFun, out_spec: Any
|
2023-06-23 15:11:37 -07:00
|
|
|
) -> tuple[lu.WrappedFun, Callable]:
|
2023-03-29 14:54:24 -07:00
|
|
|
leaves, treedef = tree_flatten(out_spec)
|
|
|
|
f, out_axes = _flat_out_axes(f, tuple(leaves), treedef)
|
|
|
|
return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef))
|
|
|
|
|
2024-11-12 22:39:26 -08:00
|
|
|
@lu.transformation_with_aux2
|
2024-12-09 11:24:11 -05:00
|
|
|
def _flat_out_axes(_fun, _store, _leaves, _treedef, *args, **kwargs):
|
|
|
|
ans = _fun(*args, **kwargs)
|
|
|
|
spec = tree_unflatten(_treedef, _leaves)
|
2023-03-29 14:54:24 -07:00
|
|
|
try:
|
|
|
|
spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None))
|
|
|
|
except ValueError:
|
|
|
|
e, *_ = prefix_errors(spec, ans)
|
|
|
|
# TODO(mattjj): currently hardcoded for pmap; generalize to vmap in followup
|
|
|
|
msg, = e('pmap out_axes').args
|
|
|
|
msg += ("\n\nThe full pytree is the output of the pmapped function. Ensure "
|
|
|
|
"that the `out_axes` argument to `pmap` is a pytree prefix of the "
|
|
|
|
"pmapped function's output.")
|
|
|
|
raise ValueError(msg) from None
|
2024-12-09 11:24:11 -05:00
|
|
|
_store.store(spec_flat)
|
2024-11-12 22:39:26 -08:00
|
|
|
return ans
|
2022-12-22 08:40:36 -08:00
|
|
|
|
|
|
|
def check_callable(fun):
|
|
|
|
# In Python 3.10+, the only thing stopping us from supporting staticmethods
|
|
|
|
# is that we can't take weak references to them, which the C++ JIT requires.
|
|
|
|
if isinstance(fun, staticmethod):
|
|
|
|
raise TypeError(f"staticmethod arguments are not supported, got {fun}")
|
|
|
|
if not callable(fun):
|
|
|
|
raise TypeError(f"Expected a callable value, got {fun}")
|
2023-07-19 23:37:30 +01:00
|
|
|
if inspect.isgeneratorfunction(fun):
|
2022-12-22 08:40:36 -08:00
|
|
|
raise TypeError(f"Expected a function, got a generator function: {fun}")
|
|
|
|
|
|
|
|
_POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD
|
|
|
|
|
|
|
|
def infer_argnums_and_argnames(
|
|
|
|
sig: inspect.Signature,
|
2023-12-11 13:59:29 +00:00
|
|
|
argnums: int | Iterable[int] | None,
|
|
|
|
argnames: str | Iterable[str] | None,
|
2023-06-23 15:11:37 -07:00
|
|
|
) -> tuple[tuple[int, ...], tuple[str, ...]]:
|
2022-12-22 08:40:36 -08:00
|
|
|
"""Infer missing argnums and argnames for a function with inspect."""
|
|
|
|
if argnums is None and argnames is None:
|
|
|
|
return (), ()
|
|
|
|
|
|
|
|
if argnums is not None and argnames is not None:
|
|
|
|
argnums = _ensure_index_tuple(argnums)
|
|
|
|
argnames = _ensure_str_tuple(argnames)
|
|
|
|
return argnums, argnames
|
|
|
|
|
|
|
|
parameters = sig.parameters
|
|
|
|
if argnums is None:
|
|
|
|
assert argnames is not None
|
|
|
|
argnames = _ensure_str_tuple(argnames)
|
|
|
|
argnums = tuple(
|
|
|
|
i for i, (k, param) in enumerate(parameters.items())
|
|
|
|
if param.kind == _POSITIONAL_OR_KEYWORD and k in argnames
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
argnums = _ensure_index_tuple(argnums)
|
|
|
|
argnames = tuple(
|
|
|
|
k for i, (k, param) in enumerate(parameters.items())
|
|
|
|
if param.kind == _POSITIONAL_OR_KEYWORD and i in argnums
|
|
|
|
)
|
|
|
|
|
|
|
|
return argnums, argnames
|
|
|
|
|
|
|
|
|
|
|
|
def resolve_argnums(
|
2024-03-20 14:32:25 -07:00
|
|
|
fun: Callable,
|
2024-03-21 06:35:20 -07:00
|
|
|
signature: inspect.Signature | None,
|
2024-03-20 14:32:25 -07:00
|
|
|
donate_argnums: int | Sequence[int] | None,
|
|
|
|
donate_argnames: str | Iterable[str] | None,
|
|
|
|
static_argnums: int | Sequence[int] | None,
|
|
|
|
static_argnames: str | Iterable[str] | None,
|
2023-07-14 14:27:29 -07:00
|
|
|
) -> tuple[tuple[int, ...], tuple[str, ...], tuple[int, ...], tuple[str, ...]]:
|
2024-03-20 14:32:25 -07:00
|
|
|
"""Validates and completes the argnum/argname specification for a jit.
|
|
|
|
|
|
|
|
* fills in any missing pieces (e.g., names given numbers, or vice versa),
|
|
|
|
* validates the argument names/numbers against the function signature,
|
|
|
|
* validates that donated and static arguments don't intersect.
|
|
|
|
* rebases the donated arguments so they index into the dynamic arguments,
|
|
|
|
(after static arguments have been removed), in the order that parameters
|
|
|
|
are passed into the compiled function.
|
|
|
|
"""
|
2024-03-21 06:35:20 -07:00
|
|
|
if signature is None:
|
2022-12-22 08:40:36 -08:00
|
|
|
# Some built-in functions don't support signature.
|
|
|
|
# See: https://github.com/python/cpython/issues/73485
|
|
|
|
# In this case no validation is done
|
|
|
|
static_argnums = () if static_argnums is None else _ensure_index_tuple(
|
|
|
|
static_argnums)
|
|
|
|
static_argnames = () if static_argnames is None else _ensure_str_tuple(
|
|
|
|
static_argnames)
|
2023-07-12 15:09:18 -07:00
|
|
|
donate_argnums = () if donate_argnums is None else _ensure_index_tuple(
|
|
|
|
donate_argnums)
|
|
|
|
if donate_argnames is not None:
|
|
|
|
raise ValueError(f"Getting the signature of function {fun} failed. "
|
2024-03-21 06:35:20 -07:00
|
|
|
"Pass donate_argnums instead of donate_argnames.")
|
2023-07-31 12:09:37 -07:00
|
|
|
assert donate_argnames is None
|
|
|
|
donate_argnames = ()
|
2022-12-22 08:40:36 -08:00
|
|
|
else:
|
|
|
|
# Infer argnums and argnames according to docstring
|
2023-07-12 15:09:18 -07:00
|
|
|
# If nums is None and names is not None, then nums are inferred from the
|
|
|
|
# names and vice-versa.
|
2022-12-22 08:40:36 -08:00
|
|
|
static_argnums, static_argnames = infer_argnums_and_argnames(
|
2024-03-21 06:35:20 -07:00
|
|
|
signature, static_argnums, static_argnames)
|
2023-07-12 15:09:18 -07:00
|
|
|
donate_argnums, donate_argnames = infer_argnums_and_argnames(
|
2024-03-21 06:35:20 -07:00
|
|
|
signature, donate_argnums, donate_argnames)
|
2022-12-22 08:40:36 -08:00
|
|
|
|
|
|
|
# Validation
|
2024-03-21 06:35:20 -07:00
|
|
|
_validate_argnums(signature, static_argnums, "static_argnums")
|
|
|
|
_validate_argnames(signature, static_argnames, "static_argnames")
|
|
|
|
_validate_argnums(signature, donate_argnums, "donate_argnums")
|
|
|
|
_validate_argnames(signature, donate_argnames, "donate_argnames")
|
2022-12-22 08:40:36 -08:00
|
|
|
|
|
|
|
# Compensate for static argnums absorbing args
|
2024-03-20 14:32:25 -07:00
|
|
|
_assert_no_intersection(static_argnames, donate_argnames)
|
2022-12-22 08:40:36 -08:00
|
|
|
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)
|
2023-07-14 14:27:29 -07:00
|
|
|
return donate_argnums, donate_argnames, static_argnums, static_argnames
|
|
|
|
|
|
|
|
|
2024-03-20 14:32:25 -07:00
|
|
|
def _assert_no_intersection(static_argnames, donate_argnames):
|
2023-07-14 14:27:29 -07:00
|
|
|
out = set(static_argnames).intersection(set(donate_argnames))
|
|
|
|
if out:
|
|
|
|
raise ValueError(
|
|
|
|
"static_argnames and donate_argnames cannot intersect. Argument names "
|
|
|
|
f"{out} appear in both static_argnames and donate_argnames")
|
2022-12-22 08:40:36 -08:00
|
|
|
|
|
|
|
|
2024-08-13 00:29:33 -07:00
|
|
|
def resolve_kwargs(fun: Callable, args, kwargs) -> tuple[Any, ...]:
|
|
|
|
"""Resolve input arguments to positional following a function's signature.
|
|
|
|
|
|
|
|
This will raise a TypeError if any keyword-only arguments were passed by the
|
|
|
|
caller.
|
|
|
|
"""
|
|
|
|
if isinstance(fun, partial):
|
|
|
|
# functools.partial should have an opaque signature.
|
|
|
|
fun = lambda *args, **kwargs: None
|
|
|
|
ba = inspect.signature(fun).bind(*args, **kwargs)
|
|
|
|
ba.apply_defaults()
|
|
|
|
if ba.kwargs:
|
|
|
|
passed_kwargs = [k for k in ba.kwargs if k in kwargs]
|
|
|
|
if passed_kwargs:
|
|
|
|
raise TypeError(
|
2025-03-04 10:31:35 -05:00
|
|
|
"The following keyword arguments could not be resolved to positions: "
|
|
|
|
f"{', '.join(passed_kwargs)}"
|
|
|
|
)
|
2024-08-13 00:29:33 -07:00
|
|
|
return ba.args
|
|
|
|
|
|
|
|
|
2021-09-08 09:00:23 -07:00
|
|
|
def _dtype(x):
|
|
|
|
try:
|
|
|
|
return dtypes.result_type(x)
|
|
|
|
except ValueError:
|
|
|
|
return dtypes.result_type(getattr(x, 'dtype'))
|
|
|
|
|
2024-03-21 17:54:08 -07:00
|
|
|
|
2021-09-08 09:00:23 -07:00
|
|
|
# This decorator exists to make it easier to monkey-patch APIs in JAX.
|
|
|
|
# By default it does nothing, but it can be monkey-patched to do other things.
|
|
|
|
def api_hook(fun, tag: str):
|
|
|
|
return fun
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
|
2025-01-15 21:36:38 +00:00
|
|
|
|
2025-01-31 22:23:20 +02:00
|
|
|
def debug_info(
|
2025-01-15 21:36:38 +00:00
|
|
|
traced_for: str,
|
|
|
|
fun: Callable,
|
|
|
|
args: Sequence[Any],
|
|
|
|
kwargs: dict[str, Any],
|
|
|
|
*,
|
2025-01-24 12:53:51 +02:00
|
|
|
static_argnums: Sequence[int] = (),
|
|
|
|
static_argnames: Sequence[str] = (),
|
2025-01-15 21:36:38 +00:00
|
|
|
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
|
2025-01-24 10:57:28 +02:00
|
|
|
# TODO(necula): check if we really need this, e.g., to speed up tracing?
|
2025-01-15 21:36:38 +00:00
|
|
|
sourceinfo: str | None = None,
|
|
|
|
signature: inspect.Signature | None = None,
|
2025-01-31 22:23:20 +02:00
|
|
|
) -> core.DebugInfo:
|
2025-02-07 10:15:47 +02:00
|
|
|
"""Constructd core.DebugInfo for a function given example args and kwargs.
|
|
|
|
|
|
|
|
`args` and `kwargs` are example positional and keyword arguments, users with
|
|
|
|
`inspect.Signature` to get the names of argments. The arguments that are
|
|
|
|
considered static for tracing purposes should be included, and designated
|
|
|
|
using `static_argnums` and `static_argnames`.
|
|
|
|
|
|
|
|
See docstring for linear_util.DebugInfo.
|
|
|
|
"""
|
2025-01-15 21:36:38 +00:00
|
|
|
if sourceinfo is None:
|
|
|
|
sourceinfo = fun_sourceinfo(fun)
|
|
|
|
if signature is None:
|
|
|
|
signature = fun_signature(fun)
|
|
|
|
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
|
|
|
|
static_argnames)
|
2025-01-31 22:23:20 +02:00
|
|
|
return core.DebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
|
2025-01-15 21:36:38 +00:00
|
|
|
|
|
|
|
|
2024-02-15 13:48:49 -08:00
|
|
|
def fun_signature(fun: Callable) -> inspect.Signature | None:
|
|
|
|
try:
|
|
|
|
return inspect.signature(fun)
|
|
|
|
except (ValueError, TypeError):
|
|
|
|
return None
|
|
|
|
|
2025-02-06 12:44:38 +02:00
|
|
|
def save_wrapped_fun_sourceinfo(wrapper: Callable,
|
2025-02-13 22:06:18 -08:00
|
|
|
wrapped: Callable | core.DebugInfo) -> None:
|
2025-01-09 13:26:29 +02:00
|
|
|
# Prefer this to functools.wraps because it does not create a reference to
|
|
|
|
# the wrapped function.
|
2025-02-06 12:44:38 +02:00
|
|
|
if isinstance(wrapped, core.DebugInfo):
|
|
|
|
func_src_info = wrapped.func_src_info
|
|
|
|
elif callable(wrapped):
|
|
|
|
func_src_info = fun_sourceinfo(wrapped)
|
|
|
|
else:
|
2025-02-13 22:06:18 -08:00
|
|
|
assert False, wrapped # Unreachable
|
2025-02-06 12:44:38 +02:00
|
|
|
setattr(wrapper, "__fun_sourceinfo__", func_src_info)
|
[better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-20 17:17:44 +01:00
|
|
|
|
|
|
|
_fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
|
2025-01-09 13:26:29 +02:00
|
|
|
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
# TODO(mattjj): make this function internal to this module
|
[better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-20 17:17:44 +01:00
|
|
|
def fun_sourceinfo(fun: Callable) -> str:
|
2025-01-31 22:23:20 +02:00
|
|
|
# See DebugInfo.fun_src_info
|
2025-01-09 13:26:29 +02:00
|
|
|
res = getattr(fun, "__fun_sourceinfo__", None)
|
|
|
|
if res is not None: return res
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
while isinstance(fun, partial):
|
|
|
|
fun = fun.func
|
|
|
|
fun = inspect.unwrap(fun)
|
|
|
|
try:
|
|
|
|
filename = fun.__code__.co_filename
|
|
|
|
lineno = fun.__code__.co_firstlineno
|
|
|
|
return f"{fun.__name__} at {filename}:{lineno}"
|
[better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-20 17:17:44 +01:00
|
|
|
except AttributeError as e:
|
|
|
|
try:
|
|
|
|
fun_str = str(fun)
|
|
|
|
except:
|
|
|
|
return "<unknown>"
|
|
|
|
# By contract, the function name has no spaces; also, we want to avoid
|
|
|
|
# fun_sourceinfo of the form "<object Foo at 0x1234>", because it makes
|
|
|
|
# lowering non-deterministic.
|
|
|
|
if m := _fun_name_re.match(fun_str):
|
|
|
|
return m.group(1)
|
|
|
|
return "<unknown>"
|
|
|
|
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
|
2025-01-15 21:36:38 +00:00
|
|
|
def _non_static_arg_names(fn_signature: inspect.Signature | None,
|
|
|
|
args: Sequence[Any], kwargs: dict[str, Any],
|
|
|
|
static_argnums: Sequence[int],
|
|
|
|
static_argnames: Sequence[str],
|
2025-02-18 10:09:47 +01:00
|
|
|
) -> tuple[str, ...]:
|
[better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-20 17:17:44 +01:00
|
|
|
"""Returns the names of the non-static arguments.
|
|
|
|
|
|
|
|
If the `fn_signature` is given then we get from it the names of the
|
|
|
|
top-level arguments. In other cases, including when the `args` and `kwargs`
|
|
|
|
do not match the signature, we use names like `args[0[]`, `args[1]`, etc.
|
|
|
|
"""
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
static = object()
|
|
|
|
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
|
|
|
|
static_argnames_ = set(static_argnames)
|
|
|
|
args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)]
|
[better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-20 17:17:44 +01:00
|
|
|
kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
|
|
|
|
if fn_signature is not None:
|
|
|
|
try:
|
|
|
|
ba = fn_signature.bind(*args_, **kwargs_)
|
|
|
|
except (ValueError, TypeError):
|
|
|
|
pass
|
|
|
|
else:
|
2025-02-07 10:15:47 +02:00
|
|
|
return tuple(f'{name}{lu._clean_keystr_arg_names(path)}'
|
|
|
|
for name, x in ba.arguments.items()
|
[better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-20 17:17:44 +01:00
|
|
|
for path, l in generate_key_paths(x) if l is not static)
|
2025-02-07 10:15:47 +02:00
|
|
|
args_arg_names = tuple(f'args{lu._clean_keystr_arg_names(path)}'
|
[better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-20 17:17:44 +01:00
|
|
|
for path, l in generate_key_paths(args_)
|
|
|
|
if l is not static)
|
2025-02-07 10:15:47 +02:00
|
|
|
kwargs_arg_names = tuple(f'kwargs{lu._clean_keystr_arg_names(path)}'
|
[better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-20 17:17:44 +01:00
|
|
|
for path, l in generate_key_paths(kwargs_)
|
|
|
|
if l is not static)
|
|
|
|
arg_names = args_arg_names + kwargs_arg_names
|
|
|
|
return arg_names
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
|
2024-02-13 16:45:27 -08:00
|
|
|
def hoist_obj_attrs(f, flat_args):
|
|
|
|
idxs, objs, flat_args_ = [], [], []
|
|
|
|
for i, x in enumerate(flat_args):
|
|
|
|
if type(x) in _class_with_attrs:
|
|
|
|
objs.append(_HashableByObjectId(x))
|
|
|
|
else:
|
|
|
|
idxs.append(i)
|
|
|
|
flat_args_.append(x)
|
|
|
|
return _argnums_partial(f, tuple(idxs), tuple(objs)), flat_args_
|
|
|
|
|
|
|
|
class _HashableByObjectId:
|
|
|
|
__slots__ = ['val']
|
|
|
|
def __init__(self, val):
|
|
|
|
self.val = val
|
|
|
|
def __hash__(self):
|
|
|
|
return id(self.val)
|
|
|
|
def __eq__(self, other):
|
|
|
|
return self.val is other.val
|
|
|
|
|
2024-06-26 14:44:52 -04:00
|
|
|
def register_class_with_attrs(t: type) -> None:
|
2024-02-13 16:45:27 -08:00
|
|
|
_class_with_attrs.add(t)
|
2024-06-26 14:44:52 -04:00
|
|
|
_class_with_attrs: set[type] = set()
|
2024-12-18 22:11:25 +00:00
|
|
|
|
|
|
|
# TODO(mattjj): make this function faster
|
2025-02-13 22:06:18 -08:00
|
|
|
def _check_no_aliased_ref_args(dbg: core.DebugInfo, avals, args):
|
2024-12-18 22:11:25 +00:00
|
|
|
assert config.mutable_array_checks.value
|
|
|
|
refs: dict[int, int] = {}
|
|
|
|
for i, (a, x) in enumerate(zip(avals, args)):
|
|
|
|
if (isinstance(a, AbstractRef) and
|
|
|
|
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
|
|
|
|
raise ValueError(
|
|
|
|
"only one reference to a mutable array may be passed as an argument "
|
|
|
|
f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} "
|
|
|
|
f"the mutable array reference of type {a.str_short()} appeared at both "
|
|
|
|
f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}."
|
|
|
|
if dbg else
|
|
|
|
f"at both flat index {dup_idx} and flat index {i}") from None
|
|
|
|
|
2025-02-13 22:06:18 -08:00
|
|
|
def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo, consts, args) -> None:
|
2024-12-18 22:11:25 +00:00
|
|
|
assert config.mutable_array_checks.value
|
|
|
|
refs: set[int] = {id(core.get_referent(c)) for c in consts
|
|
|
|
if isinstance(core.get_aval(c), AbstractRef)}
|
|
|
|
for i, x in enumerate(args):
|
|
|
|
if id(core.get_referent(x)) in refs:
|
2024-12-19 07:06:12 -08:00
|
|
|
a = core.shaped_abstractify(x)
|
2024-12-18 22:11:25 +00:00
|
|
|
raise ValueError(
|
|
|
|
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
|
|
|
|
f"array reference of type {a.str_short()} was both closed over and "
|
|
|
|
f"passed as the argument "
|
2025-01-24 10:57:28 +02:00
|
|
|
f"{dbg.safe_arg_names(len(args))[i]}" if dbg else "at flat index {i}")
|