mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Merge pull request #10775 from pschuh:mlir-caching
PiperOrigin-RevId: 462263487
This commit is contained in:
commit
be6db2e619
@ -892,6 +892,11 @@ config.define_bool_state(
|
||||
)
|
||||
)
|
||||
|
||||
config.define_bool_state(
|
||||
name='jax_experimental_subjaxpr_lowering_cache',
|
||||
default=False,
|
||||
help='Enable using a cache for lowering subjaxprs.')
|
||||
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_put_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_put*() call."""
|
||||
|
@ -1818,7 +1818,11 @@ class DynamicJaxprTrace(core.Trace):
|
||||
in_tracers = [*implicit_tracers, *explicit_tracers]
|
||||
# TODO(mattjj): check in_tracers are consistent with f.in_type annotation
|
||||
with core.new_sublevel():
|
||||
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main)
|
||||
if config.jax_check_tracer_leaks or not config.jax_experimental_subjaxpr_lowering_cache:
|
||||
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main)
|
||||
else:
|
||||
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2_memoized(
|
||||
f, self.main).val
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Effects not supported for call primitives.')
|
||||
if params.get('inline', False):
|
||||
@ -2117,6 +2121,18 @@ def trace_to_subjaxpr_dynamic2(
|
||||
return jaxpr, out_type, consts
|
||||
|
||||
|
||||
@lu.cache
|
||||
def trace_to_subjaxpr_dynamic2_memoized(fun: lu.WrappedFun,
|
||||
main: core.MainTrace):
|
||||
return WrapperForWeakRef(trace_to_subjaxpr_dynamic2(fun, main))
|
||||
|
||||
|
||||
class WrapperForWeakRef:
|
||||
val: Any
|
||||
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
@contextlib.contextmanager
|
||||
def extend_jaxpr_stack(main, frame):
|
||||
main.jaxpr_stack = main.jaxpr_stack + (frame,)
|
||||
|
@ -3779,6 +3779,54 @@ class APITest(jtu.JaxTestCase):
|
||||
g(1, 2) # doesn't crash
|
||||
|
||||
|
||||
@jtu.with_config(jax_experimental_subjaxpr_lowering_cache=True)
|
||||
class SubcallTraceCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def test_subcall_trace_caching(self):
|
||||
should_be_tracing_f = False
|
||||
|
||||
@api.jit
|
||||
def f(x):
|
||||
self.assertTrue(should_be_tracing_f)
|
||||
return x**2
|
||||
|
||||
@api.jit
|
||||
def g(x):
|
||||
nonlocal should_be_tracing_f
|
||||
self.assertTrue(should_be_tracing_g)
|
||||
should_be_tracing_f = True
|
||||
y = f(x)
|
||||
should_be_tracing_f = False
|
||||
z = f(x + 1)
|
||||
return y + z
|
||||
|
||||
should_be_tracing_g = True
|
||||
out = g(2)
|
||||
self.assertEqual(out, 13)
|
||||
|
||||
should_be_tracing_g = False
|
||||
out = g(3)
|
||||
self.assertEqual(out, 25)
|
||||
|
||||
def test_subcall_jaxpr_id(self):
|
||||
|
||||
@api.jit
|
||||
def f(x):
|
||||
return x**2
|
||||
|
||||
def g(x):
|
||||
y = f(x)
|
||||
z = f(x + 1)
|
||||
return y + z
|
||||
|
||||
jaxpr = api.make_jaxpr(g)(2)
|
||||
self.assertIn("call_jaxpr", jaxpr.eqns[0].params)
|
||||
self.assertIn("call_jaxpr", jaxpr.eqns[2].params)
|
||||
subjaxpr1 = jaxpr.eqns[0].params["call_jaxpr"]
|
||||
subjaxpr2 = jaxpr.eqns[2].params["call_jaxpr"]
|
||||
self.assertIs(subjaxpr1, subjaxpr2)
|
||||
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -4274,7 +4322,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
return seq[0]
|
||||
|
||||
remat(g)()
|
||||
remat(g)()
|
||||
remat(lambda: g())() # lambda defeats caching
|
||||
|
||||
with self.assertRaisesRegex(UnexpectedTracerError, "global state"):
|
||||
api.jit(f)()
|
||||
|
Loading…
x
Reference in New Issue
Block a user