Add dce_rules for pjit primitive so that remat can DCE through the pjit primitive and remove unused residuals

PiperOrigin-RevId: 504123801
This commit is contained in:
Yash Katariya 2023-01-23 17:31:33 -08:00 committed by jax authors
parent 1ee21d121c
commit fb9b5ec1e4
2 changed files with 51 additions and 6 deletions

View File

@ -1705,6 +1705,44 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
ad.reducing_transposes[pjit_p] = _pjit_transpose
@weakref_lru_cache
def _dce_jaxpr_pjit(
jaxpr: core.ClosedJaxpr, used_outputs: Tuple[bool]
) -> Tuple[core.ClosedJaxpr, List[bool]]:
new_jaxpr, used_inputs = pe.dce_jaxpr(jaxpr.jaxpr, used_outputs)
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts), used_inputs
def dce_jaxpr_pjit_rule(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], Optional[core.JaxprEqn]]:
dced_jaxpr, used_inputs = _dce_jaxpr_pjit(
eqn.params['jaxpr'], tuple(used_outputs))
def keep_where(xs, keeps):
return tuple(x for x, keep in safe_zip(xs, keeps) if keep)
eqn_params = eqn.params
new_params = dict(
eqn_params,
jaxpr=dced_jaxpr,
in_shardings=keep_where(eqn_params["in_shardings"], used_inputs),
out_shardings=keep_where(eqn_params["out_shardings"], used_outputs),
in_positional_semantics=keep_where(eqn_params["in_positional_semantics"],
used_inputs),
donated_invars=keep_where(eqn_params["donated_invars"], used_inputs),
)
if not any(used_inputs) and not any(used_outputs) and not dced_jaxpr.effects:
return used_inputs, None
else:
new_eqn = core.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, dced_jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn
pe.dce_rules[pjit_p] = dce_jaxpr_pjit_rule
def _check_resources_against_named_axes(what, aval, pos_axis_resources, named_axis_resources):
pjit_resources = set(
it.chain.from_iterable([d for d in pos_axis_resources if d is not None]))

View File

@ -3258,12 +3258,14 @@ class APITest(jtu.JaxTestCase):
# Use the tracer
return x + self._saved_tracer
if jax.config.jax_jit_pjit_api_merge:
msg = "Encountered an unexpected tracer"
else:
msg = "Encountered an unexpected tracer.*Tracer not in input tracers"
with self.assertRaisesRegex(
UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Tracer not in input tracers",
re.DOTALL)):
api.jit(func1)(2.)
UnexpectedTracerError, re.compile(msg, re.DOTALL)):
api.jit(func1)(2.0)
def test_escaped_tracer_omnistaging(self):
count = 1
@ -3640,7 +3642,12 @@ class APITest(jtu.JaxTestCase):
x = g(x)
return x
with self.assertRaisesRegex(Exception, r"Leaked sublevel"):
if jax.config.jax_jit_pjit_api_merge:
msg = r'Leaked trace MainTrace\(2,DynamicJaxprTrace\)'
else:
msg = r'Leaked sublevel'
with self.assertRaisesRegex(Exception, f"{msg}"):
f(3)
def test_leak_checker_avoids_false_positive_custom_jvp(self):