Merge pull request #10469 from mattjj:remove-units-pjit

PiperOrigin-RevId: 445041987
This commit is contained in:
jax authors 2022-04-27 20:27:33 -07:00
commit b85277860d

View File

@ -49,7 +49,8 @@ from jax._src.tree_util import prefix_errors
from jax._src import util
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wrap_name, wraps,
distributed_debug_log, split_list, cache, tuple_insert, weakref_lru_cache)
distributed_debug_log, split_list, cache, tuple_insert, weakref_lru_cache,
merge_lists)
class _FromGdaSingleton:
pass
@ -803,21 +804,15 @@ def _pjit_partial_eval(trace, *in_tracers,
known_ins = tuple(pv.is_known() for pv in in_pvals)
unknown_ins = tuple(not k for k in known_ins)
raw_known_jaxpr, raw_unknown_jaxpr, unknown_outs = pe.partial_eval_jaxpr(
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = pe.partial_eval_jaxpr_nounits(
jaxpr, unknown_ins, instantiate=False)
unknown_outs = tuple(unknown_outs)
known_outs = tuple(not uk for uk in unknown_outs)
num_residuals = len(raw_known_jaxpr.jaxpr.outvars) - len(unknown_outs)
num_residuals = len(res_avals)
def keep_where(l, should_keep):
return tuple(x for x, keep in zip(l, should_keep) if keep)
# Prepare the known jaxpr
# TODO(apaszke): map_jaxpr will break caching!
known_jaxpr = raw_known_jaxpr.map_jaxpr(lambda jaxpr: pe._drop_vars(
jaxpr,
drop_ins=unknown_ins,
drop_outs=unknown_outs + (False,) * num_residuals))
# Compute the known outputs
known_params = dict(
jaxpr=known_jaxpr,
@ -850,18 +845,18 @@ def _pjit_partial_eval(trace, *in_tracers,
*(pv.get_known() for pv in in_pvals if pv.is_known()),
**known_params)
if num_residuals:
known_out_vals, residual_vals = split_list(all_known_outs, [-num_residuals])
known_out_vals, residual_vals = \
split_list(all_known_outs, [len(all_known_outs) - num_residuals])
else:
known_out_vals, residual_vals = all_known_outs, ()
known_tracers_out = [trace.new_const(known_out) for known_out in known_out_vals]
residual_tracers = [trace.new_instantiated_const(residual) for residual in residual_vals]
# Prepare the unknown jaxpr
# TODO(apaszke): map_jaxpr will break caching!
unknown_jaxpr = raw_unknown_jaxpr.map_jaxpr(lambda jaxpr: pe._drop_vars(
jaxpr,
drop_ins=known_ins + (False,) * num_residuals,
drop_outs=known_outs))
# The convention of partial_eval_jaxpr_nounits is to place residual binders
# at the front of the jaxpr produced, so we move them to the back since both
# the jaxpr equation built below and the pjit transpose rule assume a
# residual-inputs-last convention.
unknown_jaxpr = pe.move_binders_to_back(
unknown_jaxpr, [True] * num_residuals + [False] * sum(unknown_ins))
# Prepare unknown tracers
unknown_params = dict(
jaxpr=unknown_jaxpr,
@ -888,7 +883,7 @@ def _pjit_partial_eval(trace, *in_tracers,
unknown_jaxpr.effects,
source_info_util.current())
for t in unknown_tracers_out: t.recipe = eqn
return pe._zip_knowns(known_tracers_out, unknown_tracers_out, unknown_outs)
return merge_lists(unknown_outs, known_out_vals, unknown_tracers_out)
pe.custom_partial_eval_rules[pjit_p] = _pjit_partial_eval