Merge remote-tracking branch 'upstream/master' into jaxpr_pp

This commit is contained in:
George Necula 2019-11-28 08:56:00 +01:00
commit 2b0b04fcad
21 changed files with 652 additions and 167 deletions

View File

@ -30,7 +30,10 @@ before_install:
- conda update -q conda
install:
- conda install --yes python=$TRAVIS_PYTHON_VERSION pip six protobuf>=3.6.0 absl-py opt_einsum numpy scipy pytest-xdist fastcache
- pip install jaxlib
# The jaxlib version should match the minimum jaxlib version in
# jax/lib/__init__.py. This tests JAX PRs against the oldest permitted
# jaxlib.
- pip install jaxlib==0.1.36
- pip install -v .
# The following are needed to test the Colab notebooks and the documentation building
- if [[ "$JAX_ONLY_DOCUMENTATION" != "" ]]; then

View File

@ -104,7 +104,6 @@ Operators
scatter
scatter_add
select
shaped_identity
shift_left
shift_right_arithmetic
shift_right_logical

View File

@ -77,3 +77,9 @@ py_library(
srcs = ["experimental/vectorize.py"],
deps = [":jax"],
)
py_library(
name = "loops",
srcs = ["experimental/loops.py"],
deps = [":jax"],
)

View File

@ -22,7 +22,7 @@ import six
from . import core
from . import ad_util
from . import dtypes
from . util import prod
from . util import prod, partialmethod
def concretization_err_msg(fun):
@ -145,6 +145,9 @@ class ShapedArray(UnshapedArray):
def strip_weak_type(self):
return ShapedArray(self.shape, self.dtype) if self.weak_type else self
def _forward_to_value(self, fun, ignored_tracer, *args):
return fun(self.val, *args)
class ConcreteArray(ShapedArray):
__slots__ = ['val']
array_abstraction_level = 0
@ -185,6 +188,15 @@ class ConcreteArray(ShapedArray):
def strip_weak_type(self):
return ConcreteArray(self.val) if self.weak_type else self
_bool = _nonzero = partialmethod(_forward_to_value, bool)
_float = partialmethod(_forward_to_value, float)
_int = partialmethod(_forward_to_value, int)
if six.PY2:
_long = partialmethod(_forward_to_value, long) # noqa: F821
_complex = partialmethod(_forward_to_value, complex)
_hex = partialmethod(_forward_to_value, hex)
_oct = partialmethod(_forward_to_value, oct)
class AbstractToken(core.AbstractValue): pass
abstract_token = AbstractToken()

View File

@ -55,7 +55,7 @@ from .util import (unzip2, unzip3, curry, partial, safe_map, safe_zip,
from .lib import xla_bridge as xb
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
host_id, host_ids, host_count)
from .abstract_arrays import ShapedArray, raise_to_shaped
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
from .interpreters import partial_eval as pe
from .interpreters import xla
from .interpreters import pxla
@ -1064,13 +1064,11 @@ def jvp(fun, primals, tangents):
or standard Python containers of arrays or scalars. It should return an
array, scalar, or standard Python container of arrays or scalars.
primals: The primal values at which the Jacobian of `fun` should be
evaluated. Should be a tuple of arrays, scalar, or standard Python
container thereof. The length of the tuple is equal to the number of
positional parameters of `fun`.
evaluated. Should be either a tuple or a list of arguments,
and its length should equal to the number of positional parameters of `fun`.
tangents: The tangent vector for which the Jacobian-vector product should be
evaluated. Should be a tuple of arrays, scalar, or standard Python
container thereof, with the same tree structure and array shapes as
`primals`.
evaluated. Should be either a tuple or a list of tangents, with the same
tree structure and array shapes as `primals`.
Returns:
A `(primals_out, tangents_out)` pair, where `primals_out` is
@ -1089,6 +1087,12 @@ def jvp(fun, primals, tangents):
if not isinstance(fun, lu.WrappedFun):
fun = lu.wrap_init(fun)
if (not isinstance(primals, (tuple, list)) or
not isinstance(tangents, (tuple, list))):
msg = ("primal and tangent arguments to jax.jvp must be tuples or lists; "
"found {} and {}.")
raise TypeError(msg.format(type(primals).__name__, type(tangents).__name__))
ps_flat, tree_def = tree_flatten(primals)
ts_flat, tree_def_2 = tree_flatten(tangents)
if tree_def != tree_def_2:
@ -1978,3 +1982,14 @@ def eval_shape(fun, *args, **kwargs):
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
out = [ShapeDtypeStruct(x.shape, x.dtype) for x in out]
return tree_unflatten(out_tree(), out)
def checkpoint(fun, concrete=False):
@wraps(fun)
def fun_remat(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out_flat = pe.remat_call(flat_fun, *args_flat, concrete=concrete)
return tree_unflatten(out_tree(), out_flat)
return fun_remat
remat = checkpoint

View File

@ -597,7 +597,8 @@ def call_bind(primitive, f, *args, **params):
def call_impl(f, *args, **params):
return f.call_wrapped(*args, **params)
del params # params parameterize the call primitive, not the function
return f.call_wrapped(*args)
call_p = Primitive('call')

View File

@ -56,6 +56,7 @@ from jax.tree_util import tree_unflatten
### Composable gradient transformations. ###
InitUpdate = collections.namedtuple("InitUpdate", ("init", "update"))
ClipState = collections.namedtuple("ClipState", "")
@ -77,7 +78,7 @@ def clip(max_delta):
lambda g: jnp.clip_by_value(g, -max_delta, max_delta), updates)
return updates, state
return init_fn, update_fn
return InitUpdate(init_fn, update_fn)
ClipByGlobalNormState = collections.namedtuple("ClipByGlobalNormState", "")
@ -111,7 +112,7 @@ def clip_by_global_norm(max_norm):
lambda t: jnp.where(trigger, t, t * (max_norm / g_norm)), updates)
return updates, state
return init_fn, update_fn
return InitUpdate(init_fn, update_fn)
TraceState = collections.namedtuple("TraceState", "trace")
@ -138,7 +139,7 @@ def trace(decay, nesterov):
tree_multimap(f, updates, update_trace) if nesterov else update_trace)
return updates, TraceState(trace=update_trace)
return init_fn, update_fn
return InitUpdate(init_fn, update_fn)
ScaleByRmsState = collections.namedtuple("ScaleByRmsState", "nu")
@ -172,7 +173,7 @@ def scale_by_rms(decay=0.9, eps=1e-8):
updates = tree_multimap(lambda g, n: g / (jnp.sqrt(n + eps)), updates, nu)
return updates, ScaleByRmsState(nu=nu)
return init_fn, update_fn
return InitUpdate(init_fn, update_fn)
ScaleByRStdDevState = collections.namedtuple("ScaleByRStdDevState", "mu nu")
@ -204,7 +205,7 @@ def scale_by_stddev(decay=0.9, eps=1e-8):
lambda g, m, n: g / jnp.sqrt(n - m**2 + eps), updates, mu, nu)
return updates, ScaleByRStdDevState(mu=mu, nu=nu)
return init_fn, update_fn
return InitUpdate(init_fn, update_fn)
ScaleByAdamState = collections.namedtuple("ScaleByAdamState", "count mu nu")
@ -239,7 +240,7 @@ def scale_by_adam(b1=0.9, b2=0.999, eps=1e-8):
lambda m, v: m / (jnp.sqrt(v) + eps), mu_hat, nu_hat)
return updates, ScaleByAdamState(count=state.count + 1, mu=mu, nu=nu)
return init_fn, update_fn
return InitUpdate(init_fn, update_fn)
ScaleState = collections.namedtuple("ScaleState", "")
@ -262,7 +263,7 @@ def scale(step_size):
updates = tree_multimap(lambda g: step_size * g, updates)
return updates, state
return init_fn, update_fn
return InitUpdate(init_fn, update_fn)
ScaleByScheduleState = collections.namedtuple("ScaleByScheduleState", "count")
@ -286,7 +287,7 @@ def scale_by_schedule(step_size_fn):
updates = tree_multimap(lambda g: step_size_fn(state.count) * g, updates)
return updates, ScaleByScheduleState(count=state.count + 1)
return init_fn, update_fn
return InitUpdate(init_fn, update_fn)
AddNoiseState = collections.namedtuple("AddNoiseState", "count rng_key")
@ -322,7 +323,7 @@ def add_noise(eta, gamma, seed):
lambda g, n: g + variance * n, updates, noise)
return updates, AddNoiseState(count=state.count + 1, rng_key=all_keys[0])
return init_fn, update_fn
return InitUpdate(init_fn, update_fn)
### Utilities for building and using custom optimizers. ###
@ -345,17 +346,17 @@ def chain(*args):
init_fns, update_fns = zip(*args)
def init(params):
def init_fn(params):
return [fn(params) for fn in init_fns]
def update(updates, state):
def update_fn(updates, state):
new_state = []
for s, fn in zip(state, update_fns):
updates, new_s = fn(updates, s)
new_state.append(new_s)
return updates, new_state
return init, update
return InitUpdate(init_fn, update_fn)
def apply_updates(params, updates):

