diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 78f44196b..be5012aa5 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -53,27 +53,27 @@ jobs: - python-version: 3.6 os: ubuntu-latest enable-x64: 0 - enable-omnistaging: 0 + enable-omnistaging: 1 package-overrides: "none" num_generated_cases: 25 - python-version: 3.7 os: ubuntu-latest enable-x64: 1 - enable-omnistaging: 0 + enable-omnistaging: 1 # Test experimental NumPy dispatch package-overrides: "git+https://github.com/seberg/numpy-dispatch.git" num_generated_cases: 25 - python-version: 3.6 os: ubuntu-latest enable-x64: 1 - enable-omnistaging: 0 + enable-omnistaging: 1 # Test with numpy version that matches Google-internal version package-overrides: "numpy==1.16.4" num_generated_cases: 10 - python-version: 3.7 os: ubuntu-latest enable-x64: 0 - enable-omnistaging: 1 + enable-omnistaging: 0 package-overrides: "none" num_generated_cases: 8 steps: diff --git a/docs/jaxpr.rst b/docs/jaxpr.rst index 7e4e778cc..469bdd5d0 100644 --- a/docs/jaxpr.rst +++ b/docs/jaxpr.rst @@ -164,34 +164,15 @@ before (with two input vars, one for each element of the input tuple) Constant Vars --------------- +------------- -ConstVars arise when the computation contains array constants, either -from the Python program, or from constant-folding. For example, the function -``func6`` below +Some values in jaxprs are constants, in that their value does not depend on the +jaxpr's arguments. When these values are scalars they are represented directly +in the jaxpr equations; non-scalar array constants are instead hoisted out to +the top-level jaxpr, where they correspond to constant variables ("constvars"). +These constvars differ from the other jaxpr parameters ("invars") only as a +bookkeeping convention. ->>> def func5(first, second): -... temp = first + jnp.sin(second) * 3. - jnp.ones(8) -... return temp -... ->>> def func6(first): -... return func5(first, jnp.ones(8)) -... - -JAX produces the following jaxpr - ->>> print(make_jaxpr(func6)(jnp.ones(8))) -{ lambda b d ; a. - let c = add a b - e = sub c d - in (e,) } - -When tracing ``func6``, the function ``func5`` is invoked with a constant value -(``np.ones(8)``) for the second argument. As a result, the sub-expression -``jnp.sin(second) * 3.`` is constant-folded. -There are two ConstVars, ``b`` (standing for ``jnp.sin(second) * 3.``) and ``d`` -(standing for ``jnp.ones(8)``). Unfortunately, it is not easy to tell from the -jaxpr notation what constants the constant variables stand for. Higher-order primitives ----------------------- @@ -293,44 +274,25 @@ contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar` >>> def func8(arg1, arg2): # arg2 is a pair ... return lax.cond(arg1 >= 0., ... lambda xtrue: xtrue[0], -... lambda xfalse: jnp.ones(1) + xfalse[1], +... lambda xfalse: jnp.array([1]) + xfalse[1], ... arg2) ... >>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.))) -{ lambda f ; a b c. - let d = ge a 0.0 - e = convert_element_type[ new_dtype=int32 - old_dtype=bool ] d - g = cond[ branches=( { lambda ; c a b. - let d = add c b - in (d,) } - { lambda ; e_ a b. - let +{ lambda a ; b c d. + let e = ge b 0.0 + f = convert_element_type[ new_dtype=int32 + old_dtype=bool ] e + g = cond[ branches=( { lambda ; a b c. + let d = convert_element_type[ new_dtype=float32 + old_dtype=int32 ] a + e = add d c + in (e,) } + { lambda ; f_ a b. + let in (a,) } ) - linear=(False, False, False) ] e f b c + linear=(False, False, False) ] f a c d in (g,) } -The top-level jaxpr has one `constvar` ``f`` (corresponding to -``jnp.ones(1)`` from the body of the first (false) branch) and three -input variables ``a b c`` (corresponding to ``arg1`` and the two -elements of ``arg2``; note that ``arg2`` has been flattened). The -``false_jaxpr`` has three input variables (``c`` corresponding to the -constant for ``jnp.ones(1)``, and ``a b`` for the two elements of -``arg2`` that are passed to ``false_jaxpr``). The ``true_jaxpr`` has -three input variables. The first (``e_``) is an unused argument -matching the constant first argument ``c`` of ``false_jaxpr`` -(required for the jaxpr signatures to match). The subsequent two -correspond to the two elements of ``arg2`` that is passed to -``true_jaxpr``. - -The actual operands to the cond primitive are: ``e f b c``, which -correspond in order to: - - * one operand for the predicate, - * one constant (only used by ``false_jaxpr``, but passed to both), - i.e., ``f``, which is a constvar for the top-level jaxpr - * two operands passed to both jaxprs, i.e., ``b`` and ``c``, which are - input vars, corresponding to ``arg2`` for the top-level jaxpr. While ^^^^^ @@ -357,32 +319,22 @@ For example, here is an example fori loop ... arg + ones) ... >>> print(make_jaxpr(func10)(np.ones(16), 5)) -{ lambda c d ; a b. - let e = add a d - _ _ f = while[ body_jaxpr={ lambda ; e g a b c. - let d = add a 1 - f = add c e - h = add f g - in (d, b, h) } +{ lambda ; a b. + let c = broadcast_in_dim[ broadcast_dimensions=( ) + shape=(16,) ] 1.0 + d = add a c + _ _ e = while[ body_jaxpr={ lambda ; a b c d e. + let f = add c 1 + g = mul a 3.0 + h = add e g + i = add h b + in (f, d, i) } body_nconsts=2 cond_jaxpr={ lambda ; a b c. let d = lt a b in (d,) } - cond_nconsts=0 ] c a 0 b e - in (f,) } - -The top-level jaxpr has two constvars: ``c`` (corresponding to ``ones * 3.`` from the body -of the loop) and ``d`` (corresponding to the use of ``ones`` in the initial carry). -There are also two input variables (``a`` corresponding to ``arg`` and ``b`` corresponding -to ``n``). -The loop carry consists of three values, as seen in the body of ``cond_jaxpr`` -(corresponding to the iteration index, iteration end, and the accumulated value carry). -Note that ``body_jaxpr`` takes 5 input variables. The first two are actually -constvars: ``e`` corresponding to ``ones * 3`` and ``g`` corresponding to the -captures use of ``arg`` in the loop body. -The parameter ``body_nconsts = 2`` specifies that there are 2 constants for the -``body_jaxpr``. -The other 3 input variables for ``body_jaxpr`` correspond to the flattened carry values. + cond_nconsts=0 ] c a 0 b d + in (e,) } The while primitive takes 5 arguments: ``c a 0 b e``, as follows: @@ -395,13 +347,13 @@ Scan JAX supports a special form of loop over the elements of an array (with statically known shape). The fact that there are a fixed number of iterations -makes this form of looping easily reverse-differentiable. Such loops are constructed -with the :py:func:`jax.lax.scan` operator:: +makes this form of looping easily reverse-differentiable. Such loops are +constructed with the :py:func:`jax.lax.scan` function:: lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B]) -Here ``C`` is the type of the scan carry, ``A`` is the element type of the input array(s), -and ``B`` is the element type of the output array(s). +Here ``C`` is the type of the scan carry, ``A`` is the element type of the +input array(s), and ``B`` is the element type of the output array(s). For the example consider the function ``func11`` below @@ -415,12 +367,14 @@ For the example consider the function ``func11`` below ... return lax.scan(body, 0., (arr, ones)) ... >>> print(make_jaxpr(func11)(np.ones(16), 5.)) -{ lambda c ; a b. - let d e = scan[ jaxpr={ lambda ; f a b c. - let d = mul b c - e = add a d - g = add e f - in (g, a) } +{ lambda ; a b. + let c = broadcast_in_dim[ broadcast_dimensions=( ) + shape=(16,) ] 1.0 + d e = scan[ jaxpr={ lambda ; a b c d. + let e = mul c d + f = add b e + g = add f a + in (g, b) } length=16 linear=(False, False, False, False) num_carry=1 @@ -429,17 +383,6 @@ For the example consider the function ``func11`` below unroll=1 ] b 0.0 a c in (d, e) } -The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant, -and two input variables corresponding to the arguments ``arr`` and ``extra``. -The body of the scan has 4 input variables, of which: - - * one (``f``) is a constant (since ``num_consts = 1``), and stands for the - captured variable ``extra`` used in the loop body, - * one (``a``) is the value of the carry (since ``num_carry = 1``) - * The remaining 2 are the input values. ``b`` is the array element from the - first array passed to lax.scan (``arr``) and ``c`` is the second array - (``ones``). - The ``linear`` parameter describes for each of the input variables whether they are guaranteed to be used linearly in the body. Once the scan goes through linearization, more arguments will be linear. @@ -466,37 +409,27 @@ computation should run. For example ... return arg + inner(arg - 2.) ... >>> print(make_jaxpr(func12)(1.)) -{ lambda b ; a. - let c = sub a 2.0 - d = xla_call[ backend=None - call_jaxpr={ lambda ; c b a. - let d = mul b c - e = add a d +{ lambda ; a. + let b = sub a 2.0 + c = xla_call[ backend=None + call_jaxpr={ lambda ; a b. + let c = broadcast_in_dim[ broadcast_dimensions=( ) + shape=(1,) ] 1.0 + d = mul a c + e = add b d in (e,) } device=None - donated_invars=(False, False, False) - name=inner ] b a c - e = add a d - in (e,) } + donated_invars=(False, False) + name=inner ] a b + d = add a c + in (d,) } -The top-level constvar ``b`` refers to the ``jnp.ones(1)`` constant, and -the top-level input variable `a` refers to the ``arg`` parameter of ``func12``. -The ``xla_call`` primitive stands for a call to the jitted ``inner`` function. -The primitive has the function body in the ``call_jaxpr`` parameter, a jaxpr -with 3 input parameters: - - * ``c`` is a constvar and stands for the ``ones`` constant, - * ``b`` corresponds to the free variable ``arg`` captured in the ``inner`` function, - * ``a`` corresponds to the ``inner`` parameter ``x``. - -The primitive takes three arguments ``b a c``. XLA_pmap ^^^^^^^^ -If you use the :py:func:`jax.pmap` transformation, the function to be -mapped is captured using the ``xla_pmap`` primitive. Consider this -example +If you use the :py:func:`jax.pmap` transformation, the function to be mapped is +captured using the ``xla_pmap`` primitive. Consider this example >>> from jax import pmap >>> @@ -507,34 +440,30 @@ example ... return pmap(inner, axis_name='rows')(arr) ... >>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.)) -{ lambda c ; a b. - let d = xla_pmap[ axis_name=rows +{ lambda ; a b. + let c = xla_pmap[ axis_name=rows axis_size=1 backend=None - call_jaxpr={ lambda ; d b a. - let c = add a b + call_jaxpr={ lambda ; a b. + let c = add b a + d = broadcast_in_dim[ broadcast_dimensions=( ) + shape=(1,) ] 1.0 e = add c d f = psum[ axis_index_groups=None - axis_name=rows ] a + axis_name=rows ] b g = div e f in (g,) } devices=None - donated_invars=(False, False, False) + donated_invars=(False, False) global_axis_size=None - mapped_invars=(True, False, True) - name=inner ] c b a - in (d,) } + mapped_invars=(False, True) + name=inner ] b a + in (c,) } -The top-level constvar ``c`` refers to the ``jnp.ones(1)`` constant. The ``xla_pmap`` primitive specifies the name of the axis (parameter ``rows``) -and the body of the function to be mapped as the ``call_jaxpr`` parameter. The +and the body of the function to be mapped as the ``call_jaxpr`` parameter. value of this parameter is a Jaxpr with 3 input variables: - * ``d`` stands for the constant ``jnp.ones(1)``, - * ``b`` stands for the free variable ``extra``, - * ``a`` stands for the parameter ``x`` of ``inner``. - - The parameter ``mapped_invars`` specify which of the input variables should be mapped and which should be broadcast. In our example, the value of ``extra`` is broadcast, the other input values are mapped. diff --git a/jax/api.py b/jax/api.py index e012b4f21..63854ca82 100644 --- a/jax/api.py +++ b/jax/api.py @@ -393,7 +393,7 @@ def disable_jit(): ... return y + 3 ... >>> print(f(jax.numpy.array([1, 2, 3]))) - Value of y is Traced + Value of y is Tracedwith [5 7 9] Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`, @@ -651,7 +651,7 @@ def _xla_computation( else: pvals = [pe.PartialVal.unknown(aval) for aval in avals] jaxpr, out_pvals, consts = pe.trace_to_jaxpr( - jaxtree_fun, pvals, instantiate=True, stage_out=True) + jaxtree_fun, pvals, instantiate=True, stage_out=True) # type: ignore out_avals = [raise_to_shaped(pval.get_aval()) for pval in out_pvals] jaxpr = xla.apply_outfeed_rewriter(jaxpr) axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr)) @@ -1910,11 +1910,12 @@ def make_jaxpr(fun: Callable, >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; a. let b = cos a - c = cos b - d = mul 1.0 c - e = neg d - f = sin a - g = mul e f + c = sin a + _ = sin b + d = cos b + e = mul 1.0 d + f = neg e + g = mul f c in (g,) } """ _check_callable(fun) @@ -1936,7 +1937,7 @@ def make_jaxpr(fun: Callable, else: in_pvals = [pe.PartialVal.unknown(a) for a in in_avals] jaxpr, out_pvals, consts = pe.trace_to_jaxpr( - jaxtree_fun, in_pvals, instantiate=True, stage_out=True) + jaxtree_fun, in_pvals, instantiate=True, stage_out=True) # type: ignore out_avals = map(raise_to_shaped, unzip2(out_pvals)[0]) typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals) return typed_jaxpr @@ -2214,7 +2215,7 @@ class CustomTransformsFunction(object): if config.omnistaging_enabled: jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True) else: - with core.initial_style_staging(): + with core.initial_style_staging(): # type: ignore jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True) outs = self.prim.bind(*it.chain(consts, args_flat), jaxpr=jaxpr, in_tree=in_tree, out_tree=out_tree(), diff --git a/jax/config.py b/jax/config.py index 9f0f993d6..b9a165df2 100644 --- a/jax/config.py +++ b/jax/config.py @@ -42,8 +42,8 @@ class Config: self.meta = {} self.FLAGS = NameSpace(self.read) self.use_absl = False - self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', False) - self._omnistaging_enablers = [] + self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True) + self._omnistaging_disablers = [] def update(self, name, val): if self.use_absl: @@ -114,24 +114,25 @@ class Config: self.complete_absl_config(absl.flags) already_configured_with_absl = True - if FLAGS.jax_omnistaging: - self.enable_omnistaging() + if not FLAGS.jax_omnistaging: + self.disable_omnistaging() - def register_omnistaging_enabler(self, enabler): - if not self.omnistaging_enabled: - self._omnistaging_enablers.append(enabler) + + def register_omnistaging_disabler(self, disabler): + if self.omnistaging_enabled: + self._omnistaging_disablers.append(disabler) else: - enabler() + disabler() - # TODO(mattjj): remove this when omnistaging fully lands def enable_omnistaging(self): if not self.omnistaging_enabled: - for enabler in self._omnistaging_enablers: - enabler() - self.omnistaging_enabled = True + raise Exception("can't re-enable omnistaging after it's been disabled") def disable_omnistaging(self): - pass + if self.omnistaging_enabled: + for disabler in self._omnistaging_disablers: + disabler() + self.omnistaging_enabled = False class NameSpace(object): @@ -156,6 +157,6 @@ flags.DEFINE_bool( flags.DEFINE_bool( 'jax_omnistaging', - bool_env('JAX_OMNISTAGING', False), + bool_env('JAX_OMNISTAGING', True), help='Enable staging based on dynamic context rather than data dependence.' ) diff --git a/jax/core.py b/jax/core.py index d2bc5b572..7762e955b 100644 --- a/jax/core.py +++ b/jax/core.py @@ -24,7 +24,7 @@ import threading import types from typing import (Any, Callable, ClassVar, Dict, Generator, Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple, - Type, Union, cast, no_type_check) + Type, Union, cast) import numpy as np @@ -266,19 +266,14 @@ class Primitive: def __repr__(self): return '{}'.format(self.name) - def bind(self, *args, **kwargs): + + def bind(self, *args, **params): assert skip_checks or all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args), args top_trace = find_top_trace(args) - if top_trace is None: - return self.impl(*args, **kwargs) - tracers = map(top_trace.full_raise, args) - out_tracer = top_trace.process_primitive(self, tracers, kwargs) - if self.multiple_results: - return map(full_lower, out_tracer) - else: - return full_lower(out_tracer) + out = top_trace.process_primitive(self, tracers, params) + return map(full_lower, out) if self.multiple_results else full_lower(out) def def_impl(self, impl): self.impl = impl @@ -517,14 +512,8 @@ class Tracer: def __long__(self): return self.aval._long(self) def __hex__(self): return self.aval._hex(self) def __oct__(self): return self.aval._oct(self) - - def __float__(self): - raise TypeError("JAX Tracer object cannot be interpreted as a float. " - "Try using `x.astype(float)` instead.") - - def __complex__(self): - raise TypeError("JAX Tracer object cannot be interpreted as a complex. " - "Try using `x.astype(complex)` instead.") + def __float__(self): return self.aval._float(self) + def __complex__(self): return self.aval._complex(self) def __setitem__(self, idx, val): raise TypeError("JAX 'Tracer' objects do not support item assignment") @@ -571,6 +560,9 @@ class Tracer: def __deepcopy__(self, unused_memo): return self + def _origin_msg(self) -> str: + return "" + # these can be used to set up forwarding of properties and instance methods from # Tracer instances to the underlying avals aval_property = namedtuple("aval_property", ["fget"]) @@ -612,57 +604,47 @@ class TraceStack: downward: List[MainTrace] def __init__(self): - self.upward = [] - self.downward = [] + eval_trace = MainTrace(0, EvalTrace) + self.stack = [eval_trace] + self.dynamic = eval_trace - def next_level(self, bottom: bool) -> int: - if bottom: - return - (len(self.downward) + 1) - else: - return len(self.upward) + def next_level(self) -> int: + return len(self.stack) - def push(self, main_trace: MainTrace, bottom: bool) -> None: - if bottom: - self.downward.append(main_trace) - else: - self.upward.append(main_trace) + def push(self, main_trace: MainTrace) -> None: + self.stack.append(main_trace) - def pop(self, bottom: bool) -> None: - if bottom: - self.downward.pop() - else: - self.upward.pop() + def pop(self) -> None: + self.stack.pop() def __repr__(self) -> str: - return 'Trace stack\n{} ---\n{}'.format( - map(' {}\n'.format, self.upward[::-1]), - map(' {}\n'.format, self.downward)) + stack_str = map(' {}\n'.format, self.stack[::-1]) + return f'Trace stack\n{stack_str}\n{self.dynamic}' def copy(self): - new = TraceStack() - new.upward = self.upward[:] - new.downward = self.downward[:] + new = self.__new__(TraceStack) + new.stack = self.stack[:] + new.dynamic = self.dynamic return new class Sublevel(int): pass AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace']) - class TraceState: trace_stack: TraceStack substack: List[Sublevel] - initial_style: bool + axis_env: List[AxisEnvFrame] def __init__(self) -> None: self.trace_stack = TraceStack() self.substack = [Sublevel(0)] - self.initial_style = False + self.axis_env = [] def copy(self): - new = TraceState() + new = self.__new__(TraceState) new.trace_stack = self.trace_stack.copy() new.substack = self.substack[:] - new.initial_style = self.initial_style + new.axis_env = self.axis_env[:] return new # The global state of the tracer is accessed by a thread-local object. @@ -676,8 +658,9 @@ thread_local_state = ThreadLocalState() def reset_trace_state() -> bool: "Reset the global trace state and return True if it was already clean." if (thread_local_state.trace_state.substack != [Sublevel(0)] or - thread_local_state.trace_state.trace_stack.downward or - thread_local_state.trace_state.trace_stack.upward): + thread_local_state.trace_state.axis_env != [] or + thread_local_state.trace_state.trace_stack.stack != [MainTrace(0, EvalTrace)] or + thread_local_state.trace_state.trace_stack.dynamic != MainTrace(0, EvalTrace)): thread_local_state.trace_state.__init__() # type: ignore return False else: @@ -687,15 +670,21 @@ def cur_sublevel() -> Sublevel: return thread_local_state.trace_state.substack[-1] @contextmanager -def new_main(trace_type: Type[Trace], bottom=False) -> Generator[MainTrace, None, None]: - level = thread_local_state.trace_state.trace_stack.next_level(bottom) +def new_main(trace_type: Type[Trace], dynamic: bool = False, + ) -> Generator[MainTrace, None, None]: + stack = thread_local_state.trace_state.trace_stack + level = stack.next_level() main = MainTrace(level, trace_type) - thread_local_state.trace_state.trace_stack.push(main, bottom) + stack.push(main) + if dynamic: + prev_dynamic, stack.dynamic = stack.dynamic, main try: yield main finally: - thread_local_state.trace_state.trace_stack.pop(bottom) + thread_local_state.trace_state.trace_stack.pop() + if dynamic: + stack.dynamic = prev_dynamic if check_leaks: t = ref(main) @@ -704,6 +693,23 @@ def new_main(trace_type: Type[Trace], bottom=False) -> Generator[MainTrace, None print(thread_local_state.trace_state.trace_stack) raise Exception('Leaked trace {}'.format(t())) +@contextmanager +def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]: + stack = thread_local_state.trace_state.trace_stack + main = MainTrace(0, trace_type) + prev_dynamic, stack.dynamic = stack.dynamic, main + prev_base, stack.stack[0] = stack.stack[0], main + try: + yield main + finally: + stack.dynamic = prev_dynamic + stack.stack[0] = prev_base + +@contextmanager +def eval_context(): + with new_base_main(EvalTrace): + yield + @contextmanager def new_sublevel() -> Generator[None, None, None]: sublevel = Sublevel(len(thread_local_state.trace_state.substack)) @@ -719,29 +725,24 @@ def new_sublevel() -> Generator[None, None, None]: if t() is not None: raise Exception('Leaked sublevel {}'.format(t())) +def maybe_new_sublevel(trace): + # dynamic traces run the WrappedFun, so we raise the sublevel for them + dynamic = thread_local_state.trace_state.trace_stack.dynamic + return new_sublevel() if trace.main is dynamic else suppress() + def full_lower(val): if isinstance(val, Tracer): return val.full_lower() else: return val -def find_top_trace(xs) -> Optional[Trace]: - top_trace = max((x._trace for x in xs if isinstance(x, Tracer)), - key=attrgetter('level'), default=None) - return top_trace and type(top_trace)(top_trace.main, cur_sublevel()) - -@contextmanager -def initial_style_staging(): - trace_state = thread_local_state.trace_state - prev, trace_state.initial_style = trace_state.initial_style, True - try: - yield - finally: - trace_state.initial_style = prev - -@contextmanager -def eval_context(): - yield # dummy implementation for forward compatibility +def find_top_trace(xs) -> Trace: + top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)), + default=None, key=attrgetter('level')) + dynamic = thread_local_state.trace_state.trace_stack.dynamic + top_main = (dynamic if top_main is None or dynamic.level > top_main.level + else top_main) + return top_main and top_main.trace_type(top_main, cur_sublevel()) # type: ignore # -------------------- abstract values -------------------- @@ -844,20 +845,24 @@ pytype_aval_mappings[Unit] = lambda _: abstract_unit class ConcretizationTypeError(TypeError): pass -def raise_concretization_error(val, context=""): - msg = (f"Abstract tracer value encountered where concrete value is expected ({context}).\n" - "Use transformation parameters such as `static_argnums` for `jit` " - "to avoid tracing input values.\n" - "See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.\n" - f"Encountered value: {val}") +def raise_concretization_error(val: Tracer, context=""): + msg = ("Abstract tracer value encountered where concrete value is expected.\n\n" + + context + "\n\n" + + val._origin_msg() + "\n\n" + + "You can use transformation parameters such as `static_argnums` for " + "`jit` to avoid tracing particular arguments of transformed functions.\n\n" + "See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n" + f"Encountered tracer value: {val}") raise ConcretizationTypeError(msg) -def concretization_function_error(fun, context=""): +def concretization_function_error(fun, suggest_astype=False): fname = getattr(fun, "__name__", fun) - fname_context = f"in `{fname}`" - if context: - fname_context += f" {context}" + fname_context = f"The problem arose with the `{fname}` function. " + if suggest_astype: + fname_context += ("If trying to convert the data type of a value, " + f"try using `x.astype({fun.__name__})` " + f"or `jnp.array(x, {fun.__name__})` instead.") def error(self, arg): raise_concretization_error(arg, fname_context) return error @@ -899,12 +904,9 @@ class UnshapedArray(AbstractValue): ", weak_type=True" if self.weak_type else "") _bool = _nonzero = concretization_function_error(bool) - _float = concretization_function_error( - float, "Try using `x.astype(float)` instead.") - _int = concretization_function_error( - int, "Try using `x.astype(int)` instead.") - _complex = concretization_function_error( - complex, "Try using `x.astype(complex)` instead.") + _float = concretization_function_error(float, True) + _int = concretization_function_error(int, True) + _complex = concretization_function_error(complex, True) _hex = concretization_function_error(hex) _oct = concretization_function_error(oct) @@ -1036,9 +1038,12 @@ class ConcreteArray(ShapedArray): return ConcreteArray(self.val) if self.weak_type else self _bool = _nonzero = partialmethod(_forward_to_value, bool) - _int = partialmethod(_forward_to_value, int) - _hex = partialmethod(_forward_to_value, hex) - _oct = partialmethod(_forward_to_value, oct) + _int = partialmethod(_forward_to_value, int) + _hex = partialmethod(_forward_to_value, hex) + _oct = partialmethod(_forward_to_value, oct) + + _float = concretization_function_error(float, True) + _complex = concretization_function_error(complex, True) class AbstractToken(AbstractValue): @@ -1123,20 +1128,16 @@ def process_env_traces(primitive: Union['CallPrimitive', 'MapPrimitive'], yield outs, tuple(todo) # Ensure the aux output is immutable def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'], - fun: lu.WrappedFun, *args, **params): + fun, *args, **params): params_tuple = tuple(params.items()) top_trace = find_top_trace(args) - level = (thread_local_state.trace_state.trace_stack.next_level(True) - if top_trace is None else top_trace.level) - params_tuple = tuple(params.items()) - fun, env_trace_todo = process_env_traces(fun, primitive, level, params_tuple) - if top_trace is None: - with new_sublevel(): - outs = primitive.impl(fun, *args, **params) - else: - tracers = map(top_trace.full_raise, args) + fun, env_trace_todo = process_env_traces( + fun, primitive, top_trace and top_trace.level, params_tuple) + tracers = map(top_trace.full_raise, args) + with maybe_new_sublevel(top_trace): outs = primitive.process(top_trace, fun, tracers, params) - return apply_todos(env_trace_todo(), map(full_lower, outs)) + return map(full_lower, apply_todos(env_trace_todo(), outs)) + class CallPrimitive(Primitive): multiple_results = True @@ -1176,10 +1177,65 @@ class MapPrimitive(Primitive): def post_process(self, trace, out_tracers, params): return trace.post_process_map(self, out_tracers, params) -# This is a no-op with omnistaging disabled @contextmanager def extend_axis_env(axis_name, size: int, tag: Any): - yield + frame = AxisEnvFrame(axis_name, size, tag) + thread_local_state.trace_state.axis_env.append(frame) + try: + yield + finally: + thread_local_state.trace_state.axis_env.pop() + +def axis_frame(axis_name): + frames = thread_local_state.trace_state.axis_env + for frame in reversed(frames): + if frame.name == axis_name: + return frame + else: + raise NameError("unbound axis name: {}".format(axis_name)) + +def axis_index(axis_name): + """Return the index along the mapped axis ``axis_name``. + + Args: + axis_name: hashable Python object used to name the mapped axis. + + Returns: + An integer representing the index. + + For example, with 8 XLA devices available: + + >>> from functools import partial + >>> @partial(jax.pmap, axis_name='i') + ... def f(_): + ... return lax.axis_index('i') + ... + >>> f(np.zeros(4)) + ShardedDeviceArray([0, 1, 2, 3], dtype=int32) + >>> f(np.zeros(8)) + ShardedDeviceArray([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) + >>> @partial(jax.pmap, axis_name='i') + ... @partial(jax.pmap, axis_name='j') + ... def f(_): + ... return lax.axis_index('i'), lax.axis_index('j') + ... + >>> x, y = f(np.zeros((4, 2))) + >>> print(x) + [[0 0] + [1 1] + [2 2] + [3 3]] + >>> print(y) + [[0 1] + [0 1] + [0 1] + [0 1]] + """ + return axis_index_p.bind(axis_name=axis_name) + +axis_index_p = Primitive('axis_index') +axis_index_p.def_abstract_eval(lambda *, axis_name: ShapedArray((), np.int32)) + # ------------------- Jaxpr checking ------------------- @@ -1368,7 +1424,7 @@ def pp_eqn(eqn: JaxprEqn, print_shapes: bool = False) -> PrettyPrint: pp_rhs = (pp(eqn.primitive.name) >> pp_kv_pairs(sorted(eqn.params.items())) >> pp(' ') >> pp(pp_vars(eqn.invars, print_shapes))) - if len(lhs) <= 6: + if len(lhs) <= 6 or print_shapes: return pp_lhs >> pp(' ') >> pp_rhs else: return pp_lhs + pp_rhs.indent(2) @@ -1428,61 +1484,63 @@ def pp_kv_pairs(kv_pairs): else: return pp('') -axis_frame = None - -# TODO(mattjj): remove when omnistaging fully lands -@config.register_omnistaging_enabler -@no_type_check -def omnistaging_enabler() -> None: +@config.register_omnistaging_disabler +def omnistaging_disabler() -> None: global thread_local_state, call_bind, find_top_trace, initial_style_staging, \ - new_main, reset_trace_state, extend_axis_env, axis_frame, \ - new_base_main, eval_context, \ - TraceStack, TraceState - del initial_style_staging + new_main, reset_trace_state, TraceStack, TraceState, extend_axis_env class TraceStack: - stack: List[MainTrace] - dynamic: MainTrace + upward: List[MainTrace] + downward: List[MainTrace] def __init__(self): - eval_trace = MainTrace(0, EvalTrace) - self.stack = [eval_trace] - self.dynamic = eval_trace + self.upward = [] + self.downward = [] - def next_level(self) -> int: - return len(self.stack) + def next_level(self, bottom: bool) -> int: + if bottom: + return - (len(self.downward) + 1) + else: + return len(self.upward) - def push(self, main_trace: MainTrace) -> None: - self.stack.append(main_trace) + def push(self, main_trace: MainTrace, bottom: bool) -> None: + if bottom: + self.downward.append(main_trace) + else: + self.upward.append(main_trace) - def pop(self) -> None: - self.stack.pop() + def pop(self, bottom: bool) -> None: + if bottom: + self.downward.pop() + else: + self.upward.pop() def __repr__(self) -> str: - stack_str = map(' {}\n'.format, self.stack[::-1]) - return f'Trace stack\n{stack_str}\n{self.dynamic}' + return 'Trace stack\n{} ---\n{}'.format( + map(' {}\n'.format, self.upward[::-1]), + map(' {}\n'.format, self.downward)) def copy(self): - new = self.__new__(TraceStack) - new.stack = self.stack[:] - new.dynamic = self.dynamic + new = TraceStack() + new.upward = self.upward[:] + new.downward = self.downward[:] return new class TraceState: trace_stack: TraceStack substack: List[Sublevel] - axis_env: List[AxisEnvFrame] + initial_style: bool def __init__(self) -> None: - self.trace_stack = TraceStack() + self.trace_stack = TraceStack() # type: ignore self.substack = [Sublevel(0)] - self.axis_env = [] + self.initial_style = False def copy(self): - new = self.__new__(TraceState) + new = TraceState() new.trace_stack = self.trace_stack.copy() new.substack = self.substack[:] - new.axis_env = self.axis_env[:] + new.initial_style = self.initial_style return new thread_local_state = ThreadLocalState() @@ -1490,54 +1548,23 @@ def omnistaging_enabler() -> None: def reset_trace_state() -> bool: "Reset the global trace state and return True if it was already clean." if (thread_local_state.trace_state.substack != [Sublevel(0)] or - thread_local_state.trace_state.axis_env != [] or - thread_local_state.trace_state.trace_stack.stack != [MainTrace(0, EvalTrace)] or - thread_local_state.trace_state.trace_stack.dynamic != MainTrace(0, EvalTrace)): + thread_local_state.trace_state.trace_stack.downward or + thread_local_state.trace_state.trace_stack.upward): thread_local_state.trace_state.__init__() # type: ignore return False else: return True - def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'], - fun, *args, **params): - params_tuple = tuple(params.items()) - top_trace = find_top_trace(args) - fun, env_trace_todo = process_env_traces( - fun, primitive, top_trace and top_trace.level, params_tuple) - tracers = map(top_trace.full_raise, args) - with maybe_new_sublevel(top_trace): - outs = primitive.process(top_trace, fun, tracers, params) - return map(full_lower, apply_todos(env_trace_todo(), outs)) - - def maybe_new_sublevel(trace): - # dynamic traces run the WrappedFun, so we raise the sublevel for them - dynamic = thread_local_state.trace_state.trace_stack.dynamic - return new_sublevel() if trace.main is dynamic else suppress() - - def find_top_trace(xs) -> Trace: - top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)), - default=None, key=attrgetter('level')) - dynamic = thread_local_state.trace_state.trace_stack.dynamic - top_main = (dynamic if top_main is None or dynamic.level > top_main.level - else top_main) - return top_main and top_main.trace_type(top_main, cur_sublevel()) # type: ignore - @contextmanager - def new_main(trace_type: Type[Trace], dynamic: bool = False, - ) -> Generator[MainTrace, None, None]: - stack = thread_local_state.trace_state.trace_stack - level = stack.next_level() + def new_main(trace_type: Type[Trace], bottom=False) -> Generator[MainTrace, None, None]: + level = thread_local_state.trace_state.trace_stack.next_level(bottom) main = MainTrace(level, trace_type) - stack.push(main) - if dynamic: - prev_dynamic, stack.dynamic = stack.dynamic, main + thread_local_state.trace_state.trace_stack.push(main, bottom) try: yield main finally: - thread_local_state.trace_state.trace_stack.pop() - if dynamic: - stack.dynamic = prev_dynamic + thread_local_state.trace_state.trace_stack.pop(bottom) if check_leaks: t = ref(main) @@ -1546,47 +1573,55 @@ def omnistaging_enabler() -> None: print(thread_local_state.trace_state.trace_stack) raise Exception('Leaked trace {}'.format(t())) - @contextmanager - def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]: - stack = thread_local_state.trace_state.trace_stack - main = MainTrace(0, trace_type) - prev_dynamic, stack.dynamic = stack.dynamic, main - prev_base, stack.stack[0] = stack.stack[0], main - try: - yield main - finally: - stack.dynamic = prev_dynamic - stack.stack[0] = prev_base + def find_top_trace(xs) -> Optional[Trace]: + top_trace = max((x._trace for x in xs if isinstance(x, Tracer)), + key=attrgetter('level'), default=None) + return top_trace and type(top_trace)(top_trace.main, cur_sublevel()) @contextmanager def eval_context(): - with new_base_main(EvalTrace): - yield + yield # dummy implementation for forward compatibility - def bind(self, *args, **params): + def bind(self, *args, **kwargs): assert skip_checks or all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args), args top_trace = find_top_trace(args) + if top_trace is None: + return self.impl(*args, **kwargs) + tracers = map(top_trace.full_raise, args) - out = top_trace.process_primitive(self, tracers, params) - return map(full_lower, out) if self.multiple_results else full_lower(out) - Primitive.bind = bind + out_tracer = top_trace.process_primitive(self, tracers, kwargs) + if self.multiple_results: + return map(full_lower, out_tracer) + else: + return full_lower(out_tracer) + Primitive.bind = bind # type: ignore + + def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'], + fun: lu.WrappedFun, *args, **params): + params_tuple = tuple(params.items()) + top_trace = find_top_trace(args) + level = (thread_local_state.trace_state.trace_stack.next_level(True) + if top_trace is None else top_trace.level) + params_tuple = tuple(params.items()) + fun, env_trace_todo = process_env_traces(fun, primitive, level, params_tuple) + if top_trace is None: + with new_sublevel(): + outs = primitive.impl(fun, *args, **params) + else: + tracers = map(top_trace.full_raise, args) + outs = primitive.process(top_trace, fun, tracers, params) + return apply_todos(env_trace_todo(), map(full_lower, outs)) @contextmanager - def extend_axis_env(axis_name, size: int, main_trace: Optional[MainTrace]): - frame = AxisEnvFrame(axis_name, size, main_trace) - thread_local_state.trace_state.axis_env.append(frame) + def extend_axis_env(axis_name, size: int, tag: Any): + yield + + @contextmanager + def initial_style_staging(): + trace_state = thread_local_state.trace_state + prev, trace_state.initial_style = trace_state.initial_style, True try: yield finally: - frame_ = thread_local_state.trace_state.axis_env.pop() - assert frame is frame_ # Only runs if there was was no exception - - def axis_frame(axis_name): - frames = thread_local_state.trace_state.axis_env - for frame in reversed(frames): - if frame.name == axis_name: - return frame - else: - raise NameError(f"Unbound axis name: {axis_name}.\n" - f"The currently bound axes are: {','.join(f.name for f in frames)}") + trace_state.initial_style = prev diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index 503a271c2..c52f0836f 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -69,11 +69,7 @@ def _memoize(thunk): return memoized def _initial_style_jaxpr(fun, in_avals): - in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] - jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True, - bottom=True, stage_out=False) - assert not any(isinstance(c, core.Tracer) for c in consts) - out_avals = map(raise_to_shaped, unzip2(out_pvals)[0]) + jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals) return typed_jaxpr @@ -276,12 +272,8 @@ class CustomJVPCallPrimitive(core.CallPrimitive): fun, self, top_trace and top_trace.level, ()) jvp, env_trace_todo2 = core.process_env_traces( jvp, self, top_trace and top_trace.level, ()) - if top_trace is None: - with core.new_sublevel(): - outs = self.impl(fun, jvp, *args) - else: - tracers = map(top_trace.full_raise, args) - outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) + tracers = map(top_trace.full_raise, args) # type: ignore + outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) if env_trace_todo: raise core.UnexpectedTracerError @@ -602,13 +594,16 @@ xla.initial_style_translations[custom_vjp_call_jaxpr_p] = \ batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp -# TODO(mattjj): remove when omnistaging fully lands -@config.register_omnistaging_enabler -def omnistaging_enabler() -> None: - global _initial_style_jaxpr +@config.register_omnistaging_disabler +def omnistaging_disabler() -> None: + global _initial_style_jaxpr, custom_jvp_call def _initial_style_jaxpr(fun, in_avals): - jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) + in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] + jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True, + bottom=True, stage_out=False) # type: ignore + assert not any(isinstance(c, core.Tracer) for c in consts) + out_avals = map(raise_to_shaped, unzip2(out_pvals)[0]) typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals) return typed_jaxpr @@ -619,10 +614,15 @@ def omnistaging_enabler() -> None: fun, self, top_trace and top_trace.level, ()) jvp, env_trace_todo2 = core.process_env_traces( jvp, self, top_trace and top_trace.level, ()) - tracers = map(top_trace.full_raise, args) # type: ignore - outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore + if top_trace is None: + with core.new_sublevel(): + outs = self.impl(fun, jvp, *args) + else: + tracers = map(top_trace.full_raise, args) + outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) if env_trace_todo: raise core.UnexpectedTracerError return map(core.full_lower, outs) CustomJVPCallPrimitive.bind = bind # type: ignore + custom_jvp_call = custom_jvp_call_p.bind diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 3efffa3ba..cff5bde38 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -237,10 +237,6 @@ deflinear(lax.reduce_window_sum_p) deflinear(lax_fft.fft_p) deflinear(xla.device_put_p) -# TODO(mattjj): remove when omnistaging fully lands -try: deflinear(lax.tie_in_p) -except AttributeError: pass - def _cumulative_jet_rule(primals_in, series_in, *, axis: int, prefix_scan: Callable): # Irrespective of backend, we always use the parallel prefix scan @@ -523,8 +519,8 @@ def _gen_reduce_choose_taylor_rule(chooser_fun): series_out = [_reduce_chooser_taylor_rule(g) for g in gs] return primal_out, series_out return chooser_taylor_rule -jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(lax.reduce_max_p.bind) -jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(lax.reduce_min_p.bind) +jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(lax._reduce_max) +jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(lax._reduce_min) def _abs_taylor_rule(x, series_in, **params): x, = x @@ -584,3 +580,6 @@ def _custom_jvp_call_jaxpr_rule(primals_in, series_in, *, fun_jaxpr, del jvp_jaxpr_thunk return jet(core.jaxpr_as_fun(fun_jaxpr), primals_in, series_in) jet_rules[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_rule + + +deflinear(lax.tie_in_p) diff --git a/jax/experimental/ode.py b/jax/experimental/ode.py index fff9867dc..06667307e 100644 --- a/jax/experimental/ode.py +++ b/jax/experimental/ode.py @@ -51,9 +51,9 @@ def closure_convert(fun, in_tree, in_avals): else: in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - with core.initial_style_staging(): + with core.initial_style_staging(): # type: ignore jaxpr, out_pvals, consts = pe.trace_to_jaxpr( - wrapped_fun, in_pvals, instantiate=True, stage_out=False) + wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore out_tree = out_tree() # We only want to closure convert for constants with respect to which we're diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index c27b61921..c54a362dc 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -504,7 +504,7 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_ else: primal_jaxpr, tangent_jaxpr, out_unknowns = \ pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True, - trace_type=None) + trace_type=None) # type: ignore def do_transpose(primals_in, cotangents_in): # NOTE: This is passing in undefined primals in place of tangent arguments, but it @@ -555,9 +555,7 @@ def jvp_jaxpr(jaxpr, nonzeros, instantiate): f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros) tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) - pvals = [pe.PartialVal.unknown(aval) for aval in avals_in] - jaxpr_out, pvals_out, literals_out = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True) - avals_out, _ = unzip2(pvals_out) + jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out) return jaxpr_out, out_nonzeros() @@ -644,7 +642,7 @@ def defvjp_all(prim, custom_vjp): if config.omnistaging_enabled: jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True) else: - with core.initial_style_staging(): + with core.initial_style_staging(): # type: ignore jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True) tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr, @@ -680,9 +678,8 @@ def defvjp2(prim, *vjps): defvjp_all(prim, vjpmaker) -# TODO(mattjj): remove when omnistaging fully lands -@config.register_omnistaging_enabler -def omnistaging_enabler() -> None: +@config.register_omnistaging_disabler +def omnistaging_disabler() -> None: global jvp_jaxpr def jvp_jaxpr(jaxpr, nonzeros, instantiate): @@ -691,6 +688,8 @@ def omnistaging_enabler() -> None: f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros) tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) - jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) + pvals = [pe.PartialVal.unknown(aval) for aval in avals_in] + jaxpr_out, pvals_out, literals_out = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True) + avals_out, _ = unzip2(pvals_out) jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out) return jaxpr_out, out_nonzeros() diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index f414d436c..9380da451 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -375,10 +375,8 @@ def batch_jaxpr(jaxpr, size, batched, instantiate): f, batched_out = batched_traceable(f, size, batched, instantiate) avals_in = [_promote_aval_rank(size, a) if b else a for a, b in zip(jaxpr.in_avals, batched)] - in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in] - jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True) - avals_out, _ = unzip2(pvals_out) - jaxpr_out = core.TypedJaxpr(jaxpr_out, consts_out, avals_in, avals_out) + jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in) + jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out) return jaxpr_out, batched_out() @lu.transformation_with_aux @@ -427,8 +425,8 @@ def _merge_bdims(x, y): return x # arbitrary -@config.register_omnistaging_enabler -def omnistaging_enabler() -> None: +@config.register_omnistaging_disabler +def omnistaging_disabler() -> None: global batch_jaxpr def batch_jaxpr(jaxpr, size, batched, instantiate): @@ -436,8 +434,10 @@ def omnistaging_enabler() -> None: f, batched_out = batched_traceable(f, size, batched, instantiate) avals_in = [_promote_aval_rank(size, a) if b else a for a, b in zip(jaxpr.in_avals, batched)] - jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in) - jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out) + in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in] + jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True) + avals_out, _ = unzip2(pvals_out) + jaxpr_out = core.TypedJaxpr(jaxpr_out, consts_out, avals_in, avals_out) return jaxpr_out, batched_out() diff --git a/jax/interpreters/invertible_ad.py b/jax/interpreters/invertible_ad.py index 390264596..4ae3aaa8d 100644 --- a/jax/interpreters/invertible_ad.py +++ b/jax/interpreters/invertible_ad.py @@ -217,9 +217,9 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True) else: - ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( + ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( # type: ignore complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), - instantiate=True, stage_out=False) + instantiate=True, stage_out=False) # type: ignore assert not ivjp_jaxpr.constvars # That might happen some time, but don't bother until then out_avals = map(raise_to_shaped, unzip2(out_pvals)[0]) ivjp_jaxpr = core.TypedJaxpr(ivjp_jaxpr, [], in_avals, out_avals) @@ -234,8 +234,8 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore ivjp_jaxpr, unknowns, instantiate=False) # type:ignore else: - jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( - ivjp_jaxpr, unknowns, instantiate=False, trace_type=None) + jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore + ivjp_jaxpr, unknowns, instantiate=False, trace_type=None) # type: ignore unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs]) # Make sure we're able to compute all cotangents. We don't really care if we # can reconstruct or primals or not, although failure to do so might result in diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 17d26977f..516df03f7 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -17,7 +17,7 @@ from collections import namedtuple import contextlib import functools from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple, - List, Union, cast, Type, Set) + List, Union, cast, Type, no_type_check) from weakref import ref import numpy as np @@ -165,8 +165,8 @@ class JaxprTrace(Trace): # We use process_call to handle both call and map primitives. def process_call(self, primitive, f: lu.WrappedFun, tracers, params): if not config.omnistaging_enabled: - if (self.main.trace_type is StagingJaxprTrace - and primitive in staged_out_calls): + if (self.main.trace_type is StagingJaxprTrace # type: ignore + and primitive in staged_out_calls): # type: ignore tracers = map(self.instantiate_const_abstracted, tracers) if primitive in call_partial_eval_rules: @@ -284,21 +284,6 @@ class JaxprTrace(Trace): env_tracers = map(self.full_raise, env) return jaxpr, out_pvs, consts, env_tracers - def process_custom_jvp_call(self, prim, fun, jvp, tracers): - # See comment at top of `JaxprTrace`. This method should be reachable - # only when we stage out, and in that case we drop the custom differentiation - # rules, because we do not need them. - assert self.main.trace_type is StagingJaxprTrace - return fun.call_wrapped(*tracers) - - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): - # See comment in the above process_custom_jvp_call method. - assert self.main.trace_type is StagingJaxprTrace - return fun.call_wrapped(*tracers) - - -class StagingJaxprTrace(JaxprTrace): pass - @lu.transformation_with_aux def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts): @@ -320,7 +305,7 @@ def abstract_eval_fun(fun, *avals, **params): else: pvals_in = [PartialVal.unknown(a) for a in avals] _, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in, - instantiate=True, stage_out=True) + instantiate=True, stage_out=True) # type: ignore avals_out, _ = unzip2(pvals_out) for aval_out in avals_out: assert isinstance(aval_out, AbstractValue) # instantiate=True @@ -370,10 +355,8 @@ class JaxprTracer(Tracer): # TODO(necula): this should return a TypedJaxpr def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal], - instantiate: Union[bool, Sequence[bool]] = False, - stage_out=False, bottom=False, - trace_type: Optional[Type[Trace]] = None) \ - -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]: + instantiate: Union[bool, Sequence[bool]] = False, + ) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]: """Traces a function into a Jaxpr, given PartialVals for inputs. Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the @@ -416,8 +399,7 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal], out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)] consts = [3, 6] # values for `ka` and `kb` constvars """ - trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace) - with core.new_main(trace_type, bottom=bottom) as main: + with core.new_main(JaxprTrace) as main: fun = trace_to_subjaxpr(fun, main, instantiate) jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) assert not env @@ -574,7 +556,6 @@ def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, Abst def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool], instantiate: Union[bool, Sequence[bool]], - trace_type: Optional[Type[core.Trace]] ) -> Tuple[TypedJaxpr, TypedJaxpr, Sequence[bool]]: """Specializes a Jaxpr given an indication of which inputs are known. @@ -616,19 +597,16 @@ def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool], cell = [] def fun(*vals): pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val) - for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)] - jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate, - trace_type=trace_type) + for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)] + jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate) out_pvs_2, out_consts_2 = unzip2(out_pvals_2) cell.append((out_pvs_2, jaxpr_2, len(consts_2))) return out_consts_2 + consts_2 - # The abstract_unit here doesn't really matter, because trace_to_jaxpr completely ignores - # the avals, and it will never actually reach any primitives, because the `fun` above will - # execute the jaxpr with the right avals (it reconstructs `pvals` inside). - pvals = [PartialVal.unknown(abstract_unit) if uk else PartialVal.unknown(aval) - for aval, uk in zip(jaxpr.in_avals, unknowns)] - jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True) + # For jaxpr_known we pass core.unit for the unknown inputs, and known PartialVal for the + # known inputs. + in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)] + jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals) (out_pvs_2, jaxpr_2, num_res), = cell assert len(jaxpr_2.constvars) == num_res @@ -647,11 +625,10 @@ def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool], in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals)) out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals)) # out_avals_1 and in_avals_2 need the residuals added - out_pvs, _ = unzip2(out_pvals) - res_avals = out_pvs[len(jaxpr.out_avals):] + res_avals = out_avals[len(jaxpr.out_avals):] assert len(res_avals) == num_res - out_avals_1 = out_avals_1 + res_avals - in_avals_2 = in_avals_2 + res_avals + out_avals_1 = [*out_avals_1, *res_avals] + in_avals_2 = [*in_avals_2, *res_avals] typed_jaxpr_1 = TypedJaxpr(jaxpr_1, consts_1, in_avals_1, out_avals_1) typed_jaxpr_2 = TypedJaxpr(jaxpr_2, (), in_avals_2, out_avals_2) @@ -685,7 +662,7 @@ def _remat_partial_eval(trace, _, f, tracers, params): jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval( f, in_pvals, partial(remat_call_p.bind, **params)) else: - with core.initial_style_staging(): + with core.initial_style_staging(): # type: ignore jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval( f, in_pvals, partial(remat_call_p.bind, **params)) @@ -710,7 +687,7 @@ def _remat_partial_eval(trace, _, f, tracers, params): typed_jaxpr, in_unknowns, instantiate=False) # type: ignore else: jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr( - typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type) + typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type) # type: ignore out_knowns = [not b for b in out_unknowns] out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns) @@ -839,26 +816,19 @@ class DynamicJaxprTracer(core.Tracer): def _contents(self): return () - def __bool__(self): - self._concretization_error('__bool__') - - def _concretization_error(self, name): - msgs = self._progenitor_messages() - msg = (f"Abstract tracer value passed to {name} for which a concrete value " - "is required.\n" - f"While tracing the function {self._trace.main.source_info}, " - "this tracer originated from using JAX operations on these lines:" - "\n\n" + "\n\n".join(msgs) + "\n\n" - "See the above traceback for where this tracer was encountered.") - raise core.ConcretizationTypeError(msg) - - def _progenitor_messages(self): + def _origin_msg(self): progenitor_eqns = self._trace.frame.find_progenitors(self) - # TODO mention which jit this tracer belongs to msgs = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n" f" from line {source_info_util.summarize(eqn.source_info)}" for eqn in progenitor_eqns] - return msgs + if msgs: + origin = (f"While tracing the function {self._trace.main.source_info}, " + "this value became a tracer due to JAX operations on these lines:" + "\n\n" + "\n\n".join(msgs)) + else: + origin = ("The error occured while tracing the function " + f"{self._trace.main.source_info}.") + return origin class JaxprStackFrame: __slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val', @@ -883,7 +853,10 @@ class JaxprStackFrame: return jaxpr, out_avals, constvals def find_progenitors(self, tracer): - active_vars = {self.tracer_to_var[id(tracer)]} + var = self.tracer_to_var.get(id(tracer)) + if not var: + return [] + active_vars = {var} for eqn in self.eqns[::-1]: produced = set(eqn.outvars) & active_vars if produced: @@ -1078,19 +1051,60 @@ def fun_sourceinfo(fun): return "" -# TODO(mattjj): remove when omnistaging fully lands - -@config.register_omnistaging_enabler -def omnistaging_enabler() -> None: - global trace_to_jaxpr, partial_eval_jaxpr - - del JaxprTrace.process_custom_jvp_call - del JaxprTrace.process_custom_vjp_call +@config.register_omnistaging_disabler +@no_type_check +def omnistaging_disabler() -> None: + global trace_to_jaxpr, partial_eval_jaxpr, staged_out_calls, StagingJaxprTrace def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal], - instantiate: Union[bool, Sequence[bool]] = False, - ) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]: - with core.new_main(JaxprTrace) as main: + instantiate: Union[bool, Sequence[bool]] = False, + stage_out=False, bottom=False, + trace_type: Optional[Type[Trace]] = None, + ) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]: + """Traces a function into a Jaxpr, given PartialVals for inputs. + + Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the + computation that depends on unknown inputs. The `out_pvals` are the PartialVal + for the outputs. The intermediate values that depend only on known inputs and + are needed to compute the output of `jaxpr` are in `consts` and are passed in + as the constvars of the `jaxpr`. The handling of the known outputs depends on + `instantiate`. + + For example, given `fun` defined as follows:: + + def fun(ki, ui): # ki will be a known input in this example + ka = ki + 2 + kb = ka + 3 + return (kb, ui + ka) + + with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only + computation that depends on unknown inputs is `ui + ka` and will be the only + computation in the body of the `jaxpr`. This computation depends on the known + intermediate value `ka`, which will be computed statically. Currently, such + constants are either embedded in the Jaxpr if they are scalars, or passed as a + constvar to `jaxpr`, and then the value of the actual constant will be in + `consts`: + + When `instantiate=False` we get:: + + jaxpr = + { lambda ka ; ki ui. + let c = add ui ka + in (*, c) } # known outputs are `*` + out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)] + consts = [3] # the constant for `ka` + + When `instantiate=True` we get:: + + jaxpr = + { lambda ka kb ; ki ui. + let c = add ui ka + in (kb, c) } # known output are explicit + out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)] + consts = [3, 6] # values for `ka` and `kb` constvars + """ + trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace) + with core.new_main(trace_type, bottom=bottom) as main: fun = trace_to_subjaxpr(fun, main, instantiate) jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) assert not env @@ -1100,6 +1114,7 @@ def omnistaging_enabler() -> None: def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool], instantiate: Union[bool, Sequence[bool]], + trace_type: Optional[Type[core.Trace]] ) -> Tuple[TypedJaxpr, TypedJaxpr, Sequence[bool]]: f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) @@ -1107,15 +1122,18 @@ def omnistaging_enabler() -> None: def fun(*vals): pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val) for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)] - jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate) + jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate, + trace_type=trace_type) out_pvs_2, out_consts_2 = unzip2(out_pvals_2) cell.append((out_pvs_2, jaxpr_2, len(consts_2))) return out_consts_2 + consts_2 - # For jaxpr_known we pass core.unit for the unknown inputs, and known PartialVal for the - # known inputs. - in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)] - jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals) + # The abstract_unit here doesn't really matter, because trace_to_jaxpr completely ignores + # the avals, and it will never actually reach any primitives, because the `fun` above will + # execute the jaxpr with the right avals (it reconstructs `pvals` inside). + pvals = [PartialVal.unknown(abstract_unit) if uk else PartialVal.unknown(aval) + for aval, uk in zip(jaxpr.in_avals, unknowns)] + jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True) (out_pvs_2, jaxpr_2, num_res), = cell assert len(jaxpr_2.constvars) == num_res @@ -1134,13 +1152,32 @@ def omnistaging_enabler() -> None: in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals)) out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals)) # out_avals_1 and in_avals_2 need the residuals added - res_avals = out_avals[len(jaxpr.out_avals):] + out_pvs, _ = unzip2(out_pvals) + res_avals = out_pvs[len(jaxpr.out_avals):] assert len(res_avals) == num_res - out_avals_1 = [*out_avals_1, *res_avals] - in_avals_2 = [*in_avals_2, *res_avals] + out_avals_1 = out_avals_1 + res_avals + in_avals_2 = in_avals_2 + res_avals typed_jaxpr_1 = TypedJaxpr(jaxpr_1, consts_1, in_avals_1, out_avals_1) typed_jaxpr_2 = TypedJaxpr(jaxpr_2, (), in_avals_2, out_avals_2) return typed_jaxpr_1, typed_jaxpr_2, uk_out -staged_out_calls: Set[core.Primitive] = set() + def process_custom_jvp_call(self, prim, fun, jvp, tracers): + # See comment at top of `JaxprTrace`. This method should be reachable + # only when we stage out, and in that case we drop the custom differentiation + # rules, because we do not need them. + if not config.omnistaging_enabled: + assert self.main.trace_type is StagingJaxprTrace + return fun.call_wrapped(*tracers) + JaxprTrace.process_custom_jvp_call = process_custom_jvp_call + + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): + # See comment in the above process_custom_jvp_call method. + if not config.omnistaging_enabled: + assert self.main.trace_type is StagingJaxprTrace + return fun.call_wrapped(*tracers) + JaxprTrace.process_custom_vjp_call = process_custom_vjp_call + + staged_out_calls = set() + + class StagingJaxprTrace(JaxprTrace): pass diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 8e9bc2b1a..5821ec2b4 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -34,7 +34,7 @@ import itertools as it import operator as op import threading from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, - Type, Union) + Type, Union, no_type_check) from absl import logging import numpy as np @@ -330,91 +330,6 @@ pxla_result_handlers[ShapedArray] = array_result_handler pxla_result_handlers[ConcreteArray] = array_result_handler -### applying parallel primitives in op-by-op Python dispatch - -# There are at least two cases where we might want to evaluate a parallel -# primitive dispatched from Python, rather than being staged out: -# 1. axis_size = psum(1, 'axis_name'), -# 2. to enable an implicit outermost pmap-like context for multi-host -# multi-controller SPMD programs. -# In each case, we can't rely on any data dependence on a pmap trace; instead we -# need some dynamic context, basically modeling the axis name environment stack. -# To handle the former case, we don't need to communicate at all; we instead -# have a table of parallel_pure_rules. To handle the latter case, we'll have a -# globally-scoped root environment frame and compile and execute a single-op -# XLA collective. - -class DynamicAxisEnvFrame(object): - __slots__ = ["name", "pmap_trace", "hard_size"] - def __init__(self, name, pmap_trace, hard_size): - self.name = name - self.pmap_trace = pmap_trace - self.hard_size = hard_size - -class DynamicAxisEnv(list): - def __contains__(self, axis_name): - return axis_name in (frame.name for frame in self) - - def __getitem__(self, axis_name): - if axis_name not in self: - raise NameError("unbound axis name: {}".format(axis_name)) - for frame in reversed(self): - if frame.name == axis_name: - return frame - else: - assert False - - @property - def sizes(self): - return tuple(frame.hard_size for frame in self) - - @property - def nreps(self): - return prod(frame.hard_size for frame in self) - -class _ThreadLocalState(threading.local): - def __init__(self): - self.dynamic_axis_env = DynamicAxisEnv() - -_thread_local_state = _ThreadLocalState() - -@contextmanager -def extend_dynamic_axis_env(axis_name, pmap_trace, hard_size): - dynamic_axis_env = _thread_local_state.dynamic_axis_env - dynamic_axis_env.append(DynamicAxisEnvFrame(axis_name, pmap_trace, hard_size)) - try: - yield - finally: - dynamic_axis_env.pop() - -def unmapped_device_count(backend=None): - dynamic_axis_env = _thread_local_state.dynamic_axis_env - mapped = prod(frame.hard_size for frame in dynamic_axis_env) - unmapped, ragged = divmod(xb.device_count(backend), mapped) - assert not ragged and unmapped > 0 - return unmapped - -def apply_parallel_primitive(prim, *args, **params): - # This is the op-by-op version of applying a collective primitive, like a psum - # that doesn't have a data dependence on the argument of a pmap function. In - # particular, this code gets hit when we write `axis_size = psum(1, 'i')`. We - # look up information in the dynamic axis env. - dynamic_axis_env = _thread_local_state.dynamic_axis_env - axis_name = params.pop('axis_name') - axis_index_groups = params.pop('axis_index_groups') - if axis_index_groups is not None: - shape = (len(axis_index_groups[0]),) - else: - logical_size = lambda frame: frame.hard_size - if isinstance(axis_name, (list, tuple)): - shape = tuple(logical_size(dynamic_axis_env[name]) for name in axis_name) - else: - shape = (logical_size(dynamic_axis_env[axis_name]),) - return parallel_pure_rules[prim](*args, shape=shape, **params) - -parallel_pure_rules: Dict[core.Primitive, Callable] = {} - - ### lazy device-memory persistence and result handling class ShardedDeviceArray(xla.DeviceArray): @@ -636,7 +551,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size, # We add a dummy first invar, to carry the trace details to `dynamic_fun` pval = pe.PartialVal.unknown(core.abstract_unit) # dummy value for axis env jaxpr, out_pvals, consts = pe.trace_to_jaxpr( - dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True) + dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True) # type: ignore jaxpr.invars = jaxpr.invars[1:] # ignore dummy jaxpr = xla.apply_outfeed_rewriter(jaxpr) @@ -795,7 +710,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size, handle_outs = avals_to_results_handler( # type: ignore axis_size, num_local_replicas, num_partitions, out_parts, out_avals) else: - handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas, + handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas, # type: ignore num_partitions, out_parts, out_pvals, compiled.local_devices(), backend) @@ -884,17 +799,20 @@ def get_num_partitions(*partitions): class ResultToPopulate: pass result_to_populate = ResultToPopulate() -def _pvals_to_results_handler( - size, nrep, npart, - out_parts: Optional[Tuple[PartitionsOrReplicated, ...]], - out_pvals, devices, backend): - nouts = len(out_pvals) +def avals_to_results_handler(size, nrep, npart, out_parts, out_avals): + nouts = len(out_avals) if out_parts is None: - out_parts = (None,) * len(out_pvals) - handlers = [ - _pval_to_result_handler(size, nrep, npart, parts, pval, devices, backend) - for pval, parts in safe_zip(out_pvals, out_parts) # type: ignore - ] + out_parts = (None,) * len(out_avals) + + # TODO(mattjj,skyewm): can probably clean up this logic + out_specs = [_pmap_sharding_spec(nrep, size, npart, parts, aval, True) + if aval is not core.abstract_unit else None + for parts, aval in zip(out_parts, out_avals)] + out_indices = [spec_to_indices(core.unmapped_aval(size, aval).shape, spec) + if aval is not core.abstract_unit else None + for aval, spec in zip(out_avals, out_specs)] # pytype: disable=attribute-error + handlers = [aval_to_result_handler(spec, idcs, core.unmapped_aval(size, aval)) + for spec, idcs, aval in zip(out_specs, out_indices, out_avals)] def handler(out_bufs): assert nrep * npart == len(out_bufs) @@ -903,7 +821,7 @@ def _pvals_to_results_handler( for i, buf in enumerate(tuple_buf): buffers[i][r] = buf assert not any(buf is result_to_populate for bufs in buffers - for buf in bufs) + for buf in bufs) return [h(bufs) for h, bufs in zip(handlers, buffers)] return handler @@ -947,37 +865,6 @@ def replicate(val, axis_size, nrep, devices=None, backend=None): device_buffers = [xla.device_put(val, d) for d in devices] return ShardedDeviceArray(replicated_aval, sharding_spec, device_buffers) -def _pval_to_result_handler(axis_size, nrep, npart, parts, pval, devices, backend): - if devices: - assert all(d.host_id == xb.host_id(backend) for d in devices) - pv, const = pval - if pv is None: - if nrep is None: - nrep = axis_size - # If 'const' is a ShardedDeviceArray, it must have come from a pmap nested - # inside the one we're currently evaluating, and we should replicate - # 'const' across the total number of devices needed. We don't necessarily - # know the nested pmap's axis_size (e.g. the jaxpr for - # pmap(pmap(lambda x: 3)) is trivial, with no pmaps), but we can use the - # axis size of the output 'const'. - # TODO: we might be doing unnecessary device transfers in the inner pmap. - if isinstance(const, ShardedDeviceArray): - nrep *= len(const) - - bcast_const = (core.unit if const is core.unit - else replicate(const, axis_size, nrep, devices, backend)) # type: ignore - return lambda _: bcast_const # type: ignore - else: - if pv is not core.abstract_unit: - unsharded_aval = ShapedArray((axis_size,) + pv.shape, pv.dtype) - sharding_spec = _pmap_sharding_spec(nrep, axis_size, npart, parts, pv, - True) - indices = spec_to_indices(unsharded_aval.shape, sharding_spec) - else: - sharding_spec = indices = None - unsharded_aval = pv - return aval_to_result_handler(sharding_spec, indices, unsharded_aval) - def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, mapped): """Sharding spec for arguments or results of a pmap. Args: @@ -1014,7 +901,6 @@ def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, mapped): replication_factors=[(replication_factor * axis_size, 0)] + shard_spec.replication_factors) - def partitioned_sharding_spec(num_partitions: int, partitions: Optional[Sequence[int]], aval): if partitions is None: @@ -1043,7 +929,6 @@ def execute_replicated(compiled, backend, in_handler, out_handler, *args): xla_pmap_p = core.MapPrimitive('xla_pmap') xla_pmap = xla_pmap_p.bind xla_pmap_p.def_impl(xla_pmap_impl) -pe.staged_out_calls.add(xla_pmap_p) # Set param update handlers to update `donated_invars` just like xla_call_p pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p] @@ -1279,41 +1164,35 @@ soft_pmap_p.def_impl(soft_pmap_impl) soft_pmap_rules: Dict[core.Primitive, Callable] = {} -def deleted_with_omnistaging(*a, **k): - assert False, "Should be deleted" - @contextmanager def maybe_extend_axis_env(*args, **kwargs): - yield + with core.extend_axis_env(*args, **kwargs): + yield -@config.register_omnistaging_enabler -def omnistaging_enable() -> None: +@config.register_omnistaging_disabler +@no_type_check +def omnistaging_disabler() -> None: global DynamicAxisEnvFrame, DynamicAxisEnv, _ThreadLocalState, \ _thread_local_state, extend_dynamic_axis_env, unmapped_device_count, \ apply_parallel_primitive, parallel_pure_rules, \ _pvals_to_results_handler, _pval_to_result_handler, replicate, \ - avals_to_results_handler, maybe_extend_axis_env - del DynamicAxisEnvFrame, DynamicAxisEnv, _ThreadLocalState, \ - _thread_local_state, extend_dynamic_axis_env, unmapped_device_count, \ - _pvals_to_results_handler, _pval_to_result_handler, replicate + avals_to_results_handler, axis_index, maybe_extend_axis_env - apply_parallel_primitive = deleted_with_omnistaging - parallel_pure_rules.clear() + @contextmanager + def maybe_extend_axis_env(*args, **kwargs): + yield - def avals_to_results_handler(size, nrep, npart, out_parts, out_avals): - nouts = len(out_avals) + def _pvals_to_results_handler( + size, nrep, npart, + out_parts: Optional[Tuple[PartitionsOrReplicated, ...]], + out_pvals, devices, backend): + nouts = len(out_pvals) if out_parts is None: - out_parts = (None,) * len(out_avals) - - # TODO(mattjj,skyewm): can probably clean up this logic - out_specs = [_pmap_sharding_spec(nrep, size, npart, parts, aval, True) - if aval is not core.abstract_unit else None - for parts, aval in zip(out_parts, out_avals)] - out_indices = [spec_to_indices(core.unmapped_aval(size, aval).shape, spec) - if aval is not core.abstract_unit else None - for aval, spec in zip(out_avals, out_specs)] # pytype: disable=attribute-error - handlers = [aval_to_result_handler(spec, idcs, core.unmapped_aval(size, aval)) - for spec, idcs, aval in zip(out_specs, out_indices, out_avals)] + out_parts = (None,) * len(out_pvals) + handlers = [ + _pval_to_result_handler(size, nrep, npart, parts, pval, devices, backend) + for pval, parts in safe_zip(out_pvals, out_parts) # type: ignore + ] def handler(out_bufs): assert nrep * npart == len(out_bufs) @@ -1326,7 +1205,105 @@ def omnistaging_enable() -> None: return [h(bufs) for h, bufs in zip(handlers, buffers)] return handler + def _pval_to_result_handler(axis_size, nrep, npart, parts, pval, devices, backend): + if devices: + assert all(d.host_id == xb.host_id(backend) for d in devices) + pv, const = pval + if pv is None: + if nrep is None: + nrep = axis_size + # If 'const' is a ShardedDeviceArray, it must have come from a pmap nested + # inside the one we're currently evaluating, and we should replicate + # 'const' across the total number of devices needed. We don't necessarily + # know the nested pmap's axis_size (e.g. the jaxpr for + # pmap(pmap(lambda x: 3)) is trivial, with no pmaps), but we can use the + # axis size of the output 'const'. + # TODO: we might be doing unnecessary device transfers in the inner pmap. + if isinstance(const, ShardedDeviceArray): + nrep *= len(const) + + bcast_const = (core.unit if const is core.unit + else replicate(const, axis_size, nrep, devices, backend)) # type: ignore + return lambda _: bcast_const # type: ignore + else: + if pv is not core.abstract_unit: + unsharded_aval = ShapedArray((axis_size,) + pv.shape, pv.dtype) + sharding_spec = _pmap_sharding_spec(nrep, axis_size, npart, parts, pv, + True) + indices = spec_to_indices(unsharded_aval.shape, sharding_spec) + else: + sharding_spec = indices = None + unsharded_aval = pv + return aval_to_result_handler(sharding_spec, indices, unsharded_aval) + @contextmanager - def maybe_extend_axis_env(*args, **kwargs): - with core.extend_axis_env(*args, **kwargs): + def extend_dynamic_axis_env(axis_name, pmap_trace, hard_size): + dynamic_axis_env = _thread_local_state.dynamic_axis_env + dynamic_axis_env.append(DynamicAxisEnvFrame(axis_name, pmap_trace, hard_size)) + try: yield + finally: + dynamic_axis_env.pop() + + def unmapped_device_count(backend=None): + dynamic_axis_env = _thread_local_state.dynamic_axis_env + mapped = prod(frame.hard_size for frame in dynamic_axis_env) + unmapped, ragged = divmod(xb.device_count(backend), mapped) + assert not ragged and unmapped > 0 + return unmapped + + def apply_parallel_primitive(prim, *args, **params): + # This is the op-by-op version of applying a collective primitive, like a psum + # that doesn't have a data dependence on the argument of a pmap function. In + # particular, this code gets hit when we write `axis_size = psum(1, 'i')`. We + # look up information in the dynamic axis env. + dynamic_axis_env = _thread_local_state.dynamic_axis_env + axis_name = params.pop('axis_name') + axis_index_groups = params.pop('axis_index_groups') + if axis_index_groups is not None: + shape = (len(axis_index_groups[0]),) + else: + logical_size = lambda frame: frame.hard_size + if isinstance(axis_name, (list, tuple)): + shape = tuple(logical_size(dynamic_axis_env[name]) for name in axis_name) + else: + shape = (logical_size(dynamic_axis_env[axis_name]),) + return parallel_pure_rules[prim](*args, shape=shape, **params) + + pe.staged_out_calls.add(xla_pmap_p) # type: ignore + +parallel_pure_rules = {} # type: ignore + +class DynamicAxisEnvFrame(object): + __slots__ = ["name", "pmap_trace", "hard_size"] + def __init__(self, name, pmap_trace, hard_size): + self.name = name + self.pmap_trace = pmap_trace + self.hard_size = hard_size + +class DynamicAxisEnv(list): + def __contains__(self, axis_name): + return axis_name in (frame.name for frame in self) + + def __getitem__(self, axis_name): + if axis_name not in self: + raise NameError("unbound axis name: {}".format(axis_name)) + for frame in reversed(self): + if frame.name == axis_name: + return frame + else: + assert False + + @property + def sizes(self): + return tuple(frame.hard_size for frame in self) + + @property + def nreps(self): + return prod(frame.hard_size for frame in self) + +class _ThreadLocalState(threading.local): + def __init__(self): + self.dynamic_axis_env = DynamicAxisEnv() + +_thread_local_state = _ThreadLocalState() diff --git a/jax/interpreters/sharded_jit.py b/jax/interpreters/sharded_jit.py index c28f2a0dc..034961c0a 100644 --- a/jax/interpreters/sharded_jit.py +++ b/jax/interpreters/sharded_jit.py @@ -41,10 +41,11 @@ def _map(f, *xs): class ResultToPopulate: pass result_to_populate = ResultToPopulate() -def _pvals_to_results_handler(nrep, npart, partitions, out_pvals): - nouts = len(out_pvals) - handlers = [_pval_to_result_handler(npart, parts, out_pval) - for parts, out_pval in safe_zip(partitions, out_pvals)] # type: ignore + +def _avals_to_results_handler(nrep, npart, partitions, out_avals): + nouts = len(out_avals) + handlers = [_aval_to_result_handler(npart, parts, out_aval) + for parts, out_aval in safe_zip(partitions, out_avals)] def handler(out_bufs): assert nrep * npart == len(out_bufs) @@ -53,23 +54,18 @@ def _pvals_to_results_handler(nrep, npart, partitions, out_pvals): for i, buf in enumerate(tuple_buf): buffers[i][r] = buf assert not any(buf is result_to_populate for bufs in buffers - for buf in bufs) + for buf in bufs) return [h(bufs) for h, bufs in zip(handlers, buffers)] return handler - -def _pval_to_result_handler(npart, parts, pval): - pv, const = pval - if pv is None: - raise NotImplementedError # TODO(skye): handle constant outputs +def _aval_to_result_handler(npart, parts, aval): + if aval is not core.abstract_unit: + spec = pxla.partitioned_sharding_spec(npart, parts, aval) + indices = pxla.spec_to_indices(aval.shape, spec) else: - if pv is not core.abstract_unit: - spec = pxla.partitioned_sharding_spec(npart, parts, pv) - indices = pxla.spec_to_indices(pv.shape, spec) - else: - spec = indices = None - return pxla.aval_to_result_handler(spec, indices, pv) + spec = indices = None + return pxla.aval_to_result_handler(spec, indices, aval) @lu.cache @@ -84,7 +80,8 @@ def _sharded_callable( jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args) else: in_pvals = [pe.PartialVal.unknown(aval) for aval in abstract_args] - jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=False, bottom=True) + jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, # type: ignore + instantiate=False, bottom=True) # type: ignore # TODO(skye): add tests for equationless jaxpr cases if not jaxpr.eqns and all(outvar.aval is core.abstract_unit @@ -138,7 +135,7 @@ def _sharded_callable( handle_outs = _avals_to_results_handler(nrep, num_partitions, out_parts, # type: ignore out_avals) else: - handle_outs = _pvals_to_results_handler(nrep, num_partitions, out_parts, + handle_outs = _pvals_to_results_handler(nrep, num_partitions, out_parts, # type: ignore out_pvals) return partial(_execute_spatially_partitioned, compiled, handle_args, handle_outs) @@ -354,16 +351,14 @@ def with_sharding_constraint(x, partitions: Optional[PartitionSpec]): return sharding_constraint_p.bind(x, partitions=partitions) -@config.register_omnistaging_enabler -def omnistaging_enabler() -> None: - global _avals_to_results_handler, _aval_to_result_handler, \ - _pvals_to_results_handler, _pval_to_result_handler - del _pvals_to_results_handler, _pval_to_result_handler +@config.register_omnistaging_disabler +def omnistaging_disabler() -> None: + global _pvals_to_results_handler, _pval_to_result_handler - def _avals_to_results_handler(nrep, npart, partitions, out_avals): - nouts = len(out_avals) - handlers = [_aval_to_result_handler(npart, parts, out_aval) - for parts, out_aval in safe_zip(partitions, out_avals)] + def _pvals_to_results_handler(nrep, npart, partitions, out_pvals): + nouts = len(out_pvals) + handlers = [_pval_to_result_handler(npart, parts, out_pval) + for parts, out_pval in safe_zip(partitions, out_pvals)] # type: ignore def handler(out_bufs): assert nrep * npart == len(out_bufs) @@ -377,11 +372,14 @@ def omnistaging_enabler() -> None: return handler - - def _aval_to_result_handler(npart, parts, aval): - if aval is not core.abstract_unit: - spec = pxla.partitioned_sharding_spec(npart, parts, aval) - indices = pxla.spec_to_indices(aval.shape, spec) + def _pval_to_result_handler(npart, parts, pval): + pv, const = pval + if pv is None: + raise NotImplementedError # TODO(skye): handle constant outputs else: - spec = indices = None - return pxla.aval_to_result_handler(spec, indices, aval) + if pv is not core.abstract_unit: + spec = pxla.partitioned_sharding_spec(npart, parts, pv) + indices = pxla.spec_to_indices(pv.shape, spec) + else: + spec = indices = None + return pxla.aval_to_result_handler(spec, indices, pv) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index a82500ec2..fa8159874 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -597,8 +597,8 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar raise core.UnexpectedTracerError("Encountered an unexpected tracer.") else: pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args] - jaxpr, pvals, consts = pe.trace_to_jaxpr( - fun, pvals, instantiate=False, stage_out=True, bottom=True) + jaxpr, pvals, consts = pe.trace_to_jaxpr( # type: ignore + fun, pvals, instantiate=False, stage_out=True, bottom=True) # type: ignore map(prefetch, it.chain(consts, jaxpr_literals(jaxpr))) jaxpr = apply_outfeed_rewriter(jaxpr) @@ -608,7 +608,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar if config.omnistaging_enabled: result_handlers = tuple(aval_to_result_handler(device, a) for a in out_avals) else: - result_handlers = tuple(map(partial(_pval_to_result_handler, device), pvals)) + result_handlers = tuple(map(partial(_pval_to_result_handler, device), pvals)) # type: ignore # Computations that only produce constants and/or only rearrange their inputs, # which are often produced from partial evaluation, don't need compilation, @@ -879,7 +879,7 @@ def lower_fun(fun, multiple_results): else: pvals = [pe.PartialVal.unknown(a) for a in avals] jaxpr, _, consts = pe.trace_to_jaxpr(wrapped_fun, pvals, instantiate=True, - stage_out=True) + stage_out=True) # type: ignore xla_consts = _xla_consts(c, consts) outs = jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts, '', *xla_args) if multiple_results: @@ -905,7 +905,7 @@ def lower_fun_initial_style(fun): else: pvals = [pe.PartialVal.unknown(a) for a in avals] jaxpr, _, consts = pe.trace_to_jaxpr( - lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True) + lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True) # type: ignore xla_consts = _xla_consts(c, consts) outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, name_stack, *xla_args) @@ -1277,19 +1277,25 @@ def _call_translation_rule(c, axis_env, in_nodes, name_stack, call_translations[core.call_p] = _call_translation_rule -# TODO(mattjj): remove when omnistaging fully lands +def _axis_index_translation_rule(c, *, axis_name, axis_env, platform): + div = xb.constant(c, np.array(axis_env.nreps // prod(axis_env.sizes), + dtype=np.uint32)) + mod = xb.constant(c, np.array(axis_env.sizes[-1], dtype=np.uint32)) + unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) + return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) +parallel_translations[core.axis_index_p] = _axis_index_translation_rule # type: ignore -def _pval_to_result_handler(device, pval): - pv, const = pval - if pv is None: - const = _device_put_impl(const, device) if device else const - return lambda _: const - else: - return aval_to_result_handler(device, pv) -pe.staged_out_calls.add(xla_call_p) - -@config.register_omnistaging_enabler -def omnistaging_enabler() -> None: +@config.register_omnistaging_disabler +def omnistaging_disabler() -> None: global _pval_to_result_handler - del _pval_to_result_handler + + def _pval_to_result_handler(device, pval): + pv, const = pval + if pv is None: + const = _device_put_impl(const, device) if device else const + return lambda _: const + else: + return aval_to_result_handler(device, pv) + + pe.staged_out_calls.add(xla_call_p) # type: ignore diff --git a/jax/lax/lax.py b/jax/lax/lax.py index 64bc68f28..735c6f4de 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -22,6 +22,7 @@ import warnings import numpy as np +import jax from .. import core from .. import ad_util from .. import api @@ -1373,30 +1374,8 @@ def top_k(operand: Array, k: int) -> Tuple[Array, Array]: return top_k_p.bind(operand, k=k) def tie_in(x: Array, y: Array) -> Array: - """Returns the value of ``y`` but with a fake data dependence on ``x``. - - When staging to XLA (e.g. running under jit or pmap), values that don't depend - on computation inputs are computed op-by-op, and folded into the XLA - computation as constants. - - ``tie_in`` provides a way to explicitly stage values into the computation. - When staging to XLA and ``x`` is already staged, then the result of ``tie_in`` - is ``y``, but staged to XLA. Downstream use of the result will also be staged - to XLA. - - For example, ``lax.sin(const)`` would be constant-folded if ``const`` is - a constant array, but ``lax.sin(lax.tie_in(x, const))``, will be staged to - XLA as long as ``x`` is staged to XLA. - """ - if config.omnistaging_enabled: - return y - else: - return tie_in_p.bind(x, y) - -# def tie_in(x: Array, y: Array) -> Array: -# """Deprecated. Ignores ``x`` and returns ``y``.""" -# return y - + """Deprecated. Ignores ``x`` and returns ``y``.""" + return y def full(shape: Shape, fill_value: Array, dtype: Optional[DType] = None) -> Array: """Returns an array of `shape` filled with `fill_value`. @@ -1502,7 +1481,13 @@ def stop_gradient(x): >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) array(0., dtype=float32) """ - return tree_map(ad_util.stop_gradient_p.bind, x) + def stop(x): + if (dtypes.issubdtype(_dtype(x), np.floating) or + dtypes.issubdtype(_dtype(x), np.complexfloating)): + return ad_util.stop_gradient_p.bind(x) + else: + return x # only bind primitive on inexact dtypes, to avoid some staging + return tree_map(stop, x) ### convenience wrappers around traceables @@ -5656,30 +5641,6 @@ xla.translations[top_k_p] = partial(standard_translate, 'top_k') ad.primitive_jvps[top_k_p] = _top_k_jvp batching.primitive_batchers[top_k_p] = _top_k_batch_rule -def _tie_in_transpose_rule(t, x, y): - if ad.is_undefined_primal(x): - return [ad_util.Zero(x.aval), t] - else: - return [ad_util.Zero.from_value(x), t] - -def _tie_in_batch_rule(batched_args, batch_dims): - y = tie_in(*batched_args) - _, bdim_y = batch_dims - return y, bdim_y - -def _tie_in_impl(x, y): - core.check_valid_jaxtype(x) - core.check_valid_jaxtype(y) - return y - -tie_in_p = Primitive('tie_in') -tie_in_p.def_impl(_tie_in_impl) -tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y)) -xla.translations[tie_in_p] = lambda c, x, y: y -ad.deflinear2(tie_in_p, _tie_in_transpose_rule) -batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule -masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1] - def _stop_gradient_jvp_rule(primals, tangents): # if we don't call stop_gradient here, we'd only peel off one autodiff tracer @@ -6198,7 +6159,72 @@ def _check_user_dtype_supported(dtype, fun_name=None): warnings.warn(msg.format(dtype, fun_name , truncated_dtype)) -@config.register_omnistaging_enabler -def omnistaging_enabler() -> None: - global _tie_in_transpose_rule, _tie_in_batch_rule, _tie_in_impl, tie_in_p - del _tie_in_transpose_rule, _tie_in_batch_rule, _tie_in_impl, tie_in_p +def _canonicalize_axis(axis, num_dims): + """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" + axis = operator.index(axis) + if not -num_dims <= axis < num_dims: + raise ValueError( + "axis {} is out of bounds for array of dimension {}".format( + axis, num_dims)) + if axis < 0: + axis = axis + num_dims + return axis + + +tie_in_p = Primitive('tie_in') + +@config.register_omnistaging_disabler +def omnistaging_disabler() -> None: + global tie_in + + def tie_in(x: Array, y: Array) -> Array: + """Returns the value of ``y`` but with a fake data dependence on ``x``. + + When staging to XLA (e.g. running under jit or pmap), values that don't depend + on computation inputs are computed op-by-op, and folded into the XLA + computation as constants. + + ``tie_in`` provides a way to explicitly stage values into the computation. + When staging to XLA and ``x`` is already staged, then the result of ``tie_in`` + is ``y``, but staged to XLA. Downstream use of the result will also be staged + to XLA. + + For example, ``lax.sin(const)`` would be constant-folded if ``const`` is + a constant array, but ``lax.sin(lax.tie_in(x, const))``, will be staged to + XLA as long as ``x`` is staged to XLA. + """ + if config.omnistaging_enabled: + return y + else: + return tie_in_p.bind(x, y) + + # If lax has already been imported, we need to monkey-patch the + # lax/__init__.py import of tie_in. If not (i.e. if this is running at lax + # module creation time) then we'll get an import error. + try: + jax.lax.tie_in = tie_in + except AttributeError: + pass + + def _tie_in_transpose_rule(t, x, y): + if ad.is_undefined_primal(x): + return [ad_util.Zero(x.aval), t] + else: + return [ad_util.Zero.from_value(x), t] + + def _tie_in_batch_rule(batched_args, batch_dims): + y = tie_in(*batched_args) + _, bdim_y = batch_dims + return y, bdim_y + + def _tie_in_impl(x, y): + core.check_valid_jaxtype(x) + core.check_valid_jaxtype(y) + return y + + tie_in_p.def_impl(_tie_in_impl) + tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y)) + xla.translations[tie_in_p] = lambda c, x, y: y + ad.deflinear2(tie_in_p, _tie_in_transpose_rule) + batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule + masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1] diff --git a/jax/lax/lax_control_flow.py b/jax/lax/lax_control_flow.py index f44df1e3c..d01c2664a 100644 --- a/jax/lax/lax_control_flow.py +++ b/jax/lax/lax_control_flow.py @@ -62,22 +62,18 @@ T = TypeVar('T') @cache() def _initial_style_untyped_jaxpr(fun: Callable, in_tree, in_avals): - in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - with core.initial_style_staging(): - jaxpr, out_pvals, consts = pe.trace_to_jaxpr( - wrapped_fun, in_pvals, instantiate=True, stage_out=False) - return jaxpr, out_pvals, consts, out_tree + jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) + return jaxpr, out_avals, consts, out_tree() @cache() def _initial_style_jaxpr(fun: Callable, in_tree, in_avals): - jaxpr, out_pvals, consts, out_tree = _initial_style_untyped_jaxpr( - fun, in_tree, in_avals) - out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0]) + jaxpr, out_avals, consts, out_tree = \ + _initial_style_untyped_jaxpr(fun, in_tree, in_avals) const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts) typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), (), const_avals + in_avals, out_avals) - return typed_jaxpr, consts, out_tree() + return typed_jaxpr, consts, out_tree def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable], in_tree, in_avals): @@ -88,37 +84,29 @@ def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable], # for each one, it makes another that accepts *all* constants, but only uses # those that it needs (dropping the rest). - jaxprs, all_out_pvals, all_consts, all_out_trees = unzip4([ - _initial_style_untyped_jaxpr(fun, in_tree, in_avals) for fun in funs]) + jaxprs, all_out_avals, all_consts, all_out_trees = unzip4( + _initial_style_untyped_jaxpr(fun, in_tree, in_avals) for fun in funs) newvar = core.gensym(jaxprs, suffix='_') - all_const_avals = tuple( - tuple(raise_to_shaped(core.get_aval(c)) for c in consts) - for consts in all_consts) - unused_const_vars = tuple( - tuple(newvar(aval) for aval in const_avals) - for const_avals in all_const_avals) + all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts] + for consts in all_consts] + unused_const_vars = [[newvar(aval) for aval in const_avals] + for const_avals in all_const_avals] def pad_jaxpr_constvars(i, jaxpr): prefix = util.concatenate(unused_const_vars[:i]) suffix = util.concatenate(unused_const_vars[i+1:]) - constvars = prefix + jaxpr.constvars + suffix + constvars = [*prefix, *jaxpr.constvars, *suffix] return core.Jaxpr(constvars=constvars, invars=jaxpr.invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns) - const_avals = tuple(util.concatenate(all_const_avals)) - - def type_and_const_convert_jaxpr(jaxpr, out_pvals): - out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0]) - return core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), - (), const_avals + in_avals, out_avals) - + consts = util.concatenate(all_consts) + const_avals = util.concatenate(all_const_avals) jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)] - typed_jaxprs = _map(type_and_const_convert_jaxpr, jaxprs, all_out_pvals) - - return (tuple(typed_jaxprs), - tuple(util.concatenate(all_consts)), - tuple(out_tree() for out_tree in all_out_trees)) + typed_jaxprs = [core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), + (), [*const_avals, *in_avals], out_avals) + for jaxpr, out_avals in zip(jaxprs, all_out_avals)] + return typed_jaxprs, consts, all_out_trees def _abstractify(x): return raise_to_shaped(core.get_aval(x)) @@ -1517,7 +1505,7 @@ def _prune_zeros(ts): def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): - if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace: + if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace: # type: ignore params = dict(reverse=reverse, length=length, num_consts=num_consts, num_carry=num_carry, jaxpr=jaxpr, linear=linear, unroll=unroll) @@ -2450,26 +2438,29 @@ def associative_scan(fn, elems, reverse=False): return tree_unflatten(tree, scans) -# TODO(mattjj): remove when omnistaging fully lands -@config.register_omnistaging_enabler -def omnistaging_enabler() -> None: +@config.register_omnistaging_disabler +def omnistaging_disabler() -> None: global _initial_style_untyped_jaxpr, _initial_style_jaxpr, \ _initial_style_jaxprs_with_common_consts @cache() def _initial_style_untyped_jaxpr(fun: Callable, in_tree, in_avals): + in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) - return jaxpr, out_avals, consts, out_tree() + with core.initial_style_staging(): # type: ignore + jaxpr, out_pvals, consts = pe.trace_to_jaxpr( # type: ignore + wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore + return jaxpr, out_pvals, consts, out_tree @cache() def _initial_style_jaxpr(fun: Callable, in_tree, in_avals): - jaxpr, out_avals, consts, out_tree = \ - _initial_style_untyped_jaxpr(fun, in_tree, in_avals) + jaxpr, out_pvals, consts, out_tree = _initial_style_untyped_jaxpr( + fun, in_tree, in_avals) + out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0]) const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts) typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), (), const_avals + in_avals, out_avals) - return typed_jaxpr, consts, out_tree + return typed_jaxpr, consts, out_tree() def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable], in_tree, in_avals): @@ -2480,26 +2471,34 @@ def omnistaging_enabler() -> None: # for each one, it makes another that accepts *all* constants, but only uses # those that it needs (dropping the rest). - jaxprs, all_out_avals, all_consts, all_out_trees = unzip4( - _initial_style_untyped_jaxpr(fun, in_tree, in_avals) for fun in funs) + jaxprs, all_out_pvals, all_consts, all_out_trees = unzip4([ + _initial_style_untyped_jaxpr(fun, in_tree, in_avals) for fun in funs]) newvar = core.gensym(jaxprs, suffix='_') - all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts] - for consts in all_consts] - unused_const_vars = [[newvar(aval) for aval in const_avals] - for const_avals in all_const_avals] + all_const_avals = tuple( + tuple(raise_to_shaped(core.get_aval(c)) for c in consts) + for consts in all_consts) + unused_const_vars = tuple( + tuple(newvar(aval) for aval in const_avals) + for const_avals in all_const_avals) def pad_jaxpr_constvars(i, jaxpr): prefix = util.concatenate(unused_const_vars[:i]) suffix = util.concatenate(unused_const_vars[i+1:]) - constvars = [*prefix, *jaxpr.constvars, *suffix] + constvars = prefix + jaxpr.constvars + suffix return core.Jaxpr(constvars=constvars, invars=jaxpr.invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns) - consts = util.concatenate(all_consts) - const_avals = util.concatenate(all_const_avals) + const_avals = tuple(util.concatenate(all_const_avals)) + + def type_and_const_convert_jaxpr(jaxpr, out_pvals): + out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0]) + return core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), + (), const_avals + in_avals, out_avals) + jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)] - typed_jaxprs = [core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), - (), [*const_avals, *in_avals], out_avals) - for jaxpr, out_avals in zip(jaxprs, all_out_avals)] - return typed_jaxprs, consts, all_out_trees + typed_jaxprs = _map(type_and_const_convert_jaxpr, jaxprs, all_out_pvals) + + return (tuple(typed_jaxprs), + tuple(util.concatenate(all_consts)), + tuple(out_tree() for out_tree in all_out_trees)) diff --git a/jax/lax/lax_parallel.py b/jax/lax/lax_parallel.py index 6f1c6dbb2..5d22d68e7 100644 --- a/jax/lax/lax_parallel.py +++ b/jax/lax/lax_parallel.py @@ -458,12 +458,10 @@ def _psum_transpose_rule(cts, axis_name, axis_index_groups): psum_p = core.Primitive('psum') psum_p.multiple_results = True -psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p)) psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args)) pxla.soft_pmap_rules[psum_p] = \ partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum) xla.parallel_translations[psum_p] = _psum_translation_rule -pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) ad.deflinear(psum_p, _psum_transpose_rule) pxla.multi_host_supported_collectives.add(psum_p) batching.split_axis_rules[psum_p] = partial(_split_axis_comm_assoc, psum_p) @@ -474,6 +472,21 @@ batching.collective_rules[psum_p] = \ lambda v, d: v.sum(d), lambda v, axis_size: axis_size * v) +# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at +# tracing time. +@psum_p.def_custom_bind +def psum_bind(*args, axis_name, axis_index_groups): + if all(not isinstance(x, core.Tracer) for x in args): + if axis_index_groups is not None: + size = len(axis_index_groups[0]) + elif type(axis_name) is tuple: + size = prod([core.axis_frame(name).size for name in axis_name]) # type: ignore + else: + size = core.axis_frame(axis_name).size # type: ignore + return tuple(size * x for x in args) + return core.Primitive.bind( + psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups) + pmax_p = core.Primitive('pmax') pmax_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) @@ -660,20 +673,6 @@ def _axis_index_translation_rule(c, *, axis_name, axis_env, platform): unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) -def _axis_index_bind(*, axis_name): - dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env - frame = dynamic_axis_env[axis_name] - trace = frame.pmap_trace - - out_aval = ShapedArray((), np.int32) - out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) - eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, - dict(axis_name=axis_name), - source_info_util.current()) - out_tracer.recipe = eqn - - return out_tracer - def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name): assert not vals and not mapped idx = axis_index(axis_name) # type: ignore @@ -682,46 +681,58 @@ def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name): axis_index_p = core.Primitive('axis_index') xla.parallel_translations[axis_index_p] = _axis_index_translation_rule pxla.soft_pmap_rules[axis_index_p] = _axis_index_soft_pmap_rule # type: ignore -axis_index_p.def_custom_bind(_axis_index_bind) axis_index_p.def_abstract_eval( lambda *args, **params: ShapedArray((), np.int32)) pxla.multi_host_supported_collectives.add(axis_index_p) +# Axis index doesn't get any arguments, so that the default bind would have no +# way to call into a data-dependency based trace such as vmap. Each trace that +# wants to bind an axis name has to additionally implement `process_axis_index` +# and put its main trace on the axis env stack. +def _axis_index_bind(*, axis_name): + frame = core.axis_frame(axis_name) + if frame.main_trace is not None: + trace = frame.main_trace.trace_type(frame.main_trace, core.cur_sublevel()) + return trace.process_axis_index(frame) + return core.Primitive.bind(axis_index_p, axis_name=axis_name) +axis_index_p.def_custom_bind(_axis_index_bind) -@config.register_omnistaging_enabler -def omnistaging_enabler() -> None: - # We set a special bind rule for psum so that psum(1, 'i') can be evaluated at - # tracing time. - @psum_p.def_custom_bind - def psum_bind(*args, axis_name, axis_index_groups): - if all(not isinstance(x, core.Tracer) for x in args): - if axis_index_groups is not None: - size = len(axis_index_groups[0]) - elif type(axis_name) is tuple: - size = prod([core.axis_frame(name).size for name in axis_name]) # type: ignore - else: - size = core.axis_frame(axis_name).size # type: ignore - return tuple(size * x for x in args) - return core.Primitive.bind( - psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups) +def _process_axis_index(self, frame): + return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0) +batching.BatchTrace.process_axis_index = _process_axis_index - if psum_p in pxla.parallel_pure_rules: - del pxla.parallel_pure_rules[psum_p] - # Axis index doesn't get any arguments, so that the default bind would have no - # way to call into a data-dependency based trace such as vmap. Each trace that - # wants to bind an axis name has to additionally implement `process_axis_index` - # and put its main trace on the axis env stack. +@config.register_omnistaging_disabler +def omnistaging_disabler() -> None: + global axis_index + + psum_p.bind = partial(core.Primitive.bind, psum_p) + psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p)) # type: ignore + pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) # type: ignore + def _axis_index_bind(*, axis_name): - frame = core.axis_frame(axis_name) - if frame.main_trace is not None: - trace = frame.main_trace.trace_type(frame.main_trace, core.cur_sublevel()) - return trace.process_axis_index(frame) - return core.Primitive.bind(axis_index_p, axis_name=axis_name) + dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env + frame = dynamic_axis_env[axis_name] + sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1] + nreps = dynamic_axis_env.nreps + trace = frame.pmap_trace + + out_aval = ShapedArray((), np.int32) + out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) + eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, + dict(nreps=nreps, sizes=sizes, axis_name=axis_name), + source_info_util.current()) + out_tracer.recipe = eqn + + return out_tracer + + def _axis_index_translation_rule(c, nreps, sizes, axis_name): + div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32)) + mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32)) + unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) + return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) axis_index_p.def_custom_bind(_axis_index_bind) - - def process_axis_index(self, frame): - return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0) - - batching.BatchTrace.process_axis_index = process_axis_index + axis_index_p.def_abstract_eval( + lambda *args, **params: ShapedArray((), np.int32)) + xla.translations[axis_index_p] = _axis_index_translation_rule diff --git a/jax/nn/functions.py b/jax/nn/functions.py index f511ee3a0..6e3d2457c 100644 --- a/jax/nn/functions.py +++ b/jax/nn/functions.py @@ -266,8 +266,9 @@ def one_hot(x, num_classes, *, dtype=jnp.float64): dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). """ - num_classes = core.concrete_or_error(int, num_classes, - "in jax.nn.one_hot argument `num_classes`") + num_classes = core.concrete_or_error( + int, num_classes, + "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = dtypes.canonicalize_dtype(dtype) x = jnp.asarray(x) lhs = x[..., jnp.newaxis] diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 7c5e32673..9eea2dbd1 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -1137,6 +1137,16 @@ def reshape(a, newshape, order="C"): def _compute_newshape(a, newshape): """Fixes a -1 value in newshape, if present.""" # other errors, like having more than one -1, are caught downstream + try: iter(newshape) + except: iterable = False + else: iterable = True + if iterable: + newshape = [core.concrete_or_error(int, d, + "The error arose in jax.numpy.reshape.") + for d in newshape] + else: + newshape = core.concrete_or_error(int, newshape, + "The error arose in jax.numpy.reshape.") newsize = _prod(newshape) if newsize < 0: fix = a.size // -newsize @@ -2417,7 +2427,7 @@ def identity(n, dtype=None): def arange(start, stop=None, step=None, dtype=None): lax._check_user_dtype_supported(dtype, "arange") require = partial(core.concrete_or_error, _np_asarray) - msg = "in jax.numpy.arange argument `{}`".format + msg = "It arose in jax.numpy.arange argument `{}`.".format if stop is None and step is None: start = require(start, msg("stop")) dtype = dtype or _dtype(start) @@ -2622,7 +2632,7 @@ def repeat(a, repeats, axis=None, *, total_repeat_length=None): # If total_repeat_length is not given, can't compile, use a default. if total_repeat_length is None: - repeats = core.concrete_or_error(np.array, repeats, "jax.numpy.repeat") + repeats = core.concrete_or_error(np.array, repeats, "It arose in jax.numpy.repeat.") repeats = np.ravel(repeats) if ndim(a) != 0: repeats = np.broadcast_to(repeats, [a.shape[axis]]) diff --git a/tests/api_test.py b/tests/api_test.py index 3b28f948a..6edb3992e 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -513,7 +513,7 @@ class APITest(jtu.JaxTestCase): self.assertRaisesRegex( TypeError, - f"Try using `x.astype\\({castfun.__name__}\\)` instead.", + f"[Tt]ry using `x.astype\\({castfun.__name__}\\)`", lambda: jit(f)(1.0)) def test_switch_value_jit(self): @@ -549,7 +549,7 @@ class APITest(jtu.JaxTestCase): self.assertRaisesRegex( TypeError, "('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer" - "|Abstract tracer value encountered where concrete value is expected .*)", lambda: jit(f)(0)) + "|Abstract tracer value encountered where concrete value is expected.*)", lambda: jit(f)(0)) def test_unimplemented_interpreter_rules(self): foo_p = Primitive('foo') diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 5b5cef2ee..9761d38e3 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4145,7 +4145,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jnp.sum(jnp.arange(3), (0, 0)) def testArangeConcretizationError(self): - msg = r"Abstract tracer.*\(in jax.numpy.arange argument `{}`\).*".format + msg = r"It arose in jax.numpy.arange argument `{}`".format with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')): jax.jit(jnp.arange)(3) diff --git a/tests/nn_test.py b/tests/nn_test.py index 14f612325..83508af15 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -150,7 +150,7 @@ class NNFunctionsTest(jtu.JaxTestCase): def testOneHotConcretizationError(self): # https://github.com/google/jax/issues/3654 - msg = r"Abstract tracer.*\(in jax.nn.one_hot argument `num_classes`\).*" + msg = r"in jax.nn.one_hot argument `num_classes`" with self.assertRaisesRegex(core.ConcretizationTypeError, msg): jax.jit(nn.one_hot)(3, 5)