mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge remote-tracking branch 'upstream/master' into jaxpr_pp
This commit is contained in:
commit
2b0b04fcad
@ -30,7 +30,10 @@ before_install:
|
||||
- conda update -q conda
|
||||
install:
|
||||
- conda install --yes python=$TRAVIS_PYTHON_VERSION pip six protobuf>=3.6.0 absl-py opt_einsum numpy scipy pytest-xdist fastcache
|
||||
- pip install jaxlib
|
||||
# The jaxlib version should match the minimum jaxlib version in
|
||||
# jax/lib/__init__.py. This tests JAX PRs against the oldest permitted
|
||||
# jaxlib.
|
||||
- pip install jaxlib==0.1.36
|
||||
- pip install -v .
|
||||
# The following are needed to test the Colab notebooks and the documentation building
|
||||
- if [[ "$JAX_ONLY_DOCUMENTATION" != "" ]]; then
|
||||
|
@ -104,7 +104,6 @@ Operators
|
||||
scatter
|
||||
scatter_add
|
||||
select
|
||||
shaped_identity
|
||||
shift_left
|
||||
shift_right_arithmetic
|
||||
shift_right_logical
|
||||
|
@ -77,3 +77,9 @@ py_library(
|
||||
srcs = ["experimental/vectorize.py"],
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "loops",
|
||||
srcs = ["experimental/loops.py"],
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
@ -22,7 +22,7 @@ import six
|
||||
from . import core
|
||||
from . import ad_util
|
||||
from . import dtypes
|
||||
from . util import prod
|
||||
from . util import prod, partialmethod
|
||||
|
||||
|
||||
def concretization_err_msg(fun):
|
||||
@ -145,6 +145,9 @@ class ShapedArray(UnshapedArray):
|
||||
def strip_weak_type(self):
|
||||
return ShapedArray(self.shape, self.dtype) if self.weak_type else self
|
||||
|
||||
def _forward_to_value(self, fun, ignored_tracer, *args):
|
||||
return fun(self.val, *args)
|
||||
|
||||
class ConcreteArray(ShapedArray):
|
||||
__slots__ = ['val']
|
||||
array_abstraction_level = 0
|
||||
@ -185,6 +188,15 @@ class ConcreteArray(ShapedArray):
|
||||
def strip_weak_type(self):
|
||||
return ConcreteArray(self.val) if self.weak_type else self
|
||||
|
||||
_bool = _nonzero = partialmethod(_forward_to_value, bool)
|
||||
_float = partialmethod(_forward_to_value, float)
|
||||
_int = partialmethod(_forward_to_value, int)
|
||||
if six.PY2:
|
||||
_long = partialmethod(_forward_to_value, long) # noqa: F821
|
||||
_complex = partialmethod(_forward_to_value, complex)
|
||||
_hex = partialmethod(_forward_to_value, hex)
|
||||
_oct = partialmethod(_forward_to_value, oct)
|
||||
|
||||
class AbstractToken(core.AbstractValue): pass
|
||||
|
||||
abstract_token = AbstractToken()
|
||||
|
29
jax/api.py
29
jax/api.py
@ -55,7 +55,7 @@ from .util import (unzip2, unzip3, curry, partial, safe_map, safe_zip,
|
||||
from .lib import xla_bridge as xb
|
||||
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
|
||||
host_id, host_ids, host_count)
|
||||
from .abstract_arrays import ShapedArray, raise_to_shaped
|
||||
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from .interpreters import partial_eval as pe
|
||||
from .interpreters import xla
|
||||
from .interpreters import pxla
|
||||
@ -1064,13 +1064,11 @@ def jvp(fun, primals, tangents):
|
||||
or standard Python containers of arrays or scalars. It should return an
|
||||
array, scalar, or standard Python container of arrays or scalars.
|
||||
primals: The primal values at which the Jacobian of `fun` should be
|
||||
evaluated. Should be a tuple of arrays, scalar, or standard Python
|
||||
container thereof. The length of the tuple is equal to the number of
|
||||
positional parameters of `fun`.
|
||||
evaluated. Should be either a tuple or a list of arguments,
|
||||
and its length should equal to the number of positional parameters of `fun`.
|
||||
tangents: The tangent vector for which the Jacobian-vector product should be
|
||||
evaluated. Should be a tuple of arrays, scalar, or standard Python
|
||||
container thereof, with the same tree structure and array shapes as
|
||||
`primals`.
|
||||
evaluated. Should be either a tuple or a list of tangents, with the same
|
||||
tree structure and array shapes as `primals`.
|
||||
|
||||
Returns:
|
||||
A `(primals_out, tangents_out)` pair, where `primals_out` is
|
||||
@ -1089,6 +1087,12 @@ def jvp(fun, primals, tangents):
|
||||
if not isinstance(fun, lu.WrappedFun):
|
||||
fun = lu.wrap_init(fun)
|
||||
|
||||
if (not isinstance(primals, (tuple, list)) or
|
||||
not isinstance(tangents, (tuple, list))):
|
||||
msg = ("primal and tangent arguments to jax.jvp must be tuples or lists; "
|
||||
"found {} and {}.")
|
||||
raise TypeError(msg.format(type(primals).__name__, type(tangents).__name__))
|
||||
|
||||
ps_flat, tree_def = tree_flatten(primals)
|
||||
ts_flat, tree_def_2 = tree_flatten(tangents)
|
||||
if tree_def != tree_def_2:
|
||||
@ -1978,3 +1982,14 @@ def eval_shape(fun, *args, **kwargs):
|
||||
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
|
||||
out = [ShapeDtypeStruct(x.shape, x.dtype) for x in out]
|
||||
return tree_unflatten(out_tree(), out)
|
||||
|
||||
|
||||
def checkpoint(fun, concrete=False):
|
||||
@wraps(fun)
|
||||
def fun_remat(*args, **kwargs):
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
out_flat = pe.remat_call(flat_fun, *args_flat, concrete=concrete)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
return fun_remat
|
||||
remat = checkpoint
|
||||
|
@ -597,7 +597,8 @@ def call_bind(primitive, f, *args, **params):
|
||||
|
||||
|
||||
def call_impl(f, *args, **params):
|
||||
return f.call_wrapped(*args, **params)
|
||||
del params # params parameterize the call primitive, not the function
|
||||
return f.call_wrapped(*args)
|
||||
|
||||
|
||||
call_p = Primitive('call')
|
||||
|
@ -56,6 +56,7 @@ from jax.tree_util import tree_unflatten
|
||||
### Composable gradient transformations. ###
|
||||
|
||||
|
||||
InitUpdate = collections.namedtuple("InitUpdate", ("init", "update"))
|
||||
ClipState = collections.namedtuple("ClipState", "")
|
||||
|
||||
|
||||
@ -77,7 +78,7 @@ def clip(max_delta):
|
||||
lambda g: jnp.clip_by_value(g, -max_delta, max_delta), updates)
|
||||
return updates, state
|
||||
|
||||
return init_fn, update_fn
|
||||
return InitUpdate(init_fn, update_fn)
|
||||
|
||||
|
||||
ClipByGlobalNormState = collections.namedtuple("ClipByGlobalNormState", "")
|
||||
@ -111,7 +112,7 @@ def clip_by_global_norm(max_norm):
|
||||
lambda t: jnp.where(trigger, t, t * (max_norm / g_norm)), updates)
|
||||
return updates, state
|
||||
|
||||
return init_fn, update_fn
|
||||
return InitUpdate(init_fn, update_fn)
|
||||
|
||||
|
||||
TraceState = collections.namedtuple("TraceState", "trace")
|
||||
@ -138,7 +139,7 @@ def trace(decay, nesterov):
|
||||
tree_multimap(f, updates, update_trace) if nesterov else update_trace)
|
||||
return updates, TraceState(trace=update_trace)
|
||||
|
||||
return init_fn, update_fn
|
||||
return InitUpdate(init_fn, update_fn)
|
||||
|
||||
|
||||
ScaleByRmsState = collections.namedtuple("ScaleByRmsState", "nu")
|
||||
@ -172,7 +173,7 @@ def scale_by_rms(decay=0.9, eps=1e-8):
|
||||
updates = tree_multimap(lambda g, n: g / (jnp.sqrt(n + eps)), updates, nu)
|
||||
return updates, ScaleByRmsState(nu=nu)
|
||||
|
||||
return init_fn, update_fn
|
||||
return InitUpdate(init_fn, update_fn)
|
||||
|
||||
|
||||
ScaleByRStdDevState = collections.namedtuple("ScaleByRStdDevState", "mu nu")
|
||||
@ -204,7 +205,7 @@ def scale_by_stddev(decay=0.9, eps=1e-8):
|
||||
lambda g, m, n: g / jnp.sqrt(n - m**2 + eps), updates, mu, nu)
|
||||
return updates, ScaleByRStdDevState(mu=mu, nu=nu)
|
||||
|
||||
return init_fn, update_fn
|
||||
return InitUpdate(init_fn, update_fn)
|
||||
|
||||
|
||||
ScaleByAdamState = collections.namedtuple("ScaleByAdamState", "count mu nu")
|
||||
@ -239,7 +240,7 @@ def scale_by_adam(b1=0.9, b2=0.999, eps=1e-8):
|
||||
lambda m, v: m / (jnp.sqrt(v) + eps), mu_hat, nu_hat)
|
||||
return updates, ScaleByAdamState(count=state.count + 1, mu=mu, nu=nu)
|
||||
|
||||
return init_fn, update_fn
|
||||
return InitUpdate(init_fn, update_fn)
|
||||
|
||||
|
||||
ScaleState = collections.namedtuple("ScaleState", "")
|
||||
@ -262,7 +263,7 @@ def scale(step_size):
|
||||
updates = tree_multimap(lambda g: step_size * g, updates)
|
||||
return updates, state
|
||||
|
||||
return init_fn, update_fn
|
||||
return InitUpdate(init_fn, update_fn)
|
||||
|
||||
|
||||
ScaleByScheduleState = collections.namedtuple("ScaleByScheduleState", "count")
|
||||
@ -286,7 +287,7 @@ def scale_by_schedule(step_size_fn):
|
||||
updates = tree_multimap(lambda g: step_size_fn(state.count) * g, updates)
|
||||
return updates, ScaleByScheduleState(count=state.count + 1)
|
||||
|
||||
return init_fn, update_fn
|
||||
return InitUpdate(init_fn, update_fn)
|
||||
|
||||
|
||||
AddNoiseState = collections.namedtuple("AddNoiseState", "count rng_key")
|
||||
@ -322,7 +323,7 @@ def add_noise(eta, gamma, seed):
|
||||
lambda g, n: g + variance * n, updates, noise)
|
||||
return updates, AddNoiseState(count=state.count + 1, rng_key=all_keys[0])
|
||||
|
||||
return init_fn, update_fn
|
||||
return InitUpdate(init_fn, update_fn)
|
||||
|
||||
|
||||
### Utilities for building and using custom optimizers. ###
|
||||
@ -345,17 +346,17 @@ def chain(*args):
|
||||
|
||||
init_fns, update_fns = zip(*args)
|
||||
|
||||
def init(params):
|
||||
def init_fn(params):
|
||||
return [fn(params) for fn in init_fns]
|
||||
|
||||
def update(updates, state):
|
||||
def update_fn(updates, state):
|
||||
new_state = []
|
||||
for s, fn in zip(state, update_fns):
|
||||
updates, new_s = fn(updates, s)
|
||||
new_state.append(new_s)
|
||||
return updates, new_state
|
||||
|
||||
return init, update
|
||||
return InitUpdate(init_fn, update_fn)
|
||||
|
||||
|
||||
def apply_updates(params, updates):
|
||||
|
@ -26,7 +26,8 @@ from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like
|
||||
from ..abstract_arrays import raise_to_shaped
|
||||
from ..util import unzip2, unzip3, safe_map, safe_zip, partial, split_list
|
||||
from ..tree_util import build_tree, register_pytree_node, tree_map
|
||||
from ..linear_util import thunk, transformation, transformation_with_aux, wrap_init
|
||||
from ..linear_util import (thunk, transformation, transformation_with_aux,
|
||||
wrap_init, hashable_partial)
|
||||
from ..api_util import flatten_fun, flatten_fun_nokwargs
|
||||
from ..tree_util import tree_flatten, tree_unflatten
|
||||
|
||||
@ -139,6 +140,9 @@ def unpair_pval(pval):
|
||||
return (aval_1, const_1), (aval_2, const_2)
|
||||
|
||||
def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
|
||||
if all(ct is zero for ct in cotangents_in):
|
||||
return [zero] * len(jaxpr.freevars), [zero] * len(jaxpr.invars)
|
||||
|
||||
def write_cotangent(v, ct):
|
||||
# assert v not in primal_env
|
||||
if ct is not None:
|
||||
@ -158,13 +162,51 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
|
||||
primal_env[v] = val
|
||||
|
||||
primal_env = {}
|
||||
write_primal(core.unitvar, core.unit)
|
||||
map(write_primal, jaxpr.constvars, consts)
|
||||
map(write_primal, jaxpr.freevars, freevar_vals)
|
||||
map(write_primal, jaxpr.invars, args)
|
||||
|
||||
def is_linear(var):
|
||||
if type(var) is Literal:
|
||||
return False
|
||||
else:
|
||||
return primal_env.get(var, undefined_primal) is undefined_primal
|
||||
|
||||
linear_eqns = []
|
||||
for eqn in jaxpr.eqns:
|
||||
if not eqn.bound_subjaxprs:
|
||||
if any(is_linear(v) for v in eqn.invars):
|
||||
linear_eqns.append(eqn)
|
||||
else:
|
||||
in_vals = map(read_primal, eqn.invars)
|
||||
ans = eqn.primitive.bind(*in_vals, **eqn.params)
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
else:
|
||||
write_primal(eqn.outvars[0], ans)
|
||||
else:
|
||||
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
|
||||
assert not any(is_linear(v) for v in const_vars)
|
||||
if any(is_linear(v) for v in it.chain(eqn.invars, bound_vars)):
|
||||
linear_eqns.append(eqn)
|
||||
elif eqn.primitive is not pe.remat_call_p:
|
||||
ans = _eval_subjaxpr_primals(
|
||||
eqn.primitive, subjaxpr, map(read_primal, const_vars),
|
||||
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
|
||||
# we special-case remat_call here because it can be mixed linear /
|
||||
# nonlinear, so we always evaluate it even if it has a linear part
|
||||
if eqn.primitive is pe.remat_call_p:
|
||||
ans = _eval_subjaxpr_primals(
|
||||
eqn.primitive, subjaxpr, map(read_primal, const_vars),
|
||||
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
|
||||
ct_env = {}
|
||||
map(write_cotangent, jaxpr.outvars, cotangents_in)
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
for eqn in linear_eqns[::-1]:
|
||||
invals = map(read_primal, eqn.invars)
|
||||
if eqn.primitive.multiple_results:
|
||||
cts_in = map(read_cotangent, eqn.outvars)
|
||||
@ -186,6 +228,56 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
|
||||
cotangents_out = map(read_cotangent, jaxpr.invars)
|
||||
return freevar_cts, cotangents_out
|
||||
|
||||
def _eval_subjaxpr_primals(prim, jaxpr, consts, freevar_vals, in_vals, params):
|
||||
all_args, in_tree_def = tree_flatten((consts, freevar_vals, in_vals))
|
||||
fun = hashable_partial(wrap_init(_eval_primals), jaxpr)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
out_flat = prim.bind(fun, *all_args, **params)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
def _eval_primals(jaxpr, consts, freevar_vals, args):
|
||||
primal_env = {}
|
||||
|
||||
def read_primal(v):
|
||||
if type(v) is Literal:
|
||||
return v.val
|
||||
else:
|
||||
return primal_env.get(v, undefined_primal)
|
||||
|
||||
def write_primal(v, val):
|
||||
if val is not undefined_primal:
|
||||
primal_env[v] = val
|
||||
|
||||
def is_linear(var):
|
||||
if type(var) is Literal:
|
||||
return False
|
||||
else:
|
||||
return primal_env.get(var, undefined_primal) is undefined_primal
|
||||
|
||||
write_primal(core.unitvar, core.unit)
|
||||
map(write_primal, jaxpr.constvars, consts)
|
||||
map(write_primal, jaxpr.freevars, freevar_vals)
|
||||
map(write_primal, jaxpr.invars, args)
|
||||
for eqn in jaxpr.eqns:
|
||||
if not eqn.bound_subjaxprs:
|
||||
if not any(is_linear(v) for v in eqn.invars):
|
||||
in_vals = map(read_primal, eqn.invars)
|
||||
ans = eqn.primitive.bind(*in_vals, **eqn.params)
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
else:
|
||||
write_primal(eqn.outvars[0], ans)
|
||||
else:
|
||||
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
|
||||
assert not any(is_linear(v) for v in const_vars)
|
||||
if (eqn.primitive is pe.remat_call_p or
|
||||
not any(is_linear(v) for v in it.chain(eqn.invars, bound_vars))):
|
||||
ans = _eval_subjaxpr_primals(
|
||||
eqn.primitive, subjaxpr, map(read_primal, const_vars),
|
||||
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
return map(read_primal, jaxpr.outvars)
|
||||
|
||||
class UndefinedPrimal(object):
|
||||
def __repr__(self): return '_'
|
||||
undefined_primal = UndefinedPrimal()
|
||||
@ -451,18 +543,19 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
|
||||
out_flat, tree_def = tree_flatten((primal_out, tangent_out))
|
||||
yield out_flat, tree_def
|
||||
|
||||
|
||||
def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
|
||||
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
|
||||
fun = wrap_init(partial(backward_pass, jaxpr))
|
||||
fun = hashable_partial(wrap_init(backward_pass), jaxpr)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
out_flat = primitive.bind(fun, *all_args, **params)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
|
||||
primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p)
|
||||
|
||||
def map_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
|
||||
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
|
||||
fun = wrap_init(partial(backward_pass, jaxpr))
|
||||
fun = hashable_partial(wrap_init(backward_pass), jaxpr)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
out_flat = primitive.bind(fun, *all_args, **params)
|
||||
freevar_cts, arg_cts = tree_unflatten(out_tree(), out_flat)
|
||||
|
@ -18,13 +18,15 @@ from __future__ import print_function
|
||||
|
||||
import itertools as it
|
||||
from collections import namedtuple, Counter, defaultdict
|
||||
import contextlib
|
||||
import threading
|
||||
from weakref import ref
|
||||
|
||||
import numpy as onp
|
||||
|
||||
from .. import core
|
||||
from .. import linear_util as lu
|
||||
from ..abstract_arrays import ShapedArray, ConcreteArray
|
||||
from ..abstract_arrays import ShapedArray, ConcreteArray, raise_to_shaped
|
||||
from ..linear_util import thunk, transformation, transformation_with_aux
|
||||
from ..util import unzip2, safe_zip, safe_map, toposort, partial, split_list
|
||||
from ..core import (Trace, Tracer, new_master, Jaxpr, Literal, get_aval,
|
||||
@ -104,6 +106,8 @@ class JaxprTrace(Trace):
|
||||
return out_tracer
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
if call_primitive in call_partial_eval_rules:
|
||||
return call_partial_eval_rules[call_primitive](self, f, tracers, params)
|
||||
if call_primitive in map_primitives:
|
||||
return self.process_map(call_primitive, f, tracers, params)
|
||||
in_pvs, in_consts = unzip2([t.pval for t in tracers])
|
||||
@ -188,7 +192,6 @@ class JaxprTrace(Trace):
|
||||
return out_tracers
|
||||
return out, todo
|
||||
|
||||
|
||||
def _mapped_aval(aval):
|
||||
if aval is core.abstract_unit:
|
||||
return aval
|
||||
@ -207,6 +210,8 @@ def _unmapped_aval(size, aval):
|
||||
raise TypeError(aval)
|
||||
|
||||
map_primitives = set()
|
||||
custom_partial_eval_rules = {}
|
||||
call_partial_eval_rules = {}
|
||||
|
||||
|
||||
def partial_eval(f, trace, pvs):
|
||||
@ -408,6 +413,14 @@ def closure_convert_jaxpr(jaxpr):
|
||||
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
|
||||
return lifted_jaxpr
|
||||
|
||||
def convert_freevars_jaxpr(jaxpr):
|
||||
core.skip_checks or core.check_jaxpr(jaxpr)
|
||||
lifted_jaxpr = Jaxpr(constvars=jaxpr.constvars, freevars=(),
|
||||
invars=jaxpr.freevars + jaxpr.invars,
|
||||
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
||||
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
|
||||
return lifted_jaxpr
|
||||
|
||||
def partial_eval_jaxpr(jaxpr, unknowns, instantiate):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
|
||||
@ -450,4 +463,148 @@ def partial_eval_jaxpr(jaxpr, unknowns, instantiate):
|
||||
def _split_aval(unknown, aval):
|
||||
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
|
||||
|
||||
custom_partial_eval_rules = {}
|
||||
|
||||
remat_call_p = core.Primitive('remat_call')
|
||||
remat_call = partial(core.call_bind, remat_call_p)
|
||||
remat_call_p.def_custom_bind(remat_call)
|
||||
remat_call_p.def_impl(core.call_impl)
|
||||
remat_call_p.multiple_results = True
|
||||
|
||||
def _remat_partial_eval(trace, f, tracers, params):
|
||||
concrete = params['concrete']
|
||||
|
||||
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
|
||||
# the function being called, not just for the unknown parts. To do that, we
|
||||
# instantiate all the input tracers as constants in the jaxpr being formed.
|
||||
# Those tracers might have concrete avals, and doing abstract interpretation
|
||||
# on concrete avals engenders a tradeoff: it allows data-dependent Python
|
||||
# control flow to work, but it can in some cases lead to redundant FLOPs (done
|
||||
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
|
||||
# `concrete` parameter to switch this behavior, and if `concrete` is False
|
||||
# then we raise the avals to the Shaped level.
|
||||
instantiated_tracers = map(trace.instantiate_const, tracers)
|
||||
if not concrete:
|
||||
instantiated_tracers = [
|
||||
JaxprTracer(trace, PartialVal((raise_to_shaped(t.pval[0]), unit)), t.recipe)
|
||||
if type(t.pval[0]) is ConcreteArray else t for t in instantiated_tracers]
|
||||
|
||||
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
|
||||
in_pvs, in_consts = unzip2(t.pval for t in instantiated_tracers)
|
||||
fun, aux = partial_eval(f, trace, in_pvs)
|
||||
if concrete:
|
||||
# TODO(mattjj): remove `remat_context` when confident no accidental FLOPs
|
||||
with remat_context():
|
||||
out_flat = remat_call_p.bind(fun, *in_consts, **params)
|
||||
else:
|
||||
out_flat = remat_call_p.bind(fun, *in_consts, **params)
|
||||
out_pvs, jaxpr, env = aux()
|
||||
env = map(trace.full_raise, env)
|
||||
out_pval_consts1, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
||||
out_pvals1 = [PartialVal((pv, const)) for pv, const in zip(out_pvs, out_pval_consts1)]
|
||||
|
||||
# Since we traced with everything marked as unknown, but we need to know which
|
||||
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
|
||||
jaxpr_converted = convert_freevars_jaxpr(jaxpr)
|
||||
in_avals = ([raise_to_shaped(t.pval[0]) for t in env]
|
||||
+ [raise_to_shaped(pv) for pv in in_pvs])
|
||||
out_avals = [raise_to_shaped(pv if pv is not None else core.get_aval(const))
|
||||
for pv, const in zip(out_pvs, out_pval_consts1)]
|
||||
typed_jaxpr = core.TypedJaxpr(jaxpr_converted, consts, in_avals, out_avals)
|
||||
in_unknowns = [t.pval[0] is not None for t in it.chain(env, tracers)]
|
||||
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, False)
|
||||
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)
|
||||
|
||||
# First, we prune the jaxpr to be staged out not to have too many outputs.
|
||||
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns)
|
||||
|
||||
# Next, we need values for the outputs that should be known. Since consts
|
||||
# weren't passed through Python for evaluation, we need to evaluate jaxpr_1,
|
||||
# minus the residual outputs that we don't need. When `concrete=True`, as an
|
||||
# optimization we can avoid redoing *some* redundant FLOPs, namely those that
|
||||
# produced concrete avals at the output, simply by using those as computed
|
||||
# values. For the use case of reverse-mode ad in op-by-op ("eager mode")
|
||||
# evaluation, all the primal outputs should be concrete (thus not recomputed).
|
||||
to_compute = [not uk and type(pv) is not ConcreteArray
|
||||
for uk, pv in zip(out_unknowns, out_pvs)]
|
||||
jaxpr_1_primals = _dce_jaxpr(jaxpr_1, to_compute + [False] * num_res)
|
||||
_, in_consts = unzip2(t.pval for t in it.chain(env, tracers))
|
||||
out_pval_consts2 = core.jaxpr_as_fun(jaxpr_1_primals)(*in_consts)[:-num_res or None]
|
||||
out_pvals = map(_reconstruct_pval, out_pvals1, out_pval_consts2, out_unknowns)
|
||||
|
||||
# Now that we have out_pvals, the rest is just like JaxprTrace.process_call.
|
||||
instantiated_tracers = env + instantiated_tracers
|
||||
const_tracers = map(trace.new_instantiated_const, consts)
|
||||
bound_subjaxpr = (typed_jaxpr.jaxpr, const_tracers, ())
|
||||
out_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_pvals]
|
||||
eqn = new_eqn_recipe(instantiated_tracers, out_tracers, remat_call_p,
|
||||
(bound_subjaxpr,), params)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
|
||||
|
||||
def _dce_jaxpr(typed_jaxpr, outputs):
|
||||
# This dead-code elimination is pretty rudimentary, and in particular doesn't
|
||||
# nontrivially DCE through scan, call, or other higher-order primitives.
|
||||
# TODO(mattjj): better DCE
|
||||
jaxpr = typed_jaxpr.jaxpr
|
||||
outvars, out_avals = jaxpr.outvars, typed_jaxpr.out_avals
|
||||
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
|
||||
for var, aval, output in zip(outvars, out_avals, outputs)]
|
||||
new_outvars, new_out_avals = unzip2(out_pairs)
|
||||
|
||||
needed_vars = set(new_outvars)
|
||||
new_eqns = []
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
if set(eqn.outvars) & needed_vars:
|
||||
new_eqns.append(eqn)
|
||||
needed_vars.update(eqn.invars)
|
||||
new_eqns = new_eqns[::-1]
|
||||
|
||||
new_jaxpr = core.Jaxpr(jaxpr.constvars, jaxpr.freevars, jaxpr.invars,
|
||||
new_outvars, new_eqns)
|
||||
return core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals, typed_jaxpr.in_avals,
|
||||
new_out_avals)
|
||||
|
||||
def _reconstruct_pval(pval1, const2, unknown):
|
||||
pv1, const1 = pval1
|
||||
if unknown or pv1 is None:
|
||||
return pval1
|
||||
else:
|
||||
if type(pv1) is ConcreteArray:
|
||||
return PartialVal((None, pv1.val))
|
||||
else:
|
||||
return PartialVal((None, const2))
|
||||
|
||||
# TODO(mattjj): for https://github.com/google/jax/pull/1749 we allowed
|
||||
# standard_abstract_eval to perform concrete evaluation (i.e. FLOPs), but we
|
||||
# don't think it should happen except for in a remat context
|
||||
@contextlib.contextmanager
|
||||
def remat_context():
|
||||
try:
|
||||
prev_state = _thread_local_state.remat
|
||||
_thread_local_state.remat = True
|
||||
yield
|
||||
finally:
|
||||
_thread_local_state.remat = prev_state
|
||||
|
||||
class _ThreadLocalState(threading.local):
|
||||
def __init__(self):
|
||||
self.remat = False
|
||||
_thread_local_state = _ThreadLocalState()
|
||||
|
||||
|
||||
def move_binders_to_front(typed_jaxpr, to_move):
|
||||
assert not typed_jaxpr.jaxpr.constvars and not typed_jaxpr.jaxpr.freevars
|
||||
assert len(typed_jaxpr.in_avals) == len(to_move)
|
||||
new_invars = _move_to_front(typed_jaxpr.jaxpr.invars, to_move)
|
||||
new_jaxpr = core.Jaxpr((), (), new_invars, typed_jaxpr.jaxpr.outvars,
|
||||
typed_jaxpr.jaxpr.eqns)
|
||||
new_in_avals = _move_to_front(typed_jaxpr.in_avals, to_move)
|
||||
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals,
|
||||
new_in_avals, typed_jaxpr.out_avals)
|
||||
return new_typed_jaxpr
|
||||
|
||||
def _move_to_front(lst, to_move):
|
||||
return ([elt for elt, move in zip(lst, to_move) if move] +
|
||||
[elt for elt, move in zip(lst, to_move) if not move])
|
||||
|
||||
|
@ -22,6 +22,7 @@ import itertools as it
|
||||
import operator as op
|
||||
import os
|
||||
|
||||
from absl import logging
|
||||
import numpy as onp
|
||||
import six
|
||||
from six.moves import xrange
|
||||
@ -381,8 +382,10 @@ def _xla_call_impl(fun, *args, **params):
|
||||
|
||||
@lu.cache
|
||||
def _xla_callable(fun, device, backend, *abstract_args):
|
||||
if FLAGS.jax_log_compiles:
|
||||
print("Compiling {} for args {}.".format(fun.__name__, abstract_args))
|
||||
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
|
||||
logging.log(log_priority,
|
||||
"Compiling {} for args {}.".format(fun.__name__, abstract_args))
|
||||
|
||||
pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
|
||||
with core.new_master(pe.JaxprTrace, True) as master:
|
||||
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
|
||||
@ -394,7 +397,7 @@ def _xla_callable(fun, device, backend, *abstract_args):
|
||||
if nreps > xb.device_count(backend):
|
||||
msg = ("compiling computation that requires {} replicas, but only {} XLA "
|
||||
"devices are available")
|
||||
raise ValueError(msg.format(num_replicas, xb.device_count(backend)))
|
||||
raise ValueError(msg.format(nreps, xb.device_count(backend)))
|
||||
axis_env = AxisEnv(nreps, [], [])
|
||||
|
||||
if xb.host_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
|
||||
@ -759,6 +762,36 @@ pe.custom_partial_eval_rules[device_put_p] = lambda trace, x, **params: x
|
||||
ad.deflinear(device_put_p, lambda cotangent, **kwargs: [cotangent])
|
||||
|
||||
|
||||
def _remat_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes, in_nodes,
|
||||
backend=None, device=None, concrete=None):
|
||||
# This looks a lot like _xla_call_translation_rule, except for a widget we use
|
||||
# to foil CSE.
|
||||
del device, concrete # Unused.
|
||||
subc = xb.make_computation_builder("remat_call_subcomputation")
|
||||
consts = [subc.ParameterWithShape(c.GetShape(n)) for n in const_nodes]
|
||||
freevars = [subc.ParameterWithShape(c.GetShape(n)) for n in freevar_nodes]
|
||||
args = [subc.ParameterWithShape(c.GetShape(n)) for n in in_nodes]
|
||||
args = [_foil_cse(subc, x) for x in args]
|
||||
out_nodes = jaxpr_subcomp(subc, jaxpr, backend, axis_env, consts, freevars, *args)
|
||||
subc = subc.Build(subc.Tuple(*out_nodes))
|
||||
return c.Call(subc, list(const_nodes) + list(freevar_nodes) + list(in_nodes))
|
||||
call_translations[pe.remat_call_p] = _remat_translation_rule
|
||||
|
||||
def _foil_cse(c, x):
|
||||
xla_shape = c.GetShape(x)
|
||||
if xla_shape.is_tuple():
|
||||
assert not xla_shape.tuple_shapes()
|
||||
return x
|
||||
else:
|
||||
rng = c.RngNormal(c.Constant(onp.array(0, dtype=onp.float32)),
|
||||
c.Constant(onp.array(1, dtype=onp.float32)),
|
||||
[])
|
||||
pred = c.Lt(rng, c.Constant(onp.finfo(onp.float32).max))
|
||||
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
|
||||
zero = c.Broadcast(c.Constant(onp.array(0, dtype=dtype)), shape)
|
||||
return c.Select(pred, x, zero)
|
||||
|
||||
|
||||
### lazy constants
|
||||
|
||||
class DeviceConstant(DeviceArray):
|
||||
|
@ -1035,9 +1035,6 @@ def sort_key_val(keys, values, dimension=-1):
|
||||
def tie_in(x, y):
|
||||
return tie_in_p.bind(x, y)
|
||||
|
||||
def shaped_identity(x):
|
||||
return shaped_identity_p.bind(x, shape=x.shape)
|
||||
|
||||
|
||||
def full(shape, fill_value, dtype=None):
|
||||
"""Returns an array of `shape` filled with `fill_value`.
|
||||
@ -1472,7 +1469,7 @@ _complex_basetype = lambda dtype: onp.abs(onp.zeros((), dtype)).dtype
|
||||
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
|
||||
prim = Primitive(name)
|
||||
prim.def_impl(partial(xla.apply_primitive, prim))
|
||||
prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule))
|
||||
prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule))
|
||||
xla.translations[prim] = translation_rule or partial(standard_translate, name)
|
||||
return prim
|
||||
|
||||
@ -1480,17 +1477,21 @@ def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
|
||||
def standard_reduction_primitive(shape_rule, dtype_rule, name, translation_rule=None):
|
||||
prim = Primitive(name)
|
||||
prim.def_impl(partial(xla.apply_primitive, prim))
|
||||
prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule))
|
||||
prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule))
|
||||
xla.reduction_translations[prim] = translation_rule or partial(standard_translate, name)
|
||||
return prim
|
||||
|
||||
|
||||
def standard_abstract_eval(shape_rule, dtype_rule, *args, **kwargs):
|
||||
def standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs):
|
||||
assert all(isinstance(arg, UnshapedArray) for arg in args), args
|
||||
least_specialized = _max(
|
||||
map(type, args), key=operator.attrgetter('array_abstraction_level'))
|
||||
if least_specialized is ConcreteArray:
|
||||
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
|
||||
msg = ("If you see this error, please let us know by opening an issue at\n"
|
||||
"https://github.com/google/jax/issues \n"
|
||||
"since we thought this was unreachable!")
|
||||
assert pe._thread_local_state.remat, msg
|
||||
return ConcreteArray(prim.impl(*[x.val for x in args], **kwargs))
|
||||
elif least_specialized is ShapedArray:
|
||||
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
|
||||
elif least_specialized is UnshapedArray:
|
||||
@ -3933,7 +3934,7 @@ batching.primitive_batchers[sort_p] = _sort_batch_rule
|
||||
|
||||
|
||||
def _sort_key_val_abstract_eval(keys, values, dimension):
|
||||
return keys, values
|
||||
return raise_to_shaped(keys), raise_to_shaped(values)
|
||||
|
||||
def _sort_key_val_jvp(primals, tangents, dimension):
|
||||
# NOTE(mattjj): this re-sorts three times, but if we had a variadic
|
||||
@ -3991,7 +3992,7 @@ def _sort_key_val_batch_rule(batched_args, batch_dims, dimension):
|
||||
new_dimension = dimension + (keys_bdim <= dimension)
|
||||
return sort_key_val(keys, new_values, new_dimension), (keys_bdim, keys_bdim)
|
||||
else:
|
||||
raise Exception # unreachable
|
||||
assert False # unreachable
|
||||
|
||||
sort_key_val_p = Primitive('sort_key_val')
|
||||
sort_key_val_p.multiple_results = True
|
||||
@ -4013,21 +4014,13 @@ def _tie_in_batch_rule(batched_args, batch_dims):
|
||||
|
||||
tie_in_p = Primitive('tie_in')
|
||||
tie_in_p.def_impl(lambda x, y: y)
|
||||
tie_in_p.def_abstract_eval(lambda x, y: y)
|
||||
tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y))
|
||||
xla.translations[tie_in_p] = lambda c, x, y: y
|
||||
ad.deflinear(tie_in_p, _tie_in_transpose_rule)
|
||||
batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule
|
||||
masking.shape_rules[tie_in_p] = lambda shape_exprs: shape_exprs[1]
|
||||
masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]
|
||||
|
||||
shaped_identity_p = Primitive('shape_id')
|
||||
shaped_identity_p.def_impl(lambda x, shape: x)
|
||||
shaped_identity_p.def_abstract_eval(lambda x, shape: x)
|
||||
xla.translations[shaped_identity_p] = lambda c, x, shape: x
|
||||
ad.deflinear(shaped_identity_p, lambda t, shape: [shaped_identity(t)])
|
||||
batching.primitive_batchers[shaped_identity_p] = \
|
||||
lambda a, d, shape: (shaped_identity(a[0]), d[0])
|
||||
|
||||
|
||||
### constants
|
||||
|
||||
|
@ -675,7 +675,7 @@ def _scan_partial_eval(trace, *tracers, **kwargs):
|
||||
_, _, res_pvals = split_list(out_pvals_1, [num_carry, num_ys])
|
||||
intensive_residuals = [const for pv, const in res_pvals if pv is None]
|
||||
move = [False] * len(jaxpr_1.in_avals) + [pv is None for pv, _ in res_pvals]
|
||||
jaxpr_2_opt = _move_binders_to_front(jaxpr_2, move)
|
||||
jaxpr_2_opt = pe.move_binders_to_front(jaxpr_2, move)
|
||||
num_consts_2 = num_consts + len(intensive_residuals)
|
||||
|
||||
in_consts = (list(consts_1) + [core.unit] * num_consts +
|
||||
@ -712,21 +712,6 @@ def _scan_partial_eval(trace, *tracers, **kwargs):
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
|
||||
def _move_binders_to_front(typed_jaxpr, to_move):
|
||||
assert not typed_jaxpr.jaxpr.constvars and not typed_jaxpr.jaxpr.freevars
|
||||
assert len(typed_jaxpr.in_avals) == len(to_move)
|
||||
new_invars = _move_to_front(typed_jaxpr.jaxpr.invars, to_move)
|
||||
new_jaxpr = core.Jaxpr((), (), new_invars, typed_jaxpr.jaxpr.outvars,
|
||||
typed_jaxpr.jaxpr.eqns)
|
||||
new_in_avals = _move_to_front(typed_jaxpr.in_avals, to_move)
|
||||
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals,
|
||||
new_in_avals, typed_jaxpr.out_avals)
|
||||
return new_typed_jaxpr
|
||||
|
||||
def _move_to_front(lst, to_move):
|
||||
return ([elt for elt, move in zip(lst, to_move) if move] +
|
||||
[elt for elt, move in zip(lst, to_move) if not move])
|
||||
|
||||
def _promote_aval_rank(sz, aval):
|
||||
if aval is core.abstract_unit:
|
||||
return core.abstract_unit
|
||||
@ -1084,7 +1069,7 @@ def custom_root(f, initial_guess, solve, tangent_solve):
|
||||
|
||||
|
||||
def _root_abstract_eval(*args, **kwargs):
|
||||
return args[sum(kwargs['const_lengths']):]
|
||||
return _map(raise_to_shaped, args[sum(kwargs['const_lengths']):])
|
||||
|
||||
|
||||
def _root_impl(*args, **kwargs):
|
||||
@ -1253,7 +1238,7 @@ def custom_linear_solve(
|
||||
|
||||
|
||||
def _linear_solve_abstract_eval(*args, **kwargs):
|
||||
return args[sum(kwargs['const_lengths']):]
|
||||
return _map(raise_to_shaped, args[sum(kwargs['const_lengths']):])
|
||||
|
||||
|
||||
def _custom_linear_solve_impl(*args, **kwargs):
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
import jaxlib
|
||||
|
||||
_minimum_jaxlib_version = (0, 1, 31)
|
||||
_minimum_jaxlib_version = (0, 1, 36)
|
||||
try:
|
||||
from jaxlib import version as jaxlib_version
|
||||
except:
|
||||
|
@ -148,8 +148,8 @@ class WrappedFun(object):
|
||||
gen = gen(*(gen_args + tuple(args)), **kwargs)
|
||||
args, kwargs = next(gen)
|
||||
stack.append((gen, out_store))
|
||||
gen = None
|
||||
|
||||
del gen
|
||||
ans = self.f(*args, **dict(self.params, **kwargs))
|
||||
del args
|
||||
while stack:
|
||||
@ -210,3 +210,8 @@ def cache(call):
|
||||
cache[key] = (ans, fun.stores)
|
||||
return ans
|
||||
return memoized_fun
|
||||
|
||||
@transformation
|
||||
def hashable_partial(x, *args):
|
||||
ans = yield (x,) + args, {}
|
||||
yield ans
|
||||
|
@ -256,6 +256,8 @@ def update_numpydoc(docstr, fun, op):
|
||||
|
||||
#Some numpy functions have an extra tab at the beginning of each line,
|
||||
#If this function is one of those we remove this extra tab from all the lines
|
||||
if not hasattr(op, '__code__'):
|
||||
return docstr
|
||||
if docstr[:4] == ' ':
|
||||
lines = docstr.split('\n')
|
||||
for idx, line in enumerate(lines):
|
||||
@ -289,6 +291,8 @@ def _wraps(fun, update_doc=True, lax_description=""):
|
||||
If False, include the numpy docstring verbatim.
|
||||
"""
|
||||
def wrap(op):
|
||||
if not hasattr(fun, '__doc__') or fun.__doc__ is None:
|
||||
return op
|
||||
try:
|
||||
# Numpy doc comments have the form:
|
||||
# fn(x, y, z) (optional)
|
||||
@ -993,7 +997,7 @@ def broadcast_arrays(*args):
|
||||
|
||||
def broadcast_to(arr, shape):
|
||||
"""Like Numpy's broadcast_to but doesn't necessarily return views."""
|
||||
arr = arr if isinstance(arr, ndarray) or isscalar(arr) else array(arr)
|
||||
arr = arr if isinstance(arr, ndarray) else array(arr)
|
||||
shape = tuple(map(int, shape)) # check that shape is concrete
|
||||
arr_shape = _shape(arr)
|
||||
if arr_shape == shape:
|
||||
@ -2157,38 +2161,35 @@ def vdot(a, b, precision=None):
|
||||
@_wraps(onp.tensordot, lax_description=_PRECISION_DOC)
|
||||
def tensordot(a, b, axes=2, precision=None):
|
||||
_check_arraylike("tensordot", a, b)
|
||||
if not (ndim(a) >= 1 and ndim(b) >= 1):
|
||||
a_ndim = ndim(a)
|
||||
b_ndim = ndim(b)
|
||||
if a_ndim < 1 or b_ndim < 1:
|
||||
msg = "tensordot requires a.ndim and b.dim to be at least 1, got {} and {}."
|
||||
raise TypeError(msg.format(ndim(a), ndim(b)))
|
||||
|
||||
a, b = _promote_dtypes(a, b)
|
||||
if type(axes) is int:
|
||||
if axes == 0:
|
||||
a, b = _promote_dtypes(a, b)
|
||||
return lax.mul(lax.reshape(a, shape(a) + (1,) * ndim(b)),
|
||||
lax.reshape(b, (1,) * ndim(a) + shape(b)))
|
||||
else:
|
||||
a, b = _promote_dtypes(a, b)
|
||||
a_reshape = lax.reshape(a, (_prod(a.shape[:-axes]), _prod(a.shape[-axes:])))
|
||||
b_reshape = lax.reshape(b, (_prod(b.shape[:axes]), _prod(b.shape[axes:])))
|
||||
out_reshape = lax.dot(a_reshape, b_reshape, precision=precision)
|
||||
return lax.reshape(out_reshape, a.shape[:-axes] + b.shape[axes:])
|
||||
if axes > _min(a_ndim, b_ndim):
|
||||
msg = "Number of tensordot axes (axes {}) exceeds input ranks ({} and {})"
|
||||
raise msg.format(axes, a.shape, b.shape)
|
||||
contracting_dims = tuple(range(a_ndim - axes, a_ndim)), tuple(range(axes))
|
||||
elif type(axes) in (list, tuple) and len(axes) == 2:
|
||||
ax1, ax2 = axes
|
||||
if type(ax1) == type(ax2) == int:
|
||||
a_transposed = moveaxis(a, ax1, -1) if ax1 != a.ndim - 1 else a
|
||||
b_transposed = moveaxis(b, ax2, 0) if ax2 != 0 else b
|
||||
return tensordot(a_transposed, b_transposed, 1, precision)
|
||||
contracting_dims = ((_canonicalize_axis(ax1, a_ndim),),
|
||||
(_canonicalize_axis(ax2, b_ndim),))
|
||||
elif type(ax1) in (list, tuple) and type(ax2) in (list, tuple):
|
||||
if len(ax1) != len(ax2):
|
||||
msg = "tensordot requires axes lists to have equal length, got {} and {}."
|
||||
raise TypeError(msg.format(ax1, ax2))
|
||||
num_axes = len(ax1)
|
||||
a_transposed = moveaxis(a, ax1, tuple(range(a.ndim - num_axes, a.ndim)))
|
||||
b_transposed = moveaxis(b, ax2, tuple(range(num_axes)))
|
||||
return tensordot(a_transposed, b_transposed, num_axes, precision)
|
||||
msg = ("tensordot axes argument must be an int, a pair of ints, or a pair of "
|
||||
"lists/tuples of ints.")
|
||||
raise TypeError(msg)
|
||||
contracting_dims = (tuple(_canonicalize_axis(i, a_ndim) for i in ax1),
|
||||
tuple(_canonicalize_axis(i, b_ndim) for i in ax2))
|
||||
else:
|
||||
msg = ("tensordot axes argument must be an int, a pair of ints, or a pair "
|
||||
"of lists/tuples of ints.")
|
||||
raise TypeError(msg)
|
||||
return lax.dot_general(a, b, (contracting_dims, ((), ())),
|
||||
precision=precision)
|
||||
|
||||
|
||||
@_wraps(onp.einsum, lax_description=_PRECISION_DOC)
|
||||
@ -2298,24 +2299,15 @@ def _einsum(operands, contractions, precision):
|
||||
batch_names = ''.join(lhs_names[i] for i in range(len(lhs_names))
|
||||
if i in batch_dims)
|
||||
|
||||
if contracted_names:
|
||||
# contract using lax.dot_general
|
||||
lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
|
||||
for n in contracted_names)
|
||||
operand = _dot_general(lhs, rhs, lhs_cont, rhs_cont, len(batch_dims),
|
||||
precision)
|
||||
deleted_names = batch_names + ''.join(contracted_names)
|
||||
names = (batch_names + removechars(lhs_names, deleted_names)
|
||||
+ removechars(rhs_names, deleted_names))
|
||||
else:
|
||||
# no contraction, just a tensor product
|
||||
nbatch = len(batch_names)
|
||||
assert lhs.shape[:nbatch] == rhs.shape[:nbatch]
|
||||
names = batch_names + lhs_names[nbatch:] + rhs_names[nbatch:]
|
||||
lhs_shape = lhs.shape + (1,) * (rhs.ndim - nbatch)
|
||||
rhs_shape = rhs.shape[:nbatch] + (1,) * (lhs.ndim - nbatch) + rhs.shape[nbatch:]
|
||||
operand = lax.reshape(lhs, lhs_shape) * lax.reshape(rhs, rhs_shape)
|
||||
|
||||
# contract using lax.dot_general
|
||||
lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
|
||||
for n in contracted_names)
|
||||
bdims = tuple(range(len(batch_dims)))
|
||||
dimension_numbers = [(lhs_cont, rhs_cont), (bdims, bdims)]
|
||||
operand = lax.dot_general(lhs, rhs, dimension_numbers, precision)
|
||||
deleted_names = batch_names + ''.join(contracted_names)
|
||||
names = (batch_names + removechars(lhs_names, deleted_names)
|
||||
+ removechars(rhs_names, deleted_names))
|
||||
else:
|
||||
raise NotImplementedError # if this is actually reachable, open an issue!
|
||||
|
||||
@ -2331,50 +2323,6 @@ def _einsum(operands, contractions, precision):
|
||||
return operands[0]
|
||||
|
||||
|
||||
def _dot_general(lhs, rhs, lhs_cont, rhs_cont, nbatch, precision):
|
||||
"""Helper for einsum contractions."""
|
||||
# lax.dot_general has some tight constraints on dimension_numbers that this
|
||||
# wrapper loosens via transposes and reshapes
|
||||
assert len(lhs_cont) == len(rhs_cont) > 0
|
||||
ncont = len(lhs_cont)
|
||||
lhs_ntensor = lhs.ndim - nbatch - ncont
|
||||
rhs_ntensor = rhs.ndim - nbatch - ncont
|
||||
batch_dims = tuple(range(nbatch))
|
||||
|
||||
if ncont == 1 and 0 <= lhs_ntensor <= 1 and 0 <= rhs_ntensor <= 1:
|
||||
dimension_numbers = [(lhs_cont, rhs_cont), (batch_dims, batch_dims)]
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers, precision)
|
||||
else:
|
||||
# move contracting dimensions to the end. lax.dot_general only allows one
|
||||
# contracting dimension, so if there's more than one we collapse them.
|
||||
if ncont > 1:
|
||||
lhs_cdims = tuple(range(lhs.ndim - ncont, lhs.ndim))
|
||||
lhs = moveaxis(lhs, lhs_cont, lhs_cdims)
|
||||
lhs = lhs.reshape(lhs.shape[:-ncont] + (-1,))
|
||||
|
||||
rhs_cdims = tuple(range(rhs.ndim - ncont, rhs.ndim))
|
||||
rhs = moveaxis(rhs, rhs_cont, rhs_cdims)
|
||||
rhs = rhs.reshape(rhs.shape[:-ncont] + (-1,))
|
||||
else:
|
||||
lhs = moveaxis(lhs, lhs_cont[0], -1)
|
||||
rhs = moveaxis(rhs, rhs_cont[0], -1)
|
||||
|
||||
# lax.dot_general only allows zero or one tensor product dims per operand,
|
||||
# so if there's more than one we collapse them.
|
||||
result_shape = lhs.shape[:nbatch] + lhs.shape[nbatch:-1] + rhs.shape[nbatch:-1]
|
||||
|
||||
if lhs_ntensor > 1:
|
||||
lhs = lhs.reshape(lhs.shape[:nbatch] + (-1,) + lhs.shape[-1:])
|
||||
|
||||
if rhs_ntensor > 1:
|
||||
rhs = rhs.reshape(rhs.shape[:nbatch] + (-1,) + rhs.shape[-1:])
|
||||
|
||||
lhs_cont, rhs_cont = [lhs.ndim - 1], [rhs.ndim - 1]
|
||||
dimension_numbers = [(lhs_cont, rhs_cont), (batch_dims, batch_dims)]
|
||||
result = lax.dot_general(lhs, rhs, dimension_numbers, precision)
|
||||
return lax.reshape(result, result_shape)
|
||||
|
||||
|
||||
def _movechars(s, src, dst):
|
||||
"""Helper for einsum string munging, like moveaxis on identifier strings."""
|
||||
chars = [c for i, c in enumerate(s) if i not in src]
|
||||
|
@ -12,4 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.1.52"
|
||||
__version__ = "0.1.53"
|
||||
|
@ -29,6 +29,7 @@ cc_library(
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/base",
|
||||
"@pybind11",
|
||||
@ -161,6 +162,7 @@ cuda_library(
|
||||
":gpu_kernel_helpers",
|
||||
":kernel_helpers",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -22,6 +22,7 @@ import unittest
|
||||
import warnings
|
||||
import weakref
|
||||
|
||||
from absl import logging
|
||||
from absl.testing import absltest
|
||||
import numpy as onp
|
||||
import six
|
||||
@ -494,11 +495,29 @@ class APITest(jtu.JaxTestCase):
|
||||
("primal and tangent arguments to jax.jvp must have the same tree "
|
||||
"structure"),
|
||||
lambda: api.jvp(lambda x, y: x * y, (onp.float32(2),), ()))
|
||||
# If primals and tangents must both be tuples or both lists
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
("primal and tangent arguments to jax.jvp must have the same tree "
|
||||
"structure"),
|
||||
lambda: api.jvp(lambda x, y: x * y, (onp.float32(2),), [onp.float32(2)]))
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"primal and tangent arguments to jax.jvp must have equal types",
|
||||
lambda: api.jvp(lambda x: -x, (onp.float16(2),), (onp.float32(4),)))
|
||||
|
||||
|
||||
def test_jvp_non_tuple_arguments(self):
|
||||
def f(x, y): return x + y
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"primal and tangent arguments to jax.jvp must be tuples or lists; found float and tuple.",
|
||||
lambda: partial(api.jvp(f, 0., (1.,))))
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"primal and tangent arguments to jax.jvp must be tuples or lists; found tuple and ndarray.",
|
||||
lambda: partial(api.jvp(f, (0.,), onp.array([1., 2.]))))
|
||||
|
||||
def test_vjp_mismatched_arguments(self):
|
||||
_, pullback = api.vjp(lambda x, y: x * y, onp.float32(3), onp.float32(4))
|
||||
self.assertRaisesRegex(
|
||||
@ -1326,5 +1345,214 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
""")
|
||||
|
||||
|
||||
def test_grad_of_jit_compilation_caching(self):
|
||||
if not hasattr(self, "assertLogs"):
|
||||
raise unittest.SkipTest("test requires assertLogs (python 3)")
|
||||
|
||||
lax.add(1, 2) # make sure some initial warnings are already printed
|
||||
|
||||
sin = api.jit(np.sin)
|
||||
|
||||
prev_level = logging.get_verbosity()
|
||||
try:
|
||||
logging.set_verbosity('DEBUG')
|
||||
with self.assertLogs(level=logging.DEBUG) as l:
|
||||
ans1 = api.grad(sin)(2.)
|
||||
ans2 = api.grad(sin)(3.)
|
||||
finally:
|
||||
logging.set_verbosity(prev_level)
|
||||
self.assertLen(l.output, 2)
|
||||
|
||||
self.assertAllClose(ans1, onp.cos(2.), check_dtypes=False)
|
||||
self.assertAllClose(ans2, onp.cos(3.), check_dtypes=False)
|
||||
|
||||
def test_remat_basic(self):
|
||||
@api.remat
|
||||
def g(x):
|
||||
return lax.sin(lax.sin(x)), 3.
|
||||
|
||||
def f(x):
|
||||
x, _ = g(x)
|
||||
return x
|
||||
|
||||
ans = f(2.)
|
||||
expected = onp.sin(onp.sin(2.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans, f_lin = api.linearize(f, 2.)
|
||||
expected = onp.sin(onp.sin(2.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = f_lin(3.)
|
||||
expected = onp.cos(onp.sin(2.)) * onp.cos(2.) * 3.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
sin_calls = []
|
||||
cos_calls = []
|
||||
sin_impl = lax.sin_p.impl
|
||||
cos_impl = lax.cos_p.impl
|
||||
try:
|
||||
lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x))
|
||||
lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x))
|
||||
f_lin(3.)
|
||||
finally:
|
||||
lax.sin_p.def_impl(sin_impl)
|
||||
lax.cos_p.def_impl(cos_impl)
|
||||
self.assertEqual(len(sin_calls), 1)
|
||||
self.assertEqual(len(cos_calls), 2)
|
||||
|
||||
def test_remat_freevars(self):
|
||||
def f1(x):
|
||||
y = 2 * np.sin(x)
|
||||
z = np.cos(x) * np.sin(y)
|
||||
return z
|
||||
|
||||
def f2(x):
|
||||
y = 2 * np.sin(x)
|
||||
z = api.remat(lambda x: np.cos(x) * np.sin(y))(x)
|
||||
return z
|
||||
|
||||
ans, f_lin = api.linearize(f2, 2.)
|
||||
expected, f_lin_expected = api.linearize(f1, 2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = f_lin(3.)
|
||||
expected = f_lin_expected(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_remat_grad_python_control_flow(self):
|
||||
@partial(api.remat, concrete=True)
|
||||
def g(x):
|
||||
if x > 0:
|
||||
return lax.sin(x), 3.
|
||||
else:
|
||||
return lax.cos(x), 4.
|
||||
|
||||
def f(x):
|
||||
x, _ = g(x)
|
||||
return x
|
||||
|
||||
ans = f(2.)
|
||||
expected = onp.sin(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(f)(2.)
|
||||
expected = onp.cos(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_remat_jit(self):
|
||||
@api.remat
|
||||
def g(x):
|
||||
return lax.sin(lax.sin(x))
|
||||
|
||||
def f_(x):
|
||||
return g(x)
|
||||
f = api.jit(f_)
|
||||
|
||||
ans = f(2.)
|
||||
expected = onp.sin(onp.sin(2.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(f)(2.)
|
||||
expected = onp.cos(onp.sin(2.)) * onp.cos(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.jit(api.grad(f_))(2.)
|
||||
expected = onp.cos(onp.sin(2.)) * onp.cos(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_remat_vmap(self):
|
||||
@api.remat
|
||||
def g(x):
|
||||
return lax.sin(lax.sin(x))
|
||||
|
||||
x = onp.arange(3.)
|
||||
|
||||
ans = api.vmap(g)(x)
|
||||
expected = onp.sin(onp.sin(x))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.jacfwd(g)(x)
|
||||
expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.jacrev(g)(x)
|
||||
expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_remat_higher_order_autodiff(self):
|
||||
def f(x):
|
||||
return lax.cos(lax.sin(x))
|
||||
g = api.remat(f)
|
||||
|
||||
ans = api.grad(api.grad(g))(3.)
|
||||
expected = api.grad(api.grad(f))(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_remat_scan(self):
|
||||
to_scan = lambda c, x: (np.sin(c), None)
|
||||
|
||||
def f_noremat(x):
|
||||
y, _ = lax.scan(to_scan, x, onp.arange(3.))
|
||||
return y
|
||||
|
||||
def f_yesremat(x):
|
||||
y, _ = lax.scan(api.remat(to_scan), x, onp.arange(3.))
|
||||
return y
|
||||
|
||||
ans = f_yesremat(4.)
|
||||
expected = f_noremat(4.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(f_yesremat)(4.)
|
||||
expected = api.grad(f_noremat)(4.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
|
||||
scan_eqn, = jaxpr.eqns
|
||||
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
|
||||
|
||||
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
|
||||
scan_eqn, = jaxpr.eqns
|
||||
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
|
||||
|
||||
def test_remat_no_redundant_flops(self):
|
||||
# see https://github.com/google/jax/pull/1749#issuecomment-558267584
|
||||
|
||||
@api.jit
|
||||
def g(x):
|
||||
return f(2., x)
|
||||
|
||||
@api.remat
|
||||
def f(x, y):
|
||||
return np.sin(x) * y
|
||||
|
||||
# We swap out sin_p's impl rule to count how many times it's invoked
|
||||
called = []
|
||||
sin_impl = lax.sin_p.impl
|
||||
try:
|
||||
lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x))
|
||||
api.grad(g)(3.)
|
||||
finally:
|
||||
lax.sin_p.def_impl(sin_impl)
|
||||
num_calls = len(called)
|
||||
self.assertEqual(num_calls, 1)
|
||||
|
||||
def test_remat_binomial_checkpointing(self):
|
||||
def binom_checkpoint(funs):
|
||||
if len(funs) == 1:
|
||||
return funs[0]
|
||||
else:
|
||||
f1 = binom_checkpoint(funs[:len(funs)//2])
|
||||
f2 = binom_checkpoint(funs[len(funs)//2:])
|
||||
return api.remat(lambda x: f1(f2(x)))
|
||||
|
||||
f1 = binom_checkpoint([np.sin, np.sin, np.sin, np.sin])
|
||||
f2 = lambda x: np.sin(np.sin(np.sin(np.sin(x))))
|
||||
x = 4.
|
||||
self.assertAllClose(f1(x), f2(x), check_dtypes=False)
|
||||
self.assertAllClose(api.grad(f1)(x), api.grad(f2)(x), check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
@ -248,7 +248,7 @@ class GeneratedFunTest(jtu.JaxTestCase):
|
||||
tangents = [tangents[i] for i in dyn_argnums]
|
||||
fun, vals = partial_argnums(fun, vals, dyn_argnums)
|
||||
ans1, deriv1 = jvp_fd(fun, vals, tangents)
|
||||
ans2, deriv2 = jvp(fun, vals, tangents)
|
||||
ans2, deriv2 = jvp(fun, tuple(vals), tuple(tangents))
|
||||
check_all_close(ans1, ans2)
|
||||
check_all_close(deriv1, deriv2)
|
||||
|
||||
|
@ -2393,6 +2393,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(lnp.broadcast_to(1, (3, 2)), onp.ones((3, 2)),
|
||||
check_dtypes=False)
|
||||
|
||||
def testBroadcastToOnScalar(self):
|
||||
self.assertIsInstance(lnp.broadcast_to(10.0, ()), lnp.ndarray)
|
||||
self.assertIsInstance(onp.broadcast_to(10.0, ()), onp.ndarray)
|
||||
|
||||
def testPrecision(self):
|
||||
|
||||
def iter_eqns(jaxpr):
|
||||
|
@ -53,10 +53,10 @@ class OptixTest(absltest.TestCase):
|
||||
|
||||
# experimental/optix.py
|
||||
optix_params = self.init_params
|
||||
opt_init, opt_update = optix.sgd(LR, 0.0)
|
||||
state = opt_init(optix_params)
|
||||
sgd = optix.sgd(LR, 0.0)
|
||||
state = sgd.init(optix_params)
|
||||
for _ in range(STEPS):
|
||||
updates, state = opt_update(self.per_step_updates, state)
|
||||
updates, state = sgd.update(self.per_step_updates, state)
|
||||
optix_params = optix.apply_updates(optix_params, updates)
|
||||
|
||||
# Check equivalence.
|
||||
@ -76,10 +76,10 @@ class OptixTest(absltest.TestCase):
|
||||
|
||||
# experimental/optix.py
|
||||
optix_params = self.init_params
|
||||
opt_init, opt_update = optix.adam(LR, b1, b2, eps)
|
||||
state = opt_init(optix_params)
|
||||
adam = optix.adam(LR, b1, b2, eps)
|
||||
state = adam.init(optix_params)
|
||||
for _ in range(STEPS):
|
||||
updates, state = opt_update(self.per_step_updates, state)
|
||||
updates, state = adam.update(self.per_step_updates, state)
|
||||
optix_params = optix.apply_updates(optix_params, updates)
|
||||
|
||||
# Check equivalence.
|
||||
@ -99,10 +99,10 @@ class OptixTest(absltest.TestCase):
|
||||
|
||||
# experimental/optix.py
|
||||
optix_params = self.init_params
|
||||
opt_init, opt_update = optix.rmsprop(LR, decay, eps)
|
||||
state = opt_init(optix_params)
|
||||
rmsprop = optix.rmsprop(LR, decay, eps)
|
||||
state = rmsprop.init(optix_params)
|
||||
for _ in range(STEPS):
|
||||
updates, state = opt_update(self.per_step_updates, state)
|
||||
updates, state = rmsprop.update(self.per_step_updates, state)
|
||||
optix_params = optix.apply_updates(optix_params, updates)
|
||||
|
||||
# Check equivalence.
|
||||
|
Loading…
x
Reference in New Issue
Block a user