omnistaging on by default (#4038)

This commit is contained in:
Matthew Johnson 2020-09-15 08:06:46 -07:00 committed by GitHub
parent 6af476900a
commit 2678a4647a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 896 additions and 867 deletions

View File

@ -53,27 +53,27 @@ jobs:
- python-version: 3.6
os: ubuntu-latest
enable-x64: 0
enable-omnistaging: 0
enable-omnistaging: 1
package-overrides: "none"
num_generated_cases: 25
- python-version: 3.7
os: ubuntu-latest
enable-x64: 1
enable-omnistaging: 0
enable-omnistaging: 1
# Test experimental NumPy dispatch
package-overrides: "git+https://github.com/seberg/numpy-dispatch.git"
num_generated_cases: 25
- python-version: 3.6
os: ubuntu-latest
enable-x64: 1
enable-omnistaging: 0
enable-omnistaging: 1
# Test with numpy version that matches Google-internal version
package-overrides: "numpy==1.16.4"
num_generated_cases: 10
- python-version: 3.7
os: ubuntu-latest
enable-x64: 0
enable-omnistaging: 1
enable-omnistaging: 0
package-overrides: "none"
num_generated_cases: 8
steps:

View File

@ -164,34 +164,15 @@ before (with two input vars, one for each element of the input tuple)
Constant Vars
--------------
-------------
ConstVars arise when the computation contains array constants, either
from the Python program, or from constant-folding. For example, the function
``func6`` below
Some values in jaxprs are constants, in that their value does not depend on the
jaxpr's arguments. When these values are scalars they are represented directly
in the jaxpr equations; non-scalar array constants are instead hoisted out to
the top-level jaxpr, where they correspond to constant variables ("constvars").
These constvars differ from the other jaxpr parameters ("invars") only as a
bookkeeping convention.
>>> def func5(first, second):
... temp = first + jnp.sin(second) * 3. - jnp.ones(8)
... return temp
...
>>> def func6(first):
... return func5(first, jnp.ones(8))
...
JAX produces the following jaxpr
>>> print(make_jaxpr(func6)(jnp.ones(8)))
{ lambda b d ; a.
let c = add a b
e = sub c d
in (e,) }
When tracing ``func6``, the function ``func5`` is invoked with a constant value
(``np.ones(8)``) for the second argument. As a result, the sub-expression
``jnp.sin(second) * 3.`` is constant-folded.
There are two ConstVars, ``b`` (standing for ``jnp.sin(second) * 3.``) and ``d``
(standing for ``jnp.ones(8)``). Unfortunately, it is not easy to tell from the
jaxpr notation what constants the constant variables stand for.
Higher-order primitives
-----------------------
@ -293,44 +274,25 @@ contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`
>>> def func8(arg1, arg2): # arg2 is a pair
... return lax.cond(arg1 >= 0.,
... lambda xtrue: xtrue[0],
... lambda xfalse: jnp.ones(1) + xfalse[1],
... lambda xfalse: jnp.array([1]) + xfalse[1],
... arg2)
...
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda f ; a b c.
let d = ge a 0.0
e = convert_element_type[ new_dtype=int32
old_dtype=bool ] d
g = cond[ branches=( { lambda ; c a b.
let d = add c b
in (d,) }
{ lambda ; e_ a b.
let
{ lambda a ; b c d.
let e = ge b 0.0
f = convert_element_type[ new_dtype=int32
old_dtype=bool ] e
g = cond[ branches=( { lambda ; a b c.
let d = convert_element_type[ new_dtype=float32
old_dtype=int32 ] a
e = add d c
in (e,) }
{ lambda ; f_ a b.
let
in (a,) } )
linear=(False, False, False) ] e f b c
linear=(False, False, False) ] f a c d
in (g,) }
The top-level jaxpr has one `constvar` ``f`` (corresponding to
``jnp.ones(1)`` from the body of the first (false) branch) and three
input variables ``a b c`` (corresponding to ``arg1`` and the two
elements of ``arg2``; note that ``arg2`` has been flattened). The
``false_jaxpr`` has three input variables (``c`` corresponding to the
constant for ``jnp.ones(1)``, and ``a b`` for the two elements of
``arg2`` that are passed to ``false_jaxpr``). The ``true_jaxpr`` has
three input variables. The first (``e_``) is an unused argument
matching the constant first argument ``c`` of ``false_jaxpr``
(required for the jaxpr signatures to match). The subsequent two
correspond to the two elements of ``arg2`` that is passed to
``true_jaxpr``.
The actual operands to the cond primitive are: ``e f b c``, which
correspond in order to:
* one operand for the predicate,
* one constant (only used by ``false_jaxpr``, but passed to both),
i.e., ``f``, which is a constvar for the top-level jaxpr
* two operands passed to both jaxprs, i.e., ``b`` and ``c``, which are
input vars, corresponding to ``arg2`` for the top-level jaxpr.
While
^^^^^
@ -357,32 +319,22 @@ For example, here is an example fori loop
... arg + ones)
...
>>> print(make_jaxpr(func10)(np.ones(16), 5))
{ lambda c d ; a b.
let e = add a d
_ _ f = while[ body_jaxpr={ lambda ; e g a b c.
let d = add a 1
f = add c e
h = add f g
in (d, b, h) }
{ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(16,) ] 1.0
d = add a c
_ _ e = while[ body_jaxpr={ lambda ; a b c d e.
let f = add c 1
g = mul a 3.0
h = add e g
i = add h b
in (f, d, i) }
body_nconsts=2
cond_jaxpr={ lambda ; a b c.
let d = lt a b
in (d,) }
cond_nconsts=0 ] c a 0 b e
in (f,) }
The top-level jaxpr has two constvars: ``c`` (corresponding to ``ones * 3.`` from the body
of the loop) and ``d`` (corresponding to the use of ``ones`` in the initial carry).
There are also two input variables (``a`` corresponding to ``arg`` and ``b`` corresponding
to ``n``).
The loop carry consists of three values, as seen in the body of ``cond_jaxpr``
(corresponding to the iteration index, iteration end, and the accumulated value carry).
Note that ``body_jaxpr`` takes 5 input variables. The first two are actually
constvars: ``e`` corresponding to ``ones * 3`` and ``g`` corresponding to the
captures use of ``arg`` in the loop body.
The parameter ``body_nconsts = 2`` specifies that there are 2 constants for the
``body_jaxpr``.
The other 3 input variables for ``body_jaxpr`` correspond to the flattened carry values.
cond_nconsts=0 ] c a 0 b d
in (e,) }
The while primitive takes 5 arguments: ``c a 0 b e``, as follows:
@ -395,13 +347,13 @@ Scan
JAX supports a special form of loop over the elements of an array (with
statically known shape). The fact that there are a fixed number of iterations
makes this form of looping easily reverse-differentiable. Such loops are constructed
with the :py:func:`jax.lax.scan` operator::
makes this form of looping easily reverse-differentiable. Such loops are
constructed with the :py:func:`jax.lax.scan` function::
lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])
Here ``C`` is the type of the scan carry, ``A`` is the element type of the input array(s),
and ``B`` is the element type of the output array(s).
Here ``C`` is the type of the scan carry, ``A`` is the element type of the
input array(s), and ``B`` is the element type of the output array(s).
For the example consider the function ``func11`` below
@ -415,12 +367,14 @@ For the example consider the function ``func11`` below
... return lax.scan(body, 0., (arr, ones))
...
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
{ lambda c ; a b.
let d e = scan[ jaxpr={ lambda ; f a b c.
let d = mul b c
e = add a d
g = add e f
in (g, a) }
{ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(16,) ] 1.0
d e = scan[ jaxpr={ lambda ; a b c d.
let e = mul c d
f = add b e
g = add f a
in (g, b) }
length=16
linear=(False, False, False, False)
num_carry=1
@ -429,17 +383,6 @@ For the example consider the function ``func11`` below
unroll=1 ] b 0.0 a c
in (d, e) }
The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant,
and two input variables corresponding to the arguments ``arr`` and ``extra``.
The body of the scan has 4 input variables, of which:
* one (``f``) is a constant (since ``num_consts = 1``), and stands for the
captured variable ``extra`` used in the loop body,
* one (``a``) is the value of the carry (since ``num_carry = 1``)
* The remaining 2 are the input values. ``b`` is the array element from the
first array passed to lax.scan (``arr``) and ``c`` is the second array
(``ones``).
The ``linear`` parameter describes for each of the input variables whether they
are guaranteed to be used linearly in the body. Once the scan goes through
linearization, more arguments will be linear.
@ -466,37 +409,27 @@ computation should run. For example
... return arg + inner(arg - 2.)
...
>>> print(make_jaxpr(func12)(1.))
{ lambda b ; a.
let c = sub a 2.0
d = xla_call[ backend=None
call_jaxpr={ lambda ; c b a.
let d = mul b c
e = add a d
{ lambda ; a.
let b = sub a 2.0
c = xla_call[ backend=None
call_jaxpr={ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
d = mul a c
e = add b d
in (e,) }
device=None
donated_invars=(False, False, False)
name=inner ] b a c
e = add a d
in (e,) }
donated_invars=(False, False)
name=inner ] a b
d = add a c
in (d,) }
The top-level constvar ``b`` refers to the ``jnp.ones(1)`` constant, and
the top-level input variable `a` refers to the ``arg`` parameter of ``func12``.
The ``xla_call`` primitive stands for a call to the jitted ``inner`` function.
The primitive has the function body in the ``call_jaxpr`` parameter, a jaxpr
with 3 input parameters:
* ``c`` is a constvar and stands for the ``ones`` constant,
* ``b`` corresponds to the free variable ``arg`` captured in the ``inner`` function,
* ``a`` corresponds to the ``inner`` parameter ``x``.
The primitive takes three arguments ``b a c``.
XLA_pmap
^^^^^^^^
If you use the :py:func:`jax.pmap` transformation, the function to be
mapped is captured using the ``xla_pmap`` primitive. Consider this
example
If you use the :py:func:`jax.pmap` transformation, the function to be mapped is
captured using the ``xla_pmap`` primitive. Consider this example
>>> from jax import pmap
>>>
@ -507,34 +440,30 @@ example
... return pmap(inner, axis_name='rows')(arr)
...
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
{ lambda c ; a b.
let d = xla_pmap[ axis_name=rows
{ lambda ; a b.
let c = xla_pmap[ axis_name=rows
axis_size=1
backend=None
call_jaxpr={ lambda ; d b a.
let c = add a b
call_jaxpr={ lambda ; a b.
let c = add b a
d = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
e = add c d
f = psum[ axis_index_groups=None
axis_name=rows ] a
axis_name=rows ] b
g = div e f
in (g,) }
devices=None
donated_invars=(False, False, False)
donated_invars=(False, False)
global_axis_size=None
mapped_invars=(True, False, True)
name=inner ] c b a
in (d,) }
mapped_invars=(False, True)
name=inner ] b a
in (c,) }
The top-level constvar ``c`` refers to the ``jnp.ones(1)`` constant.
The ``xla_pmap`` primitive specifies the name of the axis (parameter ``rows``)
and the body of the function to be mapped as the ``call_jaxpr`` parameter. The
and the body of the function to be mapped as the ``call_jaxpr`` parameter.
value of this parameter is a Jaxpr with 3 input variables:
* ``d`` stands for the constant ``jnp.ones(1)``,
* ``b`` stands for the free variable ``extra``,
* ``a`` stands for the parameter ``x`` of ``inner``.
The parameter ``mapped_invars`` specify which of the input variables should be
mapped and which should be broadcast. In our example, the value of ``extra``
is broadcast, the other input values are mapped.

View File

@ -393,7 +393,7 @@ def disable_jit():
... return y + 3
...
>>> print(f(jax.numpy.array([1, 2, 3])))
Value of y is Traced<ShapedArray(int32[3]):JaxprTrace(level=-1/1)>
Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=0/1)>
[5 7 9]
Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`,
@ -651,7 +651,7 @@ def _xla_computation(
else:
pvals = [pe.PartialVal.unknown(aval) for aval in avals]
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
jaxtree_fun, pvals, instantiate=True, stage_out=True)
jaxtree_fun, pvals, instantiate=True, stage_out=True) # type: ignore
out_avals = [raise_to_shaped(pval.get_aval()) for pval in out_pvals]
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
@ -1910,11 +1910,12 @@ def make_jaxpr(fun: Callable,
>>> jax.make_jaxpr(jax.grad(f))(3.0)
{ lambda ; a.
let b = cos a
c = cos b
d = mul 1.0 c
e = neg d
f = sin a
g = mul e f
c = sin a
_ = sin b
d = cos b
e = mul 1.0 d
f = neg e
g = mul f c
in (g,) }
"""
_check_callable(fun)
@ -1936,7 +1937,7 @@ def make_jaxpr(fun: Callable,
else:
in_pvals = [pe.PartialVal.unknown(a) for a in in_avals]
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
jaxtree_fun, in_pvals, instantiate=True, stage_out=True)
jaxtree_fun, in_pvals, instantiate=True, stage_out=True) # type: ignore
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
return typed_jaxpr
@ -2214,7 +2215,7 @@ class CustomTransformsFunction(object):
if config.omnistaging_enabled:
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
else:
with core.initial_style_staging():
with core.initial_style_staging(): # type: ignore
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
outs = self.prim.bind(*it.chain(consts, args_flat), jaxpr=jaxpr,
in_tree=in_tree, out_tree=out_tree(),

View File

@ -42,8 +42,8 @@ class Config:
self.meta = {}
self.FLAGS = NameSpace(self.read)
self.use_absl = False
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', False)
self._omnistaging_enablers = []
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True)
self._omnistaging_disablers = []
def update(self, name, val):
if self.use_absl:
@ -114,24 +114,25 @@ class Config:
self.complete_absl_config(absl.flags)
already_configured_with_absl = True
if FLAGS.jax_omnistaging:
self.enable_omnistaging()
if not FLAGS.jax_omnistaging:
self.disable_omnistaging()
def register_omnistaging_enabler(self, enabler):
if not self.omnistaging_enabled:
self._omnistaging_enablers.append(enabler)
def register_omnistaging_disabler(self, disabler):
if self.omnistaging_enabled:
self._omnistaging_disablers.append(disabler)
else:
enabler()
disabler()
# TODO(mattjj): remove this when omnistaging fully lands
def enable_omnistaging(self):
if not self.omnistaging_enabled:
for enabler in self._omnistaging_enablers:
enabler()
self.omnistaging_enabled = True
raise Exception("can't re-enable omnistaging after it's been disabled")
def disable_omnistaging(self):
pass
if self.omnistaging_enabled:
for disabler in self._omnistaging_disablers:
disabler()
self.omnistaging_enabled = False
class NameSpace(object):
@ -156,6 +157,6 @@ flags.DEFINE_bool(
flags.DEFINE_bool(
'jax_omnistaging',
bool_env('JAX_OMNISTAGING', False),
bool_env('JAX_OMNISTAGING', True),
help='Enable staging based on dynamic context rather than data dependence.'
)

View File

@ -24,7 +24,7 @@ import threading
import types
from typing import (Any, Callable, ClassVar, Dict, Generator,
Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple,
Type, Union, cast, no_type_check)
Type, Union, cast)
import numpy as np
@ -266,19 +266,14 @@ class Primitive:
def __repr__(self):
return '{}'.format(self.name)
def bind(self, *args, **kwargs):
def bind(self, *args, **params):
assert skip_checks or all(isinstance(arg, Tracer)
or valid_jaxtype(arg) for arg in args), args
top_trace = find_top_trace(args)
if top_trace is None:
return self.impl(*args, **kwargs)
tracers = map(top_trace.full_raise, args)
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
if self.multiple_results:
return map(full_lower, out_tracer)
else:
return full_lower(out_tracer)
out = top_trace.process_primitive(self, tracers, params)
return map(full_lower, out) if self.multiple_results else full_lower(out)
def def_impl(self, impl):
self.impl = impl
@ -517,14 +512,8 @@ class Tracer:
def __long__(self): return self.aval._long(self)
def __hex__(self): return self.aval._hex(self)
def __oct__(self): return self.aval._oct(self)
def __float__(self):
raise TypeError("JAX Tracer object cannot be interpreted as a float. "
"Try using `x.astype(float)` instead.")
def __complex__(self):
raise TypeError("JAX Tracer object cannot be interpreted as a complex. "
"Try using `x.astype(complex)` instead.")
def __float__(self): return self.aval._float(self)
def __complex__(self): return self.aval._complex(self)
def __setitem__(self, idx, val):
raise TypeError("JAX 'Tracer' objects do not support item assignment")
@ -571,6 +560,9 @@ class Tracer:
def __deepcopy__(self, unused_memo):
return self
def _origin_msg(self) -> str:
return ""
# these can be used to set up forwarding of properties and instance methods from
# Tracer instances to the underlying avals
aval_property = namedtuple("aval_property", ["fget"])
@ -612,57 +604,47 @@ class TraceStack:
downward: List[MainTrace]
def __init__(self):
self.upward = []
self.downward = []
eval_trace = MainTrace(0, EvalTrace)
self.stack = [eval_trace]
self.dynamic = eval_trace
def next_level(self, bottom: bool) -> int:
if bottom:
return - (len(self.downward) + 1)
else:
return len(self.upward)
def next_level(self) -> int:
return len(self.stack)
def push(self, main_trace: MainTrace, bottom: bool) -> None:
if bottom:
self.downward.append(main_trace)
else:
self.upward.append(main_trace)
def push(self, main_trace: MainTrace) -> None:
self.stack.append(main_trace)
def pop(self, bottom: bool) -> None:
if bottom:
self.downward.pop()
else:
self.upward.pop()
def pop(self) -> None:
self.stack.pop()
def __repr__(self) -> str:
return 'Trace stack\n{} ---\n{}'.format(
map(' {}\n'.format, self.upward[::-1]),
map(' {}\n'.format, self.downward))
stack_str = map(' {}\n'.format, self.stack[::-1])
return f'Trace stack\n{stack_str}\n{self.dynamic}'
def copy(self):
new = TraceStack()
new.upward = self.upward[:]
new.downward = self.downward[:]
new = self.__new__(TraceStack)
new.stack = self.stack[:]
new.dynamic = self.dynamic
return new
class Sublevel(int): pass
AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace'])
class TraceState:
trace_stack: TraceStack
substack: List[Sublevel]
initial_style: bool
axis_env: List[AxisEnvFrame]
def __init__(self) -> None:
self.trace_stack = TraceStack()
self.substack = [Sublevel(0)]
self.initial_style = False
self.axis_env = []
def copy(self):
new = TraceState()
new = self.__new__(TraceState)
new.trace_stack = self.trace_stack.copy()
new.substack = self.substack[:]
new.initial_style = self.initial_style
new.axis_env = self.axis_env[:]
return new
# The global state of the tracer is accessed by a thread-local object.
@ -676,8 +658,9 @@ thread_local_state = ThreadLocalState()
def reset_trace_state() -> bool:
"Reset the global trace state and return True if it was already clean."
if (thread_local_state.trace_state.substack != [Sublevel(0)] or
thread_local_state.trace_state.trace_stack.downward or
thread_local_state.trace_state.trace_stack.upward):
thread_local_state.trace_state.axis_env != [] or
thread_local_state.trace_state.trace_stack.stack != [MainTrace(0, EvalTrace)] or
thread_local_state.trace_state.trace_stack.dynamic != MainTrace(0, EvalTrace)):
thread_local_state.trace_state.__init__() # type: ignore
return False
else:
@ -687,15 +670,21 @@ def cur_sublevel() -> Sublevel:
return thread_local_state.trace_state.substack[-1]
@contextmanager
def new_main(trace_type: Type[Trace], bottom=False) -> Generator[MainTrace, None, None]:
level = thread_local_state.trace_state.trace_stack.next_level(bottom)
def new_main(trace_type: Type[Trace], dynamic: bool = False,
) -> Generator[MainTrace, None, None]:
stack = thread_local_state.trace_state.trace_stack
level = stack.next_level()
main = MainTrace(level, trace_type)
thread_local_state.trace_state.trace_stack.push(main, bottom)
stack.push(main)
if dynamic:
prev_dynamic, stack.dynamic = stack.dynamic, main
try:
yield main
finally:
thread_local_state.trace_state.trace_stack.pop(bottom)
thread_local_state.trace_state.trace_stack.pop()
if dynamic:
stack.dynamic = prev_dynamic
if check_leaks:
t = ref(main)
@ -704,6 +693,23 @@ def new_main(trace_type: Type[Trace], bottom=False) -> Generator[MainTrace, None
print(thread_local_state.trace_state.trace_stack)
raise Exception('Leaked trace {}'.format(t()))
@contextmanager
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
stack = thread_local_state.trace_state.trace_stack
main = MainTrace(0, trace_type)
prev_dynamic, stack.dynamic = stack.dynamic, main
prev_base, stack.stack[0] = stack.stack[0], main
try:
yield main
finally:
stack.dynamic = prev_dynamic
stack.stack[0] = prev_base
@contextmanager
def eval_context():
with new_base_main(EvalTrace):
yield
@contextmanager
def new_sublevel() -> Generator[None, None, None]:
sublevel = Sublevel(len(thread_local_state.trace_state.substack))
@ -719,29 +725,24 @@ def new_sublevel() -> Generator[None, None, None]:
if t() is not None:
raise Exception('Leaked sublevel {}'.format(t()))
def maybe_new_sublevel(trace):
# dynamic traces run the WrappedFun, so we raise the sublevel for them
dynamic = thread_local_state.trace_state.trace_stack.dynamic
return new_sublevel() if trace.main is dynamic else suppress()
def full_lower(val):
if isinstance(val, Tracer):
return val.full_lower()
else:
return val
def find_top_trace(xs) -> Optional[Trace]:
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
key=attrgetter('level'), default=None)
return top_trace and type(top_trace)(top_trace.main, cur_sublevel())
@contextmanager
def initial_style_staging():
trace_state = thread_local_state.trace_state
prev, trace_state.initial_style = trace_state.initial_style, True
try:
yield
finally:
trace_state.initial_style = prev
@contextmanager
def eval_context():
yield # dummy implementation for forward compatibility
def find_top_trace(xs) -> Trace:
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
default=None, key=attrgetter('level'))
dynamic = thread_local_state.trace_state.trace_stack.dynamic
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
else top_main)
return top_main and top_main.trace_type(top_main, cur_sublevel()) # type: ignore
# -------------------- abstract values --------------------
@ -844,20 +845,24 @@ pytype_aval_mappings[Unit] = lambda _: abstract_unit
class ConcretizationTypeError(TypeError): pass
def raise_concretization_error(val, context=""):
msg = (f"Abstract tracer value encountered where concrete value is expected ({context}).\n"
"Use transformation parameters such as `static_argnums` for `jit` "
"to avoid tracing input values.\n"
"See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.\n"
f"Encountered value: {val}")
def raise_concretization_error(val: Tracer, context=""):
msg = ("Abstract tracer value encountered where concrete value is expected.\n\n"
+ context + "\n\n"
+ val._origin_msg() + "\n\n"
+ "You can use transformation parameters such as `static_argnums` for "
"`jit` to avoid tracing particular arguments of transformed functions.\n\n"
"See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
f"Encountered tracer value: {val}")
raise ConcretizationTypeError(msg)
def concretization_function_error(fun, context=""):
def concretization_function_error(fun, suggest_astype=False):
fname = getattr(fun, "__name__", fun)
fname_context = f"in `{fname}`"
if context:
fname_context += f" {context}"
fname_context = f"The problem arose with the `{fname}` function. "
if suggest_astype:
fname_context += ("If trying to convert the data type of a value, "
f"try using `x.astype({fun.__name__})` "
f"or `jnp.array(x, {fun.__name__})` instead.")
def error(self, arg):
raise_concretization_error(arg, fname_context)
return error
@ -899,12 +904,9 @@ class UnshapedArray(AbstractValue):
", weak_type=True" if self.weak_type else "")
_bool = _nonzero = concretization_function_error(bool)
_float = concretization_function_error(
float, "Try using `x.astype(float)` instead.")
_int = concretization_function_error(
int, "Try using `x.astype(int)` instead.")
_complex = concretization_function_error(
complex, "Try using `x.astype(complex)` instead.")
_float = concretization_function_error(float, True)
_int = concretization_function_error(int, True)
_complex = concretization_function_error(complex, True)
_hex = concretization_function_error(hex)
_oct = concretization_function_error(oct)
@ -1036,9 +1038,12 @@ class ConcreteArray(ShapedArray):
return ConcreteArray(self.val) if self.weak_type else self
_bool = _nonzero = partialmethod(_forward_to_value, bool)
_int = partialmethod(_forward_to_value, int)
_hex = partialmethod(_forward_to_value, hex)
_oct = partialmethod(_forward_to_value, oct)
_int = partialmethod(_forward_to_value, int)
_hex = partialmethod(_forward_to_value, hex)
_oct = partialmethod(_forward_to_value, oct)
_float = concretization_function_error(float, True)
_complex = concretization_function_error(complex, True)
class AbstractToken(AbstractValue):
@ -1123,20 +1128,16 @@ def process_env_traces(primitive: Union['CallPrimitive', 'MapPrimitive'],
yield outs, tuple(todo) # Ensure the aux output is immutable
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
fun: lu.WrappedFun, *args, **params):
fun, *args, **params):
params_tuple = tuple(params.items())
top_trace = find_top_trace(args)
level = (thread_local_state.trace_state.trace_stack.next_level(True)
if top_trace is None else top_trace.level)
params_tuple = tuple(params.items())
fun, env_trace_todo = process_env_traces(fun, primitive, level, params_tuple)
if top_trace is None:
with new_sublevel():
outs = primitive.impl(fun, *args, **params)
else:
tracers = map(top_trace.full_raise, args)
fun, env_trace_todo = process_env_traces(
fun, primitive, top_trace and top_trace.level, params_tuple)
tracers = map(top_trace.full_raise, args)
with maybe_new_sublevel(top_trace):
outs = primitive.process(top_trace, fun, tracers, params)
return apply_todos(env_trace_todo(), map(full_lower, outs))
return map(full_lower, apply_todos(env_trace_todo(), outs))
class CallPrimitive(Primitive):
multiple_results = True
@ -1176,10 +1177,65 @@ class MapPrimitive(Primitive):
def post_process(self, trace, out_tracers, params):
return trace.post_process_map(self, out_tracers, params)
# This is a no-op with omnistaging disabled
@contextmanager
def extend_axis_env(axis_name, size: int, tag: Any):
yield
frame = AxisEnvFrame(axis_name, size, tag)
thread_local_state.trace_state.axis_env.append(frame)
try:
yield
finally:
thread_local_state.trace_state.axis_env.pop()
def axis_frame(axis_name):
frames = thread_local_state.trace_state.axis_env
for frame in reversed(frames):
if frame.name == axis_name:
return frame
else:
raise NameError("unbound axis name: {}".format(axis_name))
def axis_index(axis_name):
"""Return the index along the mapped axis ``axis_name``.
Args:
axis_name: hashable Python object used to name the mapped axis.
Returns:
An integer representing the index.
For example, with 8 XLA devices available:
>>> from functools import partial
>>> @partial(jax.pmap, axis_name='i')
... def f(_):
... return lax.axis_index('i')
...
>>> f(np.zeros(4))
ShardedDeviceArray([0, 1, 2, 3], dtype=int32)
>>> f(np.zeros(8))
ShardedDeviceArray([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
>>> @partial(jax.pmap, axis_name='i')
... @partial(jax.pmap, axis_name='j')
... def f(_):
... return lax.axis_index('i'), lax.axis_index('j')
...
>>> x, y = f(np.zeros((4, 2)))
>>> print(x)
[[0 0]
[1 1]
[2 2]
[3 3]]
>>> print(y)
[[0 1]
[0 1]
[0 1]
[0 1]]
"""
return axis_index_p.bind(axis_name=axis_name)
axis_index_p = Primitive('axis_index')
axis_index_p.def_abstract_eval(lambda *, axis_name: ShapedArray((), np.int32))
# ------------------- Jaxpr checking -------------------
@ -1368,7 +1424,7 @@ def pp_eqn(eqn: JaxprEqn, print_shapes: bool = False) -> PrettyPrint:
pp_rhs = (pp(eqn.primitive.name) >>
pp_kv_pairs(sorted(eqn.params.items())) >> pp(' ') >>
pp(pp_vars(eqn.invars, print_shapes)))
if len(lhs) <= 6:
if len(lhs) <= 6 or print_shapes:
return pp_lhs >> pp(' ') >> pp_rhs
else:
return pp_lhs + pp_rhs.indent(2)
@ -1428,61 +1484,63 @@ def pp_kv_pairs(kv_pairs):
else:
return pp('')
axis_frame = None
# TODO(mattjj): remove when omnistaging fully lands
@config.register_omnistaging_enabler
@no_type_check
def omnistaging_enabler() -> None:
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global thread_local_state, call_bind, find_top_trace, initial_style_staging, \
new_main, reset_trace_state, extend_axis_env, axis_frame, \
new_base_main, eval_context, \
TraceStack, TraceState
del initial_style_staging
new_main, reset_trace_state, TraceStack, TraceState, extend_axis_env
class TraceStack:
stack: List[MainTrace]
dynamic: MainTrace
upward: List[MainTrace]
downward: List[MainTrace]
def __init__(self):
eval_trace = MainTrace(0, EvalTrace)
self.stack = [eval_trace]
self.dynamic = eval_trace
self.upward = []
self.downward = []
def next_level(self) -> int:
return len(self.stack)
def next_level(self, bottom: bool) -> int:
if bottom:
return - (len(self.downward) + 1)
else:
return len(self.upward)
def push(self, main_trace: MainTrace) -> None:
self.stack.append(main_trace)
def push(self, main_trace: MainTrace, bottom: bool) -> None:
if bottom:
self.downward.append(main_trace)
else:
self.upward.append(main_trace)
def pop(self) -> None:
self.stack.pop()
def pop(self, bottom: bool) -> None:
if bottom:
self.downward.pop()
else:
self.upward.pop()
def __repr__(self) -> str:
stack_str = map(' {}\n'.format, self.stack[::-1])
return f'Trace stack\n{stack_str}\n{self.dynamic}'
return 'Trace stack\n{} ---\n{}'.format(
map(' {}\n'.format, self.upward[::-1]),
map(' {}\n'.format, self.downward))
def copy(self):
new = self.__new__(TraceStack)
new.stack = self.stack[:]
new.dynamic = self.dynamic
new = TraceStack()
new.upward = self.upward[:]
new.downward = self.downward[:]
return new
class TraceState:
trace_stack: TraceStack
substack: List[Sublevel]
axis_env: List[AxisEnvFrame]
initial_style: bool
def __init__(self) -> None:
self.trace_stack = TraceStack()
self.trace_stack = TraceStack() # type: ignore
self.substack = [Sublevel(0)]
self.axis_env = []
self.initial_style = False
def copy(self):
new = self.__new__(TraceState)
new = TraceState()
new.trace_stack = self.trace_stack.copy()
new.substack = self.substack[:]
new.axis_env = self.axis_env[:]
new.initial_style = self.initial_style
return new
thread_local_state = ThreadLocalState()
@ -1490,54 +1548,23 @@ def omnistaging_enabler() -> None:
def reset_trace_state() -> bool:
"Reset the global trace state and return True if it was already clean."
if (thread_local_state.trace_state.substack != [Sublevel(0)] or
thread_local_state.trace_state.axis_env != [] or
thread_local_state.trace_state.trace_stack.stack != [MainTrace(0, EvalTrace)] or
thread_local_state.trace_state.trace_stack.dynamic != MainTrace(0, EvalTrace)):
thread_local_state.trace_state.trace_stack.downward or
thread_local_state.trace_state.trace_stack.upward):
thread_local_state.trace_state.__init__() # type: ignore
return False
else:
return True
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
fun, *args, **params):
params_tuple = tuple(params.items())
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces(
fun, primitive, top_trace and top_trace.level, params_tuple)
tracers = map(top_trace.full_raise, args)
with maybe_new_sublevel(top_trace):
outs = primitive.process(top_trace, fun, tracers, params)
return map(full_lower, apply_todos(env_trace_todo(), outs))
def maybe_new_sublevel(trace):
# dynamic traces run the WrappedFun, so we raise the sublevel for them
dynamic = thread_local_state.trace_state.trace_stack.dynamic
return new_sublevel() if trace.main is dynamic else suppress()
def find_top_trace(xs) -> Trace:
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
default=None, key=attrgetter('level'))
dynamic = thread_local_state.trace_state.trace_stack.dynamic
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
else top_main)
return top_main and top_main.trace_type(top_main, cur_sublevel()) # type: ignore
@contextmanager
def new_main(trace_type: Type[Trace], dynamic: bool = False,
) -> Generator[MainTrace, None, None]:
stack = thread_local_state.trace_state.trace_stack
level = stack.next_level()
def new_main(trace_type: Type[Trace], bottom=False) -> Generator[MainTrace, None, None]:
level = thread_local_state.trace_state.trace_stack.next_level(bottom)
main = MainTrace(level, trace_type)
stack.push(main)
if dynamic:
prev_dynamic, stack.dynamic = stack.dynamic, main
thread_local_state.trace_state.trace_stack.push(main, bottom)
try:
yield main
finally:
thread_local_state.trace_state.trace_stack.pop()
if dynamic:
stack.dynamic = prev_dynamic
thread_local_state.trace_state.trace_stack.pop(bottom)
if check_leaks:
t = ref(main)
@ -1546,47 +1573,55 @@ def omnistaging_enabler() -> None:
print(thread_local_state.trace_state.trace_stack)
raise Exception('Leaked trace {}'.format(t()))
@contextmanager
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
stack = thread_local_state.trace_state.trace_stack
main = MainTrace(0, trace_type)
prev_dynamic, stack.dynamic = stack.dynamic, main
prev_base, stack.stack[0] = stack.stack[0], main
try:
yield main
finally:
stack.dynamic = prev_dynamic
stack.stack[0] = prev_base
def find_top_trace(xs) -> Optional[Trace]:
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
key=attrgetter('level'), default=None)
return top_trace and type(top_trace)(top_trace.main, cur_sublevel())
@contextmanager
def eval_context():
with new_base_main(EvalTrace):
yield
yield # dummy implementation for forward compatibility
def bind(self, *args, **params):
def bind(self, *args, **kwargs):
assert skip_checks or all(isinstance(arg, Tracer)
or valid_jaxtype(arg) for arg in args), args
top_trace = find_top_trace(args)
if top_trace is None:
return self.impl(*args, **kwargs)
tracers = map(top_trace.full_raise, args)
out = top_trace.process_primitive(self, tracers, params)
return map(full_lower, out) if self.multiple_results else full_lower(out)
Primitive.bind = bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
if self.multiple_results:
return map(full_lower, out_tracer)
else:
return full_lower(out_tracer)
Primitive.bind = bind # type: ignore
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
fun: lu.WrappedFun, *args, **params):
params_tuple = tuple(params.items())
top_trace = find_top_trace(args)
level = (thread_local_state.trace_state.trace_stack.next_level(True)
if top_trace is None else top_trace.level)
params_tuple = tuple(params.items())
fun, env_trace_todo = process_env_traces(fun, primitive, level, params_tuple)
if top_trace is None:
with new_sublevel():
outs = primitive.impl(fun, *args, **params)
else:
tracers = map(top_trace.full_raise, args)
outs = primitive.process(top_trace, fun, tracers, params)
return apply_todos(env_trace_todo(), map(full_lower, outs))
@contextmanager
def extend_axis_env(axis_name, size: int, main_trace: Optional[MainTrace]):
frame = AxisEnvFrame(axis_name, size, main_trace)
thread_local_state.trace_state.axis_env.append(frame)
def extend_axis_env(axis_name, size: int, tag: Any):
yield
@contextmanager
def initial_style_staging():
trace_state = thread_local_state.trace_state
prev, trace_state.initial_style = trace_state.initial_style, True
try:
yield
finally:
frame_ = thread_local_state.trace_state.axis_env.pop()
assert frame is frame_ # Only runs if there was was no exception
def axis_frame(axis_name):
frames = thread_local_state.trace_state.axis_env
for frame in reversed(frames):
if frame.name == axis_name:
return frame
else:
raise NameError(f"Unbound axis name: {axis_name}.\n"
f"The currently bound axes are: {','.join(f.name for f in frames)}")
trace_state.initial_style = prev

View File

@ -69,11 +69,7 @@ def _memoize(thunk):
return memoized
def _initial_style_jaxpr(fun, in_avals):
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
bottom=True, stage_out=False)
assert not any(isinstance(c, core.Tracer) for c in consts)
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
return typed_jaxpr
@ -276,12 +272,8 @@ class CustomJVPCallPrimitive(core.CallPrimitive):
fun, self, top_trace and top_trace.level, ())
jvp, env_trace_todo2 = core.process_env_traces(
jvp, self, top_trace and top_trace.level, ())
if top_trace is None:
with core.new_sublevel():
outs = self.impl(fun, jvp, *args)
else:
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)
tracers = map(top_trace.full_raise, args) # type: ignore
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
if env_trace_todo:
raise core.UnexpectedTracerError
@ -602,13 +594,16 @@ xla.initial_style_translations[custom_vjp_call_jaxpr_p] = \
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
# TODO(mattjj): remove when omnistaging fully lands
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
global _initial_style_jaxpr
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global _initial_style_jaxpr, custom_jvp_call
def _initial_style_jaxpr(fun, in_avals):
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
bottom=True, stage_out=False) # type: ignore
assert not any(isinstance(c, core.Tracer) for c in consts)
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
return typed_jaxpr
@ -619,10 +614,15 @@ def omnistaging_enabler() -> None:
fun, self, top_trace and top_trace.level, ())
jvp, env_trace_todo2 = core.process_env_traces(
jvp, self, top_trace and top_trace.level, ())
tracers = map(top_trace.full_raise, args) # type: ignore
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore
if top_trace is None:
with core.new_sublevel():
outs = self.impl(fun, jvp, *args)
else:
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
if env_trace_todo:
raise core.UnexpectedTracerError
return map(core.full_lower, outs)
CustomJVPCallPrimitive.bind = bind # type: ignore
custom_jvp_call = custom_jvp_call_p.bind

View File

@ -237,10 +237,6 @@ deflinear(lax.reduce_window_sum_p)
deflinear(lax_fft.fft_p)
deflinear(xla.device_put_p)
# TODO(mattjj): remove when omnistaging fully lands
try: deflinear(lax.tie_in_p)
except AttributeError: pass
def _cumulative_jet_rule(primals_in, series_in, *, axis: int,
prefix_scan: Callable):
# Irrespective of backend, we always use the parallel prefix scan
@ -523,8 +519,8 @@ def _gen_reduce_choose_taylor_rule(chooser_fun):
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
return primal_out, series_out
return chooser_taylor_rule
jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(lax.reduce_max_p.bind)
jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(lax.reduce_min_p.bind)
jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(lax._reduce_max)
jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(lax._reduce_min)
def _abs_taylor_rule(x, series_in, **params):
x, = x
@ -584,3 +580,6 @@ def _custom_jvp_call_jaxpr_rule(primals_in, series_in, *, fun_jaxpr,
del jvp_jaxpr_thunk
return jet(core.jaxpr_as_fun(fun_jaxpr), primals_in, series_in)
jet_rules[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_rule
deflinear(lax.tie_in_p)

View File

@ -51,9 +51,9 @@ def closure_convert(fun, in_tree, in_avals):
else:
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
with core.initial_style_staging():
with core.initial_style_staging(): # type: ignore
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
wrapped_fun, in_pvals, instantiate=True, stage_out=False)
wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore
out_tree = out_tree()
# We only want to closure convert for constants with respect to which we're

View File

@ -504,7 +504,7 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_
else:
primal_jaxpr, tangent_jaxpr, out_unknowns = \
pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True,
trace_type=None)
trace_type=None) # type: ignore
def do_transpose(primals_in, cotangents_in):
# NOTE: This is passing in undefined primals in place of tangent arguments, but it
@ -555,9 +555,7 @@ def jvp_jaxpr(jaxpr, nonzeros, instantiate):
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
jaxpr_out, pvals_out, literals_out = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
avals_out, _ = unzip2(pvals_out)
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out)
return jaxpr_out, out_nonzeros()
@ -644,7 +642,7 @@ def defvjp_all(prim, custom_vjp):
if config.omnistaging_enabled:
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True)
else:
with core.initial_style_staging():
with core.initial_style_staging(): # type: ignore
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals,
instantiate=True)
tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
@ -680,9 +678,8 @@ def defvjp2(prim, *vjps):
defvjp_all(prim, vjpmaker)
# TODO(mattjj): remove when omnistaging fully lands
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global jvp_jaxpr
def jvp_jaxpr(jaxpr, nonzeros, instantiate):
@ -691,6 +688,8 @@ def omnistaging_enabler() -> None:
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
jaxpr_out, pvals_out, literals_out = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
avals_out, _ = unzip2(pvals_out)
jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out)
return jaxpr_out, out_nonzeros()

View File

@ -375,10 +375,8 @@ def batch_jaxpr(jaxpr, size, batched, instantiate):
f, batched_out = batched_traceable(f, size, batched, instantiate)
avals_in = [_promote_aval_rank(size, a) if b else a
for a, b in zip(jaxpr.in_avals, batched)]
in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True)
avals_out, _ = unzip2(pvals_out)
jaxpr_out = core.TypedJaxpr(jaxpr_out, consts_out, avals_in, avals_out)
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out)
return jaxpr_out, batched_out()
@lu.transformation_with_aux
@ -427,8 +425,8 @@ def _merge_bdims(x, y):
return x # arbitrary
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global batch_jaxpr
def batch_jaxpr(jaxpr, size, batched, instantiate):
@ -436,8 +434,10 @@ def omnistaging_enabler() -> None:
f, batched_out = batched_traceable(f, size, batched, instantiate)
avals_in = [_promote_aval_rank(size, a) if b else a
for a, b in zip(jaxpr.in_avals, batched)]
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out)
in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True)
avals_out, _ = unzip2(pvals_out)
jaxpr_out = core.TypedJaxpr(jaxpr_out, consts_out, avals_in, avals_out)
return jaxpr_out, batched_out()

View File

@ -217,9 +217,9 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang
ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True)
else:
ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( # type: ignore
complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals),
instantiate=True, stage_out=False)
instantiate=True, stage_out=False) # type: ignore
assert not ivjp_jaxpr.constvars # That might happen some time, but don't bother until then
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
ivjp_jaxpr = core.TypedJaxpr(ivjp_jaxpr, [], in_avals, out_avals)
@ -234,8 +234,8 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang
jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore
ivjp_jaxpr, unknowns, instantiate=False) # type:ignore
else:
jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr(
ivjp_jaxpr, unknowns, instantiate=False, trace_type=None)
jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore
ivjp_jaxpr, unknowns, instantiate=False, trace_type=None) # type: ignore
unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs])
# Make sure we're able to compute all cotangents. We don't really care if we
# can reconstruct or primals or not, although failure to do so might result in

View File

@ -17,7 +17,7 @@ from collections import namedtuple
import contextlib
import functools
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
List, Union, cast, Type, Set)
List, Union, cast, Type, no_type_check)
from weakref import ref
import numpy as np
@ -165,8 +165,8 @@ class JaxprTrace(Trace):
# We use process_call to handle both call and map primitives.
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
if not config.omnistaging_enabled:
if (self.main.trace_type is StagingJaxprTrace
and primitive in staged_out_calls):
if (self.main.trace_type is StagingJaxprTrace # type: ignore
and primitive in staged_out_calls): # type: ignore
tracers = map(self.instantiate_const_abstracted, tracers)
if primitive in call_partial_eval_rules:
@ -284,21 +284,6 @@ class JaxprTrace(Trace):
env_tracers = map(self.full_raise, env)
return jaxpr, out_pvs, consts, env_tracers
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
# See comment at top of `JaxprTrace`. This method should be reachable
# only when we stage out, and in that case we drop the custom differentiation
# rules, because we do not need them.
assert self.main.trace_type is StagingJaxprTrace
return fun.call_wrapped(*tracers)
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
# See comment in the above process_custom_jvp_call method.
assert self.main.trace_type is StagingJaxprTrace
return fun.call_wrapped(*tracers)
class StagingJaxprTrace(JaxprTrace): pass
@lu.transformation_with_aux
def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts):
@ -320,7 +305,7 @@ def abstract_eval_fun(fun, *avals, **params):
else:
pvals_in = [PartialVal.unknown(a) for a in avals]
_, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
instantiate=True, stage_out=True)
instantiate=True, stage_out=True) # type: ignore
avals_out, _ = unzip2(pvals_out)
for aval_out in avals_out:
assert isinstance(aval_out, AbstractValue) # instantiate=True
@ -370,10 +355,8 @@ class JaxprTracer(Tracer):
# TODO(necula): this should return a TypedJaxpr
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
instantiate: Union[bool, Sequence[bool]] = False,
stage_out=False, bottom=False,
trace_type: Optional[Type[Trace]] = None) \
-> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
instantiate: Union[bool, Sequence[bool]] = False,
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
"""Traces a function into a Jaxpr, given PartialVals for inputs.
Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
@ -416,8 +399,7 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
consts = [3, 6] # values for `ka` and `kb` constvars
"""
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
with core.new_main(trace_type, bottom=bottom) as main:
with core.new_main(JaxprTrace) as main:
fun = trace_to_subjaxpr(fun, main, instantiate)
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
assert not env
@ -574,7 +556,6 @@ def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, Abst
def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool],
instantiate: Union[bool, Sequence[bool]],
trace_type: Optional[Type[core.Trace]]
) -> Tuple[TypedJaxpr, TypedJaxpr, Sequence[bool]]:
"""Specializes a Jaxpr given an indication of which inputs are known.
@ -616,19 +597,16 @@ def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool],
cell = []
def fun(*vals):
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate,
trace_type=trace_type)
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
return out_consts_2 + consts_2
# The abstract_unit here doesn't really matter, because trace_to_jaxpr completely ignores
# the avals, and it will never actually reach any primitives, because the `fun` above will
# execute the jaxpr with the right avals (it reconstructs `pvals` inside).
pvals = [PartialVal.unknown(abstract_unit) if uk else PartialVal.unknown(aval)
for aval, uk in zip(jaxpr.in_avals, unknowns)]
jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True)
# For jaxpr_known we pass core.unit for the unknown inputs, and known PartialVal for the
# known inputs.
in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)]
jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
(out_pvs_2, jaxpr_2, num_res), = cell
assert len(jaxpr_2.constvars) == num_res
@ -647,11 +625,10 @@ def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool],
in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
# out_avals_1 and in_avals_2 need the residuals added
out_pvs, _ = unzip2(out_pvals)
res_avals = out_pvs[len(jaxpr.out_avals):]
res_avals = out_avals[len(jaxpr.out_avals):]
assert len(res_avals) == num_res
out_avals_1 = out_avals_1 + res_avals
in_avals_2 = in_avals_2 + res_avals
out_avals_1 = [*out_avals_1, *res_avals]
in_avals_2 = [*in_avals_2, *res_avals]
typed_jaxpr_1 = TypedJaxpr(jaxpr_1, consts_1, in_avals_1, out_avals_1)
typed_jaxpr_2 = TypedJaxpr(jaxpr_2, (), in_avals_2, out_avals_2)
@ -685,7 +662,7 @@ def _remat_partial_eval(trace, _, f, tracers, params):
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params))
else:
with core.initial_style_staging():
with core.initial_style_staging(): # type: ignore
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params))
@ -710,7 +687,7 @@ def _remat_partial_eval(trace, _, f, tracers, params):
typed_jaxpr, in_unknowns, instantiate=False) # type: ignore
else:
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type)
typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type) # type: ignore
out_knowns = [not b for b in out_unknowns]
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
@ -839,26 +816,19 @@ class DynamicJaxprTracer(core.Tracer):
def _contents(self):
return ()
def __bool__(self):
self._concretization_error('__bool__')
def _concretization_error(self, name):
msgs = self._progenitor_messages()
msg = (f"Abstract tracer value passed to {name} for which a concrete value "
"is required.\n"
f"While tracing the function {self._trace.main.source_info}, "
"this tracer originated from using JAX operations on these lines:"
"\n\n" + "\n\n".join(msgs) + "\n\n"
"See the above traceback for where this tracer was encountered.")
raise core.ConcretizationTypeError(msg)
def _progenitor_messages(self):
def _origin_msg(self):
progenitor_eqns = self._trace.frame.find_progenitors(self)
# TODO mention which jit this tracer belongs to
msgs = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n"
f" from line {source_info_util.summarize(eqn.source_info)}"
for eqn in progenitor_eqns]
return msgs
if msgs:
origin = (f"While tracing the function {self._trace.main.source_info}, "
"this value became a tracer due to JAX operations on these lines:"
"\n\n" + "\n\n".join(msgs))
else:
origin = ("The error occured while tracing the function "
f"{self._trace.main.source_info}.")
return origin
class JaxprStackFrame:
__slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
@ -883,7 +853,10 @@ class JaxprStackFrame:
return jaxpr, out_avals, constvals
def find_progenitors(self, tracer):
active_vars = {self.tracer_to_var[id(tracer)]}
var = self.tracer_to_var.get(id(tracer))
if not var:
return []
active_vars = {var}
for eqn in self.eqns[::-1]:
produced = set(eqn.outvars) & active_vars
if produced:
@ -1078,19 +1051,60 @@ def fun_sourceinfo(fun):
return "<unknown>"
# TODO(mattjj): remove when omnistaging fully lands
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
global trace_to_jaxpr, partial_eval_jaxpr
del JaxprTrace.process_custom_jvp_call
del JaxprTrace.process_custom_vjp_call
@config.register_omnistaging_disabler
@no_type_check
def omnistaging_disabler() -> None:
global trace_to_jaxpr, partial_eval_jaxpr, staged_out_calls, StagingJaxprTrace
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
instantiate: Union[bool, Sequence[bool]] = False,
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
with core.new_main(JaxprTrace) as main:
instantiate: Union[bool, Sequence[bool]] = False,
stage_out=False, bottom=False,
trace_type: Optional[Type[Trace]] = None,
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
"""Traces a function into a Jaxpr, given PartialVals for inputs.
Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
computation that depends on unknown inputs. The `out_pvals` are the PartialVal
for the outputs. The intermediate values that depend only on known inputs and
are needed to compute the output of `jaxpr` are in `consts` and are passed in
as the constvars of the `jaxpr`. The handling of the known outputs depends on
`instantiate`.
For example, given `fun` defined as follows::
def fun(ki, ui): # ki will be a known input in this example
ka = ki + 2
kb = ka + 3
return (kb, ui + ka)
with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
computation that depends on unknown inputs is `ui + ka` and will be the only
computation in the body of the `jaxpr`. This computation depends on the known
intermediate value `ka`, which will be computed statically. Currently, such
constants are either embedded in the Jaxpr if they are scalars, or passed as a
constvar to `jaxpr`, and then the value of the actual constant will be in
`consts`:
When `instantiate=False` we get::
jaxpr =
{ lambda ka ; ki ui.
let c = add ui ka
in (*, c) } # known outputs are `*`
out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
consts = [3] # the constant for `ka`
When `instantiate=True` we get::
jaxpr =
{ lambda ka kb ; ki ui.
let c = add ui ka
in (kb, c) } # known output are explicit
out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
consts = [3, 6] # values for `ka` and `kb` constvars
"""
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
with core.new_main(trace_type, bottom=bottom) as main:
fun = trace_to_subjaxpr(fun, main, instantiate)
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
assert not env
@ -1100,6 +1114,7 @@ def omnistaging_enabler() -> None:
def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool],
instantiate: Union[bool, Sequence[bool]],
trace_type: Optional[Type[core.Trace]]
) -> Tuple[TypedJaxpr, TypedJaxpr, Sequence[bool]]:
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
@ -1107,15 +1122,18 @@ def omnistaging_enabler() -> None:
def fun(*vals):
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate,
trace_type=trace_type)
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
return out_consts_2 + consts_2
# For jaxpr_known we pass core.unit for the unknown inputs, and known PartialVal for the
# known inputs.
in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)]
jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
# The abstract_unit here doesn't really matter, because trace_to_jaxpr completely ignores
# the avals, and it will never actually reach any primitives, because the `fun` above will
# execute the jaxpr with the right avals (it reconstructs `pvals` inside).
pvals = [PartialVal.unknown(abstract_unit) if uk else PartialVal.unknown(aval)
for aval, uk in zip(jaxpr.in_avals, unknowns)]
jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True)
(out_pvs_2, jaxpr_2, num_res), = cell
assert len(jaxpr_2.constvars) == num_res
@ -1134,13 +1152,32 @@ def omnistaging_enabler() -> None:
in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
# out_avals_1 and in_avals_2 need the residuals added
res_avals = out_avals[len(jaxpr.out_avals):]
out_pvs, _ = unzip2(out_pvals)
res_avals = out_pvs[len(jaxpr.out_avals):]
assert len(res_avals) == num_res
out_avals_1 = [*out_avals_1, *res_avals]
in_avals_2 = [*in_avals_2, *res_avals]
out_avals_1 = out_avals_1 + res_avals
in_avals_2 = in_avals_2 + res_avals
typed_jaxpr_1 = TypedJaxpr(jaxpr_1, consts_1, in_avals_1, out_avals_1)
typed_jaxpr_2 = TypedJaxpr(jaxpr_2, (), in_avals_2, out_avals_2)
return typed_jaxpr_1, typed_jaxpr_2, uk_out
staged_out_calls: Set[core.Primitive] = set()
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
# See comment at top of `JaxprTrace`. This method should be reachable
# only when we stage out, and in that case we drop the custom differentiation
# rules, because we do not need them.
if not config.omnistaging_enabled:
assert self.main.trace_type is StagingJaxprTrace
return fun.call_wrapped(*tracers)
JaxprTrace.process_custom_jvp_call = process_custom_jvp_call
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
# See comment in the above process_custom_jvp_call method.
if not config.omnistaging_enabled:
assert self.main.trace_type is StagingJaxprTrace
return fun.call_wrapped(*tracers)
JaxprTrace.process_custom_vjp_call = process_custom_vjp_call
staged_out_calls = set()
class StagingJaxprTrace(JaxprTrace): pass