View File

@ -26,7 +26,8 @@ from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like
from ..abstract_arrays import raise_to_shaped
from ..util import unzip2, unzip3, safe_map, safe_zip, partial, split_list
from ..tree_util import build_tree, register_pytree_node, tree_map
from ..linear_util import thunk, transformation, transformation_with_aux, wrap_init
from ..linear_util import (thunk, transformation, transformation_with_aux,
wrap_init, hashable_partial)
from ..api_util import flatten_fun, flatten_fun_nokwargs
from ..tree_util import tree_flatten, tree_unflatten
@ -139,6 +140,9 @@ def unpair_pval(pval):
return (aval_1, const_1), (aval_2, const_2)
def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
if all(ct is zero for ct in cotangents_in):
return [zero] * len(jaxpr.freevars), [zero] * len(jaxpr.invars)
def write_cotangent(v, ct):
# assert v not in primal_env
if ct is not None:
@ -158,13 +162,51 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
primal_env[v] = val
primal_env = {}
write_primal(core.unitvar, core.unit)
map(write_primal, jaxpr.constvars, consts)
map(write_primal, jaxpr.freevars, freevar_vals)
map(write_primal, jaxpr.invars, args)
def is_linear(var):
if type(var) is Literal:
return False
else:
return primal_env.get(var, undefined_primal) is undefined_primal
linear_eqns = []
for eqn in jaxpr.eqns:
if not eqn.bound_subjaxprs:
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
else:
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
if eqn.primitive.multiple_results:
map(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)
else:
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
assert not any(is_linear(v) for v in const_vars)
if any(is_linear(v) for v in it.chain(eqn.invars, bound_vars)):
linear_eqns.append(eqn)
elif eqn.primitive is not pe.remat_call_p:
ans = _eval_subjaxpr_primals(
eqn.primitive, subjaxpr, map(read_primal, const_vars),
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
map(write_primal, eqn.outvars, ans)
# we special-case remat_call here because it can be mixed linear /
# nonlinear, so we always evaluate it even if it has a linear part
if eqn.primitive is pe.remat_call_p:
ans = _eval_subjaxpr_primals(
eqn.primitive, subjaxpr, map(read_primal, const_vars),
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
map(write_primal, eqn.outvars, ans)
ct_env = {}
map(write_cotangent, jaxpr.outvars, cotangents_in)
for eqn in jaxpr.eqns[::-1]:
for eqn in linear_eqns[::-1]:
invals = map(read_primal, eqn.invars)
if eqn.primitive.multiple_results:
cts_in = map(read_cotangent, eqn.outvars)
@ -186,6 +228,56 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
cotangents_out = map(read_cotangent, jaxpr.invars)
return freevar_cts, cotangents_out
def _eval_subjaxpr_primals(prim, jaxpr, consts, freevar_vals, in_vals, params):
all_args, in_tree_def = tree_flatten((consts, freevar_vals, in_vals))
fun = hashable_partial(wrap_init(_eval_primals), jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = prim.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
def _eval_primals(jaxpr, consts, freevar_vals, args):
primal_env = {}
def read_primal(v):
if type(v) is Literal:
return v.val
else:
return primal_env.get(v, undefined_primal)
def write_primal(v, val):
if val is not undefined_primal:
primal_env[v] = val
def is_linear(var):
if type(var) is Literal:
return False
else:
return primal_env.get(var, undefined_primal) is undefined_primal
write_primal(core.unitvar, core.unit)
map(write_primal, jaxpr.constvars, consts)
map(write_primal, jaxpr.freevars, freevar_vals)
map(write_primal, jaxpr.invars, args)
for eqn in jaxpr.eqns:
if not eqn.bound_subjaxprs:
if not any(is_linear(v) for v in eqn.invars):
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
if eqn.primitive.multiple_results:
map(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)
else:
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
assert not any(is_linear(v) for v in const_vars)
if (eqn.primitive is pe.remat_call_p or
not any(is_linear(v) for v in it.chain(eqn.invars, bound_vars))):
ans = _eval_subjaxpr_primals(
eqn.primitive, subjaxpr, map(read_primal, const_vars),
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
map(write_primal, eqn.outvars, ans)
return map(read_primal, jaxpr.outvars)
class UndefinedPrimal(object):
def __repr__(self): return '_'
undefined_primal = UndefinedPrimal()
@ -451,18 +543,19 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
out_flat, tree_def = tree_flatten((primal_out, tangent_out))
yield out_flat, tree_def
def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
fun = wrap_init(partial(backward_pass, jaxpr))
fun = hashable_partial(wrap_init(backward_pass), jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = primitive.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p)
def map_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
fun = wrap_init(partial(backward_pass, jaxpr))
fun = hashable_partial(wrap_init(backward_pass), jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = primitive.bind(fun, *all_args, **params)
freevar_cts, arg_cts = tree_unflatten(out_tree(), out_flat)

View File

@ -18,13 +18,15 @@ from __future__ import print_function
import itertools as it
from collections import namedtuple, Counter, defaultdict
import contextlib
import threading
from weakref import ref
import numpy as onp
from .. import core
from .. import linear_util as lu
from ..abstract_arrays import ShapedArray, ConcreteArray
from ..abstract_arrays import ShapedArray, ConcreteArray, raise_to_shaped
from ..linear_util import thunk, transformation, transformation_with_aux
from ..util import unzip2, safe_zip, safe_map, toposort, partial, split_list
from ..core import (Trace, Tracer, new_master, Jaxpr, Literal, get_aval,
@ -104,6 +106,8 @@ class JaxprTrace(Trace):
return out_tracer
def process_call(self, call_primitive, f, tracers, params):
if call_primitive in call_partial_eval_rules:
return call_partial_eval_rules[call_primitive](self, f, tracers, params)
if call_primitive in map_primitives:
return self.process_map(call_primitive, f, tracers, params)
in_pvs, in_consts = unzip2([t.pval for t in tracers])
@ -188,7 +192,6 @@ class JaxprTrace(Trace):
return out_tracers
return out, todo
def _mapped_aval(aval):
if aval is core.abstract_unit:
return aval
@ -207,6 +210,8 @@ def _unmapped_aval(size, aval):
raise TypeError(aval)
map_primitives = set()
custom_partial_eval_rules = {}
call_partial_eval_rules = {}
def partial_eval(f, trace, pvs):
@ -408,6 +413,14 @@ def closure_convert_jaxpr(jaxpr):
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
def convert_freevars_jaxpr(jaxpr):
core.skip_checks or core.check_jaxpr(jaxpr)
lifted_jaxpr = Jaxpr(constvars=jaxpr.constvars, freevars=(),
invars=jaxpr.freevars + jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
def partial_eval_jaxpr(jaxpr, unknowns, instantiate):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
@ -450,4 +463,148 @@ def partial_eval_jaxpr(jaxpr, unknowns, instantiate):
def _split_aval(unknown, aval):
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
custom_partial_eval_rules = {}
remat_call_p = core.Primitive('remat_call')
remat_call = partial(core.call_bind, remat_call_p)
remat_call_p.def_custom_bind(remat_call)
remat_call_p.def_impl(core.call_impl)
remat_call_p.multiple_results = True
def _remat_partial_eval(trace, f, tracers, params):
concrete = params['concrete']
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
# the function being called, not just for the unknown parts. To do that, we
# instantiate all the input tracers as constants in the jaxpr being formed.
# Those tracers might have concrete avals, and doing abstract interpretation
# on concrete avals engenders a tradeoff: it allows data-dependent Python
# control flow to work, but it can in some cases lead to redundant FLOPs (done
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
# `concrete` parameter to switch this behavior, and if `concrete` is False
# then we raise the avals to the Shaped level.
instantiated_tracers = map(trace.instantiate_const, tracers)
if not concrete:
instantiated_tracers = [
JaxprTracer(trace, PartialVal((raise_to_shaped(t.pval[0]), unit)), t.recipe)
if type(t.pval[0]) is ConcreteArray else t for t in instantiated_tracers]
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
in_pvs, in_consts = unzip2(t.pval for t in instantiated_tracers)
fun, aux = partial_eval(f, trace, in_pvs)
if concrete:
# TODO(mattjj): remove `remat_context` when confident no accidental FLOPs
with remat_context():
out_flat = remat_call_p.bind(fun, *in_consts, **params)
else:
out_flat = remat_call_p.bind(fun, *in_consts, **params)
out_pvs, jaxpr, env = aux()
env = map(trace.full_raise, env)
out_pval_consts1, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvals1 = [PartialVal((pv, const)) for pv, const in zip(out_pvs, out_pval_consts1)]
# Since we traced with everything marked as unknown, but we need to know which
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
jaxpr_converted = convert_freevars_jaxpr(jaxpr)
in_avals = ([raise_to_shaped(t.pval[0]) for t in env]
+ [raise_to_shaped(pv) for pv in in_pvs])
out_avals = [raise_to_shaped(pv if pv is not None else core.get_aval(const))
for pv, const in zip(out_pvs, out_pval_consts1)]
typed_jaxpr = core.TypedJaxpr(jaxpr_converted, consts, in_avals, out_avals)
in_unknowns = [t.pval[0] is not None for t in it.chain(env, tracers)]
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, False)
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)
# First, we prune the jaxpr to be staged out not to have too many outputs.
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns)
# Next, we need values for the outputs that should be known. Since consts
# weren't passed through Python for evaluation, we need to evaluate jaxpr_1,
# minus the residual outputs that we don't need. When `concrete=True`, as an
# optimization we can avoid redoing *some* redundant FLOPs, namely those that
# produced concrete avals at the output, simply by using those as computed
# values. For the use case of reverse-mode ad in op-by-op ("eager mode")
# evaluation, all the primal outputs should be concrete (thus not recomputed).
to_compute = [not uk and type(pv) is not ConcreteArray
for uk, pv in zip(out_unknowns, out_pvs)]
jaxpr_1_primals = _dce_jaxpr(jaxpr_1, to_compute + [False] * num_res)
_, in_consts = unzip2(t.pval for t in it.chain(env, tracers))
out_pval_consts2 = core.jaxpr_as_fun(jaxpr_1_primals)(*in_consts)[:-num_res or None]
out_pvals = map(_reconstruct_pval, out_pvals1, out_pval_consts2, out_unknowns)
# Now that we have out_pvals, the rest is just like JaxprTrace.process_call.
instantiated_tracers = env + instantiated_tracers
const_tracers = map(trace.new_instantiated_const, consts)
bound_subjaxpr = (typed_jaxpr.jaxpr, const_tracers, ())
out_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_pvals]
eqn = new_eqn_recipe(instantiated_tracers, out_tracers, remat_call_p,
(bound_subjaxpr,), params)
for t in out_tracers: t.recipe = eqn
return out_tracers
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
def _dce_jaxpr(typed_jaxpr, outputs):
# This dead-code elimination is pretty rudimentary, and in particular doesn't
# nontrivially DCE through scan, call, or other higher-order primitives.
# TODO(mattjj): better DCE
jaxpr = typed_jaxpr.jaxpr
outvars, out_avals = jaxpr.outvars, typed_jaxpr.out_avals
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
for var, aval, output in zip(outvars, out_avals, outputs)]
new_outvars, new_out_avals = unzip2(out_pairs)
needed_vars = set(new_outvars)
new_eqns = []
for eqn in jaxpr.eqns[::-1]:
if set(eqn.outvars) & needed_vars:
new_eqns.append(eqn)
needed_vars.update(eqn.invars)
new_eqns = new_eqns[::-1]
new_jaxpr = core.Jaxpr(jaxpr.constvars, jaxpr.freevars, jaxpr.invars,
new_outvars, new_eqns)
return core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals, typed_jaxpr.in_avals,
new_out_avals)
def _reconstruct_pval(pval1, const2, unknown):
pv1, const1 = pval1
if unknown or pv1 is None:
return pval1
else:
if type(pv1) is ConcreteArray:
return PartialVal((None, pv1.val))
else:
return PartialVal((None, const2))
# TODO(mattjj): for https://github.com/google/jax/pull/1749 we allowed
# standard_abstract_eval to perform concrete evaluation (i.e. FLOPs), but we
# don't think it should happen except for in a remat context
@contextlib.contextmanager
def remat_context():
try:
prev_state = _thread_local_state.remat
_thread_local_state.remat = True
yield
finally:
_thread_local_state.remat = prev_state
class _ThreadLocalState(threading.local):
def __init__(self):
self.remat = False
_thread_local_state = _ThreadLocalState()
def move_binders_to_front(typed_jaxpr, to_move):
assert not typed_jaxpr.jaxpr.constvars and not typed_jaxpr.jaxpr.freevars
assert len(typed_jaxpr.in_avals) == len(to_move)
new_invars = _move_to_front(typed_jaxpr.jaxpr.invars, to_move)
new_jaxpr = core.Jaxpr((), (), new_invars, typed_jaxpr.jaxpr.outvars,
typed_jaxpr.jaxpr.eqns)
new_in_avals = _move_to_front(typed_jaxpr.in_avals, to_move)
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals,
new_in_avals, typed_jaxpr.out_avals)
return new_typed_jaxpr
def _move_to_front(lst, to_move):
return ([elt for elt, move in zip(lst, to_move) if move] +
[elt for elt, move in zip(lst, to_move) if not move])

