mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #10469 from mattjj:remove-units-pjit
PiperOrigin-RevId: 445041987
This commit is contained in:
commit
b85277860d
@ -49,7 +49,8 @@ from jax._src.tree_util import prefix_errors
|
||||
from jax._src import util
|
||||
from jax._src.util import (
|
||||
HashableFunction, safe_map, safe_zip, wrap_name, wraps,
|
||||
distributed_debug_log, split_list, cache, tuple_insert, weakref_lru_cache)
|
||||
distributed_debug_log, split_list, cache, tuple_insert, weakref_lru_cache,
|
||||
merge_lists)
|
||||
|
||||
class _FromGdaSingleton:
|
||||
pass
|
||||
@ -803,21 +804,15 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
|
||||
known_ins = tuple(pv.is_known() for pv in in_pvals)
|
||||
unknown_ins = tuple(not k for k in known_ins)
|
||||
raw_known_jaxpr, raw_unknown_jaxpr, unknown_outs = pe.partial_eval_jaxpr(
|
||||
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = pe.partial_eval_jaxpr_nounits(
|
||||
jaxpr, unknown_ins, instantiate=False)
|
||||
unknown_outs = tuple(unknown_outs)
|
||||
known_outs = tuple(not uk for uk in unknown_outs)
|
||||
num_residuals = len(raw_known_jaxpr.jaxpr.outvars) - len(unknown_outs)
|
||||
num_residuals = len(res_avals)
|
||||
|
||||
def keep_where(l, should_keep):
|
||||
return tuple(x for x, keep in zip(l, should_keep) if keep)
|
||||
|
||||
# Prepare the known jaxpr
|
||||
# TODO(apaszke): map_jaxpr will break caching!
|
||||
known_jaxpr = raw_known_jaxpr.map_jaxpr(lambda jaxpr: pe._drop_vars(
|
||||
jaxpr,
|
||||
drop_ins=unknown_ins,
|
||||
drop_outs=unknown_outs + (False,) * num_residuals))
|
||||
# Compute the known outputs
|
||||
known_params = dict(
|
||||
jaxpr=known_jaxpr,
|
||||
@ -850,18 +845,18 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
*(pv.get_known() for pv in in_pvals if pv.is_known()),
|
||||
**known_params)
|
||||
if num_residuals:
|
||||
known_out_vals, residual_vals = split_list(all_known_outs, [-num_residuals])
|
||||
known_out_vals, residual_vals = \
|
||||
split_list(all_known_outs, [len(all_known_outs) - num_residuals])
|
||||
else:
|
||||
known_out_vals, residual_vals = all_known_outs, ()
|
||||
known_tracers_out = [trace.new_const(known_out) for known_out in known_out_vals]
|
||||
residual_tracers = [trace.new_instantiated_const(residual) for residual in residual_vals]
|
||||
|
||||
# Prepare the unknown jaxpr
|
||||
# TODO(apaszke): map_jaxpr will break caching!
|
||||
unknown_jaxpr = raw_unknown_jaxpr.map_jaxpr(lambda jaxpr: pe._drop_vars(
|
||||
jaxpr,
|
||||
drop_ins=known_ins + (False,) * num_residuals,
|
||||
drop_outs=known_outs))
|
||||
# The convention of partial_eval_jaxpr_nounits is to place residual binders
|
||||
# at the front of the jaxpr produced, so we move them to the back since both
|
||||
# the jaxpr equation built below and the pjit transpose rule assume a
|
||||
# residual-inputs-last convention.
|
||||
unknown_jaxpr = pe.move_binders_to_back(
|
||||
unknown_jaxpr, [True] * num_residuals + [False] * sum(unknown_ins))
|
||||
# Prepare unknown tracers
|
||||
unknown_params = dict(
|
||||
jaxpr=unknown_jaxpr,
|
||||
@ -888,7 +883,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
unknown_jaxpr.effects,
|
||||
source_info_util.current())
|
||||
for t in unknown_tracers_out: t.recipe = eqn
|
||||
return pe._zip_knowns(known_tracers_out, unknown_tracers_out, unknown_outs)
|
||||
return merge_lists(unknown_outs, known_out_vals, unknown_tracers_out)
|
||||
pe.custom_partial_eval_rules[pjit_p] = _pjit_partial_eval
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user