Merge pull request #10775 from pschuh:mlir-caching

PiperOrigin-RevId: 462263487
This commit is contained in:
jax authors 2022-07-20 17:10:40 -07:00
commit be6db2e619
3 changed files with 71 additions and 2 deletions

View File

@ -892,6 +892,11 @@ config.define_bool_state(
)
)
config.define_bool_state(
name='jax_experimental_subjaxpr_lowering_cache',
default=False,
help='Enable using a cache for lowering subjaxprs.')
@contextlib.contextmanager
def explicit_device_put_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_put*() call."""

View File

@ -1818,7 +1818,11 @@ class DynamicJaxprTrace(core.Trace):
in_tracers = [*implicit_tracers, *explicit_tracers]
# TODO(mattjj): check in_tracers are consistent with f.in_type annotation
with core.new_sublevel():
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main)
if config.jax_check_tracer_leaks or not config.jax_experimental_subjaxpr_lowering_cache:
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main)
else:
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2_memoized(
f, self.main).val
if jaxpr.effects:
raise NotImplementedError('Effects not supported for call primitives.')
if params.get('inline', False):
@ -2117,6 +2121,18 @@ def trace_to_subjaxpr_dynamic2(
return jaxpr, out_type, consts
@lu.cache
def trace_to_subjaxpr_dynamic2_memoized(fun: lu.WrappedFun,
main: core.MainTrace):
return WrapperForWeakRef(trace_to_subjaxpr_dynamic2(fun, main))
class WrapperForWeakRef:
val: Any
def __init__(self, val):
self.val = val
@contextlib.contextmanager
def extend_jaxpr_stack(main, frame):
main.jaxpr_stack = main.jaxpr_stack + (frame,)

View File

@ -3779,6 +3779,54 @@ class APITest(jtu.JaxTestCase):
g(1, 2) # doesn't crash
@jtu.with_config(jax_experimental_subjaxpr_lowering_cache=True)
class SubcallTraceCacheTest(jtu.JaxTestCase):
def test_subcall_trace_caching(self):
should_be_tracing_f = False
@api.jit
def f(x):
self.assertTrue(should_be_tracing_f)
return x**2
@api.jit
def g(x):
nonlocal should_be_tracing_f
self.assertTrue(should_be_tracing_g)
should_be_tracing_f = True
y = f(x)
should_be_tracing_f = False
z = f(x + 1)
return y + z
should_be_tracing_g = True
out = g(2)
self.assertEqual(out, 13)
should_be_tracing_g = False
out = g(3)
self.assertEqual(out, 25)
def test_subcall_jaxpr_id(self):
@api.jit
def f(x):
return x**2
def g(x):
y = f(x)
z = f(x + 1)
return y + z
jaxpr = api.make_jaxpr(g)(2)
self.assertIn("call_jaxpr", jaxpr.eqns[0].params)
self.assertIn("call_jaxpr", jaxpr.eqns[2].params)
subjaxpr1 = jaxpr.eqns[0].params["call_jaxpr"]
subjaxpr2 = jaxpr.eqns[2].params["call_jaxpr"]
self.assertIs(subjaxpr1, subjaxpr2)
class RematTest(jtu.JaxTestCase):
@parameterized.named_parameters(
@ -4274,7 +4322,7 @@ class RematTest(jtu.JaxTestCase):
return seq[0]
remat(g)()
remat(g)()
remat(lambda: g())() # lambda defeats caching
with self.assertRaisesRegex(UnexpectedTracerError, "global state"):
api.jit(f)()