View File

@ -22,6 +22,7 @@ import itertools as it
import operator as op
import os
from absl import logging
import numpy as onp
import six
from six.moves import xrange
@ -381,8 +382,10 @@ def _xla_call_impl(fun, *args, **params):
@lu.cache
def _xla_callable(fun, device, backend, *abstract_args):
if FLAGS.jax_log_compiles:
print("Compiling {} for args {}.".format(fun.__name__, abstract_args))
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
logging.log(log_priority,
"Compiling {} for args {}.".format(fun.__name__, abstract_args))
pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
with core.new_master(pe.JaxprTrace, True) as master:
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
@ -394,7 +397,7 @@ def _xla_callable(fun, device, backend, *abstract_args):
if nreps > xb.device_count(backend):
msg = ("compiling computation that requires {} replicas, but only {} XLA "
"devices are available")
raise ValueError(msg.format(num_replicas, xb.device_count(backend)))
raise ValueError(msg.format(nreps, xb.device_count(backend)))
axis_env = AxisEnv(nreps, [], [])
if xb.host_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
@ -759,6 +762,36 @@ pe.custom_partial_eval_rules[device_put_p] = lambda trace, x, **params: x
ad.deflinear(device_put_p, lambda cotangent, **kwargs: [cotangent])
def _remat_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes, in_nodes,
backend=None, device=None, concrete=None):
# This looks a lot like _xla_call_translation_rule, except for a widget we use
# to foil CSE.
del device, concrete # Unused.
subc = xb.make_computation_builder("remat_call_subcomputation")
consts = [subc.ParameterWithShape(c.GetShape(n)) for n in const_nodes]
freevars = [subc.ParameterWithShape(c.GetShape(n)) for n in freevar_nodes]
args = [subc.ParameterWithShape(c.GetShape(n)) for n in in_nodes]
args = [_foil_cse(subc, x) for x in args]
out_nodes = jaxpr_subcomp(subc, jaxpr, backend, axis_env, consts, freevars, *args)
subc = subc.Build(subc.Tuple(*out_nodes))
return c.Call(subc, list(const_nodes) + list(freevar_nodes) + list(in_nodes))
call_translations[pe.remat_call_p] = _remat_translation_rule
def _foil_cse(c, x):
xla_shape = c.GetShape(x)
if xla_shape.is_tuple():
assert not xla_shape.tuple_shapes()
return x
else:
rng = c.RngNormal(c.Constant(onp.array(0, dtype=onp.float32)),
c.Constant(onp.array(1, dtype=onp.float32)),
[])
pred = c.Lt(rng, c.Constant(onp.finfo(onp.float32).max))
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
zero = c.Broadcast(c.Constant(onp.array(0, dtype=dtype)), shape)
return c.Select(pred, x, zero)
### lazy constants
class DeviceConstant(DeviceArray):

