mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
omnistaging on by default (#4038)
This commit is contained in:
parent
6af476900a
commit
2678a4647a
8
.github/workflows/ci-build.yaml
vendored
8
.github/workflows/ci-build.yaml
vendored
@ -53,27 +53,27 @@ jobs:
|
||||
- python-version: 3.6
|
||||
os: ubuntu-latest
|
||||
enable-x64: 0
|
||||
enable-omnistaging: 0
|
||||
enable-omnistaging: 1
|
||||
package-overrides: "none"
|
||||
num_generated_cases: 25
|
||||
- python-version: 3.7
|
||||
os: ubuntu-latest
|
||||
enable-x64: 1
|
||||
enable-omnistaging: 0
|
||||
enable-omnistaging: 1
|
||||
# Test experimental NumPy dispatch
|
||||
package-overrides: "git+https://github.com/seberg/numpy-dispatch.git"
|
||||
num_generated_cases: 25
|
||||
- python-version: 3.6
|
||||
os: ubuntu-latest
|
||||
enable-x64: 1
|
||||
enable-omnistaging: 0
|
||||
enable-omnistaging: 1
|
||||
# Test with numpy version that matches Google-internal version
|
||||
package-overrides: "numpy==1.16.4"
|
||||
num_generated_cases: 10
|
||||
- python-version: 3.7
|
||||
os: ubuntu-latest
|
||||
enable-x64: 0
|
||||
enable-omnistaging: 1
|
||||
enable-omnistaging: 0
|
||||
package-overrides: "none"
|
||||
num_generated_cases: 8
|
||||
steps:
|
||||
|
211
docs/jaxpr.rst
211
docs/jaxpr.rst
@ -164,34 +164,15 @@ before (with two input vars, one for each element of the input tuple)
|
||||
|
||||
|
||||
Constant Vars
|
||||
--------------
|
||||
-------------
|
||||
|
||||
ConstVars arise when the computation contains array constants, either
|
||||
from the Python program, or from constant-folding. For example, the function
|
||||
``func6`` below
|
||||
Some values in jaxprs are constants, in that their value does not depend on the
|
||||
jaxpr's arguments. When these values are scalars they are represented directly
|
||||
in the jaxpr equations; non-scalar array constants are instead hoisted out to
|
||||
the top-level jaxpr, where they correspond to constant variables ("constvars").
|
||||
These constvars differ from the other jaxpr parameters ("invars") only as a
|
||||
bookkeeping convention.
|
||||
|
||||
>>> def func5(first, second):
|
||||
... temp = first + jnp.sin(second) * 3. - jnp.ones(8)
|
||||
... return temp
|
||||
...
|
||||
>>> def func6(first):
|
||||
... return func5(first, jnp.ones(8))
|
||||
...
|
||||
|
||||
JAX produces the following jaxpr
|
||||
|
||||
>>> print(make_jaxpr(func6)(jnp.ones(8)))
|
||||
{ lambda b d ; a.
|
||||
let c = add a b
|
||||
e = sub c d
|
||||
in (e,) }
|
||||
|
||||
When tracing ``func6``, the function ``func5`` is invoked with a constant value
|
||||
(``np.ones(8)``) for the second argument. As a result, the sub-expression
|
||||
``jnp.sin(second) * 3.`` is constant-folded.
|
||||
There are two ConstVars, ``b`` (standing for ``jnp.sin(second) * 3.``) and ``d``
|
||||
(standing for ``jnp.ones(8)``). Unfortunately, it is not easy to tell from the
|
||||
jaxpr notation what constants the constant variables stand for.
|
||||
|
||||
Higher-order primitives
|
||||
-----------------------
|
||||
@ -293,44 +274,25 @@ contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`
|
||||
>>> def func8(arg1, arg2): # arg2 is a pair
|
||||
... return lax.cond(arg1 >= 0.,
|
||||
... lambda xtrue: xtrue[0],
|
||||
... lambda xfalse: jnp.ones(1) + xfalse[1],
|
||||
... lambda xfalse: jnp.array([1]) + xfalse[1],
|
||||
... arg2)
|
||||
...
|
||||
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
|
||||
{ lambda f ; a b c.
|
||||
let d = ge a 0.0
|
||||
e = convert_element_type[ new_dtype=int32
|
||||
old_dtype=bool ] d
|
||||
g = cond[ branches=( { lambda ; c a b.
|
||||
let d = add c b
|
||||
in (d,) }
|
||||
{ lambda ; e_ a b.
|
||||
let
|
||||
{ lambda a ; b c d.
|
||||
let e = ge b 0.0
|
||||
f = convert_element_type[ new_dtype=int32
|
||||
old_dtype=bool ] e
|
||||
g = cond[ branches=( { lambda ; a b c.
|
||||
let d = convert_element_type[ new_dtype=float32
|
||||
old_dtype=int32 ] a
|
||||
e = add d c
|
||||
in (e,) }
|
||||
{ lambda ; f_ a b.
|
||||
let
|
||||
in (a,) } )
|
||||
linear=(False, False, False) ] e f b c
|
||||
linear=(False, False, False) ] f a c d
|
||||
in (g,) }
|
||||
|
||||
The top-level jaxpr has one `constvar` ``f`` (corresponding to
|
||||
``jnp.ones(1)`` from the body of the first (false) branch) and three
|
||||
input variables ``a b c`` (corresponding to ``arg1`` and the two
|
||||
elements of ``arg2``; note that ``arg2`` has been flattened). The
|
||||
``false_jaxpr`` has three input variables (``c`` corresponding to the
|
||||
constant for ``jnp.ones(1)``, and ``a b`` for the two elements of
|
||||
``arg2`` that are passed to ``false_jaxpr``). The ``true_jaxpr`` has
|
||||
three input variables. The first (``e_``) is an unused argument
|
||||
matching the constant first argument ``c`` of ``false_jaxpr``
|
||||
(required for the jaxpr signatures to match). The subsequent two
|
||||
correspond to the two elements of ``arg2`` that is passed to
|
||||
``true_jaxpr``.
|
||||
|
||||
The actual operands to the cond primitive are: ``e f b c``, which
|
||||
correspond in order to:
|
||||
|
||||
* one operand for the predicate,
|
||||
* one constant (only used by ``false_jaxpr``, but passed to both),
|
||||
i.e., ``f``, which is a constvar for the top-level jaxpr
|
||||
* two operands passed to both jaxprs, i.e., ``b`` and ``c``, which are
|
||||
input vars, corresponding to ``arg2`` for the top-level jaxpr.
|
||||
|
||||
While
|
||||
^^^^^
|
||||
@ -357,32 +319,22 @@ For example, here is an example fori loop
|
||||
... arg + ones)
|
||||
...
|
||||
>>> print(make_jaxpr(func10)(np.ones(16), 5))
|
||||
{ lambda c d ; a b.
|
||||
let e = add a d
|
||||
_ _ f = while[ body_jaxpr={ lambda ; e g a b c.
|
||||
let d = add a 1
|
||||
f = add c e
|
||||
h = add f g
|
||||
in (d, b, h) }
|
||||
{ lambda ; a b.
|
||||
let c = broadcast_in_dim[ broadcast_dimensions=( )
|
||||
shape=(16,) ] 1.0
|
||||
d = add a c
|
||||
_ _ e = while[ body_jaxpr={ lambda ; a b c d e.
|
||||
let f = add c 1
|
||||
g = mul a 3.0
|
||||
h = add e g
|
||||
i = add h b
|
||||
in (f, d, i) }
|
||||
body_nconsts=2
|
||||
cond_jaxpr={ lambda ; a b c.
|
||||
let d = lt a b
|
||||
in (d,) }
|
||||
cond_nconsts=0 ] c a 0 b e
|
||||
in (f,) }
|
||||
|
||||
The top-level jaxpr has two constvars: ``c`` (corresponding to ``ones * 3.`` from the body
|
||||
of the loop) and ``d`` (corresponding to the use of ``ones`` in the initial carry).
|
||||
There are also two input variables (``a`` corresponding to ``arg`` and ``b`` corresponding
|
||||
to ``n``).
|
||||
The loop carry consists of three values, as seen in the body of ``cond_jaxpr``
|
||||
(corresponding to the iteration index, iteration end, and the accumulated value carry).
|
||||
Note that ``body_jaxpr`` takes 5 input variables. The first two are actually
|
||||
constvars: ``e`` corresponding to ``ones * 3`` and ``g`` corresponding to the
|
||||
captures use of ``arg`` in the loop body.
|
||||
The parameter ``body_nconsts = 2`` specifies that there are 2 constants for the
|
||||
``body_jaxpr``.
|
||||
The other 3 input variables for ``body_jaxpr`` correspond to the flattened carry values.
|
||||
cond_nconsts=0 ] c a 0 b d
|
||||
in (e,) }
|
||||
|
||||
The while primitive takes 5 arguments: ``c a 0 b e``, as follows:
|
||||
|
||||
@ -395,13 +347,13 @@ Scan
|
||||
|
||||
JAX supports a special form of loop over the elements of an array (with
|
||||
statically known shape). The fact that there are a fixed number of iterations
|
||||
makes this form of looping easily reverse-differentiable. Such loops are constructed
|
||||
with the :py:func:`jax.lax.scan` operator::
|
||||
makes this form of looping easily reverse-differentiable. Such loops are
|
||||
constructed with the :py:func:`jax.lax.scan` function::
|
||||
|
||||
lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])
|
||||
|
||||
Here ``C`` is the type of the scan carry, ``A`` is the element type of the input array(s),
|
||||
and ``B`` is the element type of the output array(s).
|
||||
Here ``C`` is the type of the scan carry, ``A`` is the element type of the
|
||||
input array(s), and ``B`` is the element type of the output array(s).
|
||||
|
||||
For the example consider the function ``func11`` below
|
||||
|
||||
@ -415,12 +367,14 @@ For the example consider the function ``func11`` below
|
||||
... return lax.scan(body, 0., (arr, ones))
|
||||
...
|
||||
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
|
||||
{ lambda c ; a b.
|
||||
let d e = scan[ jaxpr={ lambda ; f a b c.
|
||||
let d = mul b c
|
||||
e = add a d
|
||||
g = add e f
|
||||
in (g, a) }
|
||||
{ lambda ; a b.
|
||||
let c = broadcast_in_dim[ broadcast_dimensions=( )
|
||||
shape=(16,) ] 1.0
|
||||
d e = scan[ jaxpr={ lambda ; a b c d.
|
||||
let e = mul c d
|
||||
f = add b e
|
||||
g = add f a
|
||||
in (g, b) }
|
||||
length=16
|
||||
linear=(False, False, False, False)
|
||||
num_carry=1
|
||||
@ -429,17 +383,6 @@ For the example consider the function ``func11`` below
|
||||
unroll=1 ] b 0.0 a c
|
||||
in (d, e) }
|
||||
|
||||
The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant,
|
||||
and two input variables corresponding to the arguments ``arr`` and ``extra``.
|
||||
The body of the scan has 4 input variables, of which:
|
||||
|
||||
* one (``f``) is a constant (since ``num_consts = 1``), and stands for the
|
||||
captured variable ``extra`` used in the loop body,
|
||||
* one (``a``) is the value of the carry (since ``num_carry = 1``)
|
||||
* The remaining 2 are the input values. ``b`` is the array element from the
|
||||
first array passed to lax.scan (``arr``) and ``c`` is the second array
|
||||
(``ones``).
|
||||
|
||||
The ``linear`` parameter describes for each of the input variables whether they
|
||||
are guaranteed to be used linearly in the body. Once the scan goes through
|
||||
linearization, more arguments will be linear.
|
||||
@ -466,37 +409,27 @@ computation should run. For example
|
||||
... return arg + inner(arg - 2.)
|
||||
...
|
||||
>>> print(make_jaxpr(func12)(1.))
|
||||
{ lambda b ; a.
|
||||
let c = sub a 2.0
|
||||
d = xla_call[ backend=None
|
||||
call_jaxpr={ lambda ; c b a.
|
||||
let d = mul b c
|
||||
e = add a d
|
||||
{ lambda ; a.
|
||||
let b = sub a 2.0
|
||||
c = xla_call[ backend=None
|
||||
call_jaxpr={ lambda ; a b.
|
||||
let c = broadcast_in_dim[ broadcast_dimensions=( )
|
||||
shape=(1,) ] 1.0
|
||||
d = mul a c
|
||||
e = add b d
|
||||
in (e,) }
|
||||
device=None
|
||||
donated_invars=(False, False, False)
|
||||
name=inner ] b a c
|
||||
e = add a d
|
||||
in (e,) }
|
||||
donated_invars=(False, False)
|
||||
name=inner ] a b
|
||||
d = add a c
|
||||
in (d,) }
|
||||
|
||||
The top-level constvar ``b`` refers to the ``jnp.ones(1)`` constant, and
|
||||
the top-level input variable `a` refers to the ``arg`` parameter of ``func12``.
|
||||
The ``xla_call`` primitive stands for a call to the jitted ``inner`` function.
|
||||
The primitive has the function body in the ``call_jaxpr`` parameter, a jaxpr
|
||||
with 3 input parameters:
|
||||
|
||||
* ``c`` is a constvar and stands for the ``ones`` constant,
|
||||
* ``b`` corresponds to the free variable ``arg`` captured in the ``inner`` function,
|
||||
* ``a`` corresponds to the ``inner`` parameter ``x``.
|
||||
|
||||
The primitive takes three arguments ``b a c``.
|
||||
|
||||
XLA_pmap
|
||||
^^^^^^^^
|
||||
|
||||
If you use the :py:func:`jax.pmap` transformation, the function to be
|
||||
mapped is captured using the ``xla_pmap`` primitive. Consider this
|
||||
example
|
||||
If you use the :py:func:`jax.pmap` transformation, the function to be mapped is
|
||||
captured using the ``xla_pmap`` primitive. Consider this example
|
||||
|
||||
>>> from jax import pmap
|
||||
>>>
|
||||
@ -507,34 +440,30 @@ example
|
||||
... return pmap(inner, axis_name='rows')(arr)
|
||||
...
|
||||
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
|
||||
{ lambda c ; a b.
|
||||
let d = xla_pmap[ axis_name=rows
|
||||
{ lambda ; a b.
|
||||
let c = xla_pmap[ axis_name=rows
|
||||
axis_size=1
|
||||
backend=None
|
||||
call_jaxpr={ lambda ; d b a.
|
||||
let c = add a b
|
||||
call_jaxpr={ lambda ; a b.
|
||||
let c = add b a
|
||||
d = broadcast_in_dim[ broadcast_dimensions=( )
|
||||
shape=(1,) ] 1.0
|
||||
e = add c d
|
||||
f = psum[ axis_index_groups=None
|
||||
axis_name=rows ] a
|
||||
axis_name=rows ] b
|
||||
g = div e f
|
||||
in (g,) }
|
||||
devices=None
|
||||
donated_invars=(False, False, False)
|
||||
donated_invars=(False, False)
|
||||
global_axis_size=None
|
||||
mapped_invars=(True, False, True)
|
||||
name=inner ] c b a
|
||||
in (d,) }
|
||||
mapped_invars=(False, True)
|
||||
name=inner ] b a
|
||||
in (c,) }
|
||||
|
||||
The top-level constvar ``c`` refers to the ``jnp.ones(1)`` constant.
|
||||
The ``xla_pmap`` primitive specifies the name of the axis (parameter ``rows``)
|
||||
and the body of the function to be mapped as the ``call_jaxpr`` parameter. The
|
||||
and the body of the function to be mapped as the ``call_jaxpr`` parameter.
|
||||
value of this parameter is a Jaxpr with 3 input variables:
|
||||
|
||||
* ``d`` stands for the constant ``jnp.ones(1)``,
|
||||
* ``b`` stands for the free variable ``extra``,
|
||||
* ``a`` stands for the parameter ``x`` of ``inner``.
|
||||
|
||||
|
||||
The parameter ``mapped_invars`` specify which of the input variables should be
|
||||
mapped and which should be broadcast. In our example, the value of ``extra``
|
||||
is broadcast, the other input values are mapped.
|
||||
|
19
jax/api.py
19
jax/api.py
@ -393,7 +393,7 @@ def disable_jit():
|
||||
... return y + 3
|
||||
...
|
||||
>>> print(f(jax.numpy.array([1, 2, 3])))
|
||||
Value of y is Traced<ShapedArray(int32[3]):JaxprTrace(level=-1/1)>
|
||||
Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=0/1)>
|
||||
[5 7 9]
|
||||
|
||||
Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`,
|
||||
@ -651,7 +651,7 @@ def _xla_computation(
|
||||
else:
|
||||
pvals = [pe.PartialVal.unknown(aval) for aval in avals]
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
|
||||
jaxtree_fun, pvals, instantiate=True, stage_out=True)
|
||||
jaxtree_fun, pvals, instantiate=True, stage_out=True) # type: ignore
|
||||
out_avals = [raise_to_shaped(pval.get_aval()) for pval in out_pvals]
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
|
||||
@ -1910,11 +1910,12 @@ def make_jaxpr(fun: Callable,
|
||||
>>> jax.make_jaxpr(jax.grad(f))(3.0)
|
||||
{ lambda ; a.
|
||||
let b = cos a
|
||||
c = cos b
|
||||
d = mul 1.0 c
|
||||
e = neg d
|
||||
f = sin a
|
||||
g = mul e f
|
||||
c = sin a
|
||||
_ = sin b
|
||||
d = cos b
|
||||
e = mul 1.0 d
|
||||
f = neg e
|
||||
g = mul f c
|
||||
in (g,) }
|
||||
"""
|
||||
_check_callable(fun)
|
||||
@ -1936,7 +1937,7 @@ def make_jaxpr(fun: Callable,
|
||||
else:
|
||||
in_pvals = [pe.PartialVal.unknown(a) for a in in_avals]
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
|
||||
jaxtree_fun, in_pvals, instantiate=True, stage_out=True)
|
||||
jaxtree_fun, in_pvals, instantiate=True, stage_out=True) # type: ignore
|
||||
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
|
||||
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
|
||||
return typed_jaxpr
|
||||
@ -2214,7 +2215,7 @@ class CustomTransformsFunction(object):
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
|
||||
else:
|
||||
with core.initial_style_staging():
|
||||
with core.initial_style_staging(): # type: ignore
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
|
||||
outs = self.prim.bind(*it.chain(consts, args_flat), jaxpr=jaxpr,
|
||||
in_tree=in_tree, out_tree=out_tree(),
|
||||
|
@ -42,8 +42,8 @@ class Config:
|
||||
self.meta = {}
|
||||
self.FLAGS = NameSpace(self.read)
|
||||
self.use_absl = False
|
||||
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', False)
|
||||
self._omnistaging_enablers = []
|
||||
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True)
|
||||
self._omnistaging_disablers = []
|
||||
|
||||
def update(self, name, val):
|
||||
if self.use_absl:
|
||||
@ -114,24 +114,25 @@ class Config:
|
||||
self.complete_absl_config(absl.flags)
|
||||
already_configured_with_absl = True
|
||||
|
||||
if FLAGS.jax_omnistaging:
|
||||
self.enable_omnistaging()
|
||||
if not FLAGS.jax_omnistaging:
|
||||
self.disable_omnistaging()
|
||||
|
||||
def register_omnistaging_enabler(self, enabler):
|
||||
if not self.omnistaging_enabled:
|
||||
self._omnistaging_enablers.append(enabler)
|
||||
|
||||
def register_omnistaging_disabler(self, disabler):
|
||||
if self.omnistaging_enabled:
|
||||
self._omnistaging_disablers.append(disabler)
|
||||
else:
|
||||
enabler()
|
||||
disabler()
|
||||
|
||||
# TODO(mattjj): remove this when omnistaging fully lands
|
||||
def enable_omnistaging(self):
|
||||
if not self.omnistaging_enabled:
|
||||
for enabler in self._omnistaging_enablers:
|
||||
enabler()
|
||||
self.omnistaging_enabled = True
|
||||
raise Exception("can't re-enable omnistaging after it's been disabled")
|
||||
|
||||
def disable_omnistaging(self):
|
||||
pass
|
||||
if self.omnistaging_enabled:
|
||||
for disabler in self._omnistaging_disablers:
|
||||
disabler()
|
||||
self.omnistaging_enabled = False
|
||||
|
||||
|
||||
class NameSpace(object):
|
||||
@ -156,6 +157,6 @@ flags.DEFINE_bool(
|
||||
|
||||
flags.DEFINE_bool(
|
||||
'jax_omnistaging',
|
||||
bool_env('JAX_OMNISTAGING', False),
|
||||
bool_env('JAX_OMNISTAGING', True),
|
||||
help='Enable staging based on dynamic context rather than data dependence.'
|
||||
)
|
||||
|
437
jax/core.py
437
jax/core.py
@ -24,7 +24,7 @@ import threading
|
||||
import types
|
||||
from typing import (Any, Callable, ClassVar, Dict, Generator,
|
||||
Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple,
|
||||
Type, Union, cast, no_type_check)
|
||||
Type, Union, cast)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -266,19 +266,14 @@ class Primitive:
|
||||
def __repr__(self):
|
||||
return '{}'.format(self.name)
|
||||
|
||||
def bind(self, *args, **kwargs):
|
||||
|
||||
def bind(self, *args, **params):
|
||||
assert skip_checks or all(isinstance(arg, Tracer)
|
||||
or valid_jaxtype(arg) for arg in args), args
|
||||
top_trace = find_top_trace(args)
|
||||
if top_trace is None:
|
||||
return self.impl(*args, **kwargs)
|
||||
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
|
||||
if self.multiple_results:
|
||||
return map(full_lower, out_tracer)
|
||||
else:
|
||||
return full_lower(out_tracer)
|
||||
out = top_trace.process_primitive(self, tracers, params)
|
||||
return map(full_lower, out) if self.multiple_results else full_lower(out)
|
||||
|
||||
def def_impl(self, impl):
|
||||
self.impl = impl
|
||||
@ -517,14 +512,8 @@ class Tracer:
|
||||
def __long__(self): return self.aval._long(self)
|
||||
def __hex__(self): return self.aval._hex(self)
|
||||
def __oct__(self): return self.aval._oct(self)
|
||||
|
||||
def __float__(self):
|
||||
raise TypeError("JAX Tracer object cannot be interpreted as a float. "
|
||||
"Try using `x.astype(float)` instead.")
|
||||
|
||||
def __complex__(self):
|
||||
raise TypeError("JAX Tracer object cannot be interpreted as a complex. "
|
||||
"Try using `x.astype(complex)` instead.")
|
||||
def __float__(self): return self.aval._float(self)
|
||||
def __complex__(self): return self.aval._complex(self)
|
||||
|
||||
def __setitem__(self, idx, val):
|
||||
raise TypeError("JAX 'Tracer' objects do not support item assignment")
|
||||
@ -571,6 +560,9 @@ class Tracer:
|
||||
def __deepcopy__(self, unused_memo):
|
||||
return self
|
||||
|
||||
def _origin_msg(self) -> str:
|
||||
return ""
|
||||
|
||||
# these can be used to set up forwarding of properties and instance methods from
|
||||
# Tracer instances to the underlying avals
|
||||
aval_property = namedtuple("aval_property", ["fget"])
|
||||
@ -612,57 +604,47 @@ class TraceStack:
|
||||
downward: List[MainTrace]
|
||||
|
||||
def __init__(self):
|
||||
self.upward = []
|
||||
self.downward = []
|
||||
eval_trace = MainTrace(0, EvalTrace)
|
||||
self.stack = [eval_trace]
|
||||
self.dynamic = eval_trace
|
||||
|
||||
def next_level(self, bottom: bool) -> int:
|
||||
if bottom:
|
||||
return - (len(self.downward) + 1)
|
||||
else:
|
||||
return len(self.upward)
|
||||
def next_level(self) -> int:
|
||||
return len(self.stack)
|
||||
|
||||
def push(self, main_trace: MainTrace, bottom: bool) -> None:
|
||||
if bottom:
|
||||
self.downward.append(main_trace)
|
||||
else:
|
||||
self.upward.append(main_trace)
|
||||
def push(self, main_trace: MainTrace) -> None:
|
||||
self.stack.append(main_trace)
|
||||
|
||||
def pop(self, bottom: bool) -> None:
|
||||
if bottom:
|
||||
self.downward.pop()
|
||||
else:
|
||||
self.upward.pop()
|
||||
def pop(self) -> None:
|
||||
self.stack.pop()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return 'Trace stack\n{} ---\n{}'.format(
|
||||
map(' {}\n'.format, self.upward[::-1]),
|
||||
map(' {}\n'.format, self.downward))
|
||||
stack_str = map(' {}\n'.format, self.stack[::-1])
|
||||
return f'Trace stack\n{stack_str}\n{self.dynamic}'
|
||||
|
||||
def copy(self):
|
||||
new = TraceStack()
|
||||
new.upward = self.upward[:]
|
||||
new.downward = self.downward[:]
|
||||
new = self.__new__(TraceStack)
|
||||
new.stack = self.stack[:]
|
||||
new.dynamic = self.dynamic
|
||||
return new
|
||||
|
||||
class Sublevel(int): pass
|
||||
AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace'])
|
||||
|
||||
|
||||
class TraceState:
|
||||
trace_stack: TraceStack
|
||||
substack: List[Sublevel]
|
||||
initial_style: bool
|
||||
axis_env: List[AxisEnvFrame]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.trace_stack = TraceStack()
|
||||
self.substack = [Sublevel(0)]
|
||||
self.initial_style = False
|
||||
self.axis_env = []
|
||||
|
||||
def copy(self):
|
||||
new = TraceState()
|
||||
new = self.__new__(TraceState)
|
||||
new.trace_stack = self.trace_stack.copy()
|
||||
new.substack = self.substack[:]
|
||||
new.initial_style = self.initial_style
|
||||
new.axis_env = self.axis_env[:]
|
||||
return new
|
||||
|
||||
# The global state of the tracer is accessed by a thread-local object.
|
||||
@ -676,8 +658,9 @@ thread_local_state = ThreadLocalState()
|
||||
def reset_trace_state() -> bool:
|
||||
"Reset the global trace state and return True if it was already clean."
|
||||
if (thread_local_state.trace_state.substack != [Sublevel(0)] or
|
||||
thread_local_state.trace_state.trace_stack.downward or
|
||||
thread_local_state.trace_state.trace_stack.upward):
|
||||
thread_local_state.trace_state.axis_env != [] or
|
||||
thread_local_state.trace_state.trace_stack.stack != [MainTrace(0, EvalTrace)] or
|
||||
thread_local_state.trace_state.trace_stack.dynamic != MainTrace(0, EvalTrace)):
|
||||
thread_local_state.trace_state.__init__() # type: ignore
|
||||
return False
|
||||
else:
|
||||
@ -687,15 +670,21 @@ def cur_sublevel() -> Sublevel:
|
||||
return thread_local_state.trace_state.substack[-1]
|
||||
|
||||
@contextmanager
|
||||
def new_main(trace_type: Type[Trace], bottom=False) -> Generator[MainTrace, None, None]:
|
||||
level = thread_local_state.trace_state.trace_stack.next_level(bottom)
|
||||
def new_main(trace_type: Type[Trace], dynamic: bool = False,
|
||||
) -> Generator[MainTrace, None, None]:
|
||||
stack = thread_local_state.trace_state.trace_stack
|
||||
level = stack.next_level()
|
||||
main = MainTrace(level, trace_type)
|
||||
thread_local_state.trace_state.trace_stack.push(main, bottom)
|
||||
stack.push(main)
|
||||
if dynamic:
|
||||
prev_dynamic, stack.dynamic = stack.dynamic, main
|
||||
|
||||
try:
|
||||
yield main
|
||||
finally:
|
||||
thread_local_state.trace_state.trace_stack.pop(bottom)
|
||||
thread_local_state.trace_state.trace_stack.pop()
|
||||
if dynamic:
|
||||
stack.dynamic = prev_dynamic
|
||||
|
||||
if check_leaks:
|
||||
t = ref(main)
|
||||
@ -704,6 +693,23 @@ def new_main(trace_type: Type[Trace], bottom=False) -> Generator[MainTrace, None
|
||||
print(thread_local_state.trace_state.trace_stack)
|
||||
raise Exception('Leaked trace {}'.format(t()))
|
||||
|
||||
@contextmanager
|
||||
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
|
||||
stack = thread_local_state.trace_state.trace_stack
|
||||
main = MainTrace(0, trace_type)
|
||||
prev_dynamic, stack.dynamic = stack.dynamic, main
|
||||
prev_base, stack.stack[0] = stack.stack[0], main
|
||||
try:
|
||||
yield main
|
||||
finally:
|
||||
stack.dynamic = prev_dynamic
|
||||
stack.stack[0] = prev_base
|
||||
|
||||
@contextmanager
|
||||
def eval_context():
|
||||
with new_base_main(EvalTrace):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def new_sublevel() -> Generator[None, None, None]:
|
||||
sublevel = Sublevel(len(thread_local_state.trace_state.substack))
|
||||
@ -719,29 +725,24 @@ def new_sublevel() -> Generator[None, None, None]:
|
||||
if t() is not None:
|
||||
raise Exception('Leaked sublevel {}'.format(t()))
|
||||
|
||||
def maybe_new_sublevel(trace):
|
||||
# dynamic traces run the WrappedFun, so we raise the sublevel for them
|
||||
dynamic = thread_local_state.trace_state.trace_stack.dynamic
|
||||
return new_sublevel() if trace.main is dynamic else suppress()
|
||||
|
||||
def full_lower(val):
|
||||
if isinstance(val, Tracer):
|
||||
return val.full_lower()
|
||||
else:
|
||||
return val
|
||||
|
||||
def find_top_trace(xs) -> Optional[Trace]:
|
||||
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
|
||||
key=attrgetter('level'), default=None)
|
||||
return top_trace and type(top_trace)(top_trace.main, cur_sublevel())
|
||||
|
||||
@contextmanager
|
||||
def initial_style_staging():
|
||||
trace_state = thread_local_state.trace_state
|
||||
prev, trace_state.initial_style = trace_state.initial_style, True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
trace_state.initial_style = prev
|
||||
|
||||
@contextmanager
|
||||
def eval_context():
|
||||
yield # dummy implementation for forward compatibility
|
||||
def find_top_trace(xs) -> Trace:
|
||||
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
|
||||
default=None, key=attrgetter('level'))
|
||||
dynamic = thread_local_state.trace_state.trace_stack.dynamic
|
||||
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
|
||||
else top_main)
|
||||
return top_main and top_main.trace_type(top_main, cur_sublevel()) # type: ignore
|
||||
|
||||
|
||||
# -------------------- abstract values --------------------
|
||||
@ -844,20 +845,24 @@ pytype_aval_mappings[Unit] = lambda _: abstract_unit
|
||||
|
||||
class ConcretizationTypeError(TypeError): pass
|
||||
|
||||
def raise_concretization_error(val, context=""):
|
||||
msg = (f"Abstract tracer value encountered where concrete value is expected ({context}).\n"
|
||||
"Use transformation parameters such as `static_argnums` for `jit` "
|
||||
"to avoid tracing input values.\n"
|
||||
"See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.\n"
|
||||
f"Encountered value: {val}")
|
||||
def raise_concretization_error(val: Tracer, context=""):
|
||||
msg = ("Abstract tracer value encountered where concrete value is expected.\n\n"
|
||||
+ context + "\n\n"
|
||||
+ val._origin_msg() + "\n\n"
|
||||
+ "You can use transformation parameters such as `static_argnums` for "
|
||||
"`jit` to avoid tracing particular arguments of transformed functions.\n\n"
|
||||
"See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
|
||||
f"Encountered tracer value: {val}")
|
||||
raise ConcretizationTypeError(msg)
|
||||
|
||||
|
||||
def concretization_function_error(fun, context=""):
|
||||
def concretization_function_error(fun, suggest_astype=False):
|
||||
fname = getattr(fun, "__name__", fun)
|
||||
fname_context = f"in `{fname}`"
|
||||
if context:
|
||||
fname_context += f" {context}"
|
||||
fname_context = f"The problem arose with the `{fname}` function. "
|
||||
if suggest_astype:
|
||||
fname_context += ("If trying to convert the data type of a value, "
|
||||
f"try using `x.astype({fun.__name__})` "
|
||||
f"or `jnp.array(x, {fun.__name__})` instead.")
|
||||
def error(self, arg):
|
||||
raise_concretization_error(arg, fname_context)
|
||||
return error
|
||||
@ -899,12 +904,9 @@ class UnshapedArray(AbstractValue):
|
||||
", weak_type=True" if self.weak_type else "")
|
||||
|
||||
_bool = _nonzero = concretization_function_error(bool)
|
||||
_float = concretization_function_error(
|
||||
float, "Try using `x.astype(float)` instead.")
|
||||
_int = concretization_function_error(
|
||||
int, "Try using `x.astype(int)` instead.")
|
||||
_complex = concretization_function_error(
|
||||
complex, "Try using `x.astype(complex)` instead.")
|
||||
_float = concretization_function_error(float, True)
|
||||
_int = concretization_function_error(int, True)
|
||||
_complex = concretization_function_error(complex, True)
|
||||
_hex = concretization_function_error(hex)
|
||||
_oct = concretization_function_error(oct)
|
||||
|
||||
@ -1036,9 +1038,12 @@ class ConcreteArray(ShapedArray):
|
||||
return ConcreteArray(self.val) if self.weak_type else self
|
||||
|
||||
_bool = _nonzero = partialmethod(_forward_to_value, bool)
|
||||
_int = partialmethod(_forward_to_value, int)
|
||||
_hex = partialmethod(_forward_to_value, hex)
|
||||
_oct = partialmethod(_forward_to_value, oct)
|
||||
_int = partialmethod(_forward_to_value, int)
|
||||
_hex = partialmethod(_forward_to_value, hex)
|
||||
_oct = partialmethod(_forward_to_value, oct)
|
||||
|
||||
_float = concretization_function_error(float, True)
|
||||
_complex = concretization_function_error(complex, True)
|
||||
|
||||
|
||||
class AbstractToken(AbstractValue):
|
||||
@ -1123,20 +1128,16 @@ def process_env_traces(primitive: Union['CallPrimitive', 'MapPrimitive'],
|
||||
yield outs, tuple(todo) # Ensure the aux output is immutable
|
||||
|
||||
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
|
||||
fun: lu.WrappedFun, *args, **params):
|
||||
fun, *args, **params):
|
||||
params_tuple = tuple(params.items())
|
||||
top_trace = find_top_trace(args)
|
||||
level = (thread_local_state.trace_state.trace_stack.next_level(True)
|
||||
if top_trace is None else top_trace.level)
|
||||
params_tuple = tuple(params.items())
|
||||
fun, env_trace_todo = process_env_traces(fun, primitive, level, params_tuple)
|
||||
if top_trace is None:
|
||||
with new_sublevel():
|
||||
outs = primitive.impl(fun, *args, **params)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
fun, env_trace_todo = process_env_traces(
|
||||
fun, primitive, top_trace and top_trace.level, params_tuple)
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
with maybe_new_sublevel(top_trace):
|
||||
outs = primitive.process(top_trace, fun, tracers, params)
|
||||
return apply_todos(env_trace_todo(), map(full_lower, outs))
|
||||
return map(full_lower, apply_todos(env_trace_todo(), outs))
|
||||
|
||||
|
||||
class CallPrimitive(Primitive):
|
||||
multiple_results = True
|
||||
@ -1176,10 +1177,65 @@ class MapPrimitive(Primitive):
|
||||
def post_process(self, trace, out_tracers, params):
|
||||
return trace.post_process_map(self, out_tracers, params)
|
||||
|
||||
# This is a no-op with omnistaging disabled
|
||||
@contextmanager
|
||||
def extend_axis_env(axis_name, size: int, tag: Any):
|
||||
yield
|
||||
frame = AxisEnvFrame(axis_name, size, tag)
|
||||
thread_local_state.trace_state.axis_env.append(frame)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
thread_local_state.trace_state.axis_env.pop()
|
||||
|
||||
def axis_frame(axis_name):
|
||||
frames = thread_local_state.trace_state.axis_env
|
||||
for frame in reversed(frames):
|
||||
if frame.name == axis_name:
|
||||
return frame
|
||||
else:
|
||||
raise NameError("unbound axis name: {}".format(axis_name))
|
||||
|
||||
def axis_index(axis_name):
|
||||
"""Return the index along the mapped axis ``axis_name``.
|
||||
|
||||
Args:
|
||||
axis_name: hashable Python object used to name the mapped axis.
|
||||
|
||||
Returns:
|
||||
An integer representing the index.
|
||||
|
||||
For example, with 8 XLA devices available:
|
||||
|
||||
>>> from functools import partial
|
||||
>>> @partial(jax.pmap, axis_name='i')
|
||||
... def f(_):
|
||||
... return lax.axis_index('i')
|
||||
...
|
||||
>>> f(np.zeros(4))
|
||||
ShardedDeviceArray([0, 1, 2, 3], dtype=int32)
|
||||
>>> f(np.zeros(8))
|
||||
ShardedDeviceArray([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
|
||||
>>> @partial(jax.pmap, axis_name='i')
|
||||
... @partial(jax.pmap, axis_name='j')
|
||||
... def f(_):
|
||||
... return lax.axis_index('i'), lax.axis_index('j')
|
||||
...
|
||||
>>> x, y = f(np.zeros((4, 2)))
|
||||
>>> print(x)
|
||||
[[0 0]
|
||||
[1 1]
|
||||
[2 2]
|
||||
[3 3]]
|
||||
>>> print(y)
|
||||
[[0 1]
|
||||
[0 1]
|
||||
[0 1]
|
||||
[0 1]]
|
||||
"""
|
||||
return axis_index_p.bind(axis_name=axis_name)
|
||||
|
||||
axis_index_p = Primitive('axis_index')
|
||||
axis_index_p.def_abstract_eval(lambda *, axis_name: ShapedArray((), np.int32))
|
||||
|
||||
|
||||
# ------------------- Jaxpr checking -------------------
|
||||
|
||||
@ -1368,7 +1424,7 @@ def pp_eqn(eqn: JaxprEqn, print_shapes: bool = False) -> PrettyPrint:
|
||||
pp_rhs = (pp(eqn.primitive.name) >>
|
||||
pp_kv_pairs(sorted(eqn.params.items())) >> pp(' ') >>
|
||||
pp(pp_vars(eqn.invars, print_shapes)))
|
||||
if len(lhs) <= 6:
|
||||
if len(lhs) <= 6 or print_shapes:
|
||||
return pp_lhs >> pp(' ') >> pp_rhs
|
||||
else:
|
||||
return pp_lhs + pp_rhs.indent(2)
|
||||
@ -1428,61 +1484,63 @@ def pp_kv_pairs(kv_pairs):
|
||||
else:
|
||||
return pp('')
|
||||
|
||||
axis_frame = None
|
||||
|
||||
# TODO(mattjj): remove when omnistaging fully lands
|
||||
@config.register_omnistaging_enabler
|
||||
@no_type_check
|
||||
def omnistaging_enabler() -> None:
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global thread_local_state, call_bind, find_top_trace, initial_style_staging, \
|
||||
new_main, reset_trace_state, extend_axis_env, axis_frame, \
|
||||
new_base_main, eval_context, \
|
||||
TraceStack, TraceState
|
||||
del initial_style_staging
|
||||
new_main, reset_trace_state, TraceStack, TraceState, extend_axis_env
|
||||
|
||||
class TraceStack:
|
||||
stack: List[MainTrace]
|
||||
dynamic: MainTrace
|
||||
upward: List[MainTrace]
|
||||
downward: List[MainTrace]
|
||||
|
||||
def __init__(self):
|
||||
eval_trace = MainTrace(0, EvalTrace)
|
||||
self.stack = [eval_trace]
|
||||
self.dynamic = eval_trace
|
||||
self.upward = []
|
||||
self.downward = []
|
||||
|
||||
def next_level(self) -> int:
|
||||
return len(self.stack)
|
||||
def next_level(self, bottom: bool) -> int:
|
||||
if bottom:
|
||||
return - (len(self.downward) + 1)
|
||||
else:
|
||||
return len(self.upward)
|
||||
|
||||
def push(self, main_trace: MainTrace) -> None:
|
||||
self.stack.append(main_trace)
|
||||
def push(self, main_trace: MainTrace, bottom: bool) -> None:
|
||||
if bottom:
|
||||
self.downward.append(main_trace)
|
||||
else:
|
||||
self.upward.append(main_trace)
|
||||
|
||||
def pop(self) -> None:
|
||||
self.stack.pop()
|
||||
def pop(self, bottom: bool) -> None:
|
||||
if bottom:
|
||||
self.downward.pop()
|
||||
else:
|
||||
self.upward.pop()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
stack_str = map(' {}\n'.format, self.stack[::-1])
|
||||
return f'Trace stack\n{stack_str}\n{self.dynamic}'
|
||||
return 'Trace stack\n{} ---\n{}'.format(
|
||||
map(' {}\n'.format, self.upward[::-1]),
|
||||
map(' {}\n'.format, self.downward))
|
||||
|
||||
def copy(self):
|
||||
new = self.__new__(TraceStack)
|
||||
new.stack = self.stack[:]
|
||||
new.dynamic = self.dynamic
|
||||
new = TraceStack()
|
||||
new.upward = self.upward[:]
|
||||
new.downward = self.downward[:]
|
||||
return new
|
||||
|
||||
class TraceState:
|
||||
trace_stack: TraceStack
|
||||
substack: List[Sublevel]
|
||||
axis_env: List[AxisEnvFrame]
|
||||
initial_style: bool
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.trace_stack = TraceStack()
|
||||
self.trace_stack = TraceStack() # type: ignore
|
||||
self.substack = [Sublevel(0)]
|
||||
self.axis_env = []
|
||||
self.initial_style = False
|
||||
|
||||
def copy(self):
|
||||
new = self.__new__(TraceState)
|
||||
new = TraceState()
|
||||
new.trace_stack = self.trace_stack.copy()
|
||||
new.substack = self.substack[:]
|
||||
new.axis_env = self.axis_env[:]
|
||||
new.initial_style = self.initial_style
|
||||
return new
|
||||
|
||||
thread_local_state = ThreadLocalState()
|
||||
@ -1490,54 +1548,23 @@ def omnistaging_enabler() -> None:
|
||||
def reset_trace_state() -> bool:
|
||||
"Reset the global trace state and return True if it was already clean."
|
||||
if (thread_local_state.trace_state.substack != [Sublevel(0)] or
|
||||
thread_local_state.trace_state.axis_env != [] or
|
||||
thread_local_state.trace_state.trace_stack.stack != [MainTrace(0, EvalTrace)] or
|
||||
thread_local_state.trace_state.trace_stack.dynamic != MainTrace(0, EvalTrace)):
|
||||
thread_local_state.trace_state.trace_stack.downward or
|
||||
thread_local_state.trace_state.trace_stack.upward):
|
||||
thread_local_state.trace_state.__init__() # type: ignore
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
|
||||
fun, *args, **params):
|
||||
params_tuple = tuple(params.items())
|
||||
top_trace = find_top_trace(args)
|
||||
fun, env_trace_todo = process_env_traces(
|
||||
fun, primitive, top_trace and top_trace.level, params_tuple)
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
with maybe_new_sublevel(top_trace):
|
||||
outs = primitive.process(top_trace, fun, tracers, params)
|
||||
return map(full_lower, apply_todos(env_trace_todo(), outs))
|
||||
|
||||
def maybe_new_sublevel(trace):
|
||||
# dynamic traces run the WrappedFun, so we raise the sublevel for them
|
||||
dynamic = thread_local_state.trace_state.trace_stack.dynamic
|
||||
return new_sublevel() if trace.main is dynamic else suppress()
|
||||
|
||||
def find_top_trace(xs) -> Trace:
|
||||
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
|
||||
default=None, key=attrgetter('level'))
|
||||
dynamic = thread_local_state.trace_state.trace_stack.dynamic
|
||||
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
|
||||
else top_main)
|
||||
return top_main and top_main.trace_type(top_main, cur_sublevel()) # type: ignore
|
||||
|
||||
@contextmanager
|
||||
def new_main(trace_type: Type[Trace], dynamic: bool = False,
|
||||
) -> Generator[MainTrace, None, None]:
|
||||
stack = thread_local_state.trace_state.trace_stack
|
||||
level = stack.next_level()
|
||||
def new_main(trace_type: Type[Trace], bottom=False) -> Generator[MainTrace, None, None]:
|
||||
level = thread_local_state.trace_state.trace_stack.next_level(bottom)
|
||||
main = MainTrace(level, trace_type)
|
||||
stack.push(main)
|
||||
if dynamic:
|
||||
prev_dynamic, stack.dynamic = stack.dynamic, main
|
||||
thread_local_state.trace_state.trace_stack.push(main, bottom)
|
||||
|
||||
try:
|
||||
yield main
|
||||
finally:
|
||||
thread_local_state.trace_state.trace_stack.pop()
|
||||
if dynamic:
|
||||
stack.dynamic = prev_dynamic
|
||||
thread_local_state.trace_state.trace_stack.pop(bottom)
|
||||
|
||||
if check_leaks:
|
||||
t = ref(main)
|
||||
@ -1546,47 +1573,55 @@ def omnistaging_enabler() -> None:
|
||||
print(thread_local_state.trace_state.trace_stack)
|
||||
raise Exception('Leaked trace {}'.format(t()))
|
||||
|
||||
@contextmanager
|
||||
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
|
||||
stack = thread_local_state.trace_state.trace_stack
|
||||
main = MainTrace(0, trace_type)
|
||||
prev_dynamic, stack.dynamic = stack.dynamic, main
|
||||
prev_base, stack.stack[0] = stack.stack[0], main
|
||||
try:
|
||||
yield main
|
||||
finally:
|
||||
stack.dynamic = prev_dynamic
|
||||
stack.stack[0] = prev_base
|
||||
def find_top_trace(xs) -> Optional[Trace]:
|
||||
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
|
||||
key=attrgetter('level'), default=None)
|
||||
return top_trace and type(top_trace)(top_trace.main, cur_sublevel())
|
||||
|
||||
@contextmanager
|
||||
def eval_context():
|
||||
with new_base_main(EvalTrace):
|
||||
yield
|
||||
yield # dummy implementation for forward compatibility
|
||||
|
||||
def bind(self, *args, **params):
|
||||
def bind(self, *args, **kwargs):
|
||||
assert skip_checks or all(isinstance(arg, Tracer)
|
||||
or valid_jaxtype(arg) for arg in args), args
|
||||
top_trace = find_top_trace(args)
|
||||
if top_trace is None:
|
||||
return self.impl(*args, **kwargs)
|
||||
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
out = top_trace.process_primitive(self, tracers, params)
|
||||
return map(full_lower, out) if self.multiple_results else full_lower(out)
|
||||
Primitive.bind = bind
|
||||
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
|
||||
if self.multiple_results:
|
||||
return map(full_lower, out_tracer)
|
||||
else:
|
||||
return full_lower(out_tracer)
|
||||
Primitive.bind = bind # type: ignore
|
||||
|
||||
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
|
||||
fun: lu.WrappedFun, *args, **params):
|
||||
params_tuple = tuple(params.items())
|
||||
top_trace = find_top_trace(args)
|
||||
level = (thread_local_state.trace_state.trace_stack.next_level(True)
|
||||
if top_trace is None else top_trace.level)
|
||||
params_tuple = tuple(params.items())
|
||||
fun, env_trace_todo = process_env_traces(fun, primitive, level, params_tuple)
|
||||
if top_trace is None:
|
||||
with new_sublevel():
|
||||
outs = primitive.impl(fun, *args, **params)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = primitive.process(top_trace, fun, tracers, params)
|
||||
return apply_todos(env_trace_todo(), map(full_lower, outs))
|
||||
|
||||
@contextmanager
|
||||
def extend_axis_env(axis_name, size: int, main_trace: Optional[MainTrace]):
|
||||
frame = AxisEnvFrame(axis_name, size, main_trace)
|
||||
thread_local_state.trace_state.axis_env.append(frame)
|
||||
def extend_axis_env(axis_name, size: int, tag: Any):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def initial_style_staging():
|
||||
trace_state = thread_local_state.trace_state
|
||||
prev, trace_state.initial_style = trace_state.initial_style, True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
frame_ = thread_local_state.trace_state.axis_env.pop()
|
||||
assert frame is frame_ # Only runs if there was was no exception
|
||||
|
||||
def axis_frame(axis_name):
|
||||
frames = thread_local_state.trace_state.axis_env
|
||||
for frame in reversed(frames):
|
||||
if frame.name == axis_name:
|
||||
return frame
|
||||
else:
|
||||
raise NameError(f"Unbound axis name: {axis_name}.\n"
|
||||
f"The currently bound axes are: {','.join(f.name for f in frames)}")
|
||||
trace_state.initial_style = prev
|
||||
|
@ -69,11 +69,7 @@ def _memoize(thunk):
|
||||
return memoized
|
||||
|
||||
def _initial_style_jaxpr(fun, in_avals):
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
|
||||
bottom=True, stage_out=False)
|
||||
assert not any(isinstance(c, core.Tracer) for c in consts)
|
||||
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
|
||||
return typed_jaxpr
|
||||
|
||||
@ -276,12 +272,8 @@ class CustomJVPCallPrimitive(core.CallPrimitive):
|
||||
fun, self, top_trace and top_trace.level, ())
|
||||
jvp, env_trace_todo2 = core.process_env_traces(
|
||||
jvp, self, top_trace and top_trace.level, ())
|
||||
if top_trace is None:
|
||||
with core.new_sublevel():
|
||||
outs = self.impl(fun, jvp, *args)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)
|
||||
tracers = map(top_trace.full_raise, args) # type: ignore
|
||||
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore
|
||||
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
|
||||
if env_trace_todo:
|
||||
raise core.UnexpectedTracerError
|
||||
@ -602,13 +594,16 @@ xla.initial_style_translations[custom_vjp_call_jaxpr_p] = \
|
||||
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
|
||||
|
||||
# TODO(mattjj): remove when omnistaging fully lands
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
global _initial_style_jaxpr
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global _initial_style_jaxpr, custom_jvp_call
|
||||
|
||||
def _initial_style_jaxpr(fun, in_avals):
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
|
||||
bottom=True, stage_out=False) # type: ignore
|
||||
assert not any(isinstance(c, core.Tracer) for c in consts)
|
||||
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
|
||||
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
|
||||
return typed_jaxpr
|
||||
|
||||
@ -619,10 +614,15 @@ def omnistaging_enabler() -> None:
|
||||
fun, self, top_trace and top_trace.level, ())
|
||||
jvp, env_trace_todo2 = core.process_env_traces(
|
||||
jvp, self, top_trace and top_trace.level, ())
|
||||
tracers = map(top_trace.full_raise, args) # type: ignore
|
||||
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore
|
||||
if top_trace is None:
|
||||
with core.new_sublevel():
|
||||
outs = self.impl(fun, jvp, *args)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)
|
||||
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
|
||||
if env_trace_todo:
|
||||
raise core.UnexpectedTracerError
|
||||
return map(core.full_lower, outs)
|
||||
CustomJVPCallPrimitive.bind = bind # type: ignore
|
||||
custom_jvp_call = custom_jvp_call_p.bind
|
||||
|
@ -237,10 +237,6 @@ deflinear(lax.reduce_window_sum_p)
|
||||
deflinear(lax_fft.fft_p)
|
||||
deflinear(xla.device_put_p)
|
||||
|
||||
# TODO(mattjj): remove when omnistaging fully lands
|
||||
try: deflinear(lax.tie_in_p)
|
||||
except AttributeError: pass
|
||||
|
||||
def _cumulative_jet_rule(primals_in, series_in, *, axis: int,
|
||||
prefix_scan: Callable):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
@ -523,8 +519,8 @@ def _gen_reduce_choose_taylor_rule(chooser_fun):
|
||||
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
|
||||
return primal_out, series_out
|
||||
return chooser_taylor_rule
|
||||
jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(lax.reduce_max_p.bind)
|
||||
jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(lax.reduce_min_p.bind)
|
||||
jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(lax._reduce_max)
|
||||
jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(lax._reduce_min)
|
||||
|
||||
def _abs_taylor_rule(x, series_in, **params):
|
||||
x, = x
|
||||
@ -584,3 +580,6 @@ def _custom_jvp_call_jaxpr_rule(primals_in, series_in, *, fun_jaxpr,
|
||||
del jvp_jaxpr_thunk
|
||||
return jet(core.jaxpr_as_fun(fun_jaxpr), primals_in, series_in)
|
||||
jet_rules[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_rule
|
||||
|
||||
|
||||
deflinear(lax.tie_in_p)
|
||||
|
@ -51,9 +51,9 @@ def closure_convert(fun, in_tree, in_avals):
|
||||
else:
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
with core.initial_style_staging():
|
||||
with core.initial_style_staging(): # type: ignore
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
|
||||
wrapped_fun, in_pvals, instantiate=True, stage_out=False)
|
||||
wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore
|
||||
out_tree = out_tree()
|
||||
|
||||
# We only want to closure convert for constants with respect to which we're
|
||||
|
@ -504,7 +504,7 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_
|
||||
else:
|
||||
primal_jaxpr, tangent_jaxpr, out_unknowns = \
|
||||
pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True,
|
||||
trace_type=None)
|
||||
trace_type=None) # type: ignore
|
||||
|
||||
def do_transpose(primals_in, cotangents_in):
|
||||
# NOTE: This is passing in undefined primals in place of tangent arguments, but it
|
||||
@ -555,9 +555,7 @@ def jvp_jaxpr(jaxpr, nonzeros, instantiate):
|
||||
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
|
||||
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
|
||||
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
|
||||
pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
|
||||
jaxpr_out, pvals_out, literals_out = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
|
||||
avals_out, _ = unzip2(pvals_out)
|
||||
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
|
||||
jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out)
|
||||
return jaxpr_out, out_nonzeros()
|
||||
|
||||
@ -644,7 +642,7 @@ def defvjp_all(prim, custom_vjp):
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True)
|
||||
else:
|
||||
with core.initial_style_staging():
|
||||
with core.initial_style_staging(): # type: ignore
|
||||
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals,
|
||||
instantiate=True)
|
||||
tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
|
||||
@ -680,9 +678,8 @@ def defvjp2(prim, *vjps):
|
||||
defvjp_all(prim, vjpmaker)
|
||||
|
||||
|
||||
# TODO(mattjj): remove when omnistaging fully lands
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global jvp_jaxpr
|
||||
|
||||
def jvp_jaxpr(jaxpr, nonzeros, instantiate):
|
||||
@ -691,6 +688,8 @@ def omnistaging_enabler() -> None:
|
||||
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
|
||||
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
|
||||
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
|
||||
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
|
||||
pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
|
||||
jaxpr_out, pvals_out, literals_out = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
|
||||
avals_out, _ = unzip2(pvals_out)
|
||||
jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out)
|
||||
return jaxpr_out, out_nonzeros()
|
||||
|
@ -375,10 +375,8 @@ def batch_jaxpr(jaxpr, size, batched, instantiate):
|
||||
f, batched_out = batched_traceable(f, size, batched, instantiate)
|
||||
avals_in = [_promote_aval_rank(size, a) if b else a
|
||||
for a, b in zip(jaxpr.in_avals, batched)]
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
|
||||
jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True)
|
||||
avals_out, _ = unzip2(pvals_out)
|
||||
jaxpr_out = core.TypedJaxpr(jaxpr_out, consts_out, avals_in, avals_out)
|
||||
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||
jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out)
|
||||
return jaxpr_out, batched_out()
|
||||
|
||||
@lu.transformation_with_aux
|
||||
@ -427,8 +425,8 @@ def _merge_bdims(x, y):
|
||||
return x # arbitrary
|
||||
|
||||
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global batch_jaxpr
|
||||
|
||||
def batch_jaxpr(jaxpr, size, batched, instantiate):
|
||||
@ -436,8 +434,10 @@ def omnistaging_enabler() -> None:
|
||||
f, batched_out = batched_traceable(f, size, batched, instantiate)
|
||||
avals_in = [_promote_aval_rank(size, a) if b else a
|
||||
for a, b in zip(jaxpr.in_avals, batched)]
|
||||
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||
jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out)
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
|
||||
jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True)
|
||||
avals_out, _ = unzip2(pvals_out)
|
||||
jaxpr_out = core.TypedJaxpr(jaxpr_out, consts_out, avals_in, avals_out)
|
||||
return jaxpr_out, batched_out()
|
||||
|
||||
|
||||
|
@ -217,9 +217,9 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang
|
||||
ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
|
||||
complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True)
|
||||
else:
|
||||
ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
|
||||
ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( # type: ignore
|
||||
complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals),
|
||||
instantiate=True, stage_out=False)
|
||||
instantiate=True, stage_out=False) # type: ignore
|
||||
assert not ivjp_jaxpr.constvars # That might happen some time, but don't bother until then
|
||||
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
|
||||
ivjp_jaxpr = core.TypedJaxpr(ivjp_jaxpr, [], in_avals, out_avals)
|
||||
@ -234,8 +234,8 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore
|
||||
ivjp_jaxpr, unknowns, instantiate=False) # type:ignore
|
||||
else:
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr(
|
||||
ivjp_jaxpr, unknowns, instantiate=False, trace_type=None)
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore
|
||||
ivjp_jaxpr, unknowns, instantiate=False, trace_type=None) # type: ignore
|
||||
unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs])
|
||||
# Make sure we're able to compute all cotangents. We don't really care if we
|
||||
# can reconstruct or primals or not, although failure to do so might result in
|
||||
|
@ -17,7 +17,7 @@ from collections import namedtuple
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
|
||||
List, Union, cast, Type, Set)
|
||||
List, Union, cast, Type, no_type_check)
|
||||
from weakref import ref
|
||||
|
||||
import numpy as np
|
||||
@ -165,8 +165,8 @@ class JaxprTrace(Trace):
|
||||
# We use process_call to handle both call and map primitives.
|
||||
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
|
||||
if not config.omnistaging_enabled:
|
||||
if (self.main.trace_type is StagingJaxprTrace
|
||||
and primitive in staged_out_calls):
|
||||
if (self.main.trace_type is StagingJaxprTrace # type: ignore
|
||||
and primitive in staged_out_calls): # type: ignore
|
||||
tracers = map(self.instantiate_const_abstracted, tracers)
|
||||
|
||||
if primitive in call_partial_eval_rules:
|
||||
@ -284,21 +284,6 @@ class JaxprTrace(Trace):
|
||||
env_tracers = map(self.full_raise, env)
|
||||
return jaxpr, out_pvs, consts, env_tracers
|
||||
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
||||
# See comment at top of `JaxprTrace`. This method should be reachable
|
||||
# only when we stage out, and in that case we drop the custom differentiation
|
||||
# rules, because we do not need them.
|
||||
assert self.main.trace_type is StagingJaxprTrace
|
||||
return fun.call_wrapped(*tracers)
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
||||
# See comment in the above process_custom_jvp_call method.
|
||||
assert self.main.trace_type is StagingJaxprTrace
|
||||
return fun.call_wrapped(*tracers)
|
||||
|
||||
|
||||
class StagingJaxprTrace(JaxprTrace): pass
|
||||
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts):
|
||||
@ -320,7 +305,7 @@ def abstract_eval_fun(fun, *avals, **params):
|
||||
else:
|
||||
pvals_in = [PartialVal.unknown(a) for a in avals]
|
||||
_, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
|
||||
instantiate=True, stage_out=True)
|
||||
instantiate=True, stage_out=True) # type: ignore
|
||||
avals_out, _ = unzip2(pvals_out)
|
||||
for aval_out in avals_out:
|
||||
assert isinstance(aval_out, AbstractValue) # instantiate=True
|
||||
@ -370,10 +355,8 @@ class JaxprTracer(Tracer):
|
||||
|
||||
# TODO(necula): this should return a TypedJaxpr
|
||||
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
instantiate: Union[bool, Sequence[bool]] = False,
|
||||
stage_out=False, bottom=False,
|
||||
trace_type: Optional[Type[Trace]] = None) \
|
||||
-> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
|
||||
instantiate: Union[bool, Sequence[bool]] = False,
|
||||
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
|
||||
"""Traces a function into a Jaxpr, given PartialVals for inputs.
|
||||
|
||||
Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
|
||||
@ -416,8 +399,7 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
|
||||
consts = [3, 6] # values for `ka` and `kb` constvars
|
||||
"""
|
||||
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
|
||||
with core.new_main(trace_type, bottom=bottom) as main:
|
||||
with core.new_main(JaxprTrace) as main:
|
||||
fun = trace_to_subjaxpr(fun, main, instantiate)
|
||||
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
||||
assert not env
|
||||
@ -574,7 +556,6 @@ def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, Abst
|
||||
|
||||
def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool],
|
||||
instantiate: Union[bool, Sequence[bool]],
|
||||
trace_type: Optional[Type[core.Trace]]
|
||||
) -> Tuple[TypedJaxpr, TypedJaxpr, Sequence[bool]]:
|
||||
"""Specializes a Jaxpr given an indication of which inputs are known.
|
||||
|
||||
@ -616,19 +597,16 @@ def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool],
|
||||
cell = []
|
||||
def fun(*vals):
|
||||
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
|
||||
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
|
||||
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate,
|
||||
trace_type=trace_type)
|
||||
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
|
||||
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
|
||||
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
|
||||
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
|
||||
return out_consts_2 + consts_2
|
||||
|
||||
# The abstract_unit here doesn't really matter, because trace_to_jaxpr completely ignores
|
||||
# the avals, and it will never actually reach any primitives, because the `fun` above will
|
||||
# execute the jaxpr with the right avals (it reconstructs `pvals` inside).
|
||||
pvals = [PartialVal.unknown(abstract_unit) if uk else PartialVal.unknown(aval)
|
||||
for aval, uk in zip(jaxpr.in_avals, unknowns)]
|
||||
jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True)
|
||||
# For jaxpr_known we pass core.unit for the unknown inputs, and known PartialVal for the
|
||||
# known inputs.
|
||||
in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)]
|
||||
jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
|
||||
(out_pvs_2, jaxpr_2, num_res), = cell
|
||||
assert len(jaxpr_2.constvars) == num_res
|
||||
|
||||
@ -647,11 +625,10 @@ def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool],
|
||||
in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
|
||||
out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
|
||||
# out_avals_1 and in_avals_2 need the residuals added
|
||||
out_pvs, _ = unzip2(out_pvals)
|
||||
res_avals = out_pvs[len(jaxpr.out_avals):]
|
||||
res_avals = out_avals[len(jaxpr.out_avals):]
|
||||
assert len(res_avals) == num_res
|
||||
out_avals_1 = out_avals_1 + res_avals
|
||||
in_avals_2 = in_avals_2 + res_avals
|
||||
out_avals_1 = [*out_avals_1, *res_avals]
|
||||
in_avals_2 = [*in_avals_2, *res_avals]
|
||||
|
||||
typed_jaxpr_1 = TypedJaxpr(jaxpr_1, consts_1, in_avals_1, out_avals_1)
|
||||
typed_jaxpr_2 = TypedJaxpr(jaxpr_2, (), in_avals_2, out_avals_2)
|
||||
@ -685,7 +662,7 @@ def _remat_partial_eval(trace, _, f, tracers, params):
|
||||
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
|
||||
f, in_pvals, partial(remat_call_p.bind, **params))
|
||||
else:
|
||||
with core.initial_style_staging():
|
||||
with core.initial_style_staging(): # type: ignore
|
||||
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
|
||||
f, in_pvals, partial(remat_call_p.bind, **params))
|
||||
|
||||
@ -710,7 +687,7 @@ def _remat_partial_eval(trace, _, f, tracers, params):
|
||||
typed_jaxpr, in_unknowns, instantiate=False) # type: ignore
|
||||
else:
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
|
||||
typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type)
|
||||
typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type) # type: ignore
|
||||
out_knowns = [not b for b in out_unknowns]
|
||||
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
|
||||
|
||||
@ -839,26 +816,19 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
def _contents(self):
|
||||
return ()
|
||||
|
||||
def __bool__(self):
|
||||
self._concretization_error('__bool__')
|
||||
|
||||
def _concretization_error(self, name):
|
||||
msgs = self._progenitor_messages()
|
||||
msg = (f"Abstract tracer value passed to {name} for which a concrete value "
|
||||
"is required.\n"
|
||||
f"While tracing the function {self._trace.main.source_info}, "
|
||||
"this tracer originated from using JAX operations on these lines:"
|
||||
"\n\n" + "\n\n".join(msgs) + "\n\n"
|
||||
"See the above traceback for where this tracer was encountered.")
|
||||
raise core.ConcretizationTypeError(msg)
|
||||
|
||||
def _progenitor_messages(self):
|
||||
def _origin_msg(self):
|
||||
progenitor_eqns = self._trace.frame.find_progenitors(self)
|
||||
# TODO mention which jit this tracer belongs to
|
||||
msgs = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n"
|
||||
f" from line {source_info_util.summarize(eqn.source_info)}"
|
||||
for eqn in progenitor_eqns]
|
||||
return msgs
|
||||
if msgs:
|
||||
origin = (f"While tracing the function {self._trace.main.source_info}, "
|
||||
"this value became a tracer due to JAX operations on these lines:"
|
||||
"\n\n" + "\n\n".join(msgs))
|
||||
else:
|
||||
origin = ("The error occured while tracing the function "
|
||||
f"{self._trace.main.source_info}.")
|
||||
return origin
|
||||
|
||||
class JaxprStackFrame:
|
||||
__slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
|
||||
@ -883,7 +853,10 @@ class JaxprStackFrame:
|
||||
return jaxpr, out_avals, constvals
|
||||
|
||||
def find_progenitors(self, tracer):
|
||||
active_vars = {self.tracer_to_var[id(tracer)]}
|
||||
var = self.tracer_to_var.get(id(tracer))
|
||||
if not var:
|
||||
return []
|
||||
active_vars = {var}
|
||||
for eqn in self.eqns[::-1]:
|
||||
produced = set(eqn.outvars) & active_vars
|
||||
if produced:
|
||||
@ -1078,19 +1051,60 @@ def fun_sourceinfo(fun):
|
||||
return "<unknown>"
|
||||
|
||||
|
||||
# TODO(mattjj): remove when omnistaging fully lands
|
||||
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
global trace_to_jaxpr, partial_eval_jaxpr
|
||||
|
||||
del JaxprTrace.process_custom_jvp_call
|
||||
del JaxprTrace.process_custom_vjp_call
|
||||
@config.register_omnistaging_disabler
|
||||
@no_type_check
|
||||
def omnistaging_disabler() -> None:
|
||||
global trace_to_jaxpr, partial_eval_jaxpr, staged_out_calls, StagingJaxprTrace
|
||||
|
||||
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
instantiate: Union[bool, Sequence[bool]] = False,
|
||||
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
|
||||
with core.new_main(JaxprTrace) as main:
|
||||
instantiate: Union[bool, Sequence[bool]] = False,
|
||||
stage_out=False, bottom=False,
|
||||
trace_type: Optional[Type[Trace]] = None,
|
||||
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
|
||||
"""Traces a function into a Jaxpr, given PartialVals for inputs.
|
||||
|
||||
Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
|
||||
computation that depends on unknown inputs. The `out_pvals` are the PartialVal
|
||||
for the outputs. The intermediate values that depend only on known inputs and
|
||||
are needed to compute the output of `jaxpr` are in `consts` and are passed in
|
||||
as the constvars of the `jaxpr`. The handling of the known outputs depends on
|
||||
`instantiate`.
|
||||
|
||||
For example, given `fun` defined as follows::
|
||||
|
||||
def fun(ki, ui): # ki will be a known input in this example
|
||||
ka = ki + 2
|
||||
kb = ka + 3
|
||||
return (kb, ui + ka)
|
||||
|
||||
with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
|
||||
computation that depends on unknown inputs is `ui + ka` and will be the only
|
||||
computation in the body of the `jaxpr`. This computation depends on the known
|
||||
intermediate value `ka`, which will be computed statically. Currently, such
|
||||
constants are either embedded in the Jaxpr if they are scalars, or passed as a
|
||||
constvar to `jaxpr`, and then the value of the actual constant will be in
|
||||
`consts`:
|
||||
|
||||
When `instantiate=False` we get::
|
||||
|
||||
jaxpr =
|
||||
{ lambda ka ; ki ui.
|
||||
let c = add ui ka
|
||||
in (*, c) } # known outputs are `*`
|
||||
out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
|
||||
consts = [3] # the constant for `ka`
|
||||
|
||||
When `instantiate=True` we get::
|
||||
|
||||
jaxpr =
|
||||
{ lambda ka kb ; ki ui.
|
||||
let c = add ui ka
|
||||
in (kb, c) } # known output are explicit
|
||||
out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
|
||||
consts = [3, 6] # values for `ka` and `kb` constvars
|
||||
"""
|
||||
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
|
||||
with core.new_main(trace_type, bottom=bottom) as main:
|
||||
fun = trace_to_subjaxpr(fun, main, instantiate)
|
||||
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
||||
assert not env
|
||||
@ -1100,6 +1114,7 @@ def omnistaging_enabler() -> None:
|
||||
|
||||
def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool],
|
||||
instantiate: Union[bool, Sequence[bool]],
|
||||
trace_type: Optional[Type[core.Trace]]
|
||||
) -> Tuple[TypedJaxpr, TypedJaxpr, Sequence[bool]]:
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
|
||||
@ -1107,15 +1122,18 @@ def omnistaging_enabler() -> None:
|
||||
def fun(*vals):
|
||||
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
|
||||
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
|
||||
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
|
||||
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate,
|
||||
trace_type=trace_type)
|
||||
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
|
||||
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
|
||||
return out_consts_2 + consts_2
|
||||
|
||||
# For jaxpr_known we pass core.unit for the unknown inputs, and known PartialVal for the
|
||||
# known inputs.
|
||||
in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)]
|
||||
jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
|
||||
# The abstract_unit here doesn't really matter, because trace_to_jaxpr completely ignores
|
||||
# the avals, and it will never actually reach any primitives, because the `fun` above will
|
||||
# execute the jaxpr with the right avals (it reconstructs `pvals` inside).
|
||||
pvals = [PartialVal.unknown(abstract_unit) if uk else PartialVal.unknown(aval)
|
||||
for aval, uk in zip(jaxpr.in_avals, unknowns)]
|
||||
jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True)
|
||||
(out_pvs_2, jaxpr_2, num_res), = cell
|
||||
assert len(jaxpr_2.constvars) == num_res
|
||||
|
||||
@ -1134,13 +1152,32 @@ def omnistaging_enabler() -> None:
|
||||
in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
|
||||
out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
|
||||
# out_avals_1 and in_avals_2 need the residuals added
|
||||
res_avals = out_avals[len(jaxpr.out_avals):]
|
||||
out_pvs, _ = unzip2(out_pvals)
|
||||
res_avals = out_pvs[len(jaxpr.out_avals):]
|
||||
assert len(res_avals) == num_res
|
||||
out_avals_1 = [*out_avals_1, *res_avals]
|
||||
in_avals_2 = [*in_avals_2, *res_avals]
|
||||
out_avals_1 = out_avals_1 + res_avals
|
||||
in_avals_2 = in_avals_2 + res_avals
|
||||
|
||||
typed_jaxpr_1 = TypedJaxpr(jaxpr_1, consts_1, in_avals_1, out_avals_1)
|
||||
typed_jaxpr_2 = TypedJaxpr(jaxpr_2, (), in_avals_2, out_avals_2)
|
||||
return typed_jaxpr_1, typed_jaxpr_2, uk_out
|
||||
|
||||
staged_out_calls: Set[core.Primitive] = set()
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
||||
# See comment at top of `JaxprTrace`. This method should be reachable
|
||||
# only when we stage out, and in that case we drop the custom differentiation
|
||||
# rules, because we do not need them.
|
||||
if not config.omnistaging_enabled:
|
||||
assert self.main.trace_type is StagingJaxprTrace
|
||||
return fun.call_wrapped(*tracers)
|
||||
JaxprTrace.process_custom_jvp_call = process_custom_jvp_call
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
||||
# See comment in the above process_custom_jvp_call method.
|
||||
if not config.omnistaging_enabled:
|
||||
assert self.main.trace_type is StagingJaxprTrace
|
||||
return fun.call_wrapped(*tracers)
|
||||
JaxprTrace.process_custom_vjp_call = process_custom_vjp_call
|
||||
|
||||
staged_out_calls = set()
|
||||
|
||||
class StagingJaxprTrace(JaxprTrace): pass
|
||||
|
@ -34,7 +34,7 @@ import itertools as it
|
||||
import operator as op
|
||||
import threading
|
||||
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
|
||||
Type, Union)
|
||||
Type, Union, no_type_check)
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
@ -330,91 +330,6 @@ pxla_result_handlers[ShapedArray] = array_result_handler
|
||||
pxla_result_handlers[ConcreteArray] = array_result_handler
|
||||
|
||||
|
||||
### applying parallel primitives in op-by-op Python dispatch
|
||||
|
||||
# There are at least two cases where we might want to evaluate a parallel
|
||||
# primitive dispatched from Python, rather than being staged out:
|
||||
# 1. axis_size = psum(1, 'axis_name'),
|
||||
# 2. to enable an implicit outermost pmap-like context for multi-host
|
||||
# multi-controller SPMD programs.
|
||||
# In each case, we can't rely on any data dependence on a pmap trace; instead we
|
||||
# need some dynamic context, basically modeling the axis name environment stack.
|
||||
# To handle the former case, we don't need to communicate at all; we instead
|
||||
# have a table of parallel_pure_rules. To handle the latter case, we'll have a
|
||||
# globally-scoped root environment frame and compile and execute a single-op
|
||||
# XLA collective.
|
||||
|
||||
class DynamicAxisEnvFrame(object):
|
||||
__slots__ = ["name", "pmap_trace", "hard_size"]
|
||||
def __init__(self, name, pmap_trace, hard_size):
|
||||
self.name = name
|
||||
self.pmap_trace = pmap_trace
|
||||
self.hard_size = hard_size
|
||||
|
||||
class DynamicAxisEnv(list):
|
||||
def __contains__(self, axis_name):
|
||||
return axis_name in (frame.name for frame in self)
|
||||
|
||||
def __getitem__(self, axis_name):
|
||||
if axis_name not in self:
|
||||
raise NameError("unbound axis name: {}".format(axis_name))
|
||||
for frame in reversed(self):
|
||||
if frame.name == axis_name:
|
||||
return frame
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return tuple(frame.hard_size for frame in self)
|
||||
|
||||
@property
|
||||
def nreps(self):
|
||||
return prod(frame.hard_size for frame in self)
|
||||
|
||||
class _ThreadLocalState(threading.local):
|
||||
def __init__(self):
|
||||
self.dynamic_axis_env = DynamicAxisEnv()
|
||||
|
||||
_thread_local_state = _ThreadLocalState()
|
||||
|
||||
@contextmanager
|
||||
def extend_dynamic_axis_env(axis_name, pmap_trace, hard_size):
|
||||
dynamic_axis_env = _thread_local_state.dynamic_axis_env
|
||||
dynamic_axis_env.append(DynamicAxisEnvFrame(axis_name, pmap_trace, hard_size))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
dynamic_axis_env.pop()
|
||||
|
||||
def unmapped_device_count(backend=None):
|
||||
dynamic_axis_env = _thread_local_state.dynamic_axis_env
|
||||
mapped = prod(frame.hard_size for frame in dynamic_axis_env)
|
||||
unmapped, ragged = divmod(xb.device_count(backend), mapped)
|
||||
assert not ragged and unmapped > 0
|
||||
return unmapped
|
||||
|
||||
def apply_parallel_primitive(prim, *args, **params):
|
||||
# This is the op-by-op version of applying a collective primitive, like a psum
|
||||
# that doesn't have a data dependence on the argument of a pmap function. In
|
||||
# particular, this code gets hit when we write `axis_size = psum(1, 'i')`. We
|
||||
# look up information in the dynamic axis env.
|
||||
dynamic_axis_env = _thread_local_state.dynamic_axis_env
|
||||
axis_name = params.pop('axis_name')
|
||||
axis_index_groups = params.pop('axis_index_groups')
|
||||
if axis_index_groups is not None:
|
||||
shape = (len(axis_index_groups[0]),)
|
||||
else:
|
||||
logical_size = lambda frame: frame.hard_size
|
||||
if isinstance(axis_name, (list, tuple)):
|
||||
shape = tuple(logical_size(dynamic_axis_env[name]) for name in axis_name)
|
||||
else:
|
||||
shape = (logical_size(dynamic_axis_env[axis_name]),)
|
||||
return parallel_pure_rules[prim](*args, shape=shape, **params)
|
||||
|
||||
parallel_pure_rules: Dict[core.Primitive, Callable] = {}
|
||||
|
||||
|
||||
### lazy device-memory persistence and result handling
|
||||
|
||||
class ShardedDeviceArray(xla.DeviceArray):
|
||||
@ -636,7 +551,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
|
||||
# We add a dummy first invar, to carry the trace details to `dynamic_fun`
|
||||
pval = pe.PartialVal.unknown(core.abstract_unit) # dummy value for axis env
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
|
||||
dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True)
|
||||
dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True) # type: ignore
|
||||
jaxpr.invars = jaxpr.invars[1:] # ignore dummy
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
@ -795,7 +710,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
|
||||
handle_outs = avals_to_results_handler( # type: ignore
|
||||
axis_size, num_local_replicas, num_partitions, out_parts, out_avals)
|
||||
else:
|
||||
handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas,
|
||||
handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas, # type: ignore
|
||||
num_partitions, out_parts,
|
||||
out_pvals, compiled.local_devices(),
|
||||
backend)
|
||||
@ -884,17 +799,20 @@ def get_num_partitions(*partitions):
|
||||
class ResultToPopulate: pass
|
||||
result_to_populate = ResultToPopulate()
|
||||
|
||||
def _pvals_to_results_handler(
|
||||
size, nrep, npart,
|
||||
out_parts: Optional[Tuple[PartitionsOrReplicated, ...]],
|
||||
out_pvals, devices, backend):
|
||||
nouts = len(out_pvals)
|
||||
def avals_to_results_handler(size, nrep, npart, out_parts, out_avals):
|
||||
nouts = len(out_avals)
|
||||
if out_parts is None:
|
||||
out_parts = (None,) * len(out_pvals)
|
||||
handlers = [
|
||||
_pval_to_result_handler(size, nrep, npart, parts, pval, devices, backend)
|
||||
for pval, parts in safe_zip(out_pvals, out_parts) # type: ignore
|
||||
]
|
||||
out_parts = (None,) * len(out_avals)
|
||||
|
||||
# TODO(mattjj,skyewm): can probably clean up this logic
|
||||
out_specs = [_pmap_sharding_spec(nrep, size, npart, parts, aval, True)
|
||||
if aval is not core.abstract_unit else None
|
||||
for parts, aval in zip(out_parts, out_avals)]
|
||||
out_indices = [spec_to_indices(core.unmapped_aval(size, aval).shape, spec)
|
||||
if aval is not core.abstract_unit else None
|
||||
for aval, spec in zip(out_avals, out_specs)] # pytype: disable=attribute-error
|
||||
handlers = [aval_to_result_handler(spec, idcs, core.unmapped_aval(size, aval))
|
||||
for spec, idcs, aval in zip(out_specs, out_indices, out_avals)]
|
||||
|
||||
def handler(out_bufs):
|
||||
assert nrep * npart == len(out_bufs)
|
||||
@ -903,7 +821,7 @@ def _pvals_to_results_handler(
|
||||
for i, buf in enumerate(tuple_buf):
|
||||
buffers[i][r] = buf
|
||||
assert not any(buf is result_to_populate for bufs in buffers
|
||||
for buf in bufs)
|
||||
for buf in bufs)
|
||||
return [h(bufs) for h, bufs in zip(handlers, buffers)]
|
||||
return handler
|
||||
|
||||
@ -947,37 +865,6 @@ def replicate(val, axis_size, nrep, devices=None, backend=None):
|
||||
device_buffers = [xla.device_put(val, d) for d in devices]
|
||||
return ShardedDeviceArray(replicated_aval, sharding_spec, device_buffers)
|
||||
|
||||
def _pval_to_result_handler(axis_size, nrep, npart, parts, pval, devices, backend):
|
||||
if devices:
|
||||
assert all(d.host_id == xb.host_id(backend) for d in devices)
|
||||
pv, const = pval
|
||||
if pv is None:
|
||||
if nrep is None:
|
||||
nrep = axis_size
|
||||
# If 'const' is a ShardedDeviceArray, it must have come from a pmap nested
|
||||
# inside the one we're currently evaluating, and we should replicate
|
||||
# 'const' across the total number of devices needed. We don't necessarily
|
||||
# know the nested pmap's axis_size (e.g. the jaxpr for
|
||||
# pmap(pmap(lambda x: 3)) is trivial, with no pmaps), but we can use the
|
||||
# axis size of the output 'const'.
|
||||
# TODO: we might be doing unnecessary device transfers in the inner pmap.
|
||||
if isinstance(const, ShardedDeviceArray):
|
||||
nrep *= len(const)
|
||||
|
||||
bcast_const = (core.unit if const is core.unit
|
||||
else replicate(const, axis_size, nrep, devices, backend)) # type: ignore
|
||||
return lambda _: bcast_const # type: ignore
|
||||
else:
|
||||
if pv is not core.abstract_unit:
|
||||
unsharded_aval = ShapedArray((axis_size,) + pv.shape, pv.dtype)
|
||||
sharding_spec = _pmap_sharding_spec(nrep, axis_size, npart, parts, pv,
|
||||
True)
|
||||
indices = spec_to_indices(unsharded_aval.shape, sharding_spec)
|
||||
else:
|
||||
sharding_spec = indices = None
|
||||
unsharded_aval = pv
|
||||
return aval_to_result_handler(sharding_spec, indices, unsharded_aval)
|
||||
|
||||
def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, mapped):
|
||||
"""Sharding spec for arguments or results of a pmap.
|
||||
Args:
|
||||
@ -1014,7 +901,6 @@ def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, mapped):
|
||||
replication_factors=[(replication_factor * axis_size, 0)] +
|
||||
shard_spec.replication_factors)
|
||||
|
||||
|
||||
def partitioned_sharding_spec(num_partitions: int,
|
||||
partitions: Optional[Sequence[int]], aval):
|
||||
if partitions is None:
|
||||
@ -1043,7 +929,6 @@ def execute_replicated(compiled, backend, in_handler, out_handler, *args):
|
||||
xla_pmap_p = core.MapPrimitive('xla_pmap')
|
||||
xla_pmap = xla_pmap_p.bind
|
||||
xla_pmap_p.def_impl(xla_pmap_impl)
|
||||
pe.staged_out_calls.add(xla_pmap_p)
|
||||
|
||||
# Set param update handlers to update `donated_invars` just like xla_call_p
|
||||
pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p]
|
||||
@ -1279,41 +1164,35 @@ soft_pmap_p.def_impl(soft_pmap_impl)
|
||||
|
||||
soft_pmap_rules: Dict[core.Primitive, Callable] = {}
|
||||
|
||||
def deleted_with_omnistaging(*a, **k):
|
||||
assert False, "Should be deleted"
|
||||
|
||||
@contextmanager
|
||||
def maybe_extend_axis_env(*args, **kwargs):
|
||||
yield
|
||||
with core.extend_axis_env(*args, **kwargs):
|
||||
yield
|
||||
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enable() -> None:
|
||||
@config.register_omnistaging_disabler
|
||||
@no_type_check
|
||||
def omnistaging_disabler() -> None:
|
||||
global DynamicAxisEnvFrame, DynamicAxisEnv, _ThreadLocalState, \
|
||||
_thread_local_state, extend_dynamic_axis_env, unmapped_device_count, \
|
||||
apply_parallel_primitive, parallel_pure_rules, \
|
||||
_pvals_to_results_handler, _pval_to_result_handler, replicate, \
|
||||
avals_to_results_handler, maybe_extend_axis_env
|
||||
del DynamicAxisEnvFrame, DynamicAxisEnv, _ThreadLocalState, \
|
||||
_thread_local_state, extend_dynamic_axis_env, unmapped_device_count, \
|
||||
_pvals_to_results_handler, _pval_to_result_handler, replicate
|
||||
avals_to_results_handler, axis_index, maybe_extend_axis_env
|
||||
|
||||
apply_parallel_primitive = deleted_with_omnistaging
|
||||
parallel_pure_rules.clear()
|
||||
@contextmanager
|
||||
def maybe_extend_axis_env(*args, **kwargs):
|
||||
yield
|
||||
|
||||
def avals_to_results_handler(size, nrep, npart, out_parts, out_avals):
|
||||
nouts = len(out_avals)
|
||||
def _pvals_to_results_handler(
|
||||
size, nrep, npart,
|
||||
out_parts: Optional[Tuple[PartitionsOrReplicated, ...]],
|
||||
out_pvals, devices, backend):
|
||||
nouts = len(out_pvals)
|
||||
if out_parts is None:
|
||||
out_parts = (None,) * len(out_avals)
|
||||
|
||||
# TODO(mattjj,skyewm): can probably clean up this logic
|
||||
out_specs = [_pmap_sharding_spec(nrep, size, npart, parts, aval, True)
|
||||
if aval is not core.abstract_unit else None
|
||||
for parts, aval in zip(out_parts, out_avals)]
|
||||
out_indices = [spec_to_indices(core.unmapped_aval(size, aval).shape, spec)
|
||||
if aval is not core.abstract_unit else None
|
||||
for aval, spec in zip(out_avals, out_specs)] # pytype: disable=attribute-error
|
||||
handlers = [aval_to_result_handler(spec, idcs, core.unmapped_aval(size, aval))
|
||||
for spec, idcs, aval in zip(out_specs, out_indices, out_avals)]
|
||||
out_parts = (None,) * len(out_pvals)
|
||||
handlers = [
|
||||
_pval_to_result_handler(size, nrep, npart, parts, pval, devices, backend)
|
||||
for pval, parts in safe_zip(out_pvals, out_parts) # type: ignore
|
||||
]
|
||||
|
||||
def handler(out_bufs):
|
||||
assert nrep * npart == len(out_bufs)
|
||||
@ -1326,7 +1205,105 @@ def omnistaging_enable() -> None:
|
||||
return [h(bufs) for h, bufs in zip(handlers, buffers)]
|
||||
return handler
|
||||
|
||||
def _pval_to_result_handler(axis_size, nrep, npart, parts, pval, devices, backend):
|
||||
if devices:
|
||||
assert all(d.host_id == xb.host_id(backend) for d in devices)
|
||||
pv, const = pval
|
||||
if pv is None:
|
||||
if nrep is None:
|
||||
nrep = axis_size
|
||||
# If 'const' is a ShardedDeviceArray, it must have come from a pmap nested
|
||||
# inside the one we're currently evaluating, and we should replicate
|
||||
# 'const' across the total number of devices needed. We don't necessarily
|
||||
# know the nested pmap's axis_size (e.g. the jaxpr for
|
||||
# pmap(pmap(lambda x: 3)) is trivial, with no pmaps), but we can use the
|
||||
# axis size of the output 'const'.
|
||||
# TODO: we might be doing unnecessary device transfers in the inner pmap.
|
||||
if isinstance(const, ShardedDeviceArray):
|
||||
nrep *= len(const)
|
||||
|
||||
bcast_const = (core.unit if const is core.unit
|
||||
else replicate(const, axis_size, nrep, devices, backend)) # type: ignore
|
||||
return lambda _: bcast_const # type: ignore
|
||||
else:
|
||||
if pv is not core.abstract_unit:
|
||||
unsharded_aval = ShapedArray((axis_size,) + pv.shape, pv.dtype)
|
||||
sharding_spec = _pmap_sharding_spec(nrep, axis_size, npart, parts, pv,
|
||||
True)
|
||||
indices = spec_to_indices(unsharded_aval.shape, sharding_spec)
|
||||
else:
|
||||
sharding_spec = indices = None
|
||||
unsharded_aval = pv
|
||||
return aval_to_result_handler(sharding_spec, indices, unsharded_aval)
|
||||
|
||||
@contextmanager
|
||||
def maybe_extend_axis_env(*args, **kwargs):
|
||||
with core.extend_axis_env(*args, **kwargs):
|
||||
def extend_dynamic_axis_env(axis_name, pmap_trace, hard_size):
|
||||
dynamic_axis_env = _thread_local_state.dynamic_axis_env
|
||||
dynamic_axis_env.append(DynamicAxisEnvFrame(axis_name, pmap_trace, hard_size))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
dynamic_axis_env.pop()
|
||||
|
||||
def unmapped_device_count(backend=None):
|
||||
dynamic_axis_env = _thread_local_state.dynamic_axis_env
|
||||
mapped = prod(frame.hard_size for frame in dynamic_axis_env)
|
||||
unmapped, ragged = divmod(xb.device_count(backend), mapped)
|
||||
assert not ragged and unmapped > 0
|
||||
return unmapped
|
||||
|
||||
def apply_parallel_primitive(prim, *args, **params):
|
||||
# This is the op-by-op version of applying a collective primitive, like a psum
|
||||
# that doesn't have a data dependence on the argument of a pmap function. In
|
||||
# particular, this code gets hit when we write `axis_size = psum(1, 'i')`. We
|
||||
# look up information in the dynamic axis env.
|
||||
dynamic_axis_env = _thread_local_state.dynamic_axis_env
|
||||
axis_name = params.pop('axis_name')
|
||||
axis_index_groups = params.pop('axis_index_groups')
|
||||
if axis_index_groups is not None:
|
||||
shape = (len(axis_index_groups[0]),)
|
||||
else:
|
||||
logical_size = lambda frame: frame.hard_size
|
||||
if isinstance(axis_name, (list, tuple)):
|
||||
shape = tuple(logical_size(dynamic_axis_env[name]) for name in axis_name)
|
||||
else:
|
||||
shape = (logical_size(dynamic_axis_env[axis_name]),)
|
||||
return parallel_pure_rules[prim](*args, shape=shape, **params)
|
||||
|
||||
pe.staged_out_calls.add(xla_pmap_p) # type: ignore
|
||||
|
||||
parallel_pure_rules = {} # type: ignore
|
||||
|
||||
class DynamicAxisEnvFrame(object):
|
||||
__slots__ = ["name", "pmap_trace", "hard_size"]
|
||||
def __init__(self, name, pmap_trace, hard_size):
|
||||
self.name = name
|
||||
self.pmap_trace = pmap_trace
|
||||
self.hard_size = hard_size
|
||||
|
||||
class DynamicAxisEnv(list):
|
||||
def __contains__(self, axis_name):
|
||||
return axis_name in (frame.name for frame in self)
|
||||
|
||||
def __getitem__(self, axis_name):
|
||||
if axis_name not in self:
|
||||
raise NameError("unbound axis name: {}".format(axis_name))
|
||||
for frame in reversed(self):
|
||||
if frame.name == axis_name:
|
||||
return frame
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return tuple(frame.hard_size for frame in self)
|
||||
|
||||
@property
|
||||
def nreps(self):
|
||||
return prod(frame.hard_size for frame in self)
|
||||
|
||||
class _ThreadLocalState(threading.local):
|
||||
def __init__(self):
|
||||
self.dynamic_axis_env = DynamicAxisEnv()
|
||||
|
||||
_thread_local_state = _ThreadLocalState()
|
||||
|
@ -41,10 +41,11 @@ def _map(f, *xs):
|
||||
class ResultToPopulate: pass
|
||||
result_to_populate = ResultToPopulate()
|
||||
|
||||
def _pvals_to_results_handler(nrep, npart, partitions, out_pvals):
|
||||
nouts = len(out_pvals)
|
||||
handlers = [_pval_to_result_handler(npart, parts, out_pval)
|
||||
for parts, out_pval in safe_zip(partitions, out_pvals)] # type: ignore
|
||||
|
||||
def _avals_to_results_handler(nrep, npart, partitions, out_avals):
|
||||
nouts = len(out_avals)
|
||||
handlers = [_aval_to_result_handler(npart, parts, out_aval)
|
||||
for parts, out_aval in safe_zip(partitions, out_avals)]
|
||||
|
||||
def handler(out_bufs):
|
||||
assert nrep * npart == len(out_bufs)
|
||||
@ -53,23 +54,18 @@ def _pvals_to_results_handler(nrep, npart, partitions, out_pvals):
|
||||
for i, buf in enumerate(tuple_buf):
|
||||
buffers[i][r] = buf
|
||||
assert not any(buf is result_to_populate for bufs in buffers
|
||||
for buf in bufs)
|
||||
for buf in bufs)
|
||||
return [h(bufs) for h, bufs in zip(handlers, buffers)]
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _pval_to_result_handler(npart, parts, pval):
|
||||
pv, const = pval
|
||||
if pv is None:
|
||||
raise NotImplementedError # TODO(skye): handle constant outputs
|
||||
def _aval_to_result_handler(npart, parts, aval):
|
||||
if aval is not core.abstract_unit:
|
||||
spec = pxla.partitioned_sharding_spec(npart, parts, aval)
|
||||
indices = pxla.spec_to_indices(aval.shape, spec)
|
||||
else:
|
||||
if pv is not core.abstract_unit:
|
||||
spec = pxla.partitioned_sharding_spec(npart, parts, pv)
|
||||
indices = pxla.spec_to_indices(pv.shape, spec)
|
||||
else:
|
||||
spec = indices = None
|
||||
return pxla.aval_to_result_handler(spec, indices, pv)
|
||||
spec = indices = None
|
||||
return pxla.aval_to_result_handler(spec, indices, aval)
|
||||
|
||||
|
||||
@lu.cache
|
||||
@ -84,7 +80,8 @@ def _sharded_callable(
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
|
||||
else:
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in abstract_args]
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=False, bottom=True)
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, # type: ignore
|
||||
instantiate=False, bottom=True) # type: ignore
|
||||
|
||||
# TODO(skye): add tests for equationless jaxpr cases
|
||||
if not jaxpr.eqns and all(outvar.aval is core.abstract_unit
|
||||
@ -138,7 +135,7 @@ def _sharded_callable(
|
||||
handle_outs = _avals_to_results_handler(nrep, num_partitions, out_parts, # type: ignore
|
||||
out_avals)
|
||||
else:
|
||||
handle_outs = _pvals_to_results_handler(nrep, num_partitions, out_parts,
|
||||
handle_outs = _pvals_to_results_handler(nrep, num_partitions, out_parts, # type: ignore
|
||||
out_pvals)
|
||||
return partial(_execute_spatially_partitioned, compiled, handle_args,
|
||||
handle_outs)
|
||||
@ -354,16 +351,14 @@ def with_sharding_constraint(x, partitions: Optional[PartitionSpec]):
|
||||
return sharding_constraint_p.bind(x, partitions=partitions)
|
||||
|
||||
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
global _avals_to_results_handler, _aval_to_result_handler, \
|
||||
_pvals_to_results_handler, _pval_to_result_handler
|
||||
del _pvals_to_results_handler, _pval_to_result_handler
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global _pvals_to_results_handler, _pval_to_result_handler
|
||||
|
||||
def _avals_to_results_handler(nrep, npart, partitions, out_avals):
|
||||
nouts = len(out_avals)
|
||||
handlers = [_aval_to_result_handler(npart, parts, out_aval)
|
||||
for parts, out_aval in safe_zip(partitions, out_avals)]
|
||||
def _pvals_to_results_handler(nrep, npart, partitions, out_pvals):
|
||||
nouts = len(out_pvals)
|
||||
handlers = [_pval_to_result_handler(npart, parts, out_pval)
|
||||
for parts, out_pval in safe_zip(partitions, out_pvals)] # type: ignore
|
||||
|
||||
def handler(out_bufs):
|
||||
assert nrep * npart == len(out_bufs)
|
||||
@ -377,11 +372,14 @@ def omnistaging_enabler() -> None:
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _aval_to_result_handler(npart, parts, aval):
|
||||
if aval is not core.abstract_unit:
|
||||
spec = pxla.partitioned_sharding_spec(npart, parts, aval)
|
||||
indices = pxla.spec_to_indices(aval.shape, spec)
|
||||
def _pval_to_result_handler(npart, parts, pval):
|
||||
pv, const = pval
|
||||
if pv is None:
|
||||
raise NotImplementedError # TODO(skye): handle constant outputs
|
||||
else:
|
||||
spec = indices = None
|
||||
return pxla.aval_to_result_handler(spec, indices, aval)
|
||||
if pv is not core.abstract_unit:
|
||||
spec = pxla.partitioned_sharding_spec(npart, parts, pv)
|
||||
indices = pxla.spec_to_indices(pv.shape, spec)
|
||||
else:
|
||||
spec = indices = None
|
||||
return pxla.aval_to_result_handler(spec, indices, pv)
|
||||
|
@ -597,8 +597,8 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
|
||||
else:
|
||||
pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
|
||||
jaxpr, pvals, consts = pe.trace_to_jaxpr(
|
||||
fun, pvals, instantiate=False, stage_out=True, bottom=True)
|
||||
jaxpr, pvals, consts = pe.trace_to_jaxpr( # type: ignore
|
||||
fun, pvals, instantiate=False, stage_out=True, bottom=True) # type: ignore
|
||||
map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
|
||||
jaxpr = apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
@ -608,7 +608,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
if config.omnistaging_enabled:
|
||||
result_handlers = tuple(aval_to_result_handler(device, a) for a in out_avals)
|
||||
else:
|
||||
result_handlers = tuple(map(partial(_pval_to_result_handler, device), pvals))
|
||||
result_handlers = tuple(map(partial(_pval_to_result_handler, device), pvals)) # type: ignore
|
||||
|
||||
# Computations that only produce constants and/or only rearrange their inputs,
|
||||
# which are often produced from partial evaluation, don't need compilation,
|
||||
@ -879,7 +879,7 @@ def lower_fun(fun, multiple_results):
|
||||
else:
|
||||
pvals = [pe.PartialVal.unknown(a) for a in avals]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(wrapped_fun, pvals, instantiate=True,
|
||||
stage_out=True)
|
||||
stage_out=True) # type: ignore
|
||||
xla_consts = _xla_consts(c, consts)
|
||||
outs = jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts, '', *xla_args)
|
||||
if multiple_results:
|
||||
@ -905,7 +905,7 @@ def lower_fun_initial_style(fun):
|
||||
else:
|
||||
pvals = [pe.PartialVal.unknown(a) for a in avals]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(
|
||||
lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True)
|
||||
lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True) # type: ignore
|
||||
xla_consts = _xla_consts(c, consts)
|
||||
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, name_stack,
|
||||
*xla_args)
|
||||
@ -1277,19 +1277,25 @@ def _call_translation_rule(c, axis_env, in_nodes, name_stack,
|
||||
call_translations[core.call_p] = _call_translation_rule
|
||||
|
||||
|
||||
# TODO(mattjj): remove when omnistaging fully lands
|
||||
def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
|
||||
div = xb.constant(c, np.array(axis_env.nreps // prod(axis_env.sizes),
|
||||
dtype=np.uint32))
|
||||
mod = xb.constant(c, np.array(axis_env.sizes[-1], dtype=np.uint32))
|
||||
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
|
||||
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
|
||||
parallel_translations[core.axis_index_p] = _axis_index_translation_rule # type: ignore
|
||||
|
||||
def _pval_to_result_handler(device, pval):
|
||||
pv, const = pval
|
||||
if pv is None:
|
||||
const = _device_put_impl(const, device) if device else const
|
||||
return lambda _: const
|
||||
else:
|
||||
return aval_to_result_handler(device, pv)
|
||||
|
||||
pe.staged_out_calls.add(xla_call_p)
|
||||
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global _pval_to_result_handler
|
||||
del _pval_to_result_handler
|
||||
|
||||
def _pval_to_result_handler(device, pval):
|
||||
pv, const = pval
|
||||
if pv is None:
|
||||
const = _device_put_impl(const, device) if device else const
|
||||
return lambda _: const
|
||||
else:
|
||||
return aval_to_result_handler(device, pv)
|
||||
|
||||
pe.staged_out_calls.add(xla_call_p) # type: ignore
|
||||
|
132
jax/lax/lax.py
132
jax/lax/lax.py
@ -22,6 +22,7 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from .. import core
|
||||
from .. import ad_util
|
||||
from .. import api
|
||||
@ -1373,30 +1374,8 @@ def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
|
||||
return top_k_p.bind(operand, k=k)
|
||||
|
||||
def tie_in(x: Array, y: Array) -> Array:
|
||||
"""Returns the value of ``y`` but with a fake data dependence on ``x``.
|
||||
|
||||
When staging to XLA (e.g. running under jit or pmap), values that don't depend
|
||||
on computation inputs are computed op-by-op, and folded into the XLA
|
||||
computation as constants.
|
||||
|
||||
``tie_in`` provides a way to explicitly stage values into the computation.
|
||||
When staging to XLA and ``x`` is already staged, then the result of ``tie_in``
|
||||
is ``y``, but staged to XLA. Downstream use of the result will also be staged
|
||||
to XLA.
|
||||
|
||||
For example, ``lax.sin(const)`` would be constant-folded if ``const`` is
|
||||
a constant array, but ``lax.sin(lax.tie_in(x, const))``, will be staged to
|
||||
XLA as long as ``x`` is staged to XLA.
|
||||
"""
|
||||
if config.omnistaging_enabled:
|
||||
return y
|
||||
else:
|
||||
return tie_in_p.bind(x, y)
|
||||
|
||||
# def tie_in(x: Array, y: Array) -> Array:
|
||||
# """Deprecated. Ignores ``x`` and returns ``y``."""
|
||||
# return y
|
||||
|
||||
"""Deprecated. Ignores ``x`` and returns ``y``."""
|
||||
return y
|
||||
|
||||
def full(shape: Shape, fill_value: Array, dtype: Optional[DType] = None) -> Array:
|
||||
"""Returns an array of `shape` filled with `fill_value`.
|
||||
@ -1502,7 +1481,13 @@ def stop_gradient(x):
|
||||
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
|
||||
array(0., dtype=float32)
|
||||
"""
|
||||
return tree_map(ad_util.stop_gradient_p.bind, x)
|
||||
def stop(x):
|
||||
if (dtypes.issubdtype(_dtype(x), np.floating) or
|
||||
dtypes.issubdtype(_dtype(x), np.complexfloating)):
|
||||
return ad_util.stop_gradient_p.bind(x)
|
||||
else:
|
||||
return x # only bind primitive on inexact dtypes, to avoid some staging
|
||||
return tree_map(stop, x)
|
||||
|
||||
|
||||
### convenience wrappers around traceables
|
||||
@ -5656,30 +5641,6 @@ xla.translations[top_k_p] = partial(standard_translate, 'top_k')
|
||||
ad.primitive_jvps[top_k_p] = _top_k_jvp
|
||||
batching.primitive_batchers[top_k_p] = _top_k_batch_rule
|
||||
|
||||
def _tie_in_transpose_rule(t, x, y):
|
||||
if ad.is_undefined_primal(x):
|
||||
return [ad_util.Zero(x.aval), t]
|
||||
else:
|
||||
return [ad_util.Zero.from_value(x), t]
|
||||
|
||||
def _tie_in_batch_rule(batched_args, batch_dims):
|
||||
y = tie_in(*batched_args)
|
||||
_, bdim_y = batch_dims
|
||||
return y, bdim_y
|
||||
|
||||
def _tie_in_impl(x, y):
|
||||
core.check_valid_jaxtype(x)
|
||||
core.check_valid_jaxtype(y)
|
||||
return y
|
||||
|
||||
tie_in_p = Primitive('tie_in')
|
||||
tie_in_p.def_impl(_tie_in_impl)
|
||||
tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y))
|
||||
xla.translations[tie_in_p] = lambda c, x, y: y
|
||||
ad.deflinear2(tie_in_p, _tie_in_transpose_rule)
|
||||
batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule
|
||||
masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]
|
||||
|
||||
|
||||
def _stop_gradient_jvp_rule(primals, tangents):
|
||||
# if we don't call stop_gradient here, we'd only peel off one autodiff tracer
|
||||
@ -6198,7 +6159,72 @@ def _check_user_dtype_supported(dtype, fun_name=None):
|
||||
warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
|
||||
|
||||
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
global _tie_in_transpose_rule, _tie_in_batch_rule, _tie_in_impl, tie_in_p
|
||||
del _tie_in_transpose_rule, _tie_in_batch_rule, _tie_in_impl, tie_in_p
|
||||
def _canonicalize_axis(axis, num_dims):
|
||||
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
|
||||
axis = operator.index(axis)
|
||||
if not -num_dims <= axis < num_dims:
|
||||
raise ValueError(
|
||||
"axis {} is out of bounds for array of dimension {}".format(
|
||||
axis, num_dims))
|
||||
if axis < 0:
|
||||
axis = axis + num_dims
|
||||
return axis
|
||||
|
||||
|
||||
tie_in_p = Primitive('tie_in')
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global tie_in
|
||||
|
||||
def tie_in(x: Array, y: Array) -> Array:
|
||||
"""Returns the value of ``y`` but with a fake data dependence on ``x``.
|
||||
|
||||
When staging to XLA (e.g. running under jit or pmap), values that don't depend
|
||||
on computation inputs are computed op-by-op, and folded into the XLA
|
||||
computation as constants.
|
||||
|
||||
``tie_in`` provides a way to explicitly stage values into the computation.
|
||||
When staging to XLA and ``x`` is already staged, then the result of ``tie_in``
|
||||
is ``y``, but staged to XLA. Downstream use of the result will also be staged
|
||||
to XLA.
|
||||
|
||||
For example, ``lax.sin(const)`` would be constant-folded if ``const`` is
|
||||
a constant array, but ``lax.sin(lax.tie_in(x, const))``, will be staged to
|
||||
XLA as long as ``x`` is staged to XLA.
|
||||
"""
|
||||
if config.omnistaging_enabled:
|
||||
return y
|
||||
else:
|
||||
return tie_in_p.bind(x, y)
|
||||
|
||||
# If lax has already been imported, we need to monkey-patch the
|
||||
# lax/__init__.py import of tie_in. If not (i.e. if this is running at lax
|
||||
# module creation time) then we'll get an import error.
|
||||
try:
|
||||
jax.lax.tie_in = tie_in
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def _tie_in_transpose_rule(t, x, y):
|
||||
if ad.is_undefined_primal(x):
|
||||
return [ad_util.Zero(x.aval), t]
|
||||
else:
|
||||
return [ad_util.Zero.from_value(x), t]
|
||||
|
||||
def _tie_in_batch_rule(batched_args, batch_dims):
|
||||
y = tie_in(*batched_args)
|
||||
_, bdim_y = batch_dims
|
||||
return y, bdim_y
|
||||
|
||||
def _tie_in_impl(x, y):
|
||||
core.check_valid_jaxtype(x)
|
||||
core.check_valid_jaxtype(y)
|
||||
return y
|
||||
|
||||
tie_in_p.def_impl(_tie_in_impl)
|
||||
tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y))
|
||||
xla.translations[tie_in_p] = lambda c, x, y: y
|
||||
ad.deflinear2(tie_in_p, _tie_in_transpose_rule)
|
||||
batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule
|
||||
masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]
|
||||
|
@ -62,22 +62,18 @@ T = TypeVar('T')
|
||||
|
||||
@cache()
|
||||
def _initial_style_untyped_jaxpr(fun: Callable, in_tree, in_avals):
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
with core.initial_style_staging():
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
|
||||
wrapped_fun, in_pvals, instantiate=True, stage_out=False)
|
||||
return jaxpr, out_pvals, consts, out_tree
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
|
||||
return jaxpr, out_avals, consts, out_tree()
|
||||
|
||||
@cache()
|
||||
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
|
||||
jaxpr, out_pvals, consts, out_tree = _initial_style_untyped_jaxpr(
|
||||
fun, in_tree, in_avals)
|
||||
out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
|
||||
jaxpr, out_avals, consts, out_tree = \
|
||||
_initial_style_untyped_jaxpr(fun, in_tree, in_avals)
|
||||
const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
|
||||
typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
|
||||
(), const_avals + in_avals, out_avals)
|
||||
return typed_jaxpr, consts, out_tree()
|
||||
return typed_jaxpr, consts, out_tree
|
||||
|
||||
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
|
||||
in_tree, in_avals):
|
||||
@ -88,37 +84,29 @@ def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
|
||||
# for each one, it makes another that accepts *all* constants, but only uses
|
||||
# those that it needs (dropping the rest).
|
||||
|
||||
jaxprs, all_out_pvals, all_consts, all_out_trees = unzip4([
|
||||
_initial_style_untyped_jaxpr(fun, in_tree, in_avals) for fun in funs])
|
||||
jaxprs, all_out_avals, all_consts, all_out_trees = unzip4(
|
||||
_initial_style_untyped_jaxpr(fun, in_tree, in_avals) for fun in funs)
|
||||
|
||||
newvar = core.gensym(jaxprs, suffix='_')
|
||||
all_const_avals = tuple(
|
||||
tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
|
||||
for consts in all_consts)
|
||||
unused_const_vars = tuple(
|
||||
tuple(newvar(aval) for aval in const_avals)
|
||||
for const_avals in all_const_avals)
|
||||
all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts]
|
||||
for consts in all_consts]
|
||||
unused_const_vars = [[newvar(aval) for aval in const_avals]
|
||||
for const_avals in all_const_avals]
|
||||
|
||||
def pad_jaxpr_constvars(i, jaxpr):
|
||||
prefix = util.concatenate(unused_const_vars[:i])
|
||||
suffix = util.concatenate(unused_const_vars[i+1:])
|
||||
constvars = prefix + jaxpr.constvars + suffix
|
||||
constvars = [*prefix, *jaxpr.constvars, *suffix]
|
||||
return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
|
||||
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
||||
|
||||
const_avals = tuple(util.concatenate(all_const_avals))
|
||||
|
||||
def type_and_const_convert_jaxpr(jaxpr, out_pvals):
|
||||
out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
|
||||
return core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
|
||||
(), const_avals + in_avals, out_avals)
|
||||
|
||||
consts = util.concatenate(all_consts)
|
||||
const_avals = util.concatenate(all_const_avals)
|
||||
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
|
||||
typed_jaxprs = _map(type_and_const_convert_jaxpr, jaxprs, all_out_pvals)
|
||||
|
||||
return (tuple(typed_jaxprs),
|
||||
tuple(util.concatenate(all_consts)),
|
||||
tuple(out_tree() for out_tree in all_out_trees))
|
||||
typed_jaxprs = [core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
|
||||
(), [*const_avals, *in_avals], out_avals)
|
||||
for jaxpr, out_avals in zip(jaxprs, all_out_avals)]
|
||||
return typed_jaxprs, consts, all_out_trees
|
||||
|
||||
def _abstractify(x):
|
||||
return raise_to_shaped(core.get_aval(x))
|
||||
@ -1517,7 +1505,7 @@ def _prune_zeros(ts):
|
||||
|
||||
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
jaxpr, linear, unroll):
|
||||
if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace:
|
||||
if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace: # type: ignore
|
||||
params = dict(reverse=reverse, length=length, num_consts=num_consts,
|
||||
num_carry=num_carry, jaxpr=jaxpr, linear=linear,
|
||||
unroll=unroll)
|
||||
@ -2450,26 +2438,29 @@ def associative_scan(fn, elems, reverse=False):
|
||||
return tree_unflatten(tree, scans)
|
||||
|
||||
|
||||
# TODO(mattjj): remove when omnistaging fully lands
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global _initial_style_untyped_jaxpr, _initial_style_jaxpr, \
|
||||
_initial_style_jaxprs_with_common_consts
|
||||
|
||||
@cache()
|
||||
def _initial_style_untyped_jaxpr(fun: Callable, in_tree, in_avals):
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
|
||||
return jaxpr, out_avals, consts, out_tree()
|
||||
with core.initial_style_staging(): # type: ignore
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr( # type: ignore
|
||||
wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore
|
||||
return jaxpr, out_pvals, consts, out_tree
|
||||
|
||||
@cache()
|
||||
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
|
||||
jaxpr, out_avals, consts, out_tree = \
|
||||
_initial_style_untyped_jaxpr(fun, in_tree, in_avals)
|
||||
jaxpr, out_pvals, consts, out_tree = _initial_style_untyped_jaxpr(
|
||||
fun, in_tree, in_avals)
|
||||
out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
|
||||
const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
|
||||
typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
|
||||
(), const_avals + in_avals, out_avals)
|
||||
return typed_jaxpr, consts, out_tree
|
||||
return typed_jaxpr, consts, out_tree()
|
||||
|
||||
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
|
||||
in_tree, in_avals):
|
||||
@ -2480,26 +2471,34 @@ def omnistaging_enabler() -> None:
|
||||
# for each one, it makes another that accepts *all* constants, but only uses
|
||||
# those that it needs (dropping the rest).
|
||||
|
||||
jaxprs, all_out_avals, all_consts, all_out_trees = unzip4(
|
||||
_initial_style_untyped_jaxpr(fun, in_tree, in_avals) for fun in funs)
|
||||
jaxprs, all_out_pvals, all_consts, all_out_trees = unzip4([
|
||||
_initial_style_untyped_jaxpr(fun, in_tree, in_avals) for fun in funs])
|
||||
|
||||
newvar = core.gensym(jaxprs, suffix='_')
|
||||
all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts]
|
||||
for consts in all_consts]
|
||||
unused_const_vars = [[newvar(aval) for aval in const_avals]
|
||||
for const_avals in all_const_avals]
|
||||
all_const_avals = tuple(
|
||||
tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
|
||||
for consts in all_consts)
|
||||
unused_const_vars = tuple(
|
||||
tuple(newvar(aval) for aval in const_avals)
|
||||
for const_avals in all_const_avals)
|
||||
|
||||
def pad_jaxpr_constvars(i, jaxpr):
|
||||
prefix = util.concatenate(unused_const_vars[:i])
|
||||
suffix = util.concatenate(unused_const_vars[i+1:])
|
||||
constvars = [*prefix, *jaxpr.constvars, *suffix]
|
||||
constvars = prefix + jaxpr.constvars + suffix
|
||||
return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
|
||||
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
||||
|
||||
consts = util.concatenate(all_consts)
|
||||
const_avals = util.concatenate(all_const_avals)
|
||||
const_avals = tuple(util.concatenate(all_const_avals))
|
||||
|
||||
def type_and_const_convert_jaxpr(jaxpr, out_pvals):
|
||||
out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
|
||||
return core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
|
||||
(), const_avals + in_avals, out_avals)
|
||||
|
||||
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
|
||||
typed_jaxprs = [core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
|
||||
(), [*const_avals, *in_avals], out_avals)
|
||||
for jaxpr, out_avals in zip(jaxprs, all_out_avals)]
|
||||
return typed_jaxprs, consts, all_out_trees
|
||||
typed_jaxprs = _map(type_and_const_convert_jaxpr, jaxprs, all_out_pvals)
|
||||
|
||||
return (tuple(typed_jaxprs),
|
||||
tuple(util.concatenate(all_consts)),
|
||||
tuple(out_tree() for out_tree in all_out_trees))
|
||||
|
@ -458,12 +458,10 @@ def _psum_transpose_rule(cts, axis_name, axis_index_groups):
|
||||
|
||||
psum_p = core.Primitive('psum')
|
||||
psum_p.multiple_results = True
|
||||
psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p))
|
||||
psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
|
||||
pxla.soft_pmap_rules[psum_p] = \
|
||||
partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum)
|
||||
xla.parallel_translations[psum_p] = _psum_translation_rule
|
||||
pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args)
|
||||
ad.deflinear(psum_p, _psum_transpose_rule)
|
||||
pxla.multi_host_supported_collectives.add(psum_p)
|
||||
batching.split_axis_rules[psum_p] = partial(_split_axis_comm_assoc, psum_p)
|
||||
@ -474,6 +472,21 @@ batching.collective_rules[psum_p] = \
|
||||
lambda v, d: v.sum(d),
|
||||
lambda v, axis_size: axis_size * v)
|
||||
|
||||
# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
|
||||
# tracing time.
|
||||
@psum_p.def_custom_bind
|
||||
def psum_bind(*args, axis_name, axis_index_groups):
|
||||
if all(not isinstance(x, core.Tracer) for x in args):
|
||||
if axis_index_groups is not None:
|
||||
size = len(axis_index_groups[0])
|
||||
elif type(axis_name) is tuple:
|
||||
size = prod([core.axis_frame(name).size for name in axis_name]) # type: ignore
|
||||
else:
|
||||
size = core.axis_frame(axis_name).size # type: ignore
|
||||
return tuple(size * x for x in args)
|
||||
return core.Primitive.bind(
|
||||
psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups)
|
||||
|
||||
|
||||
pmax_p = core.Primitive('pmax')
|
||||
pmax_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
|
||||
@ -660,20 +673,6 @@ def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
|
||||
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
|
||||
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
|
||||
|
||||
def _axis_index_bind(*, axis_name):
|
||||
dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
|
||||
frame = dynamic_axis_env[axis_name]
|
||||
trace = frame.pmap_trace
|
||||
|
||||
out_aval = ShapedArray((), np.int32)
|
||||
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
|
||||
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
|
||||
dict(axis_name=axis_name),
|
||||
source_info_util.current())
|
||||
out_tracer.recipe = eqn
|
||||
|
||||
return out_tracer
|
||||
|
||||
def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name):
|
||||
assert not vals and not mapped
|
||||
idx = axis_index(axis_name) # type: ignore
|
||||
@ -682,46 +681,58 @@ def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name):
|
||||
axis_index_p = core.Primitive('axis_index')
|
||||
xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
|
||||
pxla.soft_pmap_rules[axis_index_p] = _axis_index_soft_pmap_rule # type: ignore
|
||||
axis_index_p.def_custom_bind(_axis_index_bind)
|
||||
axis_index_p.def_abstract_eval(
|
||||
lambda *args, **params: ShapedArray((), np.int32))
|
||||
pxla.multi_host_supported_collectives.add(axis_index_p)
|
||||
|
||||
# Axis index doesn't get any arguments, so that the default bind would have no
|
||||
# way to call into a data-dependency based trace such as vmap. Each trace that
|
||||
# wants to bind an axis name has to additionally implement `process_axis_index`
|
||||
# and put its main trace on the axis env stack.
|
||||
def _axis_index_bind(*, axis_name):
|
||||
frame = core.axis_frame(axis_name)
|
||||
if frame.main_trace is not None:
|
||||
trace = frame.main_trace.trace_type(frame.main_trace, core.cur_sublevel())
|
||||
return trace.process_axis_index(frame)
|
||||
return core.Primitive.bind(axis_index_p, axis_name=axis_name)
|
||||
axis_index_p.def_custom_bind(_axis_index_bind)
|
||||
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
|
||||
# tracing time.
|
||||
@psum_p.def_custom_bind
|
||||
def psum_bind(*args, axis_name, axis_index_groups):
|
||||
if all(not isinstance(x, core.Tracer) for x in args):
|
||||
if axis_index_groups is not None:
|
||||
size = len(axis_index_groups[0])
|
||||
elif type(axis_name) is tuple:
|
||||
size = prod([core.axis_frame(name).size for name in axis_name]) # type: ignore
|
||||
else:
|
||||
size = core.axis_frame(axis_name).size # type: ignore
|
||||
return tuple(size * x for x in args)
|
||||
return core.Primitive.bind(
|
||||
psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups)
|
||||
def _process_axis_index(self, frame):
|
||||
return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0)
|
||||
batching.BatchTrace.process_axis_index = _process_axis_index
|
||||
|
||||
if psum_p in pxla.parallel_pure_rules:
|
||||
del pxla.parallel_pure_rules[psum_p]
|
||||
|
||||
# Axis index doesn't get any arguments, so that the default bind would have no
|
||||
# way to call into a data-dependency based trace such as vmap. Each trace that
|
||||
# wants to bind an axis name has to additionally implement `process_axis_index`
|
||||
# and put its main trace on the axis env stack.
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global axis_index
|
||||
|
||||
psum_p.bind = partial(core.Primitive.bind, psum_p)
|
||||
psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p)) # type: ignore
|
||||
pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) # type: ignore
|
||||
|
||||
def _axis_index_bind(*, axis_name):
|
||||
frame = core.axis_frame(axis_name)
|
||||
if frame.main_trace is not None:
|
||||
trace = frame.main_trace.trace_type(frame.main_trace, core.cur_sublevel())
|
||||
return trace.process_axis_index(frame)
|
||||
return core.Primitive.bind(axis_index_p, axis_name=axis_name)
|
||||
dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
|
||||
frame = dynamic_axis_env[axis_name]
|
||||
sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1]
|
||||
nreps = dynamic_axis_env.nreps
|
||||
trace = frame.pmap_trace
|
||||
|
||||
out_aval = ShapedArray((), np.int32)
|
||||
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
|
||||
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
|
||||
dict(nreps=nreps, sizes=sizes, axis_name=axis_name),
|
||||
source_info_util.current())
|
||||
out_tracer.recipe = eqn
|
||||
|
||||
return out_tracer
|
||||
|
||||
def _axis_index_translation_rule(c, nreps, sizes, axis_name):
|
||||
div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
|
||||
mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
|
||||
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
|
||||
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
|
||||
|
||||
axis_index_p.def_custom_bind(_axis_index_bind)
|
||||
|
||||
def process_axis_index(self, frame):
|
||||
return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0)
|
||||
|
||||
batching.BatchTrace.process_axis_index = process_axis_index
|
||||
axis_index_p.def_abstract_eval(
|
||||
lambda *args, **params: ShapedArray((), np.int32))
|
||||
xla.translations[axis_index_p] = _axis_index_translation_rule
|
||||
|
@ -266,8 +266,9 @@ def one_hot(x, num_classes, *, dtype=jnp.float64):
|
||||
dtype: optional, a float dtype for the returned values (default float64 if
|
||||
jax_enable_x64 is true, otherwise float32).
|
||||
"""
|
||||
num_classes = core.concrete_or_error(int, num_classes,
|
||||
"in jax.nn.one_hot argument `num_classes`")
|
||||
num_classes = core.concrete_or_error(
|
||||
int, num_classes,
|
||||
"The error arose in jax.nn.one_hot argument `num_classes`.")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
x = jnp.asarray(x)
|
||||
lhs = x[..., jnp.newaxis]
|
||||
|
@ -1137,6 +1137,16 @@ def reshape(a, newshape, order="C"):
|
||||
def _compute_newshape(a, newshape):
|
||||
"""Fixes a -1 value in newshape, if present."""
|
||||
# other errors, like having more than one -1, are caught downstream
|
||||
try: iter(newshape)
|
||||
except: iterable = False
|
||||
else: iterable = True
|
||||
if iterable:
|
||||
newshape = [core.concrete_or_error(int, d,
|
||||
"The error arose in jax.numpy.reshape.")
|
||||
for d in newshape]
|
||||
else:
|
||||
newshape = core.concrete_or_error(int, newshape,
|
||||
"The error arose in jax.numpy.reshape.")
|
||||
newsize = _prod(newshape)
|
||||
if newsize < 0:
|
||||
fix = a.size // -newsize
|
||||
@ -2417,7 +2427,7 @@ def identity(n, dtype=None):
|
||||
def arange(start, stop=None, step=None, dtype=None):
|
||||
lax._check_user_dtype_supported(dtype, "arange")
|
||||
require = partial(core.concrete_or_error, _np_asarray)
|
||||
msg = "in jax.numpy.arange argument `{}`".format
|
||||
msg = "It arose in jax.numpy.arange argument `{}`.".format
|
||||
if stop is None and step is None:
|
||||
start = require(start, msg("stop"))
|
||||
dtype = dtype or _dtype(start)
|
||||
@ -2622,7 +2632,7 @@ def repeat(a, repeats, axis=None, *, total_repeat_length=None):
|
||||
|
||||
# If total_repeat_length is not given, can't compile, use a default.
|
||||
if total_repeat_length is None:
|
||||
repeats = core.concrete_or_error(np.array, repeats, "jax.numpy.repeat")
|
||||
repeats = core.concrete_or_error(np.array, repeats, "It arose in jax.numpy.repeat.")
|
||||
repeats = np.ravel(repeats)
|
||||
if ndim(a) != 0:
|
||||
repeats = np.broadcast_to(repeats, [a.shape[axis]])
|
||||
|
@ -513,7 +513,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
f"Try using `x.astype\\({castfun.__name__}\\)` instead.",
|
||||
f"[Tt]ry using `x.astype\\({castfun.__name__}\\)`",
|
||||
lambda: jit(f)(1.0))
|
||||
|
||||
def test_switch_value_jit(self):
|
||||
@ -549,7 +549,7 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
|
||||
"|Abstract tracer value encountered where concrete value is expected .*)", lambda: jit(f)(0))
|
||||
"|Abstract tracer value encountered where concrete value is expected.*)", lambda: jit(f)(0))
|
||||
|
||||
def test_unimplemented_interpreter_rules(self):
|
||||
foo_p = Primitive('foo')
|
||||
|
@ -4145,7 +4145,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp.sum(jnp.arange(3), (0, 0))
|
||||
|
||||
def testArangeConcretizationError(self):
|
||||
msg = r"Abstract tracer.*\(in jax.numpy.arange argument `{}`\).*".format
|
||||
msg = r"It arose in jax.numpy.arange argument `{}`".format
|
||||
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')):
|
||||
jax.jit(jnp.arange)(3)
|
||||
|
||||
|
@ -150,7 +150,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
|
||||
def testOneHotConcretizationError(self):
|
||||
# https://github.com/google/jax/issues/3654
|
||||
msg = r"Abstract tracer.*\(in jax.nn.one_hot argument `num_classes`\).*"
|
||||
msg = r"in jax.nn.one_hot argument `num_classes`"
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
jax.jit(nn.one_hot)(3, 5)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user