2022-12-15 20:34:43 -08:00
|
|
|
# Copyright 2018 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
|
2023-12-08 20:10:08 +00:00
|
|
|
from collections import Counter, defaultdict, deque, namedtuple
|
2024-06-26 14:44:52 -04:00
|
|
|
from collections.abc import (Callable, Collection, Generator, Hashable,
|
|
|
|
Iterable, Iterator, Set, Sequence, MutableSet,
|
2024-03-04 05:41:29 -08:00
|
|
|
MutableMapping)
|
2024-05-18 08:45:31 -07:00
|
|
|
from contextlib import contextmanager, ExitStack
|
2022-12-15 20:34:43 -08:00
|
|
|
from dataclasses import dataclass
|
|
|
|
import functools
|
|
|
|
from functools import partial, partialmethod, total_ordering
|
|
|
|
import gc
|
|
|
|
import inspect
|
|
|
|
import itertools as it
|
2023-02-28 12:40:30 -08:00
|
|
|
import math
|
2022-12-15 20:34:43 -08:00
|
|
|
import operator
|
|
|
|
import threading
|
|
|
|
import types
|
2024-06-26 14:44:52 -04:00
|
|
|
from typing import (Any, ClassVar, Generic, NamedTuple, TypeVar,
|
2023-12-11 13:59:29 +00:00
|
|
|
cast, overload, Union)
|
2022-12-15 20:34:43 -08:00
|
|
|
import warnings
|
|
|
|
from weakref import ref
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2024-06-13 13:14:27 -07:00
|
|
|
from jax._src import deprecations
|
2022-12-15 20:34:43 -08:00
|
|
|
from jax._src import dtypes
|
2023-10-09 07:28:18 -07:00
|
|
|
from jax._src import config
|
2023-02-01 17:50:00 -08:00
|
|
|
from jax._src import effects
|
2024-05-17 16:31:23 -07:00
|
|
|
from jax._src import compute_on
|
2023-03-28 18:30:36 -07:00
|
|
|
from jax._src.errors import (
|
2023-06-21 01:41:45 -07:00
|
|
|
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
|
2023-03-28 18:30:36 -07:00
|
|
|
TracerIntegerConversionError, UnexpectedTracerError)
|
2022-12-20 14:49:27 -08:00
|
|
|
from jax._src import linear_util as lu
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
from jax._src import source_info_util
|
2023-02-28 12:40:30 -08:00
|
|
|
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
|
2022-12-15 20:34:43 -08:00
|
|
|
tuple_delete, as_hashable_function,
|
2023-04-19 17:01:05 -07:00
|
|
|
HashableFunction, HashableWrapper, weakref_lru_cache,
|
2024-01-26 13:52:47 +00:00
|
|
|
partition_list, StrictABCMeta)
|
2022-12-15 20:34:43 -08:00
|
|
|
import jax._src.pretty_printer as pp
|
|
|
|
from jax._src.lib import jax_jit
|
|
|
|
from jax._src import traceback_util
|
2023-05-17 07:58:19 -07:00
|
|
|
from jax._src.typing import Array, DimSize, Shape
|
2022-12-15 20:34:43 -08:00
|
|
|
from jax._src import typing
|
2024-04-10 13:14:43 -07:00
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
|
|
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
map, unsafe_map = safe_map, map
|
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.int_flag(
|
2023-07-27 12:15:16 -07:00
|
|
|
'jax_tracer_error_num_traceback_frames',
|
2023-10-09 07:28:18 -07:00
|
|
|
config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
|
2023-07-27 12:15:16 -07:00
|
|
|
help='Set the number of stack frames in JAX tracer error messages.'
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
# -------------------- jaxprs --------------------
|
|
|
|
|
2023-02-01 17:50:00 -08:00
|
|
|
Effect = effects.Effect
|
|
|
|
Effects = effects.Effects
|
|
|
|
EffectTypeSet = effects.EffectTypeSet
|
|
|
|
no_effects: Effects = effects.no_effects
|
2022-12-15 20:34:43 -08:00
|
|
|
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
class JaxprDebugInfo(NamedTuple):
|
|
|
|
traced_for: str # e.g. 'jit', 'scan', etc
|
2024-07-05 09:51:02 +01:00
|
|
|
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
|
2023-07-21 14:20:39 -04:00
|
|
|
arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... )
|
2024-05-16 15:10:01 +01:00
|
|
|
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)
|
2023-03-02 09:58:14 -08:00
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
class Jaxpr:
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
|
|
|
|
'_effects', '_debug_info']
|
2023-01-23 11:33:30 -08:00
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
_constvars: list[Var]
|
|
|
|
_invars: list[Var]
|
|
|
|
_outvars: list[Atom]
|
|
|
|
_eqns: list[JaxprEqn]
|
2023-01-23 09:37:58 -08:00
|
|
|
_effects: Effects
|
2023-07-21 14:20:39 -04:00
|
|
|
_debug_info: JaxprDebugInfo | None
|
2023-01-23 09:37:58 -08:00
|
|
|
|
2023-11-09 13:39:34 -05:00
|
|
|
@property
|
|
|
|
def constvars(self) -> list[Var]:
|
|
|
|
return self._constvars
|
|
|
|
|
|
|
|
@property
|
|
|
|
def invars(self) -> list[Var]:
|
|
|
|
return self._invars
|
|
|
|
|
|
|
|
@property
|
|
|
|
def outvars(self) -> list[Atom]:
|
|
|
|
return self._outvars
|
|
|
|
|
|
|
|
@property
|
|
|
|
def eqns(self) -> list[JaxprEqn]:
|
|
|
|
return self._eqns
|
|
|
|
|
|
|
|
@property
|
|
|
|
def effects(self) -> Effects:
|
|
|
|
return self._effects
|
|
|
|
|
|
|
|
@property
|
|
|
|
def debug_info(self) -> JaxprDebugInfo | None:
|
|
|
|
return self._debug_info
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
|
|
|
|
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
|
2023-03-02 09:58:14 -08:00
|
|
|
effects: Effects = no_effects,
|
2023-07-21 14:20:39 -04:00
|
|
|
debug_info: JaxprDebugInfo | None = None):
|
2022-12-15 20:34:43 -08:00
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
constvars: list of variables introduced for constants. Array constants are
|
|
|
|
replaced with such variables while scalar constants are kept inline.
|
|
|
|
invars: list of input variables. Together, `constvars` and `invars` are
|
|
|
|
the inputs to the Jaxpr.
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
outvars: list of output atoms.
|
2022-12-15 20:34:43 -08:00
|
|
|
eqns: list of equations.
|
|
|
|
effects: set of effects. The effects on a jaxpr are a superset of the
|
|
|
|
union of the effects for each equation.
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
debug_info: optional JaxprDebugInfo.
|
2022-12-15 20:34:43 -08:00
|
|
|
"""
|
2023-01-23 09:37:58 -08:00
|
|
|
self._constvars = list(constvars)
|
|
|
|
self._invars = list(invars)
|
|
|
|
self._outvars = list(outvars)
|
|
|
|
self._eqns = list(eqns)
|
|
|
|
self._effects = effects
|
2023-03-02 09:58:14 -08:00
|
|
|
self._debug_info = debug_info
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
assert (not debug_info or len(debug_info.arg_names) == len(invars) and
|
|
|
|
len(debug_info.result_paths) == len(outvars))
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def __str__(self):
|
2023-12-08 20:10:08 +00:00
|
|
|
return str(self.pretty_print())
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
__repr__ = __str__
|
|
|
|
|
|
|
|
def pretty_print(self, *, source_info=False, print_shapes=True,
|
2023-02-17 12:45:39 -08:00
|
|
|
custom_pp_eqn_rules=True, name_stack=False,
|
2023-12-08 20:10:08 +00:00
|
|
|
print_effects: bool = False, **kwargs):
|
2024-05-06 09:59:18 -04:00
|
|
|
doc = pp_toplevel_jaxpr(
|
|
|
|
self, source_info=source_info, print_shapes=print_shapes,
|
|
|
|
custom_pp_eqn_rules=custom_pp_eqn_rules, name_stack=name_stack,
|
|
|
|
print_effects=print_effects)
|
|
|
|
return doc.format(**kwargs)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def _repr_pretty_(self, p, cycle):
|
|
|
|
return p.text(self.pretty_print(use_color=True))
|
|
|
|
|
2024-04-30 08:30:55 -07:00
|
|
|
def replace(self, **kwargs):
|
|
|
|
jaxpr = Jaxpr(
|
|
|
|
constvars=kwargs.pop("constvars", self.constvars),
|
|
|
|
invars=kwargs.pop("invars", self.invars),
|
|
|
|
outvars=kwargs.pop("outvars", self.outvars),
|
|
|
|
eqns=kwargs.pop("eqns", self.eqns),
|
|
|
|
effects=kwargs.pop("effects", self.effects),
|
|
|
|
debug_info=kwargs.pop("debug_info", self.debug_info),
|
|
|
|
)
|
|
|
|
if kwargs:
|
|
|
|
raise ValueError(f"Unknown keyword arguments: {kwargs}")
|
|
|
|
return jaxpr
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def join_effects(*effects: Effects) -> Effects:
|
2023-12-11 08:45:06 -08:00
|
|
|
return set().union(*effects) if effects else no_effects
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def jaxprs_in_params(params) -> Iterator[Jaxpr]:
|
|
|
|
for val in params.values():
|
|
|
|
vals = val if isinstance(val, tuple) else (val,)
|
|
|
|
for v in vals:
|
|
|
|
if isinstance(v, Jaxpr):
|
|
|
|
yield v
|
|
|
|
elif isinstance(v, ClosedJaxpr):
|
|
|
|
yield v.jaxpr
|
|
|
|
|
|
|
|
|
|
|
|
def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
|
|
|
|
"""Generator for all subjaxprs found in the params of jaxpr.eqns.
|
|
|
|
Does not descend recursively into the found subjaxprs.
|
|
|
|
"""
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
yield from jaxprs_in_params(eqn.params)
|
|
|
|
|
|
|
|
|
|
|
|
class ClosedJaxpr:
|
2023-01-23 11:33:30 -08:00
|
|
|
__slots__ = ['__weakref__', '_jaxpr', '_consts']
|
|
|
|
|
2023-01-23 09:37:58 -08:00
|
|
|
_jaxpr: Jaxpr
|
2023-06-23 15:11:37 -07:00
|
|
|
_consts: list[Any]
|
2023-01-23 09:37:58 -08:00
|
|
|
|
|
|
|
jaxpr = property(lambda self: self._jaxpr)
|
|
|
|
consts = property(lambda self: self._consts)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
|
|
|
|
assert len(consts) == len(jaxpr.constvars)
|
2023-01-10 13:39:25 -08:00
|
|
|
# assert not any(isinstance(c, Tracer) for c in consts) # TODO(mattjj): enable
|
2023-01-23 09:37:58 -08:00
|
|
|
self._jaxpr = jaxpr
|
|
|
|
self._consts = list(consts)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def in_avals(self):
|
|
|
|
return [v.aval for v in self.jaxpr.invars]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def out_avals(self):
|
|
|
|
return [v.aval for v in self.jaxpr.outvars]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def literals(self):
|
|
|
|
return self.consts # backwards compatible alias
|
|
|
|
|
|
|
|
@property
|
|
|
|
def eqns(self):
|
|
|
|
return self.jaxpr.eqns
|
|
|
|
|
|
|
|
@property
|
|
|
|
def effects(self) -> Effects:
|
|
|
|
return self.jaxpr.effects
|
|
|
|
|
|
|
|
def map_jaxpr(self, f):
|
|
|
|
return ClosedJaxpr(f(self.jaxpr), self.consts)
|
|
|
|
|
|
|
|
def replace(self, *, jaxpr=None, consts=None):
|
|
|
|
jaxpr = self.jaxpr if jaxpr is None else jaxpr
|
|
|
|
consts = self.consts if consts is None else consts
|
|
|
|
return ClosedJaxpr(jaxpr, consts)
|
|
|
|
|
|
|
|
def __str__(self): return str(self.jaxpr)
|
|
|
|
def __repr__(self): return repr(self.jaxpr)
|
|
|
|
|
|
|
|
def pretty_print(self, *, source_info=False, print_shapes=True,
|
2023-12-08 20:10:08 +00:00
|
|
|
name_stack=False, custom_pp_eqn_rules=True,
|
|
|
|
print_effects=False, **kwargs):
|
|
|
|
return self.jaxpr.pretty_print(
|
|
|
|
source_info=source_info,
|
|
|
|
print_shapes=print_shapes,
|
|
|
|
name_stack=name_stack,
|
|
|
|
custom_pp_eqn_rules=custom_pp_eqn_rules,
|
|
|
|
print_effects=print_effects,
|
|
|
|
**kwargs)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def _repr_pretty_(self, p, cycle):
|
|
|
|
return p.text(self.pretty_print(use_color=True))
|
|
|
|
|
|
|
|
@curry
|
|
|
|
def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
|
|
|
|
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
|
|
|
|
|
|
|
|
|
2024-05-18 08:45:31 -07:00
|
|
|
class JaxprEqnContext:
|
2024-05-17 16:31:23 -07:00
|
|
|
|
2024-05-24 09:14:43 -07:00
|
|
|
def __init__(self, compute_type: str | None, threefry_partitionable: bool):
|
2024-05-17 16:31:23 -07:00
|
|
|
self.compute_type = compute_type
|
2024-05-24 09:14:43 -07:00
|
|
|
self.threefry_partitionable = threefry_partitionable
|
|
|
|
self._managers = [
|
|
|
|
(compute_on.extend_compute_type, self.compute_type),
|
|
|
|
(config.threefry_partitionable.__call__, self.threefry_partitionable),
|
|
|
|
]
|
2024-05-17 16:31:23 -07:00
|
|
|
|
2024-05-18 08:45:31 -07:00
|
|
|
@property
|
|
|
|
@contextmanager
|
|
|
|
def manager(self):
|
|
|
|
with ExitStack() as stack:
|
|
|
|
for manager, val in self._managers:
|
|
|
|
stack.enter_context(manager(val))
|
|
|
|
yield
|
2024-05-17 15:58:25 -07:00
|
|
|
|
2024-05-21 22:42:28 -07:00
|
|
|
def __repr__(self):
|
2024-05-24 09:14:43 -07:00
|
|
|
return (f"JaxprEqnContext(compute_type={self.compute_type},"
|
|
|
|
f"threefry_partitionable={self.threefry_partitionable})")
|
2024-05-21 22:42:28 -07:00
|
|
|
|
2024-05-17 15:58:25 -07:00
|
|
|
|
Use a class with `__slots__` instead of a NamedTuple in JaxprEqn and SourceInfo, which are two tuples we build frequently.
Surprisingly this is faster. With Python 3.12:
```
In [1]: from typing import NamedTuple
In [2]: class C(NamedTuple):
...: a: int
...: b: int
...: c: int
...: d: int
...: e: int
...: f: int
...: g: int
...:
In [3]: class D:
...: __slots__ = ('a', 'b', 'c', 'd', 'e', 'f', 'g')
...: def __init__(self, a, b, c, d, e, f, g):
...: self.a = a
...: self.b = b
...: self.c = c
...: self.d = d
...: self.e = e
...: self.f = f
...: self.g = g
...:
In [4]: %timeit D(1, 2, 3, 4, 5, 6, 7)
158 ns ± 0.458 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
In [5]: %timeit C(1, 2, 3, 4, 5, 6, 7)
236 ns ± 0.498 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [6]: %timeit D(1, 2, 3, 4, 5, 6, 7)
159 ns ± 2.13 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
In [7]: %timeit C(1, 2, 3, 4, 5, 6, 7)
235 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
```
No behavioral changes intended.
PiperOrigin-RevId: 648556436
2024-07-01 19:18:15 -07:00
|
|
|
class JaxprEqn:
|
2023-06-23 15:11:37 -07:00
|
|
|
invars: list[Atom]
|
|
|
|
outvars: list[Var]
|
2022-12-15 20:34:43 -08:00
|
|
|
primitive: Primitive
|
2023-06-23 15:11:37 -07:00
|
|
|
params: dict[str, Any]
|
2022-12-15 20:34:43 -08:00
|
|
|
effects: Effects
|
|
|
|
source_info: source_info_util.SourceInfo
|
2024-05-17 15:58:25 -07:00
|
|
|
ctx: JaxprEqnContext
|
2022-12-15 20:34:43 -08:00
|
|
|
|
Use a class with `__slots__` instead of a NamedTuple in JaxprEqn and SourceInfo, which are two tuples we build frequently.
Surprisingly this is faster. With Python 3.12:
```
In [1]: from typing import NamedTuple
In [2]: class C(NamedTuple):
...: a: int
...: b: int
...: c: int
...: d: int
...: e: int
...: f: int
...: g: int
...:
In [3]: class D:
...: __slots__ = ('a', 'b', 'c', 'd', 'e', 'f', 'g')
...: def __init__(self, a, b, c, d, e, f, g):
...: self.a = a
...: self.b = b
...: self.c = c
...: self.d = d
...: self.e = e
...: self.f = f
...: self.g = g
...:
In [4]: %timeit D(1, 2, 3, 4, 5, 6, 7)
158 ns ± 0.458 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
In [5]: %timeit C(1, 2, 3, 4, 5, 6, 7)
236 ns ± 0.498 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [6]: %timeit D(1, 2, 3, 4, 5, 6, 7)
159 ns ± 2.13 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
In [7]: %timeit C(1, 2, 3, 4, 5, 6, 7)
235 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
```
No behavioral changes intended.
PiperOrigin-RevId: 648556436
2024-07-01 19:18:15 -07:00
|
|
|
# It's slightly faster to use a class with __slots__ than a NamedTuple.
|
|
|
|
__slots__ = ['invars', 'outvars', 'primitive', 'params', 'effects',
|
|
|
|
'source_info', 'ctx']
|
|
|
|
|
|
|
|
def __init__(self, invars, outvars, primitive, params, effects, source_info,
|
|
|
|
ctx):
|
|
|
|
self.invars = invars
|
|
|
|
self.outvars = outvars
|
|
|
|
self.primitive = primitive
|
|
|
|
self.params = params
|
|
|
|
self.effects = effects
|
|
|
|
self.source_info = source_info
|
|
|
|
self.ctx = ctx
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
def __repr__(self):
|
|
|
|
return str(pp_eqn(self, JaxprPpContext(), JaxprPpSettings())).rstrip()
|
|
|
|
|
2023-04-11 08:59:28 -07:00
|
|
|
def replace(
|
|
|
|
self,
|
2023-07-21 14:20:39 -04:00
|
|
|
invars: list[Atom] | None = None,
|
|
|
|
outvars: list[Var] | None = None,
|
|
|
|
primitive: Primitive | None = None,
|
|
|
|
params: dict[str, Any] | None = None,
|
|
|
|
effects: Effects | None = None,
|
|
|
|
source_info: source_info_util.SourceInfo | None = None,
|
2024-05-17 15:58:25 -07:00
|
|
|
ctx: JaxprEqnContext | None = None
|
2023-04-11 08:59:28 -07:00
|
|
|
):
|
|
|
|
return JaxprEqn(
|
|
|
|
self.invars if invars is None else invars,
|
|
|
|
self.outvars if outvars is None else outvars,
|
|
|
|
self.primitive if primitive is None else primitive,
|
|
|
|
self.params if params is None else params,
|
|
|
|
self.effects if effects is None else effects,
|
|
|
|
self.source_info if source_info is None else source_info,
|
2024-05-17 15:58:25 -07:00
|
|
|
self.ctx if ctx is None else ctx,
|
2023-04-11 08:59:28 -07:00
|
|
|
)
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
# TODO(mattjj): call typecheck rules here, so we don't form bad eqns
|
2024-05-17 15:58:25 -07:00
|
|
|
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None,
|
|
|
|
ctx=None):
|
2022-12-15 20:34:43 -08:00
|
|
|
source_info = source_info or source_info_util.new_source_info()
|
2024-05-24 09:14:43 -07:00
|
|
|
ctx = ctx or JaxprEqnContext(compute_on.current_compute_type(),
|
|
|
|
config.threefry_partitionable.value)
|
2023-10-09 07:28:18 -07:00
|
|
|
if config.enable_checks.value:
|
2022-12-15 20:34:43 -08:00
|
|
|
assert all(isinstance(x, (Var, Literal)) for x in invars)
|
|
|
|
assert all(isinstance(v, Var) for v in outvars)
|
2024-05-17 15:58:25 -07:00
|
|
|
return JaxprEqn(invars, outvars, primitive, params, effects, source_info, ctx)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
2024-02-16 05:56:45 -08:00
|
|
|
_var_counter = it.count()
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
@total_ordering
|
|
|
|
class Var:
|
2023-04-07 12:32:55 -07:00
|
|
|
__slots__ = ["count", "suffix", "aval"]
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
count: int
|
|
|
|
suffix: str
|
|
|
|
aval: AbstractValue
|
|
|
|
|
Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
2024-02-16 05:56:45 -08:00
|
|
|
def __init__(self, suffix: str, aval: AbstractValue):
|
|
|
|
self.count = next(_var_counter)
|
2022-12-15 20:34:43 -08:00
|
|
|
self.suffix = suffix
|
|
|
|
self.aval = raise_to_shaped(aval)
|
|
|
|
|
Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
2024-02-16 05:56:45 -08:00
|
|
|
# TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not
|
|
|
|
# care about variable ordering, but the downstream package kfac_jax does.
|
2022-12-15 20:34:43 -08:00
|
|
|
def __lt__(self, other):
|
Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
2024-02-16 05:56:45 -08:00
|
|
|
return self.count < other.count
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def __repr__(self):
|
Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
2024-02-16 05:56:45 -08:00
|
|
|
return f'Var(id={id(self)}){self.suffix}:{self.aval.str_short()}'
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
|
Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
2024-02-16 05:56:45 -08:00
|
|
|
def gensym(suffix: str = '') -> Callable[[AbstractValue], Var]:
|
|
|
|
"""Produce distinct variables, printed with the optional suffix."""
|
|
|
|
return partial(Var, suffix)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
# In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that
|
|
|
|
# the assignment is dropped, i.e. that an expression's output value will never
|
|
|
|
# be read. In that sense, `dropvar` is not a variable, but it is convenient to
|
|
|
|
# treat it as a special case of one. Its `aval` is similarly inexact.
|
|
|
|
class DropVar(Var):
|
|
|
|
def __init__(self, aval: AbstractValue):
|
Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
2024-02-16 05:56:45 -08:00
|
|
|
super().__init__('', aval)
|
2022-12-15 20:34:43 -08:00
|
|
|
def __repr__(self): return '_'
|
|
|
|
|
|
|
|
class Literal:
|
|
|
|
__slots__ = ["val", "aval", "hash"]
|
|
|
|
|
|
|
|
val: Any
|
|
|
|
aval: AbstractValue
|
2023-07-21 14:20:39 -04:00
|
|
|
hash: int | None
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def __init__(self, val, aval):
|
|
|
|
self.val = val
|
|
|
|
self.aval = aval
|
|
|
|
try:
|
|
|
|
self.hash = hash(val)
|
|
|
|
except TypeError:
|
|
|
|
if type(val) in literalable_types:
|
|
|
|
try:
|
|
|
|
self.hash = hash((val.item(), val.dtype))
|
|
|
|
except (TypeError, AttributeError, ValueError):
|
|
|
|
self.hash = None
|
|
|
|
|
|
|
|
__hash__ = None # type: ignore
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
if hasattr(self, 'hash'):
|
|
|
|
return f'{self.val}'
|
|
|
|
else:
|
|
|
|
return f'Literal(val={self.val})'
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
literalable_types: set[type] = set()
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
Atom = Union[Var, Literal]
|
|
|
|
|
|
|
|
class Primitive:
|
|
|
|
name: str
|
|
|
|
# set for multi-output primitives.
|
|
|
|
multiple_results: bool = False
|
|
|
|
# set for call primitives processed in final style.
|
|
|
|
call_primitive: bool = False
|
|
|
|
# set for map primitives processed in final style.
|
|
|
|
map_primitive: bool = False
|
|
|
|
|
|
|
|
def __init__(self, name: str):
|
|
|
|
self.name = name
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return f'{self.name}'
|
|
|
|
|
|
|
|
def bind(self, *args, **params):
|
2023-10-09 07:28:18 -07:00
|
|
|
assert (not config.enable_checks.value or
|
2022-12-15 20:34:43 -08:00
|
|
|
all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
|
|
|
|
return self.bind_with_trace(find_top_trace(args), args, params)
|
|
|
|
|
|
|
|
def bind_with_trace(self, trace, args, params):
|
2024-04-30 14:11:27 -07:00
|
|
|
with pop_level(trace.level):
|
|
|
|
out = trace.process_primitive(self, map(trace.full_raise, args), params)
|
2022-12-15 20:34:43 -08:00
|
|
|
return map(full_lower, out) if self.multiple_results else full_lower(out)
|
|
|
|
|
|
|
|
def def_impl(self, impl):
|
|
|
|
self.impl = impl
|
|
|
|
return impl
|
|
|
|
|
|
|
|
def def_abstract_eval(self, abstract_eval):
|
2024-05-17 09:46:36 +01:00
|
|
|
self.abstract_eval = _effect_free_abstract_eval(abstract_eval)
|
2022-12-15 20:34:43 -08:00
|
|
|
return abstract_eval
|
|
|
|
|
|
|
|
def def_effectful_abstract_eval(self, effectful_abstract_eval):
|
2024-05-17 09:46:36 +01:00
|
|
|
self.abstract_eval = effectful_abstract_eval
|
2022-12-15 20:34:43 -08:00
|
|
|
return effectful_abstract_eval
|
|
|
|
|
|
|
|
def def_custom_bind(self, bind):
|
|
|
|
self.bind = bind
|
|
|
|
return bind
|
|
|
|
|
|
|
|
def impl(self, *args, **params):
|
|
|
|
raise NotImplementedError("Evaluation rule for '{}' not implemented"
|
|
|
|
.format(self.name))
|
|
|
|
|
|
|
|
def abstract_eval(self, *args, **params):
|
|
|
|
raise NotImplementedError("Abstract evaluation for '{}' not implemented"
|
|
|
|
.format(self.name))
|
|
|
|
|
|
|
|
def get_bind_params(self, params):
|
|
|
|
return [], params
|
|
|
|
|
|
|
|
|
|
|
|
def _effect_free_abstract_eval(abstract_eval):
|
|
|
|
def abstract_eval_(*args, **kwargs):
|
|
|
|
return abstract_eval(*args, **kwargs), no_effects
|
|
|
|
return abstract_eval_
|
|
|
|
|
|
|
|
# -------------------- lifting --------------------
|
|
|
|
|
|
|
|
# TODO(mattjj): replace this approach with a primitive-keyed table of rules
|
|
|
|
def traverse_jaxpr_params(f, params):
|
|
|
|
"""Applies f to each jaxpr parameter and returns a tuple of returned values."""
|
|
|
|
return {name: f(p)
|
|
|
|
for name, param in params.items()
|
|
|
|
for p in (param if isinstance(param, (tuple, list)) else [param])
|
|
|
|
if type(p) in (Jaxpr, ClosedJaxpr)}
|
|
|
|
|
|
|
|
|
2024-02-15 12:27:13 -08:00
|
|
|
def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[Any]:
|
2022-12-15 20:34:43 -08:00
|
|
|
def read(v: Atom) -> Any:
|
|
|
|
return v.val if isinstance(v, Literal) else env[v]
|
|
|
|
|
|
|
|
def write(v: Var, val: Any) -> None:
|
2023-10-09 07:28:18 -07:00
|
|
|
if config.enable_checks.value and not config.dynamic_shapes.value:
|
2022-12-15 20:34:43 -08:00
|
|
|
assert typecheck(v.aval, val), (v.aval, val)
|
|
|
|
env[v] = val
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
env: dict[Var, Any] = {}
|
2022-12-15 20:34:43 -08:00
|
|
|
map(write, jaxpr.constvars, consts)
|
|
|
|
map(write, jaxpr.invars, args)
|
2023-06-16 06:07:54 -07:00
|
|
|
lu = last_used(jaxpr)
|
2022-12-15 20:34:43 -08:00
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
|
|
|
|
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
|
2023-04-19 13:46:33 -04:00
|
|
|
traceback = eqn.source_info.traceback if propagate_source_info else None
|
2024-05-18 08:45:31 -07:00
|
|
|
with source_info_util.user_context(
|
|
|
|
traceback, name_stack=name_stack), eqn.ctx.manager:
|
2022-12-15 20:34:43 -08:00
|
|
|
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
|
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
map(write, eqn.outvars, ans)
|
|
|
|
else:
|
|
|
|
write(eqn.outvars[0], ans)
|
2023-06-16 06:07:54 -07:00
|
|
|
clean_up_dead_vars(eqn, env, lu)
|
2022-12-15 20:34:43 -08:00
|
|
|
return map(read, jaxpr.outvars)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------- tracing --------------------
|
|
|
|
|
2023-03-14 09:55:38 -07:00
|
|
|
TracerType = TypeVar('TracerType', bound='Tracer')
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-03-14 09:55:38 -07:00
|
|
|
class Trace(Generic[TracerType]):
|
2022-12-15 20:34:43 -08:00
|
|
|
__slots__ = ['main', 'level', 'sublevel']
|
|
|
|
|
|
|
|
main: MainTrace
|
|
|
|
level: int
|
|
|
|
sublevel: Sublevel
|
|
|
|
|
|
|
|
def __init__(self, main: MainTrace, sublevel: Sublevel) -> None:
|
|
|
|
self.main = main
|
|
|
|
self.level = main.level
|
|
|
|
self.sublevel = sublevel
|
|
|
|
|
2023-03-14 09:55:38 -07:00
|
|
|
def full_raise(self, val) -> TracerType:
|
2022-12-15 20:34:43 -08:00
|
|
|
if not isinstance(val, Tracer):
|
2024-02-13 07:18:35 -08:00
|
|
|
# This check is only applied to non-Tracers, because the hasattr() is
|
|
|
|
# expensive (Tracer.__getattr__) in the common case that val is a Tracer.
|
|
|
|
if hasattr(val, "dimension_as_value"): # Used for shape_poly._DimExpr
|
|
|
|
val = val.dimension_as_value()
|
|
|
|
if not isinstance(val, Tracer):
|
|
|
|
return self.pure(val)
|
|
|
|
else:
|
|
|
|
return self.pure(val)
|
2022-12-15 20:34:43 -08:00
|
|
|
val._assert_live()
|
|
|
|
level = self.level
|
|
|
|
sublevel = self.sublevel
|
|
|
|
if val._trace.main is self.main:
|
|
|
|
if val._trace.sublevel == sublevel:
|
2023-03-14 09:55:38 -07:00
|
|
|
return cast(TracerType, val)
|
2022-12-15 20:34:43 -08:00
|
|
|
elif val._trace.sublevel < sublevel:
|
|
|
|
return self.sublift(val)
|
|
|
|
else:
|
|
|
|
raise escaped_tracer_error(
|
|
|
|
val, f"Can't lift sublevels {val._trace.sublevel} to {sublevel}")
|
|
|
|
elif val._trace.level < level:
|
|
|
|
if val._trace.sublevel > sublevel:
|
|
|
|
raise escaped_tracer_error(
|
|
|
|
val, f"Incompatible sublevel: {val._trace}, {(level, sublevel)}")
|
|
|
|
return self.lift(val)
|
|
|
|
elif val._trace.level > level:
|
|
|
|
raise escaped_tracer_error(
|
|
|
|
val, f"Can't lift level {val} to {self}")
|
|
|
|
else: # val._trace.level == self.level:
|
|
|
|
raise escaped_tracer_error(
|
|
|
|
val, f"Different traces at same level: {val}, {self}")
|
|
|
|
|
2023-03-14 09:55:38 -07:00
|
|
|
def pure(self, val) -> TracerType:
|
2022-12-15 20:34:43 -08:00
|
|
|
raise NotImplementedError("must override")
|
|
|
|
|
2023-03-14 09:55:38 -07:00
|
|
|
def lift(self, tracer) -> TracerType:
|
2022-12-15 20:34:43 -08:00
|
|
|
raise NotImplementedError("must override")
|
|
|
|
|
2023-03-14 09:55:38 -07:00
|
|
|
def sublift(self, tracer) -> TracerType:
|
2022-12-15 20:34:43 -08:00
|
|
|
raise NotImplementedError("must override")
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
|
|
raise NotImplementedError("must override")
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return '{}(level={}/{})'.format(
|
|
|
|
self.__class__.__name__, self.level, self.sublevel)
|
|
|
|
|
|
|
|
def process_call(self, call_primitive, f, tracers, params):
|
|
|
|
msg = (f"{type(self)} must override process_call to handle call-like "
|
|
|
|
"primitives")
|
|
|
|
raise NotImplementedError(msg)
|
|
|
|
|
|
|
|
def process_map(self, map_primitive, f, tracers, params):
|
|
|
|
msg = (f"{type(self)} must override process_map to handle map-like "
|
|
|
|
"primitives")
|
|
|
|
raise NotImplementedError(msg)
|
|
|
|
|
2023-02-17 14:03:28 -08:00
|
|
|
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
|
|
|
|
symbolic_zeros):
|
2022-12-15 20:34:43 -08:00
|
|
|
msg = (f"{type(self)} must override process_custom_jvp_call "
|
|
|
|
"to handle custom_jvp primitives")
|
|
|
|
raise NotImplementedError(msg)
|
|
|
|
|
|
|
|
def process_custom_transpose(self, prim, call, tracers, **params):
|
|
|
|
msg = (f"{type(self)} must override process_custom_transpose "
|
|
|
|
"to handle custom_transpose_call primitives")
|
|
|
|
raise NotImplementedError(msg)
|
|
|
|
|
2023-03-24 14:42:19 -07:00
|
|
|
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
|
|
|
|
out_trees, symbolic_zeros):
|
2022-12-15 20:34:43 -08:00
|
|
|
msg = (f"{type(self)} must override process_custom_vjp_call "
|
|
|
|
"to handle custom_vjp primitives")
|
|
|
|
raise NotImplementedError(msg)
|
|
|
|
|
|
|
|
|
|
|
|
def raise_as_much_as_possible(tracer) -> Tracer:
|
|
|
|
# Find effective bottom of trace stack (highest dynamic Trace on the stack).
|
|
|
|
trace_stack = thread_local_state.trace_state.trace_stack.stack
|
|
|
|
idx = next(i for i, m in enumerate(trace_stack) if m is
|
|
|
|
thread_local_state.trace_state.trace_stack.dynamic)
|
|
|
|
|
|
|
|
# Only pay attention to effective part of trace stack.
|
|
|
|
trace_stack = trace_stack[idx:]
|
|
|
|
|
|
|
|
# Lift tracer into everything in the effective stack higher than its level
|
|
|
|
for trace in trace_stack:
|
|
|
|
trace = trace.with_cur_sublevel()
|
|
|
|
if (not isinstance(tracer, Tracer) or tracer._trace.level < trace.level):
|
|
|
|
tracer = trace.full_raise(tracer)
|
|
|
|
|
|
|
|
return tracer
|
|
|
|
|
|
|
|
|
|
|
|
def escaped_tracer_error(tracer, detail=None):
|
2023-07-27 12:15:16 -07:00
|
|
|
num_frames = _TRACER_ERROR_NUM_TRACEBACK_FRAMES.value
|
2022-12-15 20:34:43 -08:00
|
|
|
msg = ('Encountered an unexpected tracer. A function transformed by JAX '
|
|
|
|
'had a side effect, allowing for a reference to an intermediate value '
|
|
|
|
f'with type {tracer.aval.str_short()} wrapped in a '
|
|
|
|
f'{type(tracer).__name__} to escape the scope of the transformation.\n'
|
|
|
|
'JAX transformations require that functions explicitly return their '
|
|
|
|
'outputs, and disallow saving intermediate values to global state.')
|
|
|
|
dbg = getattr(tracer, '_debug_info', None)
|
|
|
|
if dbg is not None:
|
|
|
|
msg += ('\nThe function being traced when the value leaked was '
|
|
|
|
f'{dbg.func_src_info} traced for {dbg.traced_for}.')
|
|
|
|
line_info = getattr(tracer, '_line_info', None)
|
|
|
|
if line_info is not None:
|
|
|
|
divider = '\n' + '-'*30 + '\n'
|
|
|
|
msg += divider
|
|
|
|
msg += ('The leaked intermediate value was created on line '
|
|
|
|
f'{source_info_util.summarize(line_info)}. ')
|
|
|
|
msg += divider
|
|
|
|
if num_frames > 0:
|
|
|
|
msg += (f'When the value was created, the final {num_frames} stack '
|
|
|
|
'frames (most recent last) excluding JAX-internal frames were:')
|
|
|
|
msg += divider + source_info_util.summarize(
|
|
|
|
line_info, num_frames=num_frames) + divider
|
|
|
|
msg += ('\nTo catch the leak earlier, try setting the environment variable '
|
|
|
|
'JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context '
|
|
|
|
'manager.')
|
|
|
|
if detail:
|
|
|
|
msg += f'Detail: {detail}'
|
|
|
|
return UnexpectedTracerError(msg)
|
|
|
|
|
|
|
|
|
2023-09-19 09:00:19 -07:00
|
|
|
def check_scalar_conversion(arr: Array):
|
2024-02-15 16:55:33 -08:00
|
|
|
if arr.ndim > 0:
|
|
|
|
raise TypeError("Only scalar arrays can be converted to Python scalars; "
|
|
|
|
f"got {arr.ndim=}")
|
2023-09-19 09:00:19 -07:00
|
|
|
|
|
|
|
|
|
|
|
def check_integer_conversion(arr: Array):
|
|
|
|
if not (arr.shape == () and dtypes.issubdtype(arr.dtype, np.integer)):
|
|
|
|
raise TypeError("Only integer scalar arrays can be converted to a scalar index.")
|
|
|
|
|
|
|
|
|
2024-03-06 11:30:48 -08:00
|
|
|
def check_bool_conversion(arr: Array):
|
2023-09-19 09:00:19 -07:00
|
|
|
if arr.size == 0:
|
2024-03-06 11:30:48 -08:00
|
|
|
raise ValueError("The truth value of an empty array is ambiguous. Use"
|
|
|
|
" `array.size > 0` to check that an array is not empty.")
|
2023-09-19 09:00:19 -07:00
|
|
|
if arr.size > 1:
|
2024-03-06 11:30:48 -08:00
|
|
|
raise ValueError("The truth value of an array with more than one element"
|
|
|
|
" is ambiguous. Use a.any() or a.all()")
|
2023-09-19 09:00:19 -07:00
|
|
|
|
|
|
|
|
2023-11-03 09:56:33 -07:00
|
|
|
def _aval_property(name):
|
|
|
|
return property(lambda self: getattr(self.aval, name))
|
|
|
|
|
2023-09-19 09:00:19 -07:00
|
|
|
|
2024-01-26 13:52:47 +00:00
|
|
|
class Tracer(typing.Array, metaclass=StrictABCMeta):
|
2022-12-15 20:34:43 -08:00
|
|
|
__array_priority__ = 1000
|
|
|
|
__slots__ = ['_trace', '_line_info']
|
|
|
|
|
2023-11-03 09:56:33 -07:00
|
|
|
dtype = _aval_property('dtype')
|
|
|
|
ndim = _aval_property('ndim')
|
|
|
|
size = _aval_property('size')
|
|
|
|
shape = _aval_property('shape')
|
|
|
|
|
2024-06-13 13:14:27 -07:00
|
|
|
def __hash__(self):
|
|
|
|
# TODO(jakevdp) finalize this deprecation and set __hash__ = None
|
|
|
|
# Warning added 2024-06-13
|
|
|
|
if deprecations.is_accelerated('tracer-hash'):
|
|
|
|
raise TypeError(f"unhashable type: {type(self)}")
|
|
|
|
# Use FutureWarning rather than DeprecationWarning because hash is likely
|
|
|
|
# not called directly by the user, so we want to warn at all stacklevels.
|
|
|
|
warnings.warn(
|
|
|
|
f"unhashable type: {type(self)}. Attempting to hash a tracer will lead to an"
|
|
|
|
" error in a future JAX release.", category=FutureWarning)
|
|
|
|
return super().__hash__()
|
|
|
|
|
2023-11-03 09:56:33 -07:00
|
|
|
def __init__(self, trace: Trace):
|
|
|
|
self._trace = trace
|
|
|
|
|
2023-06-20 00:33:51 -07:00
|
|
|
def _error_repr(self):
|
|
|
|
if self.aval is None:
|
|
|
|
return f"traced array with aval {self.aval}"
|
2024-06-26 00:00:32 +00:00
|
|
|
return f"traced array with shape {raise_to_shaped(self.aval).str_short()}"
|
2023-06-20 00:33:51 -07:00
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
def __array__(self, *args, **kw):
|
|
|
|
raise TracerArrayConversionError(self)
|
|
|
|
|
|
|
|
def __dlpack__(self, *args, **kw):
|
|
|
|
raise ConcretizationTypeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The __dlpack__() method was called on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-03-17 09:42:49 -07:00
|
|
|
def tolist(self):
|
|
|
|
raise ConcretizationTypeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The tolist() method was called on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2023-03-17 09:42:49 -07:00
|
|
|
|
|
|
|
def tobytes(self, order="C"):
|
|
|
|
del order
|
|
|
|
raise ConcretizationTypeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The tobytes() method was called on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2023-03-17 09:42:49 -07:00
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
def __iter__(self):
|
|
|
|
return iter(self.aval._iter(self))
|
|
|
|
|
|
|
|
def __reversed__(self):
|
|
|
|
return iter(self[::-1])
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return self.aval._len(self)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def sharding(self):
|
|
|
|
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
|
|
|
|
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
|
|
|
|
# we raise an AttributeError so that hasattr() and getattr() work as expected.
|
|
|
|
raise AttributeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The 'sharding' attribute is not available on {self._error_repr()}."
|
2024-05-31 21:37:00 +00:00
|
|
|
f"{self._origin_msg()}")
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2024-07-23 09:48:51 -07:00
|
|
|
@property
|
|
|
|
def device(self):
|
|
|
|
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
|
|
|
|
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
|
|
|
|
# we raise an AttributeError so that hasattr() and getattr() work as expected.
|
|
|
|
raise AttributeError(self,
|
|
|
|
f"The 'device' attribute is not available on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
@property
|
|
|
|
def addressable_shards(self):
|
|
|
|
raise ConcretizationTypeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The 'addressable_shards' attribute is not available on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def at(self):
|
|
|
|
return self.aval.at.fget(self)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
|
|
|
raise NotImplementedError("must override")
|
|
|
|
|
|
|
|
def _assert_live(self) -> None:
|
|
|
|
pass # Override for liveness checking
|
|
|
|
|
|
|
|
def get_referent(self) -> Any:
|
|
|
|
return self # Override for object equivalence checking
|
|
|
|
|
2023-09-19 09:00:19 -07:00
|
|
|
def __bool__(self):
|
|
|
|
check_bool_conversion(self)
|
|
|
|
return self.aval._bool(self)
|
|
|
|
|
|
|
|
def __int__(self):
|
|
|
|
check_scalar_conversion(self)
|
|
|
|
return self.aval._int(self)
|
|
|
|
|
|
|
|
def __float__(self):
|
|
|
|
check_scalar_conversion(self)
|
|
|
|
return self.aval._float(self)
|
|
|
|
|
|
|
|
def __complex__(self):
|
|
|
|
check_scalar_conversion(self)
|
|
|
|
return self.aval._complex(self)
|
|
|
|
|
|
|
|
def __hex__(self):
|
|
|
|
check_integer_conversion(self)
|
|
|
|
return self.aval._hex(self)
|
|
|
|
|
|
|
|
def __oct__(self):
|
|
|
|
check_integer_conversion(self)
|
|
|
|
return self.aval._oct(self)
|
|
|
|
|
|
|
|
def __index__(self):
|
|
|
|
check_integer_conversion(self)
|
|
|
|
raise self.aval._index(self)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
# raises a useful error on attempts to pickle a Tracer.
|
|
|
|
def __reduce__(self):
|
|
|
|
raise ConcretizationTypeError(
|
|
|
|
self, ("The error occurred in the __reduce__ method, which may "
|
|
|
|
"indicate an attempt to serialize/pickle a traced value."))
|
|
|
|
|
|
|
|
# raises the better error message from ShapedArray
|
|
|
|
def __setitem__(self, idx, val): return self.aval._setitem(self, idx, val)
|
|
|
|
|
|
|
|
# NumPy also only looks up special methods on classes.
|
|
|
|
def __array_module__(self, types): return self.aval._array_module(self, types)
|
|
|
|
|
|
|
|
def __getattr__(self, name):
|
|
|
|
# if the aval property raises an AttributeError, gets caught here
|
2023-10-09 07:28:18 -07:00
|
|
|
assert not config.enable_checks.value or name != "aval"
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
try:
|
|
|
|
attr = getattr(self.aval, name)
|
2023-03-17 11:25:10 -07:00
|
|
|
except AttributeError as err:
|
2022-12-15 20:34:43 -08:00
|
|
|
raise AttributeError(
|
|
|
|
f"{self.__class__.__name__} has no attribute {name}"
|
|
|
|
) from err
|
|
|
|
else:
|
|
|
|
t = type(attr)
|
|
|
|
if t is aval_property:
|
|
|
|
return attr.fget(self)
|
|
|
|
elif t is aval_method:
|
|
|
|
return types.MethodType(attr.fun, self)
|
|
|
|
else:
|
|
|
|
return attr
|
|
|
|
|
|
|
|
def _pretty_print(self):
|
|
|
|
base = pp.text(f'Traced<{self.aval}>with<{self._trace}>')
|
|
|
|
contents = [(name, attr._pretty_print() if isinstance(attr, Tracer)
|
|
|
|
else pp.text(repr(attr))) for name, attr in self._contents()]
|
|
|
|
if contents:
|
|
|
|
base = pp.group(pp.nest(2, pp.concat([
|
|
|
|
base, pp.text(' with'), pp.brk(), pp.join(pp.brk(), [
|
|
|
|
pp.text(f'{name} = ') + pp_payload
|
|
|
|
for name, pp_payload in contents])
|
|
|
|
])))
|
|
|
|
return base
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return self._pretty_print().format()
|
|
|
|
|
|
|
|
def _contents(self):
|
|
|
|
try:
|
|
|
|
return [(name, getattr(self, name)) for name in self.__slots__]
|
|
|
|
except AttributeError:
|
|
|
|
return ()
|
|
|
|
|
|
|
|
def _origin_msg(self) -> str:
|
|
|
|
return ""
|
|
|
|
|
2023-02-10 09:42:32 -08:00
|
|
|
# Methods that are only valid for materialized arrays
|
|
|
|
def addressable_data(self, index):
|
|
|
|
raise ConcretizationTypeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The addressable_data() method was called on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2023-02-10 09:42:32 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def block_until_ready(self):
|
2024-04-08 15:16:39 -07:00
|
|
|
# Raise AttributeError for backward compatibility with hasattr() and getattr() checks.
|
2023-02-10 09:42:32 -08:00
|
|
|
raise AttributeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The 'block_until_ready' method is not available on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2023-02-10 09:42:32 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def copy_to_host_async(self):
|
2024-04-08 15:16:39 -07:00
|
|
|
# Raise AttributeError for backward compatibility with hasattr() and getattr() checks.
|
2023-02-10 09:42:32 -08:00
|
|
|
raise AttributeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The 'copy_to_host_async' method is not available on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2023-02-10 09:42:32 -08:00
|
|
|
|
|
|
|
def delete(self):
|
|
|
|
raise ConcretizationTypeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The delete() method was called on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2023-02-10 09:42:32 -08:00
|
|
|
|
|
|
|
def devices(self):
|
|
|
|
raise ConcretizationTypeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The devices() method was called on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2023-02-10 09:42:32 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def global_shards(self):
|
|
|
|
raise ConcretizationTypeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The global_shards property was called on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2023-02-10 09:42:32 -08:00
|
|
|
|
|
|
|
def is_deleted(self):
|
|
|
|
raise ConcretizationTypeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The is_deleted() method was called on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2023-02-10 09:42:32 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def is_fully_addressable(self):
|
|
|
|
raise ConcretizationTypeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The is_fully_addressable property was called on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2023-02-10 09:42:32 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def is_fully_replicated(self):
|
|
|
|
raise ConcretizationTypeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The is_fully_replicated property was called on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2023-02-10 09:42:32 -08:00
|
|
|
|
|
|
|
def on_device_size_in_bytes(self):
|
|
|
|
raise ConcretizationTypeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The on_device_size_in_bytes() method was called on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2023-02-10 09:42:32 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def traceback(self):
|
|
|
|
raise ConcretizationTypeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The traceback property was called on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2023-02-10 09:42:32 -08:00
|
|
|
|
|
|
|
def unsafe_buffer_pointer(self):
|
|
|
|
raise ConcretizationTypeError(self,
|
2023-06-20 00:33:51 -07:00
|
|
|
f"The unsafe_buffer_pointer() method was called on {self._error_repr()}."
|
|
|
|
f"{self._origin_msg()}")
|
2023-02-10 09:42:32 -08:00
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
# these can be used to set up forwarding of properties and instance methods from
|
|
|
|
# Tracer instances to the underlying avals
|
|
|
|
aval_property = namedtuple("aval_property", ["fget"])
|
|
|
|
aval_method = namedtuple("aval_method", ["fun"])
|
|
|
|
|
|
|
|
|
|
|
|
class EvalTrace(Trace):
|
|
|
|
# See comments in https://github.com/google/jax/pull/3370
|
|
|
|
def pure(self, x): return x
|
|
|
|
lift = sublift = pure
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
2024-04-10 13:14:43 -07:00
|
|
|
if config.debug_key_reuse.value:
|
|
|
|
# Import here to avoid circular imports
|
|
|
|
from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error
|
|
|
|
return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params)
|
|
|
|
else:
|
|
|
|
return primitive.impl(*tracers, **params)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def process_call(self, primitive, f, tracers, params):
|
2024-04-10 13:14:43 -07:00
|
|
|
if config.debug_key_reuse.value:
|
|
|
|
# Import here to avoid circular imports
|
|
|
|
from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error
|
|
|
|
return call_impl_with_key_reuse_checks(primitive, primitive.impl, f, *tracers, **params)
|
|
|
|
else:
|
|
|
|
return primitive.impl(f, *tracers, **params)
|
2022-12-15 20:34:43 -08:00
|
|
|
process_map = process_call
|
|
|
|
|
|
|
|
def process_custom_transpose(self, primitive, call, tracers, **_):
|
2023-02-17 14:03:28 -08:00
|
|
|
del primitive, _
|
2022-12-15 20:34:43 -08:00
|
|
|
with new_sublevel():
|
|
|
|
return call.call_wrapped(*tracers)
|
|
|
|
|
2023-02-17 14:03:28 -08:00
|
|
|
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, **_):
|
|
|
|
del primitive, jvp, _ # Unused.
|
2022-12-15 20:34:43 -08:00
|
|
|
with new_sublevel():
|
|
|
|
return fun.call_wrapped(*tracers)
|
|
|
|
|
|
|
|
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, **_): # pytype: disable=signature-mismatch
|
2023-02-17 14:03:28 -08:00
|
|
|
del primitive, fwd, bwd, _ # Unused.
|
2022-12-15 20:34:43 -08:00
|
|
|
with new_sublevel():
|
|
|
|
return fun.call_wrapped(*tracers)
|
|
|
|
|
|
|
|
|
|
|
|
class MainTrace:
|
|
|
|
level: int
|
2023-06-23 15:11:37 -07:00
|
|
|
trace_type: type[Trace]
|
|
|
|
payload: dict[str, Any]
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def __init__(self, level, trace_type, **payload) -> None:
|
|
|
|
self.level = level
|
|
|
|
self.trace_type = trace_type
|
|
|
|
self.payload = payload
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
return f"MainTrace({self.level},{self.trace_type.__name__})"
|
|
|
|
|
|
|
|
def __hash__(self) -> int:
|
|
|
|
return hash((self.level, self.trace_type))
|
|
|
|
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
|
|
return (isinstance(other, MainTrace) and
|
|
|
|
self.level == other.level and
|
|
|
|
self.trace_type == other.trace_type and
|
|
|
|
self.payload == other.payload)
|
|
|
|
|
|
|
|
def with_cur_sublevel(self):
|
|
|
|
return self.trace_type(self, cur_sublevel(), **self.payload)
|
|
|
|
|
|
|
|
class TraceStack:
|
|
|
|
# See comments in https://github.com/google/jax/pull/3370
|
2023-06-23 15:11:37 -07:00
|
|
|
stack: list[MainTrace]
|
2022-12-15 20:34:43 -08:00
|
|
|
dynamic: MainTrace
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
eval_trace = MainTrace(0, EvalTrace)
|
|
|
|
self.stack = [eval_trace]
|
|
|
|
self.dynamic = eval_trace
|
|
|
|
|
|
|
|
def next_level(self) -> int:
|
|
|
|
return len(self.stack)
|
|
|
|
|
|
|
|
def push(self, main_trace: MainTrace) -> None:
|
|
|
|
self.stack.append(main_trace)
|
|
|
|
|
|
|
|
def pop(self) -> None:
|
|
|
|
self.stack.pop()
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
stack_str = map(' {}\n'.format, self.stack[::-1])
|
|
|
|
return f'Trace stack\n{stack_str}\n{self.dynamic}'
|
|
|
|
|
|
|
|
def copy(self):
|
|
|
|
new = self.__new__(TraceStack)
|
|
|
|
new.stack = self.stack[:]
|
|
|
|
new.dynamic = self.dynamic
|
|
|
|
return new
|
|
|
|
|
|
|
|
|
|
|
|
@total_ordering
|
|
|
|
class Sublevel:
|
|
|
|
|
|
|
|
def __init__(self, level: int):
|
|
|
|
self.level = level
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return str(self.level)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return type(other) is Sublevel and self.level == other.level
|
|
|
|
|
|
|
|
def __lt__(self, other):
|
|
|
|
return type(other) is Sublevel and self.level < other.level
|
|
|
|
|
|
|
|
|
|
|
|
AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace'])
|
|
|
|
AxisName = Hashable
|
|
|
|
|
|
|
|
no_axis_name = object()
|
|
|
|
|
|
|
|
class TraceState:
|
|
|
|
trace_stack: TraceStack
|
2023-06-23 15:11:37 -07:00
|
|
|
substack: list[Sublevel]
|
|
|
|
axis_env: list[AxisEnvFrame]
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
self.trace_stack = TraceStack()
|
|
|
|
self.substack = [Sublevel(0)]
|
|
|
|
self.axis_env = []
|
|
|
|
|
|
|
|
def copy(self):
|
|
|
|
new = self.__new__(TraceState)
|
|
|
|
new.trace_stack = self.trace_stack.copy()
|
|
|
|
new.substack = self.substack[:]
|
|
|
|
new.axis_env = self.axis_env[:]
|
|
|
|
return new
|
|
|
|
|
|
|
|
|
|
|
|
def _update_thread_local_jit_state(dynamic):
|
2024-06-26 20:25:39 +00:00
|
|
|
state = (dynamic.level, dynamic.trace_type)
|
|
|
|
config.update_thread_local_jit_state(dynamic_trace_state=state)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
|
|
|
|
# The global state of the tracer is accessed by a thread-local object.
|
|
|
|
# This allows concurrent tracing in separate threads; passing traced objects
|
|
|
|
# between threads is forbidden.
|
|
|
|
class ThreadLocalState(threading.local):
|
|
|
|
def __init__(self):
|
|
|
|
self.trace_state = TraceState()
|
|
|
|
|
|
|
|
thread_local_state = ThreadLocalState()
|
|
|
|
|
|
|
|
|
|
|
|
def _initialize_jax_jit_thread_local_state():
|
|
|
|
"""Initializes the C++ thread-local context.
|
|
|
|
|
|
|
|
When the user spawns threads, the C++ `jax_jit.thread_local_state` is None.
|
|
|
|
The C++ accessor calls this function if it realizes the thread_local_state
|
|
|
|
is None (which means it's not yet initialized for this thread).
|
|
|
|
|
|
|
|
This function does not live in `config.py`, to prevent circular imports.
|
|
|
|
"""
|
|
|
|
tls = jax_jit.thread_local_state()
|
|
|
|
if tls.extra_jit_context is None:
|
|
|
|
dynamic = thread_local_state.trace_state.trace_stack.dynamic
|
2024-06-26 20:25:39 +00:00
|
|
|
state = (dynamic.level, dynamic.trace_type)
|
|
|
|
config.update_thread_local_jit_state(dynamic_trace_state=state)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
|
|
|
|
jax_jit.set_thread_local_state_initialization_callback(
|
|
|
|
_initialize_jax_jit_thread_local_state)
|
|
|
|
|
|
|
|
def trace_state_clean() -> bool:
|
|
|
|
trace_state = thread_local_state.trace_state
|
|
|
|
return (trace_state.substack == [Sublevel(0)] and
|
|
|
|
trace_state.axis_env == [] and
|
|
|
|
trace_state.trace_stack.stack == [MainTrace(0, EvalTrace)] and
|
|
|
|
trace_state.trace_stack.dynamic == MainTrace(0, EvalTrace))
|
|
|
|
|
|
|
|
def reset_trace_state() -> bool:
|
|
|
|
"""Resets the global trace state and returns True if it was already clean."""
|
|
|
|
if not trace_state_clean():
|
2024-05-17 09:46:36 +01:00
|
|
|
thread_local_state.trace_state.__init__()
|
2022-12-15 20:34:43 -08:00
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return True
|
|
|
|
|
|
|
|
def cur_sublevel() -> Sublevel:
|
|
|
|
return thread_local_state.trace_state.substack[-1]
|
|
|
|
|
|
|
|
TRACER_LEAK_DEBUGGER_WARNING = """\
|
|
|
|
JAX check_tracer_leaks behavior can trigger false positives when used with a debugger.
|
|
|
|
To avoid false positives and silence this warning, you can disable thread tracing using
|
|
|
|
the following:
|
|
|
|
|
|
|
|
import threading
|
|
|
|
threading.current_thread().pydev_do_not_trace = True
|
|
|
|
"""
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
def maybe_find_leaked_tracers(x: MainTrace | Sublevel | None
|
2023-06-23 15:11:37 -07:00
|
|
|
) -> list[Tracer]:
|
2022-12-15 20:34:43 -08:00
|
|
|
"""Find the leaked tracers holding a reference to the MainTrace or SubLevel.
|
|
|
|
|
|
|
|
It's possible there's none! eg. there's some cases where JAX itself holds a
|
|
|
|
reference to `x` inside of a lambda closure, and no tracers were leaked
|
|
|
|
by the user. In this case an empty list is returned.
|
|
|
|
"""
|
|
|
|
if not getattr(threading.current_thread(), 'pydev_do_not_trace', True):
|
|
|
|
warnings.warn(TRACER_LEAK_DEBUGGER_WARNING)
|
|
|
|
# Trigger garbage collection to filter out unreachable objects that are alive
|
|
|
|
# only due to cyclical dependencies. (We don't care about unreachable leaked
|
|
|
|
# tracers since they can't interact with user code and cause a problem.)
|
|
|
|
gc.collect()
|
|
|
|
traces = list(filter(lambda x: isinstance(x, Trace), gc.get_referrers(x)))
|
|
|
|
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces)))
|
|
|
|
return tracers
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def leaked_tracer_error(name: str, t, tracers: list[Tracer]) -> Exception:
|
2022-12-15 20:34:43 -08:00
|
|
|
assert tracers
|
|
|
|
why = partial(_why_alive, {id(tracers)})
|
|
|
|
msgs = '\n\n'.join(f'{tracers[i]}{tracers[i]._origin_msg()}{why(tracers[i])}'
|
|
|
|
for i in range(len(tracers)))
|
|
|
|
return Exception(f'Leaked {name} {t}. Leaked tracer(s):\n\n{msgs}\n')
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def _why_alive(ignore_ids: set[int], x: Any) -> str:
|
2022-12-15 20:34:43 -08:00
|
|
|
parents = lambda x: [r for r in gc.get_referrers(x) if id(r) not in ignore_ids]
|
|
|
|
child, lines, seen = x, [], set()
|
|
|
|
while (id(child) not in seen and type(child) is not types.ModuleType
|
|
|
|
and parents(child)):
|
|
|
|
parent = parents(child)[0] # just pick one parent
|
|
|
|
|
|
|
|
# For namespaces (like modules and class instances) and closures, the
|
|
|
|
# references may form a simple chain: e.g. instance refers to its own
|
|
|
|
# __dict__ which refers to child, or function refers to its __closure__
|
|
|
|
# which refers to cells which refer to child. In these cases, we can provide
|
|
|
|
# a more intuitive description by collapsing the chain into a single
|
|
|
|
# parent->child jump. We do that by setting `parent` here to be a
|
|
|
|
# grandparent (or great-grandparent) of `child`, and then handling that case
|
|
|
|
# in _why_alive_container_info. See example:
|
|
|
|
# https://github.com/google/jax/pull/13022#discussion_r1008456599
|
|
|
|
# To prevent this collapsing behavior, just comment out this code block.
|
|
|
|
if (isinstance(parent, dict) and
|
|
|
|
getattr(parents(parent)[0], '__dict__', None) is parents(child)[0]):
|
|
|
|
parent = parents(parent)[0]
|
|
|
|
elif type(parent) is types.CellType:
|
|
|
|
parent = parents(parents(parent)[0])[0]
|
|
|
|
|
|
|
|
line = f'<{type(child).__name__} {id(child)}> is referred to by '
|
|
|
|
lines.append(line + _why_alive_container_info(parent, id(child)))
|
|
|
|
seen.add(id(child))
|
|
|
|
child = parent
|
|
|
|
return '\n' + '\n'.join(lines) if lines else ''
|
|
|
|
|
|
|
|
def _why_alive_container_info(container, obj_id) -> str:
|
|
|
|
name = f'<{type(container).__name__} {id(container)}>'
|
|
|
|
if type(container) is types.ModuleType:
|
|
|
|
name = getattr(container, '__name__', name)
|
|
|
|
if type(container) is types.FunctionType:
|
|
|
|
name_ = getattr(container, '__name__', '<no-name>')
|
|
|
|
closure = inspect.getclosurevars(container)
|
|
|
|
keys = [k for k, v in dict(closure.nonlocals, **closure.globals).items()
|
|
|
|
if id(v) == obj_id]
|
|
|
|
if len(keys) == 1: return f'{name} ({name_}) closed-over variable {keys[0]}'
|
|
|
|
elif len(keys) > 1: return (f'{name} in closed-over variables ' +
|
|
|
|
', '.join(map(repr, keys)))
|
|
|
|
if hasattr(container, '__dict__'):
|
|
|
|
keys = [k for k in vars(container) if id(vars(container)[k]) == obj_id]
|
2023-10-23 15:11:15 +01:00
|
|
|
if len(keys) == 1: return f'{name}.{keys[0]}'
|
2022-12-15 20:34:43 -08:00
|
|
|
elif len(keys) > 1: return f'{name} in vars ' + ', '.join(map(repr, keys))
|
|
|
|
if isinstance(container, (list, tuple)):
|
|
|
|
idxs = [i for i, x in enumerate(container) if id(x) == obj_id]
|
|
|
|
if len(idxs) == 1: return f'{name}[{idxs[0]}]'
|
|
|
|
else: return f'{name} at indices ' + ', '.join(map(str, idxs))
|
|
|
|
if isinstance(container, dict):
|
|
|
|
keys = [k for k in container if id(container[k]) == obj_id]
|
2023-10-23 15:11:15 +01:00
|
|
|
if len(keys) == 1: return f'{name}[{keys[0]!r}]'
|
2022-12-15 20:34:43 -08:00
|
|
|
else: return f'{name} at keys ' + ', '.join(map(repr, keys))
|
|
|
|
if isinstance(container, types.ModuleType):
|
|
|
|
return f' named {container.__name__}'
|
|
|
|
return name
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
2023-08-31 17:30:34 -07:00
|
|
|
def new_main(trace_type: type[Trace], dynamic: bool = False,
|
2022-12-15 20:34:43 -08:00
|
|
|
**payload) -> Generator[MainTrace, None, None]:
|
|
|
|
# See comments in https://github.com/google/jax/pull/3370
|
|
|
|
stack = thread_local_state.trace_state.trace_stack
|
|
|
|
level = stack.next_level()
|
|
|
|
main = MainTrace(level, trace_type, **payload)
|
|
|
|
stack.push(main)
|
|
|
|
if dynamic:
|
|
|
|
prev_dynamic, stack.dynamic = stack.dynamic, main
|
|
|
|
_update_thread_local_jit_state(stack.dynamic)
|
|
|
|
|
|
|
|
try:
|
|
|
|
yield main
|
|
|
|
finally:
|
|
|
|
stack.pop()
|
|
|
|
if dynamic:
|
|
|
|
stack.dynamic = prev_dynamic
|
|
|
|
_update_thread_local_jit_state(stack.dynamic)
|
|
|
|
|
2023-10-09 07:28:18 -07:00
|
|
|
if config.check_tracer_leaks.value:
|
2022-12-15 20:34:43 -08:00
|
|
|
t = ref(main)
|
|
|
|
del main
|
|
|
|
if t() is not None:
|
|
|
|
leaked_tracers = maybe_find_leaked_tracers(t())
|
|
|
|
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
|
|
|
|
|
2023-08-31 17:30:34 -07:00
|
|
|
@contextmanager
|
|
|
|
def new_dynamic(level: int) -> Generator[None, None, None]:
|
|
|
|
stack = thread_local_state.trace_state.trace_stack
|
|
|
|
prev_dynamic, stack.dynamic = stack.dynamic, stack.stack[level]
|
|
|
|
_update_thread_local_jit_state(stack.dynamic)
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
stack.dynamic = prev_dynamic
|
|
|
|
_update_thread_local_jit_state(stack.dynamic)
|
|
|
|
|
|
|
|
def dynamic_level() -> int:
|
|
|
|
return thread_local_state.trace_state.trace_stack.dynamic.level
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
@contextmanager
|
2023-06-23 15:11:37 -07:00
|
|
|
def new_base_main(trace_type: type[Trace],
|
2022-12-15 20:34:43 -08:00
|
|
|
**payload) -> Generator[MainTrace, None, None]:
|
|
|
|
# See comments in https://github.com/google/jax/pull/3370
|
|
|
|
stack = thread_local_state.trace_state.trace_stack
|
|
|
|
main = MainTrace(0, trace_type, **payload)
|
|
|
|
prev_dynamic, stack.dynamic = stack.dynamic, main
|
|
|
|
prev_base, stack.stack[0] = stack.stack[0], main
|
|
|
|
_update_thread_local_jit_state(stack.dynamic)
|
|
|
|
try:
|
|
|
|
yield main
|
|
|
|
finally:
|
|
|
|
stack.dynamic = prev_dynamic
|
|
|
|
stack.stack[0] = prev_base
|
|
|
|
_update_thread_local_jit_state(stack.dynamic)
|
|
|
|
|
2023-10-09 07:28:18 -07:00
|
|
|
if config.check_tracer_leaks.value:
|
2022-12-15 20:34:43 -08:00
|
|
|
t = ref(main)
|
|
|
|
del main
|
|
|
|
if t() is not None:
|
|
|
|
leaked_tracers = maybe_find_leaked_tracers(t())
|
|
|
|
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
|
|
|
|
|
2024-04-30 14:11:27 -07:00
|
|
|
@contextmanager
|
|
|
|
def pop_level(level: int):
|
2024-05-02 22:18:36 +00:00
|
|
|
if level == 0:
|
|
|
|
return (yield)
|
2024-04-30 14:11:27 -07:00
|
|
|
prev, thread_local_state.trace_state.trace_stack.stack = \
|
|
|
|
thread_local_state.trace_state.trace_stack.stack, \
|
|
|
|
thread_local_state.trace_state.trace_stack.stack[:level]
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
thread_local_state.trace_state.trace_stack.stack = prev
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
@contextmanager
|
|
|
|
def ensure_compile_time_eval():
|
|
|
|
"""Context manager to ensure evaluation at trace/compile time (or error).
|
|
|
|
|
2023-06-22 15:47:41 +08:00
|
|
|
Some JAX APIs like :func:`jax.jit` and :func:`jax.lax.scan` involve staging,
|
2022-12-15 20:34:43 -08:00
|
|
|
i.e., delaying the evaluation of numerical expressions (like :mod:`jax.numpy`
|
|
|
|
function applications) so that instead of performing those computations
|
|
|
|
eagerly while evaluating the corresponding Python expressions, their
|
|
|
|
computation is carried out separately, e.g. after optimized compilation. But
|
|
|
|
this delay can be undesirable. For example, numerical values might be needed
|
|
|
|
to evaluate Python control flow and so their evaluation cannot be delayed. As
|
|
|
|
another example, it may be beneficial to ensure compile time evaluation (or
|
|
|
|
"constant folding") for performance reasons.
|
|
|
|
|
|
|
|
This context manager ensures that JAX computations are evaluated eagerly. If
|
2023-06-22 16:01:35 +08:00
|
|
|
eager evaluation is not possible, a ``ConcretizationTypeError`` is raised.
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
Here's a contrived example::
|
|
|
|
|
|
|
|
import jax
|
|
|
|
import jax.numpy as jnp
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
with jax.ensure_compile_time_eval():
|
|
|
|
y = jnp.sin(3.0)
|
|
|
|
z = jnp.sin(y)
|
|
|
|
z_positive = z > 0
|
|
|
|
if z_positive: # z_positive is usable in Python control flow
|
|
|
|
return jnp.sin(x)
|
|
|
|
else:
|
|
|
|
return jnp.cos(x)
|
|
|
|
|
|
|
|
Here's a real-world example from https://github.com/google/jax/issues/3974::
|
|
|
|
|
|
|
|
import jax
|
|
|
|
import jax.numpy as jnp
|
|
|
|
from jax import random
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def jax_fn(x):
|
|
|
|
with jax.ensure_compile_time_eval():
|
2023-08-14 12:59:09 -07:00
|
|
|
y = random.randint(random.key(0), (1000,1000), 0, 100)
|
2022-12-15 20:34:43 -08:00
|
|
|
y2 = y @ y
|
|
|
|
x2 = jnp.sum(y2) * x
|
|
|
|
return x2
|
|
|
|
|
|
|
|
A similar behavior can often be achieved simply by 'hoisting' the constant
|
|
|
|
expression out of the corresponding staging API::
|
|
|
|
|
2023-08-14 12:59:09 -07:00
|
|
|
y = random.randint(random.key(0), (1000,1000), 0, 100)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def jax_fn(x):
|
|
|
|
y2 = y @ y
|
|
|
|
x2 = jnp.sum(y2)*x
|
|
|
|
return x2
|
|
|
|
|
|
|
|
But in some cases it can be more convenient to use this context manager.
|
|
|
|
"""
|
|
|
|
with new_base_main(EvalTrace):
|
|
|
|
yield
|
|
|
|
eval_context = ensure_compile_time_eval # alias, backward compatibility
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def new_sublevel() -> Generator[None, None, None]:
|
|
|
|
sublevel = Sublevel(len(thread_local_state.trace_state.substack))
|
|
|
|
thread_local_state.trace_state.substack.append(sublevel)
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
thread_local_state.trace_state.substack.pop()
|
|
|
|
|
2023-10-09 07:28:18 -07:00
|
|
|
if config.check_tracer_leaks.value:
|
2022-12-15 20:34:43 -08:00
|
|
|
t = ref(sublevel)
|
|
|
|
del sublevel
|
|
|
|
if t() is not None:
|
|
|
|
leaked_tracers = maybe_find_leaked_tracers(t())
|
|
|
|
if leaked_tracers:
|
|
|
|
raise leaked_tracer_error("sublevel", t(), leaked_tracers)
|
|
|
|
|
|
|
|
def full_lower(val):
|
|
|
|
if isinstance(val, Tracer):
|
|
|
|
return val.full_lower()
|
|
|
|
else:
|
|
|
|
return val
|
|
|
|
|
2024-02-23 08:41:04 -08:00
|
|
|
|
|
|
|
def _get_trace_level(t: Tracer) -> int: return t._trace.level
|
|
|
|
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
def find_top_trace(xs) -> Trace:
|
|
|
|
top_tracer = max((x for x in xs if isinstance(x, Tracer)),
|
2024-02-23 08:41:04 -08:00
|
|
|
default=None, key=_get_trace_level)
|
2022-12-15 20:34:43 -08:00
|
|
|
if top_tracer is not None:
|
|
|
|
top_tracer._assert_live()
|
|
|
|
top_main = top_tracer._trace.main
|
|
|
|
else:
|
2024-05-17 09:46:36 +01:00
|
|
|
top_main = None
|
2022-12-15 20:34:43 -08:00
|
|
|
dynamic = thread_local_state.trace_state.trace_stack.dynamic
|
|
|
|
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
|
|
|
|
else top_main)
|
2024-05-17 09:46:36 +01:00
|
|
|
return top_main.with_cur_sublevel()
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def get_referent(x: Any) -> Any:
|
|
|
|
return x.get_referent() if isinstance(x, Tracer) else x
|
|
|
|
|
|
|
|
def same_referent(x: Any, y: Any) -> bool:
|
|
|
|
return get_referent(x) is get_referent(y)
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def dedup_referents(itr: Iterable[Any]) -> list[Any]:
|
2022-12-15 20:34:43 -08:00
|
|
|
return list({HashableWrapper(get_referent(x)):x for x in itr}.values())
|
|
|
|
|
2023-04-19 17:01:05 -07:00
|
|
|
def definitely_equal(x, y):
|
2023-06-10 17:33:27 -04:00
|
|
|
if isinstance(x, Tracer) or isinstance(y, Tracer):
|
|
|
|
return same_referent(x, y)
|
2023-06-30 12:31:47 +03:00
|
|
|
elif x is y:
|
|
|
|
return True
|
2023-07-11 14:03:52 +01:00
|
|
|
try:
|
|
|
|
return x == y
|
|
|
|
except InconclusiveDimensionOperation:
|
|
|
|
return False
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
# -------------------- abstract values --------------------
|
|
|
|
|
|
|
|
class AbstractValue:
|
2023-06-23 15:11:37 -07:00
|
|
|
__slots__: list[str] = []
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def at_least_vspace(self):
|
|
|
|
raise NotImplementedError("must override")
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
try:
|
|
|
|
kv_pairs = (f'{k}={v}' for k, v in self.__dict__.items())
|
|
|
|
return '{}({})'.format(self.__class__.__name__, ','.join(kv_pairs))
|
|
|
|
except AttributeError:
|
|
|
|
return self.__class__.__name__
|
|
|
|
|
|
|
|
def strip_weak_type(self) -> AbstractValue:
|
|
|
|
return self
|
|
|
|
|
|
|
|
def strip_named_shape(self) -> AbstractValue:
|
|
|
|
return self
|
|
|
|
|
|
|
|
def join(self, other):
|
|
|
|
raise NotImplementedError("must override")
|
|
|
|
|
|
|
|
def update(self, **kwargs):
|
|
|
|
raise NotImplementedError("must override")
|
|
|
|
|
|
|
|
def str_short(self, short_dtypes=False):
|
|
|
|
return str(self)
|
|
|
|
|
|
|
|
|
|
|
|
# For type signatures involving dynamic shapes, we use lists of abstract values
|
|
|
|
# which may contain (reverse) de Bruijn indices in their shapes.
|
|
|
|
class DBIdx(NamedTuple):
|
|
|
|
val: int
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
class InDBIdx:
|
|
|
|
val: int
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
class OutDBIdx:
|
|
|
|
val: int
|
|
|
|
|
|
|
|
# For annotating input types of callables (i.e. linear_util.WrappedFuns), we use
|
|
|
|
# a sequence of pairs where the first element of each pair is an AbstractValue
|
|
|
|
# (possibly containing DBIdx instances in its shape) and the second is a boolean
|
|
|
|
# indicating whether that argument is explicit (i.e. passed to the callable).
|
2023-06-23 15:11:37 -07:00
|
|
|
InputType = tuple[tuple[AbstractValue, bool], ...] # DBIdx in shapes
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
# For annotating jaxpr output types, we use a sequence of pairs where the first
|
|
|
|
# element of each pair is an AbstractValue (possibly containing InDBIdx and/or
|
|
|
|
# OutDBIdx instances in its shape) and the second is a boolean indicating
|
|
|
|
# whether that argument is explicit (i.e. returned by the callable).
|
2023-06-23 15:11:37 -07:00
|
|
|
OutputType = tuple[tuple[AbstractValue, bool], ...] # InDBIdx / OutDBIdx shapes
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
|
|
|
|
def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType:
|
|
|
|
idxs = {v: DBIdx(i) for i, v in enumerate((*jaxpr.constvars, *jaxpr.invars))}
|
|
|
|
out = [(v.aval.update(shape=tuple(idxs.get(d, d) for d in v.aval.shape)) # type: ignore
|
|
|
|
if type(v.aval) is DShapedArray else v.aval, True)
|
|
|
|
for v in jaxpr.invars]
|
|
|
|
return tuple(out)
|
|
|
|
|
|
|
|
class Bot(AbstractValue): pass
|
|
|
|
bot = Bot()
|
|
|
|
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
def lattice_join(x: AbstractValue | None,
|
|
|
|
y: AbstractValue | None) -> AbstractValue:
|
2022-12-15 20:34:43 -08:00
|
|
|
if x is None:
|
|
|
|
return cast(AbstractValue, y)
|
|
|
|
elif y is None:
|
|
|
|
return cast(AbstractValue, x)
|
|
|
|
elif isinstance(x, type(y)):
|
|
|
|
return y.join(x)
|
|
|
|
elif isinstance(y, type(x)):
|
|
|
|
return x.join(y)
|
|
|
|
elif isinstance(x, DShapedArray) and isinstance(y, ShapedArray):
|
|
|
|
# TODO(mattjj): remove this special case after dynamic shapes are integrated
|
|
|
|
return x.join(y)
|
|
|
|
else:
|
|
|
|
raise TypeError(x, y)
|
|
|
|
|
|
|
|
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
|
|
|
|
Value = Any
|
|
|
|
|
|
|
|
def valid_jaxtype(x) -> bool:
|
|
|
|
try:
|
|
|
|
concrete_aval(x)
|
|
|
|
except TypeError:
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return True
|
|
|
|
|
|
|
|
def check_valid_jaxtype(x):
|
|
|
|
if not valid_jaxtype(x):
|
|
|
|
raise TypeError(
|
2023-10-23 15:11:15 +01:00
|
|
|
f"Value {x!r} of type {type(x)} is not a valid JAX type")
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
|
|
|
|
def concrete_aval(x):
|
|
|
|
for typ in type(x).__mro__:
|
|
|
|
handler = pytype_aval_mappings.get(typ)
|
|
|
|
if handler: return handler(x)
|
|
|
|
if hasattr(x, '__jax_array__'):
|
|
|
|
return concrete_aval(x.__jax_array__())
|
2023-10-23 15:11:15 +01:00
|
|
|
raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
|
2022-12-15 20:34:43 -08:00
|
|
|
"type")
|
|
|
|
|
|
|
|
|
|
|
|
def get_aval(x):
|
|
|
|
if isinstance(x, Tracer):
|
|
|
|
return x.aval
|
|
|
|
else:
|
|
|
|
return concrete_aval(x)
|
|
|
|
|
|
|
|
|
|
|
|
def concretization_function_error(fun, suggest_astype=False):
|
|
|
|
fname = getattr(fun, "__name__", fun)
|
|
|
|
fname_context = f"The problem arose with the `{fname}` function. "
|
|
|
|
if suggest_astype:
|
|
|
|
fname_context += ("If trying to convert the data type of a value, "
|
|
|
|
f"try using `x.astype({fun.__name__})` "
|
|
|
|
f"or `jnp.array(x, {fun.__name__})` instead.")
|
2023-06-21 01:41:45 -07:00
|
|
|
if fun is bool:
|
|
|
|
def error(self, arg):
|
|
|
|
raise TracerBoolConversionError(arg)
|
2023-09-19 09:00:19 -07:00
|
|
|
elif fun in (hex, oct, operator.index):
|
|
|
|
def error(self, arg):
|
|
|
|
raise TracerIntegerConversionError(arg)
|
2023-06-21 01:41:45 -07:00
|
|
|
else:
|
|
|
|
def error(self, arg):
|
|
|
|
raise ConcretizationTypeError(arg, fname_context)
|
2022-12-15 20:34:43 -08:00
|
|
|
return error
|
|
|
|
|
|
|
|
def concrete_or_error(force: Any, val: Any, context=""):
|
|
|
|
"""Like force(val), but gives the context in the error message."""
|
|
|
|
if force is None:
|
|
|
|
force = lambda x: x
|
|
|
|
if isinstance(val, Tracer):
|
|
|
|
if isinstance(val.aval, ConcreteArray):
|
|
|
|
return force(val.aval.val)
|
|
|
|
else:
|
|
|
|
raise ConcretizationTypeError(val, context)
|
|
|
|
else:
|
|
|
|
return force(val)
|
|
|
|
|
2023-06-30 12:31:47 +03:00
|
|
|
def concrete_dim_or_error(val: Any, context=""):
|
2024-02-13 08:08:29 +02:00
|
|
|
"""Like concrete_or_error(operator.index), allowing symbolic dimensions."""
|
2024-07-23 04:32:09 -07:00
|
|
|
if is_symbolic_dim(val):
|
2023-06-30 12:31:47 +03:00
|
|
|
return val
|
|
|
|
else:
|
|
|
|
return concrete_or_error(operator.index, val, context=context)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-07-24 14:29:37 -07:00
|
|
|
### Extended dtypes
|
2023-01-27 14:06:55 -08:00
|
|
|
#
|
2023-07-24 14:29:37 -07:00
|
|
|
# Extended dtypes are JAX-specific dtypes that allow us to represent logical
|
2023-01-27 14:06:55 -08:00
|
|
|
# arrays of element types that do not have an obvious direct correspondence
|
|
|
|
# to ("physical") arrays of basic types in a compiler. In particular, their
|
|
|
|
# element types differ from those of XLA and NumPy (e.g. int32). These dtypes
|
|
|
|
# are only known to JAX. Their implementation is determined by:
|
2023-07-24 14:29:37 -07:00
|
|
|
# a) an object representing the extended dtype, accessible via the `dtype`
|
2023-01-27 14:06:55 -08:00
|
|
|
# attribute on corresponding JAX arrays and, internally, on avals such
|
|
|
|
# as ShapedArrays that correspond to such JAX arrays;
|
2023-07-24 14:29:37 -07:00
|
|
|
# b) a set of rules, available via a private attribute on the extended dtype
|
2023-01-27 14:06:55 -08:00
|
|
|
# object in (a).
|
|
|
|
# The rules in (b) tell JAX internals how to ground out the element
|
|
|
|
# type for interaction with the compiler and runtime, e.g. when lowering
|
|
|
|
# to the compiler's language.
|
|
|
|
|
2023-05-10 19:13:29 -07:00
|
|
|
@overload
|
|
|
|
def physical_aval(aval: ShapedArray) -> ShapedArray: ...
|
|
|
|
@overload
|
|
|
|
def physical_aval(aval: DShapedArray) -> DShapedArray: ...
|
|
|
|
@overload # TODO(frostig): remove this case
|
|
|
|
def physical_aval(aval: AbstractValue) -> AbstractValue: ...
|
|
|
|
|
|
|
|
def physical_aval(aval):
|
|
|
|
aval_dtype = getattr(aval, 'dtype', None)
|
Use an isinstance check rather than dtypes.issubdtype to check whether the dtype in an aval is an extended dtype.
We don't need the full generality of issubdtype, and this is slightly faster. This operation is very common (e.g., for every aval construction, even with a non-extended dtype).
On my laptop:
```
In [18]: d = jnp.dtype(jnp.int32)
In [20]: %timeit jax.dtypes.issubdtype(d, jax.dtypes.extended)
490 ns ± 2.78 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [22]: %timeit isinstance(d, jax._src.dtypes.ExtendedDType)
78.3 ns ± 0.111 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
```
PiperOrigin-RevId: 606616884
2024-02-13 07:37:14 -08:00
|
|
|
if aval_dtype and isinstance(aval_dtype, dtypes.ExtendedDType):
|
2023-05-10 19:13:29 -07:00
|
|
|
ctor = type(aval)
|
|
|
|
aval_shape = getattr(aval, 'shape', None)
|
|
|
|
assert aval_shape is not None, (ctor, aval)
|
|
|
|
elt_aval = aval_dtype._rules.physical_element_aval(aval_dtype)
|
|
|
|
assert type(elt_aval) is ShapedArray
|
|
|
|
return ctor((*aval_shape, *elt_aval.shape), elt_aval.dtype) # pytype: disable=wrong-arg-count
|
|
|
|
else:
|
|
|
|
return aval
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
def _short_dtype_name(dtype) -> str:
|
Use an isinstance check rather than dtypes.issubdtype to check whether the dtype in an aval is an extended dtype.
We don't need the full generality of issubdtype, and this is slightly faster. This operation is very common (e.g., for every aval construction, even with a non-extended dtype).
On my laptop:
```
In [18]: d = jnp.dtype(jnp.int32)
In [20]: %timeit jax.dtypes.issubdtype(d, jax.dtypes.extended)
490 ns ± 2.78 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [22]: %timeit isinstance(d, jax._src.dtypes.ExtendedDType)
78.3 ns ± 0.111 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
```
PiperOrigin-RevId: 606616884
2024-02-13 07:37:14 -08:00
|
|
|
if isinstance(dtype, dtypes.ExtendedDType):
|
2022-12-15 20:34:43 -08:00
|
|
|
return str(dtype)
|
|
|
|
else:
|
|
|
|
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
|
|
|
|
.replace('int' , 'i').replace('complex', 'c'))
|
|
|
|
|
|
|
|
def _dtype_object(dtype):
|
Use an isinstance check rather than dtypes.issubdtype to check whether the dtype in an aval is an extended dtype.
We don't need the full generality of issubdtype, and this is slightly faster. This operation is very common (e.g., for every aval construction, even with a non-extended dtype).
On my laptop:
```
In [18]: d = jnp.dtype(jnp.int32)
In [20]: %timeit jax.dtypes.issubdtype(d, jax.dtypes.extended)
490 ns ± 2.78 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [22]: %timeit isinstance(d, jax._src.dtypes.ExtendedDType)
78.3 ns ± 0.111 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
```
PiperOrigin-RevId: 606616884
2024-02-13 07:37:14 -08:00
|
|
|
return dtype if isinstance(dtype, dtypes.ExtendedDType) else np.dtype(dtype)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
class UnshapedArray(AbstractValue):
|
|
|
|
__slots__ = ['dtype', 'weak_type']
|
|
|
|
array_abstraction_level = 4
|
|
|
|
|
|
|
|
def __init__(self, dtype, weak_type=False):
|
|
|
|
self.dtype = _dtype_object(dtype)
|
|
|
|
self.weak_type = weak_type
|
|
|
|
|
|
|
|
def update(self, dtype=None, weak_type=None):
|
|
|
|
if dtype is None:
|
|
|
|
dtype = self.dtype
|
|
|
|
if weak_type is None:
|
|
|
|
weak_type = self.weak_type
|
|
|
|
return UnshapedArray(dtype, weak_type)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return (type(self) is type(other) and self.dtype == other.dtype and
|
|
|
|
self.weak_type == other.weak_type)
|
|
|
|
|
|
|
|
def __ne__(self, other):
|
|
|
|
return not self == other
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
|
|
|
|
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
|
|
|
|
# the unique character code via hash(self.dtype.char)
|
|
|
|
return hash((self.dtype, self.weak_type))
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return '{}({}{})'.format(self.__class__.__name__, self.str_short(),
|
|
|
|
", weak_type=True" if self.weak_type else "")
|
|
|
|
|
2023-09-19 09:00:19 -07:00
|
|
|
_bool = concretization_function_error(bool)
|
2022-12-15 20:34:43 -08:00
|
|
|
_int = concretization_function_error(int, True)
|
2023-09-19 09:00:19 -07:00
|
|
|
_float = concretization_function_error(float, True)
|
2022-12-15 20:34:43 -08:00
|
|
|
_complex = concretization_function_error(complex, True)
|
|
|
|
_hex = concretization_function_error(hex)
|
|
|
|
_oct = concretization_function_error(oct)
|
2023-09-19 09:00:19 -07:00
|
|
|
_index = concretization_function_error(operator.index)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def at_least_vspace(self) -> AbstractValue:
|
|
|
|
return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype),
|
|
|
|
self.weak_type)
|
|
|
|
|
|
|
|
def join(self, other):
|
|
|
|
if self.dtype == other.dtype:
|
|
|
|
if self.weak_type == other.weak_type:
|
|
|
|
return self
|
|
|
|
else:
|
|
|
|
return UnshapedArray(self.dtype, weak_type=False)
|
|
|
|
else:
|
|
|
|
raise TypeError(self, other)
|
|
|
|
|
|
|
|
def str_short(self, short_dtypes=False) -> str:
|
|
|
|
return _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
|
|
|
|
|
|
|
def strip_weak_type(self):
|
|
|
|
"""Returns a copy of the aval with weak_type=False."""
|
|
|
|
return self.update(weak_type=False)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def shape(self):
|
|
|
|
msg = ("UnshapedArray has no shape. Please open an issue at "
|
|
|
|
"https://github.com/google/jax/issues because it's unexpected for "
|
|
|
|
"UnshapedArray instances to ever be produced.")
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
2024-04-18 11:09:02 -07:00
|
|
|
def _canonicalize_dimension(dim: DimSize) -> DimSize:
|
|
|
|
# Dimensions are most commonly integral (by far), so we check that first.
|
|
|
|
try:
|
|
|
|
return operator.index(dim)
|
|
|
|
except TypeError as e:
|
|
|
|
type_error = e
|
|
|
|
if isinstance(dim, Tracer) and config.dynamic_shapes.value:
|
|
|
|
if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer)
|
|
|
|
or isinstance(dim.dtype, bint))):
|
|
|
|
raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}")
|
|
|
|
return dim
|
|
|
|
elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
|
|
|
|
type(dim._aval.dtype) is bint and not dim._aval.shape):
|
|
|
|
return dim
|
|
|
|
elif is_dim(dim):
|
|
|
|
return dim
|
|
|
|
else:
|
|
|
|
raise type_error
|
|
|
|
|
|
|
|
def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
|
|
|
|
"""Canonicalizes and checks for errors in a user-provided shape value.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
shape: a Python value that represents a shape.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A tuple of canonical dimension values.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
return tuple(unsafe_map(_canonicalize_dimension, shape))
|
|
|
|
except TypeError:
|
|
|
|
pass
|
|
|
|
raise _invalid_shape_error(shape, context)
|
|
|
|
|
|
|
|
def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
|
|
|
|
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
f: a Python value that represents a dimension.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A canonical dimension value.
|
|
|
|
"""
|
|
|
|
return canonicalize_shape((d,), context)[0]
|
|
|
|
|
|
|
|
def _invalid_shape_error(shape: Shape, context: str=""):
|
|
|
|
if config.dynamic_shapes.value:
|
|
|
|
msg = ("Shapes must be 1D sequences of integer scalars, "
|
|
|
|
f"got {shape}")
|
|
|
|
else:
|
|
|
|
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
|
|
|
|
f"got {shape}.")
|
|
|
|
if context:
|
|
|
|
msg += f" {context}."
|
|
|
|
if not config.dynamic_shapes.value and any(
|
|
|
|
isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
|
|
|
|
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
|
|
|
|
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
|
|
|
|
"smaller subfunctions.")
|
|
|
|
for x in shape:
|
|
|
|
if isinstance(x, Tracer) and hasattr(x, "_origin_msg"):
|
|
|
|
msg += x._origin_msg()
|
|
|
|
|
|
|
|
return TypeError(msg)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
class ShapedArray(UnshapedArray):
|
|
|
|
__slots__ = ['shape', 'named_shape']
|
|
|
|
array_abstraction_level = 2
|
|
|
|
|
|
|
|
def __init__(self, shape, dtype, weak_type=False, named_shape=None):
|
|
|
|
self.shape = canonicalize_shape(shape)
|
|
|
|
self.dtype = _dtype_object(dtype)
|
|
|
|
self.weak_type = weak_type
|
|
|
|
self.named_shape = {} if named_shape is None else dict(named_shape)
|
|
|
|
|
|
|
|
def update(self, shape=None, dtype=None, weak_type=None, named_shape=None):
|
|
|
|
if shape is None:
|
|
|
|
shape = self.shape
|
|
|
|
if dtype is None:
|
|
|
|
dtype = self.dtype
|
|
|
|
if weak_type is None:
|
|
|
|
weak_type = self.weak_type
|
|
|
|
if named_shape is None:
|
|
|
|
named_shape = self.named_shape
|
|
|
|
return ShapedArray(shape, dtype, weak_type, named_shape)
|
|
|
|
|
|
|
|
ndim = property(lambda self: len(self.shape))
|
2023-04-19 17:01:05 -07:00
|
|
|
size = property(lambda self:
|
|
|
|
0 if any(type(d) is int and d == 0 for d in self.shape)
|
|
|
|
else math.prod(self.shape))
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
broadcast: ClassVar[aval_method | None] = None
|
|
|
|
transpose: ClassVar[aval_method | None] = None
|
|
|
|
reshape: ClassVar[aval_method | None] = None
|
|
|
|
_iter: ClassVar[staticmethod | None] = None
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return (type(self) is type(other)
|
|
|
|
and self.dtype == other.dtype and self.shape == other.shape
|
|
|
|
and self.weak_type == other.weak_type
|
|
|
|
and self.named_shape == other.named_shape)
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
|
|
|
|
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
|
|
|
|
# the unique character code via hash(self.dtype.char)
|
|
|
|
return hash((self.shape, self.dtype, self.weak_type,
|
|
|
|
tuple(self.named_shape.items())))
|
|
|
|
|
|
|
|
def at_least_vspace(self):
|
|
|
|
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
|
|
|
self.weak_type, self.named_shape)
|
|
|
|
|
|
|
|
def join(self, other):
|
2023-06-30 12:31:47 +03:00
|
|
|
if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
|
2022-12-15 20:34:43 -08:00
|
|
|
weak_type = self.weak_type and other.weak_type
|
|
|
|
named_shape = join_named_shapes(self.named_shape, other.named_shape)
|
|
|
|
return self.update(weak_type=weak_type, named_shape=named_shape)
|
|
|
|
elif self.dtype == other.dtype:
|
|
|
|
return UnshapedArray(self.dtype)
|
|
|
|
else:
|
|
|
|
raise TypeError(self, other)
|
|
|
|
|
|
|
|
def str_short(self, short_dtypes=False):
|
|
|
|
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
2024-01-17 09:30:38 -08:00
|
|
|
dt_str = dt_str.replace('void', 'float0')
|
2022-12-15 20:34:43 -08:00
|
|
|
shapestr = ','.join(map(str, self.shape))
|
|
|
|
if self.named_shape:
|
|
|
|
named_shapestr = ','.join(f'{k}:{v}' for k, v in self.named_shape.items())
|
|
|
|
return f'{dt_str}[{shapestr};{named_shapestr}]'
|
|
|
|
else:
|
|
|
|
return f'{dt_str}[{shapestr}]'
|
|
|
|
|
|
|
|
def strip_named_shape(self):
|
|
|
|
return self.update(named_shape={})
|
|
|
|
|
|
|
|
def _len(self, ignored_tracer):
|
|
|
|
try:
|
|
|
|
return self.shape[0]
|
|
|
|
except IndexError as err:
|
|
|
|
raise TypeError("len() of unsized object") from err # same as numpy error
|
|
|
|
|
|
|
|
|
|
|
|
def _forward_to_value(self, fun, ignored_tracer, *args):
|
|
|
|
return fun(self.val, *args)
|
|
|
|
|
|
|
|
|
|
|
|
class ConcreteArray(ShapedArray):
|
|
|
|
__slots__ = ['val']
|
|
|
|
array_abstraction_level = 0
|
|
|
|
|
|
|
|
def __init__(self, dtype, val, weak_type=None):
|
|
|
|
super().__init__(
|
|
|
|
np.shape(val), dtype,
|
|
|
|
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type)
|
2023-09-22 13:46:09 -07:00
|
|
|
dtypes.check_valid_dtype(self.dtype)
|
2022-12-15 20:34:43 -08:00
|
|
|
# Note: canonicalized self.dtype doesn't necessarily match self.val
|
|
|
|
assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype)
|
|
|
|
self.val = val
|
|
|
|
|
|
|
|
def update(self, dtype=None, val=None, weak_type=None):
|
|
|
|
dtype = self.dtype if dtype is None else dtype
|
|
|
|
val = self.val if val is None else val
|
|
|
|
weak_type = self.weak_type if weak_type is None else weak_type
|
|
|
|
return ConcreteArray(dtype, val, weak_type)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
if (type(self) is type(other) and self.dtype == other.dtype
|
|
|
|
and self.shape == other.shape and self.weak_type == other.weak_type):
|
2023-08-18 16:50:36 -04:00
|
|
|
with eval_context(): # in case self.val is an Array
|
2022-12-15 20:34:43 -08:00
|
|
|
return (self.val == other.val).all()
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return id(self.val)
|
|
|
|
|
|
|
|
def join(self, other) -> AbstractValue:
|
|
|
|
if self == other:
|
|
|
|
return self
|
|
|
|
elif self.shape == other.shape and self.dtype == other.dtype:
|
|
|
|
weak_type = self.weak_type and other.weak_type
|
|
|
|
named_shape = join_named_shapes(self.named_shape, other.named_shape)
|
|
|
|
return ShapedArray(
|
|
|
|
self.shape, self.dtype, weak_type=weak_type, named_shape=named_shape)
|
|
|
|
elif self.dtype == other.dtype:
|
|
|
|
return UnshapedArray(self.dtype,
|
|
|
|
weak_type=self.weak_type and other.weak_type)
|
|
|
|
else:
|
|
|
|
raise TypeError(self, other)
|
|
|
|
|
|
|
|
def str_short(self, short_dtypes=False) -> str:
|
|
|
|
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
|
|
|
return f'{self.val}, dtype={dt_str}'
|
|
|
|
|
2023-09-19 09:00:19 -07:00
|
|
|
_bool = partialmethod(_forward_to_value, bool)
|
|
|
|
_int = partialmethod(_forward_to_value, int)
|
|
|
|
_hex = partialmethod(_forward_to_value, hex)
|
|
|
|
_oct = partialmethod(_forward_to_value, oct)
|
|
|
|
_index = partialmethod(_forward_to_value, operator.index)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-09-19 09:00:19 -07:00
|
|
|
_float = concretization_function_error(float, True)
|
|
|
|
_complex = concretization_function_error(complex, True)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def primal_dtype_to_tangent_dtype(primal_dtype):
|
Use an isinstance check rather than dtypes.issubdtype to check whether the dtype in an aval is an extended dtype.
We don't need the full generality of issubdtype, and this is slightly faster. This operation is very common (e.g., for every aval construction, even with a non-extended dtype).
On my laptop:
```
In [18]: d = jnp.dtype(jnp.int32)
In [20]: %timeit jax.dtypes.issubdtype(d, jax.dtypes.extended)
490 ns ± 2.78 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [22]: %timeit isinstance(d, jax._src.dtypes.ExtendedDType)
78.3 ns ± 0.111 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
```
PiperOrigin-RevId: 606616884
2024-02-13 07:37:14 -08:00
|
|
|
if isinstance(primal_dtype, dtypes.ExtendedDType):
|
2024-05-17 09:46:36 +01:00
|
|
|
return primal_dtype._rules.tangent_dtype(primal_dtype)
|
2023-12-20 12:47:43 -08:00
|
|
|
elif not dtypes.issubdtype(primal_dtype, np.inexact):
|
2022-12-15 20:34:43 -08:00
|
|
|
return dtypes.float0
|
|
|
|
else:
|
|
|
|
return primal_dtype
|
|
|
|
|
|
|
|
|
|
|
|
# Dynamic shape stuff below here! We keep the abstract values distinct just so
|
|
|
|
# as not to interfere with any static shape machinery.
|
|
|
|
|
|
|
|
# We have a convention of reusing AbsractValues as types, even though we could
|
|
|
|
# make a distinction and use abstract values during tracing only. This reuse
|
|
|
|
# becomes a bit more extreme with DShapedArrays. A DShapedArray's shape
|
|
|
|
# attribute is a tuple which can contain several different types: int, DArray
|
|
|
|
# (scalar and with dtype of bint type), Tracer (while tracing), Var (when used
|
|
|
|
# as jaxpr type annotations), or DBIdx/InDBIdx/OutDBIdx (when used in InputType
|
|
|
|
# or OutputType). We could reduce this polymorphism if it seems cleaner, though
|
|
|
|
# it's kind of convenient!
|
|
|
|
class DShapedArray(UnshapedArray):
|
|
|
|
__slots__ = ['shape']
|
2023-06-23 15:11:37 -07:00
|
|
|
shape: tuple[AxisSize, ...] # noqa: F821
|
2022-12-15 20:34:43 -08:00
|
|
|
array_abstraction_level: int = 3
|
|
|
|
|
|
|
|
def __init__(self, shape, dtype, weak_type=False):
|
|
|
|
self.shape = shape
|
|
|
|
self.dtype = dtype
|
|
|
|
self.weak_type = weak_type
|
|
|
|
|
|
|
|
ndim = property(lambda self: len(self.shape))
|
2023-04-19 17:01:05 -07:00
|
|
|
size = property(lambda self:
|
|
|
|
0 if any(type(d) is int and d == 0 for d in self.shape)
|
|
|
|
else math.prod(self.shape))
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def str_short(self, short_dtypes=False) -> str:
|
|
|
|
del short_dtypes # ignored
|
|
|
|
shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else ''
|
|
|
|
dtype = _short_dtype_name(self.dtype)
|
|
|
|
return f'{dtype}[{shape}]'
|
|
|
|
__str__ = __repr__ = str_short
|
|
|
|
|
|
|
|
def update(self, shape=None, dtype=None, weak_type=None):
|
|
|
|
if shape is None:
|
|
|
|
shape = self.shape
|
|
|
|
if dtype is None:
|
|
|
|
dtype = self.dtype
|
|
|
|
if weak_type is None:
|
|
|
|
weak_type = self.weak_type
|
|
|
|
return DShapedArray(shape, dtype, weak_type)
|
|
|
|
|
2023-05-17 17:40:30 -07:00
|
|
|
def _len(self, tracer):
|
|
|
|
return self.shape[0]
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
def __eq__(self, other):
|
|
|
|
return (type(self) is type(other)
|
|
|
|
and self.dtype == other.dtype and self.shape == other.shape
|
|
|
|
and self.weak_type == other.weak_type)
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash((self.shape, self.dtype, self.weak_type))
|
|
|
|
|
|
|
|
def join(self, other):
|
2023-06-30 12:31:47 +03:00
|
|
|
if (definitely_equal_shape(self.shape, other.shape) and
|
2022-12-15 20:34:43 -08:00
|
|
|
self.dtype == other.dtype):
|
|
|
|
weak_type = self.weak_type and other.weak_type
|
|
|
|
return self.update(weak_type=weak_type)
|
|
|
|
elif self.dtype == other.dtype:
|
|
|
|
return UnshapedArray(self.dtype)
|
|
|
|
else:
|
|
|
|
raise TypeError(self, other)
|
|
|
|
|
|
|
|
def at_least_vspace(self):
|
|
|
|
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
|
|
|
self.weak_type)
|
|
|
|
|
|
|
|
class DConcreteArray(DShapedArray):
|
|
|
|
__slots__ = ['val']
|
|
|
|
array_abstraction_level = 1
|
|
|
|
def __init__(self, shape, dtype, weak_type, val):
|
|
|
|
super().__init__(shape, dtype, weak_type)
|
|
|
|
self.val = val
|
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
|
|
|
|
class DArray:
|
|
|
|
_aval: DShapedArray
|
|
|
|
_data: Any # standard array type
|
|
|
|
def __init__(self, aval, data):
|
|
|
|
pad_shape = tuple(d.dtype.bound if type(d) is DArray and
|
|
|
|
type(d.dtype) is bint else d for d in aval.shape)
|
|
|
|
assert data.shape == pad_shape
|
|
|
|
self._aval = aval
|
|
|
|
self._data = data
|
|
|
|
shape = property(lambda self: self._aval.shape)
|
|
|
|
dtype = property(lambda self: self._aval.dtype)
|
2023-11-29 18:06:36 -08:00
|
|
|
aval = property(lambda self: self._aval)
|
2022-12-15 20:34:43 -08:00
|
|
|
def __repr__(self) -> str:
|
|
|
|
if not self.shape and type(self.dtype) is bint:
|
|
|
|
# special-case scalar bints
|
|
|
|
return f'{int(self._data)}{{≤{self.dtype.bound}}}'
|
|
|
|
|
|
|
|
dtypestr = _short_dtype_name(self._aval.dtype)
|
|
|
|
shapestr = ','.join(map(str, self.shape))
|
|
|
|
slices = tuple(slice(int(d._data)) if type(d) is DArray and
|
|
|
|
type(d.dtype) is bint else slice(None) for d in self.shape)
|
|
|
|
data = self._data[slices]
|
2023-05-17 17:40:30 -07:00
|
|
|
return f'{dtypestr}[{shapestr}] with value: {data}'
|
2022-12-15 20:34:43 -08:00
|
|
|
def __hash__(self) -> int:
|
|
|
|
if not self.shape:
|
|
|
|
return hash((self._aval, int(self._data)))
|
|
|
|
raise TypeError("unhashable type: DArray")
|
|
|
|
def __eq__(self, other):
|
|
|
|
if isinstance(other, DArray) and self._aval == other._aval:
|
|
|
|
return self._data == other._data
|
|
|
|
return False
|
2023-05-05 15:25:42 -04:00
|
|
|
def __len__(self):
|
|
|
|
return self.shape[0]
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
pytype_aval_mappings[DArray] = \
|
|
|
|
lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type,
|
|
|
|
x._data)
|
|
|
|
|
2023-12-21 17:43:31 -08:00
|
|
|
@dataclass(frozen=True)
|
2023-07-24 14:29:37 -07:00
|
|
|
class bint(dtypes.ExtendedDType):
|
2022-12-15 20:34:43 -08:00
|
|
|
bound: int
|
|
|
|
|
2023-07-20 14:29:23 -07:00
|
|
|
@property
|
|
|
|
def type(self) -> type:
|
2023-07-24 14:29:37 -07:00
|
|
|
return dtypes.extended
|
2023-07-20 14:29:23 -07:00
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
@property
|
|
|
|
def name(self) -> str:
|
|
|
|
return f'bint{{≤{self.bound}}}'
|
|
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
|
return self.name
|
|
|
|
|
|
|
|
AxisSize = Union[int, DArray, Tracer, Var, DBIdx, InDBIdx, OutDBIdx]
|
|
|
|
|
|
|
|
|
2024-03-01 11:07:45 -08:00
|
|
|
class MutableArray:
|
|
|
|
_aval: ShapedArray
|
|
|
|
_buf: Array
|
|
|
|
def __init__(self, aval, buf):
|
|
|
|
self._aval = aval
|
|
|
|
self._buf = buf
|
|
|
|
aval = property(lambda self: self._aval)
|
|
|
|
shape = property(lambda self: self._aval.shape)
|
|
|
|
dtype = property(lambda self: self._aval.dtype)
|
|
|
|
def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
|
|
|
|
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
|
2024-03-05 16:20:24 -08:00
|
|
|
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
|
2024-03-01 11:07:45 -08:00
|
|
|
pytype_aval_mappings[MutableArray] = lambda x: x._aval
|
|
|
|
|
|
|
|
def mutable_array(init_val):
|
|
|
|
return mutable_array_p.bind(init_val)
|
|
|
|
mutable_array_p = Primitive('mutable_array')
|
|
|
|
|
2024-04-28 21:16:13 -04:00
|
|
|
class InternalMutableArrayEffect(effects.Effect):
|
2024-04-02 10:30:35 -07:00
|
|
|
pass
|
2024-04-28 21:16:13 -04:00
|
|
|
internal_mutable_array_effect = InternalMutableArrayEffect()
|
|
|
|
effects.control_flow_allowed_effects.add_type(InternalMutableArrayEffect)
|
2024-04-02 10:30:35 -07:00
|
|
|
|
|
|
|
@mutable_array_p.def_effectful_abstract_eval
|
|
|
|
def mutable_array_abstract_eval(init_aval):
|
2024-05-22 06:35:38 -07:00
|
|
|
from jax._src.state.types import AbstractRef # pytype: disable=import-error
|
2024-04-28 21:16:13 -04:00
|
|
|
return AbstractRef(init_aval), {internal_mutable_array_effect}
|
2024-04-02 10:30:35 -07:00
|
|
|
|
2024-03-01 11:07:45 -08:00
|
|
|
@mutable_array_p.def_impl
|
|
|
|
def _mutable_array_impl(init_val):
|
2024-05-22 06:35:38 -07:00
|
|
|
from jax._src.state.types import AbstractRef # pytype: disable=import-error
|
2024-03-01 11:07:45 -08:00
|
|
|
aval = raise_to_shaped(get_aval(init_val))
|
|
|
|
return MutableArray(AbstractRef(aval), init_val)
|
|
|
|
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
class AbstractToken(AbstractValue):
|
|
|
|
def join(self, other):
|
|
|
|
if isinstance(other, AbstractToken):
|
|
|
|
return self
|
|
|
|
else:
|
|
|
|
assert False, f"Cannot join {self} with {other}"
|
|
|
|
def str_short(self, short_dtypes=False): return 'Tok'
|
|
|
|
def at_least_vspace(self): return self
|
|
|
|
abstract_token: AbstractToken = AbstractToken()
|
|
|
|
|
2024-04-18 11:09:02 -07:00
|
|
|
# Singleton shaped array used by all abstract tokens when shape/dtype is needed.
|
|
|
|
token_shaped_array: ShapedArray = ShapedArray((0,), np.dtype(np.bool_))
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
# Concrete token object
|
2024-04-18 11:09:02 -07:00
|
|
|
class Token:
|
|
|
|
# The underlying data wrapped by the token, could be used to threaded in and
|
|
|
|
# out of computations to build up data dependency.
|
|
|
|
_buf: Array
|
|
|
|
def __init__(self, buf):
|
|
|
|
self._buf = buf
|
|
|
|
def block_until_ready(self):
|
|
|
|
self._buf.block_until_ready()
|
2022-12-15 20:34:43 -08:00
|
|
|
pytype_aval_mappings[Token] = lambda _: abstract_token
|
|
|
|
|
|
|
|
|
|
|
|
def raise_to_shaped(aval: AbstractValue, weak_type=None):
|
|
|
|
aval_type = type(aval)
|
|
|
|
if aval_type is ShapedArray and weak_type is None:
|
|
|
|
return aval
|
|
|
|
if weak_type is None:
|
|
|
|
weak_type = getattr(aval, 'weak_type', False)
|
|
|
|
for typ in aval_type.__mro__:
|
|
|
|
handler = raise_to_shaped_mappings.get(typ)
|
|
|
|
if handler: return handler(aval, weak_type)
|
|
|
|
raise TypeError(type(aval))
|
|
|
|
|
2024-07-22 11:20:15 +00:00
|
|
|
raise_to_shaped_mappings: dict[type, Callable] = {
|
|
|
|
AbstractToken: lambda aval, _: aval,
|
|
|
|
Bot: lambda aval, _: aval,
|
|
|
|
UnshapedArray: lambda aval, _: aval,
|
|
|
|
ShapedArray: lambda aval, weak_type: ShapedArray(
|
|
|
|
aval.shape, aval.dtype, weak_type, aval.named_shape
|
|
|
|
),
|
|
|
|
DConcreteArray: lambda aval, weak_type: DShapedArray(
|
|
|
|
aval.shape, aval.dtype, weak_type
|
|
|
|
),
|
2022-12-15 20:34:43 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
### Operations on shapes and dimension sizes.
|
|
|
|
|
|
|
|
class InconclusiveDimensionOperation(Exception):
|
|
|
|
"""Raised when we cannot conclusively compute with symbolic dimensions."""
|
|
|
|
pass
|
|
|
|
|
2023-07-11 14:03:52 +01:00
|
|
|
def is_symbolic_dim(v: Any) -> bool:
|
|
|
|
"""Checks if a value is a symbolic dimension used for shape polymorphism.
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-07-11 14:03:52 +01:00
|
|
|
This should be used very rarely, because symbolic dimensions overload all
|
|
|
|
operators, and should just work.
|
2022-12-15 20:34:43 -08:00
|
|
|
"""
|
2023-07-11 14:03:52 +01:00
|
|
|
return hasattr(v, "dimension_as_value")
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def is_constant_dim(d: DimSize) -> bool:
|
2022-12-29 03:00:03 -08:00
|
|
|
# Whether the dimension is a static integer constant.
|
2022-12-15 20:34:43 -08:00
|
|
|
try:
|
2022-12-29 03:00:03 -08:00
|
|
|
operator.index(d)
|
2022-12-15 20:34:43 -08:00
|
|
|
return True
|
|
|
|
except:
|
|
|
|
return False
|
|
|
|
|
2022-12-29 03:00:03 -08:00
|
|
|
def is_dim(v: Any) -> bool:
|
2023-07-11 14:03:52 +01:00
|
|
|
return is_symbolic_dim(v) or is_constant_dim(v)
|
2022-12-29 03:00:03 -08:00
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
def is_constant_shape(s: Shape) -> bool:
|
|
|
|
# Whether the shape is a static constant.
|
|
|
|
return all(is_constant_dim(d) for d in s)
|
|
|
|
|
2023-06-30 12:31:47 +03:00
|
|
|
def definitely_equal_one_of_dim(d1: DimSize, dlist: Sequence[DimSize]) -> bool:
|
|
|
|
return any(definitely_equal(d1, d) for d in dlist)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-06-10 17:33:27 -04:00
|
|
|
def definitely_equal_shape(s1: Shape, s2: Shape) -> bool:
|
2023-06-30 12:31:47 +03:00
|
|
|
"""Check that two shapes are guaranteed to be element-wise equal.
|
|
|
|
|
|
|
|
In presence of dynamic shapes may return False even when the shapes may
|
|
|
|
be equal at runtime.
|
|
|
|
"""
|
2023-06-10 17:33:27 -04:00
|
|
|
return (len(s1) == len(s2) and
|
|
|
|
all(unsafe_map(definitely_equal, s1, s2)))
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
def divide_shape_sizes(s1: Shape, s2: Shape) -> DimSize:
|
|
|
|
"""Returns an integer "i" s.t., i * size(s2) == size(s1).
|
2023-07-11 14:03:52 +01:00
|
|
|
Raises InconclusiveDimensionOperation if there is no such integer."""
|
|
|
|
sz1 = math.prod(s1)
|
|
|
|
sz2 = math.prod(s2)
|
|
|
|
if definitely_equal(sz1, sz2): # Takes care of sz1 and sz2 being 0
|
|
|
|
return 1
|
|
|
|
q, r = divmod(sz1, sz2)
|
|
|
|
if isinstance(r, Tracer) or r != 0:
|
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
|
|
|
raise InconclusiveDimensionOperation(
|
|
|
|
f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}. "
|
|
|
|
f"The remainder {r} should be 0.")
|
2023-07-11 14:03:52 +01:00
|
|
|
return q
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-04-19 17:01:05 -07:00
|
|
|
def cancel_divide_tracers(num, denom):
|
|
|
|
partition = lambda l: partition_list([isinstance(d, Tracer) for d in l], l)
|
|
|
|
num, num_tracers = partition(num)
|
|
|
|
denom, denom_tracers = partition(denom)
|
|
|
|
if num_tracers or denom_tracers:
|
|
|
|
factor = _cancel_divide(num_tracers, denom_tracers)
|
|
|
|
if factor is not None:
|
|
|
|
size1 = math.prod(num)
|
|
|
|
size2 = math.prod(denom)
|
|
|
|
if size1 == size2 or size2 != 0:
|
|
|
|
return factor * (size1 // size2 if size1 != size2 else 1)
|
|
|
|
|
|
|
|
def _cancel_divide(num, denom):
|
|
|
|
num = list(num)
|
|
|
|
for a in denom:
|
|
|
|
i = next((i for i, b in enumerate(num) if definitely_equal(a, b)), None)
|
|
|
|
if i is None:
|
|
|
|
break # couldn't cancel
|
|
|
|
del num[i]
|
|
|
|
else:
|
|
|
|
return math.prod(num)
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
def is_empty_shape(s: Shape) -> bool:
|
2023-06-30 12:31:47 +03:00
|
|
|
return any(definitely_equal(d, 0) for d in s)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def dilate_dim(d: DimSize, dilation: DimSize) -> DimSize:
|
[shape_poly] Add support for max0 for symbolic dimensions.
There are a few cases when JAX computes `max(v, 0)`, most
notably when computing the sizes of strided access,
dilated convolutions and padding, and for the size
of jnp.arange.
Until now these cases were supported
for shape polymorphism only when we can tell statically
that the size is >= 0. Here we add support to the
symbolic expressions for a `non_negative` operator,
which essentially implements `max(v, 0)` and with this
we can now support the general case for `jnp.arange`, with
simpler code.
We could add a general `max` operator, and we may do so in the
future, but for now `non_negative` suffices.
Note that this fixes a couple of bugs
* for core.dilated_dim we had the code "if d == 0 then 0 else ..."
but this works only if we can tell statically that `d == 0`, and
it produced wrong results when `d` was symbolic and could take
the value 0.
* for core.stride_dim we did not handle correctly the case when
`d < window_size`.
Handling the above fundamentally requires a `max(d, 0)` operation.
2023-07-12 12:46:47 +03:00
|
|
|
"""max(0, 1 + dilation * (d - 1)).
|
2022-12-15 20:34:43 -08:00
|
|
|
|
[shape_poly] Add support for max0 for symbolic dimensions.
There are a few cases when JAX computes `max(v, 0)`, most
notably when computing the sizes of strided access,
dilated convolutions and padding, and for the size
of jnp.arange.
Until now these cases were supported
for shape polymorphism only when we can tell statically
that the size is >= 0. Here we add support to the
symbolic expressions for a `non_negative` operator,
which essentially implements `max(v, 0)` and with this
we can now support the general case for `jnp.arange`, with
simpler code.
We could add a general `max` operator, and we may do so in the
future, but for now `non_negative` suffices.
Note that this fixes a couple of bugs
* for core.dilated_dim we had the code "if d == 0 then 0 else ..."
but this works only if we can tell statically that `d == 0`, and
it produced wrong results when `d` was symbolic and could take
the value 0.
* for core.stride_dim we did not handle correctly the case when
`d < window_size`.
Handling the above fundamentally requires a `max(d, 0)` operation.
2023-07-12 12:46:47 +03:00
|
|
|
Assumes dilation >= 1.
|
2023-07-11 14:03:52 +01:00
|
|
|
"""
|
|
|
|
if definitely_equal(dilation, 1): # fast path
|
|
|
|
return d
|
2023-12-13 10:14:27 +01:00
|
|
|
return max_dim(1 + dilation * (d - 1), 0)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
|
[shape_poly] Add support for max0 for symbolic dimensions.
There are a few cases when JAX computes `max(v, 0)`, most
notably when computing the sizes of strided access,
dilated convolutions and padding, and for the size
of jnp.arange.
Until now these cases were supported
for shape polymorphism only when we can tell statically
that the size is >= 0. Here we add support to the
symbolic expressions for a `non_negative` operator,
which essentially implements `max(v, 0)` and with this
we can now support the general case for `jnp.arange`, with
simpler code.
We could add a general `max` operator, and we may do so in the
future, but for now `non_negative` suffices.
Note that this fixes a couple of bugs
* for core.dilated_dim we had the code "if d == 0 then 0 else ..."
but this works only if we can tell statically that `d == 0`, and
it produced wrong results when `d` was symbolic and could take
the value 0.
* for core.stride_dim we did not handle correctly the case when
`d < window_size`.
Handling the above fundamentally requires a `max(d, 0)` operation.
2023-07-12 12:46:47 +03:00
|
|
|
"""max(0, (d - window_size) // window_stride + 1)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-07-11 14:03:52 +01:00
|
|
|
If d < window_size, returns 0.
|
|
|
|
We assume window_size >= 1 and window_stride >= 1.
|
|
|
|
"""
|
[shape_poly] Add support for max0 for symbolic dimensions.
There are a few cases when JAX computes `max(v, 0)`, most
notably when computing the sizes of strided access,
dilated convolutions and padding, and for the size
of jnp.arange.
Until now these cases were supported
for shape polymorphism only when we can tell statically
that the size is >= 0. Here we add support to the
symbolic expressions for a `non_negative` operator,
which essentially implements `max(v, 0)` and with this
we can now support the general case for `jnp.arange`, with
simpler code.
We could add a general `max` operator, and we may do so in the
future, but for now `non_negative` suffices.
Note that this fixes a couple of bugs
* for core.dilated_dim we had the code "if d == 0 then 0 else ..."
but this works only if we can tell statically that `d == 0`, and
it produced wrong results when `d` was symbolic and could take
the value 0.
* for core.stride_dim we did not handle correctly the case when
`d < window_size`.
Handling the above fundamentally requires a `max(d, 0)` operation.
2023-07-12 12:46:47 +03:00
|
|
|
# If d < window_size then (d - window_size) // window_stride < 0
|
2023-12-13 10:14:27 +01:00
|
|
|
return max_dim((d - window_size) // window_stride + 1, 0)
|
[shape_poly] Add support for max0 for symbolic dimensions.
There are a few cases when JAX computes `max(v, 0)`, most
notably when computing the sizes of strided access,
dilated convolutions and padding, and for the size
of jnp.arange.
Until now these cases were supported
for shape polymorphism only when we can tell statically
that the size is >= 0. Here we add support to the
symbolic expressions for a `non_negative` operator,
which essentially implements `max(v, 0)` and with this
we can now support the general case for `jnp.arange`, with
simpler code.
We could add a general `max` operator, and we may do so in the
future, but for now `non_negative` suffices.
Note that this fixes a couple of bugs
* for core.dilated_dim we had the code "if d == 0 then 0 else ..."
but this works only if we can tell statically that `d == 0`, and
it produced wrong results when `d` was symbolic and could take
the value 0.
* for core.stride_dim we did not handle correctly the case when
`d < window_size`.
Handling the above fundamentally requires a `max(d, 0)` operation.
2023-07-12 12:46:47 +03:00
|
|
|
|
2023-12-13 10:14:27 +01:00
|
|
|
# TODO(necula): Deprecated Jan 2024, to be removed.
|
[shape_poly] Add support for max0 for symbolic dimensions.
There are a few cases when JAX computes `max(v, 0)`, most
notably when computing the sizes of strided access,
dilated convolutions and padding, and for the size
of jnp.arange.
Until now these cases were supported
for shape polymorphism only when we can tell statically
that the size is >= 0. Here we add support to the
symbolic expressions for a `non_negative` operator,
which essentially implements `max(v, 0)` and with this
we can now support the general case for `jnp.arange`, with
simpler code.
We could add a general `max` operator, and we may do so in the
future, but for now `non_negative` suffices.
Note that this fixes a couple of bugs
* for core.dilated_dim we had the code "if d == 0 then 0 else ..."
but this works only if we can tell statically that `d == 0`, and
it produced wrong results when `d` was symbolic and could take
the value 0.
* for core.stride_dim we did not handle correctly the case when
`d < window_size`.
Handling the above fundamentally requires a `max(d, 0)` operation.
2023-07-12 12:46:47 +03:00
|
|
|
def non_negative_dim(d: DimSize) -> DimSize:
|
|
|
|
"""max(d, 0)."""
|
2023-12-13 10:14:27 +01:00
|
|
|
return max_dim(d, 0)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-12-07 09:37:13 +01:00
|
|
|
def min_dim(d1: DimSize, d2: DimSize) -> DimSize:
|
|
|
|
"""Like min(d1, d2) but for both constant and symbolic dimensions."""
|
2023-12-13 10:14:27 +01:00
|
|
|
d1_is_constant = is_constant_dim(d1)
|
|
|
|
if d1_is_constant and is_constant_dim(d2):
|
|
|
|
return min(d1, d2)
|
2024-01-15 15:02:33 +02:00
|
|
|
if d1_is_constant:
|
|
|
|
return d2.rmin(d1) # type: ignore[union-attr]
|
|
|
|
else:
|
|
|
|
return d1.min(d2) # type: ignore[union-attr]
|
2023-12-07 09:37:13 +01:00
|
|
|
|
|
|
|
def max_dim(d1: DimSize, d2: DimSize) -> DimSize:
|
|
|
|
"""Like max(d1, d2) but for both constant and symbolic dimensions."""
|
2023-12-13 10:14:27 +01:00
|
|
|
d1_is_constant = is_constant_dim(d1)
|
|
|
|
if d1_is_constant and is_constant_dim(d2):
|
|
|
|
return max(d1, d2)
|
2024-01-15 15:02:33 +02:00
|
|
|
if d1_is_constant:
|
|
|
|
return d2.rmax(d1) # type: ignore[union-attr]
|
|
|
|
else:
|
|
|
|
return d1.max(d2) # type: ignore[union-attr]
|
2023-12-07 09:37:13 +01:00
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
def dimension_as_value(d: DimSize):
|
2023-07-11 14:03:52 +01:00
|
|
|
"""Turns a dimension size into a JAX array.
|
2022-12-20 11:01:51 +02:00
|
|
|
This is the identity function for constant dimensions.
|
|
|
|
|
|
|
|
Has the same abstract value as Python constants.
|
|
|
|
"""
|
2022-12-25 14:22:29 +02:00
|
|
|
if isinstance(d, (int, Tracer, np.int32, np.int64)): return d
|
|
|
|
# For shape_poly._DimPolynomial
|
|
|
|
if hasattr(d, "dimension_as_value"): return d.dimension_as_value()
|
|
|
|
return operator.index(d)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
class SomeTracer:
|
2023-06-13 15:56:36 -04:00
|
|
|
__slots__ = ()
|
|
|
|
def __repr__(self): return "[dynamic]"
|
|
|
|
|
|
|
|
def replace_tracer_for_error_message(obj):
|
|
|
|
# TODO(mattjj): Many ideas for improving this. Crawl the stack and see if
|
|
|
|
# there are user variables whose value is == to this object? Or search
|
|
|
|
# parameters of functions being transformed, at least? Or at least assign
|
|
|
|
# short unique ids to them?
|
|
|
|
if isinstance(obj, Tracer):
|
|
|
|
return SomeTracer()
|
|
|
|
else:
|
|
|
|
return obj
|
|
|
|
|
2023-03-29 12:09:47 +02:00
|
|
|
def evaluate_shape(shape: Shape, dim_vars: Sequence[str],
|
|
|
|
*dim_values: Array) -> Sequence[Array]:
|
|
|
|
"""Evaluates a shape possibly containing non-constants.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
shape: the shape to evaluate.
|
|
|
|
dim_vars: the dimension variables names that may appear in `shape`.
|
|
|
|
dim_values: the dimension values corresponding to `dim_vars`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
a tuple of JAX values corresponding to `shape`, of type
|
|
|
|
`dim_value_dtype`.
|
|
|
|
"""
|
|
|
|
env = dict(zip(dim_vars, dim_values))
|
|
|
|
def eval_one_dim(d: DimSize):
|
|
|
|
try:
|
|
|
|
return operator.index(d)
|
|
|
|
except:
|
|
|
|
# Is a _DimExpr
|
2024-02-20 23:13:20 +01:00
|
|
|
return d._evaluate(env) # type: ignore
|
2023-03-29 12:09:47 +02:00
|
|
|
return tuple(eval_one_dim(d) for d in shape)
|
|
|
|
|
|
|
|
def dim_value_dtype():
|
|
|
|
"""The dtype to be used for dimension values."""
|
|
|
|
return dtypes.canonicalize_dtype(np.int64)
|
|
|
|
|
|
|
|
def dim_constant(ct: int):
|
2023-04-26 08:46:52 +02:00
|
|
|
dtype = dim_value_dtype()
|
|
|
|
assert dtype in (np.int32, np.int64)
|
|
|
|
if dtype == np.int32:
|
|
|
|
return np.int32(ct)
|
|
|
|
elif dtype == np.int64:
|
|
|
|
return np.int64(ct)
|
2023-03-29 12:09:47 +02:00
|
|
|
|
|
|
|
def dim_value_aval() -> AbstractValue:
|
|
|
|
return ShapedArray((), dim_value_dtype(), weak_type=True)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
# ------------------- Named shapes -------------------
|
|
|
|
|
|
|
|
|
|
|
|
class NamedShape:
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
self.__positional = canonicalize_shape(args)
|
|
|
|
# TODO: Assert that kwargs match axis env?
|
|
|
|
self.__named = dict(kwargs)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def rank(self):
|
|
|
|
return len(self.__positional) + len(self.__named)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def positional_rank(self):
|
|
|
|
return len(self.__positional)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def named_rank(self):
|
|
|
|
return len(self.__named)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def positional(self):
|
|
|
|
return self.__positional
|
|
|
|
|
|
|
|
@property
|
|
|
|
def names(self):
|
|
|
|
return self.__named.keys()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def named_sizes(self):
|
|
|
|
return self.__named.values()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def named_items(self):
|
|
|
|
return self.__named.items()
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
try:
|
|
|
|
idx = operator.index(idx)
|
|
|
|
return self.__positional[idx]
|
|
|
|
except TypeError:
|
|
|
|
pass
|
|
|
|
return self.__named[idx]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def total(self):
|
|
|
|
total = 1
|
|
|
|
for s in self.__positional: total *= s
|
|
|
|
for s in self.__named.values(): total *= s
|
|
|
|
return total
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
# TODO(mattjj,frostig): revise not to miss commas
|
|
|
|
if not self.__named:
|
|
|
|
return str(self.__positional)
|
|
|
|
return (f"({', '.join(map(str, self.__positional))}{', ' if self.__named else ''}"
|
|
|
|
f"{', '.join(f'{k}={v}' for k, v in self.__named.items())})")
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
if isinstance(other, NamedShape):
|
|
|
|
return (self.__positional, self.__named) == (other.__positional, other.__named)
|
|
|
|
if isinstance(other, tuple):
|
|
|
|
return not self.__named and self.__positional == other
|
|
|
|
return False
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
named = frozenset(self.__named.items())
|
|
|
|
return hash((self.__positional, named))
|
|
|
|
|
|
|
|
def join_named_shapes(*named_shapes):
|
|
|
|
result = {}
|
|
|
|
for named_shape in named_shapes:
|
|
|
|
for name, size in named_shape.items():
|
|
|
|
if result.setdefault(name, size) != size:
|
|
|
|
raise TypeError(
|
|
|
|
f"Axis name {name} used with inconsistent sizes: {result[name]} != {size}")
|
|
|
|
return result
|
|
|
|
|
|
|
|
# TODO: Make canonicalize_shape return named shapes?
|
|
|
|
def as_named_shape(shape) -> NamedShape:
|
2024-07-20 02:22:38 +00:00
|
|
|
if isinstance(shape, int):
|
|
|
|
shape = (shape,)
|
2022-12-15 20:34:43 -08:00
|
|
|
if isinstance(shape, NamedShape):
|
|
|
|
return shape
|
|
|
|
return NamedShape(*shape)
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------- Call -------------------
|
|
|
|
|
|
|
|
class CallPrimitive(Primitive):
|
|
|
|
multiple_results = True
|
|
|
|
call_primitive = True
|
|
|
|
|
|
|
|
def bind(self, fun, *args, **params):
|
|
|
|
call_bind_continuation, top_trace, fun_, tracers, params = (
|
|
|
|
call_bind_with_continuation(self, fun, *args, **params))
|
|
|
|
outs = top_trace.process_call(self, fun_, tracers, params)
|
|
|
|
return call_bind_continuation(outs)
|
|
|
|
|
|
|
|
def get_bind_params(self, params):
|
|
|
|
new_params = dict(params)
|
|
|
|
jaxpr = new_params.pop('call_jaxpr')
|
|
|
|
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
|
2023-10-09 07:28:18 -07:00
|
|
|
if config.dynamic_shapes.value:
|
2022-12-15 20:34:43 -08:00
|
|
|
subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr))
|
|
|
|
return [subfun], new_params
|
|
|
|
|
|
|
|
def call_bind_with_continuation(primitive: CallPrimitive, fun, *args, **params):
|
|
|
|
top_trace = find_top_trace(args)
|
|
|
|
fun_, env_trace_todo = process_env_traces_call(
|
2023-02-02 20:24:26 -08:00
|
|
|
fun, primitive, top_trace.level, tuple(params.items()))
|
2022-12-15 20:34:43 -08:00
|
|
|
tracers = map(top_trace.full_raise, args)
|
|
|
|
fun_ = lu.annotate(fun_, fun.in_type)
|
|
|
|
|
|
|
|
def call_bind_continuation(outs):
|
|
|
|
return map(full_lower, apply_todos(env_trace_todo(), outs))
|
|
|
|
return call_bind_continuation, top_trace, fun_, tracers, params
|
|
|
|
|
|
|
|
@lu.transformation_with_aux
|
2023-02-02 20:24:26 -08:00
|
|
|
def process_env_traces_call(primitive: CallPrimitive, level: int,
|
2022-12-15 20:34:43 -08:00
|
|
|
params_tuple: tuple, *args):
|
|
|
|
outs = yield args, {}
|
|
|
|
params = dict(params_tuple)
|
|
|
|
todo = []
|
|
|
|
while True:
|
2023-02-02 20:24:26 -08:00
|
|
|
tracers = [x for x in outs if isinstance(x, Tracer) and x._trace.level > level]
|
|
|
|
if not tracers:
|
2022-12-15 20:34:43 -08:00
|
|
|
break
|
2024-02-23 08:41:04 -08:00
|
|
|
ans = max(tracers, key=_get_trace_level)
|
2022-12-15 20:34:43 -08:00
|
|
|
trace = ans._trace.main.with_cur_sublevel()
|
|
|
|
outs = map(trace.full_raise, outs)
|
|
|
|
outs, cur_todo = trace.post_process_call(primitive, outs, params)
|
|
|
|
todo.append(cur_todo)
|
|
|
|
yield outs, tuple(todo) # Ensure the aux output is immutable
|
|
|
|
|
|
|
|
def apply_todos(todos, outs):
|
|
|
|
todos_list = list(todos)
|
|
|
|
while todos_list:
|
|
|
|
outs = map(full_lower, todos_list.pop()(outs))
|
|
|
|
return outs
|
|
|
|
|
|
|
|
|
|
|
|
def call_impl(f: lu.WrappedFun, *args, **params):
|
|
|
|
del params # params parameterize the call primitive, not the function
|
|
|
|
with new_sublevel():
|
|
|
|
return f.call_wrapped(*args)
|
|
|
|
|
|
|
|
call_p: CallPrimitive = CallPrimitive('call')
|
|
|
|
call = call_p.bind
|
|
|
|
call_p.def_impl(call_impl)
|
|
|
|
|
|
|
|
|
|
|
|
class ClosedCallPrimitive(CallPrimitive):
|
|
|
|
def get_bind_params(self, params):
|
|
|
|
new_params = dict(params)
|
|
|
|
jaxpr = new_params.pop('call_jaxpr')
|
|
|
|
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
|
|
|
|
return [subfun], new_params
|
|
|
|
|
|
|
|
closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')
|
|
|
|
closed_call_p.def_impl(call_impl)
|
2024-05-24 08:23:31 -07:00
|
|
|
closed_call_p.def_effectful_abstract_eval(
|
|
|
|
lambda *_, call_jaxpr: (call_jaxpr.out_avals, call_jaxpr.effects))
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
outfeed_primitives: set[Primitive] = set()
|
2022-12-15 20:34:43 -08:00
|
|
|
def jaxpr_uses_outfeed(jaxpr: Jaxpr) -> bool:
|
|
|
|
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
|
|
|
|
return any(primitive_uses_outfeed(eqn.primitive, eqn.params)
|
|
|
|
for eqn in jaxpr.eqns)
|
|
|
|
|
|
|
|
def _param_uses_outfeed(param):
|
|
|
|
if type(param) is Jaxpr:
|
|
|
|
if jaxpr_uses_outfeed(param):
|
|
|
|
return True
|
|
|
|
elif type(param) is ClosedJaxpr:
|
|
|
|
if jaxpr_uses_outfeed(param.jaxpr):
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def primitive_uses_outfeed(prim: Primitive, params: dict) -> bool:
|
2022-12-15 20:34:43 -08:00
|
|
|
if prim in outfeed_primitives:
|
|
|
|
return True
|
|
|
|
for param in params.values():
|
|
|
|
if isinstance(param, tuple):
|
|
|
|
if any(unsafe_map(_param_uses_outfeed, param)):
|
|
|
|
return True
|
|
|
|
elif _param_uses_outfeed(param):
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
# ------------------- Map -------------------
|
|
|
|
|
|
|
|
class MapPrimitive(Primitive):
|
|
|
|
multiple_results = True
|
|
|
|
map_primitive = True
|
|
|
|
|
|
|
|
def bind(self, fun, *args, **params):
|
|
|
|
assert len(params['in_axes']) == len(args)
|
|
|
|
return map_bind(self, fun, *args, **params)
|
|
|
|
|
|
|
|
def process(self, trace, fun, tracers, params):
|
|
|
|
return trace.process_map(self, fun, tracers, params)
|
|
|
|
|
|
|
|
def post_process(self, trace, out_tracers, params):
|
|
|
|
return trace.post_process_map(self, out_tracers, params)
|
|
|
|
|
|
|
|
def get_bind_params(self, params):
|
|
|
|
new_params = dict(params)
|
|
|
|
jaxpr = new_params.pop('call_jaxpr')
|
|
|
|
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
|
|
|
|
axes = new_params.pop('out_axes')
|
|
|
|
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
|
|
|
|
return [subfun], new_params
|
|
|
|
|
|
|
|
|
|
|
|
def map_bind_with_continuation(primitive: MapPrimitive, fun, *args,
|
|
|
|
out_axes_thunk, **params):
|
|
|
|
# The new thunk depends deterministically on the old thunk and the wrapped
|
|
|
|
# function. Any caching already has to include the wrapped function as part
|
|
|
|
# of the key, so we only use the previous thunk for equality checks.
|
|
|
|
@as_hashable_function(closure=out_axes_thunk)
|
|
|
|
def new_out_axes_thunk():
|
|
|
|
out_axes = out_axes_thunk()
|
|
|
|
_, out_axes_transforms = todo_and_xforms()
|
|
|
|
for t in out_axes_transforms:
|
|
|
|
out_axes = t(out_axes)
|
|
|
|
return out_axes
|
|
|
|
params = dict(params, out_axes_thunk=new_out_axes_thunk)
|
|
|
|
top_trace = find_top_trace(args)
|
|
|
|
fun, todo_and_xforms = process_env_traces_map(
|
|
|
|
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
|
|
|
|
tracers = map(top_trace.full_raise, args)
|
|
|
|
|
|
|
|
def map_bind_continuation(outs):
|
|
|
|
env_trace_todo, _ = todo_and_xforms()
|
|
|
|
return map(full_lower, apply_todos(env_trace_todo, outs))
|
|
|
|
|
|
|
|
return map_bind_continuation, top_trace, fun, tracers, params
|
|
|
|
|
|
|
|
|
|
|
|
def map_bind(primitive: MapPrimitive, fun, *args, **params):
|
|
|
|
map_bind_continuation, top_trace, fun, tracers, params = (
|
|
|
|
map_bind_with_continuation(primitive, fun, *args, **params))
|
|
|
|
return map_bind_continuation(
|
|
|
|
primitive.process(top_trace, fun, tracers, params))
|
|
|
|
|
|
|
|
@lu.transformation_with_aux
|
|
|
|
def process_env_traces_map(primitive: MapPrimitive, level: int,
|
|
|
|
params_tuple: tuple, *args):
|
|
|
|
outs = yield args, {}
|
|
|
|
params = dict(params_tuple)
|
|
|
|
todo = []
|
|
|
|
out_axes_transforms = []
|
|
|
|
while True:
|
|
|
|
tracers = [x for x in outs if isinstance(x, Tracer)
|
|
|
|
and (level is None or x._trace.level > level)]
|
2023-02-02 20:24:26 -08:00
|
|
|
if not tracers:
|
2022-12-15 20:34:43 -08:00
|
|
|
break
|
2024-02-23 08:41:04 -08:00
|
|
|
ans = max(tracers, key=_get_trace_level)
|
2022-12-15 20:34:43 -08:00
|
|
|
trace = ans._trace.main.with_cur_sublevel()
|
|
|
|
outs = map(trace.full_raise, outs)
|
|
|
|
outs, (cur_todo, cur_xform) = primitive.post_process(trace, outs, params)
|
|
|
|
todo.append(cur_todo)
|
|
|
|
out_axes_transforms.append(cur_xform)
|
|
|
|
yield outs, (tuple(todo), tuple(out_axes_transforms))
|
|
|
|
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
def mapped_aval(size: AxisSize, axis: int | None,
|
2022-12-15 20:34:43 -08:00
|
|
|
aval: AbstractValue) -> AbstractValue:
|
|
|
|
handler, _ = aval_mapping_handlers.get(type(aval), (None, None))
|
|
|
|
if handler is not None:
|
|
|
|
return handler(size, axis, aval)
|
|
|
|
else:
|
|
|
|
raise TypeError(f"no mapping handler for {aval} of type {type(aval)}")
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
def unmapped_aval(size: AxisSize, axis_name, axis: int | None,
|
2022-12-15 20:34:43 -08:00
|
|
|
aval: AbstractValue) -> AbstractValue:
|
|
|
|
_, handler = aval_mapping_handlers.get(type(aval), (None, None))
|
|
|
|
if handler is not None:
|
|
|
|
return handler(size, axis_name, axis, aval)
|
|
|
|
else:
|
|
|
|
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")
|
|
|
|
|
|
|
|
|
|
|
|
def _map_shaped_array(
|
2023-07-21 14:20:39 -04:00
|
|
|
size: int, axis: int | None, aval: ShapedArray) -> ShapedArray:
|
2022-12-15 20:34:43 -08:00
|
|
|
assert axis is None or aval.shape[axis] == size
|
|
|
|
# TODO: Extend the named shape
|
|
|
|
if axis is None: return aval
|
|
|
|
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
|
|
|
|
named_shape=aval.named_shape, weak_type=aval.weak_type)
|
|
|
|
|
|
|
|
def _unmap_shaped_array(
|
2023-07-21 14:20:39 -04:00
|
|
|
size: int, axis_name: AxisName, axis: int | None, aval: ShapedArray
|
2022-12-15 20:34:43 -08:00
|
|
|
) -> ShapedArray:
|
|
|
|
named_shape = dict(aval.named_shape)
|
|
|
|
named_shape.pop(axis_name, None) # TODO: make this mandatory
|
|
|
|
if axis is None: return aval.update(named_shape=named_shape)
|
|
|
|
elif type(axis) is int:
|
|
|
|
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
|
|
|
|
named_shape=named_shape, weak_type=aval.weak_type)
|
|
|
|
else: raise TypeError(axis)
|
|
|
|
|
|
|
|
def _map_dshaped_array(
|
2023-07-21 14:20:39 -04:00
|
|
|
size: AxisSize, axis: int | None, aval: DShapedArray) -> DShapedArray:
|
2022-12-15 20:34:43 -08:00
|
|
|
if axis is None: return aval
|
|
|
|
return DShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
|
|
|
|
aval.weak_type)
|
|
|
|
|
|
|
|
def _unmap_dshaped_array(
|
2023-07-21 14:20:39 -04:00
|
|
|
size: AxisSize, axis_name: AxisName, axis: int | None, aval: DShapedArray
|
2022-12-15 20:34:43 -08:00
|
|
|
) -> DShapedArray:
|
|
|
|
if axis is None: return aval
|
|
|
|
elif type(axis) is int:
|
|
|
|
return DShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
|
|
|
|
weak_type=aval.weak_type)
|
|
|
|
else:
|
|
|
|
raise TypeError(axis)
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
AvalMapHandlerPair = tuple[Callable, Callable]
|
|
|
|
aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
|
2022-12-15 20:34:43 -08:00
|
|
|
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
|
|
|
|
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
|
|
|
|
ConcreteArray: (_map_shaped_array, _unmap_shaped_array),
|
|
|
|
AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a)
|
|
|
|
}
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
|
|
|
|
frame = AxisEnvFrame(axis_name, size, tag)
|
|
|
|
ts = thread_local_state.trace_state
|
|
|
|
ts.axis_env.append(frame)
|
2023-10-09 07:28:18 -07:00
|
|
|
config.update_thread_local_jit_state(
|
2022-12-15 20:34:43 -08:00
|
|
|
axis_env_state=tuple(f for f in ts.axis_env
|
|
|
|
if f.name is not no_axis_name))
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
ts.axis_env.pop()
|
2023-10-09 07:28:18 -07:00
|
|
|
config.update_thread_local_jit_state(
|
2022-12-15 20:34:43 -08:00
|
|
|
axis_env_state=tuple(f for f in ts.axis_env
|
|
|
|
if f.name is not no_axis_name))
|
|
|
|
|
|
|
|
@contextmanager
|
2023-06-23 15:11:37 -07:00
|
|
|
def extend_axis_env_nd(axes: Iterable[tuple[AxisName, int]], tag: Any = None):
|
2023-04-08 21:12:40 -07:00
|
|
|
frames = [AxisEnvFrame(axis_name, size, tag) for axis_name, size in axes]
|
2022-12-15 20:34:43 -08:00
|
|
|
ts = thread_local_state.trace_state
|
|
|
|
ts.axis_env.extend(frames)
|
2023-10-09 07:28:18 -07:00
|
|
|
config.update_thread_local_jit_state(
|
2022-12-15 20:34:43 -08:00
|
|
|
axis_env_state=tuple(f for f in ts.axis_env
|
|
|
|
if f.name is not no_axis_name))
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
for _ in frames: ts.axis_env.pop()
|
2023-10-09 07:28:18 -07:00
|
|
|
config.update_thread_local_jit_state(
|
2022-12-15 20:34:43 -08:00
|
|
|
axis_env_state=tuple(f for f in ts.axis_env
|
|
|
|
if f.name is not no_axis_name))
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def stash_axis_env():
|
|
|
|
"Promise that a function or with-suite does not depend implicitly on axis env"
|
|
|
|
# If the promise is broken, then a NameError about an unbound axis name will
|
|
|
|
# be raised.
|
|
|
|
ts = thread_local_state.trace_state
|
|
|
|
prev_axis_env, ts.axis_env = ts.axis_env, []
|
2023-10-09 07:28:18 -07:00
|
|
|
config.update_thread_local_jit_state(axis_env_state=())
|
2022-12-15 20:34:43 -08:00
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
ts.axis_env = prev_axis_env
|
2023-10-09 07:28:18 -07:00
|
|
|
config.update_thread_local_jit_state(
|
2022-12-15 20:34:43 -08:00
|
|
|
axis_env_state=tuple(f for f in ts.axis_env
|
|
|
|
if f.name is not no_axis_name))
|
|
|
|
|
|
|
|
|
|
|
|
# When a mapped function is given no axis name, we generate a name object based
|
|
|
|
# on the id of the function object. Collisions aren't important because this
|
|
|
|
# name can't be used in collectives, as user code never gets a ref to this
|
|
|
|
# object. We don't want to use the function object itself because that might
|
|
|
|
# persist references to the function object.
|
|
|
|
# TODO(mattjj): revisit this unique axis name strategy
|
|
|
|
@total_ordering
|
|
|
|
class _TempAxisName:
|
|
|
|
|
|
|
|
def __init__(self, obj):
|
|
|
|
self.id = id(obj)
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return f'<axis {hex(self.id)}>'
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash(self.id)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return type(other) is _TempAxisName and self.id == other.id
|
|
|
|
|
|
|
|
def __lt__(self, other):
|
|
|
|
return type(other) is _TempAxisName and self.id < other.id
|
|
|
|
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
def axis_frame(axis_name: AxisName, main_trace: MainTrace | None = None
|
2023-03-30 22:11:21 -07:00
|
|
|
) -> AxisEnvFrame:
|
2022-12-15 20:34:43 -08:00
|
|
|
frames = thread_local_state.trace_state.axis_env
|
|
|
|
for frame in reversed(frames):
|
2023-03-30 22:11:21 -07:00
|
|
|
if (frame.name == axis_name and
|
|
|
|
(main_trace is None or frame.main_trace is main_trace)):
|
2022-12-15 20:34:43 -08:00
|
|
|
return frame
|
|
|
|
named_axes = [frame.name for frame in reversed(frames)
|
|
|
|
if not isinstance(frame.name, _TempAxisName)]
|
|
|
|
raise NameError(
|
|
|
|
f'unbound axis name: {axis_name}. The following axis names (e.g. defined '
|
|
|
|
f'by pmap) are available to collective operations: {named_axes}')
|
|
|
|
|
|
|
|
|
2024-03-04 05:41:29 -08:00
|
|
|
@dataclass(frozen=True)
|
|
|
|
class NamedAxisEffect(effects.Effect):
|
|
|
|
"""A side-effect introducing a new named axis into the current scope."""
|
|
|
|
|
|
|
|
name: AxisName
|
|
|
|
|
|
|
|
|
|
|
|
effects.control_flow_allowed_effects.add_type(NamedAxisEffect)
|
|
|
|
effects.custom_derivatives_allowed_effects.add_type(NamedAxisEffect)
|
|
|
|
effects.lowerable_effects.add_type(NamedAxisEffect)
|
|
|
|
effects.remat_allowed_effects.add_type(NamedAxisEffect)
|
|
|
|
|
|
|
|
|
|
|
|
def filter_named_axis_effects(
|
|
|
|
effects: Effects, names: Collection[AxisName]
|
|
|
|
) -> Effects:
|
|
|
|
return {e for e in effects
|
|
|
|
if not isinstance(e, NamedAxisEffect) or e.name not in names}
|
|
|
|
|
|
|
|
|
|
|
|
def remove_named_axis_effects(
|
|
|
|
jaxpr: Jaxpr, names: Collection[AxisName]
|
|
|
|
) -> Jaxpr:
|
|
|
|
if not names or not jaxpr.effects:
|
|
|
|
return jaxpr
|
|
|
|
return jaxpr.replace(effects=filter_named_axis_effects(jaxpr.effects, names))
|
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
ParamDict = dict[str, Any]
|
|
|
|
AxisSubst = Callable[[AxisName], tuple[AxisName, ...]]
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
class NameGatheringSubst:
|
|
|
|
def __init__(self):
|
|
|
|
self.axis_names = set()
|
|
|
|
def __call__(self, axis_name):
|
|
|
|
self.axis_names.add(axis_name)
|
|
|
|
return (axis_name,)
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def used_axis_names(primitive: Primitive, params: ParamDict) -> set[AxisName]:
|
2022-12-15 20:34:43 -08:00
|
|
|
subst = NameGatheringSubst()
|
|
|
|
subst_axis_names(primitive, params, subst)
|
|
|
|
return subst.axis_names
|
|
|
|
|
|
|
|
def subst_axis_names(primitive: Primitive, params: ParamDict, subst: AxisSubst, traverse: bool = True) -> ParamDict:
|
|
|
|
if primitive in axis_substitution_rules:
|
|
|
|
return axis_substitution_rules[primitive](params, subst, traverse)
|
|
|
|
if not traverse:
|
|
|
|
return params
|
|
|
|
# Default implementation: substitute names in all jaxpr parameters
|
|
|
|
if isinstance(primitive, MapPrimitive):
|
|
|
|
def shadowed_subst(name):
|
|
|
|
return (name,) if name == params['axis_name'] else subst(name)
|
|
|
|
else:
|
|
|
|
shadowed_subst = subst
|
|
|
|
jaxpr_params = [(n, v) for n, v in params.items() if isinstance(v, (Jaxpr, ClosedJaxpr))]
|
|
|
|
if not jaxpr_params:
|
|
|
|
return params
|
|
|
|
new_params = dict(params)
|
|
|
|
for name, jaxpr in jaxpr_params:
|
|
|
|
new_params[name] = subst_axis_names_jaxpr(jaxpr, shadowed_subst)
|
|
|
|
return new_params
|
|
|
|
|
|
|
|
class DuplicateAxisNameError(Exception):
|
|
|
|
def __init__(self, var):
|
|
|
|
self.var = var
|
|
|
|
self.eqn = None
|
|
|
|
|
2024-03-04 05:41:29 -08:00
|
|
|
def subst_axis_names_effects(effects: Set[Effect], subst: AxisSubst) -> Set[Effect]:
|
|
|
|
new_effects = set[Effect]()
|
|
|
|
for e in effects:
|
|
|
|
if isinstance(e, NamedAxisEffect):
|
|
|
|
new_effects.update(map(NamedAxisEffect, subst(e.name)))
|
|
|
|
else:
|
|
|
|
new_effects.add(e)
|
|
|
|
return new_effects
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: dict[Var, Var]) -> Var:
|
2022-12-15 20:34:43 -08:00
|
|
|
# Var identity is load-bearing, so we can't have duplicates!
|
|
|
|
if isinstance(v, DropVar): return v
|
|
|
|
assert v not in var_map
|
|
|
|
if not hasattr(v.aval, 'named_shape'):
|
|
|
|
var_map[v] = v
|
|
|
|
return v
|
|
|
|
names = tuple(it.chain.from_iterable(subst(name) for name in v.aval.named_shape))
|
|
|
|
named_shape = {name: axis_frame(name).size for name in names}
|
|
|
|
if len(named_shape) != len(names):
|
|
|
|
raise DuplicateAxisNameError(v)
|
Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
2024-02-16 05:56:45 -08:00
|
|
|
new_v = Var(v.suffix, v.aval.update(named_shape=named_shape))
|
2022-12-15 20:34:43 -08:00
|
|
|
var_map[v] = new_v
|
|
|
|
return new_v
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: dict[Var, Var]) -> JaxprEqn:
|
|
|
|
invars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in eqn.invars]
|
2022-12-15 20:34:43 -08:00
|
|
|
try:
|
|
|
|
outvars = [subst_axis_names_var(v, subst, var_map) for v in eqn.outvars]
|
|
|
|
except DuplicateAxisNameError as e:
|
|
|
|
e.eqn = eqn
|
|
|
|
raise
|
|
|
|
params = subst_axis_names(eqn.primitive, eqn.params, subst)
|
2024-03-04 05:41:29 -08:00
|
|
|
effects = subst_axis_names_effects(eqn.effects, subst)
|
|
|
|
return eqn.replace(invars=invars, outvars=outvars, params=params, effects=effects)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
def do_subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst):
|
2022-12-15 20:34:43 -08:00
|
|
|
consts = None
|
|
|
|
if isinstance(jaxpr, ClosedJaxpr):
|
|
|
|
consts = jaxpr.consts
|
|
|
|
jaxpr = jaxpr.jaxpr
|
2023-06-23 15:11:37 -07:00
|
|
|
var_map: dict[Var, Var] = {}
|
2023-01-23 09:37:58 -08:00
|
|
|
invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars] # type: ignore[union-attr]
|
|
|
|
constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars] # type: ignore[union-attr]
|
2024-05-17 09:46:36 +01:00
|
|
|
eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns]
|
2023-06-23 15:11:37 -07:00
|
|
|
outvars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars] # type: ignore[union-attr]
|
2024-03-04 05:41:29 -08:00
|
|
|
effects = subst_axis_names_effects(jaxpr.effects, subst)
|
|
|
|
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, effects)
|
2022-12-15 20:34:43 -08:00
|
|
|
if consts is not None:
|
|
|
|
return ClosedJaxpr(new_jaxpr, consts)
|
|
|
|
return new_jaxpr
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
def used_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr):
|
2024-03-04 05:41:29 -08:00
|
|
|
return {e.name for e in jaxpr.effects if isinstance(e, NamedAxisEffect)}
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
def subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst):
|
2022-12-15 20:34:43 -08:00
|
|
|
if isinstance(subst, NameGatheringSubst): # This is a common case, so we optimize it!
|
|
|
|
subst.axis_names |= used_axis_names_jaxpr(jaxpr)
|
|
|
|
return jaxpr
|
|
|
|
return do_subst_axis_names_jaxpr(jaxpr, subst)
|
|
|
|
|
2023-01-18 10:17:01 -08:00
|
|
|
def replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: Effects):
|
|
|
|
return _replace_jaxpr_effects(jaxpr, frozenset(effects))
|
|
|
|
|
|
|
|
@weakref_lru_cache
|
2023-06-23 15:11:37 -07:00
|
|
|
def _replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: frozenset[Effect]):
|
2023-01-18 10:17:01 -08:00
|
|
|
return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(effects=set(effects)))
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
axis_substitution_rules: dict[Primitive, Callable[[ParamDict, AxisSubst, bool], ParamDict]] = {}
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
# ------------------- AxisPrimitive -------------------
|
|
|
|
# Primitives that store axis names in params and want those axis names to
|
|
|
|
# participate in dispatch should subclass AxisPrimitive.
|
|
|
|
|
|
|
|
class AxisPrimitive(Primitive):
|
|
|
|
def bind(self, *args, **params):
|
|
|
|
top_trace = find_top_trace(args)
|
|
|
|
axis_main = max((axis_frame(a).main_trace for a in used_axis_names(self, params)),
|
|
|
|
default=None, key=lambda t: getattr(t, 'level', -1))
|
|
|
|
top_trace = (top_trace if not axis_main or axis_main.level < top_trace.level
|
|
|
|
else axis_main.with_cur_sublevel())
|
|
|
|
return self.bind_with_trace(top_trace, args, params)
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------- Jaxpr checking -------------------
|
|
|
|
|
|
|
|
def typecheck(aval: AbstractValue, x) -> bool:
|
|
|
|
return typecompat(aval, get_aval(x))
|
|
|
|
|
|
|
|
def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool:
|
|
|
|
"""Determine whether `aval` conforms to `aval_ref`.
|
|
|
|
|
|
|
|
Ignores weak_type and named_shape, other than to check that an axis name isn't
|
|
|
|
used with different sizes.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
return typematch(aval_ref, lattice_join(aval_ref, aval))
|
|
|
|
except TypeError:
|
|
|
|
return False
|
|
|
|
|
|
|
|
def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
|
|
|
|
"""Determine whether `aval1` and `aval2` are equivalent.
|
|
|
|
|
|
|
|
Ignores weak_type and named_shape, other than to check that an axis name isn't
|
|
|
|
used with different sizes.
|
|
|
|
"""
|
|
|
|
if aval1 == aval2: return True
|
|
|
|
# unequal avals may still represent the same type, because type is represented
|
|
|
|
# by avals at the shaped level, and because weak type tags and (for now) named
|
|
|
|
# shape components aren't considered part of the type
|
|
|
|
if isinstance(aval1, ShapedArray) and isinstance(aval2, ShapedArray):
|
|
|
|
# a bonus check for whether any named axes have inconsistent sizes
|
|
|
|
join_named_shapes(aval1.named_shape, aval2.named_shape)
|
|
|
|
return (raise_to_shaped(aval1, weak_type=False).strip_named_shape() ==
|
|
|
|
raise_to_shaped(aval2, weak_type=False).strip_named_shape())
|
|
|
|
|
|
|
|
class JaxprTypeError(TypeError): pass
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
custom_typechecks: dict[Primitive, Callable] = {}
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-03-21 21:43:20 -07:00
|
|
|
def _check_closed_call(_, *in_atoms, call_jaxpr):
|
2022-12-15 20:34:43 -08:00
|
|
|
in_avals = [x.aval for x in in_atoms]
|
2024-05-24 08:23:31 -07:00
|
|
|
if not all(map(typecompat, call_jaxpr.in_avals, in_avals)):
|
2022-12-15 20:34:43 -08:00
|
|
|
raise JaxprTypeError("Closed call in_avals mismatch")
|
|
|
|
return call_jaxpr.out_avals, call_jaxpr.effects
|
|
|
|
custom_typechecks[closed_call_p] = _check_closed_call
|
|
|
|
|
|
|
|
def check_jaxpr(jaxpr: Jaxpr):
|
|
|
|
"""Checks well-formedness of a jaxpr.
|
|
|
|
|
|
|
|
Specifically, check that:
|
|
|
|
- variables that are read are bound beforehand
|
|
|
|
- variables are typed equally throughout a jaxpr
|
|
|
|
- variable type annotations are compatible with their binding expression
|
|
|
|
|
|
|
|
Raises `JaxprTypeError` if `jaxpr` is determined invalid. Returns `None`
|
|
|
|
otherwise.
|
|
|
|
"""
|
2023-07-21 14:20:39 -04:00
|
|
|
@functools.cache
|
2022-12-15 20:34:43 -08:00
|
|
|
def ctx_factory():
|
|
|
|
ctx = JaxprPpContext()
|
|
|
|
pp_settings = JaxprPpSettings()
|
|
|
|
try: pp_jaxpr(jaxpr, ctx, pp_settings) # side-effect on ctx, build variable names
|
|
|
|
except: pass
|
|
|
|
return ctx, pp_settings
|
|
|
|
|
|
|
|
try:
|
|
|
|
_check_jaxpr(ctx_factory, jaxpr)
|
|
|
|
except JaxprTypeError as e:
|
|
|
|
ctx, pp_settings = ctx_factory()
|
|
|
|
if len(e.args) == 2:
|
|
|
|
msg, eqnidx = e.args
|
|
|
|
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqnidx - 10, eqnidx + 10, ctx,
|
|
|
|
pp_settings))
|
|
|
|
else:
|
|
|
|
msg, = e.args
|
|
|
|
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20, ctx, pp_settings))
|
|
|
|
msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
|
|
|
|
raise JaxprTypeError(msg) from None
|
|
|
|
|
2023-12-11 12:03:48 -08:00
|
|
|
# Run key reuse checker after validating jaxpr:
|
2024-03-21 10:47:16 -07:00
|
|
|
if config.debug_key_reuse.value:
|
2023-12-11 12:03:48 -08:00
|
|
|
# Import here to avoid circular imports
|
|
|
|
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error
|
|
|
|
check_key_reuse_jaxpr(jaxpr)
|
|
|
|
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
def _check_jaxpr(
|
2023-06-23 15:11:37 -07:00
|
|
|
ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]],
|
2022-12-15 20:34:43 -08:00
|
|
|
jaxpr: Jaxpr
|
|
|
|
) -> None:
|
|
|
|
# Use set of variables to types to check that variables are in scope.
|
2023-06-23 15:11:37 -07:00
|
|
|
env: set[Var] = set()
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def read(x: Atom) -> Atom:
|
|
|
|
# Check the type annotation is itself well-typed.
|
|
|
|
check_type(ctx_factory, env, x.aval)
|
|
|
|
if isinstance(x, Var):
|
|
|
|
# Check the variable is in-scope and consistently typed.
|
|
|
|
if x not in env:
|
|
|
|
ctx, _ = ctx_factory()
|
|
|
|
raise JaxprTypeError(f"Variable '{pp_var(x, ctx)}' not defined")
|
|
|
|
return x
|
|
|
|
elif isinstance(x, Literal):
|
|
|
|
# Check that the literal matches its type annotation.
|
|
|
|
if not typecheck(x.aval, x.val):
|
|
|
|
ctx, _ = ctx_factory()
|
|
|
|
raise JaxprTypeError(
|
|
|
|
f"Literal value {x.val} does not match its type annotation "
|
|
|
|
f"{pp_aval(x.aval, ctx)}")
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
assert False, "syntactically invalid jaxpr"
|
|
|
|
|
|
|
|
def write(v: Var, a: AbstractValue) -> None:
|
|
|
|
assert isinstance(v, Var), "syntactically invalid jaxpr"
|
|
|
|
# Check the type annotation of the binder is itself well-typed.
|
|
|
|
check_type(ctx_factory, env, v.aval)
|
|
|
|
# Check that the variable is not already bound.
|
|
|
|
if v in env:
|
|
|
|
ctx, _ = ctx_factory()
|
|
|
|
raise JaxprTypeError(f"Variable '{pp_var(v, ctx)}' already bound")
|
|
|
|
# Check that the computed type is consistent with the binder annotation.
|
|
|
|
if not typematch(v.aval, a):
|
|
|
|
ctx, _ = ctx_factory()
|
|
|
|
raise JaxprTypeError(
|
|
|
|
f"Value for variable '{pp_var(v, ctx)}' inconsistently typed "
|
|
|
|
f"as {pp_aval(a, ctx)} for let-binder of type {pp_aval(v.aval, ctx)}")
|
|
|
|
# If the variable is not a DropVar, add it to the environment.
|
|
|
|
if not isinstance(v, DropVar):
|
|
|
|
env.add(v)
|
|
|
|
|
|
|
|
# Check type annotations on lambda binders.
|
|
|
|
for v in it.chain(jaxpr.constvars, jaxpr.invars):
|
|
|
|
check_type(ctx_factory, env, v.aval)
|
|
|
|
write(v, v.aval)
|
|
|
|
|
|
|
|
# Check each eqn.
|
2024-04-02 10:30:35 -07:00
|
|
|
sentinel = object()
|
|
|
|
in_idx = {v: i for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))}
|
2022-12-15 20:34:43 -08:00
|
|
|
for eqn_idx, eqn in enumerate(jaxpr.eqns):
|
|
|
|
prim = eqn.primitive
|
|
|
|
try:
|
|
|
|
in_atoms = map(read, eqn.invars)
|
|
|
|
in_avals = [x.aval for x in in_atoms] # use in_atoms for dyn shapes
|
|
|
|
|
|
|
|
# Compute the type of the primitive application.
|
|
|
|
if prim in custom_typechecks:
|
2023-03-21 21:43:20 -07:00
|
|
|
out_type, eqn_effects = custom_typechecks[prim](
|
|
|
|
ctx_factory, *in_atoms, **eqn.params)
|
2022-12-15 20:34:43 -08:00
|
|
|
elif prim.call_primitive:
|
2023-02-17 12:45:39 -08:00
|
|
|
out_type, eqn_effects = _check_call(ctx_factory, prim, in_atoms,
|
|
|
|
eqn.params)
|
2022-12-15 20:34:43 -08:00
|
|
|
elif prim.map_primitive:
|
2023-02-17 12:45:39 -08:00
|
|
|
out_type, eqn_effects = _check_map(ctx_factory, prim, in_avals,
|
|
|
|
eqn.params)
|
2022-12-15 20:34:43 -08:00
|
|
|
else:
|
2023-02-17 12:45:39 -08:00
|
|
|
out_type, eqn_effects = check_eqn(prim, in_avals, eqn.params)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
# Check the computed effect type matches the eqn's annotation, and is
|
|
|
|
# included in the jaxpr's annotation.
|
2024-04-02 10:30:35 -07:00
|
|
|
if prim is mutable_array_p:
|
|
|
|
outvar, = eqn.outvars
|
|
|
|
in_idx[outvar] = None # type: ignore
|
2023-02-17 12:45:39 -08:00
|
|
|
if eqn.effects != eqn_effects:
|
2022-12-15 20:34:43 -08:00
|
|
|
raise JaxprTypeError("Inferred effects do not match equation effects. "
|
|
|
|
f"Equation effects: {eqn.effects}. "
|
2023-02-17 12:45:39 -08:00
|
|
|
f"Inferred effects: {eqn_effects}")
|
|
|
|
for eff in eqn.effects:
|
|
|
|
if isinstance(eff, effects.JaxprInputEffect):
|
|
|
|
eqn_invar = eqn.invars[eff.input_index]
|
2024-04-02 10:30:35 -07:00
|
|
|
if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel:
|
2023-02-17 12:45:39 -08:00
|
|
|
raise JaxprTypeError(
|
|
|
|
"Invalid `JaxprInputEffect`: must correspond to a jaxpr invar")
|
|
|
|
jaxpr_effect = eff.replace(input_index=jaxpr_index)
|
|
|
|
if jaxpr_effect not in jaxpr.effects:
|
|
|
|
raise JaxprTypeError(
|
|
|
|
"Invalid `JaxprInputEffect`: must be present in jaxpr. "
|
|
|
|
f"{jaxpr_effect} is not in {jaxpr.effects}.")
|
2024-03-04 05:41:29 -08:00
|
|
|
elif isinstance(eff, NamedAxisEffect):
|
|
|
|
# It is valid for a primitive to discharge the named axis effect.
|
|
|
|
continue
|
2023-02-17 12:45:39 -08:00
|
|
|
elif eff not in jaxpr.effects:
|
|
|
|
raise JaxprTypeError("Equation effect not present in jaxpr effects. "
|
|
|
|
f"Equation effect: {eff}. "
|
|
|
|
f"Jaxpr effects: {jaxpr.effects}")
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
# Check out_type matches the let-binders' annotation (after substitution).
|
|
|
|
out_type = substitute_vars_in_output_ty(out_type, eqn.invars, eqn.outvars)
|
|
|
|
map(write, eqn.outvars, out_type)
|
|
|
|
|
|
|
|
except JaxprTypeError as e:
|
|
|
|
ctx, settings = ctx_factory()
|
|
|
|
msg, = e.args
|
|
|
|
src = source_info_util.summarize(eqn.source_info)
|
|
|
|
msg = "\n\n".join([msg, "in equation:", str(pp.nest(2, pp_eqn(eqn, ctx, settings))),
|
|
|
|
f"from source: {src}"])
|
|
|
|
raise JaxprTypeError(msg, eqn_idx) from None
|
|
|
|
|
|
|
|
# TODO(mattjj): include output type annotation on jaxpr and check it here
|
|
|
|
map(read, jaxpr.outvars)
|
|
|
|
|
|
|
|
def check_type(
|
2023-06-23 15:11:37 -07:00
|
|
|
ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]],
|
|
|
|
env: set[Var],
|
2022-12-15 20:34:43 -08:00
|
|
|
ty: AbstractValue,
|
|
|
|
) -> None:
|
|
|
|
if isinstance(ty, DShapedArray):
|
|
|
|
# Check all elements in the shape tuple are well-typed.
|
|
|
|
for d in ty.shape:
|
|
|
|
if (isinstance(d, int) or
|
|
|
|
isinstance(d, DArray) and not d.shape and type(d.dtype) == bint):
|
|
|
|
continue
|
|
|
|
elif isinstance(d, Var):
|
|
|
|
if d not in env:
|
|
|
|
ctx, _ = ctx_factory()
|
|
|
|
raise JaxprTypeError(f"unbound axis size: '{pp_var(d, ctx)}'")
|
|
|
|
if not isinstance(d.aval, (ShapedArray, DShapedArray)):
|
|
|
|
raise JaxprTypeError(f"axis size with unexpected type annotation: "
|
|
|
|
f"{d.aval} of type {type(d.aval)}")
|
|
|
|
if isinstance(d.aval, ShapedArray):
|
|
|
|
shape, dtype = d.aval.shape, d.aval.dtype
|
|
|
|
if shape: raise JaxprTypeError(f"axis size nonscalar: {d.aval}")
|
|
|
|
if not dtypes.issubdtype(dtype, np.integer):
|
|
|
|
raise JaxprTypeError(f"axis size with non-integer dtype: {d.aval}")
|
|
|
|
else:
|
|
|
|
assert isinstance(d.aval, DShapedArray)
|
|
|
|
shape, dtype = d.aval.shape, d.aval.dtype
|
|
|
|
if shape: raise JaxprTypeError(f"axis size nonscalar: {d.aval}")
|
|
|
|
if type(dtype) is not bint:
|
|
|
|
raise JaxprTypeError(
|
|
|
|
f"DArray axis size with non-bint dtype: {d.aval}")
|
|
|
|
else:
|
|
|
|
raise JaxprTypeError(f"unexpected type in shape: {type(d)}")
|
|
|
|
else:
|
|
|
|
return # Except in above case(s), all syntactic forms are valid
|
|
|
|
|
|
|
|
def substitute_vars_in_output_ty(
|
|
|
|
out_type: Sequence[AbstractValue], # shapes may contain InDBIdx / OutDBIdx
|
|
|
|
in_atoms: Sequence[Atom],
|
|
|
|
out_binders: Sequence[Var],
|
2023-06-23 15:11:37 -07:00
|
|
|
) -> list[AbstractValue]: # shapes may contain Vars
|
2022-12-15 20:34:43 -08:00
|
|
|
in_atoms = [x.val if type(x) is Literal else x for x in in_atoms]
|
|
|
|
result = []
|
|
|
|
for aval in out_type:
|
|
|
|
if type(aval) is DShapedArray:
|
2024-05-17 09:46:36 +01:00
|
|
|
shape = [in_atoms[d.val] if type(d) is InDBIdx else
|
|
|
|
out_binders[d.val] if type(d) is OutDBIdx else
|
2022-12-15 20:34:43 -08:00
|
|
|
d for d in aval.shape]
|
|
|
|
aval = aval.update(shape=tuple(shape))
|
|
|
|
result.append(aval)
|
|
|
|
return result
|
|
|
|
|
|
|
|
def check_eqn(prim, in_avals, params):
|
|
|
|
for jaxpr in jaxprs_in_params(params):
|
|
|
|
check_jaxpr(jaxpr)
|
|
|
|
|
|
|
|
out_avals, effects = prim.abstract_eval(*in_avals, **params)
|
|
|
|
if not prim.multiple_results:
|
|
|
|
out_avals = [out_avals]
|
|
|
|
return out_avals, effects
|
|
|
|
|
|
|
|
def _check_call(ctx_factory, prim, in_atoms, params):
|
|
|
|
if "call_jaxpr" not in params:
|
|
|
|
raise JaxprTypeError(
|
|
|
|
f"Call primitive {prim} missing 'call_jaxpr' parameter")
|
|
|
|
call_jaxpr = params["call_jaxpr"]
|
|
|
|
|
|
|
|
if len(in_atoms) != len(call_jaxpr.invars):
|
|
|
|
raise JaxprTypeError(f"Call primitive {prim} with {len(in_atoms)} "
|
|
|
|
f"operands cannot call jaxpr with "
|
|
|
|
f"{len(call_jaxpr.invars)} inputs")
|
|
|
|
|
|
|
|
# Check `call_jaxpr` can be applied to in_atoms.
|
2023-06-23 15:11:37 -07:00
|
|
|
env: dict[Var, Atom] = {}
|
2022-12-15 20:34:43 -08:00
|
|
|
def substitute(aval: AbstractValue):
|
|
|
|
if isinstance(aval, DShapedArray):
|
2023-09-21 22:19:29 +01:00
|
|
|
aval = aval.update(shape=tuple(env.get(d, d) for d in aval.shape)) # type: ignore
|
2022-12-15 20:34:43 -08:00
|
|
|
return aval
|
|
|
|
for v, x in zip(call_jaxpr.invars, in_atoms):
|
|
|
|
if not typecompat(substitute(v.aval), x.aval):
|
|
|
|
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
|
|
|
|
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
|
|
|
|
f"{x.aval} to jaxpr expecting type "
|
|
|
|
f"{substitute(v.aval)}")
|
|
|
|
env[v] = x if type(x) is Var else x.val
|
|
|
|
|
|
|
|
_check_jaxpr(ctx_factory, call_jaxpr)
|
|
|
|
|
|
|
|
invars, outvars = call_jaxpr.invars, call_jaxpr.outvars
|
2023-06-23 15:11:37 -07:00
|
|
|
in_map : dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)}
|
|
|
|
out_map: dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars)
|
2022-12-15 20:34:43 -08:00
|
|
|
if type(x) is Var}
|
|
|
|
out_avals = [x.aval for x in call_jaxpr.outvars]
|
|
|
|
out_type = [a.update(shape=tuple(in_map.get(d, out_map.get(d))
|
|
|
|
if type(d) is Var else d for d in a.shape))
|
|
|
|
if type(a) is DShapedArray else a for a in out_avals]
|
|
|
|
return out_type, call_jaxpr.effects
|
|
|
|
|
|
|
|
def _check_map(ctx_factory, prim, in_avals, params):
|
|
|
|
if "call_jaxpr" not in params:
|
|
|
|
raise JaxprTypeError(f"Map primitive {prim} missing 'call_jaxpr' parameter")
|
|
|
|
call_jaxpr = params["call_jaxpr"]
|
2023-02-01 17:50:00 -08:00
|
|
|
ordered_effects_ = effects.ordered_effects.filter_in(call_jaxpr.effects)
|
2022-12-15 20:34:43 -08:00
|
|
|
if ordered_effects_:
|
|
|
|
raise JaxprTypeError(
|
|
|
|
f"Map primitive {prim} mapping ordered effects: {ordered_effects_}")
|
|
|
|
if "axis_size" not in params:
|
|
|
|
raise JaxprTypeError(f"Map primitive {prim} missing 'axis_size' parameter")
|
|
|
|
axis_size = params["axis_size"]
|
|
|
|
if "axis_name" not in params:
|
|
|
|
raise JaxprTypeError(f"Map primitive {prim} missing 'axis_name' parameter")
|
|
|
|
axis_name = params["axis_name"]
|
|
|
|
if "in_axes" not in params:
|
|
|
|
raise JaxprTypeError(f"Map primitive {prim} missing 'in_axes' parameter")
|
|
|
|
in_axes = params["in_axes"]
|
|
|
|
if "out_axes" not in params:
|
|
|
|
raise JaxprTypeError(f"Map primitive {prim} missing 'out_axes' parameter")
|
|
|
|
out_axes = params["out_axes"]
|
|
|
|
|
|
|
|
binder_avals = [unmapped_aval(axis_size, axis_name, in_axis, v.aval)
|
|
|
|
if in_axis is not None else v.aval
|
|
|
|
for v, in_axis in zip(call_jaxpr.invars, in_axes)]
|
|
|
|
for binder_aval, in_aval in zip(binder_avals, in_avals):
|
|
|
|
if not typecompat(binder_aval, in_aval):
|
|
|
|
raise JaxprTypeError(f"Call primitive {prim} passes operand {in_aval} "
|
|
|
|
f"to jaxpr expecting {binder_aval}")
|
|
|
|
|
|
|
|
with extend_axis_env(params['axis_name'], axis_size, None):
|
|
|
|
_check_jaxpr(ctx_factory, call_jaxpr)
|
|
|
|
|
|
|
|
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
|
|
|
|
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval)
|
|
|
|
if out_axis is not None else aval
|
|
|
|
for aval, out_axis in zip(mapped_out_avals, out_axes)]
|
2024-03-04 05:41:29 -08:00
|
|
|
return out_avals, filter_named_axis_effects(call_jaxpr.effects, {axis_name})
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
|
|
|
|
# ------------------- Jaxpr printed representation -------------------
|
|
|
|
|
2024-05-06 09:59:18 -04:00
|
|
|
def pp_toplevel_jaxpr(jaxpr_to_print, *, source_info=False, print_shapes=True,
|
|
|
|
custom_pp_eqn_rules=True, name_stack=False,
|
|
|
|
print_effects: bool = False) -> pp.Doc:
|
|
|
|
context = JaxprPpContext()
|
|
|
|
settings = JaxprPpSettings(
|
|
|
|
source_info=source_info,
|
|
|
|
print_shapes=print_shapes,
|
|
|
|
custom_pp_eqn_rules=custom_pp_eqn_rules,
|
|
|
|
name_stack=name_stack,
|
|
|
|
print_effects=print_effects)
|
|
|
|
|
|
|
|
# Compute how many times each jaxpr is used.
|
|
|
|
names = defaultdict[Jaxpr, str](lambda: "jaxpr")
|
|
|
|
jaxpr_counts = Counter[Jaxpr]()
|
|
|
|
s = deque([jaxpr_to_print])
|
|
|
|
while s:
|
|
|
|
jaxpr = s.popleft()
|
|
|
|
jaxpr_counts[jaxpr] += 1
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
# TODO(slebedev): Come up with a more elaborate heuristic for name=.
|
|
|
|
name = eqn.params.get("name")
|
|
|
|
if name is None:
|
|
|
|
s.extend(jaxprs_in_params(eqn.params))
|
|
|
|
continue
|
|
|
|
name = name.strip("<>") # <lambda> -> lambda
|
|
|
|
for subjaxpr in jaxprs_in_params(eqn.params):
|
|
|
|
s.append(subjaxpr)
|
|
|
|
names.setdefault(subjaxpr, name)
|
|
|
|
|
|
|
|
# Pull jaxprs occurring more than once to the top-level, making sure
|
|
|
|
# that their names are unique.
|
|
|
|
docs = []
|
|
|
|
name_counts = Counter[str]()
|
|
|
|
for jaxpr, c in jaxpr_counts.items():
|
|
|
|
if c == 1:
|
|
|
|
continue
|
|
|
|
name = names[jaxpr]
|
|
|
|
if (count := name_counts[name]) > 0:
|
|
|
|
name_counts[name] += 1
|
|
|
|
name += str(count)
|
|
|
|
name_counts[name] += 1
|
|
|
|
else:
|
|
|
|
name_counts[name] += 1
|
|
|
|
docs.append(pp_top_level_jaxpr(name, jaxpr, context, settings))
|
|
|
|
context.used_names.add(name)
|
|
|
|
context.top_level_jaxprs[jaxpr] = name
|
|
|
|
docs.append(pp_jaxpr(jaxpr_to_print, context, settings))
|
|
|
|
return pp.concat(docs)
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
class JaxprPpSettings(NamedTuple):
|
|
|
|
print_shapes: bool = True
|
|
|
|
source_info: bool = False
|
|
|
|
name_stack: bool = False
|
|
|
|
custom_pp_eqn_rules: bool = True
|
2023-02-17 12:45:39 -08:00
|
|
|
print_effects: bool = False
|
2022-12-15 20:34:43 -08:00
|
|
|
|
Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
2024-02-16 05:56:45 -08:00
|
|
|
def _encode_digits_alphabetic(n: int) -> str:
|
|
|
|
if n == -1:
|
|
|
|
return '*'
|
|
|
|
s = ''
|
|
|
|
while len(s) == 0 or n:
|
|
|
|
n, i = n // 26, n % 26
|
|
|
|
s = chr(97 + i % 26) + s
|
|
|
|
return s
|
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
# A JaxprPpContext allows us to globally uniquify variable names within nested
|
|
|
|
# Jaxprs.
|
|
|
|
class JaxprPpContext:
|
2023-12-08 20:10:08 +00:00
|
|
|
var_names: defaultdict[Var, str]
|
|
|
|
used_names: MutableSet[str]
|
|
|
|
top_level_jaxprs: MutableMapping[Jaxpr, str]
|
2022-12-15 20:34:43 -08:00
|
|
|
|
2023-12-08 20:10:08 +00:00
|
|
|
def __init__(self) -> None:
|
|
|
|
self.top_level_jaxprs = {}
|
|
|
|
self.used_names = set()
|
|
|
|
fresh_names: Iterator[str] = (
|
|
|
|
name
|
|
|
|
for i in it.count()
|
|
|
|
if (name := _encode_digits_alphabetic(i)) not in self.used_names
|
|
|
|
)
|
|
|
|
self.var_names = defaultdict(fresh_names.__next__)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
|
Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
2024-02-16 05:56:45 -08:00
|
|
|
def pp_var(v: Var | Literal, context: JaxprPpContext) -> str:
|
2022-12-15 20:34:43 -08:00
|
|
|
if isinstance(v, (Literal, DropVar)): return str(v)
|
2023-12-08 20:10:08 +00:00
|
|
|
return f"{context.var_names[v]}{v.suffix}"
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def pp_aval(a: AbstractValue, context: JaxprPpContext) -> str:
|
|
|
|
if isinstance(a, DShapedArray):
|
|
|
|
shape = [pp_var(d, context) if type(d) is Var else str(d) for d in a.shape]
|
|
|
|
dtype = _short_dtype_name(a.dtype)
|
|
|
|
return f'{dtype}[{",".join(shape)}]'
|
|
|
|
else:
|
|
|
|
return a.str_short(short_dtypes=True)
|
|
|
|
|
|
|
|
def pp_vars(vs: Sequence[Any], context: JaxprPpContext,
|
|
|
|
*, separator="", print_shapes: bool = False) -> pp.Doc:
|
|
|
|
if print_shapes:
|
|
|
|
return pp.nest(2, pp.group(
|
|
|
|
pp.join(pp.text(separator) + pp.group(pp.brk()), [
|
|
|
|
pp.text(pp_var(v, context)) +
|
|
|
|
pp.type_annotation(pp.text(":" + pp_aval(v.aval, context)))
|
|
|
|
for v in vs
|
|
|
|
])
|
|
|
|
))
|
|
|
|
else:
|
|
|
|
return pp.nest(2, pp.group(
|
|
|
|
pp.join(pp.text(separator) + pp.group(pp.brk()),
|
|
|
|
[pp.text(pp_var(v, context)) for v in vs])
|
|
|
|
))
|
|
|
|
|
|
|
|
def pp_kv_pair(k:str, v: Any, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
|
|
|
|
if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
|
|
|
|
pp_v = pp_jaxprs(v, context, settings)
|
|
|
|
elif isinstance(v, Jaxpr):
|
|
|
|
pp_v = pp_jaxpr(v, context, settings)
|
|
|
|
elif isinstance(v, ClosedJaxpr):
|
|
|
|
pp_v = pp_jaxpr(v.jaxpr, context, settings)
|
|
|
|
else:
|
|
|
|
pp_v = pp.text(str(v))
|
|
|
|
return pp.text(f'{k}=') + pp_v
|
|
|
|
|
|
|
|
def pp_kv_pairs(kv_pairs, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
|
|
|
|
if not kv_pairs:
|
|
|
|
return pp.nil()
|
|
|
|
return pp.group(
|
|
|
|
pp.nest(2, pp.concat([
|
|
|
|
pp.text("["), pp.brk(""),
|
|
|
|
pp.join(pp.brk(), [pp_kv_pair(k, v, context, settings) for k, v in kv_pairs])
|
|
|
|
]))
|
|
|
|
+ pp.brk("") + pp.text("]")
|
|
|
|
)
|
|
|
|
|
2023-02-09 11:02:24 -08:00
|
|
|
def pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings
|
|
|
|
) -> pp.Doc:
|
|
|
|
rule = (_pp_eqn if not settings.custom_pp_eqn_rules else
|
|
|
|
pp_eqn_rules.get(eqn.primitive, _pp_eqn))
|
2024-05-06 09:59:18 -04:00
|
|
|
doc = rule(eqn, context, settings) # type: ignore[operator]
|
|
|
|
user_frame = source_info_util.user_frame(eqn.source_info)
|
|
|
|
return doc if user_frame is None else pp.source_map(doc, user_frame)
|
2023-02-09 11:02:24 -08:00
|
|
|
|
2023-12-07 15:56:56 +00:00
|
|
|
def _pp_eqn(eqn, context, settings, params=None) -> pp.Doc:
|
2022-12-15 20:34:43 -08:00
|
|
|
annotation = (source_info_util.summarize(eqn.source_info)
|
|
|
|
if settings.source_info else None)
|
2023-12-07 15:56:56 +00:00
|
|
|
if params is None:
|
|
|
|
params = sorted(eqn.params)
|
2022-12-15 20:34:43 -08:00
|
|
|
name_stack_annotation = f'[{eqn.source_info.name_stack}]' if settings.name_stack else None
|
|
|
|
lhs = pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
|
|
|
|
rhs = [pp.text(eqn.primitive.name, annotation=name_stack_annotation),
|
2023-12-07 15:56:56 +00:00
|
|
|
pp_kv_pairs([(p, eqn.params[p]) for p in params], context, settings),
|
2022-12-15 20:34:43 -08:00
|
|
|
pp.text(" ") + pp_vars(eqn.invars, context)]
|
2023-12-06 10:45:24 -08:00
|
|
|
if lhs.format():
|
|
|
|
return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs])
|
|
|
|
else:
|
|
|
|
return pp.concat(rhs)
|
2023-02-09 11:02:24 -08:00
|
|
|
CustomPpEqnRule = Callable[[JaxprEqn, JaxprPpContext, JaxprPpSettings], pp.Doc]
|
2023-06-23 15:11:37 -07:00
|
|
|
pp_eqn_rules: dict[Primitive, CustomPpEqnRule] = {}
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def pp_eqns(eqns, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
|
|
|
|
return pp.join(
|
|
|
|
pp.brk("; "),
|
|
|
|
[pp_eqn(e, context, settings) for e in eqns])
|
|
|
|
|
|
|
|
def _compact_eqn_should_include(k: str, v: Any) -> bool:
|
|
|
|
if k == 'branches': return False
|
|
|
|
if isinstance(v, (Jaxpr, ClosedJaxpr)): return False
|
|
|
|
if (isinstance(v, tuple) and
|
|
|
|
any(isinstance(e, (Jaxpr, ClosedJaxpr)) for e in v)):
|
|
|
|
return False
|
|
|
|
return True
|
|
|
|
|
2024-07-12 16:02:48 -07:00
|
|
|
def str_eqn_compact(primitive: Primitive, params: dict[Any, Any]) -> str:
|
2022-12-15 20:34:43 -08:00
|
|
|
"Compact equation to string conversion used in HLO metadata."
|
2024-07-12 16:02:48 -07:00
|
|
|
if primitive in custom_str_eqn_compact_rules:
|
|
|
|
return custom_str_eqn_compact_rules[primitive](primitive, params)
|
|
|
|
primitive_name = primitive.name
|
2022-12-15 20:34:43 -08:00
|
|
|
kvs = " ".join(f"{k}={v}" for k, v in params.items()
|
|
|
|
if _compact_eqn_should_include(k, v))
|
|
|
|
return f"{primitive_name}[{kvs}]" if len(kvs) > 0 else primitive_name
|
2024-07-12 16:02:48 -07:00
|
|
|
custom_str_eqn_compact_rules: dict[
|
|
|
|
Primitive, Callable[[Primitive, dict[Any, Any]], str]
|
|
|
|
] = {}
|
2022-12-15 20:34:43 -08:00
|
|
|
|
|
|
|
def pp_jaxpr_skeleton(jaxpr, eqns_fn, context: JaxprPpContext,
|
|
|
|
settings: JaxprPpSettings) -> pp.Doc:
|
|
|
|
constvars = pp_vars(jaxpr.constvars, context, print_shapes=settings.print_shapes)
|
|
|
|
invars = pp_vars(jaxpr.invars, context, print_shapes=settings.print_shapes)
|
|
|
|
eqns = eqns_fn()
|
|
|
|
outvars = pp.concat([
|
|
|
|
pp.text("("), pp_vars(jaxpr.outvars, context, separator=","),
|
|
|
|
pp.text(")" if len(jaxpr.outvars) != 1 else ",)")])
|
2023-02-17 12:45:39 -08:00
|
|
|
if settings.print_effects:
|
|
|
|
# TODO(sharadmv): render an entire signature here
|
|
|
|
eff_text = [pp.text(" : { ")]
|
|
|
|
for i, eff in enumerate(jaxpr.effects):
|
|
|
|
if i > 0:
|
|
|
|
eff_text.append(pp.text(", "))
|
|
|
|
if isinstance(eff, effects.JaxprInputEffect):
|
|
|
|
index = eff.input_index
|
|
|
|
all_vars = [*jaxpr.constvars, *jaxpr.invars]
|
|
|
|
eff_text.append(pp_effect(eff.replace(input_index=all_vars[index]),
|
|
|
|
context))
|
|
|
|
else:
|
|
|
|
eff_text.append(pp_effect(eff, context))
|
|
|
|
eff_text.append(pp.text(" }"))
|
|
|
|
else:
|
|
|
|
eff_text = []
|
2022-12-15 20:34:43 -08:00
|
|
|
return pp.group(pp.nest(2, pp.concat([
|
|
|
|
pp.text("{ "), pp.keyword(pp.text("lambda ")),
|
|
|
|
constvars, pp.text("; "), invars,
|
|
|
|
pp.text(". "), pp.keyword(pp.text("let")),
|
|
|
|
pp.nest(2, pp.brk() + eqns), pp.brk(),
|
2023-02-17 12:45:39 -08:00
|
|
|
pp.keyword(pp.text("in ")), outvars,
|
|
|
|
pp.concat(eff_text)
|
2022-12-15 20:34:43 -08:00
|
|
|
])) + pp.text(" }"))
|
|
|
|
|
|
|
|
|
2023-12-08 20:10:08 +00:00
|
|
|
def pp_top_level_jaxpr(
|
|
|
|
name: str,
|
|
|
|
jaxpr: Jaxpr,
|
|
|
|
context: JaxprPpContext,
|
|
|
|
settings: JaxprPpSettings,
|
|
|
|
) -> pp.Doc:
|
|
|
|
return pp.concat([
|
|
|
|
pp.text("let " + name + " = "),
|
|
|
|
pp_jaxpr(jaxpr, context, settings),
|
|
|
|
pp.text(" in"),
|
|
|
|
pp.brk(),
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
def pp_jaxpr(
|
|
|
|
jaxpr: Jaxpr,
|
|
|
|
context: JaxprPpContext,
|
|
|
|
settings: JaxprPpSettings,
|
|
|
|
) -> pp.Doc:
|
|
|
|
if name := context.top_level_jaxprs.get(jaxpr):
|
|
|
|
return pp.text(name)
|
2022-12-15 20:34:43 -08:00
|
|
|
eqns_fn = lambda: pp_eqns(jaxpr.eqns, context, settings)
|
|
|
|
return pp_jaxpr_skeleton(jaxpr, eqns_fn, context, settings)
|
|
|
|
|
2023-12-08 20:10:08 +00:00
|
|
|
|
2022-12-15 20:34:43 -08:00
|
|
|
def pp_jaxprs(jaxprs, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
|
|
|
|
jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
|
|
|
|
return pp.group(pp.nest(2, pp.concat([
|
|
|
|
pp.text('('), pp.brk(""),
|
|
|
|
pp.join(pp.brk(), map(lambda x: pp_jaxpr(x, context, settings), jaxprs))]
|
|
|
|
)) + pp.brk("") + pp.text(')')
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, context: JaxprPpContext,
|
|
|
|
settings: JaxprPpSettings) -> pp.Doc:
|
|
|
|
lo = max(lo, 0)
|
|
|
|
hi = max(lo, min(hi, len(jaxpr.eqns)))
|
|
|
|
eqns = jaxpr.eqns[lo:hi]
|
|
|
|
def eqns_fn():
|
|
|
|
pps = []
|
|
|
|
if len(eqns) == 0 and len(jaxpr.eqns) != 0:
|
|
|
|
pps.append(pp.text('...'))
|
|
|
|
else:
|
|
|
|
if lo != 0:
|
|
|
|
pps.append(pp.text('...'))
|
|
|
|
pps.extend(map((lambda e: pp_eqn(e, context, settings)), eqns))
|
|
|
|
if hi != len(jaxpr.eqns):
|
|
|
|
pps.append(pp.text('...'))
|
|
|
|
return pp.join(pp.brk("; "), pps)
|
|
|
|
return pp_jaxpr_skeleton(jaxpr, eqns_fn, context, settings)
|
2023-02-17 12:45:39 -08:00
|
|
|
|
|
|
|
def pp_effect(effect: Effect, context: JaxprPpContext) -> pp.Doc:
|
|
|
|
if hasattr(effect, "_pretty_print"):
|
|
|
|
return effect._pretty_print(context)
|
|
|
|
return pp.text(str(effect))
|
2023-06-16 06:07:54 -07:00
|
|
|
|
|
|
|
# ------------------- Jaxpr util -------------------
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
def last_used(jaxpr: Jaxpr) -> dict[Var, JaxprEqn | None]:
|
2023-06-16 06:07:54 -07:00
|
|
|
"""Returns a mapping from every var in jaxpr to what equation uses it last."""
|
2023-07-21 14:20:39 -04:00
|
|
|
last_used: dict[Var, JaxprEqn | None] = {
|
2023-06-16 06:07:54 -07:00
|
|
|
v: None for v in jaxpr.outvars if not isinstance(v, Literal)}
|
|
|
|
for eqn in reversed(jaxpr.eqns):
|
|
|
|
for v in eqn.invars:
|
|
|
|
if not isinstance(v, Literal) and v not in last_used:
|
|
|
|
last_used[v] = eqn
|
|
|
|
return last_used
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def clean_up_dead_vars(eqn: JaxprEqn, env: dict[Var, Any],
|
2023-07-21 14:20:39 -04:00
|
|
|
last_used: dict[Var, JaxprEqn | None]):
|
2023-06-16 06:07:54 -07:00
|
|
|
"""Remove all eqn.invars from env if eqn is the last time they were used."""
|
2023-07-21 14:20:39 -04:00
|
|
|
for v in {v for v in eqn.invars if not isinstance(v, Literal)}:
|
2023-06-16 06:07:54 -07:00
|
|
|
if last_used[v] is eqn:
|
|
|
|
# Delete ref to variable when it is no longer needed by next equations.
|
|
|
|
del env[v]
|
2024-07-22 17:56:10 -07:00
|
|
|
|
|
|
|
# Used in shard_map for converting avals
|
|
|
|
shard_aval_handlers = {} # type: ignore
|
|
|
|
unshard_aval_handlers = {} # type: ignore
|