View File

@ -1035,9 +1035,6 @@ def sort_key_val(keys, values, dimension=-1):
def tie_in(x, y):
return tie_in_p.bind(x, y)
def shaped_identity(x):
return shaped_identity_p.bind(x, shape=x.shape)
def full(shape, fill_value, dtype=None):
"""Returns an array of `shape` filled with `fill_value`.
@ -1472,7 +1469,7 @@ _complex_basetype = lambda dtype: onp.abs(onp.zeros((), dtype)).dtype
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
prim = Primitive(name)
prim.def_impl(partial(xla.apply_primitive, prim))
prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule))
prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule))
xla.translations[prim] = translation_rule or partial(standard_translate, name)
return prim
@ -1480,17 +1477,21 @@ def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
def standard_reduction_primitive(shape_rule, dtype_rule, name, translation_rule=None):
prim = Primitive(name)
prim.def_impl(partial(xla.apply_primitive, prim))
prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule))
prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule))
xla.reduction_translations[prim] = translation_rule or partial(standard_translate, name)
return prim
def standard_abstract_eval(shape_rule, dtype_rule, *args, **kwargs):
def standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs):
assert all(isinstance(arg, UnshapedArray) for arg in args), args
least_specialized = _max(
map(type, args), key=operator.attrgetter('array_abstraction_level'))
if least_specialized is ConcreteArray:
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
msg = ("If you see this error, please let us know by opening an issue at\n"
"https://github.com/google/jax/issues \n"
"since we thought this was unreachable!")
assert pe._thread_local_state.remat, msg
return ConcreteArray(prim.impl(*[x.val for x in args], **kwargs))
elif least_specialized is ShapedArray:
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
elif least_specialized is UnshapedArray:
@ -3933,7 +3934,7 @@ batching.primitive_batchers[sort_p] = _sort_batch_rule
def _sort_key_val_abstract_eval(keys, values, dimension):
return keys, values
return raise_to_shaped(keys), raise_to_shaped(values)
def _sort_key_val_jvp(primals, tangents, dimension):
# NOTE(mattjj): this re-sorts three times, but if we had a variadic
@ -3991,7 +3992,7 @@ def _sort_key_val_batch_rule(batched_args, batch_dims, dimension):
new_dimension = dimension + (keys_bdim <= dimension)
return sort_key_val(keys, new_values, new_dimension), (keys_bdim, keys_bdim)
else:
raise Exception # unreachable
assert False # unreachable
sort_key_val_p = Primitive('sort_key_val')
sort_key_val_p.multiple_results = True
@ -4013,21 +4014,13 @@ def _tie_in_batch_rule(batched_args, batch_dims):
tie_in_p = Primitive('tie_in')
tie_in_p.def_impl(lambda x, y: y)
tie_in_p.def_abstract_eval(lambda x, y: y)
tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y))
xla.translations[tie_in_p] = lambda c, x, y: y
ad.deflinear(tie_in_p, _tie_in_transpose_rule)
batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule
masking.shape_rules[tie_in_p] = lambda shape_exprs: shape_exprs[1]
masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]
shaped_identity_p = Primitive('shape_id')
shaped_identity_p.def_impl(lambda x, shape: x)
shaped_identity_p.def_abstract_eval(lambda x, shape: x)
xla.translations[shaped_identity_p] = lambda c, x, shape: x
ad.deflinear(shaped_identity_p, lambda t, shape: [shaped_identity(t)])
batching.primitive_batchers[shaped_identity_p] = \
lambda a, d, shape: (shaped_identity(a[0]), d[0])
### constants

View File

@ -675,7 +675,7 @@ def _scan_partial_eval(trace, *tracers, **kwargs):
_, _, res_pvals = split_list(out_pvals_1, [num_carry, num_ys])
intensive_residuals = [const for pv, const in res_pvals if pv is None]
move = [False] * len(jaxpr_1.in_avals) + [pv is None for pv, _ in res_pvals]
jaxpr_2_opt = _move_binders_to_front(jaxpr_2, move)
jaxpr_2_opt = pe.move_binders_to_front(jaxpr_2, move)
num_consts_2 = num_consts + len(intensive_residuals)
in_consts = (list(consts_1) + [core.unit] * num_consts +
@ -712,21 +712,6 @@ def _scan_partial_eval(trace, *tracers, **kwargs):
for t in out_tracers: t.recipe = eqn
return out_tracers
def _move_binders_to_front(typed_jaxpr, to_move):
assert not typed_jaxpr.jaxpr.constvars and not typed_jaxpr.jaxpr.freevars
assert len(typed_jaxpr.in_avals) == len(to_move)
new_invars = _move_to_front(typed_jaxpr.jaxpr.invars, to_move)
new_jaxpr = core.Jaxpr((), (), new_invars, typed_jaxpr.jaxpr.outvars,
typed_jaxpr.jaxpr.eqns)
new_in_avals = _move_to_front(typed_jaxpr.in_avals, to_move)
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals,
new_in_avals, typed_jaxpr.out_avals)
return new_typed_jaxpr
def _move_to_front(lst, to_move):
return ([elt for elt, move in zip(lst, to_move) if move] +
[elt for elt, move in zip(lst, to_move) if not move])
def _promote_aval_rank(sz, aval):
if aval is core.abstract_unit:
return core.abstract_unit
@ -1084,7 +1069,7 @@ def custom_root(f, initial_guess, solve, tangent_solve):
def _root_abstract_eval(*args, **kwargs):
return args[sum(kwargs['const_lengths']):]
return _map(raise_to_shaped, args[sum(kwargs['const_lengths']):])
def _root_impl(*args, **kwargs):
@ -1253,7 +1238,7 @@ def custom_linear_solve(
def _linear_solve_abstract_eval(*args, **kwargs):
return args[sum(kwargs['const_lengths']):]
return _map(raise_to_shaped, args[sum(kwargs['const_lengths']):])
def _custom_linear_solve_impl(*args, **kwargs):

View File

@ -17,7 +17,7 @@
import jaxlib
_minimum_jaxlib_version = (0, 1, 31)
_minimum_jaxlib_version = (0, 1, 36)
try:
from jaxlib import version as jaxlib_version
except:

View File

@ -148,8 +148,8 @@ class WrappedFun(object):
gen = gen(*(gen_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = None
del gen
ans = self.f(*args, **dict(self.params, **kwargs))
del args
while stack:
@ -210,3 +210,8 @@ def cache(call):
cache[key] = (ans, fun.stores)
return ans
return memoized_fun
@transformation
def hashable_partial(x, *args):
ans = yield (x,) + args, {}
yield ans

View File

@ -256,6 +256,8 @@ def update_numpydoc(docstr, fun, op):
#Some numpy functions have an extra tab at the beginning of each line,
#If this function is one of those we remove this extra tab from all the lines
if not hasattr(op, '__code__'):
return docstr
if docstr[:4] == ' ':
lines = docstr.split('\n')
for idx, line in enumerate(lines):
@ -289,6 +291,8 @@ def _wraps(fun, update_doc=True, lax_description=""):
If False, include the numpy docstring verbatim.
"""
def wrap(op):
if not hasattr(fun, '__doc__') or fun.__doc__ is None:
return op
try:
# Numpy doc comments have the form:
# fn(x, y, z) (optional)
@ -993,7 +997,7 @@ def broadcast_arrays(*args):
def broadcast_to(arr, shape):
"""Like Numpy's broadcast_to but doesn't necessarily return views."""
arr = arr if isinstance(arr, ndarray) or isscalar(arr) else array(arr)
arr = arr if isinstance(arr, ndarray) else array(arr)
shape = tuple(map(int, shape)) # check that shape is concrete
arr_shape = _shape(arr)
if arr_shape == shape:
@ -2157,38 +2161,35 @@ def vdot(a, b, precision=None):
@_wraps(onp.tensordot, lax_description=_PRECISION_DOC)
def tensordot(a, b, axes=2, precision=None):
_check_arraylike("tensordot", a, b)
if not (ndim(a) >= 1 and ndim(b) >= 1):
a_ndim = ndim(a)
b_ndim = ndim(b)
if a_ndim < 1 or b_ndim < 1:
msg = "tensordot requires a.ndim and b.dim to be at least 1, got {} and {}."
raise TypeError(msg.format(ndim(a), ndim(b)))
a, b = _promote_dtypes(a, b)
if type(axes) is int:
if axes == 0:
a, b = _promote_dtypes(a, b)
return lax.mul(lax.reshape(a, shape(a) + (1,) * ndim(b)),
lax.reshape(b, (1,) * ndim(a) + shape(b)))
else:
a, b = _promote_dtypes(a, b)
a_reshape = lax.reshape(a, (_prod(a.shape[:-axes]), _prod(a.shape[-axes:])))
b_reshape = lax.reshape(b, (_prod(b.shape[:axes]), _prod(b.shape[axes:])))
out_reshape = lax.dot(a_reshape, b_reshape, precision=precision)
return lax.reshape(out_reshape, a.shape[:-axes] + b.shape[axes:])
if axes > _min(a_ndim, b_ndim):
msg = "Number of tensordot axes (axes {}) exceeds input ranks ({} and {})"
raise msg.format(axes, a.shape, b.shape)
contracting_dims = tuple(range(a_ndim - axes, a_ndim)), tuple(range(axes))
elif type(axes) in (list, tuple) and len(axes) == 2:
ax1, ax2 = axes
if type(ax1) == type(ax2) == int:
a_transposed = moveaxis(a, ax1, -1) if ax1 != a.ndim - 1 else a
b_transposed = moveaxis(b, ax2, 0) if ax2 != 0 else b
return tensordot(a_transposed, b_transposed, 1, precision)
contracting_dims = ((_canonicalize_axis(ax1, a_ndim),),
(_canonicalize_axis(ax2, b_ndim),))
elif type(ax1) in (list, tuple) and type(ax2) in (list, tuple):
if len(ax1) != len(ax2):
msg = "tensordot requires axes lists to have equal length, got {} and {}."
raise TypeError(msg.format(ax1, ax2))
num_axes = len(ax1)
a_transposed = moveaxis(a, ax1, tuple(range(a.ndim - num_axes, a.ndim)))
b_transposed = moveaxis(b, ax2, tuple(range(num_axes)))
return tensordot(a_transposed, b_transposed, num_axes, precision)
msg = ("tensordot axes argument must be an int, a pair of ints, or a pair of "
"lists/tuples of ints.")
raise TypeError(msg)
contracting_dims = (tuple(_canonicalize_axis(i, a_ndim) for i in ax1),
tuple(_canonicalize_axis(i, b_ndim) for i in ax2))
else:
msg = ("tensordot axes argument must be an int, a pair of ints, or a pair "
"of lists/tuples of ints.")
raise TypeError(msg)
return lax.dot_general(a, b, (contracting_dims, ((), ())),
precision=precision)
@_wraps(onp.einsum, lax_description=_PRECISION_DOC)
@ -2298,24 +2299,15 @@ def _einsum(operands, contractions, precision):
batch_names = ''.join(lhs_names[i] for i in range(len(lhs_names))
if i in batch_dims)
if contracted_names:
# contract using lax.dot_general
lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
for n in contracted_names)
operand = _dot_general(lhs, rhs, lhs_cont, rhs_cont, len(batch_dims),
precision)
deleted_names = batch_names + ''.join(contracted_names)
names = (batch_names + removechars(lhs_names, deleted_names)
+ removechars(rhs_names, deleted_names))
else:
# no contraction, just a tensor product
nbatch = len(batch_names)
assert lhs.shape[:nbatch] == rhs.shape[:nbatch]
names = batch_names + lhs_names[nbatch:] + rhs_names[nbatch:]
lhs_shape = lhs.shape + (1,) * (rhs.ndim - nbatch)
rhs_shape = rhs.shape[:nbatch] + (1,) * (lhs.ndim - nbatch) + rhs.shape[nbatch:]
operand = lax.reshape(lhs, lhs_shape) * lax.reshape(rhs, rhs_shape)
# contract using lax.dot_general
lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
for n in contracted_names)
bdims = tuple(range(len(batch_dims)))
dimension_numbers = [(lhs_cont, rhs_cont), (bdims, bdims)]
operand = lax.dot_general(lhs, rhs, dimension_numbers, precision)
deleted_names = batch_names + ''.join(contracted_names)
names = (batch_names + removechars(lhs_names, deleted_names)
+ removechars(rhs_names, deleted_names))
else:
raise NotImplementedError # if this is actually reachable, open an issue!
@ -2331,50 +2323,6 @@ def _einsum(operands, contractions, precision):
return operands[0]
def _dot_general(lhs, rhs, lhs_cont, rhs_cont, nbatch, precision):
"""Helper for einsum contractions."""
# lax.dot_general has some tight constraints on dimension_numbers that this
# wrapper loosens via transposes and reshapes
assert len(lhs_cont) == len(rhs_cont) > 0
ncont = len(lhs_cont)
lhs_ntensor = lhs.ndim - nbatch - ncont
rhs_ntensor = rhs.ndim - nbatch - ncont
batch_dims = tuple(range(nbatch))
if ncont == 1 and 0 <= lhs_ntensor <= 1 and 0 <= rhs_ntensor <= 1:
dimension_numbers = [(lhs_cont, rhs_cont), (batch_dims, batch_dims)]
return lax.dot_general(lhs, rhs, dimension_numbers, precision)
else:
# move contracting dimensions to the end. lax.dot_general only allows one
# contracting dimension, so if there's more than one we collapse them.
if ncont > 1:
lhs_cdims = tuple(range(lhs.ndim - ncont, lhs.ndim))
lhs = moveaxis(lhs, lhs_cont, lhs_cdims)
lhs = lhs.reshape(lhs.shape[:-ncont] + (-1,))
rhs_cdims = tuple(range(rhs.ndim - ncont, rhs.ndim))
rhs = moveaxis(rhs, rhs_cont, rhs_cdims)
rhs = rhs.reshape(rhs.shape[:-ncont] + (-1,))
else:
lhs = moveaxis(lhs, lhs_cont[0], -1)
rhs = moveaxis(rhs, rhs_cont[0], -1)
# lax.dot_general only allows zero or one tensor product dims per operand,
# so if there's more than one we collapse them.
result_shape = lhs.shape[:nbatch] + lhs.shape[nbatch:-1] + rhs.shape[nbatch:-1]
if lhs_ntensor > 1:
lhs = lhs.reshape(lhs.shape[:nbatch] + (-1,) + lhs.shape[-1:])
if rhs_ntensor > 1:
rhs = rhs.reshape(rhs.shape[:nbatch] + (-1,) + rhs.shape[-1:])
lhs_cont, rhs_cont = [lhs.ndim - 1], [rhs.ndim - 1]
dimension_numbers = [(lhs_cont, rhs_cont), (batch_dims, batch_dims)]
result = lax.dot_general(lhs, rhs, dimension_numbers, precision)
return lax.reshape(result, result_shape)
def _movechars(s, src, dst):
"""Helper for einsum string munging, like moveaxis on identifier strings."""
chars = [c for i, c in enumerate(s) if i not in src]

