mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add dce_rules for pjit primitive so that remat can DCE through the pjit primitive and remove unused residuals
PiperOrigin-RevId: 504123801
This commit is contained in:
parent
1ee21d121c
commit
fb9b5ec1e4
@ -1705,6 +1705,44 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
|
||||
ad.reducing_transposes[pjit_p] = _pjit_transpose
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _dce_jaxpr_pjit(
|
||||
jaxpr: core.ClosedJaxpr, used_outputs: Tuple[bool]
|
||||
) -> Tuple[core.ClosedJaxpr, List[bool]]:
|
||||
new_jaxpr, used_inputs = pe.dce_jaxpr(jaxpr.jaxpr, used_outputs)
|
||||
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts), used_inputs
|
||||
|
||||
|
||||
def dce_jaxpr_pjit_rule(used_outputs: List[bool], eqn: core.JaxprEqn
|
||||
) -> Tuple[List[bool], Optional[core.JaxprEqn]]:
|
||||
dced_jaxpr, used_inputs = _dce_jaxpr_pjit(
|
||||
eqn.params['jaxpr'], tuple(used_outputs))
|
||||
|
||||
def keep_where(xs, keeps):
|
||||
return tuple(x for x, keep in safe_zip(xs, keeps) if keep)
|
||||
|
||||
eqn_params = eqn.params
|
||||
new_params = dict(
|
||||
eqn_params,
|
||||
jaxpr=dced_jaxpr,
|
||||
in_shardings=keep_where(eqn_params["in_shardings"], used_inputs),
|
||||
out_shardings=keep_where(eqn_params["out_shardings"], used_outputs),
|
||||
in_positional_semantics=keep_where(eqn_params["in_positional_semantics"],
|
||||
used_inputs),
|
||||
donated_invars=keep_where(eqn_params["donated_invars"], used_inputs),
|
||||
)
|
||||
if not any(used_inputs) and not any(used_outputs) and not dced_jaxpr.effects:
|
||||
return used_inputs, None
|
||||
else:
|
||||
new_eqn = core.new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, dced_jaxpr.effects, eqn.source_info)
|
||||
return used_inputs, new_eqn
|
||||
|
||||
pe.dce_rules[pjit_p] = dce_jaxpr_pjit_rule
|
||||
|
||||
|
||||
def _check_resources_against_named_axes(what, aval, pos_axis_resources, named_axis_resources):
|
||||
pjit_resources = set(
|
||||
it.chain.from_iterable([d for d in pos_axis_resources if d is not None]))
|
||||
|
@ -3258,12 +3258,14 @@ class APITest(jtu.JaxTestCase):
|
||||
# Use the tracer
|
||||
return x + self._saved_tracer
|
||||
|
||||
if jax.config.jax_jit_pjit_api_merge:
|
||||
msg = "Encountered an unexpected tracer"
|
||||
else:
|
||||
msg = "Encountered an unexpected tracer.*Tracer not in input tracers"
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
UnexpectedTracerError,
|
||||
re.compile(
|
||||
"Encountered an unexpected tracer.*Tracer not in input tracers",
|
||||
re.DOTALL)):
|
||||
api.jit(func1)(2.)
|
||||
UnexpectedTracerError, re.compile(msg, re.DOTALL)):
|
||||
api.jit(func1)(2.0)
|
||||
|
||||
def test_escaped_tracer_omnistaging(self):
|
||||
count = 1
|
||||
@ -3640,7 +3642,12 @@ class APITest(jtu.JaxTestCase):
|
||||
x = g(x)
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(Exception, r"Leaked sublevel"):
|
||||
if jax.config.jax_jit_pjit_api_merge:
|
||||
msg = r'Leaked trace MainTrace\(2,DynamicJaxprTrace\)'
|
||||
else:
|
||||
msg = r'Leaked sublevel'
|
||||
|
||||
with self.assertRaisesRegex(Exception, f"{msg}"):
|
||||
f(3)
|
||||
|
||||
def test_leak_checker_avoids_false_positive_custom_jvp(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user