mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix grad of jit caching bug
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
parent
7a9f1f3f1c
commit
2867e4be08
@ -26,7 +26,8 @@ from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like
|
||||
from ..abstract_arrays import raise_to_shaped
|
||||
from ..util import unzip2, unzip3, safe_map, safe_zip, partial, split_list
|
||||
from ..tree_util import build_tree, register_pytree_node, tree_map
|
||||
from ..linear_util import thunk, transformation, transformation_with_aux, wrap_init
|
||||
from ..linear_util import (thunk, transformation, transformation_with_aux,
|
||||
wrap_init, hashable_partial)
|
||||
from ..api_util import flatten_fun, flatten_fun_nokwargs
|
||||
from ..tree_util import tree_flatten, tree_unflatten
|
||||
|
||||
@ -451,18 +452,18 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
|
||||
out_flat, tree_def = tree_flatten((primal_out, tangent_out))
|
||||
yield out_flat, tree_def
|
||||
|
||||
|
||||
def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
|
||||
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
|
||||
fun = wrap_init(partial(backward_pass, jaxpr))
|
||||
fun = hashable_partial(wrap_init(backward_pass), jaxpr)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
out_flat = primitive.bind(fun, *all_args, **params)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
|
||||
|
||||
def map_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
|
||||
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
|
||||
fun = wrap_init(partial(backward_pass, jaxpr))
|
||||
fun = hashable_partial(wrap_init(backward_pass), jaxpr)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
out_flat = primitive.bind(fun, *all_args, **params)
|
||||
freevar_cts, arg_cts = tree_unflatten(out_tree(), out_flat)
|
||||
|
@ -22,6 +22,7 @@ import itertools as it
|
||||
import operator as op
|
||||
import os
|
||||
|
||||
from absl import logging
|
||||
import numpy as onp
|
||||
import six
|
||||
from six.moves import xrange
|
||||
@ -381,8 +382,10 @@ def _xla_call_impl(fun, *args, **params):
|
||||
|
||||
@lu.cache
|
||||
def _xla_callable(fun, device, backend, *abstract_args):
|
||||
if FLAGS.jax_log_compiles:
|
||||
print("Compiling {} for args {}.".format(fun.__name__, abstract_args))
|
||||
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
|
||||
logging.log(log_priority,
|
||||
"Compiling {} for args {}.".format(fun.__name__, abstract_args))
|
||||
|
||||
pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
|
||||
with core.new_master(pe.JaxprTrace, True) as master:
|
||||
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
|
||||
|
@ -210,3 +210,8 @@ def cache(call):
|
||||
cache[key] = (ans, fun.stores)
|
||||
return ans
|
||||
return memoized_fun
|
||||
|
||||
@transformation
|
||||
def hashable_partial(x, *args):
|
||||
ans = yield (x,) + args, {}
|
||||
yield ans
|
||||
|
@ -22,6 +22,7 @@ import unittest
|
||||
import warnings
|
||||
import weakref
|
||||
|
||||
from absl import logging
|
||||
from absl.testing import absltest
|
||||
import numpy as onp
|
||||
import six
|
||||
@ -1286,5 +1287,22 @@ class APITest(jtu.JaxTestCase):
|
||||
def test_scalar_literals(self):
|
||||
self.assertLen(api.make_jaxpr(lambda x: x + 2)(42).constvars, 0)
|
||||
|
||||
def test_grad_of_jit_compilation_caching(self):
|
||||
if not hasattr(self, "assertLogs"):
|
||||
raise unittest.SkipTest("test requires assertLogs (python 3)")
|
||||
|
||||
lax.add(1, 2) # make sure some initial warnings are already printed
|
||||
|
||||
sin = api.jit(np.sin)
|
||||
|
||||
with self.assertLogs(level=logging.DEBUG) as l:
|
||||
ans1 = api.grad(sin)(2.)
|
||||
ans2 = api.grad(sin)(3.)
|
||||
self.assertLen(l.output, 2)
|
||||
|
||||
self.assertAllClose(ans1, onp.cos(2.), check_dtypes=False)
|
||||
self.assertAllClose(ans2, onp.cos(3.), check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user