View File

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.1.52"
__version__ = "0.1.53"

View File

@ -29,6 +29,7 @@ cc_library(
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
"@com_google_absl//absl/base",
"@pybind11",
@ -161,6 +162,7 @@ cuda_library(
":gpu_kernel_helpers",
":kernel_helpers",
"@local_config_cuda//cuda:cuda_headers",
"@pybind11",
],
)

View File

@ -22,6 +22,7 @@ import unittest
import warnings
import weakref
from absl import logging
from absl.testing import absltest
import numpy as onp
import six
@ -494,11 +495,29 @@ class APITest(jtu.JaxTestCase):
("primal and tangent arguments to jax.jvp must have the same tree "
"structure"),
lambda: api.jvp(lambda x, y: x * y, (onp.float32(2),), ()))
# If primals and tangents must both be tuples or both lists
self.assertRaisesRegex(
TypeError,
("primal and tangent arguments to jax.jvp must have the same tree "
"structure"),
lambda: api.jvp(lambda x, y: x * y, (onp.float32(2),), [onp.float32(2)]))
self.assertRaisesRegex(
TypeError,
"primal and tangent arguments to jax.jvp must have equal types",
lambda: api.jvp(lambda x: -x, (onp.float16(2),), (onp.float32(4),)))
def test_jvp_non_tuple_arguments(self):
def f(x, y): return x + y
self.assertRaisesRegex(
TypeError,
"primal and tangent arguments to jax.jvp must be tuples or lists; found float and tuple.",
lambda: partial(api.jvp(f, 0., (1.,))))
self.assertRaisesRegex(
TypeError,
"primal and tangent arguments to jax.jvp must be tuples or lists; found tuple and ndarray.",
lambda: partial(api.jvp(f, (0.,), onp.array([1., 2.]))))
def test_vjp_mismatched_arguments(self):
_, pullback = api.vjp(lambda x, y: x * y, onp.float32(3), onp.float32(4))
self.assertRaisesRegex(
@ -1326,5 +1345,214 @@ class JaxprTest(jtu.JaxTestCase):
""")
def test_grad_of_jit_compilation_caching(self):
if not hasattr(self, "assertLogs"):
raise unittest.SkipTest("test requires assertLogs (python 3)")
lax.add(1, 2) # make sure some initial warnings are already printed
sin = api.jit(np.sin)
prev_level = logging.get_verbosity()
try:
logging.set_verbosity('DEBUG')
with self.assertLogs(level=logging.DEBUG) as l:
ans1 = api.grad(sin)(2.)
ans2 = api.grad(sin)(3.)
finally:
logging.set_verbosity(prev_level)
self.assertLen(l.output, 2)
self.assertAllClose(ans1, onp.cos(2.), check_dtypes=False)
self.assertAllClose(ans2, onp.cos(3.), check_dtypes=False)
def test_remat_basic(self):
@api.remat
def g(x):
return lax.sin(lax.sin(x)), 3.
def f(x):
x, _ = g(x)
return x
ans = f(2.)
expected = onp.sin(onp.sin(2.))
self.assertAllClose(ans, expected, check_dtypes=False)
ans, f_lin = api.linearize(f, 2.)
expected = onp.sin(onp.sin(2.))
self.assertAllClose(ans, expected, check_dtypes=False)
ans = f_lin(3.)
expected = onp.cos(onp.sin(2.)) * onp.cos(2.) * 3.
self.assertAllClose(ans, expected, check_dtypes=False)
sin_calls = []
cos_calls = []
sin_impl = lax.sin_p.impl
cos_impl = lax.cos_p.impl
try:
lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x))
lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x))
f_lin(3.)
finally:
lax.sin_p.def_impl(sin_impl)
lax.cos_p.def_impl(cos_impl)
self.assertEqual(len(sin_calls), 1)
self.assertEqual(len(cos_calls), 2)
def test_remat_freevars(self):
def f1(x):
y = 2 * np.sin(x)
z = np.cos(x) * np.sin(y)
return z
def f2(x):
y = 2 * np.sin(x)
z = api.remat(lambda x: np.cos(x) * np.sin(y))(x)
return z
ans, f_lin = api.linearize(f2, 2.)
expected, f_lin_expected = api.linearize(f1, 2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = f_lin(3.)
expected = f_lin_expected(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_grad_python_control_flow(self):
@partial(api.remat, concrete=True)
def g(x):
if x > 0:
return lax.sin(x), 3.
else:
return lax.cos(x), 4.
def f(x):
x, _ = g(x)
return x
ans = f(2.)
expected = onp.sin(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(f)(2.)
expected = onp.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_jit(self):
@api.remat
def g(x):
return lax.sin(lax.sin(x))
def f_(x):
return g(x)
f = api.jit(f_)
ans = f(2.)
expected = onp.sin(onp.sin(2.))
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(f)(2.)
expected = onp.cos(onp.sin(2.)) * onp.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jit(api.grad(f_))(2.)
expected = onp.cos(onp.sin(2.)) * onp.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_vmap(self):
@api.remat
def g(x):
return lax.sin(lax.sin(x))
x = onp.arange(3.)
ans = api.vmap(g)(x)
expected = onp.sin(onp.sin(x))
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jacfwd(g)(x)
expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x))
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jacrev(g)(x)
expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x))
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_higher_order_autodiff(self):
def f(x):
return lax.cos(lax.sin(x))
g = api.remat(f)
ans = api.grad(api.grad(g))(3.)
expected = api.grad(api.grad(f))(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_scan(self):
to_scan = lambda c, x: (np.sin(c), None)
def f_noremat(x):
y, _ = lax.scan(to_scan, x, onp.arange(3.))
return y
def f_yesremat(x):
y, _ = lax.scan(api.remat(to_scan), x, onp.arange(3.))
return y
ans = f_yesremat(4.)
expected = f_noremat(4.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(f_yesremat)(4.)
expected = api.grad(f_noremat)(4.)
self.assertAllClose(ans, expected, check_dtypes=False)
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
scan_eqn, = jaxpr.eqns
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
scan_eqn, = jaxpr.eqns
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
def test_remat_no_redundant_flops(self):
# see https://github.com/google/jax/pull/1749#issuecomment-558267584
@api.jit
def g(x):
return f(2., x)
@api.remat
def f(x, y):
return np.sin(x) * y
# We swap out sin_p's impl rule to count how many times it's invoked
called = []
sin_impl = lax.sin_p.impl
try:
lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x))
api.grad(g)(3.)
finally:
lax.sin_p.def_impl(sin_impl)
num_calls = len(called)
self.assertEqual(num_calls, 1)
def test_remat_binomial_checkpointing(self):
def binom_checkpoint(funs):
if len(funs) == 1:
return funs[0]
else:
f1 = binom_checkpoint(funs[:len(funs)//2])
f2 = binom_checkpoint(funs[len(funs)//2:])
return api.remat(lambda x: f1(f2(x)))
f1 = binom_checkpoint([np.sin, np.sin, np.sin, np.sin])
f2 = lambda x: np.sin(np.sin(np.sin(np.sin(x))))
x = 4.
self.assertAllClose(f1(x), f2(x), check_dtypes=False)
self.assertAllClose(api.grad(f1)(x), api.grad(f2)(x), check_dtypes=False)
if __name__ == '__main__':
absltest.main()

View File

@ -248,7 +248,7 @@ class GeneratedFunTest(jtu.JaxTestCase):
tangents = [tangents[i] for i in dyn_argnums]
fun, vals = partial_argnums(fun, vals, dyn_argnums)
ans1, deriv1 = jvp_fd(fun, vals, tangents)
ans2, deriv2 = jvp(fun, vals, tangents)
ans2, deriv2 = jvp(fun, tuple(vals), tuple(tangents))
check_all_close(ans1, ans2)
check_all_close(deriv1, deriv2)

View File

@ -2393,6 +2393,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertAllClose(lnp.broadcast_to(1, (3, 2)), onp.ones((3, 2)),
check_dtypes=False)
def testBroadcastToOnScalar(self):
self.assertIsInstance(lnp.broadcast_to(10.0, ()), lnp.ndarray)
self.assertIsInstance(onp.broadcast_to(10.0, ()), onp.ndarray)
def testPrecision(self):
def iter_eqns(jaxpr):

View File

@ -53,10 +53,10 @@ class OptixTest(absltest.TestCase):
# experimental/optix.py
optix_params = self.init_params
opt_init, opt_update = optix.sgd(LR, 0.0)
state = opt_init(optix_params)
sgd = optix.sgd(LR, 0.0)
state = sgd.init(optix_params)
for _ in range(STEPS):
updates, state = opt_update(self.per_step_updates, state)
updates, state = sgd.update(self.per_step_updates, state)
optix_params = optix.apply_updates(optix_params, updates)
# Check equivalence.
@ -76,10 +76,10 @@ class OptixTest(absltest.TestCase):
# experimental/optix.py
optix_params = self.init_params
opt_init, opt_update = optix.adam(LR, b1, b2, eps)
state = opt_init(optix_params)
adam = optix.adam(LR, b1, b2, eps)
state = adam.init(optix_params)
for _ in range(STEPS):
updates, state = opt_update(self.per_step_updates, state)
updates, state = adam.update(self.per_step_updates, state)
optix_params = optix.apply_updates(optix_params, updates)
# Check equivalence.
@ -99,10 +99,10 @@ class OptixTest(absltest.TestCase):
# experimental/optix.py
optix_params = self.init_params
opt_init, opt_update = optix.rmsprop(LR, decay, eps)
state = opt_init(optix_params)
rmsprop = optix.rmsprop(LR, decay, eps)
state = rmsprop.init(optix_params)
for _ in range(STEPS):
updates, state = opt_update(self.per_step_updates, state)
updates, state = rmsprop.update(self.per_step_updates, state)
optix_params = optix.apply_updates(optix_params, updates)
# Check equivalence.