rocm_jax/jax/_src/core.py
2024-09-20 07:52:33 -07:00

3412 lines
118 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import Counter, defaultdict, deque, namedtuple
from collections.abc import (Callable, Collection, Generator, Hashable,
Iterable, Iterator, Set, Sequence, MutableSet,
MutableMapping)
from contextlib import contextmanager, ExitStack
from dataclasses import dataclass
import functools
from functools import partial, partialmethod, total_ordering
import gc
import inspect
import itertools as it
import math
import operator
import threading
import types
from typing import (Any, ClassVar, Generic, NamedTuple, TypeVar,
cast, overload, Union)
import warnings
from weakref import ref
import numpy as np
from jax._src import deprecations
from jax._src import dtypes
from jax._src import config
from jax._src import effects
from jax._src import compute_on
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
tuple_delete, as_hashable_function,
HashableFunction, HashableWrapper, weakref_lru_cache,
partition_list, StrictABCMeta)
import jax._src.pretty_printer as pp
from jax._src.lib import jax_jit
from jax._src import traceback_util
from jax._src.typing import Array, DimSize, Shape
from jax._src import typing
from jax._src import xla_metadata as xla_metadata_lib
traceback_util.register_exclusion(__file__)
zip, unsafe_zip = safe_zip, zip
map, unsafe_map = safe_map, map
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.int_flag(
'jax_tracer_error_num_traceback_frames',
config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
help='Set the number of stack frames in JAX tracer error messages.'
)
# -------------------- jaxprs --------------------
Effect = effects.Effect
Effects = effects.Effects
EffectTypeSet = effects.EffectTypeSet
no_effects: Effects = effects.no_effects
class JaxprDebugInfo(NamedTuple):
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... )
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)
class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
'_effects', '_debug_info']
_constvars: list[Var]
_invars: list[Var]
_outvars: list[Atom]
_eqns: list[JaxprEqn]
_effects: Effects
_debug_info: JaxprDebugInfo | None
@property
def constvars(self) -> list[Var]:
return self._constvars
@property
def invars(self) -> list[Var]:
return self._invars
@property
def outvars(self) -> list[Atom]:
return self._outvars
@property
def eqns(self) -> list[JaxprEqn]:
return self._eqns
@property
def effects(self) -> Effects:
return self._effects
@property
def debug_info(self) -> JaxprDebugInfo | None:
return self._debug_info
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects,
debug_info: JaxprDebugInfo | None = None):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
replaced with such variables while scalar constants are kept inline.
invars: list of input variables. Together, `constvars` and `invars` are
the inputs to the Jaxpr.
outvars: list of output atoms.
eqns: list of equations.
effects: set of effects. The effects on a jaxpr are a superset of the
union of the effects for each equation.
debug_info: optional JaxprDebugInfo.
"""
self._constvars = list(constvars)
self._invars = list(invars)
self._outvars = list(outvars)
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info
assert (not debug_info or len(debug_info.arg_names) == len(invars) and
len(debug_info.result_paths) == len(outvars))
def __str__(self):
return str(self.pretty_print())
__repr__ = __str__
def pretty_print(self, *, source_info=False, print_shapes=True,
custom_pp_eqn_rules=True, name_stack=False,
print_effects: bool = False, **kwargs):
doc = pp_toplevel_jaxpr(
self, source_info=source_info, print_shapes=print_shapes,
custom_pp_eqn_rules=custom_pp_eqn_rules, name_stack=name_stack,
print_effects=print_effects)
return doc.format(**kwargs)
def _repr_pretty_(self, p, cycle):
return p.text(self.pretty_print(use_color=True))
def replace(self, **kwargs):
jaxpr = Jaxpr(
constvars=kwargs.pop("constvars", self.constvars),
invars=kwargs.pop("invars", self.invars),
outvars=kwargs.pop("outvars", self.outvars),
eqns=kwargs.pop("eqns", self.eqns),
effects=kwargs.pop("effects", self.effects),
debug_info=kwargs.pop("debug_info", self.debug_info),
)
if kwargs:
raise ValueError(f"Unknown keyword arguments: {kwargs}")
return jaxpr
def join_effects(*effects: Effects) -> Effects:
return set().union(*effects) if effects else no_effects
def jaxprs_in_params(params) -> Iterator[Jaxpr]:
for val in params.values():
vals = val if isinstance(val, tuple) else (val,)
for v in vals:
if isinstance(v, Jaxpr):
yield v
elif isinstance(v, ClosedJaxpr):
yield v.jaxpr
def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
"""Generator for all subjaxprs found in the params of jaxpr.eqns.
Does not descend recursively into the found subjaxprs.
"""
for eqn in jaxpr.eqns:
yield from jaxprs_in_params(eqn.params)
class ClosedJaxpr:
__slots__ = ['__weakref__', '_jaxpr', '_consts']
_jaxpr: Jaxpr
_consts: list[Any]
jaxpr = property(lambda self: self._jaxpr)
consts = property(lambda self: self._consts)
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
assert len(consts) == len(jaxpr.constvars)
# assert not any(isinstance(c, Tracer) for c in consts) # TODO(mattjj): enable
self._jaxpr = jaxpr
self._consts = list(consts)
@property
def in_avals(self):
return [v.aval for v in self.jaxpr.invars]
@property
def out_avals(self):
return [v.aval for v in self.jaxpr.outvars]
@property
def literals(self):
return self.consts # backwards compatible alias
@property
def eqns(self):
return self.jaxpr.eqns
@property
def effects(self) -> Effects:
return self.jaxpr.effects
def map_jaxpr(self, f):
return ClosedJaxpr(f(self.jaxpr), self.consts)
def replace(self, *, jaxpr=None, consts=None):
jaxpr = self.jaxpr if jaxpr is None else jaxpr
consts = self.consts if consts is None else consts
return ClosedJaxpr(jaxpr, consts)
def __str__(self): return str(self.jaxpr)
def __repr__(self): return repr(self.jaxpr)
def pretty_print(self, *, source_info=False, print_shapes=True,
name_stack=False, custom_pp_eqn_rules=True,
print_effects=False, **kwargs):
return self.jaxpr.pretty_print(
source_info=source_info,
print_shapes=print_shapes,
name_stack=name_stack,
custom_pp_eqn_rules=custom_pp_eqn_rules,
print_effects=print_effects,
**kwargs)
def _repr_pretty_(self, p, cycle):
return p.text(self.pretty_print(use_color=True))
@curry
def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
class JaxprEqnContext:
def __init__(self, compute_type: str | None, threefry_partitionable: bool,
xla_metadata=None):
self.compute_type = compute_type
self.threefry_partitionable = threefry_partitionable
self.xla_metadata = xla_metadata
self._managers = [
(compute_on.extend_compute_type, self.compute_type),
(config.threefry_partitionable.__call__, self.threefry_partitionable),
(xla_metadata_lib.set_xla_metadata, self.xla_metadata),
]
@property
@contextmanager
def manager(self):
with ExitStack() as stack:
for manager, val in self._managers:
stack.enter_context(manager(val))
yield
def __repr__(self):
return (
f"JaxprEqnContext(compute_type={self.compute_type}, "
f"threefry_partitionable={self.threefry_partitionable}, "
f"xla_metadata={self.xla_metadata})"
)
class JaxprEqn:
invars: list[Atom]
outvars: list[Var]
primitive: Primitive
params: dict[str, Any]
effects: Effects
source_info: source_info_util.SourceInfo
ctx: JaxprEqnContext
# It's slightly faster to use a class with __slots__ than a NamedTuple.
__slots__ = ['invars', 'outvars', 'primitive', 'params', 'effects',
'source_info', 'ctx']
def __init__(self, invars, outvars, primitive, params, effects, source_info,
ctx):
self.invars = invars
self.outvars = outvars
self.primitive = primitive
self.params = params
self.effects = effects
self.source_info = source_info
self.ctx = ctx
def __repr__(self):
return str(pp_eqn(self, JaxprPpContext(), JaxprPpSettings())).rstrip()
def replace(
self,
invars: list[Atom] | None = None,
outvars: list[Var] | None = None,
primitive: Primitive | None = None,
params: dict[str, Any] | None = None,
effects: Effects | None = None,
source_info: source_info_util.SourceInfo | None = None,
ctx: JaxprEqnContext | None = None
):
return JaxprEqn(
self.invars if invars is None else invars,
self.outvars if outvars is None else outvars,
self.primitive if primitive is None else primitive,
self.params if params is None else params,
self.effects if effects is None else effects,
self.source_info if source_info is None else source_info,
self.ctx if ctx is None else ctx,
)
# TODO(mattjj): call typecheck rules here, so we don't form bad eqns
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None,
ctx=None):
source_info = source_info or source_info_util.new_source_info()
ctx = ctx or JaxprEqnContext(
compute_on.current_compute_type(),
config.threefry_partitionable.value,
xla_metadata_lib.current_xla_metadata())
if config.enable_checks.value:
assert all(isinstance(x, (Var, Literal)) for x in invars)
assert all(isinstance(v, Var) for v in outvars)
return JaxprEqn(invars, outvars, primitive, params, effects, source_info, ctx)
_var_counter = it.count()
@total_ordering
class Var:
__slots__ = ["count", "suffix", "aval"]
count: int
suffix: str
aval: AbstractValue
def __init__(self, suffix: str, aval: AbstractValue):
self.count = next(_var_counter)
self.suffix = suffix
self.aval = raise_to_shaped(aval)
# TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not
# care about variable ordering, but the downstream package kfac_jax does.
def __lt__(self, other):
return self.count < other.count
def __repr__(self):
return f'Var(id={id(self)}){self.suffix}:{self.aval.str_short()}'
def gensym(suffix: str = '') -> Callable[[AbstractValue], Var]:
"""Produce distinct variables, printed with the optional suffix."""
return partial(Var, suffix)
# In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that
# the assignment is dropped, i.e. that an expression's output value will never
# be read. In that sense, `dropvar` is not a variable, but it is convenient to
# treat it as a special case of one. Its `aval` is similarly inexact.
class DropVar(Var):
def __init__(self, aval: AbstractValue):
super().__init__('', aval)
def __repr__(self): return '_'
class Literal:
__slots__ = ["val", "aval", "hash"]
val: Any
aval: AbstractValue
hash: int | None
def __init__(self, val, aval):
self.val = val
self.aval = aval
try:
self.hash = hash(val)
except TypeError:
if type(val) in literalable_types:
try:
self.hash = hash((val.item(), val.dtype))
except (TypeError, AttributeError, ValueError):
self.hash = None
__hash__ = None # type: ignore
def __repr__(self):
if hasattr(self, 'hash'):
return f'{self.val}'
else:
return f'Literal(val={self.val})'
literalable_types: set[type] = set()
Atom = Union[Var, Literal]
class Primitive:
name: str
# set for multi-output primitives.
multiple_results: bool = False
# set for call primitives processed in final style.
call_primitive: bool = False
# set for map primitives processed in final style.
map_primitive: bool = False
def __init__(self, name: str):
self.name = name
def __repr__(self):
return f'{self.name}'
def bind(self, *args, **params):
assert (not config.enable_checks.value or
all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
return self.bind_with_trace(find_top_trace(args), args, params)
def bind_with_trace(self, trace, args, params):
with pop_level(trace.level):
out = trace.process_primitive(self, map(trace.full_raise, args), params)
return map(full_lower, out) if self.multiple_results else full_lower(out)
def def_impl(self, impl):
self.impl = impl
return impl
def def_abstract_eval(self, abstract_eval):
self.abstract_eval = _effect_free_abstract_eval(abstract_eval)
return abstract_eval
def def_effectful_abstract_eval(self, effectful_abstract_eval):
self.abstract_eval = effectful_abstract_eval
return effectful_abstract_eval
def def_custom_bind(self, bind):
self.bind = bind
return bind
def impl(self, *args, **params):
raise NotImplementedError("Evaluation rule for '{}' not implemented"
.format(self.name))
def abstract_eval(self, *args, **params):
raise NotImplementedError("Abstract evaluation for '{}' not implemented"
.format(self.name))
def get_bind_params(self, params):
return [], params
def _effect_free_abstract_eval(abstract_eval):
def abstract_eval_(*args, **kwargs):
return abstract_eval(*args, **kwargs), no_effects
return abstract_eval_
# -------------------- lifting --------------------
# TODO(mattjj): replace this approach with a primitive-keyed table of rules
def traverse_jaxpr_params(f, params):
"""Applies f to each jaxpr parameter and returns a tuple of returned values."""
return {name: f(p)
for name, param in params.items()
for p in (param if isinstance(param, (tuple, list)) else [param])
if type(p) in (Jaxpr, ClosedJaxpr)}
def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[Any]:
def read(v: Atom) -> Any:
return v.val if isinstance(v, Literal) else env[v]
def write(v: Var, val: Any) -> None:
if config.enable_checks.value and not config.dynamic_shapes.value:
assert typecheck(v.aval, val), (v.aval, val)
env[v] = val
env: dict[Var, Any] = {}
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
lu = last_used(jaxpr)
for eqn in jaxpr.eqns:
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
traceback = eqn.source_info.traceback if propagate_source_info else None
with source_info_util.user_context(
traceback, name_stack=name_stack), eqn.ctx.manager:
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
if eqn.primitive.multiple_results:
map(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
clean_up_dead_vars(eqn, env, lu)
return map(read, jaxpr.outvars)
# -------------------- tracing --------------------
TracerType = TypeVar('TracerType', bound='Tracer')
class Trace(Generic[TracerType]):
__slots__ = ['main', 'level', 'sublevel']
main: MainTrace
level: int
sublevel: Sublevel
def __init__(self, main: MainTrace, sublevel: Sublevel) -> None:
self.main = main
self.level = main.level
self.sublevel = sublevel
def full_raise(self, val) -> TracerType:
if not isinstance(val, Tracer):
# This check is only applied to non-Tracers, because the hasattr() is
# expensive (Tracer.__getattr__) in the common case that val is a Tracer.
if hasattr(val, "dimension_as_value"): # Used for shape_poly._DimExpr
val = val.dimension_as_value()
if not isinstance(val, Tracer):
return self.pure(val)
else:
return self.pure(val)
val._assert_live()
level = self.level
sublevel = self.sublevel
if val._trace.main is self.main:
if val._trace.sublevel == sublevel:
return cast(TracerType, val)
elif val._trace.sublevel < sublevel:
return self.sublift(val)
else:
raise escaped_tracer_error(
val, f"Can't lift sublevels {val._trace.sublevel} to {sublevel}")
elif val._trace.level < level:
if val._trace.sublevel > sublevel:
raise escaped_tracer_error(
val, f"Incompatible sublevel: {val._trace}, {(level, sublevel)}")
return self.lift(val)
elif val._trace.level > level:
raise escaped_tracer_error(
val, f"Can't lift level {val} to {self}")
else: # val._trace.level == self.level:
raise escaped_tracer_error(
val, f"Different traces at same level: {val}, {self}")
def pure(self, val) -> TracerType:
raise NotImplementedError("must override")
def lift(self, tracer) -> TracerType:
raise NotImplementedError("must override")
def sublift(self, tracer) -> TracerType:
raise NotImplementedError("must override")
def process_primitive(self, primitive, tracers, params):
raise NotImplementedError("must override")
def __repr__(self):
return '{}(level={}/{})'.format(
self.__class__.__name__, self.level, self.sublevel)
def process_call(self, call_primitive, f, tracers, params):
msg = (f"{type(self)} must override process_call to handle call-like "
"primitives")
raise NotImplementedError(msg)
def process_map(self, map_primitive, f, tracers, params):
msg = (f"{type(self)} must override process_map to handle map-like "
"primitives")
raise NotImplementedError(msg)
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
symbolic_zeros):
msg = (f"{type(self)} must override process_custom_jvp_call "
"to handle custom_jvp primitives")
raise NotImplementedError(msg)
def process_custom_transpose(self, prim, call, tracers, **params):
msg = (f"{type(self)} must override process_custom_transpose "
"to handle custom_transpose_call primitives")
raise NotImplementedError(msg)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees, symbolic_zeros):
msg = (f"{type(self)} must override process_custom_vjp_call "
"to handle custom_vjp primitives")
raise NotImplementedError(msg)
def raise_as_much_as_possible(tracer) -> Tracer:
# Find effective bottom of trace stack (highest dynamic Trace on the stack).
trace_stack = thread_local_state.trace_state.trace_stack.stack
idx = next(i for i, m in enumerate(trace_stack) if m is
thread_local_state.trace_state.trace_stack.dynamic)
# Only pay attention to effective part of trace stack.
trace_stack = trace_stack[idx:]
# Lift tracer into everything in the effective stack higher than its level
for trace in trace_stack:
trace = trace.with_cur_sublevel()
if (not isinstance(tracer, Tracer) or tracer._trace.level < trace.level):
tracer = trace.full_raise(tracer)
return tracer
def escaped_tracer_error(tracer, detail=None):
num_frames = _TRACER_ERROR_NUM_TRACEBACK_FRAMES.value
msg = ('Encountered an unexpected tracer. A function transformed by JAX '
'had a side effect, allowing for a reference to an intermediate value '
f'with type {tracer.aval.str_short()} wrapped in a '
f'{type(tracer).__name__} to escape the scope of the transformation.\n'
'JAX transformations require that functions explicitly return their '
'outputs, and disallow saving intermediate values to global state.')
dbg = getattr(tracer, '_debug_info', None)
if dbg is not None:
msg += ('\nThe function being traced when the value leaked was '
f'{dbg.func_src_info} traced for {dbg.traced_for}.')
line_info = getattr(tracer, '_line_info', None)
if line_info is not None:
divider = '\n' + '-'*30 + '\n'
msg += divider
msg += ('The leaked intermediate value was created on line '
f'{source_info_util.summarize(line_info)}. ')
msg += divider
if num_frames > 0:
msg += (f'When the value was created, the final {num_frames} stack '
'frames (most recent last) excluding JAX-internal frames were:')
msg += divider + source_info_util.summarize(
line_info, num_frames=num_frames) + divider
msg += ('\nTo catch the leak earlier, try setting the environment variable '
'JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context '
'manager.')
if detail:
msg += f'Detail: {detail}'
return UnexpectedTracerError(msg)
def check_scalar_conversion(arr: Array):
if arr.ndim > 0:
raise TypeError("Only scalar arrays can be converted to Python scalars; "
f"got {arr.ndim=}")
def check_integer_conversion(arr: Array):
if not (arr.shape == () and dtypes.issubdtype(arr.dtype, np.integer)):
raise TypeError("Only integer scalar arrays can be converted to a scalar index.")
def check_bool_conversion(arr: Array):
if arr.size == 0:
raise ValueError("The truth value of an empty array is ambiguous. Use"
" `array.size > 0` to check that an array is not empty.")
if arr.size > 1:
raise ValueError("The truth value of an array with more than one element"
" is ambiguous. Use a.any() or a.all()")
def _aval_property(name):
return property(lambda self: getattr(self.aval, name))
class Tracer(typing.Array, metaclass=StrictABCMeta):
__array_priority__ = 1000
__slots__ = ['_trace', '_line_info']
dtype = _aval_property('dtype')
ndim = _aval_property('ndim')
size = _aval_property('size')
shape = _aval_property('shape')
def __hash__(self):
# TODO(jakevdp) finalize this deprecation and set __hash__ = None
# Warning added 2024-06-13
if deprecations.is_accelerated('tracer-hash'):
raise TypeError(f"unhashable type: {type(self)}")
# Use FutureWarning rather than DeprecationWarning because hash is likely
# not called directly by the user, so we want to warn at all stacklevels.
warnings.warn(
f"unhashable type: {type(self)}. Attempting to hash a tracer will lead to an"
" error in a future JAX release.", category=FutureWarning)
return super().__hash__()
def __init__(self, trace: Trace):
self._trace = trace
def _error_repr(self):
if self.aval is None:
return f"traced array with aval {self.aval}"
return f"traced array with shape {raise_to_shaped(self.aval).str_short()}"
def __array__(self, *args, **kw):
raise TracerArrayConversionError(self)
def __dlpack__(self, *args, **kw):
raise ConcretizationTypeError(self,
f"The __dlpack__() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
def tolist(self):
raise ConcretizationTypeError(self,
f"The tolist() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
def tobytes(self, order="C"):
del order
raise ConcretizationTypeError(self,
f"The tobytes() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
def __iter__(self):
return iter(self.aval._iter(self))
def __reversed__(self):
return iter(self[::-1])
def __len__(self):
return self.aval._len(self)
@property
def sharding(self):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
# we raise an AttributeError so that hasattr() and getattr() work as expected.
raise AttributeError(self,
f"The 'sharding' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def device(self):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
# we raise an AttributeError so that hasattr() and getattr() work as expected.
raise AttributeError(self,
f"The 'device' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def addressable_shards(self):
raise ConcretizationTypeError(self,
f"The 'addressable_shards' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def at(self):
return self.aval.at.fget(self)
@property
def aval(self):
raise NotImplementedError("must override")
def _assert_live(self) -> None:
pass # Override for liveness checking
def get_referent(self) -> Any:
return self # Override for object equivalence checking
def __bool__(self):
check_bool_conversion(self)
return self.aval._bool(self)
def __int__(self):
check_scalar_conversion(self)
return self.aval._int(self)
def __float__(self):
check_scalar_conversion(self)
return self.aval._float(self)
def __complex__(self):
check_scalar_conversion(self)
return self.aval._complex(self)
def __hex__(self):
check_integer_conversion(self)
return self.aval._hex(self)
def __oct__(self):
check_integer_conversion(self)
return self.aval._oct(self)
def __index__(self):
check_integer_conversion(self)
raise self.aval._index(self)
# raises a useful error on attempts to pickle a Tracer.
def __reduce__(self):
raise ConcretizationTypeError(
self, ("The error occurred in the __reduce__ method, which may "
"indicate an attempt to serialize/pickle a traced value."))
# raises the better error message from ShapedArray
def __setitem__(self, idx, val): return self.aval._setitem(self, idx, val)
# NumPy also only looks up special methods on classes.
def __array_module__(self, types): return self.aval._array_module(self, types)
def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here
assert not config.enable_checks.value or name != "aval"
try:
attr = getattr(self.aval, name)
except AttributeError as err:
raise AttributeError(
f"{self.__class__.__name__} has no attribute {name}"
) from err
else:
t = type(attr)
if t is aval_property:
return attr.fget(self)
elif t is aval_method:
return types.MethodType(attr.fun, self)
else:
return attr
def _pretty_print(self):
base = pp.text(f'Traced<{self.aval}>with<{self._trace}>')
contents = [(name, attr._pretty_print() if isinstance(attr, Tracer)
else pp.text(repr(attr))) for name, attr in self._contents()]
if contents:
base = pp.group(pp.nest(2, pp.concat([
base, pp.text(' with'), pp.brk(), pp.join(pp.brk(), [
pp.text(f'{name} = ') + pp_payload
for name, pp_payload in contents])
])))
return base
def __repr__(self):
return self._pretty_print().format()
def _contents(self):
try:
return [(name, getattr(self, name)) for name in self.__slots__]
except AttributeError:
return ()
def _origin_msg(self) -> str:
return ""
# Methods that are only valid for materialized arrays
def addressable_data(self, index):
raise ConcretizationTypeError(self,
f"The addressable_data() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def block_until_ready(self):
# Raise AttributeError for backward compatibility with hasattr() and getattr() checks.
raise AttributeError(self,
f"The 'block_until_ready' method is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def copy_to_host_async(self):
# Raise AttributeError for backward compatibility with hasattr() and getattr() checks.
raise AttributeError(self,
f"The 'copy_to_host_async' method is not available on {self._error_repr()}."
f"{self._origin_msg()}")
def delete(self):
raise ConcretizationTypeError(self,
f"The delete() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
def devices(self):
raise ConcretizationTypeError(self,
f"The devices() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def global_shards(self):
raise ConcretizationTypeError(self,
f"The global_shards property was called on {self._error_repr()}."
f"{self._origin_msg()}")
def is_deleted(self):
raise ConcretizationTypeError(self,
f"The is_deleted() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def is_fully_addressable(self):
raise ConcretizationTypeError(self,
f"The is_fully_addressable property was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def is_fully_replicated(self):
raise ConcretizationTypeError(self,
f"The is_fully_replicated property was called on {self._error_repr()}."
f"{self._origin_msg()}")
def on_device_size_in_bytes(self):
raise ConcretizationTypeError(self,
f"The on_device_size_in_bytes() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def traceback(self):
raise ConcretizationTypeError(self,
f"The traceback property was called on {self._error_repr()}."
f"{self._origin_msg()}")
def unsafe_buffer_pointer(self):
raise ConcretizationTypeError(self,
f"The unsafe_buffer_pointer() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
# these can be used to set up forwarding of properties and instance methods from
# Tracer instances to the underlying avals
aval_property = namedtuple("aval_property", ["fget"])
aval_method = namedtuple("aval_method", ["fun"])
class EvalTrace(Trace):
# See comments in https://github.com/jax-ml/jax/pull/3370
def pure(self, x): return x
lift = sublift = pure
def process_primitive(self, primitive, tracers, params):
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error
return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params)
else:
return primitive.impl(*tracers, **params)
def process_call(self, primitive, f, tracers, params):
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error
return call_impl_with_key_reuse_checks(primitive, primitive.impl, f, *tracers, **params)
else:
return primitive.impl(f, *tracers, **params)
process_map = process_call
def process_custom_transpose(self, primitive, call, tracers, **_):
del primitive, _
with new_sublevel():
return call.call_wrapped(*tracers)
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, **_):
del primitive, jvp, _ # Unused.
with new_sublevel():
return fun.call_wrapped(*tracers)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, **_): # pytype: disable=signature-mismatch
del primitive, fwd, bwd, _ # Unused.
with new_sublevel():
return fun.call_wrapped(*tracers)
class MainTrace:
level: int
trace_type: type[Trace]
payload: dict[str, Any]
def __init__(self, level, trace_type, **payload) -> None:
self.level = level
self.trace_type = trace_type
self.payload = payload
def __repr__(self) -> str:
return f"MainTrace({self.level},{self.trace_type.__name__})"
def __hash__(self) -> int:
return hash((self.level, self.trace_type))
def __eq__(self, other: object) -> bool:
return (isinstance(other, MainTrace) and
self.level == other.level and
self.trace_type == other.trace_type and
self.payload == other.payload)
def with_cur_sublevel(self):
return self.trace_type(self, cur_sublevel(), **self.payload)
class TraceStack:
# See comments in https://github.com/jax-ml/jax/pull/3370
stack: list[MainTrace]
dynamic: MainTrace
def __init__(self):
eval_trace = MainTrace(0, EvalTrace)
self.stack = [eval_trace]
self.dynamic = eval_trace
def next_level(self) -> int:
return len(self.stack)
def push(self, main_trace: MainTrace) -> None:
self.stack.append(main_trace)
def pop(self) -> None:
self.stack.pop()
def __repr__(self) -> str:
stack_str = map(' {}\n'.format, self.stack[::-1])
return f'Trace stack\n{stack_str}\n{self.dynamic}'
def copy(self):
new = self.__new__(TraceStack)
new.stack = self.stack[:]
new.dynamic = self.dynamic
return new
@total_ordering
class Sublevel:
def __init__(self, level: int):
self.level = level
def __repr__(self):
return str(self.level)
def __eq__(self, other):
return type(other) is Sublevel and self.level == other.level
def __lt__(self, other):
return type(other) is Sublevel and self.level < other.level
AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace'])
AxisName = Hashable
no_axis_name = object()
class TraceState:
trace_stack: TraceStack
substack: list[Sublevel]
axis_env: list[AxisEnvFrame]
def __init__(self) -> None:
self.trace_stack = TraceStack()
self.substack = [Sublevel(0)]
self.axis_env = []
def copy(self):
new = self.__new__(TraceState)
new.trace_stack = self.trace_stack.copy()
new.substack = self.substack[:]
new.axis_env = self.axis_env[:]
return new
def _update_thread_local_jit_state(dynamic):
state = (dynamic.level, dynamic.trace_type)
config.update_thread_local_jit_state(dynamic_trace_state=state)
# The global state of the tracer is accessed by a thread-local object.
# This allows concurrent tracing in separate threads; passing traced objects
# between threads is forbidden.
class ThreadLocalState(threading.local):
def __init__(self):
self.trace_state = TraceState()
thread_local_state = ThreadLocalState()
def _initialize_jax_jit_thread_local_state():
"""Initializes the C++ thread-local context.
When the user spawns threads, the C++ `jax_jit.thread_local_state` is None.
The C++ accessor calls this function if it realizes the thread_local_state
is None (which means it's not yet initialized for this thread).
This function does not live in `config.py`, to prevent circular imports.
"""
tls = jax_jit.thread_local_state()
if tls.extra_jit_context is None:
dynamic = thread_local_state.trace_state.trace_stack.dynamic
state = (dynamic.level, dynamic.trace_type)
config.update_thread_local_jit_state(dynamic_trace_state=state)
jax_jit.set_thread_local_state_initialization_callback(
_initialize_jax_jit_thread_local_state)
def trace_state_clean() -> bool:
trace_state = thread_local_state.trace_state
return (trace_state.substack == [Sublevel(0)] and
trace_state.axis_env == [] and
trace_state.trace_stack.stack == [MainTrace(0, EvalTrace)] and
trace_state.trace_stack.dynamic == MainTrace(0, EvalTrace))
def reset_trace_state() -> bool:
"""Resets the global trace state and returns True if it was already clean."""
if not trace_state_clean():
thread_local_state.trace_state.__init__()
return False
else:
return True
def cur_sublevel() -> Sublevel:
return thread_local_state.trace_state.substack[-1]
TRACER_LEAK_DEBUGGER_WARNING = """\
JAX check_tracer_leaks behavior can trigger false positives when used with a debugger.
To avoid false positives and silence this warning, you can disable thread tracing using
the following:
import threading
threading.current_thread().pydev_do_not_trace = True
"""
def maybe_find_leaked_tracers(x: MainTrace | Sublevel | None
) -> list[Tracer]:
"""Find the leaked tracers holding a reference to the MainTrace or SubLevel.
It's possible there's none! eg. there's some cases where JAX itself holds a
reference to `x` inside of a lambda closure, and no tracers were leaked
by the user. In this case an empty list is returned.
"""
if not getattr(threading.current_thread(), 'pydev_do_not_trace', True):
warnings.warn(TRACER_LEAK_DEBUGGER_WARNING)
# Trigger garbage collection to filter out unreachable objects that are alive
# only due to cyclical dependencies. (We don't care about unreachable leaked
# tracers since they can't interact with user code and cause a problem.)
gc.collect()
traces = list(filter(lambda x: isinstance(x, Trace), gc.get_referrers(x)))
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces)))
return tracers
def leaked_tracer_error(name: str, t, tracers: list[Tracer]) -> Exception:
assert tracers
why = partial(_why_alive, {id(tracers)})
msgs = '\n\n'.join(f'{tracers[i]}{tracers[i]._origin_msg()}{why(tracers[i])}'
for i in range(len(tracers)))
return Exception(f'Leaked {name} {t}. Leaked tracer(s):\n\n{msgs}\n')
def _why_alive(ignore_ids: set[int], x: Any) -> str:
parents = lambda x: [r for r in gc.get_referrers(x) if id(r) not in ignore_ids]
child, lines, seen = x, [], set()
while (id(child) not in seen and type(child) is not types.ModuleType
and parents(child)):
parent = parents(child)[0] # just pick one parent
# For namespaces (like modules and class instances) and closures, the
# references may form a simple chain: e.g. instance refers to its own
# __dict__ which refers to child, or function refers to its __closure__
# which refers to cells which refer to child. In these cases, we can provide
# a more intuitive description by collapsing the chain into a single
# parent->child jump. We do that by setting `parent` here to be a
# grandparent (or great-grandparent) of `child`, and then handling that case
# in _why_alive_container_info. See example:
# https://github.com/jax-ml/jax/pull/13022#discussion_r1008456599
# To prevent this collapsing behavior, just comment out this code block.
if (isinstance(parent, dict) and
getattr(parents(parent)[0], '__dict__', None) is parents(child)[0]):
parent = parents(parent)[0]
elif type(parent) is types.CellType:
parent = parents(parents(parent)[0])[0]
line = f'<{type(child).__name__} {id(child)}> is referred to by '
lines.append(line + _why_alive_container_info(parent, id(child)))
seen.add(id(child))
child = parent
return '\n' + '\n'.join(lines) if lines else ''
def _why_alive_container_info(container, obj_id) -> str:
name = f'<{type(container).__name__} {id(container)}>'
if type(container) is types.ModuleType:
name = getattr(container, '__name__', name)
if type(container) is types.FunctionType:
name_ = getattr(container, '__name__', '<no-name>')
closure = inspect.getclosurevars(container)
keys = [k for k, v in dict(closure.nonlocals, **closure.globals).items()
if id(v) == obj_id]
if len(keys) == 1: return f'{name} ({name_}) closed-over variable {keys[0]}'
elif len(keys) > 1: return (f'{name} in closed-over variables ' +
', '.join(map(repr, keys)))
if hasattr(container, '__dict__'):
keys = [k for k in vars(container) if id(vars(container)[k]) == obj_id]
if len(keys) == 1: return f'{name}.{keys[0]}'
elif len(keys) > 1: return f'{name} in vars ' + ', '.join(map(repr, keys))
if isinstance(container, (list, tuple)):
idxs = [i for i, x in enumerate(container) if id(x) == obj_id]
if len(idxs) == 1: return f'{name}[{idxs[0]}]'
else: return f'{name} at indices ' + ', '.join(map(str, idxs))
if isinstance(container, dict):
keys = [k for k in container if id(container[k]) == obj_id]
if len(keys) == 1: return f'{name}[{keys[0]!r}]'
else: return f'{name} at keys ' + ', '.join(map(repr, keys))
if isinstance(container, types.ModuleType):
return f' named {container.__name__}'
return name
@contextmanager
def new_main(trace_type: type[Trace], dynamic: bool = False,
**payload) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/jax-ml/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack
level = stack.next_level()
main = MainTrace(level, trace_type, **payload)
stack.push(main)
if dynamic:
prev_dynamic, stack.dynamic = stack.dynamic, main
_update_thread_local_jit_state(stack.dynamic)
try:
yield main
finally:
stack.pop()
if dynamic:
stack.dynamic = prev_dynamic
_update_thread_local_jit_state(stack.dynamic)
if config.check_tracer_leaks.value:
t = ref(main)
del main
if t() is not None:
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
@contextmanager
def new_dynamic(level: int) -> Generator[None, None, None]:
stack = thread_local_state.trace_state.trace_stack
prev_dynamic, stack.dynamic = stack.dynamic, stack.stack[level]
_update_thread_local_jit_state(stack.dynamic)
try:
yield
finally:
stack.dynamic = prev_dynamic
_update_thread_local_jit_state(stack.dynamic)
def dynamic_level() -> int:
return thread_local_state.trace_state.trace_stack.dynamic.level
@contextmanager
def new_base_main(trace_type: type[Trace],
**payload) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/jax-ml/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack
main = MainTrace(0, trace_type, **payload)
prev_dynamic, stack.dynamic = stack.dynamic, main
prev_base, stack.stack[0] = stack.stack[0], main
_update_thread_local_jit_state(stack.dynamic)
try:
yield main
finally:
stack.dynamic = prev_dynamic
stack.stack[0] = prev_base
_update_thread_local_jit_state(stack.dynamic)
if config.check_tracer_leaks.value:
t = ref(main)
del main
if t() is not None:
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
@contextmanager
def pop_level(level: int):
if level == 0:
return (yield) # noqa: B901
prev, thread_local_state.trace_state.trace_stack.stack = \
thread_local_state.trace_state.trace_stack.stack, \
thread_local_state.trace_state.trace_stack.stack[:level]
try:
yield
finally:
thread_local_state.trace_state.trace_stack.stack = prev
@contextmanager
def ensure_compile_time_eval():
"""Context manager to ensure evaluation at trace/compile time (or error).
Some JAX APIs like :func:`jax.jit` and :func:`jax.lax.scan` involve staging,
i.e., delaying the evaluation of numerical expressions (like :mod:`jax.numpy`
function applications) so that instead of performing those computations
eagerly while evaluating the corresponding Python expressions, their
computation is carried out separately, e.g. after optimized compilation. But
this delay can be undesirable. For example, numerical values might be needed
to evaluate Python control flow and so their evaluation cannot be delayed. As
another example, it may be beneficial to ensure compile time evaluation (or
"constant folding") for performance reasons.
This context manager ensures that JAX computations are evaluated eagerly. If
eager evaluation is not possible, a ``ConcretizationTypeError`` is raised.
Here's a contrived example::
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
with jax.ensure_compile_time_eval():
y = jnp.sin(3.0)
z = jnp.sin(y)
z_positive = z > 0
if z_positive: # z_positive is usable in Python control flow
return jnp.sin(x)
else:
return jnp.cos(x)
Here's a real-world example from https://github.com/jax-ml/jax/issues/3974::
import jax
import jax.numpy as jnp
from jax import random
@jax.jit
def jax_fn(x):
with jax.ensure_compile_time_eval():
y = random.randint(random.key(0), (1000,1000), 0, 100)
y2 = y @ y
x2 = jnp.sum(y2) * x
return x2
A similar behavior can often be achieved simply by 'hoisting' the constant
expression out of the corresponding staging API::
y = random.randint(random.key(0), (1000,1000), 0, 100)
@jax.jit
def jax_fn(x):
y2 = y @ y
x2 = jnp.sum(y2)*x
return x2
But in some cases it can be more convenient to use this context manager.
"""
with new_base_main(EvalTrace):
yield
eval_context = ensure_compile_time_eval # alias, backward compatibility
@contextmanager
def new_sublevel() -> Generator[None, None, None]:
sublevel = Sublevel(len(thread_local_state.trace_state.substack))
thread_local_state.trace_state.substack.append(sublevel)
try:
yield
finally:
thread_local_state.trace_state.substack.pop()
if config.check_tracer_leaks.value:
t = ref(sublevel)
del sublevel
if t() is not None:
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers:
raise leaked_tracer_error("sublevel", t(), leaked_tracers)
def full_lower(val):
if isinstance(val, Tracer):
return val.full_lower()
else:
return val
def _get_trace_level(t: Tracer) -> int: return t._trace.level
def find_top_trace(xs) -> Trace:
top_tracer = max((x for x in xs if isinstance(x, Tracer)),
default=None, key=_get_trace_level)
if top_tracer is not None:
top_tracer._assert_live()
top_main = top_tracer._trace.main
else:
top_main = None
dynamic = thread_local_state.trace_state.trace_stack.dynamic
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
else top_main)
return top_main.with_cur_sublevel()
def get_referent(x: Any) -> Any:
return x.get_referent() if isinstance(x, Tracer) else x
def same_referent(x: Any, y: Any) -> bool:
return get_referent(x) is get_referent(y)
def dedup_referents(itr: Iterable[Any]) -> list[Any]:
return list({HashableWrapper(get_referent(x)):x for x in itr}.values())
def definitely_equal(x, y):
if isinstance(x, Tracer) or isinstance(y, Tracer):
return same_referent(x, y)
elif x is y:
return True
try:
return x == y
except InconclusiveDimensionOperation:
return False
# -------------------- abstract values --------------------
class AbstractValue:
__slots__: list[str] = []
def to_tangent_aval(self):
raise NotImplementedError("must override")
# TODO(dougalm): deprecate this alias
def at_least_vspace(self):
return self.to_tangent_aval()
def __repr__(self):
try:
kv_pairs = (f'{k}={v}' for k, v in self.__dict__.items())
return '{}({})'.format(self.__class__.__name__, ','.join(kv_pairs))
except AttributeError:
return self.__class__.__name__
def strip_weak_type(self) -> AbstractValue:
return self
def join(self, other):
raise NotImplementedError("must override")
def update(self, **kwargs):
raise NotImplementedError("must override")
def str_short(self, short_dtypes=False):
return str(self)
# For type signatures involving dynamic shapes, we use lists of abstract values
# which may contain (reverse) de Bruijn indices in their shapes.
class DBIdx(NamedTuple):
val: int
@dataclass(frozen=True)
class InDBIdx:
val: int
@dataclass(frozen=True)
class OutDBIdx:
val: int
# For annotating input types of callables (i.e. linear_util.WrappedFuns), we use
# a sequence of pairs where the first element of each pair is an AbstractValue
# (possibly containing DBIdx instances in its shape) and the second is a boolean
# indicating whether that argument is explicit (i.e. passed to the callable).
InputType = tuple[tuple[AbstractValue, bool], ...] # DBIdx in shapes
# For annotating jaxpr output types, we use a sequence of pairs where the first
# element of each pair is an AbstractValue (possibly containing InDBIdx and/or
# OutDBIdx instances in its shape) and the second is a boolean indicating
# whether that argument is explicit (i.e. returned by the callable).
OutputType = tuple[tuple[AbstractValue, bool], ...] # InDBIdx / OutDBIdx shapes
def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType:
idxs = {v: DBIdx(i) for i, v in enumerate((*jaxpr.constvars, *jaxpr.invars))}
out = [(v.aval.update(shape=tuple(idxs.get(d, d) for d in v.aval.shape)) # type: ignore
if type(v.aval) is DShapedArray else v.aval, True)
for v in jaxpr.invars]
return tuple(out)
class Bot(AbstractValue): pass
bot = Bot()
def lattice_join(x: AbstractValue | None,
y: AbstractValue | None) -> AbstractValue:
if x is None:
assert y is not None
return y
elif y is None:
return x
elif isinstance(x, type(y)):
return y.join(x)
elif isinstance(y, type(x)):
return x.join(y)
elif isinstance(x, DShapedArray) and isinstance(y, ShapedArray):
# TODO(mattjj): remove this special case after dynamic shapes are integrated
return x.join(y)
else:
raise TypeError(x, y)
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
Value = Any
def valid_jaxtype(x) -> bool:
try:
concrete_aval(x)
except TypeError:
return False
else:
return True
def check_valid_jaxtype(x):
if not valid_jaxtype(x):
raise TypeError(
f"Value {x!r} of type {type(x)} is not a valid JAX type")
def concrete_aval(x):
for typ in type(x).__mro__:
handler = pytype_aval_mappings.get(typ)
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
return concrete_aval(x.__jax_array__())
raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
"type")
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
else:
return concrete_aval(x)
def get_type(x):
aval = get_aval(x)
if isinstance(aval, ConcreteArray):
return raise_to_shaped(aval)
else:
return aval
def concretization_function_error(fun, suggest_astype=False):
fname = getattr(fun, "__name__", fun)
fname_context = f"The problem arose with the `{fname}` function. "
if suggest_astype:
fname_context += ("If trying to convert the data type of a value, "
f"try using `x.astype({fun.__name__})` "
f"or `jnp.array(x, {fun.__name__})` instead.")
if fun is bool:
def error(self, arg):
raise TracerBoolConversionError(arg)
elif fun in (hex, oct, operator.index):
def error(self, arg):
raise TracerIntegerConversionError(arg)
else:
def error(self, arg):
raise ConcretizationTypeError(arg, fname_context)
return error
def concrete_or_error(force: Any, val: Any, context=""):
"""Like force(val), but gives the context in the error message."""
if force is None:
force = lambda x: x
if isinstance(val, Tracer):
if isinstance(val.aval, ConcreteArray):
return force(val.aval.val)
else:
raise ConcretizationTypeError(val, context)
else:
return force(val)
def concrete_dim_or_error(val: Any, context=""):
"""Like concrete_or_error(operator.index), allowing symbolic dimensions."""
if is_symbolic_dim(val):
return val
else:
return concrete_or_error(operator.index, val, context=context)
### Extended dtypes
#
# Extended dtypes are JAX-specific dtypes that allow us to represent logical
# arrays of element types that do not have an obvious direct correspondence
# to ("physical") arrays of basic types in a compiler. In particular, their
# element types differ from those of XLA and NumPy (e.g. int32). These dtypes
# are only known to JAX. Their implementation is determined by:
# a) an object representing the extended dtype, accessible via the `dtype`
# attribute on corresponding JAX arrays and, internally, on avals such
# as ShapedArrays that correspond to such JAX arrays;
# b) a set of rules, available via a private attribute on the extended dtype
# object in (a).
# The rules in (b) tell JAX internals how to ground out the element
# type for interaction with the compiler and runtime, e.g. when lowering
# to the compiler's language.
@overload
def physical_aval(aval: ShapedArray) -> ShapedArray: ...
@overload
def physical_aval(aval: DShapedArray) -> DShapedArray: ...
@overload # TODO(frostig): remove this case
def physical_aval(aval: AbstractValue) -> AbstractValue: ...
def physical_aval(aval):
aval_dtype = getattr(aval, 'dtype', None)
if aval_dtype and isinstance(aval_dtype, dtypes.ExtendedDType):
ctor = type(aval)
aval_shape = getattr(aval, 'shape', None)
assert aval_shape is not None, (ctor, aval)
elt_aval = aval_dtype._rules.physical_element_aval(aval_dtype)
assert type(elt_aval) is ShapedArray
return ctor((*aval_shape, *elt_aval.shape), elt_aval.dtype) # pytype: disable=wrong-arg-count
else:
return aval
def _short_dtype_name(dtype) -> str:
if isinstance(dtype, dtypes.ExtendedDType):
return str(dtype)
else:
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
.replace('int' , 'i').replace('complex', 'c'))
def _dtype_object(dtype):
return dtype if isinstance(dtype, dtypes.ExtendedDType) else np.dtype(dtype)
class UnshapedArray(AbstractValue):
__slots__ = ['dtype', 'weak_type']
array_abstraction_level = 4
def __init__(self, dtype, weak_type=False):
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
def update(self, dtype=None, weak_type=None):
if dtype is None:
dtype = self.dtype
if weak_type is None:
weak_type = self.weak_type
return UnshapedArray(dtype, weak_type)
def __eq__(self, other):
return (type(self) is type(other) and self.dtype == other.dtype and
self.weak_type == other.weak_type)
def __ne__(self, other):
return not self == other
def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
# the unique character code via hash(self.dtype.char)
return hash((self.dtype, self.weak_type))
def __repr__(self):
return '{}({}{})'.format(self.__class__.__name__, self.str_short(),
", weak_type=True" if self.weak_type else "")
_bool = concretization_function_error(bool)
_int = concretization_function_error(int, True)
_float = concretization_function_error(float, True)
_complex = concretization_function_error(complex, True)
_hex = concretization_function_error(hex)
_oct = concretization_function_error(oct)
_index = concretization_function_error(operator.index)
def to_tangent_aval(self) -> AbstractValue:
return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
def join(self, other):
if self.dtype == other.dtype:
if self.weak_type == other.weak_type:
return self
else:
return UnshapedArray(self.dtype, weak_type=False)
else:
raise TypeError(self, other)
def str_short(self, short_dtypes=False) -> str:
return _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
def strip_weak_type(self):
"""Returns a copy of the aval with weak_type=False."""
return self.update(weak_type=False)
@property
def shape(self):
msg = ("UnshapedArray has no shape. Please open an issue at "
"https://github.com/jax-ml/jax/issues because it's unexpected for "
"UnshapedArray instances to ever be produced.")
raise TypeError(msg)
def _canonicalize_dimension(dim: DimSize) -> DimSize:
# Dimensions are most commonly integral (by far), so we check that first.
try:
return operator.index(dim)
except TypeError as e:
type_error = e
if isinstance(dim, Tracer) and config.dynamic_shapes.value:
if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer)
or isinstance(dim.dtype, bint))):
raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}")
return dim
elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
type(dim._aval.dtype) is bint and not dim._aval.shape):
return dim
elif is_dim(dim):
return dim
else:
raise type_error
def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
"""Canonicalizes and checks for errors in a user-provided shape value.
Args:
shape: a Python value that represents a shape.
Returns:
A tuple of canonical dimension values.
"""
if isinstance(shape, int):
shape = shape,
try:
return tuple(unsafe_map(_canonicalize_dimension, shape))
except TypeError:
pass
raise _invalid_shape_error(shape, context)
def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
Args:
d: a Python value that represents a dimension.
Returns:
A canonical dimension value.
"""
return canonicalize_shape((d,), context)[0]
def _invalid_shape_error(shape: Shape, context: str=""):
if config.dynamic_shapes.value:
msg = ("Shapes must be 1D sequences of integer scalars, "
f"got {shape}")
else:
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
f"got {shape}.")
if context:
msg += f" {context}."
if not config.dynamic_shapes.value and any(
isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
"smaller subfunctions.")
for x in shape:
if isinstance(x, Tracer) and hasattr(x, "_origin_msg"):
msg += x._origin_msg()
return TypeError(msg)
class ShapedArray(UnshapedArray):
__slots__ = ['shape', 'sharding'] # inherits slots from parent
array_abstraction_level = 2
def __init__(self, shape, dtype, weak_type=False, sharding=None):
self.shape = canonicalize_shape(shape)
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
if config.sharding_in_types.value:
self.sharding = sharding
def update(self, shape=None, dtype=None, weak_type=None, sharding=None):
if shape is None:
shape = self.shape
if dtype is None:
dtype = self.dtype
if weak_type is None:
weak_type = self.weak_type
if sharding is None:
sharding = getattr(self, 'sharding', None)
return ShapedArray(shape, dtype, weak_type, sharding=sharding)
ndim = property(lambda self: len(self.shape))
size = property(lambda self:
0 if any(type(d) is int and d == 0 for d in self.shape)
else math.prod(self.shape))
broadcast: ClassVar[aval_method | None] = None
transpose: ClassVar[aval_method | None] = None
reshape: ClassVar[aval_method | None] = None
_iter: ClassVar[staticmethod | None] = None
def __eq__(self, other):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape
and self.weak_type == other.weak_type
and getattr(self, 'sharding', None) == getattr(other, 'sharding', None))
def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
# the unique character code via hash(self.dtype.char)
return hash((self.shape, self.dtype, self.weak_type,
getattr(self, 'sharding', None)))
def to_tangent_aval(self):
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
def join(self, other):
if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
weak_type = self.weak_type and other.weak_type
return self.update(weak_type=weak_type)
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype)
else:
raise TypeError(self, other)
def str_short(self, short_dtypes=False):
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
dt_str = dt_str.replace('void', 'float0')
shapestr = ','.join(map(str, self.shape))
if hasattr(self, 'sharding'):
return f'{dt_str}[{shapestr}]({self.sharding})'
else:
return f'{dt_str}[{shapestr}]'
def _len(self, ignored_tracer):
try:
return self.shape[0]
except IndexError as err:
raise TypeError("len() of unsized object") from err # same as numpy error
def _forward_to_value(self, fun, ignored_tracer, *args):
return fun(self.val, *args)
class ConcreteArray(ShapedArray):
__slots__ = ['val']
array_abstraction_level = 0
def __init__(self, dtype, val, weak_type=None):
super().__init__(
np.shape(val), dtype,
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type)
dtypes.check_valid_dtype(self.dtype)
# Note: canonicalized self.dtype doesn't necessarily match self.val
assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype)
self.val = val
def update(self, dtype=None, val=None, weak_type=None):
dtype = self.dtype if dtype is None else dtype
val = self.val if val is None else val
weak_type = self.weak_type if weak_type is None else weak_type
return ConcreteArray(dtype, val, weak_type)
def __eq__(self, other):
if (type(self) is type(other) and self.dtype == other.dtype
and self.shape == other.shape and self.weak_type == other.weak_type):
with eval_context(): # in case self.val is an Array
return (self.val == other.val).all()
else:
return False
def __hash__(self):
return id(self.val)
def join(self, other) -> AbstractValue:
if self == other:
return self
elif self.shape == other.shape and self.dtype == other.dtype:
weak_type = self.weak_type and other.weak_type
return ShapedArray(self.shape, self.dtype, weak_type=weak_type)
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype, weak_type=self.weak_type and other.weak_type)
else:
raise TypeError(self, other)
def str_short(self, short_dtypes=False) -> str:
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
return f'{self.val}, dtype={dt_str}'
_bool = partialmethod(_forward_to_value, bool)
_int = partialmethod(_forward_to_value, int)
_hex = partialmethod(_forward_to_value, hex)
_oct = partialmethod(_forward_to_value, oct)
_index = partialmethod(_forward_to_value, operator.index)
_float = concretization_function_error(float, True)
_complex = concretization_function_error(complex, True)
def primal_dtype_to_tangent_dtype(primal_dtype):
if isinstance(primal_dtype, dtypes.ExtendedDType):
return primal_dtype._rules.tangent_dtype(primal_dtype)
elif not dtypes.issubdtype(primal_dtype, np.inexact):
return dtypes.float0
else:
return primal_dtype
# Dynamic shape stuff below here! We keep the abstract values distinct just so
# as not to interfere with any static shape machinery.
# We have a convention of reusing AbsractValues as types, even though we could
# make a distinction and use abstract values during tracing only. This reuse
# becomes a bit more extreme with DShapedArrays. A DShapedArray's shape
# attribute is a tuple which can contain several different types: int, DArray
# (scalar and with dtype of bint type), Tracer (while tracing), Var (when used
# as jaxpr type annotations), or DBIdx/InDBIdx/OutDBIdx (when used in InputType
# or OutputType). We could reduce this polymorphism if it seems cleaner, though
# it's kind of convenient!
class DShapedArray(UnshapedArray):
__slots__ = ['shape']
shape: tuple[AxisSize, ...] # noqa: F821
array_abstraction_level: int = 3
def __init__(self, shape, dtype, weak_type=False):
self.shape = shape
self.dtype = dtype
self.weak_type = weak_type
ndim = property(lambda self: len(self.shape))
size = property(lambda self:
0 if any(type(d) is int and d == 0 for d in self.shape)
else math.prod(self.shape))
def str_short(self, short_dtypes=False) -> str:
del short_dtypes # ignored
shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else ''
dtype = _short_dtype_name(self.dtype)
return f'{dtype}[{shape}]'
__str__ = __repr__ = str_short
def update(self, shape=None, dtype=None, weak_type=None):
if shape is None:
shape = self.shape
if dtype is None:
dtype = self.dtype
if weak_type is None:
weak_type = self.weak_type
return DShapedArray(shape, dtype, weak_type)
def _len(self, tracer):
return self.shape[0]
def __eq__(self, other):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape
and self.weak_type == other.weak_type)
def __hash__(self):
return hash((self.shape, self.dtype, self.weak_type))
def join(self, other):
if (definitely_equal_shape(self.shape, other.shape) and
self.dtype == other.dtype):
weak_type = self.weak_type and other.weak_type
return self.update(weak_type=weak_type)
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype)
else:
raise TypeError(self, other)
def to_tangent_aval(self):
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
class DConcreteArray(DShapedArray):
__slots__ = ['val']
array_abstraction_level = 1
def __init__(self, shape, dtype, weak_type, val):
super().__init__(shape, dtype, weak_type)
self.val = val
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
class DArray:
_aval: DShapedArray
_data: Any # standard array type
def __init__(self, aval, data):
pad_shape = tuple(d.dtype.bound if type(d) is DArray and
type(d.dtype) is bint else d for d in aval.shape)
assert data.shape == pad_shape
self._aval = aval
self._data = data
shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)
aval = property(lambda self: self._aval)
def __repr__(self) -> str:
if not self.shape and type(self.dtype) is bint:
# special-case scalar bints
return f'{int(self._data)}{{{self.dtype.bound}}}'
dtypestr = _short_dtype_name(self._aval.dtype)
shapestr = ','.join(map(str, self.shape))
data = self.data
return f'{dtypestr}[{shapestr}] with value: {data}'
def __hash__(self) -> int:
if not self.shape:
return hash((self._aval, int(self._data)))
raise TypeError("unhashable type: DArray")
def __eq__(self, other):
if isinstance(other, DArray) and self._aval == other._aval:
return self._data == other._data
return False
def __len__(self):
return self.shape[0]
@property
def data(self):
if not self.shape and type(self.dtype) is bint:
# special-case scalar bints
return self._data
slices = tuple(
slice(int(d._data))
if type(d) is DArray and type(d.dtype) is bint
else slice(None)
for d in self.shape
)
data = self._data[slices]
return data
pytype_aval_mappings[DArray] = \
lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type,
x._data)
@dataclass(frozen=True)
class bint(dtypes.ExtendedDType):
bound: int
@property
def type(self) -> type:
return dtypes.extended
@property
def name(self) -> str:
return f'bint{{{self.bound}}}'
def __str__(self) -> str:
return self.name
AxisSize = Union[int, DArray, Tracer, Var, DBIdx, InDBIdx, OutDBIdx]
class MutableArray:
_aval: ShapedArray
_buf: Array
def __init__(self, aval, buf):
self._aval = aval
self._buf = buf
aval = property(lambda self: self._aval)
shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)
def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
pytype_aval_mappings[MutableArray] = lambda x: x._aval
def mutable_array(init_val):
return mutable_array_p.bind(init_val)
mutable_array_p = Primitive('mutable_array')
class InternalMutableArrayEffect(effects.Effect):
pass
internal_mutable_array_effect = InternalMutableArrayEffect()
effects.control_flow_allowed_effects.add_type(InternalMutableArrayEffect)
@mutable_array_p.def_effectful_abstract_eval
def mutable_array_abstract_eval(init_aval):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
return AbstractRef(init_aval), {internal_mutable_array_effect}
@mutable_array_p.def_impl
def _mutable_array_impl(init_val):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
aval = raise_to_shaped(get_aval(init_val))
return MutableArray(AbstractRef(aval), init_val)
class AbstractToken(AbstractValue):
def join(self, other):
if isinstance(other, AbstractToken):
return self
else:
assert False, f"Cannot join {self} with {other}"
def str_short(self, short_dtypes=False): return 'Tok'
def to_tangent_aval(self): return self
abstract_token: AbstractToken = AbstractToken()
# Singleton shaped array used by all abstract tokens when shape/dtype is needed.
token_shaped_array: ShapedArray = ShapedArray((0,), np.dtype(np.bool_))
# Concrete token object
class Token:
# The underlying data wrapped by the token, could be used to threaded in and
# out of computations to build up data dependency.
_buf: Array
def __init__(self, buf):
self._buf = buf
def block_until_ready(self):
self._buf.block_until_ready()
pytype_aval_mappings[Token] = lambda _: abstract_token
def raise_to_shaped(aval: AbstractValue, weak_type=None):
aval_type = type(aval)
if aval_type is ShapedArray and weak_type is None:
return aval
if weak_type is None:
weak_type = getattr(aval, 'weak_type', False)
for typ in aval_type.__mro__:
handler = raise_to_shaped_mappings.get(typ)
if handler: return handler(aval, weak_type)
raise TypeError(type(aval))
raise_to_shaped_mappings: dict[type, Callable] = {
AbstractToken: lambda aval, _: aval,
Bot: lambda aval, _: aval,
UnshapedArray: lambda aval, _: aval,
ShapedArray: lambda aval, weak_type: ShapedArray(
aval.shape, aval.dtype, weak_type),
DConcreteArray: lambda aval, weak_type: DShapedArray(
aval.shape, aval.dtype, weak_type
),
}
### Operations on shapes and dimension sizes.
class InconclusiveDimensionOperation(Exception):
"""Raised when we cannot conclusively compute with symbolic dimensions."""
pass
def is_symbolic_dim(v: Any) -> bool:
"""Checks if a value is a symbolic dimension used for shape polymorphism.
This should be used very rarely, because symbolic dimensions overload all
operators, and should just work.
"""
return hasattr(v, "dimension_as_value")
def is_constant_dim(d: DimSize) -> bool:
# Whether the dimension is a static integer constant.
try:
operator.index(d)
return True
except:
return False
def is_dim(v: Any) -> bool:
return is_symbolic_dim(v) or is_constant_dim(v)
def is_constant_shape(s: Shape) -> bool:
# Whether the shape is a static constant.
return all(is_constant_dim(d) for d in s)
def definitely_equal_one_of_dim(d1: DimSize, dlist: Sequence[DimSize]) -> bool:
return any(definitely_equal(d1, d) for d in dlist)
def definitely_equal_shape(s1: Shape, s2: Shape) -> bool:
"""Check that two shapes are guaranteed to be element-wise equal.
In presence of dynamic shapes may return False even when the shapes may
be equal at runtime.
"""
return (len(s1) == len(s2) and
all(unsafe_map(definitely_equal, s1, s2)))
def divide_shape_sizes(s1: Shape, s2: Shape) -> DimSize:
"""Returns an integer "i" s.t., i * size(s2) == size(s1).
Raises InconclusiveDimensionOperation if there is no such integer."""
sz1 = math.prod(s1)
sz2 = math.prod(s2)
if definitely_equal(sz1, sz2): # Takes care of sz1 and sz2 being 0
return 1
q, r = divmod(sz1, sz2)
if isinstance(r, Tracer) or r != 0:
raise InconclusiveDimensionOperation(
f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}. "
f"The remainder {r} should be 0.")
return q
def cancel_divide_tracers(num, denom):
partition = lambda l: partition_list([isinstance(d, Tracer) for d in l], l)
num, num_tracers = partition(num)
denom, denom_tracers = partition(denom)
if num_tracers or denom_tracers:
factor = _cancel_divide(num_tracers, denom_tracers)
if factor is not None:
size1 = math.prod(num)
size2 = math.prod(denom)
if size1 == size2 or size2 != 0:
return factor * (size1 // size2 if size1 != size2 else 1)
def _cancel_divide(num, denom):
num = list(num)
for a in denom:
i = next((i for i, b in enumerate(num) if definitely_equal(a, b)), None)
if i is None:
break # couldn't cancel
del num[i]
else:
return math.prod(num)
def is_empty_shape(s: Shape) -> bool:
return any(definitely_equal(d, 0) for d in s)
def dilate_dim(d: DimSize, dilation: DimSize) -> DimSize:
"""max(0, 1 + dilation * (d - 1)).
Assumes dilation >= 1.
"""
if definitely_equal(dilation, 1): # fast path
return d
return max_dim(1 + dilation * (d - 1), 0)
def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
"""max(0, (d - window_size) // window_stride + 1)
If d < window_size, returns 0.
We assume window_size >= 1 and window_stride >= 1.
"""
# If d < window_size then (d - window_size) // window_stride < 0
return max_dim((d - window_size) // window_stride + 1, 0)
# TODO(necula): Deprecated Jan 2024, to be removed.
def non_negative_dim(d: DimSize) -> DimSize:
"""max(d, 0)."""
return max_dim(d, 0)
def min_dim(d1: DimSize, d2: DimSize) -> DimSize:
"""Like min(d1, d2) but for both constant and symbolic dimensions."""
d1_is_constant = is_constant_dim(d1)
if d1_is_constant and is_constant_dim(d2):
return min(d1, d2)
d1 = concrete_dim_or_error(d1, "argument `d1` of `core.min_dim`")
d2 = concrete_dim_or_error(d2, "argument `d2` of `core.min_dim`")
if d1_is_constant:
return d2.rmin(d1)
else:
return d1.min(d2)
def max_dim(d1: DimSize, d2: DimSize) -> DimSize:
"""Like max(d1, d2) but for both constant and symbolic dimensions."""
d1_is_constant = is_constant_dim(d1)
if d1_is_constant and is_constant_dim(d2):
return max(d1, d2)
d1 = concrete_dim_or_error(d1, "argument `d1` of `core.max_dim`")
d2 = concrete_dim_or_error(d2, "argument `d2` of `core.max_dim`")
if d1_is_constant:
return d2.rmax(d1)
else:
return d1.max(d2)
def dimension_as_value(d: DimSize):
"""Turns a dimension size into a JAX array.
This is the identity function for constant dimensions.
Has the same abstract value as Python constants.
"""
if isinstance(d, (int, Tracer, np.int32, np.int64)): return d
# For shape_poly._DimPolynomial
if hasattr(d, "dimension_as_value"): return d.dimension_as_value()
return operator.index(d)
class SomeTracer:
__slots__ = ()
def __repr__(self): return "[dynamic]"
def replace_tracer_for_error_message(obj):
# TODO(mattjj): Many ideas for improving this. Crawl the stack and see if
# there are user variables whose value is == to this object? Or search
# parameters of functions being transformed, at least? Or at least assign
# short unique ids to them?
if isinstance(obj, Tracer):
return SomeTracer()
else:
return obj
def evaluate_shape(shape: Shape, dim_vars: Sequence[str],
*dim_values: Array) -> Sequence[Array]:
"""Evaluates a shape possibly containing non-constants.
Args:
shape: the shape to evaluate.
dim_vars: the dimension variables names that may appear in `shape`.
dim_values: the dimension values corresponding to `dim_vars`.
Returns:
a tuple of JAX values corresponding to `shape`, of type
`dim_value_dtype`.
"""
env = dict(zip(dim_vars, dim_values))
def eval_one_dim(d: DimSize):
try:
return operator.index(d)
except:
# Is a _DimExpr
return d._evaluate(env) # type: ignore
return tuple(eval_one_dim(d) for d in shape)
def dim_value_dtype():
"""The dtype to be used for dimension values."""
return dtypes.canonicalize_dtype(np.int64)
def dim_constant(ct: int):
dtype = dim_value_dtype()
assert dtype in (np.int32, np.int64)
if dtype == np.int32:
return np.int32(ct)
elif dtype == np.int64:
return np.int64(ct)
def dim_value_aval() -> AbstractValue:
return ShapedArray((), dim_value_dtype(), weak_type=True)
# ------------------- Call -------------------
class CallPrimitive(Primitive):
multiple_results = True
call_primitive = True
def bind(self, fun, *args, **params):
call_bind_continuation, top_trace, fun_, tracers, params = (
call_bind_with_continuation(self, fun, *args, **params))
outs = top_trace.process_call(self, fun_, tracers, params)
return call_bind_continuation(outs)
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
if config.dynamic_shapes.value:
subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr))
return [subfun], new_params
def call_bind_with_continuation(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun_, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
fun_ = lu.annotate(fun_, fun.in_type)
def call_bind_continuation(outs):
return map(full_lower, apply_todos(env_trace_todo(), outs))
return call_bind_continuation, top_trace, fun_, tracers, params
@lu.transformation_with_aux
def process_env_traces_call(primitive: CallPrimitive, level: int,
params_tuple: tuple, *args):
outs = yield args, {}
params = dict(params_tuple)
todo = []
while True:
tracers = [x for x in outs if isinstance(x, Tracer) and x._trace.level > level]
if not tracers:
break
ans = max(tracers, key=_get_trace_level)
trace = ans._trace.main.with_cur_sublevel()
outs = map(trace.full_raise, outs)
outs, cur_todo = trace.post_process_call(primitive, outs, params)
todo.append(cur_todo)
yield outs, tuple(todo) # Ensure the aux output is immutable
def apply_todos(todos, outs):
todos_list = list(todos)
while todos_list:
outs = map(full_lower, todos_list.pop()(outs))
return outs
def call_impl(f: lu.WrappedFun, *args, **params):
del params # params parameterize the call primitive, not the function
with new_sublevel():
return f.call_wrapped(*args)
call_p: CallPrimitive = CallPrimitive('call')
call = call_p.bind
call_p.def_impl(call_impl)
class ClosedCallPrimitive(CallPrimitive):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
return [subfun], new_params
closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')
closed_call_p.def_impl(call_impl)
closed_call_p.def_effectful_abstract_eval(
lambda *_, call_jaxpr: (call_jaxpr.out_avals, call_jaxpr.effects))
outfeed_primitives: set[Primitive] = set()
def jaxpr_uses_outfeed(jaxpr: Jaxpr) -> bool:
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
return any(primitive_uses_outfeed(eqn.primitive, eqn.params)
for eqn in jaxpr.eqns)
def _param_uses_outfeed(param):
if type(param) is Jaxpr:
if jaxpr_uses_outfeed(param):
return True
elif type(param) is ClosedJaxpr:
if jaxpr_uses_outfeed(param.jaxpr):
return True
return False
def primitive_uses_outfeed(prim: Primitive, params: dict) -> bool:
if prim in outfeed_primitives:
return True
for param in params.values():
if isinstance(param, tuple):
if any(unsafe_map(_param_uses_outfeed, param)):
return True
elif _param_uses_outfeed(param):
return True
return False
# ------------------- Map -------------------
class MapPrimitive(Primitive):
multiple_results = True
map_primitive = True
def bind(self, fun, *args, **params):
assert len(params['in_axes']) == len(args)
return map_bind(self, fun, *args, **params)
def process(self, trace, fun, tracers, params):
return trace.process_map(self, fun, tracers, params)
def post_process(self, trace, out_tracers, params):
return trace.post_process_map(self, out_tracers, params)
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
axes = new_params.pop('out_axes')
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
return [subfun], new_params
def map_bind_with_continuation(primitive: MapPrimitive, fun, *args,
out_axes_thunk, **params):
# The new thunk depends deterministically on the old thunk and the wrapped
# function. Any caching already has to include the wrapped function as part
# of the key, so we only use the previous thunk for equality checks.
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
out_axes = out_axes_thunk()
_, out_axes_transforms = todo_and_xforms()
for t in out_axes_transforms:
out_axes = t(out_axes)
return out_axes
params = dict(params, out_axes_thunk=new_out_axes_thunk)
top_trace = find_top_trace(args)
fun, todo_and_xforms = process_env_traces_map(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
def map_bind_continuation(outs):
env_trace_todo, _ = todo_and_xforms()
return map(full_lower, apply_todos(env_trace_todo, outs))
return map_bind_continuation, top_trace, fun, tracers, params
def map_bind(primitive: MapPrimitive, fun, *args, **params):
map_bind_continuation, top_trace, fun, tracers, params = (
map_bind_with_continuation(primitive, fun, *args, **params))
return map_bind_continuation(
primitive.process(top_trace, fun, tracers, params))
@lu.transformation_with_aux
def process_env_traces_map(primitive: MapPrimitive, level: int,
params_tuple: tuple, *args):
outs = yield args, {}
params = dict(params_tuple)
todo = []
out_axes_transforms = []
while True:
tracers = [x for x in outs if isinstance(x, Tracer)
and (level is None or x._trace.level > level)]
if not tracers:
break
ans = max(tracers, key=_get_trace_level)
trace = ans._trace.main.with_cur_sublevel()
outs = map(trace.full_raise, outs)
outs, (cur_todo, cur_xform) = primitive.post_process(trace, outs, params)
todo.append(cur_todo)
out_axes_transforms.append(cur_xform)
yield outs, (tuple(todo), tuple(out_axes_transforms))
def mapped_aval(size: AxisSize, axis: int | None,
aval: AbstractValue) -> AbstractValue:
handler, _ = aval_mapping_handlers.get(type(aval), (None, None))
if handler is not None:
return handler(size, axis, aval)
else:
raise TypeError(f"no mapping handler for {aval} of type {type(aval)}")
def unmapped_aval(size: AxisSize, axis_name, axis: int | None,
aval: AbstractValue) -> AbstractValue:
_, handler = aval_mapping_handlers.get(type(aval), (None, None))
if handler is not None:
return handler(size, axis_name, axis, aval)
else:
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")
def _map_shaped_array(
size: int, axis: int | None, aval: ShapedArray) -> ShapedArray:
assert axis is None or aval.shape[axis] == size
# TODO: Extend the named shape
if axis is None: return aval
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
weak_type=aval.weak_type)
def _unmap_shaped_array(
size: int, axis_name: AxisName, axis: int | None, aval: ShapedArray
) -> ShapedArray:
if axis is None: return aval
elif type(axis) is int:
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
weak_type=aval.weak_type)
else: raise TypeError(axis)
def _map_dshaped_array(
size: AxisSize, axis: int | None, aval: DShapedArray) -> DShapedArray:
if axis is None: return aval
return DShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
aval.weak_type)
def _unmap_dshaped_array(
size: AxisSize, axis_name: AxisName, axis: int | None, aval: DShapedArray
) -> DShapedArray:
if axis is None: return aval
elif type(axis) is int:
return DShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
weak_type=aval.weak_type)
else:
raise TypeError(axis)
AvalMapHandlerPair = tuple[Callable, Callable]
aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
ConcreteArray: (_map_shaped_array, _unmap_shaped_array),
AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a)
}
@contextmanager
def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
frame = AxisEnvFrame(axis_name, size, tag)
ts = thread_local_state.trace_state
ts.axis_env.append(frame)
config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
try:
yield
finally:
ts.axis_env.pop()
config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
@contextmanager
def extend_axis_env_nd(axes: Iterable[tuple[AxisName, int]], tag: Any = None):
frames = [AxisEnvFrame(axis_name, size, tag) for axis_name, size in axes]
ts = thread_local_state.trace_state
ts.axis_env.extend(frames)
config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
try:
yield
finally:
for _ in frames: ts.axis_env.pop()
config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
@contextmanager
def stash_axis_env():
"Promise that a function or with-suite does not depend implicitly on axis env"
# If the promise is broken, then a NameError about an unbound axis name will
# be raised.
ts = thread_local_state.trace_state
prev_axis_env, ts.axis_env = ts.axis_env, []
config.update_thread_local_jit_state(axis_env_state=())
try:
yield
finally:
ts.axis_env = prev_axis_env
config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
# When a mapped function is given no axis name, we generate a name object based
# on the id of the function object. Collisions aren't important because this
# name can't be used in collectives, as user code never gets a ref to this
# object. We don't want to use the function object itself because that might
# persist references to the function object.
# TODO(mattjj): revisit this unique axis name strategy
@total_ordering
class _TempAxisName:
def __init__(self, obj):
self.id = id(obj)
def __repr__(self):
return f'<axis {hex(self.id)}>'
def __hash__(self):
return hash(self.id)
def __eq__(self, other):
return type(other) is _TempAxisName and self.id == other.id
def __lt__(self, other):
return type(other) is _TempAxisName and self.id < other.id
def axis_frame(axis_name: AxisName, main_trace: MainTrace | None = None
) -> AxisEnvFrame:
frames = thread_local_state.trace_state.axis_env
for frame in reversed(frames):
if (frame.name == axis_name and
(main_trace is None or frame.main_trace is main_trace)):
return frame
named_axes = [frame.name for frame in reversed(frames)
if not isinstance(frame.name, _TempAxisName)]
raise NameError(
f'unbound axis name: {axis_name}. The following axis names (e.g. defined '
f'by pmap) are available to collective operations: {named_axes}')
@dataclass(frozen=True)
class NamedAxisEffect(effects.Effect):
"""A side-effect introducing a new named axis into the current scope."""
name: AxisName
effects.control_flow_allowed_effects.add_type(NamedAxisEffect)
effects.custom_derivatives_allowed_effects.add_type(NamedAxisEffect)
effects.lowerable_effects.add_type(NamedAxisEffect)
effects.remat_allowed_effects.add_type(NamedAxisEffect)
def filter_named_axis_effects(
effects: Effects, names: Collection[AxisName]
) -> Effects:
return {e for e in effects
if not isinstance(e, NamedAxisEffect) or e.name not in names}
def remove_named_axis_effects(
jaxpr: Jaxpr, names: Collection[AxisName]
) -> Jaxpr:
if not names or not jaxpr.effects:
return jaxpr
return jaxpr.replace(effects=filter_named_axis_effects(jaxpr.effects, names))
ParamDict = dict[str, Any]
AxisSubst = Callable[[AxisName], tuple[AxisName, ...]]
class NameGatheringSubst:
def __init__(self):
self.axis_names = set()
def __call__(self, axis_name):
self.axis_names.add(axis_name)
return (axis_name,)
def used_axis_names(primitive: Primitive, params: ParamDict) -> set[AxisName]:
subst = NameGatheringSubst()
subst_axis_names(primitive, params, subst)
return subst.axis_names
def subst_axis_names(primitive: Primitive, params: ParamDict, subst: AxisSubst, traverse: bool = True) -> ParamDict:
if primitive in axis_substitution_rules:
return axis_substitution_rules[primitive](params, subst, traverse)
if not traverse:
return params
# Default implementation: substitute names in all jaxpr parameters
if isinstance(primitive, MapPrimitive):
def shadowed_subst(name):
return (name,) if name == params['axis_name'] else subst(name)
else:
shadowed_subst = subst
jaxpr_params = [(n, v) for n, v in params.items() if isinstance(v, (Jaxpr, ClosedJaxpr))]
if not jaxpr_params:
return params
new_params = dict(params)
for name, jaxpr in jaxpr_params:
new_params[name] = subst_axis_names_jaxpr(jaxpr, shadowed_subst)
return new_params
class DuplicateAxisNameError(Exception):
def __init__(self, var):
self.var = var
self.eqn = None
def subst_axis_names_effects(effects: Set[Effect], subst: AxisSubst) -> Set[Effect]:
new_effects = set[Effect]()
for e in effects:
if isinstance(e, NamedAxisEffect):
new_effects.update(map(NamedAxisEffect, subst(e.name)))
else:
new_effects.add(e)
return new_effects
def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: dict[Var, Var]) -> Var:
# Var identity is load-bearing, so we can't have duplicates!
if isinstance(v, DropVar): return v
assert v not in var_map
var_map[v] = v
return v
def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: dict[Var, Var]) -> JaxprEqn:
invars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in eqn.invars]
try:
outvars = [subst_axis_names_var(v, subst, var_map) for v in eqn.outvars]
except DuplicateAxisNameError as e:
e.eqn = eqn
raise
params = subst_axis_names(eqn.primitive, eqn.params, subst)
effects = subst_axis_names_effects(eqn.effects, subst)
return eqn.replace(invars=invars, outvars=outvars, params=params, effects=effects)
def do_subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst):
consts = None
if isinstance(jaxpr, ClosedJaxpr):
consts = jaxpr.consts
jaxpr = jaxpr.jaxpr
var_map: dict[Var, Var] = {}
invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars] # type: ignore[union-attr]
constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars] # type: ignore[union-attr]
eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns]
outvars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars] # type: ignore[union-attr]
effects = subst_axis_names_effects(jaxpr.effects, subst)
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, effects)
if consts is not None:
return ClosedJaxpr(new_jaxpr, consts)
return new_jaxpr
def used_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr):
return {e.name for e in jaxpr.effects if isinstance(e, NamedAxisEffect)}
def subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst):
if isinstance(subst, NameGatheringSubst): # This is a common case, so we optimize it!
subst.axis_names |= used_axis_names_jaxpr(jaxpr)
return jaxpr
return do_subst_axis_names_jaxpr(jaxpr, subst)
def replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: Effects):
return _replace_jaxpr_effects(jaxpr, frozenset(effects))
@weakref_lru_cache
def _replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: frozenset[Effect]):
return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(effects=set(effects)))
axis_substitution_rules: dict[Primitive, Callable[[ParamDict, AxisSubst, bool], ParamDict]] = {}
# ------------------- AxisPrimitive -------------------
# Primitives that store axis names in params and want those axis names to
# participate in dispatch should subclass AxisPrimitive.
class AxisPrimitive(Primitive):
def bind(self, *args, **params):
top_trace = find_top_trace(args)
axis_main = max((axis_frame(a).main_trace for a in used_axis_names(self, params)),
default=None, key=lambda t: getattr(t, 'level', -1))
top_trace = (top_trace if not axis_main or axis_main.level < top_trace.level
else axis_main.with_cur_sublevel())
return self.bind_with_trace(top_trace, args, params)
# ------------------- Jaxpr checking -------------------
def typecheck(aval: AbstractValue, x) -> bool:
return typecompat(aval, get_aval(x))
def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool:
"""Determine whether `aval` conforms to `aval_ref`. Ignores weak_type."""
try:
return typematch(aval_ref, lattice_join(aval_ref, aval))
except TypeError:
return False
def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
"""Determine whether `aval1` and `aval2` are equivalent. Ignores weak_type."""
if aval1 == aval2: return True
# unequal avals may still represent the same type, because type is represented
# by avals at the shaped level, and because weak type tags aren't considered
# part of the type
return (raise_to_shaped(aval1, weak_type=False) ==
raise_to_shaped(aval2, weak_type=False))
class JaxprTypeError(TypeError): pass
custom_typechecks: dict[Primitive, Callable] = {}
def _check_closed_call(_, *in_atoms, call_jaxpr):
in_avals = [x.aval for x in in_atoms]
if not all(map(typecompat, call_jaxpr.in_avals, in_avals)):
raise JaxprTypeError("Closed call in_avals mismatch")
return call_jaxpr.out_avals, call_jaxpr.effects
custom_typechecks[closed_call_p] = _check_closed_call
def check_jaxpr(jaxpr: Jaxpr):
"""Checks well-formedness of a jaxpr.
Specifically, check that:
- variables that are read are bound beforehand
- variables are typed equally throughout a jaxpr
- variable type annotations are compatible with their binding expression
Raises `JaxprTypeError` if `jaxpr` is determined invalid. Returns `None`
otherwise.
"""
@functools.cache
def ctx_factory():
ctx = JaxprPpContext()
pp_settings = JaxprPpSettings()
try: pp_jaxpr(jaxpr, ctx, pp_settings) # side-effect on ctx, build variable names
except: pass
return ctx, pp_settings
try:
_check_jaxpr(ctx_factory, jaxpr)
except JaxprTypeError as e:
ctx, pp_settings = ctx_factory()
if len(e.args) == 2:
msg, eqnidx = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqnidx - 10, eqnidx + 10, ctx,
pp_settings))
else:
msg, = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20, ctx, pp_settings))
msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
raise JaxprTypeError(msg) from None
# Run key reuse checker after validating jaxpr:
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error
check_key_reuse_jaxpr(jaxpr)
def _check_jaxpr(
ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]],
jaxpr: Jaxpr
) -> None:
# Use set of variables to types to check that variables are in scope.
env: set[Var] = set()
def read(x: Atom) -> Atom:
# Check the type annotation is itself well-typed.
check_type(ctx_factory, env, x.aval)
if isinstance(x, Var):
# Check the variable is in-scope and consistently typed.
if x not in env:
ctx, _ = ctx_factory()
raise JaxprTypeError(f"Variable '{pp_var(x, ctx)}' not defined")
return x
elif isinstance(x, Literal):
# Check that the literal matches its type annotation.
if not typecheck(x.aval, x.val):
ctx, _ = ctx_factory()
raise JaxprTypeError(
f"Literal value {x.val} does not match its type annotation "
f"{pp_aval(x.aval, ctx)}")
return x
else:
assert False, "syntactically invalid jaxpr"
def write(v: Var, a: AbstractValue) -> None:
assert isinstance(v, Var), "syntactically invalid jaxpr"
# Check the type annotation of the binder is itself well-typed.
check_type(ctx_factory, env, v.aval)
# Check that the variable is not already bound.
if v in env:
ctx, _ = ctx_factory()
raise JaxprTypeError(f"Variable '{pp_var(v, ctx)}' already bound")
# Check that the computed type is consistent with the binder annotation.
if not typematch(v.aval, a):
ctx, _ = ctx_factory()
raise JaxprTypeError(
f"Value for variable '{pp_var(v, ctx)}' inconsistently typed "
f"as {pp_aval(a, ctx)} for let-binder of type {pp_aval(v.aval, ctx)}")
# If the variable is not a DropVar, add it to the environment.
if not isinstance(v, DropVar):
env.add(v)
# Check type annotations on lambda binders.
for v in it.chain(jaxpr.constvars, jaxpr.invars):
check_type(ctx_factory, env, v.aval)
write(v, v.aval)
# Check each eqn.
sentinel = object()
in_idx = {v: i for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))}
for eqn_idx, eqn in enumerate(jaxpr.eqns):
prim = eqn.primitive
try:
in_atoms = map(read, eqn.invars)
in_avals = [x.aval for x in in_atoms] # use in_atoms for dyn shapes
# Compute the type of the primitive application.
if prim in custom_typechecks:
out_type, eqn_effects = custom_typechecks[prim](
ctx_factory, *in_atoms, **eqn.params)
elif prim.call_primitive:
out_type, eqn_effects = _check_call(ctx_factory, prim, in_atoms,
eqn.params)
elif prim.map_primitive:
out_type, eqn_effects = _check_map(ctx_factory, prim, in_avals,
eqn.params)
else:
out_type, eqn_effects = check_eqn(prim, in_avals, eqn.params)
# Check the computed effect type matches the eqn's annotation, and is
# included in the jaxpr's annotation.
if prim is mutable_array_p:
outvar, = eqn.outvars
in_idx[outvar] = None # type: ignore
if eqn.effects != eqn_effects:
raise JaxprTypeError("Inferred effects do not match equation effects. "
f"Equation effects: {eqn.effects}. "
f"Inferred effects: {eqn_effects}")
for eff in eqn.effects:
if isinstance(eff, effects.JaxprInputEffect):
eqn_invar = eqn.invars[eff.input_index]
if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel:
raise JaxprTypeError(
"Invalid `JaxprInputEffect`: must correspond to a jaxpr invar")
jaxpr_effect = eff.replace(input_index=jaxpr_index)
if jaxpr_effect not in jaxpr.effects:
raise JaxprTypeError(
"Invalid `JaxprInputEffect`: must be present in jaxpr. "
f"{jaxpr_effect} is not in {jaxpr.effects}.")
elif isinstance(eff, NamedAxisEffect):
# It is valid for a primitive to discharge the named axis effect.
continue
elif eff not in jaxpr.effects:
raise JaxprTypeError("Equation effect not present in jaxpr effects. "
f"Equation effect: {eff}. "
f"Jaxpr effects: {jaxpr.effects}")
# Check out_type matches the let-binders' annotation (after substitution).
out_type = substitute_vars_in_output_ty(out_type, eqn.invars, eqn.outvars)
map(write, eqn.outvars, out_type)
except JaxprTypeError as e:
ctx, settings = ctx_factory()
msg, = e.args
src = source_info_util.summarize(eqn.source_info)
msg = "\n\n".join([msg, "in equation:", str(pp.nest(2, pp_eqn(eqn, ctx, settings))),
f"from source: {src}"])
raise JaxprTypeError(msg, eqn_idx) from None
# TODO(mattjj): include output type annotation on jaxpr and check it here
map(read, jaxpr.outvars)
def check_type(
ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]],
env: set[Var],
ty: AbstractValue,
) -> None:
if isinstance(ty, DShapedArray):
# Check all elements in the shape tuple are well-typed.
for d in ty.shape:
if (isinstance(d, int) or
isinstance(d, DArray) and not d.shape and type(d.dtype) == bint):
continue
elif isinstance(d, Var):
if d not in env:
ctx, _ = ctx_factory()
raise JaxprTypeError(f"unbound axis size: '{pp_var(d, ctx)}'")
if not isinstance(d.aval, (ShapedArray, DShapedArray)):
raise JaxprTypeError(f"axis size with unexpected type annotation: "
f"{d.aval} of type {type(d.aval)}")
if isinstance(d.aval, ShapedArray):
shape, dtype = d.aval.shape, d.aval.dtype
if shape: raise JaxprTypeError(f"axis size nonscalar: {d.aval}")
if not dtypes.issubdtype(dtype, np.integer):
raise JaxprTypeError(f"axis size with non-integer dtype: {d.aval}")
else:
assert isinstance(d.aval, DShapedArray)
shape, dtype = d.aval.shape, d.aval.dtype
if shape: raise JaxprTypeError(f"axis size nonscalar: {d.aval}")
if type(dtype) is not bint:
raise JaxprTypeError(
f"DArray axis size with non-bint dtype: {d.aval}")
else:
raise JaxprTypeError(f"unexpected type in shape: {type(d)}")
else:
return # Except in above case(s), all syntactic forms are valid
def substitute_vars_in_output_ty(
out_type: Sequence[AbstractValue], # shapes may contain InDBIdx / OutDBIdx
in_atoms: Sequence[Atom],
out_binders: Sequence[Var],
) -> list[AbstractValue]: # shapes may contain Vars
in_atoms = [x.val if type(x) is Literal else x for x in in_atoms]
result = []
for aval in out_type:
if type(aval) is DShapedArray:
shape = [in_atoms[d.val] if type(d) is InDBIdx else
out_binders[d.val] if type(d) is OutDBIdx else
d for d in aval.shape]
aval = aval.update(shape=tuple(shape))
result.append(aval)
return result
def check_eqn(prim, in_avals, params):
for jaxpr in jaxprs_in_params(params):
check_jaxpr(jaxpr)
out_avals, effects = prim.abstract_eval(*in_avals, **params)
if not prim.multiple_results:
out_avals = [out_avals]
return out_avals, effects
def _check_call(ctx_factory, prim, in_atoms, params):
if "call_jaxpr" not in params:
raise JaxprTypeError(
f"Call primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]
if len(in_atoms) != len(call_jaxpr.invars):
raise JaxprTypeError(f"Call primitive {prim} with {len(in_atoms)} "
f"operands cannot call jaxpr with "
f"{len(call_jaxpr.invars)} inputs")
# Check `call_jaxpr` can be applied to in_atoms.
env: dict[Var, Atom] = {}
def substitute(aval: AbstractValue):
if isinstance(aval, DShapedArray):
aval = aval.update(shape=tuple(env.get(d, d) for d in aval.shape)) # type: ignore
return aval
for v, x in zip(call_jaxpr.invars, in_atoms):
if not typecompat(substitute(v.aval), x.aval):
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
f"{x.aval} to jaxpr expecting type "
f"{substitute(v.aval)}")
env[v] = x if type(x) is Var else x.val
_check_jaxpr(ctx_factory, call_jaxpr)
invars, outvars = call_jaxpr.invars, call_jaxpr.outvars
in_map : dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)}
out_map: dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars)
if type(x) is Var}
out_avals = [x.aval for x in call_jaxpr.outvars]
out_type = [a.update(shape=tuple(in_map.get(d, out_map.get(d))
if type(d) is Var else d for d in a.shape))
if type(a) is DShapedArray else a for a in out_avals]
return out_type, call_jaxpr.effects
def _check_map(ctx_factory, prim, in_avals, params):
if "call_jaxpr" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]
ordered_effects_ = effects.ordered_effects.filter_in(call_jaxpr.effects)
if ordered_effects_:
raise JaxprTypeError(
f"Map primitive {prim} mapping ordered effects: {ordered_effects_}")
if "axis_size" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'axis_size' parameter")
axis_size = params["axis_size"]
if "axis_name" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'axis_name' parameter")
axis_name = params["axis_name"]
if "in_axes" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'in_axes' parameter")
in_axes = params["in_axes"]
if "out_axes" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'out_axes' parameter")
out_axes = params["out_axes"]
binder_avals = [unmapped_aval(axis_size, axis_name, in_axis, v.aval)
if in_axis is not None else v.aval
for v, in_axis in zip(call_jaxpr.invars, in_axes)]
for binder_aval, in_aval in zip(binder_avals, in_avals):
if not typecompat(binder_aval, in_aval):
raise JaxprTypeError(f"Call primitive {prim} passes operand {in_aval} "
f"to jaxpr expecting {binder_aval}")
with extend_axis_env(params['axis_name'], axis_size, None):
_check_jaxpr(ctx_factory, call_jaxpr)
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval)
if out_axis is not None else aval
for aval, out_axis in zip(mapped_out_avals, out_axes)]
return out_avals, filter_named_axis_effects(call_jaxpr.effects, {axis_name})
# ------------------- Jaxpr printed representation -------------------
def pp_toplevel_jaxpr(jaxpr_to_print, *, source_info=False, print_shapes=True,
custom_pp_eqn_rules=True, name_stack=False,
print_effects: bool = False) -> pp.Doc:
context = JaxprPpContext()
settings = JaxprPpSettings(
source_info=source_info,
print_shapes=print_shapes,
custom_pp_eqn_rules=custom_pp_eqn_rules,
name_stack=name_stack,
print_effects=print_effects)
# Compute how many times each jaxpr is used.
names = defaultdict[Jaxpr, str](lambda: "jaxpr")
jaxpr_counts = Counter[Jaxpr]()
s = deque([jaxpr_to_print])
while s:
jaxpr = s.popleft()
jaxpr_counts[jaxpr] += 1
for eqn in jaxpr.eqns:
# TODO(slebedev): Come up with a more elaborate heuristic for name=.
name = eqn.params.get("name")
if name is None:
s.extend(jaxprs_in_params(eqn.params))
continue
name = name.strip("<>") # <lambda> -> lambda
for subjaxpr in jaxprs_in_params(eqn.params):
s.append(subjaxpr)
names.setdefault(subjaxpr, name)
# Pull jaxprs occurring more than once to the top-level, making sure
# that their names are unique.
docs = []
name_counts = Counter[str]()
for jaxpr, c in jaxpr_counts.items():
if c == 1:
continue
name = names[jaxpr]
if (count := name_counts[name]) > 0:
name_counts[name] += 1
name += str(count)
name_counts[name] += 1
else:
name_counts[name] += 1
docs.append(pp_top_level_jaxpr(name, jaxpr, context, settings))
context.used_names.add(name)
context.top_level_jaxprs[jaxpr] = name
docs.append(pp_jaxpr(jaxpr_to_print, context, settings))
return pp.concat(docs)
class JaxprPpSettings(NamedTuple):
print_shapes: bool = True
source_info: bool = False
name_stack: bool = False
custom_pp_eqn_rules: bool = True
print_effects: bool = False
def _encode_digits_alphabetic(n: int) -> str:
if n == -1:
return '*'
s = ''
while len(s) == 0 or n:
n, i = n // 26, n % 26
s = chr(97 + i % 26) + s
return s
# A JaxprPpContext allows us to globally uniquify variable names within nested
# Jaxprs.
class JaxprPpContext:
var_names: defaultdict[Var, str]
used_names: MutableSet[str]
top_level_jaxprs: MutableMapping[Jaxpr, str]
def __init__(self) -> None:
self.top_level_jaxprs = {}
self.used_names = set()
fresh_names: Iterator[str] = (
name
for i in it.count()
if (name := _encode_digits_alphabetic(i)) not in self.used_names
)
self.var_names = defaultdict(fresh_names.__next__)
def pp_var(v: Var | Literal, context: JaxprPpContext) -> str:
if isinstance(v, (Literal, DropVar)): return str(v)
return f"{context.var_names[v]}{v.suffix}"
def pp_aval(a: AbstractValue, context: JaxprPpContext) -> str:
if isinstance(a, DShapedArray):
shape = [pp_var(d, context) if type(d) is Var else str(d) for d in a.shape]
dtype = _short_dtype_name(a.dtype)
return f'{dtype}[{",".join(shape)}]'
else:
return a.str_short(short_dtypes=True)
def pp_vars(vs: Sequence[Any], context: JaxprPpContext,
*, separator="", print_shapes: bool = False) -> pp.Doc:
if print_shapes:
return pp.nest(2, pp.group(
pp.join(pp.text(separator) + pp.group(pp.brk()), [
pp.text(pp_var(v, context)) +
pp.type_annotation(pp.text(":" + pp_aval(v.aval, context)))
for v in vs
])
))
else:
return pp.nest(2, pp.group(
pp.join(pp.text(separator) + pp.group(pp.brk()),
[pp.text(pp_var(v, context)) for v in vs])
))
def pp_kv_pair(k:str, v: Any, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
pp_v = pp_jaxprs(v, context, settings)
elif isinstance(v, Jaxpr):
pp_v = pp_jaxpr(v, context, settings)
elif isinstance(v, ClosedJaxpr):
pp_v = pp_jaxpr(v.jaxpr, context, settings)
else:
pp_v = pp.text(str(v))
return pp.text(f'{k}=') + pp_v
def pp_kv_pairs(kv_pairs, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
if not kv_pairs:
return pp.nil()
return pp.group(
pp.nest(2, pp.concat([
pp.text("["), pp.brk(""),
pp.join(pp.brk(), [pp_kv_pair(k, v, context, settings) for k, v in kv_pairs])
]))
+ pp.brk("") + pp.text("]")
)
def pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings
) -> pp.Doc:
rule = (_pp_eqn if not settings.custom_pp_eqn_rules else
pp_eqn_rules.get(eqn.primitive, _pp_eqn))
doc = rule(eqn, context, settings) # type: ignore[operator]
user_frame = source_info_util.user_frame(eqn.source_info)
return doc if user_frame is None else pp.source_map(doc, user_frame)
def _pp_eqn(eqn, context, settings, params=None) -> pp.Doc:
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
if params is None:
params = sorted(eqn.params)
name_stack_annotation = f'[{eqn.source_info.name_stack}]' if settings.name_stack else None
lhs = pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
rhs = [pp.text(eqn.primitive.name, annotation=name_stack_annotation),
pp_kv_pairs([(p, eqn.params[p]) for p in params], context, settings),
pp.text(" ") + pp_vars(eqn.invars, context)]
if lhs.format():
return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs])
else:
return pp.concat(rhs)
CustomPpEqnRule = Callable[[JaxprEqn, JaxprPpContext, JaxprPpSettings], pp.Doc]
pp_eqn_rules: dict[Primitive, CustomPpEqnRule] = {}
def pp_eqns(eqns, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
return pp.join(
pp.brk("; "),
[pp_eqn(e, context, settings) for e in eqns])
def _compact_eqn_should_include(k: str, v: Any) -> bool:
if k == 'branches': return False
if isinstance(v, (Jaxpr, ClosedJaxpr)): return False
if (isinstance(v, tuple) and
any(isinstance(e, (Jaxpr, ClosedJaxpr)) for e in v)):
return False
return True
def str_eqn_compact(primitive: Primitive, params: dict[Any, Any]) -> str:
"Compact equation to string conversion used in HLO metadata."
if primitive in custom_str_eqn_compact_rules:
return custom_str_eqn_compact_rules[primitive](primitive, params)
primitive_name = primitive.name
kvs = " ".join(f"{k}={v}" for k, v in params.items()
if _compact_eqn_should_include(k, v))
return f"{primitive_name}[{kvs}]" if len(kvs) > 0 else primitive_name
custom_str_eqn_compact_rules: dict[
Primitive, Callable[[Primitive, dict[Any, Any]], str]
] = {}
def pp_jaxpr_skeleton(jaxpr, eqns_fn, context: JaxprPpContext,
settings: JaxprPpSettings) -> pp.Doc:
constvars = pp_vars(jaxpr.constvars, context, print_shapes=settings.print_shapes)
invars = pp_vars(jaxpr.invars, context, print_shapes=settings.print_shapes)
eqns = eqns_fn()
outvars = pp.concat([
pp.text("("), pp_vars(jaxpr.outvars, context, separator=","),
pp.text(")" if len(jaxpr.outvars) != 1 else ",)")])
if settings.print_effects:
# TODO(sharadmv): render an entire signature here
eff_text = [pp.text(" : { ")]
for i, eff in enumerate(jaxpr.effects):
if i > 0:
eff_text.append(pp.text(", "))
if isinstance(eff, effects.JaxprInputEffect):
index = eff.input_index
all_vars = [*jaxpr.constvars, *jaxpr.invars]
eff_text.append(pp_effect(eff.replace(input_index=all_vars[index]),
context))
else:
eff_text.append(pp_effect(eff, context))
eff_text.append(pp.text(" }"))
else:
eff_text = []
return pp.group(pp.nest(2, pp.concat([
pp.text("{ "), pp.keyword(pp.text("lambda ")),
constvars, pp.text("; "), invars,
pp.text(". "), pp.keyword(pp.text("let")),
pp.nest(2, pp.brk() + eqns), pp.brk(),
pp.keyword(pp.text("in ")), outvars,
pp.concat(eff_text)
])) + pp.text(" }"))
def pp_top_level_jaxpr(
name: str,
jaxpr: Jaxpr,
context: JaxprPpContext,
settings: JaxprPpSettings,
) -> pp.Doc:
return pp.concat([
pp.text("let " + name + " = "),
pp_jaxpr(jaxpr, context, settings),
pp.text(" in"),
pp.brk(),
])
def pp_jaxpr(
jaxpr: Jaxpr,
context: JaxprPpContext,
settings: JaxprPpSettings,
) -> pp.Doc:
if name := context.top_level_jaxprs.get(jaxpr):
return pp.text(name)
eqns_fn = lambda: pp_eqns(jaxpr.eqns, context, settings)
return pp_jaxpr_skeleton(jaxpr, eqns_fn, context, settings)
def pp_jaxprs(jaxprs, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
return pp.group(pp.nest(2, pp.concat([
pp.text('('), pp.brk(""),
pp.join(pp.brk(), map(lambda x: pp_jaxpr(x, context, settings), jaxprs))]
)) + pp.brk("") + pp.text(')')
)
def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, context: JaxprPpContext,
settings: JaxprPpSettings) -> pp.Doc:
lo = max(lo, 0)
hi = max(lo, min(hi, len(jaxpr.eqns)))
eqns = jaxpr.eqns[lo:hi]
def eqns_fn():
pps = []
if len(eqns) == 0 and len(jaxpr.eqns) != 0:
pps.append(pp.text('...'))
else:
if lo != 0:
pps.append(pp.text('...'))
pps.extend(map((lambda e: pp_eqn(e, context, settings)), eqns))
if hi != len(jaxpr.eqns):
pps.append(pp.text('...'))
return pp.join(pp.brk("; "), pps)
return pp_jaxpr_skeleton(jaxpr, eqns_fn, context, settings)
def pp_effect(effect: Effect, context: JaxprPpContext) -> pp.Doc:
if hasattr(effect, "_pretty_print"):
return effect._pretty_print(context)
return pp.text(str(effect))
# ------------------- Jaxpr util -------------------
def last_used(jaxpr: Jaxpr) -> dict[Var, JaxprEqn | None]:
"""Returns a mapping from every var in jaxpr to what equation uses it last."""
last_used: dict[Var, JaxprEqn | None] = {
v: None for v in jaxpr.outvars if not isinstance(v, Literal)}
for eqn in reversed(jaxpr.eqns):
for v in eqn.invars:
if not isinstance(v, Literal) and v not in last_used:
last_used[v] = eqn
return last_used
def clean_up_dead_vars(eqn: JaxprEqn, env: dict[Var, Any],
last_used: dict[Var, JaxprEqn | None]):
"""Remove all eqn.invars from env if eqn is the last time they were used."""
for v in {v for v in eqn.invars if not isinstance(v, Literal)}:
if last_used[v] is eqn:
# Delete ref to variable when it is no longer needed by next equations.
del env[v]
# Used in shard_map for converting avals
shard_aval_handlers = {} # type: ignore
unshard_aval_handlers = {} # type: ignore