View File

@ -34,7 +34,7 @@ import itertools as it
import operator as op
import threading
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
Type, Union)
Type, Union, no_type_check)
from absl import logging
import numpy as np
@ -330,91 +330,6 @@ pxla_result_handlers[ShapedArray] = array_result_handler
pxla_result_handlers[ConcreteArray] = array_result_handler
### applying parallel primitives in op-by-op Python dispatch
# There are at least two cases where we might want to evaluate a parallel
# primitive dispatched from Python, rather than being staged out:
# 1. axis_size = psum(1, 'axis_name'),
# 2. to enable an implicit outermost pmap-like context for multi-host
# multi-controller SPMD programs.
# In each case, we can't rely on any data dependence on a pmap trace; instead we
# need some dynamic context, basically modeling the axis name environment stack.
# To handle the former case, we don't need to communicate at all; we instead
# have a table of parallel_pure_rules. To handle the latter case, we'll have a
# globally-scoped root environment frame and compile and execute a single-op
# XLA collective.
class DynamicAxisEnvFrame(object):
__slots__ = ["name", "pmap_trace", "hard_size"]
def __init__(self, name, pmap_trace, hard_size):
self.name = name
self.pmap_trace = pmap_trace
self.hard_size = hard_size
class DynamicAxisEnv(list):
def __contains__(self, axis_name):
return axis_name in (frame.name for frame in self)
def __getitem__(self, axis_name):
if axis_name not in self:
raise NameError("unbound axis name: {}".format(axis_name))
for frame in reversed(self):
if frame.name == axis_name:
return frame
else:
assert False
@property
def sizes(self):
return tuple(frame.hard_size for frame in self)
@property
def nreps(self):
return prod(frame.hard_size for frame in self)
class _ThreadLocalState(threading.local):
def __init__(self):
self.dynamic_axis_env = DynamicAxisEnv()
_thread_local_state = _ThreadLocalState()
@contextmanager
def extend_dynamic_axis_env(axis_name, pmap_trace, hard_size):
dynamic_axis_env = _thread_local_state.dynamic_axis_env
dynamic_axis_env.append(DynamicAxisEnvFrame(axis_name, pmap_trace, hard_size))
try:
yield
finally:
dynamic_axis_env.pop()
def unmapped_device_count(backend=None):
dynamic_axis_env = _thread_local_state.dynamic_axis_env
mapped = prod(frame.hard_size for frame in dynamic_axis_env)
unmapped, ragged = divmod(xb.device_count(backend), mapped)
assert not ragged and unmapped > 0
return unmapped
def apply_parallel_primitive(prim, *args, **params):
# This is the op-by-op version of applying a collective primitive, like a psum
# that doesn't have a data dependence on the argument of a pmap function. In
# particular, this code gets hit when we write `axis_size = psum(1, 'i')`. We
# look up information in the dynamic axis env.
dynamic_axis_env = _thread_local_state.dynamic_axis_env
axis_name = params.pop('axis_name')
axis_index_groups = params.pop('axis_index_groups')
if axis_index_groups is not None:
shape = (len(axis_index_groups[0]),)
else:
logical_size = lambda frame: frame.hard_size
if isinstance(axis_name, (list, tuple)):
shape = tuple(logical_size(dynamic_axis_env[name]) for name in axis_name)
else:
shape = (logical_size(dynamic_axis_env[axis_name]),)
return parallel_pure_rules[prim](*args, shape=shape, **params)
parallel_pure_rules: Dict[core.Primitive, Callable] = {}
### lazy device-memory persistence and result handling
class ShardedDeviceArray(xla.DeviceArray):
@ -636,7 +551,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
# We add a dummy first invar, to carry the trace details to `dynamic_fun`
pval = pe.PartialVal.unknown(core.abstract_unit) # dummy value for axis env
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True)
dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True) # type: ignore
jaxpr.invars = jaxpr.invars[1:] # ignore dummy
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
@ -795,7 +710,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
handle_outs = avals_to_results_handler( # type: ignore
axis_size, num_local_replicas, num_partitions, out_parts, out_avals)
else:
handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas,
handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas, # type: ignore
num_partitions, out_parts,
out_pvals, compiled.local_devices(),
backend)
@ -884,17 +799,20 @@ def get_num_partitions(*partitions):
class ResultToPopulate: pass
result_to_populate = ResultToPopulate()
def _pvals_to_results_handler(
size, nrep, npart,
out_parts: Optional[Tuple[PartitionsOrReplicated, ...]],
out_pvals, devices, backend):
nouts = len(out_pvals)
def avals_to_results_handler(size, nrep, npart, out_parts, out_avals):
nouts = len(out_avals)
if out_parts is None:
out_parts = (None,) * len(out_pvals)
handlers = [
_pval_to_result_handler(size, nrep, npart, parts, pval, devices, backend)
for pval, parts in safe_zip(out_pvals, out_parts) # type: ignore
]
out_parts = (None,) * len(out_avals)
# TODO(mattjj,skyewm): can probably clean up this logic
out_specs = [_pmap_sharding_spec(nrep, size, npart, parts, aval, True)
if aval is not core.abstract_unit else None
for parts, aval in zip(out_parts, out_avals)]
out_indices = [spec_to_indices(core.unmapped_aval(size, aval).shape, spec)
if aval is not core.abstract_unit else None
for aval, spec in zip(out_avals, out_specs)] # pytype: disable=attribute-error
handlers = [aval_to_result_handler(spec, idcs, core.unmapped_aval(size, aval))
for spec, idcs, aval in zip(out_specs, out_indices, out_avals)]
def handler(out_bufs):
assert nrep * npart == len(out_bufs)
@ -903,7 +821,7 @@ def _pvals_to_results_handler(
for i, buf in enumerate(tuple_buf):
buffers[i][r] = buf
assert not any(buf is result_to_populate for bufs in buffers
for buf in bufs)
for buf in bufs)
return [h(bufs) for h, bufs in zip(handlers, buffers)]
return handler
@ -947,37 +865,6 @@ def replicate(val, axis_size, nrep, devices=None, backend=None):
device_buffers = [xla.device_put(val, d) for d in devices]
return ShardedDeviceArray(replicated_aval, sharding_spec, device_buffers)
def _pval_to_result_handler(axis_size, nrep, npart, parts, pval, devices, backend):
if devices:
assert all(d.host_id == xb.host_id(backend) for d in devices)
pv, const = pval
if pv is None:
if nrep is None:
nrep = axis_size
# If 'const' is a ShardedDeviceArray, it must have come from a pmap nested
# inside the one we're currently evaluating, and we should replicate
# 'const' across the total number of devices needed. We don't necessarily
# know the nested pmap's axis_size (e.g. the jaxpr for
# pmap(pmap(lambda x: 3)) is trivial, with no pmaps), but we can use the
# axis size of the output 'const'.
# TODO: we might be doing unnecessary device transfers in the inner pmap.
if isinstance(const, ShardedDeviceArray):
nrep *= len(const)
bcast_const = (core.unit if const is core.unit
else replicate(const, axis_size, nrep, devices, backend)) # type: ignore
return lambda _: bcast_const # type: ignore
else:
if pv is not core.abstract_unit:
unsharded_aval = ShapedArray((axis_size,) + pv.shape, pv.dtype)
sharding_spec = _pmap_sharding_spec(nrep, axis_size, npart, parts, pv,
True)
indices = spec_to_indices(unsharded_aval.shape, sharding_spec)
else:
sharding_spec = indices = None
unsharded_aval = pv
return aval_to_result_handler(sharding_spec, indices, unsharded_aval)
def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, mapped):
"""Sharding spec for arguments or results of a pmap.
Args:
@ -1014,7 +901,6 @@ def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, mapped):
replication_factors=[(replication_factor * axis_size, 0)] +
shard_spec.replication_factors)
def partitioned_sharding_spec(num_partitions: int,
partitions: Optional[Sequence[int]], aval):
if partitions is None:
@ -1043,7 +929,6 @@ def execute_replicated(compiled, backend, in_handler, out_handler, *args):
xla_pmap_p = core.MapPrimitive('xla_pmap')
xla_pmap = xla_pmap_p.bind
xla_pmap_p.def_impl(xla_pmap_impl)
pe.staged_out_calls.add(xla_pmap_p)
# Set param update handlers to update `donated_invars` just like xla_call_p
pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p]
@ -1279,41 +1164,35 @@ soft_pmap_p.def_impl(soft_pmap_impl)
soft_pmap_rules: Dict[core.Primitive, Callable] = {}
def deleted_with_omnistaging(*a, **k):
assert False, "Should be deleted"
@contextmanager
def maybe_extend_axis_env(*args, **kwargs):
yield
with core.extend_axis_env(*args, **kwargs):
yield
@config.register_omnistaging_enabler
def omnistaging_enable() -> None:
@config.register_omnistaging_disabler
@no_type_check
def omnistaging_disabler() -> None:
global DynamicAxisEnvFrame, DynamicAxisEnv, _ThreadLocalState, \
_thread_local_state, extend_dynamic_axis_env, unmapped_device_count, \
apply_parallel_primitive, parallel_pure_rules, \
_pvals_to_results_handler, _pval_to_result_handler, replicate, \
avals_to_results_handler, maybe_extend_axis_env
del DynamicAxisEnvFrame, DynamicAxisEnv, _ThreadLocalState, \
_thread_local_state, extend_dynamic_axis_env, unmapped_device_count, \
_pvals_to_results_handler, _pval_to_result_handler, replicate
avals_to_results_handler, axis_index, maybe_extend_axis_env
apply_parallel_primitive = deleted_with_omnistaging
parallel_pure_rules.clear()
@contextmanager
def maybe_extend_axis_env(*args, **kwargs):
yield
def avals_to_results_handler(size, nrep, npart, out_parts, out_avals):
nouts = len(out_avals)
def _pvals_to_results_handler(
size, nrep, npart,
out_parts: Optional[Tuple[PartitionsOrReplicated, ...]],
out_pvals, devices, backend):
nouts = len(out_pvals)
if out_parts is None:
out_parts = (None,) * len(out_avals)
# TODO(mattjj,skyewm): can probably clean up this logic
out_specs = [_pmap_sharding_spec(nrep, size, npart, parts, aval, True)
if aval is not core.abstract_unit else None
for parts, aval in zip(out_parts, out_avals)]
out_indices = [spec_to_indices(core.unmapped_aval(size, aval).shape, spec)
if aval is not core.abstract_unit else None
for aval, spec in zip(out_avals, out_specs)] # pytype: disable=attribute-error
handlers = [aval_to_result_handler(spec, idcs, core.unmapped_aval(size, aval))
for spec, idcs, aval in zip(out_specs, out_indices, out_avals)]
out_parts = (None,) * len(out_pvals)
handlers = [
_pval_to_result_handler(size, nrep, npart, parts, pval, devices, backend)
for pval, parts in safe_zip(out_pvals, out_parts) # type: ignore
]
def handler(out_bufs):
assert nrep * npart == len(out_bufs)
@ -1326,7 +1205,105 @@ def omnistaging_enable() -> None:
return [h(bufs) for h, bufs in zip(handlers, buffers)]
return handler
def _pval_to_result_handler(axis_size, nrep, npart, parts, pval, devices, backend):
if devices:
assert all(d.host_id == xb.host_id(backend) for d in devices)
pv, const = pval
if pv is None:
if nrep is None:
nrep = axis_size
# If 'const' is a ShardedDeviceArray, it must have come from a pmap nested
# inside the one we're currently evaluating, and we should replicate
# 'const' across the total number of devices needed. We don't necessarily
# know the nested pmap's axis_size (e.g. the jaxpr for
# pmap(pmap(lambda x: 3)) is trivial, with no pmaps), but we can use the
# axis size of the output 'const'.
# TODO: we might be doing unnecessary device transfers in the inner pmap.
if isinstance(const, ShardedDeviceArray):
nrep *= len(const)
bcast_const = (core.unit if const is core.unit
else replicate(const, axis_size, nrep, devices, backend)) # type: ignore
return lambda _: bcast_const # type: ignore
else:
if pv is not core.abstract_unit:
unsharded_aval = ShapedArray((axis_size,) + pv.shape, pv.dtype)
sharding_spec = _pmap_sharding_spec(nrep, axis_size, npart, parts, pv,
True)
indices = spec_to_indices(unsharded_aval.shape, sharding_spec)
else:
sharding_spec = indices = None
unsharded_aval = pv
return aval_to_result_handler(sharding_spec, indices, unsharded_aval)
@contextmanager
def maybe_extend_axis_env(*args, **kwargs):
with core.extend_axis_env(*args, **kwargs):
def extend_dynamic_axis_env(axis_name, pmap_trace, hard_size):
dynamic_axis_env = _thread_local_state.dynamic_axis_env
dynamic_axis_env.append(DynamicAxisEnvFrame(axis_name, pmap_trace, hard_size))
try:
yield
finally:
dynamic_axis_env.pop()
def unmapped_device_count(backend=None):
dynamic_axis_env = _thread_local_state.dynamic_axis_env
mapped = prod(frame.hard_size for frame in dynamic_axis_env)
unmapped, ragged = divmod(xb.device_count(backend), mapped)
assert not ragged and unmapped > 0
return unmapped
def apply_parallel_primitive(prim, *args, **params):
# This is the op-by-op version of applying a collective primitive, like a psum
# that doesn't have a data dependence on the argument of a pmap function. In
# particular, this code gets hit when we write `axis_size = psum(1, 'i')`. We
# look up information in the dynamic axis env.
dynamic_axis_env = _thread_local_state.dynamic_axis_env
axis_name = params.pop('axis_name')
axis_index_groups = params.pop('axis_index_groups')
if axis_index_groups is not None:
shape = (len(axis_index_groups[0]),)
else:
logical_size = lambda frame: frame.hard_size
if isinstance(axis_name, (list, tuple)):
shape = tuple(logical_size(dynamic_axis_env[name]) for name in axis_name)
else:
shape = (logical_size(dynamic_axis_env[axis_name]),)
return parallel_pure_rules[prim](*args, shape=shape, **params)
pe.staged_out_calls.add(xla_pmap_p) # type: ignore
parallel_pure_rules = {} # type: ignore
class DynamicAxisEnvFrame(object):
__slots__ = ["name", "pmap_trace", "hard_size"]
def __init__(self, name, pmap_trace, hard_size):
self.name = name
self.pmap_trace = pmap_trace
self.hard_size = hard_size
class DynamicAxisEnv(list):
def __contains__(self, axis_name):
return axis_name in (frame.name for frame in self)
def __getitem__(self, axis_name):
if axis_name not in self:
raise NameError("unbound axis name: {}".format(axis_name))
for frame in reversed(self):
if frame.name == axis_name:
return frame
else:
assert False
@property
def sizes(self):
return tuple(frame.hard_size for frame in self)
@property
def nreps(self):
return prod(frame.hard_size for frame in self)
class _ThreadLocalState(threading.local):
def __init__(self):
self.dynamic_axis_env = DynamicAxisEnv()
_thread_local_state = _ThreadLocalState()

View File

@ -41,10 +41,11 @@ def _map(f, *xs):
class ResultToPopulate: pass
result_to_populate = ResultToPopulate()
def _pvals_to_results_handler(nrep, npart, partitions, out_pvals):
nouts = len(out_pvals)
handlers = [_pval_to_result_handler(npart, parts, out_pval)
for parts, out_pval in safe_zip(partitions, out_pvals)] # type: ignore
def _avals_to_results_handler(nrep, npart, partitions, out_avals):
nouts = len(out_avals)
handlers = [_aval_to_result_handler(npart, parts, out_aval)
for parts, out_aval in safe_zip(partitions, out_avals)]
def handler(out_bufs):
assert nrep * npart == len(out_bufs)
@ -53,23 +54,18 @@ def _pvals_to_results_handler(nrep, npart, partitions, out_pvals):
for i, buf in enumerate(tuple_buf):
buffers[i][r] = buf
assert not any(buf is result_to_populate for bufs in buffers
for buf in bufs)
for buf in bufs)
return [h(bufs) for h, bufs in zip(handlers, buffers)]
return handler
def _pval_to_result_handler(npart, parts, pval):
pv, const = pval
if pv is None:
raise NotImplementedError # TODO(skye): handle constant outputs
def _aval_to_result_handler(npart, parts, aval):
if aval is not core.abstract_unit:
spec = pxla.partitioned_sharding_spec(npart, parts, aval)
indices = pxla.spec_to_indices(aval.shape, spec)
else:
if pv is not core.abstract_unit:
spec = pxla.partitioned_sharding_spec(npart, parts, pv)
indices = pxla.spec_to_indices(pv.shape, spec)
else:
spec = indices = None
return pxla.aval_to_result_handler(spec, indices, pv)
spec = indices = None
return pxla.aval_to_result_handler(spec, indices, aval)
@lu.cache
@ -84,7 +80,8 @@ def _sharded_callable(
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
else:
in_pvals = [pe.PartialVal.unknown(aval) for aval in abstract_args]
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=False, bottom=True)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, # type: ignore
instantiate=False, bottom=True) # type: ignore
# TODO(skye): add tests for equationless jaxpr cases
if not jaxpr.eqns and all(outvar.aval is core.abstract_unit
@ -138,7 +135,7 @@ def _sharded_callable(
handle_outs = _avals_to_results_handler(nrep, num_partitions, out_parts, # type: ignore
out_avals)
else:
handle_outs = _pvals_to_results_handler(nrep, num_partitions, out_parts,
handle_outs = _pvals_to_results_handler(nrep, num_partitions, out_parts, # type: ignore
out_pvals)
return partial(_execute_spatially_partitioned, compiled, handle_args,
handle_outs)
@ -354,16 +351,14 @@ def with_sharding_constraint(x, partitions: Optional[PartitionSpec]):
return sharding_constraint_p.bind(x, partitions=partitions)
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
global _avals_to_results_handler, _aval_to_result_handler, \
_pvals_to_results_handler, _pval_to_result_handler
del _pvals_to_results_handler, _pval_to_result_handler
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global _pvals_to_results_handler, _pval_to_result_handler
def _avals_to_results_handler(nrep, npart, partitions, out_avals):
nouts = len(out_avals)
handlers = [_aval_to_result_handler(npart, parts, out_aval)
for parts, out_aval in safe_zip(partitions, out_avals)]
def _pvals_to_results_handler(nrep, npart, partitions, out_pvals):
nouts = len(out_pvals)
handlers = [_pval_to_result_handler(npart, parts, out_pval)
for parts, out_pval in safe_zip(partitions, out_pvals)] # type: ignore
def handler(out_bufs):
assert nrep * npart == len(out_bufs)
@ -377,11 +372,14 @@ def omnistaging_enabler() -> None:
return handler
def _aval_to_result_handler(npart, parts, aval):
if aval is not core.abstract_unit:
spec = pxla.partitioned_sharding_spec(npart, parts, aval)
indices = pxla.spec_to_indices(aval.shape, spec)
def _pval_to_result_handler(npart, parts, pval):
pv, const = pval
if pv is None:
raise NotImplementedError # TODO(skye): handle constant outputs
else:
spec = indices = None
return pxla.aval_to_result_handler(spec, indices, aval)
if pv is not core.abstract_unit:
spec = pxla.partitioned_sharding_spec(npart, parts, pv)
indices = pxla.spec_to_indices(pv.shape, spec)
else:
spec = indices = None
return pxla.aval_to_result_handler(spec, indices, pv)

View File

@ -597,8 +597,8 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
else:
pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
jaxpr, pvals, consts = pe.trace_to_jaxpr(
fun, pvals, instantiate=False, stage_out=True, bottom=True)
jaxpr, pvals, consts = pe.trace_to_jaxpr( # type: ignore
fun, pvals, instantiate=False, stage_out=True, bottom=True) # type: ignore
map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
jaxpr = apply_outfeed_rewriter(jaxpr)
@ -608,7 +608,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
if config.omnistaging_enabled:
result_handlers = tuple(aval_to_result_handler(device, a) for a in out_avals)
else:
result_handlers = tuple(map(partial(_pval_to_result_handler, device), pvals))
result_handlers = tuple(map(partial(_pval_to_result_handler, device), pvals)) # type: ignore
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
@ -879,7 +879,7 @@ def lower_fun(fun, multiple_results):
else:
pvals = [pe.PartialVal.unknown(a) for a in avals]
jaxpr, _, consts = pe.trace_to_jaxpr(wrapped_fun, pvals, instantiate=True,
stage_out=True)
stage_out=True) # type: ignore
xla_consts = _xla_consts(c, consts)
outs = jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts, '', *xla_args)
if multiple_results:
@ -905,7 +905,7 @@ def lower_fun_initial_style(fun):
else:
pvals = [pe.PartialVal.unknown(a) for a in avals]
jaxpr, _, consts = pe.trace_to_jaxpr(
lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True)
lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True) # type: ignore
xla_consts = _xla_consts(c, consts)
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, name_stack,
*xla_args)
@ -1277,19 +1277,25 @@ def _call_translation_rule(c, axis_env, in_nodes, name_stack,
call_translations[core.call_p] = _call_translation_rule
# TODO(mattjj): remove when omnistaging fully lands
def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
div = xb.constant(c, np.array(axis_env.nreps // prod(axis_env.sizes),
dtype=np.uint32))
mod = xb.constant(c, np.array(axis_env.sizes[-1], dtype=np.uint32))
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
parallel_translations[core.axis_index_p] = _axis_index_translation_rule # type: ignore
def _pval_to_result_handler(device, pval):
pv, const = pval
if pv is None:
const = _device_put_impl(const, device) if device else const
return lambda _: const
else:
return aval_to_result_handler(device, pv)
pe.staged_out_calls.add(xla_call_p)
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global _pval_to_result_handler
del _pval_to_result_handler
def _pval_to_result_handler(device, pval):
pv, const = pval
if pv is None:
const = _device_put_impl(const, device) if device else const
return lambda _: const
else:
return aval_to_result_handler(device, pv)
pe.staged_out_calls.add(xla_call_p) # type: ignore

View File

@ -22,6 +22,7 @@ import warnings
import numpy as np
import jax
from .. import core
from .. import ad_util
from .. import api
@ -1373,30 +1374,8 @@ def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
return top_k_p.bind(operand, k=k)
def tie_in(x: Array, y: Array) -> Array:
"""Returns the value of ``y`` but with a fake data dependence on ``x``.
When staging to XLA (e.g. running under jit or pmap), values that don't depend
on computation inputs are computed op-by-op, and folded into the XLA
computation as constants.
``tie_in`` provides a way to explicitly stage values into the computation.
When staging to XLA and ``x`` is already staged, then the result of ``tie_in``
is ``y``, but staged to XLA. Downstream use of the result will also be staged
to XLA.
For example, ``lax.sin(const)`` would be constant-folded if ``const`` is
a constant array, but ``lax.sin(lax.tie_in(x, const))``, will be staged to
XLA as long as ``x`` is staged to XLA.
"""
if config.omnistaging_enabled:
return y
else:
return tie_in_p.bind(x, y)
# def tie_in(x: Array, y: Array) -> Array:
# """Deprecated. Ignores ``x`` and returns ``y``."""
# return y
"""Deprecated. Ignores ``x`` and returns ``y``."""
return y
def full(shape: Shape, fill_value: Array, dtype: Optional[DType] = None) -> Array:
"""Returns an array of `shape` filled with `fill_value`.
@ -1502,7 +1481,13 @@ def stop_gradient(x):
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
array(0., dtype=float32)
"""
return tree_map(ad_util.stop_gradient_p.bind, x)
def stop(x):
if (dtypes.issubdtype(_dtype(x), np.floating) or
dtypes.issubdtype(_dtype(x), np.complexfloating)):
return ad_util.stop_gradient_p.bind(x)
else:
return x # only bind primitive on inexact dtypes, to avoid some staging
return tree_map(stop, x)
### convenience wrappers around traceables
@ -5656,30 +5641,6 @@ xla.translations[top_k_p] = partial(standard_translate, 'top_k')
ad.primitive_jvps[top_k_p] = _top_k_jvp
batching.primitive_batchers[top_k_p] = _top_k_batch_rule
def _tie_in_transpose_rule(t, x, y):
if ad.is_undefined_primal(x):
return [ad_util.Zero(x.aval), t]
else:
return [ad_util.Zero.from_value(x), t]
def _tie_in_batch_rule(batched_args, batch_dims):
y = tie_in(*batched_args)
_, bdim_y = batch_dims
return y, bdim_y
def _tie_in_impl(x, y):
core.check_valid_jaxtype(x)
core.check_valid_jaxtype(y)
return y
tie_in_p = Primitive('tie_in')
tie_in_p.def_impl(_tie_in_impl)
tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y))
xla.translations[tie_in_p] = lambda c, x, y: y
ad.deflinear2(tie_in_p, _tie_in_transpose_rule)
batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule
masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]
def _stop_gradient_jvp_rule(primals, tangents):
# if we don't call stop_gradient here, we'd only peel off one autodiff tracer
@ -6198,7 +6159,72 @@ def _check_user_dtype_supported(dtype, fun_name=None):
warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
global _tie_in_transpose_rule, _tie_in_batch_rule, _tie_in_impl, tie_in_p
del _tie_in_transpose_rule, _tie_in_batch_rule, _tie_in_impl, tie_in_p
def _canonicalize_axis(axis, num_dims):
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
axis = operator.index(axis)
if not -num_dims <= axis < num_dims:
raise ValueError(
"axis {} is out of bounds for array of dimension {}".format(
axis, num_dims))
if axis < 0:
axis = axis + num_dims
return axis
tie_in_p = Primitive('tie_in')
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global tie_in
def tie_in(x: Array, y: Array) -> Array:
"""Returns the value of ``y`` but with a fake data dependence on ``x``.
When staging to XLA (e.g. running under jit or pmap), values that don't depend
on computation inputs are computed op-by-op, and folded into the XLA
computation as constants.
``tie_in`` provides a way to explicitly stage values into the computation.
When staging to XLA and ``x`` is already staged, then the result of ``tie_in``
is ``y``, but staged to XLA. Downstream use of the result will also be staged
to XLA.
For example, ``lax.sin(const)`` would be constant-folded if ``const`` is
a constant array, but ``lax.sin(lax.tie_in(x, const))``, will be staged to
XLA as long as ``x`` is staged to XLA.
"""
if config.omnistaging_enabled:
return y
else:
return tie_in_p.bind(x, y)
# If lax has already been imported, we need to monkey-patch the
# lax/__init__.py import of tie_in. If not (i.e. if this is running at lax
# module creation time) then we'll get an import error.
try:
jax.lax.tie_in = tie_in
except AttributeError:
pass
def _tie_in_transpose_rule(t, x, y):
if ad.is_undefined_primal(x):
return [ad_util.Zero(x.aval), t]
else:
return [ad_util.Zero.from_value(x), t]
def _tie_in_batch_rule(batched_args, batch_dims):
y = tie_in(*batched_args)
_, bdim_y = batch_dims
return y, bdim_y
def _tie_in_impl(x, y):
core.check_valid_jaxtype(x)
core.check_valid_jaxtype(y)
return y
tie_in_p.def_impl(_tie_in_impl)
tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y))
xla.translations[tie_in_p] = lambda c, x, y: y
ad.deflinear2(tie_in_p, _tie_in_transpose_rule)
batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule
masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]

View File

@ -62,22 +62,18 @@ T = TypeVar('T')
@cache()
def _initial_style_untyped_jaxpr(fun: Callable, in_tree, in_avals):
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
with core.initial_style_staging():
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
wrapped_fun, in_pvals, instantiate=True, stage_out=False)
return jaxpr, out_pvals, consts, out_tree
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
return jaxpr, out_avals, consts, out_tree()
@cache()
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
jaxpr, out_pvals, consts, out_tree = _initial_style_untyped_jaxpr(
fun, in_tree, in_avals)
out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
jaxpr, out_avals, consts, out_tree = \
_initial_style_untyped_jaxpr(fun, in_tree, in_avals)
const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
(), const_avals + in_avals, out_avals)
return typed_jaxpr, consts, out_tree()
return typed_jaxpr, consts, out_tree
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
in_tree, in_avals):
@ -88,37 +84,29 @@ def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
# for each one, it makes another that accepts *all* constants, but only uses
# those that it needs (dropping the rest).
jaxprs, all_out_pvals, all_consts, all_out_trees = unzip4([
_initial_style_untyped_jaxpr(fun, in_tree, in_avals) for fun in funs])
jaxprs, all_out_avals, all_consts, all_out_trees = unzip4(
_initial_style_untyped_jaxpr(fun, in_tree, in_avals) for fun in funs)
newvar = core.gensym(jaxprs, suffix='_')
all_const_avals = tuple(
tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
for consts in all_consts)
unused_const_vars = tuple(
tuple(newvar(aval) for aval in const_avals)
for const_avals in all_const_avals)
all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts]
for consts in all_consts]
unused_const_vars = [[newvar(aval) for aval in const_avals]
for const_avals in all_const_avals]
def pad_jaxpr_constvars(i, jaxpr):
prefix = util.concatenate(unused_const_vars[:i])
suffix = util.concatenate(unused_const_vars[i+1:])
constvars = prefix + jaxpr.constvars + suffix
constvars = [*prefix, *jaxpr.constvars, *suffix]
return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
const_avals = tuple(util.concatenate(all_const_avals))
def type_and_const_convert_jaxpr(jaxpr, out_pvals):
out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
return core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
(), const_avals + in_avals, out_avals)
consts = util.concatenate(all_consts)
const_avals = util.concatenate(all_const_avals)
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
typed_jaxprs = _map(type_and_const_convert_jaxpr, jaxprs, all_out_pvals)
return (tuple(typed_jaxprs),
tuple(util.concatenate(all_consts)),
tuple(out_tree() for out_tree in all_out_trees))
typed_jaxprs = [core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
(), [*const_avals, *in_avals], out_avals)
for jaxpr, out_avals in zip(jaxprs, all_out_avals)]
return typed_jaxprs, consts, all_out_trees
def _abstractify(x):
return raise_to_shaped(core.get_aval(x))
@ -1517,7 +1505,7 @@ def _prune_zeros(ts):
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace:
if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace: # type: ignore
params = dict(reverse=reverse, length=length, num_consts=num_consts,
num_carry=num_carry, jaxpr=jaxpr, linear=linear,
unroll=unroll)
@ -2450,26 +2438,29 @@ def associative_scan(fn, elems, reverse=False):
return tree_unflatten(tree, scans)
# TODO(mattjj): remove when omnistaging fully lands
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global _initial_style_untyped_jaxpr, _initial_style_jaxpr, \
_initial_style_jaxprs_with_common_consts
@cache()
def _initial_style_untyped_jaxpr(fun: Callable, in_tree, in_avals):
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
return jaxpr, out_avals, consts, out_tree()
with core.initial_style_staging(): # type: ignore
jaxpr, out_pvals, consts = pe.trace_to_jaxpr( # type: ignore
wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore
return jaxpr, out_pvals, consts, out_tree
@cache()
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
jaxpr, out_avals, consts, out_tree = \
_initial_style_untyped_jaxpr(fun, in_tree, in_avals)
jaxpr, out_pvals, consts, out_tree = _initial_style_untyped_jaxpr(
fun, in_tree, in_avals)
out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
(), const_avals + in_avals, out_avals)
return typed_jaxpr, consts, out_tree
return typed_jaxpr, consts, out_tree()
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
in_tree, in_avals):
@ -2480,26 +2471,34 @@ def omnistaging_enabler() -> None:
# for each one, it makes another that accepts *all* constants, but only uses
# those that it needs (dropping the rest).
jaxprs, all_out_avals, all_consts, all_out_trees = unzip4(
_initial_style_untyped_jaxpr(fun, in_tree, in_avals) for fun in funs)
jaxprs, all_out_pvals, all_consts, all_out_trees = unzip4([
_initial_style_untyped_jaxpr(fun, in_tree, in_avals) for fun in funs])
newvar = core.gensym(jaxprs, suffix='_')
all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts]
for consts in all_consts]
unused_const_vars = [[newvar(aval) for aval in const_avals]
for const_avals in all_const_avals]
all_const_avals = tuple(
tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
for consts in all_consts)
unused_const_vars = tuple(
tuple(newvar(aval) for aval in const_avals)
for const_avals in all_const_avals)
def pad_jaxpr_constvars(i, jaxpr):
prefix = util.concatenate(unused_const_vars[:i])
suffix = util.concatenate(unused_const_vars[i+1:])
constvars = [*prefix, *jaxpr.constvars, *suffix]
constvars = prefix + jaxpr.constvars + suffix
return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
consts = util.concatenate(all_consts)
const_avals = util.concatenate(all_const_avals)
const_avals = tuple(util.concatenate(all_const_avals))
def type_and_const_convert_jaxpr(jaxpr, out_pvals):
out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
return core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
(), const_avals + in_avals, out_avals)
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
typed_jaxprs = [core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
(), [*const_avals, *in_avals], out_avals)
for jaxpr, out_avals in zip(jaxprs, all_out_avals)]
return typed_jaxprs, consts, all_out_trees
typed_jaxprs = _map(type_and_const_convert_jaxpr, jaxprs, all_out_pvals)
return (tuple(typed_jaxprs),
tuple(util.concatenate(all_consts)),
tuple(out_tree() for out_tree in all_out_trees))

View File

@ -458,12 +458,10 @@ def _psum_transpose_rule(cts, axis_name, axis_index_groups):
psum_p = core.Primitive('psum')
psum_p.multiple_results = True
psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p))
psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
pxla.soft_pmap_rules[psum_p] = \
partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = _psum_translation_rule
pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args)
ad.deflinear(psum_p, _psum_transpose_rule)
pxla.multi_host_supported_collectives.add(psum_p)
batching.split_axis_rules[psum_p] = partial(_split_axis_comm_assoc, psum_p)
@ -474,6 +472,21 @@ batching.collective_rules[psum_p] = \
lambda v, d: v.sum(d),
lambda v, axis_size: axis_size * v)
# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
# tracing time.
@psum_p.def_custom_bind
def psum_bind(*args, axis_name, axis_index_groups):
if all(not isinstance(x, core.Tracer) for x in args):
if axis_index_groups is not None:
size = len(axis_index_groups[0])
elif type(axis_name) is tuple:
size = prod([core.axis_frame(name).size for name in axis_name]) # type: ignore
else:
size = core.axis_frame(axis_name).size # type: ignore
return tuple(size * x for x in args)
return core.Primitive.bind(
psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups)
pmax_p = core.Primitive('pmax')
pmax_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
@ -660,20 +673,6 @@ def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
def _axis_index_bind(*, axis_name):
dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
frame = dynamic_axis_env[axis_name]
trace = frame.pmap_trace
out_aval = ShapedArray((), np.int32)
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
dict(axis_name=axis_name),
source_info_util.current())
out_tracer.recipe = eqn
return out_tracer
def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name):
assert not vals and not mapped
idx = axis_index(axis_name) # type: ignore
@ -682,46 +681,58 @@ def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name):
axis_index_p = core.Primitive('axis_index')
xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
pxla.soft_pmap_rules[axis_index_p] = _axis_index_soft_pmap_rule # type: ignore
axis_index_p.def_custom_bind(_axis_index_bind)
axis_index_p.def_abstract_eval(
lambda *args, **params: ShapedArray((), np.int32))
pxla.multi_host_supported_collectives.add(axis_index_p)
# Axis index doesn't get any arguments, so that the default bind would have no
# way to call into a data-dependency based trace such as vmap. Each trace that
# wants to bind an axis name has to additionally implement `process_axis_index`
# and put its main trace on the axis env stack.
def _axis_index_bind(*, axis_name):
frame = core.axis_frame(axis_name)
if frame.main_trace is not None:
trace = frame.main_trace.trace_type(frame.main_trace, core.cur_sublevel())
return trace.process_axis_index(frame)
return core.Primitive.bind(axis_index_p, axis_name=axis_name)
axis_index_p.def_custom_bind(_axis_index_bind)
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
# tracing time.
@psum_p.def_custom_bind
def psum_bind(*args, axis_name, axis_index_groups):
if all(not isinstance(x, core.Tracer) for x in args):
if axis_index_groups is not None:
size = len(axis_index_groups[0])
elif type(axis_name) is tuple:
size = prod([core.axis_frame(name).size for name in axis_name]) # type: ignore
else:
size = core.axis_frame(axis_name).size # type: ignore
return tuple(size * x for x in args)
return core.Primitive.bind(
psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups)
def _process_axis_index(self, frame):
return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0)
batching.BatchTrace.process_axis_index = _process_axis_index
if psum_p in pxla.parallel_pure_rules:
del pxla.parallel_pure_rules[psum_p]
# Axis index doesn't get any arguments, so that the default bind would have no
# way to call into a data-dependency based trace such as vmap. Each trace that
# wants to bind an axis name has to additionally implement `process_axis_index`
# and put its main trace on the axis env stack.
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global axis_index
psum_p.bind = partial(core.Primitive.bind, psum_p)
psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p)) # type: ignore
pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) # type: ignore
def _axis_index_bind(*, axis_name):
frame = core.axis_frame(axis_name)
if frame.main_trace is not None:
trace = frame.main_trace.trace_type(frame.main_trace, core.cur_sublevel())
return trace.process_axis_index(frame)
return core.Primitive.bind(axis_index_p, axis_name=axis_name)
dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
frame = dynamic_axis_env[axis_name]
sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1]
nreps = dynamic_axis_env.nreps
trace = frame.pmap_trace
out_aval = ShapedArray((), np.int32)
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
dict(nreps=nreps, sizes=sizes, axis_name=axis_name),
source_info_util.current())
out_tracer.recipe = eqn
return out_tracer
def _axis_index_translation_rule(c, nreps, sizes, axis_name):
div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
axis_index_p.def_custom_bind(_axis_index_bind)
def process_axis_index(self, frame):
return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0)
batching.BatchTrace.process_axis_index = process_axis_index
axis_index_p.def_abstract_eval(
lambda *args, **params: ShapedArray((), np.int32))
xla.translations[axis_index_p] = _axis_index_translation_rule

View File

@ -266,8 +266,9 @@ def one_hot(x, num_classes, *, dtype=jnp.float64):
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
"""
num_classes = core.concrete_or_error(int, num_classes,
"in jax.nn.one_hot argument `num_classes`")
num_classes = core.concrete_or_error(
int, num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
dtype = dtypes.canonicalize_dtype(dtype)
x = jnp.asarray(x)
lhs = x[..., jnp.newaxis]

View File

@ -1137,6 +1137,16 @@ def reshape(a, newshape, order="C"):
def _compute_newshape(a, newshape):
"""Fixes a -1 value in newshape, if present."""
# other errors, like having more than one -1, are caught downstream
try: iter(newshape)
except: iterable = False
else: iterable = True
if iterable:
newshape = [core.concrete_or_error(int, d,
"The error arose in jax.numpy.reshape.")
for d in newshape]
else:
newshape = core.concrete_or_error(int, newshape,
"The error arose in jax.numpy.reshape.")
newsize = _prod(newshape)
if newsize < 0:
fix = a.size // -newsize
@ -2417,7 +2427,7 @@ def identity(n, dtype=None):
def arange(start, stop=None, step=None, dtype=None):
lax._check_user_dtype_supported(dtype, "arange")
require = partial(core.concrete_or_error, _np_asarray)
msg = "in jax.numpy.arange argument `{}`".format
msg = "It arose in jax.numpy.arange argument `{}`.".format
if stop is None and step is None:
start = require(start, msg("stop"))
dtype = dtype or _dtype(start)
@ -2622,7 +2632,7 @@ def repeat(a, repeats, axis=None, *, total_repeat_length=None):
# If total_repeat_length is not given, can't compile, use a default.
if total_repeat_length is None:
repeats = core.concrete_or_error(np.array, repeats, "jax.numpy.repeat")
repeats = core.concrete_or_error(np.array, repeats, "It arose in jax.numpy.repeat.")
repeats = np.ravel(repeats)
if ndim(a) != 0:
repeats = np.broadcast_to(repeats, [a.shape[axis]])

View File

@ -513,7 +513,7 @@ class APITest(jtu.JaxTestCase):
self.assertRaisesRegex(
TypeError,
f"Try using `x.astype\\({castfun.__name__}\\)` instead.",
f"[Tt]ry using `x.astype\\({castfun.__name__}\\)`",
lambda: jit(f)(1.0))
def test_switch_value_jit(self):
@ -549,7 +549,7 @@ class APITest(jtu.JaxTestCase):
self.assertRaisesRegex(
TypeError,
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
"|Abstract tracer value encountered where concrete value is expected .*)", lambda: jit(f)(0))
"|Abstract tracer value encountered where concrete value is expected.*)", lambda: jit(f)(0))
def test_unimplemented_interpreter_rules(self):
foo_p = Primitive('foo')

View File

@ -4145,7 +4145,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp.sum(jnp.arange(3), (0, 0))
def testArangeConcretizationError(self):
msg = r"Abstract tracer.*\(in jax.numpy.arange argument `{}`\).*".format
msg = r"It arose in jax.numpy.arange argument `{}`".format
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')):
jax.jit(jnp.arange)(3)

View File

@ -150,7 +150,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
def testOneHotConcretizationError(self):
# https://github.com/google/jax/issues/3654
msg = r"Abstract tracer.*\(in jax.nn.one_hot argument `num_classes`\).*"
msg = r"in jax.nn.one_hot argument `num_classes`"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
jax.jit(nn.one_hot)(3, 5)