Dedup non-ref constants closed in cond branch functions.

PiperOrigin-RevId: 734497907
This commit is contained in:
Daniel Suo 2025-03-07 04:00:57 -08:00 committed by jax authors
parent bf95bf49d4
commit e6db7a9d99
3 changed files with 30 additions and 31 deletions

View File

@ -27,7 +27,6 @@ from jax._src.lax import lax
from jax._src import effects
from jax._src import ad_util
from jax._src import state
from jax._src import util
from jax._src.util import weakref_lru_cache, safe_map, partition_list
from jax._src.interpreters import partial_eval as pe
from jax.tree_util import tree_map, tree_unflatten, keystr, PyTreeDef
@ -144,52 +143,55 @@ def _initial_style_jaxprs_with_common_consts(
# b[] <- 2.0
# in () }
canonical_ref_indices = []
canonical_non_ref_indices = []
canonical_refs: list[Any] = []
tracer_id_to_canonical_id = {}
all_nonref_consts = []
canonical_non_refs: list[Any] = []
tracer_id_to_canonical_ref_id = {}
tracer_id_to_canonical_non_ref_id = {}
canonical_ref_avals = []
all_nonref_const_avals = []
canonical_non_ref_avals = []
for consts, consts_avals in zip(all_consts, all_const_avals):
ref_indices = []
nonref_consts = []
nonref_const_avals = []
non_ref_indices = []
for c, aval in zip(consts, consts_avals):
tracer_id = id(c)
if isinstance(aval, state.AbstractRef):
tracer_id = id(c)
if tracer_id not in tracer_id_to_canonical_id:
if tracer_id not in tracer_id_to_canonical_ref_id:
canonical_id = len(canonical_refs)
canonical_refs.append(c)
tracer_id_to_canonical_id[tracer_id] = canonical_id
tracer_id_to_canonical_ref_id[tracer_id] = canonical_id
canonical_ref_avals.append(aval)
canonical_id = tracer_id_to_canonical_id[tracer_id]
canonical_id = tracer_id_to_canonical_ref_id[tracer_id]
ref_indices.append(canonical_id)
else:
nonref_consts.append(c)
nonref_const_avals.append(aval)
all_nonref_consts.append(nonref_consts)
all_nonref_const_avals.append(tuple(nonref_const_avals))
if tracer_id not in tracer_id_to_canonical_non_ref_id:
canonical_id = len(canonical_non_refs)
canonical_non_refs.append(c)
tracer_id_to_canonical_non_ref_id[tracer_id] = canonical_id
canonical_non_ref_avals.append(aval)
canonical_id = tracer_id_to_canonical_non_ref_id[tracer_id]
non_ref_indices.append(canonical_id)
canonical_ref_indices.append(tuple(ref_indices))
canonical_non_ref_indices.append(tuple(non_ref_indices))
consts = [*canonical_refs, *util.concatenate(all_nonref_consts)]
jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*all_nonref_const_avals,))
consts = [*canonical_refs, *canonical_non_refs]
jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*canonical_non_ref_avals,), (*canonical_non_ref_indices,))
for i, jaxpr in enumerate(jaxprs))
return jaxprs, consts, all_out_trees
@weakref_lru_cache
def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices,
all_nonref_const_avals):
canonical_non_ref_avals, canonical_non_ref_indices):
is_ref = [isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars]
nonref_constvars, ref_constvars = partition_list(is_ref, jaxpr.constvars)
newvar = core.gensym(suffix='_')
unused_const_vars = [tuple(map(newvar, const_avals))
for const_avals in all_nonref_const_avals]
padded_ref_constvars = map(newvar, canonical_ref_avals)
padded_non_ref_constvars = map(newvar, canonical_non_ref_avals)
for canonical_id, ref_var in zip(canonical_ref_indices[i], ref_constvars):
padded_ref_constvars[canonical_id] = ref_var
const_prefix = util.concatenate(unused_const_vars[:i])
const_suffix = util.concatenate(unused_const_vars[i + 1:])
constvars = [*padded_ref_constvars, *const_prefix, *nonref_constvars,
*const_suffix]
for canonical_id, non_ref_var in zip(canonical_non_ref_indices[i], nonref_constvars):
padded_non_ref_constvars[canonical_id] = non_ref_var
constvars = [*padded_ref_constvars, *padded_non_ref_constvars]
jaxpr = jaxpr.replace(constvars=constvars)
effects = pe.make_jaxpr_effects(jaxpr.constvars, jaxpr.invars,
jaxpr.outvars, jaxpr.eqns)

View File

@ -281,8 +281,9 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
num_consts = len(consts)
out_ = iter(out)
all_inputs = [*consts, *ops]
out = [
next(out_) if fwd is None else lax.asarray(ops[fwd - num_consts])
next(out_) if fwd is None else lax.asarray(all_inputs[fwd])
for fwd in in_fwd
]
assert next(out_, None) is None

View File

@ -6446,14 +6446,10 @@ class JaxprTest(jtu.JaxTestCase):
e:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
f:f32[] = cond[
branches=(
{ lambda ; g_:f32[] h:f32[] i:f32[] j:f32[]. let
k:f32[] = sub j h
in (k,) }
{ lambda ; l:f32[] m_:f32[] n:f32[] o:f32[]. let
p:f32[] = add n l
in (p,) }
{ lambda ; g:f32[] h:f32[] i:f32[]. let j:f32[] = sub i g in (j,) }
{ lambda ; k:f32[] l:f32[] m:f32[]. let n:f32[] = add l k in (n,) }
)
] e a a c d
] e a c d
in (f,) }"""
jaxpr = api.make_jaxpr(f)(jnp.float32(3.))
self.assertMultiLineStrippedEqual(expected, str(jaxpr))