2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2018-12-12 09:47:49 -05:00
|
|
|
"""
|
2019-05-02 08:02:01 -07:00
|
|
|
JAX user-facing transformations and utilities.
|
2018-12-12 09:47:49 -05:00
|
|
|
|
2019-05-02 08:02:01 -07:00
|
|
|
The transformations here mostly wrap internal transformations, providing
|
|
|
|
convenience flags to control behavior and handling Python containers of
|
|
|
|
arguments and outputs. The Python containers handled are pytrees (see
|
|
|
|
tree_util.py), which include nested tuples/lists/dicts, where the leaves are
|
|
|
|
arrays or JaxTuples.
|
2018-12-12 09:47:49 -05:00
|
|
|
"""
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
from __future__ import absolute_import
|
2018-11-21 13:27:26 -08:00
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-06-03 07:17:37 -07:00
|
|
|
import collections
|
2018-11-17 18:03:33 -08:00
|
|
|
import itertools
|
2019-01-06 11:59:33 -08:00
|
|
|
import operator as op
|
2019-02-06 19:20:39 -08:00
|
|
|
import os
|
2019-04-11 06:58:09 -07:00
|
|
|
from warnings import warn
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
import numpy as onp
|
2019-02-06 19:44:12 -08:00
|
|
|
from contextlib import contextmanager
|
2019-02-06 19:20:39 -08:00
|
|
|
from distutils.util import strtobool
|
2019-02-21 21:40:10 -08:00
|
|
|
from six.moves import reduce
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
from . import core
|
|
|
|
from . import linear_util as lu
|
2019-06-03 07:17:37 -07:00
|
|
|
from . import ad_util
|
2018-11-17 18:03:33 -08:00
|
|
|
from .core import pack, eval_jaxpr
|
2019-01-28 08:37:49 -08:00
|
|
|
from .api_util import (pytree_fun_to_jaxtupletree_fun, pytree_to_jaxtupletree,
|
2019-04-10 22:09:14 -07:00
|
|
|
pytree_fun_to_flatjaxtuple_fun, apply_jaxtree_fun, wraps,
|
2019-05-17 07:36:52 -07:00
|
|
|
pytree_fun_to_jaxtupletree_fun2, flatten_fun_leafout)
|
2019-01-06 11:59:33 -08:00
|
|
|
from .tree_util import (process_pytree, node_types, build_tree, PyTreeDef,
|
|
|
|
tree_map, tree_flatten, tree_unflatten, tree_structure,
|
2019-01-28 09:19:06 -08:00
|
|
|
tree_transpose, leaf)
|
2019-01-06 11:59:33 -08:00
|
|
|
from .util import (unzip2, unzip3, curry, partial, safe_map, safe_zip,
|
2019-05-09 20:00:24 -07:00
|
|
|
WrapHashably, Hashable, prod)
|
2019-03-08 09:59:03 -08:00
|
|
|
from .lib.xla_bridge import canonicalize_dtype, device_count
|
2018-11-17 18:03:33 -08:00
|
|
|
from .abstract_arrays import ShapedArray
|
|
|
|
from .interpreters import partial_eval as pe
|
|
|
|
from .interpreters import xla
|
2019-01-25 08:20:33 -08:00
|
|
|
from .interpreters import pxla
|
2018-11-17 18:03:33 -08:00
|
|
|
from .interpreters import ad
|
|
|
|
from .interpreters import batching
|
2019-01-10 15:35:15 -08:00
|
|
|
from .interpreters import parallel
|
2019-02-06 19:44:12 -08:00
|
|
|
from .config import flags, config
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
map = safe_map
|
2019-01-06 11:59:33 -08:00
|
|
|
zip = safe_zip
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-02-06 19:20:39 -08:00
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
flags.DEFINE_bool("jax_disable_jit",
|
|
|
|
strtobool(os.getenv("JAX_DISABLE_JIT", "False")),
|
2019-02-06 19:44:12 -08:00
|
|
|
"Disable JIT compilation and just call original Python.")
|
2019-02-06 19:20:39 -08:00
|
|
|
|
2018-12-15 11:19:18 -08:00
|
|
|
|
2019-01-25 08:20:33 -08:00
|
|
|
def jit(fun, static_argnums=()):
|
2018-12-12 09:47:49 -05:00
|
|
|
"""Sets up `fun` for just-in-time compilation with XLA.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: Function to be jitted. Should be a pure function, as side-effects may
|
2019-06-05 13:20:44 -07:00
|
|
|
only be executed once. Its arguments and return value should be arrays,
|
|
|
|
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
|
2019-04-30 10:04:36 -04:00
|
|
|
|
|
|
|
Positional arguments indicated by `static_argnums` can be anything at all,
|
|
|
|
provided they are hashable and have an equality operation defined. Static
|
|
|
|
arguments are included as part of a compilation cache key, which is why
|
|
|
|
hash and equality operators must be defined.
|
2019-05-03 12:37:14 -07:00
|
|
|
static_argnums: A tuple of ints specifying which positional arguments to
|
2019-04-10 22:09:14 -07:00
|
|
|
treat as static (compile-time constant). Operations that only depend on
|
|
|
|
static arguments will be constant-folded. Calling the jitted function with
|
|
|
|
different values for these constants will trigger recompilation. If the
|
|
|
|
jitted function is called with fewer positional arguments than indicated
|
2019-05-03 12:37:14 -07:00
|
|
|
by `static_argnums` then an error is raised. Defaults to ().
|
2018-12-20 10:09:34 -08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A wrapped version of `fun`, set up for just-in-time compilation.
|
2019-02-15 08:16:25 -05:00
|
|
|
|
|
|
|
In the following example, `selu` can be compiled into a single fused kernel by
|
|
|
|
XLA:
|
|
|
|
|
2019-02-20 09:00:12 -05:00
|
|
|
>>> @jax.jit
|
2019-02-15 08:16:25 -05:00
|
|
|
>>> def selu(x, alpha=1.67, lmbda=1.05):
|
|
|
|
>>> return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
|
|
|
|
>>>
|
|
|
|
>>> key = jax.random.PRNGKey(0)
|
|
|
|
>>> x = jax.random.normal(key, (10,))
|
2019-04-23 18:21:33 -07:00
|
|
|
>>> print(selu(x))
|
|
|
|
[-0.54485154 0.27744263 -0.29255125 -0.91421586 -0.62452525 -0.2474813
|
|
|
|
-0.8574326 -0.7823267 0.7682731 0.59566754]
|
2018-12-12 09:47:49 -05:00
|
|
|
"""
|
2019-05-03 12:37:14 -07:00
|
|
|
return _jit(fun, static_argnums)
|
|
|
|
|
|
|
|
def _jit(fun, static_argnums, device_values=True):
|
2019-07-05 07:47:38 -07:00
|
|
|
if isinstance(static_argnums, int):
|
|
|
|
static_argnums = (static_argnums,)
|
|
|
|
|
2019-01-03 16:14:30 -08:00
|
|
|
@wraps(fun)
|
2018-11-17 18:03:33 -08:00
|
|
|
def f_jitted(*args, **kwargs):
|
2019-02-06 19:44:12 -08:00
|
|
|
if _jit_is_disabled or config.read('jax_disable_jit'):
|
|
|
|
return fun(*args, **kwargs)
|
2019-04-10 22:09:14 -07:00
|
|
|
if static_argnums and max(static_argnums) >= len(args):
|
2019-04-11 06:58:09 -07:00
|
|
|
msg = ("Jitted function has static_argnums={} but was called with only {}"
|
|
|
|
" positional arguments.")
|
2019-04-10 22:09:14 -07:00
|
|
|
raise TypeError(msg.format(static_argnums, len(args)))
|
|
|
|
f = lu.wrap_init(fun)
|
2018-11-17 18:03:33 -08:00
|
|
|
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
2019-02-20 08:04:48 -08:00
|
|
|
f, dyn_args = _argnums_partial(f, dyn_argnums, args)
|
2019-05-03 11:39:37 -07:00
|
|
|
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
|
|
|
|
_check_args(args_flat)
|
2019-05-17 07:36:52 -07:00
|
|
|
flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
|
2019-05-03 12:37:14 -07:00
|
|
|
out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
|
2019-05-17 07:36:52 -07:00
|
|
|
return out if out_tree() is leaf else tree_unflatten(out_tree(), out)
|
2019-04-10 22:09:14 -07:00
|
|
|
|
|
|
|
jitted_name = "jit({}, static_argnums={})"
|
|
|
|
f_jitted.__name__ = jitted_name.format(f_jitted.__name__, static_argnums)
|
2018-11-17 18:03:33 -08:00
|
|
|
return f_jitted
|
|
|
|
|
2018-12-20 10:09:34 -08:00
|
|
|
|
2019-02-06 19:44:12 -08:00
|
|
|
@contextmanager
|
|
|
|
def disable_jit():
|
2019-06-01 08:30:25 -07:00
|
|
|
"""Context manager that disables `jit` behavior under its dynamic context.
|
2019-02-20 09:00:12 -05:00
|
|
|
|
|
|
|
For debugging purposes, it is useful to have a mechanism that disables `jit`
|
2019-06-01 08:30:25 -07:00
|
|
|
everywhere in a dynamic context.
|
2019-02-20 09:00:12 -05:00
|
|
|
|
2019-06-01 08:30:25 -07:00
|
|
|
Values that have a data dependence on the arguments to a jitted function are
|
|
|
|
traced and abstracted. For example, an abstract value may be a ShapedArray
|
|
|
|
instance, representing the set of all possible arrays with a given shape and
|
|
|
|
dtype, but not representing one concrete array with specific values. You might
|
|
|
|
notice those if you use a benign side-effecting operation in a jitted
|
|
|
|
function, like a print:
|
2019-02-20 09:00:12 -05:00
|
|
|
|
|
|
|
>>> @jax.jit
|
|
|
|
>>> def f(x):
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
... y = x *2
|
|
|
|
... print("Value of y is", y)
|
|
|
|
... return y + 3
|
|
|
|
...
|
2019-02-20 09:00:12 -05:00
|
|
|
>>> print(f(jax.numpy.array([1, 2, 3])))
|
|
|
|
Value of y is Traced<ShapedArray(int32[3]):JaxprTrace(level=-1/1)>
|
|
|
|
[5 7 9]
|
|
|
|
|
|
|
|
Here `y` has been abstracted by `jit` to a `ShapedArray`, which represents an
|
2019-06-01 08:30:25 -07:00
|
|
|
array with a fixed shape and type but an arbitrary value. It's also traced. If
|
|
|
|
we want to see a concrete value while debugging, and avoid the tracer too, we
|
|
|
|
can use the `disable_jit` context manager:
|
2019-02-20 09:00:12 -05:00
|
|
|
|
|
|
|
>>> with jax.disable_jit():
|
|
|
|
>>> print(f(np.array([1, 2, 3])))
|
|
|
|
>>>
|
|
|
|
Value of y is [2 4 6]
|
|
|
|
[5 7 9]
|
|
|
|
"""
|
2019-02-06 19:44:12 -08:00
|
|
|
global _jit_is_disabled
|
2019-07-18 19:47:49 -07:00
|
|
|
try:
|
|
|
|
_jit_is_disabled, prev_val = True, _jit_is_disabled
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
_jit_is_disabled = prev_val
|
2019-02-06 19:44:12 -08:00
|
|
|
_jit_is_disabled = False
|
|
|
|
|
|
|
|
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
def xla_computation(fun, static_argnums=(), axis_env=None):
|
|
|
|
"""Creates a function that produces its XLA computation given example args.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: Function from which to form XLA computations.
|
|
|
|
static_argnums: See the ``jax.jit`` docstring.
|
|
|
|
axis_env: Optional, a list of pairs where the first element is an axis name
|
|
|
|
and the second element is a positive integer representing the size of the
|
|
|
|
mapped axis with that name. This parameter is useful when lowering
|
|
|
|
functions that involve parallel communication collectives, and it
|
|
|
|
specifies the axis name/size environment that would be set up by
|
|
|
|
applications of ``jax.pmap``. See the examples below.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A wrapped version of ``fun`` that when applied to example arguments returns a
|
|
|
|
built XLA Computation (see xla_client.py), from which representations of the
|
|
|
|
unoptimized XLA HLO computation can be extracted using methods like
|
|
|
|
``GetHloText``, ``GetSerializedProto``, and ``GetHloDotGraph``.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
|
|
|
>>> c = jax.xla_computation(f)(3.)
|
|
|
|
>>> print(c.GetHloText())
|
|
|
|
HloModule jaxpr_computation__4.5
|
|
|
|
ENTRY jaxpr_computation__4.5 {
|
|
|
|
tuple.1 = () tuple()
|
|
|
|
parameter.2 = f32[] parameter(0)
|
|
|
|
cosine.3 = f32[] cosine(parameter.2)
|
|
|
|
ROOT sine.4 = f32[] sine(cosine.3)
|
|
|
|
}
|
|
|
|
|
|
|
|
Here's an example that involves a parallel collective and axis name:
|
|
|
|
|
|
|
|
>>> def f(x): return x - jax.lax.psum(x, 'i')
|
|
|
|
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2)
|
|
|
|
>>> print(c.GetHloText())
|
|
|
|
HloModule jaxpr_computation.9
|
|
|
|
primitive_computation.3 {
|
|
|
|
parameter.4 = s32[] parameter(0)
|
|
|
|
parameter.5 = s32[] parameter(1)
|
|
|
|
ROOT add.6 = s32[] add(parameter.4, parameter.5)
|
|
|
|
}
|
|
|
|
ENTRY jaxpr_computation.9 {
|
|
|
|
tuple.1 = () tuple()
|
|
|
|
parameter.2 = s32[] parameter(0)
|
|
|
|
all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3
|
|
|
|
ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7)
|
|
|
|
}
|
|
|
|
|
|
|
|
Notice the ``replica_groups`` that were generated. Here's an example that
|
2019-07-05 17:15:01 -07:00
|
|
|
generates more interesting ``replica_groups``:
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
|
|
|
|
>>> def g(x):
|
|
|
|
... rowsum = lax.psum(x, 'i')
|
|
|
|
... colsum = lax.psum(x, 'j')
|
|
|
|
... allsum = lax.psum(x, ('i', 'j'))
|
|
|
|
... return rowsum, colsum, allsum
|
|
|
|
...
|
|
|
|
>>> axis_env = [('i', 4), ('j', 2)]
|
|
|
|
>>> c = xla_computation(g, axis_env=axis_env)(5.)
|
|
|
|
>>> print(c.GetHloText())
|
|
|
|
HloModule jaxpr_computation__1.19
|
|
|
|
[removed uninteresting text here]
|
|
|
|
ENTRY jaxpr_computation__1.19 {
|
|
|
|
tuple.1 = () tuple()
|
|
|
|
parameter.2 = f32[] parameter(0)
|
|
|
|
all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3
|
|
|
|
all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8
|
|
|
|
all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13
|
|
|
|
ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
|
|
|
|
}
|
|
|
|
"""
|
|
|
|
|
2019-02-20 09:03:30 -08:00
|
|
|
def pv_like(x):
|
|
|
|
aval = xla.abstractify(x)
|
|
|
|
return pe.PartialVal((aval, core.unit))
|
|
|
|
|
2019-07-09 15:12:02 -07:00
|
|
|
def make_axis_env(nreps):
|
|
|
|
if axis_env is None:
|
|
|
|
return xla.AxisEnv(nreps, [], [])
|
|
|
|
else:
|
|
|
|
nreps = nreps * prod(size for name, size in axis_env)
|
|
|
|
return xla.AxisEnv(nreps, *zip(*axis_env))
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
|
2019-02-20 09:03:30 -08:00
|
|
|
@wraps(fun)
|
|
|
|
def computation_maker(*args, **kwargs):
|
2019-04-04 17:40:48 -07:00
|
|
|
wrapped = lu.wrap_init(fun)
|
2019-02-20 09:03:30 -08:00
|
|
|
jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
2019-05-22 14:38:49 -07:00
|
|
|
if not kwargs:
|
|
|
|
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(wrapped, in_trees)
|
|
|
|
pvals = map(pv_like, jax_args)
|
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
|
2019-07-09 15:12:02 -07:00
|
|
|
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
|
|
|
|
return xla.build_jaxpr(jaxpr, axis_env_, consts,
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
*map(xla.abstractify, jax_args))
|
2019-05-22 14:38:49 -07:00
|
|
|
else:
|
|
|
|
jax_kwargs, kwargs_tree = pytree_to_jaxtupletree(kwargs)
|
|
|
|
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun2(wrapped, kwargs_tree, in_trees)
|
|
|
|
pvals = map(pv_like, (jax_kwargs,) + tuple(jax_args))
|
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
|
2019-07-09 15:12:02 -07:00
|
|
|
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
|
|
|
|
return xla.build_jaxpr(jaxpr, axis_env_, consts, xla.abstractify(jax_kwargs),
|
|
|
|
*map(xla.abstractify, jax_args))
|
2019-02-20 09:03:30 -08:00
|
|
|
|
|
|
|
return computation_maker
|
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def grad(fun, argnums=0, has_aux=False, holomorphic=False):
|
2018-12-12 09:47:49 -05:00
|
|
|
"""Creates a function which evaluates the gradient of `fun`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: Function to be differentiated. Its arguments at positions specified by
|
2018-12-20 10:09:34 -08:00
|
|
|
`argnums` should be arrays, scalars, or standard Python containers. It
|
|
|
|
should return a scalar (which includes arrays with shape `()` but not
|
|
|
|
arrays with shape `(1,)` etc.)
|
2019-01-03 16:14:30 -08:00
|
|
|
argnums: Optional, integer or tuple of integers. Specifies which positional
|
|
|
|
argument(s) to differentiate with respect to (default 0).
|
2019-03-07 14:08:02 -08:00
|
|
|
has_aux: Optional, bool. Indicates whether `fun` returns a pair where the
|
2019-04-13 13:22:45 -07:00
|
|
|
first element is considered the output of the mathematical function to be
|
|
|
|
differentiated and the second element is auxiliary data. Default False.
|
|
|
|
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
|
|
|
|
holomorphic. Default False.
|
2018-12-20 10:09:34 -08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A function with the same arguments as `fun`, that evaluates the gradient of
|
|
|
|
`fun`. If `argnums` is an integer then the gradient has the same shape and
|
|
|
|
type as the positional argument indicated by that integer. If argnums is a
|
|
|
|
tuple of integers, the gradient is a tuple of values with the same shapes
|
2019-03-07 14:40:48 -08:00
|
|
|
and types as the corresponding arguments. If `has_aux` is True then a pair
|
|
|
|
of (gradient, auxiliary_data) is returned.
|
2019-02-20 09:00:12 -05:00
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
>>> grad_tanh = jax.grad(jax.numpy.tanh)
|
2019-04-23 18:21:33 -07:00
|
|
|
>>> print(grad_tanh(0.2))
|
|
|
|
0.961043
|
2018-12-12 09:47:49 -05:00
|
|
|
"""
|
2019-04-13 13:22:45 -07:00
|
|
|
value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux,
|
|
|
|
holomorphic=holomorphic)
|
2018-12-20 10:09:34 -08:00
|
|
|
|
2019-01-06 11:59:33 -08:00
|
|
|
docstr = ("Gradient of {fun} with respect to positional argument(s) "
|
|
|
|
"{argnums}. Takes the same arguments as {fun} but returns the "
|
|
|
|
"gradient, which has the same shape as the arguments at "
|
|
|
|
"positions {argnums}.")
|
|
|
|
|
|
|
|
@wraps(fun, docstr=docstr, argnums=argnums)
|
2018-11-17 18:03:33 -08:00
|
|
|
def grad_f(*args, **kwargs):
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
|
|
|
_, g = value_and_grad_f(*args, **kwargs)
|
|
|
|
return g
|
|
|
|
else:
|
|
|
|
(_, aux), g = value_and_grad_f(*args, **kwargs)
|
|
|
|
return g, aux
|
2018-12-20 10:09:34 -08:00
|
|
|
|
|
|
|
return grad_f
|
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False):
|
2018-12-20 10:09:34 -08:00
|
|
|
"""Creates a function which evaluates both `fun` and the gradient of `fun`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: Function to be differentiated. Its arguments at positions specified by
|
|
|
|
`argnums` should be arrays, scalars, or standard Python containers. It
|
|
|
|
should return a scalar (which includes arrays with shape `()` but not
|
|
|
|
arrays with shape `(1,)` etc.)
|
2019-01-03 16:14:30 -08:00
|
|
|
argnums: Optional, integer or tuple of integers. Specifies which positional
|
|
|
|
argument(s) to differentiate with respect to (default 0).
|
2019-03-07 14:08:02 -08:00
|
|
|
has_aux: Optional, bool. Indicates whether `fun` returns a pair where the
|
2019-03-07 14:40:48 -08:00
|
|
|
first element is considered the output of the mathematical function to be
|
2019-03-07 14:08:02 -08:00
|
|
|
differentiated and the second element is auxiliary data. Default False.
|
2019-04-13 13:22:45 -07:00
|
|
|
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
|
|
|
|
holomorphic. Default False.
|
2018-12-20 10:09:34 -08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A function with the same arguments as `fun` that evaluates both `fun` and
|
|
|
|
the gradient of `fun` and returns them as a pair (a two-element tuple). If
|
|
|
|
`argnums` is an integer then the gradient has the same shape and type as the
|
|
|
|
positional argument indicated by that integer. If argnums is a tuple of
|
|
|
|
integers, the gradient is a tuple of values with the same shapes and types
|
|
|
|
as the corresponding arguments.
|
|
|
|
"""
|
2019-01-06 11:59:33 -08:00
|
|
|
|
|
|
|
docstr = ("Value and gradient of {fun} with respect to positional "
|
|
|
|
"argument(s) {argnums}. Takes the same arguments as {fun} but "
|
|
|
|
"returns a two-element tuple where the first element is the value "
|
|
|
|
"of {fun} and the second element is the gradient, which has the "
|
|
|
|
"same shape as the arguments at positions {argnums}.")
|
|
|
|
|
|
|
|
@wraps(fun, docstr=docstr, argnums=argnums)
|
2018-12-20 10:09:34 -08:00
|
|
|
def value_and_grad_f(*args, **kwargs):
|
2018-11-17 18:03:33 -08:00
|
|
|
f = lu.wrap_init(fun, kwargs)
|
2019-02-20 08:04:48 -08:00
|
|
|
f_partial, dyn_args = _argnums_partial(f, argnums, args)
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
|
|
|
ans, vjp_py = vjp(f_partial, *dyn_args)
|
|
|
|
else:
|
|
|
|
ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)
|
2019-04-13 13:22:45 -07:00
|
|
|
_check_scalar(ans)
|
|
|
|
dtype = onp.result_type(ans)
|
|
|
|
if not (holomorphic or onp.issubdtype(dtype, onp.floating)):
|
|
|
|
msg = ("Gradient only defined for real-output functions (with dtype that "
|
|
|
|
"is a subdtype of np.floating), but got dtype {}. For holomorphic "
|
|
|
|
"differentiation, pass holomorphic=True.")
|
|
|
|
raise TypeError(msg.format(dtype))
|
|
|
|
g = vjp_py(onp.ones((), dtype=dtype))
|
2018-12-20 10:09:34 -08:00
|
|
|
g = g[0] if isinstance(argnums, int) else g
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
|
|
|
return ans, g
|
|
|
|
else:
|
|
|
|
return (ans, aux), g
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-20 10:09:34 -08:00
|
|
|
return value_and_grad_f
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def _check_scalar(x):
|
2019-05-28 21:10:09 -07:00
|
|
|
msg = "Gradient only defined for scalar-output functions. Output {}.".format
|
2019-04-13 13:22:45 -07:00
|
|
|
try:
|
|
|
|
aval = core.get_aval(x)
|
|
|
|
except TypeError:
|
2019-05-28 21:10:09 -07:00
|
|
|
raise TypeError(msg("was {}".format(x)))
|
2019-04-13 13:22:45 -07:00
|
|
|
else:
|
2019-05-28 21:10:09 -07:00
|
|
|
if isinstance(aval, ShapedArray):
|
|
|
|
if aval.shape != ():
|
|
|
|
raise TypeError(msg("had shape: {}".format(aval.shape)))
|
|
|
|
else:
|
|
|
|
raise TypeError(msg("had abstract value {}".format(aval)))
|
2019-04-13 13:22:45 -07:00
|
|
|
|
2019-01-06 11:59:33 -08:00
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def jacfwd(fun, argnums=0, holomorphic=False):
|
2019-02-15 08:16:25 -05:00
|
|
|
"""Jacobian of `fun` evaluated column-by-column using forward-mode AD.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: Function whose Jacobian is to be computed.
|
|
|
|
argnums: Optional, integer or tuple of integers. Specifies which positional
|
|
|
|
argument(s) to differentiate with respect to (default `0`).
|
2019-04-13 13:22:45 -07:00
|
|
|
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
|
|
|
|
holomorphic. Default False.
|
2019-02-15 08:16:25 -05:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A function with the same arguments as `fun`, that evaluates the Jacobian of
|
|
|
|
`fun` using forward-mode automatic differentiation.
|
|
|
|
|
|
|
|
>>> def f(x):
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
... return jax.numpy.asarray(
|
|
|
|
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jax.numpy.sin(x[0])])
|
|
|
|
...
|
2019-04-23 18:21:33 -07:00
|
|
|
>>> print(jax.jacfwd(f)(np.array([1., 2., 3.])))
|
|
|
|
[[ 1. , 0. , 0. ],
|
|
|
|
[ 0. , 0. , 5. ],
|
|
|
|
[ 0. , 16. , -2. ],
|
|
|
|
[ 1.6209068 , 0. , 0.84147096]]
|
2019-02-15 08:16:25 -05:00
|
|
|
"""
|
2019-01-06 11:59:33 -08:00
|
|
|
|
|
|
|
def jacfun(*args, **kwargs):
|
|
|
|
f = lu.wrap_init(fun, kwargs)
|
2019-02-20 08:04:48 -08:00
|
|
|
f_partial, dyn_args = _argnums_partial(f, argnums, args)
|
2019-04-13 13:22:45 -07:00
|
|
|
holomorphic or tree_map(_check_real_input_jacfwd, dyn_args)
|
2019-01-06 11:59:33 -08:00
|
|
|
pushfwd = partial(jvp, f_partial, dyn_args)
|
|
|
|
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
|
|
|
|
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
|
|
|
|
return tree_map(partial(_unravel_array_into_pytree, example_args, -1), jac)
|
|
|
|
|
|
|
|
return jacfun
|
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def _check_real_input_jacfwd(x):
|
|
|
|
aval = core.get_aval(x)
|
|
|
|
if not onp.issubdtype(aval.dtype, onp.floating):
|
|
|
|
msg = ("jacfwd only defined for functions with input dtypes that are "
|
|
|
|
"sub-dtypes of `np.floating` (i.e. that model real values), but got "
|
|
|
|
"{}. For holomorphic differentiation, pass holomorphic=True.")
|
|
|
|
raise TypeError(msg.format(aval.dtype.name))
|
|
|
|
|
|
|
|
|
|
|
|
def jacrev(fun, argnums=0, holomorphic=False):
|
2019-02-15 08:16:25 -05:00
|
|
|
"""Jacobian of `fun` evaluated row-by-row using reverse-mode AD.
|
2019-01-06 11:59:33 -08:00
|
|
|
|
2019-02-15 08:16:25 -05:00
|
|
|
Args:
|
|
|
|
fun: Function whose Jacobian is to be computed.
|
|
|
|
argnums: Optional, integer or tuple of integers. Specifies which positional
|
|
|
|
argument(s) to differentiate with respect to (default `0`).
|
2019-04-13 13:22:45 -07:00
|
|
|
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
|
|
|
|
holomorphic. Default False.
|
2019-01-06 11:59:33 -08:00
|
|
|
|
2019-02-15 08:16:25 -05:00
|
|
|
Returns:
|
|
|
|
A function with the same arguments as `fun`, that evaluates the Jacobian of
|
|
|
|
`fun` using reverse-mode automatic differentiation.
|
|
|
|
|
|
|
|
>>> def f(x):
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
... return jax.numpy.asarray(
|
|
|
|
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jax.numpy.sin(x[0])])
|
|
|
|
...
|
2019-04-23 18:21:33 -07:00
|
|
|
>>> print(jax.jacrev(f)(np.array([1., 2., 3.])))
|
|
|
|
[[ 1. , 0. , 0. ],
|
|
|
|
[ 0. , 0. , 5. ],
|
|
|
|
[ 0. , 16. , -2. ],
|
|
|
|
[ 1.6209068 , 0. , 0.84147096]]
|
2019-02-15 08:16:25 -05:00
|
|
|
"""
|
2019-01-06 11:59:33 -08:00
|
|
|
def jacfun(*args, **kwargs):
|
|
|
|
f = lu.wrap_init(fun, kwargs)
|
2019-02-20 08:04:48 -08:00
|
|
|
f_partial, dyn_args = _argnums_partial(f, argnums, args)
|
2019-01-06 11:59:33 -08:00
|
|
|
y, pullback = vjp(f_partial, *dyn_args)
|
2019-04-13 13:22:45 -07:00
|
|
|
holomorphic or tree_map(_check_real_output_jacrev, y)
|
2019-01-06 11:59:33 -08:00
|
|
|
jac = vmap(pullback)(_std_basis(y))
|
|
|
|
jac = jac[0] if isinstance(argnums, int) else jac
|
|
|
|
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
|
|
|
|
jac = tree_map(partial(_unravel_array_into_pytree, y, 0), jac)
|
|
|
|
return tree_transpose(tree_structure(example_args), tree_structure(y), jac)
|
|
|
|
|
|
|
|
return jacfun
|
2019-01-07 08:56:19 -08:00
|
|
|
jacobian = jacrev
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def _check_real_output_jacrev(x):
|
|
|
|
aval = core.get_aval(x)
|
|
|
|
if not onp.issubdtype(aval.dtype, onp.floating):
|
|
|
|
msg = ("jacrev only defined for functions with output dtypes that are "
|
|
|
|
"sub-dtypes of `np.floating` (i.e. that model real values), but got "
|
|
|
|
"{}. For holomorphic differentiation, pass holomorphic=True.")
|
|
|
|
raise TypeError(msg.format(aval.dtype.name))
|
|
|
|
|
|
|
|
|
|
|
|
def hessian(fun, argnums=0, holomorphic=False):
|
2019-02-15 08:16:25 -05:00
|
|
|
"""Hessian of `fun`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: Function whose Hessian is to be computed.
|
|
|
|
argnums: Optional, integer or tuple of integers. Specifies which positional
|
|
|
|
argument(s) to differentiate with respect to (default `0`).
|
2019-04-13 13:22:45 -07:00
|
|
|
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
|
|
|
|
holomorphic. Default False.
|
2019-02-15 08:16:25 -05:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A function with the same arguments as `fun`, that evaluates the Hessian of
|
|
|
|
`fun`.
|
|
|
|
|
|
|
|
>>> g = lambda(x): x[0]**3 - 2*x[0]*x[1] - x[1]**6
|
2019-04-23 18:21:33 -07:00
|
|
|
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
|
|
|
|
[[ 6., -2.],
|
|
|
|
[ -2., -480.]]
|
2019-02-15 08:16:25 -05:00
|
|
|
"""
|
2019-04-13 13:22:45 -07:00
|
|
|
return jacfwd(jacrev(fun, argnums, holomorphic), argnums, holomorphic)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-01-06 11:59:33 -08:00
|
|
|
def _std_basis(pytree):
|
|
|
|
leaves, _ = tree_flatten(pytree)
|
|
|
|
ndim = sum(map(onp.size, leaves))
|
2019-03-23 14:08:15 -07:00
|
|
|
# TODO(mattjj): use a symbolic identity matrix here
|
2019-04-12 12:01:19 -07:00
|
|
|
dtype = onp.result_type(*leaves)
|
|
|
|
flat_basis = onp.eye(ndim, dtype=dtype)
|
|
|
|
return _unravel_array_into_pytree(pytree, 1, flat_basis)
|
2019-01-06 11:59:33 -08:00
|
|
|
|
|
|
|
def _unravel_array_into_pytree(pytree, axis, arr):
|
|
|
|
leaves, treedef = tree_flatten(pytree)
|
|
|
|
axis = axis % arr.ndim
|
|
|
|
shapes = [arr.shape[:axis] + onp.shape(l) + arr.shape[axis+1:] for l in leaves]
|
|
|
|
parts = _split(arr, onp.cumsum(map(onp.size, leaves[:-1])), axis)
|
2019-04-13 13:22:45 -07:00
|
|
|
reshaped_parts = [onp.reshape(x, shape) for x, shape in zip(parts, shapes)]
|
2019-01-06 11:59:33 -08:00
|
|
|
return tree_unflatten(treedef, reshaped_parts)
|
|
|
|
|
|
|
|
def _split(x, indices, axis):
|
|
|
|
if isinstance(x, onp.ndarray):
|
|
|
|
return onp.split(x, indices, axis)
|
|
|
|
else:
|
|
|
|
return x.split(indices, axis)
|
2019-01-03 16:14:30 -08:00
|
|
|
|
2019-01-06 11:59:33 -08:00
|
|
|
def _dtype(x):
|
|
|
|
return canonicalize_dtype(onp.result_type(x))
|
2019-01-03 16:14:30 -08:00
|
|
|
|
|
|
|
|
2018-12-11 12:52:09 -08:00
|
|
|
def vmap(fun, in_axes=0, out_axes=0):
|
2019-05-15 08:13:30 -07:00
|
|
|
"""Vectorizing map. Creates a function which maps `fun` over argument axes.
|
2018-12-12 09:47:49 -05:00
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: Function to be mapped over additional axes.
|
2019-02-14 11:58:18 -05:00
|
|
|
in_axes: Specifies which input axes to map over. These may be integers,
|
|
|
|
`None`, or (possibly nested) tuples of integers or `None`.
|
|
|
|
out_axes: Specifies which output axes to map over. These may be integers,
|
|
|
|
`None`, or (possibly nested) tuples of integers or `None`.
|
2018-12-20 10:09:34 -08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Batched/vectorized version of `fun` with arguments that correspond to those
|
|
|
|
of `fun`, but with extra array axes at positions indicated by `in_axes`, and
|
|
|
|
a return value that corresponds to that of `fun`, but with extra array axes
|
|
|
|
at positions indicated by `out_axes`.
|
2018-12-12 09:47:49 -05:00
|
|
|
|
|
|
|
For example, we can implement a matrix-matrix product using a vector dot
|
|
|
|
product:
|
|
|
|
|
2019-01-15 20:14:19 -05:00
|
|
|
>>> vv = lambda x, y: np.vdot(x, y) # ([a], [a]) -> []
|
|
|
|
>>> mv = vmap(vv, (0, None), 0) # ([a,b], [b]) -> [a]
|
|
|
|
>>> mm = vmap(mv, (None, 1), 1) # ([a,b], [b,c]) -> [a,c]
|
2018-12-12 09:47:49 -05:00
|
|
|
|
2019-04-23 18:21:33 -07:00
|
|
|
(here we use `[a,b]` to indicate an array with shape (a,b))
|
2018-12-12 09:47:49 -05:00
|
|
|
"""
|
2019-01-06 11:59:33 -08:00
|
|
|
|
|
|
|
docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "
|
|
|
|
"but with additional array axes over which {fun} is mapped.")
|
|
|
|
|
2019-01-18 07:03:12 -08:00
|
|
|
if (not isinstance(in_axes, (list, tuple, type(None), int))
|
|
|
|
or not isinstance(out_axes, (list, tuple, type(None), int))):
|
|
|
|
msg = ("vmap arguments in_axes and out_axes must each be an integer, None, "
|
|
|
|
"or a (nested) tuple of those types, got {} and {} respectively.")
|
|
|
|
raise TypeError(msg.format(type(in_axes), type(out_axes)))
|
|
|
|
|
2019-01-06 11:59:33 -08:00
|
|
|
@wraps(fun, docstr=docstr)
|
2018-12-11 12:52:09 -08:00
|
|
|
def batched_fun(*args, **kwargs):
|
2019-07-17 23:25:55 -07:00
|
|
|
if kwargs:
|
|
|
|
msg = ("kwargs not yet supported for functions output by vmap. Please "
|
|
|
|
"+1 the issue https://github.com/google/jax/issues/912")
|
|
|
|
raise NotImplementedError(msg)
|
2019-03-29 08:03:58 -07:00
|
|
|
f = lu.wrap_init(fun, kwargs) if not isinstance(fun, lu.WrappedFun) else fun
|
2019-01-18 07:03:12 -08:00
|
|
|
in_axes_ = in_axes if isinstance(in_axes, (list, tuple)) else (in_axes,) * len(args)
|
2019-01-03 16:14:30 -08:00
|
|
|
in_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
|
|
|
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
|
|
|
|
out_flat = batching.batch(jaxtree_fun, in_flat, in_axes_, out_axes)
|
2018-12-11 12:52:09 -08:00
|
|
|
return build_tree(out_tree(), out_flat)
|
|
|
|
|
|
|
|
return batched_fun
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-01-10 15:35:15 -08:00
|
|
|
|
2019-03-06 14:36:47 -08:00
|
|
|
def pmap(fun, axis_name=None):
|
2019-05-15 08:13:30 -07:00
|
|
|
"""Parallel map with support for collectives.
|
|
|
|
|
|
|
|
The purpose of ``pmap`` is to express single-program multiple-data (SPMD)
|
|
|
|
programs and execute them in parallel on XLA devices, such as multiple GPUs or
|
|
|
|
multiple TPU cores. Semantically it is comparable to ``vmap`` because both
|
|
|
|
transformations map a function over array axes, but where ``vmap`` vectorizes
|
|
|
|
functions by pushing the mapped axis down into primitive operations, ``pmap``
|
|
|
|
instead replicates the function and executes each replica on its own XLA
|
|
|
|
device in parallel.
|
|
|
|
|
|
|
|
Another key difference with ``vmap`` is that while ``vmap`` can only express
|
|
|
|
pure maps, ``pmap`` enables the use of parallel SPMD collective operations,
|
|
|
|
like all-reduce sum.
|
|
|
|
|
|
|
|
The mapped axis size must be less than or equal to the number of XLA devices
|
|
|
|
available. For nested ``pmap`` calls, the product of the mapped axis sizes
|
|
|
|
must be less than or equal to the number of XLA devices.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: Function to be mapped over argument axes.
|
|
|
|
axis_name: Optional, a hashable Python object used to identify the mapped
|
2019-05-15 08:24:47 -07:00
|
|
|
axis so that parallel collectives can be applied.
|
2019-05-15 08:13:30 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A parallelized version of ``fun`` with arguments that correspond to those of
|
|
|
|
``fun`` but each with an additional leading array axis (with equal sizes)
|
|
|
|
and with output that has an additional leading array axis (with the same
|
|
|
|
size).
|
|
|
|
|
2019-05-15 08:24:47 -07:00
|
|
|
For example, assuming 8 XLA devices are available, ``pmap`` can be used as a
|
2019-05-15 08:13:30 -07:00
|
|
|
map along a leading array axes:
|
|
|
|
|
|
|
|
>>> out = pmap(lambda x: x ** 2)(np.arange(8))
|
|
|
|
>>> print(out)
|
|
|
|
[0, 1, 4, 9, 16, 25, 36, 49]
|
|
|
|
>>> x = np.arange(3 * 2 * 2.).reshape((3, 2, 2))
|
|
|
|
>>> y = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
|
|
|
|
>>> out = pmap(np.dot)(x, y)
|
|
|
|
>>> print(out)
|
|
|
|
[[[ 4. 9.]
|
|
|
|
[ 12. 29.]]
|
|
|
|
[[ 244. 345.]
|
|
|
|
[ 348. 493.]]
|
|
|
|
[[ 1412. 1737.]
|
|
|
|
[ 1740. 2141.]]]
|
|
|
|
|
|
|
|
In addition to expressing pure maps, ``pmap`` can also be used to express
|
|
|
|
parallel single-program multiple-data (SPMD) programs that communicate via
|
2019-05-15 08:24:47 -07:00
|
|
|
collective operations. For example:
|
2019-05-15 08:13:30 -07:00
|
|
|
|
|
|
|
>>> f = lambda x: x / jax.lax.psum(x, axis_name='i')
|
|
|
|
>>> out = pmap(f, axis_name='i')(np.arange(4.))
|
|
|
|
>>> print(out)
|
|
|
|
[ 0. 0.16666667 0.33333334 0.5 ]
|
|
|
|
>>> print(out.sum())
|
|
|
|
1.0
|
|
|
|
|
|
|
|
In this example, ``axis_name`` is a string, but it can be any Python object
|
|
|
|
with ``__hash__`` and ``__eq__`` defined.
|
|
|
|
|
|
|
|
The argument ``axis_name`` to ``pmap`` names the mapped axis so that
|
|
|
|
collective operations, like ``jax.lax.psum``, can refer to it. Axis names are
|
|
|
|
important particularly in the case of nested ``pmap`` functions, where
|
|
|
|
collectives can operate over distinct axes:
|
|
|
|
|
|
|
|
>>> from functools import partial
|
|
|
|
>>> @partial(pmap, axis_name='rows')
|
|
|
|
>>> @partial(pmap, axis_name='cols')
|
|
|
|
>>> def normalize(x):
|
|
|
|
>>> row_normed = x / jax.lax.psum(x, 'rows')
|
|
|
|
>>> col_normed = x / jax.lax.psum(x, 'cols')
|
|
|
|
>>> doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
|
|
|
|
>>> return row_normed, col_normed, doubly_normed
|
|
|
|
>>>
|
|
|
|
>>> x = np.arange(8.).reshape((4, 2))
|
|
|
|
>>> row_normed, col_normed, doubly_normed = normalize(x)
|
|
|
|
>>> print(row_normed.sum(0))
|
|
|
|
[ 1. 1.]
|
|
|
|
>>> print(col_normed.sum(1))
|
|
|
|
[ 1. 1. 1. 1.]
|
|
|
|
>>> print(doubly_normed.sum((0, 1)))
|
|
|
|
1.0
|
|
|
|
"""
|
2019-03-06 14:03:47 -08:00
|
|
|
axis_name = _TempAxisName() if axis_name is None else axis_name
|
|
|
|
|
2019-01-25 08:20:33 -08:00
|
|
|
@wraps(fun)
|
2019-05-17 07:36:52 -07:00
|
|
|
def f_pmapped(*args, **kwargs):
|
2019-05-02 22:13:49 -07:00
|
|
|
axis_size = _pmap_axis_size(args)
|
2019-04-10 22:09:14 -07:00
|
|
|
f = lu.wrap_init(fun)
|
2019-05-06 06:50:15 -07:00
|
|
|
args_flat, in_tree = tree_flatten((args, kwargs))
|
|
|
|
_check_args(args_flat)
|
2019-05-17 07:36:52 -07:00
|
|
|
flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
|
2019-05-06 06:50:15 -07:00
|
|
|
out = pxla.xla_pmap(flat_fun, *args_flat,
|
2019-04-10 22:09:14 -07:00
|
|
|
axis_name=axis_name, axis_size=axis_size)
|
2019-05-17 07:36:52 -07:00
|
|
|
return out if out_tree() is leaf else tree_unflatten(out_tree(), out)
|
2019-01-25 08:20:33 -08:00
|
|
|
|
2019-03-06 14:36:47 -08:00
|
|
|
namestr = "pmap({}, axis_name={})".format
|
2019-05-17 07:36:52 -07:00
|
|
|
f_pmapped.__name__ = namestr(f_pmapped.__name__, axis_name)
|
|
|
|
return f_pmapped
|
2019-01-25 08:20:33 -08:00
|
|
|
|
2019-06-23 15:31:13 -07:00
|
|
|
class _TempAxisName(object):
|
|
|
|
def __repr__(self):
|
2019-06-23 20:01:53 -07:00
|
|
|
return '<axis {}>'.format(hex(id(self)))
|
2019-06-23 15:31:13 -07:00
|
|
|
|
2019-05-02 22:13:49 -07:00
|
|
|
def _pmap_axis_size(args):
|
|
|
|
leaves, _ = tree_flatten(args)
|
2019-05-06 22:43:31 -07:00
|
|
|
axis_sizes = reduce(set.union, map(_axis_size, leaves), set())
|
2019-05-02 22:13:49 -07:00
|
|
|
if len(axis_sizes) == 0:
|
2019-05-06 16:18:34 -07:00
|
|
|
raise ValueError("pmap requires a leading axis to map over.")
|
2019-05-02 22:13:49 -07:00
|
|
|
if len(axis_sizes) > 1:
|
|
|
|
msg = "pmap requires all leading axes to have equal length, got {}."
|
2019-05-06 16:18:34 -07:00
|
|
|
raise ValueError(msg.format(axis_sizes))
|
2019-05-02 22:13:49 -07:00
|
|
|
return axis_sizes.pop()
|
|
|
|
|
2019-05-06 22:43:31 -07:00
|
|
|
def _axis_size(x):
|
|
|
|
if isinstance(x, core.Tracer):
|
|
|
|
aval = x.aval
|
|
|
|
else:
|
|
|
|
aval = xla.abstractify(x)
|
|
|
|
return _aval_axis_size(aval)
|
2019-05-02 22:13:49 -07:00
|
|
|
|
|
|
|
def _aval_axis_size(aval):
|
|
|
|
if isinstance(aval, core.AbstractTuple):
|
|
|
|
return reduce(set.union, map(_aval_axis_size, aval), set())
|
|
|
|
else:
|
2019-05-06 16:18:34 -07:00
|
|
|
if aval.shape:
|
|
|
|
return {aval.shape[0]}
|
|
|
|
else:
|
|
|
|
raise ValueError("pmap can't map over scalars.")
|
2019-05-02 22:13:49 -07:00
|
|
|
|
|
|
|
|
2019-06-23 15:31:13 -07:00
|
|
|
def soft_pmap(fun, axis_name=None):
|
2019-03-06 14:36:47 -08:00
|
|
|
axis_name = _TempAxisName() if axis_name is None else axis_name
|
|
|
|
|
2019-06-23 15:31:13 -07:00
|
|
|
@wraps(fun)
|
|
|
|
def f_pmapped(*args):
|
|
|
|
axis_size = _pmap_axis_size(args)
|
|
|
|
chunk_size, leftover = divmod(axis_size, pxla.unmapped_device_count())
|
|
|
|
if chunk_size == 0 and leftover:
|
|
|
|
return pmap(fun, axis_name)(*args) # can map directly onto hardware
|
|
|
|
elif leftover:
|
|
|
|
raise ValueError
|
|
|
|
num_chunks = axis_size // chunk_size
|
2019-06-23 20:01:53 -07:00
|
|
|
|
2019-06-23 15:31:13 -07:00
|
|
|
f = lu.wrap_init(fun)
|
2019-01-10 15:35:15 -08:00
|
|
|
in_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
|
|
|
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
|
2019-06-23 15:31:13 -07:00
|
|
|
reshaped_args = map(partial(_reshape_split, num_chunks), in_flat)
|
|
|
|
soft_mapped_fun = pxla.split_axis(jaxtree_fun, axis_name, chunk_size)
|
|
|
|
reshaped_out = pxla.xla_pmap(soft_mapped_fun, *reshaped_args,
|
|
|
|
axis_name=axis_name, axis_size=num_chunks)
|
|
|
|
return build_tree(out_tree(), _reshape_merge(reshaped_out))
|
2019-01-10 15:35:15 -08:00
|
|
|
|
2019-06-23 15:31:13 -07:00
|
|
|
namestr = "soft_pmap({}, axis_name={})".format
|
|
|
|
f_pmapped.__name__ = namestr(f_pmapped.__name__, axis_name)
|
|
|
|
return f_pmapped
|
2019-01-10 15:35:15 -08:00
|
|
|
|
2019-06-23 15:31:13 -07:00
|
|
|
def _reshape_split(num_chunks, arg):
|
|
|
|
def split(aval, arg):
|
|
|
|
t = type(aval)
|
|
|
|
if t is core.AbstractTuple:
|
|
|
|
return core.pack(map(split, aval, arg))
|
|
|
|
elif t is ShapedArray:
|
|
|
|
prefix = (num_chunks, arg.shape[0] // num_chunks)
|
|
|
|
return arg.reshape(prefix + arg.shape[1:])
|
|
|
|
else:
|
|
|
|
raise TypeError(aval)
|
2019-01-24 16:27:34 -08:00
|
|
|
|
2019-06-23 15:31:13 -07:00
|
|
|
return split(batching.get_aval(arg), arg)
|
2019-01-24 16:27:34 -08:00
|
|
|
|
2019-06-23 15:31:13 -07:00
|
|
|
def _reshape_merge(ans):
|
|
|
|
def merge(aval, ans):
|
|
|
|
t = type(aval)
|
|
|
|
if t is core.AbstractTuple:
|
|
|
|
return core.pack(map(merge, aval, ans))
|
|
|
|
elif t is ShapedArray:
|
|
|
|
return ans.reshape((-1,) + ans.shape[2:])
|
|
|
|
else:
|
|
|
|
raise TypeError(aval)
|
|
|
|
|
|
|
|
return merge(batching.get_aval(ans), ans)
|
|
|
|
|
|
|
|
|
2019-06-23 20:01:53 -07:00
|
|
|
def _papply(fun):
|
|
|
|
# This function is for testing purposes.
|
|
|
|
axis_name = _TempAxisName()
|
2019-01-10 15:35:15 -08:00
|
|
|
|
|
|
|
def papply_fun(*args, **kwargs):
|
2019-06-23 20:01:53 -07:00
|
|
|
axis_size = _pmap_axis_size(args)
|
2019-01-10 15:35:15 -08:00
|
|
|
f = lu.wrap_init(fun, kwargs)
|
|
|
|
args_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
|
|
|
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
|
2019-06-23 20:01:53 -07:00
|
|
|
out_flat = parallel.papply(jaxtree_fun, axis_name, args_flat, axis_size)
|
2019-01-10 15:35:15 -08:00
|
|
|
return build_tree(out_tree(), out_flat)
|
|
|
|
|
|
|
|
return papply_fun, axis_name
|
|
|
|
|
|
|
|
|
2019-06-23 20:01:53 -07:00
|
|
|
def _parallelize(fun):
|
|
|
|
axis_name = _TempAxisName()
|
2019-01-10 15:35:15 -08:00
|
|
|
|
2019-06-23 20:01:53 -07:00
|
|
|
def pfun(*args):
|
|
|
|
axis_size = _pmap_axis_size(args)
|
|
|
|
chunk_size, leftover = divmod(axis_size, pxla.unmapped_device_count())
|
|
|
|
if chunk_size == 0 and leftover:
|
|
|
|
return pmap(fun, axis_name)(*args) # can map directly onto hardware
|
|
|
|
elif leftover:
|
|
|
|
raise ValueError
|
|
|
|
num_chunks = axis_size // chunk_size
|
2019-01-10 15:35:15 -08:00
|
|
|
|
2019-06-23 20:01:53 -07:00
|
|
|
f = lu.wrap_init(fun)
|
|
|
|
args_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
|
|
|
args_flat = map(partial(_reshape_split, num_chunks), args_flat)
|
|
|
|
f, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
|
|
|
|
f, out_axis = parallel.papply_transform(f, axis_name, axis_size)
|
|
|
|
f = pxla.split_axis(f, axis_name, chunk_size)
|
|
|
|
out = pxla.xla_pmap(f, *args_flat, axis_name=axis_name, axis_size=num_chunks)
|
|
|
|
out = parallel.match_axis(0, out_axis(), _reshape_merge(out))
|
|
|
|
return build_tree(out_tree(), out)
|
|
|
|
|
|
|
|
return pfun
|
2019-01-10 15:35:15 -08:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def jvp(fun, primals, tangents):
|
2019-02-19 22:08:14 -05:00
|
|
|
"""Computes a (forward-mode) Jacobian-vector product of `fun`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: Function to be differentiated. Its arguments should be arrays, scalars,
|
|
|
|
or standard Python containers of arrays or scalars. It should return an
|
|
|
|
array, scalar, or standard Python container of arrays or scalars.
|
|
|
|
primals: The primal values at which the Jacobian of `fun` should be
|
|
|
|
evaluated. Should be a tuple of arrays, scalar, or standard Python
|
2019-02-20 08:08:10 -05:00
|
|
|
container thereof. The length of the tuple is equal to the number of
|
|
|
|
positional parameters of `fun`.
|
2019-02-19 22:08:14 -05:00
|
|
|
tangents: The tangent vector for which the Jacobian-vector product should be
|
|
|
|
evaluated. Should be a tuple of arrays, scalar, or standard Python
|
2019-02-20 08:08:10 -05:00
|
|
|
container thereof, with the same tree structure and array shapes as
|
|
|
|
`primals`.
|
2019-02-19 22:08:14 -05:00
|
|
|
|
|
|
|
Returns:
|
2019-02-20 08:08:10 -05:00
|
|
|
A `(primals_out, tangents_out)` pair, where `primals_out` is
|
|
|
|
`fun(*primals)`, and `tangents_out` is the Jacobian-vector product of
|
|
|
|
`function` evaluated at `primals` with `tangents`. The `tangents_out` value
|
|
|
|
has the same Python tree structure and shapes as `primals_out`.
|
2019-02-19 22:08:14 -05:00
|
|
|
|
|
|
|
For example:
|
|
|
|
|
2019-04-23 18:21:33 -07:00
|
|
|
>>> y, v = jax.jvp(jax.numpy.sin, (0.1,), (0.2,))
|
|
|
|
>>> print(y)
|
|
|
|
0.09983342
|
|
|
|
>>> print(v)
|
|
|
|
0.19900084
|
2019-02-19 22:08:14 -05:00
|
|
|
"""
|
2019-01-03 16:14:30 -08:00
|
|
|
def trim_arg(primal, tangent):
|
|
|
|
primal_jtuple, tree_def = pytree_to_jaxtupletree(primal)
|
|
|
|
tangent_jtuple, tree_def_2 = pytree_to_jaxtupletree(tangent)
|
2018-11-17 18:03:33 -08:00
|
|
|
assert tree_def == tree_def_2, (tree_def, tree_def_2)
|
|
|
|
return primal_jtuple, tangent_jtuple, tree_def
|
|
|
|
|
|
|
|
if not isinstance(fun, lu.WrappedFun):
|
|
|
|
fun = lu.wrap_init(fun)
|
2019-01-03 16:14:30 -08:00
|
|
|
ps_flat, ts_flat, in_trees = unzip3(map(trim_arg, primals, tangents))
|
|
|
|
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(fun, in_trees)
|
|
|
|
out_primal, out_tangent = ad.jvp(jaxtree_fun).call_wrapped(ps_flat, ts_flat)
|
2018-11-17 18:03:33 -08:00
|
|
|
return (build_tree(out_tree(), out_primal), build_tree(out_tree(), out_tangent))
|
|
|
|
|
2019-03-25 10:37:24 -07:00
|
|
|
def linearize(fun, *primals):
|
|
|
|
"""Produce a linear approximation to `fun` using `jvp` and partial evaluation.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: Function to be differentiated. Its arguments should be arrays, scalars,
|
|
|
|
or standard Python containers of arrays or scalars. It should return an
|
|
|
|
array, scalar, or standard python container of arrays or scalars.
|
|
|
|
primals: The primal values at which the Jacobian of `fun` should be
|
|
|
|
evaluated. Should be a tuple of arrays, scalar, or standard Python
|
|
|
|
container thereof. The length of the tuple is equal to the number of
|
|
|
|
positional parameters of `fun`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A pair where the first element is the value of `f(*primals)` and the second
|
|
|
|
element is a function that evaluates the (forward-mode) Jacobian-vector
|
|
|
|
product of `fun` evaluated at `primals` without re-doing the linearization
|
|
|
|
work.
|
|
|
|
|
2019-03-25 11:29:44 -07:00
|
|
|
In terms of values computed, `linearize` behaves much like a curried `jvp`,
|
|
|
|
where these two code blocks compute the same values::
|
2019-03-25 10:37:24 -07:00
|
|
|
y, out_tangent = jax.jvp(f, (x,), (in_tangent,))
|
|
|
|
|
|
|
|
y, f_jvp = jax.linearize(f, x)
|
|
|
|
out_tangent = f_jvp(in_tangent)
|
|
|
|
|
|
|
|
However, the difference is that `linearize` uses partial evaluation so that
|
|
|
|
the function `f` is not re-linearized on calls to `f_jvp`. In general that
|
|
|
|
means the memory usage scales with the size of the computation, much like in
|
|
|
|
reverse-mode. (Indeed, `linearize` has a similar signature to `vjp`!)
|
|
|
|
|
2019-03-25 11:03:03 -07:00
|
|
|
This function is mainly useful if you want to apply `f_jvp` multiple times,
|
|
|
|
i.e. to evaluate a pushforward for many different input tangent vectors at the
|
|
|
|
same linearization point. Moreover if all the input tangent vectors are known
|
|
|
|
at once, it can be more efficient to vectorize using `vmap`, as in::
|
|
|
|
pushfwd = partial(jvp, f, (x,))
|
|
|
|
y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
|
|
|
|
By using `vmap` and `jvp` together like this we avoid the stored-linearization
|
|
|
|
memory cost that scales with the depth of the computation, which is incurred
|
|
|
|
by both `linearize` and `vjp`.
|
|
|
|
|
|
|
|
Here's a more complete example of using `linearize`:
|
2019-03-25 10:37:24 -07:00
|
|
|
|
2019-03-25 11:11:57 -07:00
|
|
|
>>> def f(x): return 3. * np.sin(x) + np.cos(x / 2.)
|
|
|
|
...
|
|
|
|
>>> jax.jvp(f, (2.,), (3.,))
|
|
|
|
(array(3.2681944, dtype=float32), array(-5.007528, dtype=float32))
|
|
|
|
>>> y, f_jvp = jax.linearize(f, 2.)
|
2019-04-23 18:21:33 -07:00
|
|
|
>>> print(y)
|
|
|
|
3.2681944
|
|
|
|
>>> print(f_jvp(3.))
|
|
|
|
-5.007528
|
|
|
|
>>> print(f_jvp(4.))
|
|
|
|
-6.676704
|
2019-03-25 10:37:24 -07:00
|
|
|
"""
|
|
|
|
f = lu.wrap_init(fun)
|
2019-01-03 16:14:30 -08:00
|
|
|
primals_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, primals))
|
2019-03-25 10:37:24 -07:00
|
|
|
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
|
2019-01-03 16:14:30 -08:00
|
|
|
out_primal, out_pval, jaxpr, consts = ad.linearize(jaxtree_fun, *primals_flat)
|
2018-11-17 18:03:33 -08:00
|
|
|
out_tree = out_tree()
|
|
|
|
out_primal_py = build_tree(out_tree, out_primal)
|
2019-06-18 09:18:44 -07:00
|
|
|
primal_avals = list(map(core.get_aval, primals_flat))
|
|
|
|
lifted_jvp = partial(lift_linearized, jaxpr, primal_avals, consts,
|
|
|
|
(in_trees, out_tree), out_pval)
|
2018-11-17 18:03:33 -08:00
|
|
|
return out_primal_py, lifted_jvp
|
|
|
|
|
2019-06-18 09:18:44 -07:00
|
|
|
def lift_linearized(jaxpr, primal_avals, consts, io_tree, out_pval, *py_args):
|
|
|
|
def fun(*tangents):
|
|
|
|
tangent_avals = list(map(core.get_aval, tangents))
|
|
|
|
for primal_aval, tangent_aval in zip(primal_avals, tangent_avals):
|
|
|
|
try:
|
|
|
|
core.lattice_join(primal_aval, tangent_aval)
|
|
|
|
except TypeError:
|
|
|
|
msg = ("linearized function called on tangent values inconsistent with "
|
|
|
|
"the original primal values.")
|
|
|
|
raise ValueError(msg)
|
|
|
|
primals = pack(tangents) # doesn't matter what these are-they'll be ignored
|
|
|
|
tangents = pack(tangents)
|
2018-12-03 22:24:46 -05:00
|
|
|
_, ans = eval_jaxpr(jaxpr, consts, (), primals, tangents)
|
|
|
|
return pe.merge_pvals(ans, out_pval)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-01-03 16:14:30 -08:00
|
|
|
return apply_jaxtree_fun(fun, io_tree, *py_args)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-06-24 10:45:42 -04:00
|
|
|
def _check_inexact_input_vjp(x):
|
|
|
|
aval = core.get_aval(x)
|
|
|
|
if not onp.issubdtype(aval.dtype, onp.inexact):
|
|
|
|
msg = ("Primal inputs to reverse-mode differentiation must be of float "
|
|
|
|
"or complex type, got type {}")
|
|
|
|
raise TypeError(msg.format(aval.dtype.name))
|
|
|
|
|
|
|
|
|
2019-03-07 14:08:02 -08:00
|
|
|
def vjp(fun, *primals, **kwargs):
|
2019-02-19 22:08:14 -05:00
|
|
|
"""Compute a (reverse-mode) vector-Jacobian product of `fun`.
|
|
|
|
|
2019-02-20 08:08:10 -05:00
|
|
|
`grad` is implemented as a special case of `vjp`.
|
2019-02-19 22:08:14 -05:00
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: Function to be differentiated. Its arguments should be arrays, scalars,
|
|
|
|
or standard Python containers of arrays or scalars. It should return an
|
|
|
|
array, scalar, or standard Python container of arrays or scalars.
|
2019-02-20 08:08:10 -05:00
|
|
|
primals: A sequence of primal values at which the Jacobian of `fun`
|
|
|
|
should be evaluated. The length of `primals` should be equal to the number
|
|
|
|
of positional parameters to `fun`. Each primal value should be a tuple of
|
|
|
|
arrays, scalar, or standard Python containers thereof.
|
2019-03-07 14:08:02 -08:00
|
|
|
has_aux: Optional, bool. Indicates whether `fun` returns a pair where the
|
2019-03-07 14:40:48 -08:00
|
|
|
first element is considered the output of the mathematical function to be
|
2019-03-07 14:08:02 -08:00
|
|
|
differentiated and the second element is auxiliary data. Default False.
|
2019-02-19 22:08:14 -05:00
|
|
|
|
|
|
|
Returns:
|
2019-02-20 08:08:10 -05:00
|
|
|
A `(primals_out, vjpfun)` pair, where `primals_out` is `fun(*primals)`.
|
|
|
|
`vjpfun` is a function from a cotangent vector with the same shape as
|
|
|
|
`primals_out` to a tuple of cotangent vectors with the same shape as
|
|
|
|
`primals`, representing the vector-Jacobian product of `fun` evaluated at
|
|
|
|
`primals`.
|
2019-02-19 22:08:14 -05:00
|
|
|
|
|
|
|
>>> def f(x, y):
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
... return jax.numpy.sin(x), jax.numpy.cos(y)
|
|
|
|
...
|
2019-04-23 18:21:33 -07:00
|
|
|
>>> primals, f_vjp = jax.vjp(f, 0.5, 1.0)
|
|
|
|
>>> xbar, ybar = f_vjp((-0.7, 0.3))
|
|
|
|
>>> print(xbar)
|
|
|
|
-0.61430776
|
|
|
|
>>> print(ybar)
|
|
|
|
-0.2524413
|
2019-02-19 22:08:14 -05:00
|
|
|
"""
|
2019-03-07 14:08:02 -08:00
|
|
|
has_aux = kwargs.pop('has_aux', False)
|
|
|
|
assert not kwargs
|
2018-11-17 18:03:33 -08:00
|
|
|
if not isinstance(fun, lu.WrappedFun):
|
|
|
|
fun = lu.wrap_init(fun)
|
2019-01-03 16:14:30 -08:00
|
|
|
primals_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, primals))
|
2019-02-20 08:04:48 -08:00
|
|
|
_check_args(primals_flat)
|
2019-06-24 10:45:42 -04:00
|
|
|
tree_map(_check_inexact_input_vjp, primals)
|
2019-01-03 16:14:30 -08:00
|
|
|
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(fun, in_trees)
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
|
|
|
out_primal, out_vjp = ad.vjp(jaxtree_fun, primals_flat)
|
|
|
|
else:
|
|
|
|
out_primal, out_vjp, aux = ad.vjp(jaxtree_fun, primals_flat, has_aux=True)
|
2018-11-17 18:03:33 -08:00
|
|
|
out_tree = out_tree()
|
2019-03-07 14:08:02 -08:00
|
|
|
if has_aux:
|
|
|
|
out_tree, aux_tree = out_tree.children
|
2018-11-17 18:03:33 -08:00
|
|
|
out_primal_py = build_tree(out_tree, out_primal)
|
|
|
|
ct_in_trees = [out_tree]
|
|
|
|
ct_out_tree = PyTreeDef(node_types[tuple], None, in_trees)
|
|
|
|
def out_vjp_packed(cotangent_in):
|
|
|
|
return out_vjp(cotangent_in)
|
2019-01-03 16:14:30 -08:00
|
|
|
vjp_py = partial(apply_jaxtree_fun, out_vjp_packed, (ct_in_trees, ct_out_tree))
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
|
|
|
return out_primal_py, vjp_py
|
|
|
|
else:
|
|
|
|
return out_primal_py, vjp_py, build_tree(aux_tree, aux)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def trace_to_jaxpr(traceable, py_pvals, **kwargs):
|
2019-04-10 22:09:14 -07:00
|
|
|
fun = lu.wrap_init(traceable, kwargs)
|
2018-11-17 18:03:33 -08:00
|
|
|
pvals, in_trees = unzip2(map(tree_to_pval_tuples, py_pvals))
|
2019-01-03 16:14:30 -08:00
|
|
|
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(fun, in_trees)
|
2019-04-10 22:09:14 -07:00
|
|
|
jaxpr, out_pval, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
|
2018-11-17 18:03:33 -08:00
|
|
|
return jaxpr, consts, out_pval, (in_trees, out_tree())
|
|
|
|
|
|
|
|
def lift_jaxpr(jaxpr, consts, io_tree, pvals, py_args):
|
|
|
|
def fun(*args):
|
|
|
|
ans = eval_jaxpr(jaxpr, consts, (), *args)
|
|
|
|
return pe.merge_pvals(ans, pvals)
|
2019-01-03 16:14:30 -08:00
|
|
|
return apply_jaxtree_fun(fun, io_tree, *py_args)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-02-14 11:58:18 -05:00
|
|
|
def make_jaxpr(fun):
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
"""Creates a function that produces its jaxpr given example args.
|
2019-02-14 11:58:18 -05:00
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: The function whose `jaxpr` is to be computed. Its positional arguments
|
|
|
|
and return value should be arrays, scalars, or standard Python containers
|
|
|
|
(tuple/list/dict) thereof.
|
|
|
|
|
|
|
|
Returns:
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
A wrapped version of `fun` that when applied to example arguments returns a
|
|
|
|
jaxpr representation of `fun` on those arguments.
|
2019-02-14 11:58:18 -05:00
|
|
|
|
|
|
|
A `jaxpr` is JAX's intermediate representation for program traces. The `jaxpr`
|
2019-02-14 19:08:04 -05:00
|
|
|
language is based on the simply-typed first-order lambda calculus with
|
|
|
|
let-bindings. `make_jaxpr` adapts a function to return its `jaxpr`, which we
|
|
|
|
can inspect to understand what JAX is doing internally.
|
|
|
|
|
|
|
|
The `jaxpr` returned is a trace of `fun` abstracted to `ShapedArray` level.
|
|
|
|
Other levels of abstraction exist internally.
|
2019-02-14 11:58:18 -05:00
|
|
|
|
|
|
|
We do not describe the semantics of the `jaxpr` language in detail here, but
|
|
|
|
instead give a few examples.
|
|
|
|
|
|
|
|
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
2019-04-23 18:21:33 -07:00
|
|
|
>>> print(f(3.0))
|
|
|
|
-0.83602184
|
2019-02-14 11:58:18 -05:00
|
|
|
>>> jax.make_jaxpr(f)(3.0)
|
|
|
|
{ lambda ; ; a.
|
|
|
|
let b = cos a
|
|
|
|
c = sin b
|
|
|
|
in c }
|
|
|
|
>>> jax.make_jaxpr(jax.grad(f))(3.0)
|
|
|
|
{ lambda b ; ; a.
|
|
|
|
let c = pack a
|
|
|
|
(d) = id c
|
|
|
|
e = cos d
|
|
|
|
f = cos e
|
|
|
|
g = mul b f
|
|
|
|
h = neg g
|
|
|
|
i = sin d
|
|
|
|
j = mul h i
|
|
|
|
k = pack j
|
|
|
|
(l) = id k
|
|
|
|
in l }
|
|
|
|
"""
|
2018-12-16 13:26:02 -08:00
|
|
|
def pv_like(x):
|
2018-12-19 10:59:13 -05:00
|
|
|
aval = xla.abstractify(x)
|
2018-12-16 13:26:02 -08:00
|
|
|
return pe.PartialVal((aval, core.unit))
|
|
|
|
|
2019-02-14 11:58:18 -05:00
|
|
|
@wraps(fun)
|
2018-12-19 10:59:13 -05:00
|
|
|
def jaxpr_maker(*args, **kwargs):
|
2019-04-10 22:09:14 -07:00
|
|
|
wrapped = lu.wrap_init(fun, kwargs)
|
2019-01-03 16:14:30 -08:00
|
|
|
jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
2019-02-14 11:58:18 -05:00
|
|
|
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(wrapped, in_trees)
|
2018-12-19 10:59:13 -05:00
|
|
|
pvals = map(pv_like, jax_args)
|
2019-04-10 22:09:14 -07:00
|
|
|
jaxpr, _, _ = pe.trace_to_jaxpr(jaxtree_fun, pvals)
|
2018-12-16 13:26:02 -08:00
|
|
|
return jaxpr
|
|
|
|
|
|
|
|
jaxpr_maker.__name__ = "make_jaxpr({})".format(jaxpr_maker.__name__)
|
|
|
|
return jaxpr_maker
|
|
|
|
|
2019-01-03 16:14:30 -08:00
|
|
|
tree_to_pval_tuples = partial(process_pytree, pe.pack_pvals)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-01 16:37:14 -04:00
|
|
|
def device_put(x, device_num=0):
|
2019-07-02 12:18:47 -04:00
|
|
|
return tree_map(lambda y: xla.device_put_p.bind(y, device_num=device_num), x)
|
2019-07-01 16:37:14 -04:00
|
|
|
|
|
|
|
|
2019-05-03 12:37:14 -07:00
|
|
|
device_get = _jit(lambda x: x, (), device_values=False)
|
2018-11-21 18:07:24 -08:00
|
|
|
|
|
|
|
|
2019-02-20 08:04:48 -08:00
|
|
|
def _argnums_partial(f, dyn_argnums, args):
|
2018-11-17 18:03:33 -08:00
|
|
|
if isinstance(dyn_argnums, int):
|
|
|
|
dyn_argnums = (dyn_argnums,)
|
|
|
|
else:
|
|
|
|
dyn_argnums = tuple(dyn_argnums)
|
2019-05-09 20:00:24 -07:00
|
|
|
fixed_args = tuple([None if i in dyn_argnums else _wrap_hashably(arg)
|
2018-11-30 16:16:28 -05:00
|
|
|
for i, arg in enumerate(args)])
|
2019-01-06 11:59:33 -08:00
|
|
|
dyn_args = tuple(args[i] for i in dyn_argnums)
|
2019-02-20 08:04:48 -08:00
|
|
|
return _argnums_partial_(f, dyn_argnums, fixed_args), dyn_args
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-05-09 20:00:24 -07:00
|
|
|
def _wrap_hashably(arg):
|
|
|
|
try:
|
|
|
|
hash(arg)
|
|
|
|
except TypeError:
|
2019-06-19 10:12:13 -07:00
|
|
|
return WrapHashably(arg) # e.g. ndarrays, DeviceArrays
|
2019-05-09 20:00:24 -07:00
|
|
|
else:
|
|
|
|
return Hashable(arg)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
@lu.transformation
|
2019-04-11 06:58:09 -07:00
|
|
|
def _argnums_partial_(dyn_argnums, fixed_args, *dyn_args, **kwargs):
|
2018-11-30 16:16:28 -05:00
|
|
|
args = [None if arg is None else arg.val for arg in fixed_args]
|
2018-11-17 18:03:33 -08:00
|
|
|
for i, arg in zip(dyn_argnums, dyn_args):
|
|
|
|
args[i] = arg
|
2019-04-11 06:58:09 -07:00
|
|
|
ans = yield args, kwargs
|
2018-11-17 18:03:33 -08:00
|
|
|
yield ans
|
|
|
|
|
2019-02-20 08:04:48 -08:00
|
|
|
def _check_args(args):
|
2018-11-17 18:03:33 -08:00
|
|
|
for arg in args:
|
2019-05-06 22:43:31 -07:00
|
|
|
if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
|
2018-11-17 18:03:33 -08:00
|
|
|
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
|
|
|
|
.format(arg, type(arg)))
|
|
|
|
|
2019-05-06 22:43:31 -07:00
|
|
|
def _valid_jaxtype(arg):
|
|
|
|
try:
|
|
|
|
xla.abstractify(arg)
|
|
|
|
except TypeError:
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return True
|
|
|
|
|
2019-02-13 14:28:30 -08:00
|
|
|
|
2019-06-03 07:17:37 -07:00
|
|
|
class CustomTransformsFunction(object):
|
|
|
|
def __init__(self, fun, prim):
|
|
|
|
self.fun = fun
|
|
|
|
self.prim = prim
|
|
|
|
wraps(fun)(self)
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return '<jax.custom_transforms function {fun}>'.format(fun=self.__name__)
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
2019-06-26 16:22:21 +01:00
|
|
|
def pv_like(x):
|
2019-06-27 14:13:20 +01:00
|
|
|
return pe.PartialVal((batching.get_aval(x), core.unit)) # Use shaped aval
|
2019-06-03 07:17:37 -07:00
|
|
|
jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
|
|
|
jax_kwargs, kwargs_tree = pytree_to_jaxtupletree(kwargs)
|
2019-06-26 16:22:21 +01:00
|
|
|
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun2(
|
2019-06-26 16:49:04 +01:00
|
|
|
lu.wrap_init(self.fun), kwargs_tree, in_trees)
|
2019-06-26 16:22:21 +01:00
|
|
|
pvals_in = map(pv_like, (jax_kwargs,) + jax_args)
|
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals_in, instantiate=True)
|
2019-06-26 16:49:04 +01:00
|
|
|
ans = self.prim.bind(core.pack(consts), jax_kwargs, *jax_args,
|
|
|
|
in_trees=in_trees, jaxpr=jaxpr)
|
2019-06-26 16:22:21 +01:00
|
|
|
return build_tree(out_tree(), ans)
|
2019-06-03 07:17:37 -07:00
|
|
|
|
2019-03-23 15:11:21 -07:00
|
|
|
def custom_transforms(fun):
|
2019-06-05 13:48:04 -07:00
|
|
|
"""Wraps a function so that its transformation behavior can be controlled.
|
2019-06-05 13:20:44 -07:00
|
|
|
|
|
|
|
A primary use case of ``custom_transforms`` is defining custom VJP rules (aka
|
|
|
|
custom gradients) for a Python function, while still supporting other
|
|
|
|
transformations like ``jax.jit`` and ``jax.vmap``. Custom differentiation
|
|
|
|
rules can be supplied using the ``jax.defjvp`` and ``jax.defvjp`` functions.
|
|
|
|
|
|
|
|
The ``custom_transforms`` decorator wraps ``fun`` so that its transformation
|
|
|
|
behavior can be overridden, but not all transformation rules need to be
|
2019-06-06 10:12:07 -07:00
|
|
|
specified manually. The default behavior is retained for any non-overridden
|
|
|
|
rules.
|
2019-06-05 13:20:44 -07:00
|
|
|
|
2019-07-02 13:47:59 +01:00
|
|
|
The function ``fun`` must satisfy the same constraints required for jit
|
|
|
|
compilation. In particular the shapes of arrays in the computation of ``fun``
|
|
|
|
may depend on the shapes of ``fun``'s arguments, but not their values.
|
|
|
|
Value dependent Python control flow is also not yet supported.
|
|
|
|
|
2019-06-05 13:20:44 -07:00
|
|
|
Args:
|
2019-06-05 13:48:04 -07:00
|
|
|
fun: a Python callable. Must be functionally pure. Its arguments and return
|
2019-06-05 13:20:44 -07:00
|
|
|
value should be arrays, scalars, or (nested) standard Python containers
|
|
|
|
(tuple/list/dict) thereof.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A Python callable with the same input/output and transformation behavior as
|
|
|
|
``fun``, but for which custom transformation rules can be supplied, e.g.
|
|
|
|
using ``jax.defvjp``.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
>>> @jax.custom_transforms
|
|
|
|
... def f(x):
|
|
|
|
... return np.sin(x ** 2)
|
|
|
|
...
|
|
|
|
>>> print(f(3.))
|
|
|
|
0.4121185
|
|
|
|
>>> print(jax.grad(f)(3.))
|
|
|
|
-5.4667816
|
|
|
|
>>> jax.defvjp(f, lambda g, x: g * x)
|
|
|
|
>>> print(jax.grad(f)(3.))
|
|
|
|
3.0
|
|
|
|
"""
|
2019-06-03 07:17:37 -07:00
|
|
|
name = getattr(fun, '__name__', '<unnamed custom_transforms primitive>')
|
2019-02-13 14:28:30 -08:00
|
|
|
fun_p = core.Primitive(name)
|
|
|
|
|
2019-06-26 16:22:21 +01:00
|
|
|
def fun_impl(consts, jax_kwargs, *jax_args, **params):
|
|
|
|
return core.eval_jaxpr(params['jaxpr'], consts, (), jax_kwargs, *jax_args)
|
2019-06-03 07:17:37 -07:00
|
|
|
fun_p.def_impl(fun_impl)
|
|
|
|
|
|
|
|
def fun_jvp(primals, tangents, **params):
|
|
|
|
return ad.jvp(lu.wrap_init(fun_impl, params)).call_wrapped(primals, tangents)
|
|
|
|
ad.primitive_jvps[fun_p] = fun_jvp
|
|
|
|
|
|
|
|
def fun_batch(batched_args, batch_dims, **params):
|
|
|
|
out = batching.batch(lu.wrap_init(fun_impl, params), batched_args, batch_dims, 0)
|
|
|
|
return out, 0
|
|
|
|
batching.primitive_batchers[fun_p] = fun_batch
|
|
|
|
|
2019-06-26 16:22:21 +01:00
|
|
|
def fun_abstract_eval(*avals, **params):
|
|
|
|
return pe.abstract_eval_fun(fun_impl, *avals, **params)
|
|
|
|
fun_p.def_abstract_eval(fun_abstract_eval)
|
|
|
|
|
2019-06-26 16:49:04 +01:00
|
|
|
def fun_translation(c, *xla_args, **params):
|
2019-07-03 08:13:34 +01:00
|
|
|
return xla.lower_fun(fun_impl, True)(c, *xla_args, **params)
|
2019-06-26 16:22:21 +01:00
|
|
|
xla.translations[fun_p] = fun_translation
|
2019-06-03 07:17:37 -07:00
|
|
|
|
|
|
|
return CustomTransformsFunction(fun, fun_p)
|
|
|
|
|
|
|
|
def _check_custom_transforms_type(name, fun):
|
|
|
|
if type(fun) is not CustomTransformsFunction:
|
|
|
|
msg = ("{} requires a custom_transforms function as its first argument, "
|
|
|
|
"but got type {}.")
|
|
|
|
raise TypeError(msg.format(name, type(fun)))
|
|
|
|
|
|
|
|
def defjvp_all(fun, custom_jvp):
|
2019-06-05 19:13:33 -07:00
|
|
|
"""Define a custom JVP rule for a ``custom_transforms`` function.
|
2019-06-05 16:56:43 -07:00
|
|
|
|
|
|
|
If ``fun`` represents a function with signature ``a -> b``, then
|
2019-07-13 20:08:46 -07:00
|
|
|
``custom_jvp`` represents a function with signature ``(a, T a) -> (b, T b)``,
|
2019-06-05 16:56:43 -07:00
|
|
|
where we use ``T x`` to represent a tangent type for the type ``x``.
|
|
|
|
|
2019-06-06 10:12:07 -07:00
|
|
|
In more detail, ``custom_jvp`` must take two arguments, both tuples of length
|
|
|
|
equal to the number of positional arguments to ``fun``. The first argument to
|
2019-06-05 17:56:18 -07:00
|
|
|
``custom_jvp`` represents the input primal values, and the second represents
|
|
|
|
the input tangent values. ``custom_jvp`` must return a pair where the first
|
|
|
|
element represents the output primal value and the second element represents
|
|
|
|
the output tangent value.
|
|
|
|
|
|
|
|
Defining a custom JVP rule also affects the default VJP rule, which is derived
|
|
|
|
from the JVP rule automatically via transposition.
|
2019-06-05 16:56:43 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: a custom_transforms function.
|
|
|
|
custom_jvp: a Python callable specifying the JVP rule, taking two tuples as
|
|
|
|
arguments specifying the input primal values and tangent values,
|
|
|
|
respectively. The tuple elements can be arrays, scalars, or (nested)
|
2019-06-06 10:12:07 -07:00
|
|
|
standard Python containers (tuple/list/dict) thereof. The output must be a
|
|
|
|
pair representing the primal output and tangent output, which can be
|
|
|
|
arrays, scalars, or (nested) standard Python containers. Must be
|
|
|
|
functionally pure.
|
2019-06-05 16:56:43 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
None. A side-effect is that ``fun`` is associated with the JVP rule
|
|
|
|
specified by ``custom_jvp``.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
>>> @jax.custom_transforms
|
|
|
|
... def f(x):
|
|
|
|
... return np.sin(x ** 2)
|
|
|
|
...
|
|
|
|
>>> print(f(3.))
|
|
|
|
0.4121185
|
|
|
|
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
|
|
|
|
>>> print(out_primal)
|
|
|
|
0.4121185
|
|
|
|
>>> print(out_tangent)
|
|
|
|
-10.933563
|
|
|
|
>>> jax.defjvp_all(f, lambda ps, ts: (np.sin(ps[0] ** 2), 8. * ts[0]))
|
|
|
|
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
|
|
|
|
>>> print(out_primal)
|
|
|
|
0.4121185
|
|
|
|
>>> print(out_tangent)
|
|
|
|
16.0
|
|
|
|
"""
|
2019-06-03 07:17:37 -07:00
|
|
|
_check_custom_transforms_type("defjvp_all", fun)
|
|
|
|
def custom_transforms_jvp(primals, tangents, **params):
|
2019-06-26 16:22:21 +01:00
|
|
|
consts, jax_kwargs, jax_args = primals[0], primals[1], primals[2:]
|
2019-06-27 15:35:12 +01:00
|
|
|
consts_dot, _, jax_args_dot = tangents[0], tangents[1], tangents[2:]
|
|
|
|
if consts_dot is not ad_util.zero:
|
|
|
|
msg = (
|
|
|
|
"Detected differentiation w.r.t. variables from outside the scope of "
|
|
|
|
"{}, but defjvp and defjvp_all only support differentiation w.r.t. "
|
|
|
|
"positional arguments.")
|
|
|
|
raise ValueError(msg.format(str(fun)))
|
2019-06-03 07:17:37 -07:00
|
|
|
if jax_kwargs:
|
|
|
|
msg = ("defjvp_all requires the corresponding custom_transforms function "
|
|
|
|
"not to be called with keyword arguments.")
|
|
|
|
raise ValueError(msg)
|
|
|
|
in_trees = params['in_trees']
|
|
|
|
args = tuple(map(build_tree, in_trees, jax_args))
|
|
|
|
args_dot = tuple(map(build_tree, in_trees, jax_args_dot))
|
|
|
|
pytree_out, pytree_out_dot = custom_jvp(args, args_dot)
|
|
|
|
out, out_tree = pytree_to_jaxtupletree(pytree_out)
|
|
|
|
out_dot, out_tree2 = pytree_to_jaxtupletree(pytree_out_dot)
|
|
|
|
if out_tree != out_tree2:
|
|
|
|
msg = ("custom jvp rule returned different tree structures for primals "
|
|
|
|
"and tangents, but they must be equal: {} vs {}.")
|
|
|
|
raise TypeError(msg.format(out_tree, out_tree2))
|
|
|
|
return out, out_dot
|
|
|
|
ad.primitive_jvps[fun.prim] = custom_transforms_jvp
|
|
|
|
|
|
|
|
def defjvp(fun, *jvprules):
|
2019-06-05 17:56:18 -07:00
|
|
|
"""Definine JVP rules for each argument separately.
|
|
|
|
|
2019-06-05 19:13:33 -07:00
|
|
|
This function is a convenience wrapper around ``jax.defjvp_all`` for
|
|
|
|
separately defining JVP rules for each of the function's arguments. This
|
|
|
|
convenience wrapper does not provide a mechanism for depending on anything
|
2019-06-11 06:44:59 -07:00
|
|
|
other than the function arguments and its primal output value, though
|
|
|
|
depending on intermediate results is possible using ``jax.defjvp_all``.
|
2019-06-06 10:12:07 -07:00
|
|
|
|
2019-06-05 17:56:18 -07:00
|
|
|
The signature of each component JVP rule is ``lambda g, ans, *primals: ...``
|
|
|
|
where ``g`` represents the tangent of the corresponding positional argument,
|
|
|
|
``ans`` represents the output primal, and ``*primals`` represents all the
|
|
|
|
primal positional arguments.
|
|
|
|
|
|
|
|
Defining a custom JVP rule also affects the default VJP rule, which is derived
|
|
|
|
from the JVP rule automatically via transposition.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: a custom_transforms function.
|
|
|
|
*jvprules: a sequence of functions or Nones specifying the JVP rule for each
|
|
|
|
corresponding positional argument. When an element is None, it indicates
|
|
|
|
that the Jacobian from the corresponding input to the output is zero.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
None. A side-effect is that ``fun`` is associated with the JVP rule
|
|
|
|
specified by ``*jvprules``.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
>>> @jax.custom_transforms
|
|
|
|
... def f(x):
|
|
|
|
... return np.sin(x ** 2)
|
|
|
|
...
|
|
|
|
>>> print(f(3.))
|
|
|
|
0.4121185
|
|
|
|
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
|
|
|
|
>>> print(out_primal)
|
|
|
|
0.4121185
|
|
|
|
>>> print(out_tangent)
|
|
|
|
-10.933563
|
2019-06-11 06:44:59 -07:00
|
|
|
>>> jax.defjvp(f, lambda g, ans, x: 8. * g + ans)
|
2019-06-05 17:56:18 -07:00
|
|
|
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
|
|
|
|
>>> print(out_primal)
|
|
|
|
0.4121185
|
|
|
|
>>> print(out_tangent)
|
|
|
|
16.412119
|
|
|
|
"""
|
2019-06-11 06:44:59 -07:00
|
|
|
_check_custom_transforms_type("defjvp", fun)
|
2019-06-03 07:17:37 -07:00
|
|
|
def custom_jvp(primals, tangents):
|
|
|
|
ans = fun(*primals)
|
|
|
|
tangents_out = [rule(t, ans, *primals) for rule, t in zip(jvprules, tangents)
|
|
|
|
if rule is not None and t is not ad_util.zero]
|
|
|
|
return ans, reduce(ad.add_tangents, tangents_out, ad_util.zero)
|
|
|
|
defjvp_all(fun, custom_jvp)
|
|
|
|
|
|
|
|
def defvjp_all(fun, custom_vjp):
|
2019-06-05 19:13:33 -07:00
|
|
|
"""Define a custom VJP rule for a ``custom_transforms`` function.
|
|
|
|
|
|
|
|
If ``fun`` represents a function with signature ``a -> b``, then
|
|
|
|
``custom_vjp`` represents a function with signature ``a -> (b, CT b -> CT a)``
|
|
|
|
where we use ``CT x`` to represent a cotangent type for the type ``x``. That
|
|
|
|
is, ``custom_vjp`` should take the same arguments as ``fun`` and return a pair
|
|
|
|
where the first element represents the primal value of ``fun`` applied to the
|
|
|
|
arguments, and the second element is a VJP function that maps from output
|
|
|
|
cotangents to input cotangents, returning a tuple with length equal to the
|
|
|
|
number of positional arguments supplied to ``fun``.
|
|
|
|
|
|
|
|
The VJP function returned as the second element of the output of
|
|
|
|
``custom_vjp`` can close over intermediate values computed when evaluating the
|
|
|
|
primal value of ``fun``. That is, use lexical closure to share work between
|
|
|
|
the forward pass and the backward pass of reverse-mode automatic
|
|
|
|
differentiation.
|
|
|
|
|
|
|
|
See also ``jax.custom_gradient``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: a custom_transforms function.
|
|
|
|
custom_vjp: a Python callable specifying the VJP rule, taking the same
|
|
|
|
arguments as ``fun`` and returning a pair where the first elment is the
|
|
|
|
value of ``fun`` applied to the arguments and the second element is a
|
|
|
|
Python callable representing the VJP map from output cotangents to input
|
|
|
|
cotangents. The returned VJP function must accept a value with the same
|
|
|
|
shape as the value of ``fun`` applied to the arguments and must return a
|
|
|
|
tuple with length equal to the number of positional arguments to ``fun``.
|
|
|
|
Arguments can be arrays, scalars, or (nested) standard Python containers
|
|
|
|
(tuple/list/dict) thereof. Must be functionally pure.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
None. A side-effect is that ``fun`` is associated with the VJP rule
|
|
|
|
specified by ``custom_vjp``.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
>>> @jax.custom_transforms
|
|
|
|
... def f(x):
|
|
|
|
... return np.sin(x ** 2)
|
|
|
|
...
|
|
|
|
>>> print(f(3.))
|
|
|
|
0.4121185
|
|
|
|
>>> print(jax.grad(f)(3.))
|
|
|
|
-5.4667816
|
|
|
|
>>> jax.defvjp_all(f, lambda x: (np.sin(x ** 2), lambda g: (g * x,)))
|
|
|
|
>>> print(f(3.))
|
|
|
|
0.4121185
|
|
|
|
>>> print(jax.grad(f)(3.))
|
|
|
|
3.0
|
|
|
|
|
|
|
|
An example with a function on two arguments, so that the VJP function must
|
|
|
|
return a tuple of length two:
|
|
|
|
|
|
|
|
>>> @jax.custom_transforms
|
|
|
|
... def f(x, y):
|
|
|
|
... return x * y
|
|
|
|
...
|
|
|
|
>>> jax.defvjp_all(f, lambda x, y: (x * y, lambda g: (y, x)))
|
|
|
|
>>> print(f(3., 4.))
|
|
|
|
12.0
|
|
|
|
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
|
|
|
|
(4.0, 3.0)
|
|
|
|
"""
|
2019-06-03 07:17:37 -07:00
|
|
|
_check_custom_transforms_type("defvjp_all", fun)
|
2019-06-27 17:35:34 +01:00
|
|
|
def custom_transforms_vjp(argnums, consts, jax_kwargs, *jax_args, **params):
|
|
|
|
if 0 in argnums:
|
|
|
|
msg = (
|
|
|
|
"Detected differentiation w.r.t. variables from outside the scope of "
|
|
|
|
"{}, but defvjp and defvjp_all only support differentiation w.r.t. "
|
|
|
|
"positional arguments.")
|
|
|
|
raise ValueError(msg.format(str(fun)))
|
2019-06-03 07:17:37 -07:00
|
|
|
if jax_kwargs:
|
|
|
|
msg = ("defvjp_all requires the corresponding custom_transforms function "
|
|
|
|
"not to be called with keyword arguments.")
|
|
|
|
raise ValueError(msg)
|
|
|
|
args = map(build_tree, params['in_trees'], jax_args)
|
|
|
|
pytree_out, vjp_pytree = custom_vjp(*args)
|
|
|
|
out, out_tree = pytree_to_jaxtupletree(pytree_out)
|
2019-06-05 19:13:33 -07:00
|
|
|
def vjp_pytree_(ct):
|
|
|
|
args_cts = tuple(vjp_pytree(ct))
|
|
|
|
if len(args_cts) != len(params['in_trees']):
|
|
|
|
msg = ("custom VJP function must return a tuple of length equal to the "
|
|
|
|
"number of positional arguments to the function being "
|
|
|
|
"differentiated: expected {}, got {}")
|
|
|
|
raise TypeError(msg.format(len(params['in_trees']), len(args_cts)))
|
2019-06-26 16:22:21 +01:00
|
|
|
return ((), {},) + args_cts
|
2019-06-03 07:17:37 -07:00
|
|
|
vjp, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(vjp_pytree_), (out_tree,))
|
|
|
|
return out, vjp.call_wrapped
|
2019-06-27 17:35:34 +01:00
|
|
|
ad.defvjp_argnums(fun.prim, custom_transforms_vjp)
|
2019-06-03 07:17:37 -07:00
|
|
|
|
|
|
|
def defvjp(fun, *vjprules):
|
2019-06-05 19:13:33 -07:00
|
|
|
"""Define VJP rules for each argument separately.
|
|
|
|
|
|
|
|
This function is a convenience wrapper around ``jax.defvjp_all`` for
|
|
|
|
separately defining VJP rules for each of the function's arguments. This
|
|
|
|
convenience wrapper does not provide a mechanism for depending on anything
|
2019-06-11 06:44:59 -07:00
|
|
|
other than the function arguments and its primal output value, though
|
|
|
|
depending on intermediate results is possible using ``jax.defvjp_all``.
|
2019-06-05 19:13:33 -07:00
|
|
|
|
|
|
|
The signature of each component VJP rule is ``lambda g, ans, *primals: ...``
|
|
|
|
where ``g`` represents the output cotangent, ``ans`` represents the output
|
|
|
|
primal, and ``*primals`` represents all the primal positional arguments.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: a custom_transforms function.
|
|
|
|
*vjprules: a sequence of functions or Nones specifying the VJP rule for each
|
|
|
|
corresponding positional argument. When an element is None, it indicates
|
|
|
|
that the Jacobian from the corresponding input to the output is zero.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
None. A side-effect is that ``fun`` is associated with the VJP rule
|
|
|
|
specified by ``*vjprules``.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
>>> @jax.custom_transforms
|
|
|
|
... def f(x, y):
|
|
|
|
... return np.sin(x ** 2 + y)
|
|
|
|
...
|
|
|
|
>>> print(f(3., 4.))
|
|
|
|
0.42016703
|
|
|
|
>>> print(jax.grad(f)(3., 4.))
|
|
|
|
5.4446807
|
|
|
|
>>> print(jax.grad(f, 1)(3., 4.))
|
|
|
|
0.9074468
|
2019-06-11 06:44:59 -07:00
|
|
|
>>> jax.defvjp(f, None, lambda g, ans, x, y: g + x + y + ans)
|
2019-06-05 19:13:33 -07:00
|
|
|
>>> print(jax.grad(f)(3., 4.))
|
|
|
|
0.0
|
|
|
|
>>> print(jax.grad(f, 1)(3., 4.))
|
|
|
|
8.420167
|
|
|
|
"""
|
2019-06-11 06:44:59 -07:00
|
|
|
_check_custom_transforms_type("defvjp", fun)
|
2019-06-03 07:17:37 -07:00
|
|
|
def custom_vjp(*primals):
|
|
|
|
ans = fun(*primals)
|
|
|
|
# TODO(mattjj): avoid instantiating zeros?
|
|
|
|
vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else ad_util.zeros_like_jaxval(x)
|
|
|
|
for x, vjp in zip(primals, vjprules)]
|
|
|
|
return ans, vjpfun
|
|
|
|
defvjp_all(fun, custom_vjp)
|
2019-02-13 14:28:30 -08:00
|
|
|
|
2019-06-05 13:48:04 -07:00
|
|
|
def custom_gradient(fun):
|
|
|
|
"""Convenience function for defining custom VJP rules (aka custom gradients).
|
|
|
|
|
|
|
|
While the canonical way to define custom VJP rules is via ``jax.defvjp_all``
|
|
|
|
and its convenience wrappers, the ``custom_gradient`` convenience wrapper
|
2019-06-11 06:44:59 -07:00
|
|
|
follows TensorFlow's ``tf.custom_gradient`` API. The difference here is that
|
|
|
|
``custom_gradient`` can be used as a decorator on one function that returns
|
|
|
|
both the primal value (representing the output of the mathematical function to
|
|
|
|
be differentiated) and the VJP (gradient) function.
|
2019-06-05 13:48:04 -07:00
|
|
|
|
|
|
|
See https://www.tensorflow.org/api_docs/python/tf/custom_gradient.
|
|
|
|
|
2019-06-11 06:44:59 -07:00
|
|
|
If the mathematical function to be differentiated has type signature
|
|
|
|
``a -> b``, then the Python callable ``fun`` should have signature
|
|
|
|
``a -> (b, CT b -> CT a)`` where we use ``CT x`` to denote a cotangent type
|
|
|
|
for ``x``. See the example below. That is, ``fun`` should return a pair where
|
|
|
|
the first element represents the value of the mathematical function to be
|
|
|
|
differentiated and the second element is a function that represents the custom
|
|
|
|
VJP rule.
|
2019-06-05 18:02:15 -07:00
|
|
|
|
|
|
|
The custom VJP function returned as the second element of the output of ``fun``
|
|
|
|
can close over intermediate values computed when evaluating the function to be
|
|
|
|
differentiated. That is, use lexical closure to share work between the forward
|
|
|
|
pass and the backward pass of reverse-mode automatic differentiation.
|
2019-06-05 13:48:04 -07:00
|
|
|
|
|
|
|
Args:
|
2019-06-11 06:44:59 -07:00
|
|
|
fun: a Python callable specifying both the mathematical function to be
|
|
|
|
differentiated and its reverse-mode differentiation rule. It should return
|
|
|
|
a pair consisting of an output value and a Python callable that represents
|
|
|
|
the custom gradient function.
|
2019-06-05 13:48:04 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A Python callable with signature ``a -> b``, i.e. that returns the output
|
|
|
|
value specified by the first element of ``fun``'s output pair. A side effect
|
2019-06-05 19:18:36 -07:00
|
|
|
is that under-the-hood ``jax.defvjp_all`` is called to set up the returned
|
|
|
|
Python callable with the custom VJP rule specified by the second element
|
2019-06-05 13:48:04 -07:00
|
|
|
of ``fun``'s output pair.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
>>> @jax.custom_gradient
|
|
|
|
... def f(x):
|
2019-06-05 19:13:33 -07:00
|
|
|
... return x ** 2, lambda g: (g * x,)
|
2019-06-05 13:48:04 -07:00
|
|
|
...
|
|
|
|
>>> print(f(3.))
|
|
|
|
9.0
|
|
|
|
>>> print(jax.grad(f)(3.))
|
|
|
|
3.0
|
2019-06-05 19:13:33 -07:00
|
|
|
|
|
|
|
An example with a function on two arguments, so that the VJP function must
|
|
|
|
return a tuple of length two:
|
|
|
|
|
|
|
|
>>> @jax.custom_gradient
|
|
|
|
... def f(x, y):
|
|
|
|
... return x * y, lambda g: (y, x)
|
|
|
|
...
|
|
|
|
>>> print(f(3., 4.))
|
|
|
|
12.0
|
|
|
|
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
|
|
|
|
(4.0, 3.0)
|
2019-06-05 13:48:04 -07:00
|
|
|
"""
|
|
|
|
def primal_fun(*args, **kwargs):
|
|
|
|
ans, _ = fun(*args, **kwargs)
|
|
|
|
return ans
|
|
|
|
primal_fun = custom_transforms(primal_fun)
|
|
|
|
defvjp_all(primal_fun, fun)
|
|
|
|
return primal_fun
|
|
|
|
|
2019-02-13 14:28:30 -08:00
|
|
|
|
2019-06-03 07:17:37 -07:00
|
|
|
def jarrett(fun):
|
|
|
|
new_fun = custom_transforms(fun)
|
|
|
|
|
|
|
|
def elementwise_jvp(primals, tangents):
|
|
|
|
pushfwd = partial(jvp, fun, primals)
|
|
|
|
y, jacs = vmap(pushfwd, out_axes=(None, 0))(_elementwise_std_basis(tangents))
|
|
|
|
flat_tangents, _ = tree_flatten(tangents)
|
|
|
|
out_tangent = sum([t * jac for t, jac in zip(flat_tangents, jacs)])
|
|
|
|
return y, out_tangent
|
|
|
|
defjvp_all(new_fun, elementwise_jvp)
|
2019-03-23 14:08:15 -07:00
|
|
|
|
2019-06-03 07:17:37 -07:00
|
|
|
return new_fun
|
2019-03-23 14:08:15 -07:00
|
|
|
|
|
|
|
def _elementwise_std_basis(pytree):
|
|
|
|
leaves, _ = tree_flatten(pytree)
|
|
|
|
arity = len(leaves)
|
|
|
|
dims = map(onp.size, leaves)
|
|
|
|
# TODO(mattjj): use symbolic constants
|
2019-04-12 12:01:19 -07:00
|
|
|
dtype = onp.result_type(*leaves)
|
|
|
|
if not onp.issubdtype(dtype, onp.floating):
|
|
|
|
msg = ("Jacobian only defined for functions with floating input and output "
|
|
|
|
"dtypes (i.e. dtypes that model real numbers), got {}.")
|
|
|
|
raise TypeError(msg.format(dtype)) # TODO(mattjj, dougalm): handle complex
|
|
|
|
basis_array = onp.stack([onp.concatenate(
|
|
|
|
[onp.ones(dims[j], dtype) if i == j else onp.zeros(dims[j], dtype)
|
|
|
|
for j in range(arity)]) for i in range(arity)])
|
2019-03-23 14:08:15 -07:00
|
|
|
return _unravel_array_into_pytree(pytree, 1, basis_array)
|
|
|
|
|
2019-04-02 11:22:19 -07:00
|
|
|
|
2019-05-11 10:45:14 -07:00
|
|
|
# This function mostly exists for making slides about JAX.
|
|
|
|
def _make_graphviz(fun):
|
2019-04-02 11:22:19 -07:00
|
|
|
"""Adapts `fun` to return a graphviz dot string of its program representation.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: The function whose `jaxpr` is to be rendered into graphviz dot. Its
|
|
|
|
positional arguments and return value should be arrays, scalars, or
|
|
|
|
standard Python containers (tuple/list/dict) thereof.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A wrapped version of `fun`, set up to return a graphviz dot string.
|
|
|
|
|
|
|
|
See make_jaxpr for a related function.
|
|
|
|
"""
|
2019-05-11 10:45:14 -07:00
|
|
|
# TODO(mattjj): handle eqn.restructure
|
|
|
|
# TODO(mattjj): handle subjaxprs
|
2019-04-02 11:22:19 -07:00
|
|
|
|
|
|
|
def pv_like(x):
|
|
|
|
aval = xla.abstractify(x)
|
|
|
|
return pe.PartialVal((aval, core.unit))
|
|
|
|
|
|
|
|
id_names = ("id{}".format(i) for i in itertools.count())
|
|
|
|
|
|
|
|
def jaxpr_to_graphviz(jaxpr, consts):
|
2019-04-02 21:17:24 -07:00
|
|
|
fragment = []
|
2019-04-02 11:22:19 -07:00
|
|
|
|
2019-04-02 21:17:24 -07:00
|
|
|
fragment.extend(map(invar_node, jaxpr.invars, jaxpr.invars))
|
|
|
|
fragment.extend(map(freevar_node, jaxpr.freevars, jaxpr.freevars))
|
|
|
|
fragment.extend(map(constant_node, jaxpr.constvars, consts))
|
2019-04-02 11:22:19 -07:00
|
|
|
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
if eqn.destructure:
|
|
|
|
id_name = next(id_names)
|
2019-04-02 21:17:24 -07:00
|
|
|
fragment.append(function_node(id_name, eqn.primitive.name))
|
|
|
|
fragment.extend(edge(invar, id_name) for invar in eqn.invars)
|
|
|
|
fragment.extend(edge(id_name, outvar) for outvar in eqn.outvars)
|
2019-04-02 11:22:19 -07:00
|
|
|
else:
|
2019-04-02 21:17:24 -07:00
|
|
|
fragment.append(function_node(eqn.outvars[0], eqn.primitive.name))
|
|
|
|
fragment.extend(edge(invar, eqn.outvars[0]) for invar in eqn.invars)
|
|
|
|
fragment.append(outvar_node(jaxpr.outvar, "out"))
|
|
|
|
return graph(''.join(fragment))
|
2019-04-02 11:22:19 -07:00
|
|
|
|
|
|
|
edge = '{} -> {} [color=gray30];\n'.format
|
|
|
|
function_node = '{} [label="{}", shape=box, color=lightskyblue, style=filled];\n'.format
|
|
|
|
invar_node = '{} [rank=2, label="{}", color=mediumspringgreen, style=filled];\n'.format
|
|
|
|
outvar_node = '{} [label="{}", fillcolor=indianred1, style="filled,dashed", color=black];\n'.format
|
|
|
|
constant_node = '{} [rank=2, label="{}", color=goldenrod1, style=filled];\n'.format
|
|
|
|
freevar_node = '{} [rank=2, label="{}", color=palegreen, style=filled];\n'.format
|
|
|
|
graph = 'digraph G {{{}}}'.format
|
|
|
|
|
|
|
|
@wraps(fun)
|
|
|
|
def graphviz_maker(*args, **kwargs):
|
2019-04-10 22:09:14 -07:00
|
|
|
wrapped = lu.wrap_init(fun, kwargs)
|
2019-04-02 11:22:19 -07:00
|
|
|
jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
|
|
|
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(wrapped, in_trees)
|
|
|
|
pvals = map(pv_like, jax_args)
|
2019-04-10 22:09:14 -07:00
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
|
2019-04-02 11:22:19 -07:00
|
|
|
return jaxpr_to_graphviz(jaxpr, consts)
|
|
|
|
|
|
|
|
graphviz_maker.__name__ = "make_graphviz({})".format(graphviz_maker.__name__)
|
|
|
|
return graphviz_maker
|
2019-06-01 09:34:33 -07:00
|
|
|
|
|
|
|
|
|
|
|
def eval_shape(fun, *args, **kwargs):
|
|
|
|
"""Compute the shape of ``fun(*args, **kwargs)`` without incurring any FLOPs.
|
|
|
|
|
|
|
|
This utility function is useful for performing shape inference. Its
|
|
|
|
input/output behavior is defined by:
|
|
|
|
|
|
|
|
def eval_shape(fun, *args, **kwargs):
|
|
|
|
out = fun(*args, **kwargs)
|
|
|
|
return jax.tree_util.tree_map(np.shape, out)
|
|
|
|
|
|
|
|
But instead of applying ``fun`` directly, which might be expensive, it uses
|
|
|
|
JAX's abstract interpretation machinery to evaluate the shapes without doing
|
|
|
|
any FLOPs.
|
|
|
|
|
|
|
|
Using ``eval_shape`` can also catch shape errors, and will raise same shape
|
|
|
|
errors as evaluating ``fun(*args, **kwargs)``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
*args: a positional argument tuple of arrays, scalars, or (nested) standard
|
2019-06-01 09:48:28 -07:00
|
|
|
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
|
|
|
|
those types. Since only the ``shape`` and ``dtype`` attributes are
|
|
|
|
accessed, only values that duck-type arrays are required, rather than real
|
|
|
|
ndarrays. The duck-typed objects cannot be namedtuples because those are
|
|
|
|
treated as standard Python containers. See the example below.
|
2019-06-01 09:34:33 -07:00
|
|
|
**kwargs: a keyword argument dict of arrays, scalars, or (nested) standard
|
|
|
|
Python containers (pytrees) of those types. As in ``args``, array values
|
|
|
|
need only be duck-typed to have ``shape`` and ``dtype`` attributes.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
2019-06-01 09:48:28 -07:00
|
|
|
>>> f = lambda A, x: np.tanh(np.dot(A, x))
|
|
|
|
>>> class MyArgArray(object):
|
|
|
|
... def __init__(self, shape, dtype):
|
|
|
|
... self.shape = shape
|
|
|
|
... self.dtype = dtype
|
|
|
|
...
|
2019-06-01 09:34:33 -07:00
|
|
|
>>> A = MyArgArray((2000, 3000), np.float32)
|
|
|
|
>>> x = MyArgArray((3000, 1000), np.float32)
|
2019-06-01 09:48:28 -07:00
|
|
|
>>> out_shape = jax.eval_shape(f, A, x) # no FLOPs performed
|
2019-06-01 09:34:33 -07:00
|
|
|
>>> print(out_shape)
|
|
|
|
(2000, 1000)
|
|
|
|
"""
|
|
|
|
def abstractify(x):
|
|
|
|
if type(x) is core.JaxTuple:
|
|
|
|
return core.AbstractTuple(map(abstractify, x))
|
|
|
|
else:
|
|
|
|
return ShapedArray(onp.shape(x), onp.result_type(x))
|
|
|
|
|
|
|
|
jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
|
|
|
jax_kwargs, kwargs_tree = pytree_to_jaxtupletree(kwargs)
|
|
|
|
f, out_tree = pytree_fun_to_jaxtupletree_fun2(lu.wrap_init(fun), kwargs_tree, in_trees)
|
|
|
|
abstract_args = map(abstractify, (jax_kwargs,) + tuple(jax_args))
|
|
|
|
out = pe.abstract_eval_fun(f.call_wrapped, *abstract_args)
|
|
|
|
return tree_map(onp.shape, build_tree(out_tree(), out))
|