mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
simplify remat partial eval parameterization
The main win here is reducing the number of arguments for the function that parameterizes _remat_partial_eval (so it can be used both with remat and invertible ad features). I also included a fix to _remat_partial_eval that is needed in #3370, though I don't think it's needed on master. It was easier to include the fix now. Both these changes made rebasing #3370 easier!
This commit is contained in:
parent
6bcf0568aa
commit
05e0716f40
@ -22,7 +22,8 @@ from jax import linear_util as lu
|
||||
from . import ad
|
||||
from . import partial_eval as pe
|
||||
from .partial_eval import (PartialVal, partial_eval_jaxpr,
|
||||
JaxprTracer, ConstVar, convert_constvars_jaxpr, new_eqn_recipe)
|
||||
JaxprTracer, ConstVar, convert_constvars_jaxpr,
|
||||
new_eqn_recipe, _partition_knowns)
|
||||
from ..core import raise_to_shaped, get_aval, Literal, Jaxpr
|
||||
from ..custom_derivatives import _initial_style_jaxpr, _resolve_kwargs
|
||||
from ..api_util import flatten_fun_nokwargs
|
||||
@ -43,24 +44,25 @@ invertible_call_p.def_custom_bind(invertible_call)
|
||||
invertible_call_p.def_impl(core.call_impl)
|
||||
invertible_call_p.multiple_results = True
|
||||
|
||||
def _invertible_call_make_output_tracers(trace, typed_jaxpr, in_tracers, out_known_pvals, out_unknown_pvals, _, params):
|
||||
unknown_output_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_unknown_pvals]
|
||||
lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr)
|
||||
# Add dummy arguments representing the outputs to the jaxpr. Those should remain unused in case
|
||||
# the expression actually ends up being evaluated, but they make it well-formed.
|
||||
out_known_avals = tuple(raise_to_shaped(get_aval(pval.get_known())) for pval in out_known_pvals)
|
||||
lifted_jaxpr = _append_invars(lifted_jaxpr, out_known_avals)
|
||||
new_params = dict(params, call_jaxpr=lifted_jaxpr)
|
||||
# We also append some dummy outputs that correspond to the known outputs we left in the call_jaxpr
|
||||
dummy_outputs = [JaxprTracer(trace, pval, core.unit) for pval in out_known_pvals]
|
||||
def _invertible_call_make_output_tracers(trace, in_tracers, out_tracers, params):
|
||||
uks = [not t.pval.is_known() for t in out_tracers]
|
||||
out_tracers_known, out_tracers_unknown = _partition_knowns(out_tracers, uks)
|
||||
|
||||
output_constants = [JaxprTracer(trace, pval, ConstVar(pval.get_known())) for pval in out_known_pvals]
|
||||
eqn = new_eqn_recipe(tuple(it.chain(in_tracers, output_constants)),
|
||||
dummy_outputs + unknown_output_tracers,
|
||||
invertible_call_p,
|
||||
new_params)
|
||||
for t in unknown_output_tracers: t.recipe = eqn
|
||||
return unknown_output_tracers
|
||||
# Add dummy arguments representing the outputs to the jaxpr. Those should
|
||||
# remain unused if the expression is evaluated, but they make it well-formed.
|
||||
out_known_avals = [raise_to_shaped(t.pval.get_aval()) for t in out_tracers_known]
|
||||
out_consts = [trace.instantiate_const(t) for t in out_tracers_known]
|
||||
new_jaxpr = _append_invars(params['call_jaxpr'], tuple(out_known_avals))
|
||||
new_in_tracers = (*in_tracers, *out_consts)
|
||||
|
||||
# Append dummy outputs that correspond to known outputs left in the call_jaxpr
|
||||
dummy_outputs = [JaxprTracer(trace, t.pval, core.unit) for t in out_tracers_known]
|
||||
new_out_tracers = (*dummy_outputs, *out_tracers_unknown)
|
||||
|
||||
eqn = new_eqn_recipe(new_in_tracers, new_out_tracers, invertible_call_p,
|
||||
dict(params, call_jaxpr=new_jaxpr))
|
||||
for t in out_tracers_unknown: t.recipe = eqn
|
||||
return new_out_tracers
|
||||
|
||||
pe.call_partial_eval_rules[invertible_call_p] = partial(
|
||||
pe._remat_partial_eval, _invertible_call_make_output_tracers)
|
||||
@ -69,7 +71,8 @@ pe.call_partial_eval_rules[invertible_call_p] = partial(
|
||||
@cache()
|
||||
def _append_invars(jaxpr, avals):
|
||||
newvar = core.gensym([jaxpr])
|
||||
return core.Jaxpr(jaxpr.constvars, jaxpr.invars + map(newvar, avals), jaxpr.outvars, jaxpr.eqns)
|
||||
return core.Jaxpr(jaxpr.constvars, jaxpr.invars + map(newvar, avals),
|
||||
jaxpr.outvars, jaxpr.eqns)
|
||||
|
||||
|
||||
def _invertible_call_transpose(params, call_jaxpr, args, ct, _):
|
||||
|
@ -680,7 +680,12 @@ remat_call_p.def_custom_bind(remat_call)
|
||||
remat_call_p.def_impl(core.call_impl)
|
||||
remat_call_p.multiple_results = True
|
||||
|
||||
def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params):
|
||||
# We reuse the _remat_partial_eval function both for remat_call and for
|
||||
# invertible_call, both of which in a sense stage out operations to
|
||||
# rematerialize values. The two usages differ only in details of what jaxpr eqn
|
||||
# and output tracers are formed. As a result we parameterize _remat_partial_eval
|
||||
# by a `process_out` function.
|
||||
def _remat_partial_eval(process_out, trace, _, f, tracers, params):
|
||||
concrete = params['concrete']
|
||||
|
||||
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
|
||||
@ -699,24 +704,27 @@ def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params):
|
||||
|
||||
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
|
||||
in_pvals = [t.pval for t in instantiated_tracers]
|
||||
with core.initial_style_staging():
|
||||
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
|
||||
f, in_pvals, partial(remat_call_p.bind, **params))
|
||||
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
|
||||
f, in_pvals, partial(remat_call_p.bind, **params))
|
||||
|
||||
# Convert consts to inputs, since they may contain Tracer instances.
|
||||
jaxpr = convert_constvars_jaxpr(jaxpr)
|
||||
const_tracers = map(trace.new_instantiated_const, consts)
|
||||
|
||||
# Since we traced with everything marked as unknown, but we need to know which
|
||||
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
|
||||
|
||||
in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in env_tracers]
|
||||
+ [raise_to_shaped(pval.get_aval()) for pval in in_pvals])
|
||||
in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in const_tracers] +
|
||||
[raise_to_shaped(t.pval.get_aval()) for t in env_tracers] +
|
||||
[raise_to_shaped(pval.get_aval()) for pval in in_pvals])
|
||||
out_avals = [raise_to_shaped(abstract_unit if var is unitvar
|
||||
else get_aval(var.val) if type(var) is Literal
|
||||
else pval.get_aval())
|
||||
for var, pval in zip(jaxpr.outvars, eval_out_pvals)]
|
||||
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
|
||||
in_unknowns = [not t.is_known() for t in it.chain(env_tracers, tracers)]
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns,
|
||||
instantiate=False,
|
||||
trace_type=trace.master.trace_type)
|
||||
typed_jaxpr = core.TypedJaxpr(jaxpr, (), in_avals, out_avals)
|
||||
in_unknowns = ([False] * len(consts) +
|
||||
[not t.is_known() for t in it.chain(env_tracers, tracers)])
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
|
||||
typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.master.trace_type)
|
||||
out_knowns = [not b for b in out_unknowns]
|
||||
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
|
||||
|
||||
@ -728,56 +736,52 @@ def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params):
|
||||
# values. For the use case of inverse-mode ad in op-by-op ("eager mode")
|
||||
# evaluation, all the primal outputs should be concrete (thus not recomputed).
|
||||
to_compute = [type(pval[0]) is not ConcreteArray
|
||||
for uk, pval in zip(out_unknowns, eval_out_pvals)
|
||||
if not uk]
|
||||
for uk, pval in zip(out_unknowns, eval_out_pvals) if not uk]
|
||||
num_outputs = len(jaxpr_unknown.out_avals)
|
||||
num_res = len(jaxpr_known.out_avals) - num_outputs
|
||||
jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res, drop_outputs=True)
|
||||
jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute)
|
||||
_, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers))
|
||||
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*in_consts)
|
||||
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts)
|
||||
out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts)
|
||||
|
||||
# Now that we have out_pvals, the rest is similar to JaxprTrace.process_call.
|
||||
# Known outputs should keep propagating as constants
|
||||
assert all(pv.is_known() for pv in out_known_pvals)
|
||||
known_output_tracers = [trace.new_const(pval.get_known()) for pval in out_known_pvals]
|
||||
known_output_tracers = [trace.new_const(pval.get_known())
|
||||
for pval in out_known_pvals]
|
||||
# Unknown outputs get wrapped in tracers with the appropriate recipe
|
||||
unknown_output_tracers = [JaxprTracer(trace, out_pval, None)
|
||||
for out_pval in out_unknown_pvals]
|
||||
out_tracers = _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)
|
||||
|
||||
# Unknown outputs get wrapped in tracers with the appropriate recipe, as in JaxprTrace.process_call
|
||||
const_tracers = map(trace.new_instantiated_const, consts)
|
||||
unknown_output_tracers = wrap_unknown_pvals(
|
||||
trace,
|
||||
typed_jaxpr,
|
||||
tuple(it.chain(const_tracers, env_tracers, instantiated_tracers)),
|
||||
out_known_pvals,
|
||||
out_unknown_pvals,
|
||||
out_unknowns,
|
||||
params)
|
||||
in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers)
|
||||
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(typed_jaxpr.jaxpr))
|
||||
return process_out(trace, in_tracers, out_tracers, new_params)
|
||||
|
||||
return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)
|
||||
def _remat_make_output_tracers(_, in_tracers, out_tracers, params):
|
||||
# dce jaxpr outputs
|
||||
jaxpr = params['call_jaxpr']
|
||||
out_unknowns = [not t.pval.is_known() for t in out_tracers]
|
||||
typed_jaxpr = core.TypedJaxpr(jaxpr, (), [v.aval for v in jaxpr.invars],
|
||||
[v.aval for v in jaxpr.outvars])
|
||||
new_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns, drop_outputs=True).jaxpr
|
||||
new_params = dict(params, call_jaxpr=new_jaxpr)
|
||||
|
||||
def _remat_make_output_tracers(trace, typed_jaxpr, input_tracers, _, out_unknown_pvals, out_unknowns, params):
|
||||
unknown_output_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_unknown_pvals]
|
||||
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns, drop_outputs=True)
|
||||
lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr)
|
||||
new_params = dict(params, call_jaxpr=lifted_jaxpr)
|
||||
eqn = new_eqn_recipe(input_tracers,
|
||||
unknown_output_tracers,
|
||||
remat_call_p,
|
||||
new_params)
|
||||
for t in unknown_output_tracers: t.recipe = eqn
|
||||
return unknown_output_tracers
|
||||
# set up eqn for unknown outputs
|
||||
unknown_out_tracers = [t for t in out_tracers if not t.pval.is_known()]
|
||||
eqn = new_eqn_recipe(in_tracers, unknown_out_tracers, remat_call_p, new_params)
|
||||
for t in unknown_out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
call_partial_eval_rules[remat_call_p] = partial(
|
||||
_remat_partial_eval, _remat_make_output_tracers)
|
||||
|
||||
call_partial_eval_rules[remat_call_p] = partial(_remat_partial_eval, _remat_make_output_tracers)
|
||||
def _partition_knowns(pvals, unknowns: Sequence[bool]):
|
||||
return ([e for e, unknown in zip(pvals, unknowns) if not unknown],
|
||||
[e for e, unknown in zip(pvals, unknowns) if unknown])
|
||||
|
||||
def _partition_knowns(l, unknowns):
|
||||
return ([e for e, unknown in zip(l, unknowns) if not unknown],
|
||||
[e for e, unknown in zip(l, unknowns) if unknown])
|
||||
|
||||
def _zip_knowns(kl, ul, unknowns):
|
||||
ul_it = iter(ul)
|
||||
kl_it = iter(kl)
|
||||
return [next(ul_it) if unknown else next(kl_it) for unknown in unknowns]
|
||||
def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
|
||||
known_iter, unknown_iter = iter(known_list), iter(unknown_list)
|
||||
return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown]
|
||||
|
||||
|
||||
def _dce_jaxpr(typed_jaxpr: TypedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> TypedJaxpr:
|
||||
|
@ -1266,6 +1266,156 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans1, np.cos(2.), check_dtypes=False)
|
||||
self.assertAllClose(ans2, np.cos(3.), check_dtypes=False)
|
||||
|
||||
def test_trivial_computations(self):
|
||||
x = jnp.array([1, 2, 3])
|
||||
y = api.jit(lambda x: x)(x)
|
||||
self.assertIs(x, y)
|
||||
|
||||
z1, z2 = api.jit(lambda x: (x, x))(x)
|
||||
self.assertIs(z1, z2)
|
||||
|
||||
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
|
||||
z1, z2, z3 = api.jit(lambda x, y: (y, 1, x))(x1, x2)
|
||||
self.assertIs(z1, x2)
|
||||
self.assertIs(z3, x1)
|
||||
self.assertEqual(z2, 1)
|
||||
|
||||
def test_nested_jit_hoisting(self):
|
||||
@api.jit
|
||||
def f(x, y):
|
||||
z = 2 * x
|
||||
return y + z, 3
|
||||
|
||||
@api.jit
|
||||
def g(x):
|
||||
return f(2, x)
|
||||
|
||||
jaxpr_subcomp = xla.jaxpr_subcomp
|
||||
|
||||
jaxprs = []
|
||||
def jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
|
||||
jaxprs.append(jaxpr)
|
||||
return jaxpr_subcomp(c, jaxpr, *args, **kwargs)
|
||||
|
||||
try:
|
||||
xla.jaxpr_subcomp = jaxpr_subcomp_and_collect
|
||||
ans = g(3)
|
||||
finally:
|
||||
xla.jaxpr_subcomp = jaxpr_subcomp
|
||||
|
||||
self.assertEqual(ans, (7, 3))
|
||||
self.assertLen(jaxprs, 2)
|
||||
outer_jaxpr, inner_jaxpr = jaxprs
|
||||
|
||||
self.assertLen(outer_jaxpr.eqns, 1)
|
||||
self.assertEqual(outer_jaxpr.eqns[0].primitive.name, 'xla_call')
|
||||
subjaxpr_1 = outer_jaxpr.eqns[0].params["call_jaxpr"]
|
||||
self.assertEqual(str(subjaxpr_1), str(inner_jaxpr))
|
||||
self.assertLen(inner_jaxpr.eqns, 2)
|
||||
self.assertEqual(inner_jaxpr.eqns[0].primitive.name, 'mul')
|
||||
self.assertEqual(inner_jaxpr.eqns[1].primitive.name, 'add')
|
||||
|
||||
def test_primitive_compilation_cache(self):
|
||||
with jtu.count_primitive_compiles() as count:
|
||||
lax.add(1, 2)
|
||||
lax.add(2, 3)
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def test_arange_jit(self):
|
||||
# see https://github.com/google/jax/issues/553
|
||||
def fun(x):
|
||||
r = jnp.arange(x.shape[0])[x]
|
||||
return r
|
||||
|
||||
jit(fun)(jnp.array([0, 1, 2], dtype=jnp.int32)) # doesn't crash
|
||||
|
||||
def helper_save_tracer(self, x):
|
||||
self._saved_tracer = x
|
||||
return x
|
||||
|
||||
def test_escaped_tracers_diffent_top_level_traces(self):
|
||||
api.jit(self.helper_save_tracer)(0.)
|
||||
with self.assertRaisesRegex(
|
||||
core.UnexpectedTracerError,
|
||||
re.compile(
|
||||
"Encountered an unexpected tracer.*Different traces at same level",
|
||||
re.DOTALL)):
|
||||
api.jit(lambda x: self._saved_tracer)(0.)
|
||||
|
||||
def test_escaped_tracers_cant_lift_sublevels(self):
|
||||
api.jit(self.helper_save_tracer)(0.)
|
||||
with self.assertRaisesRegex(
|
||||
core.UnexpectedTracerError,
|
||||
re.compile(
|
||||
"Encountered an unexpected tracer.*Can't lift sublevels 1 to 0",
|
||||
re.DOTALL)):
|
||||
api.jit(lambda x: x)(self._saved_tracer)
|
||||
|
||||
def test_escaped_tracers_tracer_from_higher_level(self):
|
||||
api.grad(self.helper_save_tracer)(0.)
|
||||
with self.assertRaisesRegex(
|
||||
core.UnexpectedTracerError,
|
||||
re.compile(
|
||||
"Encountered an unexpected tracer.*Tracer from a higher level",
|
||||
re.DOTALL)):
|
||||
api.grad(lambda x: x)(self._saved_tracer)
|
||||
|
||||
def test_escaped_tracers_incompatible_sublevel(self):
|
||||
def func1(x):
|
||||
api.jit(self.helper_save_tracer)(0.)
|
||||
# Use the tracer
|
||||
return x + self._saved_tracer
|
||||
with self.assertRaisesRegex(
|
||||
core.UnexpectedTracerError,
|
||||
re.compile("Encountered an unexpected tracer.*Incompatible sublevel",
|
||||
re.DOTALL)):
|
||||
api.jit(func1)(2.)
|
||||
|
||||
def test_escaped_tracers_cant_lift(self):
|
||||
def func1(x):
|
||||
api.grad(self.helper_save_tracer)(0.)
|
||||
return x + self._saved_tracer
|
||||
with self.assertRaisesRegex(
|
||||
core.UnexpectedTracerError,
|
||||
re.compile("Encountered an unexpected tracer.*Can't lift",
|
||||
re.DOTALL)):
|
||||
api.grad(func1)(2.)
|
||||
|
||||
def test_escaped_tracers_not_among_input_tracers(self):
|
||||
def func1(x):
|
||||
api.grad(self.helper_save_tracer)(x)
|
||||
# Use the tracer
|
||||
return x + self._saved_tracer
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
core.UnexpectedTracerError,
|
||||
re.compile(
|
||||
"Encountered an unexpected tracer.*Tracer not among input tracers",
|
||||
re.DOTALL)):
|
||||
api.jit(func1)(2.)
|
||||
|
||||
def test_pmap_static_kwarg_error_message(self):
|
||||
# https://github.com/google/jax/issues/3007
|
||||
def f(a, b):
|
||||
return a + b
|
||||
|
||||
g = jax.pmap(f, static_broadcasted_argnums=(1,))
|
||||
|
||||
msg = (r"pmapped function has static_broadcasted_argnums=\(1,\) but was "
|
||||
r"called with only 1 positional argument. All static broadcasted "
|
||||
r"arguments must be passed positionally.")
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
g(jnp.ones((1, 1)), b=1)
|
||||
|
||||
def test_vmap_unmapped_last(self):
|
||||
@partial(jax.vmap, out_axes=jax.interpreters.batching.last)
|
||||
def f(x):
|
||||
return np.zeros((2,))
|
||||
f(np.zeros((5,)))
|
||||
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
def test_remat_basic(self):
|
||||
@api.remat
|
||||
def g(x):
|
||||
@ -1622,154 +1772,6 @@ class APITest(jtu.JaxTestCase):
|
||||
vjp(v)
|
||||
|
||||
|
||||
def test_trivial_computations(self):
|
||||
x = jnp.array([1, 2, 3])
|
||||
y = api.jit(lambda x: x)(x)
|
||||
self.assertIs(x, y)
|
||||
|
||||
z1, z2 = api.jit(lambda x: (x, x))(x)
|
||||
self.assertIs(z1, z2)
|
||||
|
||||
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
|
||||
z1, z2, z3 = api.jit(lambda x, y: (y, 1, x))(x1, x2)
|
||||
self.assertIs(z1, x2)
|
||||
self.assertIs(z3, x1)
|
||||
self.assertEqual(z2, 1)
|
||||
|
||||
def test_nested_jit_hoisting(self):
|
||||
@api.jit
|
||||
def f(x, y):
|
||||
z = 2 * x
|
||||
return y + z, 3
|
||||
|
||||
@api.jit
|
||||
def g(x):
|
||||
return f(2, x)
|
||||
|
||||
jaxpr_subcomp = xla.jaxpr_subcomp
|
||||
|
||||
jaxprs = []
|
||||
def jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
|
||||
jaxprs.append(jaxpr)
|
||||
return jaxpr_subcomp(c, jaxpr, *args, **kwargs)
|
||||
|
||||
try:
|
||||
xla.jaxpr_subcomp = jaxpr_subcomp_and_collect
|
||||
ans = g(3)
|
||||
finally:
|
||||
xla.jaxpr_subcomp = jaxpr_subcomp
|
||||
|
||||
self.assertEqual(ans, (7, 3))
|
||||
self.assertLen(jaxprs, 2)
|
||||
outer_jaxpr, inner_jaxpr = jaxprs
|
||||
|
||||
self.assertLen(outer_jaxpr.eqns, 1)
|
||||
self.assertEqual(outer_jaxpr.eqns[0].primitive.name, 'xla_call')
|
||||
subjaxpr_1 = outer_jaxpr.eqns[0].params["call_jaxpr"]
|
||||
self.assertEqual(str(subjaxpr_1), str(inner_jaxpr))
|
||||
self.assertLen(inner_jaxpr.eqns, 2)
|
||||
self.assertEqual(inner_jaxpr.eqns[0].primitive.name, 'mul')
|
||||
self.assertEqual(inner_jaxpr.eqns[1].primitive.name, 'add')
|
||||
|
||||
def test_primitive_compilation_cache(self):
|
||||
with jtu.count_primitive_compiles() as count:
|
||||
lax.add(1, 2)
|
||||
lax.add(2, 3)
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def test_arange_jit(self):
|
||||
# see https://github.com/google/jax/issues/553
|
||||
def fun(x):
|
||||
r = jnp.arange(x.shape[0])[x]
|
||||
return r
|
||||
|
||||
jit(fun)(jnp.array([0, 1, 2], dtype=jnp.int32)) # doesn't crash
|
||||
|
||||
def helper_save_tracer(self, x):
|
||||
self._saved_tracer = x
|
||||
return x
|
||||
|
||||
def test_escaped_tracers_diffent_top_level_traces(self):
|
||||
api.jit(self.helper_save_tracer)(0.)
|
||||
with self.assertRaisesRegex(
|
||||
core.UnexpectedTracerError,
|
||||
re.compile(
|
||||
"Encountered an unexpected tracer.*Different traces at same level",
|
||||
re.DOTALL)):
|
||||
api.jit(lambda x: self._saved_tracer)(0.)
|
||||
|
||||
def test_escaped_tracers_cant_lift_sublevels(self):
|
||||
api.jit(self.helper_save_tracer)(0.)
|
||||
with self.assertRaisesRegex(
|
||||
core.UnexpectedTracerError,
|
||||
re.compile(
|
||||
"Encountered an unexpected tracer.*Can't lift sublevels 1 to 0",
|
||||
re.DOTALL)):
|
||||
api.jit(lambda x: x)(self._saved_tracer)
|
||||
|
||||
def test_escaped_tracers_tracer_from_higher_level(self):
|
||||
api.grad(self.helper_save_tracer)(0.)
|
||||
with self.assertRaisesRegex(
|
||||
core.UnexpectedTracerError,
|
||||
re.compile(
|
||||
"Encountered an unexpected tracer.*Tracer from a higher level",
|
||||
re.DOTALL)):
|
||||
api.grad(lambda x: x)(self._saved_tracer)
|
||||
|
||||
def test_escaped_tracers_incompatible_sublevel(self):
|
||||
def func1(x):
|
||||
api.jit(self.helper_save_tracer)(0.)
|
||||
# Use the tracer
|
||||
return x + self._saved_tracer
|
||||
with self.assertRaisesRegex(
|
||||
core.UnexpectedTracerError,
|
||||
re.compile("Encountered an unexpected tracer.*Incompatible sublevel",
|
||||
re.DOTALL)):
|
||||
api.jit(func1)(2.)
|
||||
|
||||
def test_escaped_tracers_cant_lift(self):
|
||||
def func1(x):
|
||||
api.grad(self.helper_save_tracer)(0.)
|
||||
return x + self._saved_tracer
|
||||
with self.assertRaisesRegex(
|
||||
core.UnexpectedTracerError,
|
||||
re.compile("Encountered an unexpected tracer.*Can't lift",
|
||||
re.DOTALL)):
|
||||
api.grad(func1)(2.)
|
||||
|
||||
def test_escaped_tracers_not_among_input_tracers(self):
|
||||
def func1(x):
|
||||
api.grad(self.helper_save_tracer)(x)
|
||||
# Use the tracer
|
||||
return x + self._saved_tracer
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
core.UnexpectedTracerError,
|
||||
re.compile(
|
||||
"Encountered an unexpected tracer.*Tracer not among input tracers",
|
||||
re.DOTALL)):
|
||||
api.jit(func1)(2.)
|
||||
|
||||
def test_pmap_static_kwarg_error_message(self):
|
||||
# https://github.com/google/jax/issues/3007
|
||||
def f(a, b):
|
||||
return a + b
|
||||
|
||||
g = jax.pmap(f, static_broadcasted_argnums=(1,))
|
||||
|
||||
msg = (r"pmapped function has static_broadcasted_argnums=\(1,\) but was "
|
||||
r"called with only 1 positional argument. All static broadcasted "
|
||||
r"arguments must be passed positionally.")
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
g(jnp.ones((1, 1)), b=1)
|
||||
|
||||
def test_vmap_unmapped_last(self):
|
||||
@partial(jax.vmap, out_axes=jax.interpreters.batching.last)
|
||||
def f(x):
|
||||
return np.zeros((2,))
|
||||
f(np.zeros((5,)))
|
||||
|
||||
|
||||
class JaxprTest(jtu.JaxTestCase):
|
||||
|
||||
def test_scalar_literals(self):
|
||||
@ -2833,6 +2835,7 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
jax.grad(clip_gradient)(1.) # doesn't crash
|
||||
|
||||
|
||||
class InvertibleADTest(jtu.JaxTestCase):
|
||||
|
||||
def test_invertible_basic(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user