fix grad of jit caching bug

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
Matthew Johnson 2019-11-26 07:56:48 -08:00 committed by Matthew Johnson
parent 7a9f1f3f1c
commit 2867e4be08
4 changed files with 33 additions and 6 deletions

View File

@ -26,7 +26,8 @@ from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like
from ..abstract_arrays import raise_to_shaped
from ..util import unzip2, unzip3, safe_map, safe_zip, partial, split_list
from ..tree_util import build_tree, register_pytree_node, tree_map
from ..linear_util import thunk, transformation, transformation_with_aux, wrap_init
from ..linear_util import (thunk, transformation, transformation_with_aux,
wrap_init, hashable_partial)
from ..api_util import flatten_fun, flatten_fun_nokwargs
from ..tree_util import tree_flatten, tree_unflatten
@ -451,18 +452,18 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
out_flat, tree_def = tree_flatten((primal_out, tangent_out))
yield out_flat, tree_def
def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
fun = wrap_init(partial(backward_pass, jaxpr))
fun = hashable_partial(wrap_init(backward_pass), jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = primitive.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
def map_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
fun = wrap_init(partial(backward_pass, jaxpr))
fun = hashable_partial(wrap_init(backward_pass), jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = primitive.bind(fun, *all_args, **params)
freevar_cts, arg_cts = tree_unflatten(out_tree(), out_flat)

View File

@ -22,6 +22,7 @@ import itertools as it
import operator as op
import os
from absl import logging
import numpy as onp
import six
from six.moves import xrange
@ -381,8 +382,10 @@ def _xla_call_impl(fun, *args, **params):
@lu.cache
def _xla_callable(fun, device, backend, *abstract_args):
if FLAGS.jax_log_compiles:
print("Compiling {} for args {}.".format(fun.__name__, abstract_args))
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
logging.log(log_priority,
"Compiling {} for args {}.".format(fun.__name__, abstract_args))
pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
with core.new_master(pe.JaxprTrace, True) as master:
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)

View File

@ -210,3 +210,8 @@ def cache(call):
cache[key] = (ans, fun.stores)
return ans
return memoized_fun
@transformation
def hashable_partial(x, *args):
ans = yield (x,) + args, {}
yield ans

View File

@ -22,6 +22,7 @@ import unittest
import warnings
import weakref
from absl import logging
from absl.testing import absltest
import numpy as onp
import six
@ -1286,5 +1287,22 @@ class APITest(jtu.JaxTestCase):
def test_scalar_literals(self):
self.assertLen(api.make_jaxpr(lambda x: x + 2)(42).constvars, 0)
def test_grad_of_jit_compilation_caching(self):
if not hasattr(self, "assertLogs"):
raise unittest.SkipTest("test requires assertLogs (python 3)")
lax.add(1, 2) # make sure some initial warnings are already printed
sin = api.jit(np.sin)
with self.assertLogs(level=logging.DEBUG) as l:
ans1 = api.grad(sin)(2.)
ans2 = api.grad(sin)(3.)
self.assertLen(l.output, 2)
self.assertAllClose(ans1, onp.cos(2.), check_dtypes=False)
self.assertAllClose(ans2, onp.cos(3.), check_dtypes=False)
if __name__ == '__main__':
absltest.main()