From e6db7a9d99fbfa1a2de9fe649189611fa2e9b6ee Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Fri, 7 Mar 2025 04:00:57 -0800 Subject: [PATCH] Dedup non-ref constants closed in cond branch functions. PiperOrigin-RevId: 734497907 --- jax/_src/lax/control_flow/common.py | 48 ++++++++++++----------- jax/_src/lax/control_flow/conditionals.py | 3 +- tests/api_test.py | 10 ++--- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index cecd1cdc5..b75cbf6ac 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -27,7 +27,6 @@ from jax._src.lax import lax from jax._src import effects from jax._src import ad_util from jax._src import state -from jax._src import util from jax._src.util import weakref_lru_cache, safe_map, partition_list from jax._src.interpreters import partial_eval as pe from jax.tree_util import tree_map, tree_unflatten, keystr, PyTreeDef @@ -144,52 +143,55 @@ def _initial_style_jaxprs_with_common_consts( # b[] <- 2.0 # in () } canonical_ref_indices = [] + canonical_non_ref_indices = [] canonical_refs: list[Any] = [] - tracer_id_to_canonical_id = {} - all_nonref_consts = [] + canonical_non_refs: list[Any] = [] + tracer_id_to_canonical_ref_id = {} + tracer_id_to_canonical_non_ref_id = {} canonical_ref_avals = [] - all_nonref_const_avals = [] + canonical_non_ref_avals = [] for consts, consts_avals in zip(all_consts, all_const_avals): ref_indices = [] - nonref_consts = [] - nonref_const_avals = [] + non_ref_indices = [] for c, aval in zip(consts, consts_avals): + tracer_id = id(c) if isinstance(aval, state.AbstractRef): - tracer_id = id(c) - if tracer_id not in tracer_id_to_canonical_id: + if tracer_id not in tracer_id_to_canonical_ref_id: canonical_id = len(canonical_refs) canonical_refs.append(c) - tracer_id_to_canonical_id[tracer_id] = canonical_id + tracer_id_to_canonical_ref_id[tracer_id] = canonical_id canonical_ref_avals.append(aval) - canonical_id = tracer_id_to_canonical_id[tracer_id] + canonical_id = tracer_id_to_canonical_ref_id[tracer_id] ref_indices.append(canonical_id) else: - nonref_consts.append(c) - nonref_const_avals.append(aval) - all_nonref_consts.append(nonref_consts) - all_nonref_const_avals.append(tuple(nonref_const_avals)) + if tracer_id not in tracer_id_to_canonical_non_ref_id: + canonical_id = len(canonical_non_refs) + canonical_non_refs.append(c) + tracer_id_to_canonical_non_ref_id[tracer_id] = canonical_id + canonical_non_ref_avals.append(aval) + canonical_id = tracer_id_to_canonical_non_ref_id[tracer_id] + non_ref_indices.append(canonical_id) canonical_ref_indices.append(tuple(ref_indices)) + canonical_non_ref_indices.append(tuple(non_ref_indices)) - consts = [*canonical_refs, *util.concatenate(all_nonref_consts)] - jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*all_nonref_const_avals,)) + consts = [*canonical_refs, *canonical_non_refs] + jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*canonical_non_ref_avals,), (*canonical_non_ref_indices,)) for i, jaxpr in enumerate(jaxprs)) return jaxprs, consts, all_out_trees @weakref_lru_cache def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices, - all_nonref_const_avals): + canonical_non_ref_avals, canonical_non_ref_indices): is_ref = [isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars] nonref_constvars, ref_constvars = partition_list(is_ref, jaxpr.constvars) newvar = core.gensym(suffix='_') - unused_const_vars = [tuple(map(newvar, const_avals)) - for const_avals in all_nonref_const_avals] padded_ref_constvars = map(newvar, canonical_ref_avals) + padded_non_ref_constvars = map(newvar, canonical_non_ref_avals) for canonical_id, ref_var in zip(canonical_ref_indices[i], ref_constvars): padded_ref_constvars[canonical_id] = ref_var - const_prefix = util.concatenate(unused_const_vars[:i]) - const_suffix = util.concatenate(unused_const_vars[i + 1:]) - constvars = [*padded_ref_constvars, *const_prefix, *nonref_constvars, - *const_suffix] + for canonical_id, non_ref_var in zip(canonical_non_ref_indices[i], nonref_constvars): + padded_non_ref_constvars[canonical_id] = non_ref_var + constvars = [*padded_ref_constvars, *padded_non_ref_constvars] jaxpr = jaxpr.replace(constvars=constvars) effects = pe.make_jaxpr_effects(jaxpr.constvars, jaxpr.invars, jaxpr.outvars, jaxpr.eqns) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index e2ad6ced1..63896cc2a 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -281,8 +281,9 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, num_consts = len(consts) out_ = iter(out) + all_inputs = [*consts, *ops] out = [ - next(out_) if fwd is None else lax.asarray(ops[fwd - num_consts]) + next(out_) if fwd is None else lax.asarray(all_inputs[fwd]) for fwd in in_fwd ] assert next(out_, None) is None diff --git a/tests/api_test.py b/tests/api_test.py index 571a33e24..ff729c03d 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -6446,14 +6446,10 @@ class JaxprTest(jtu.JaxTestCase): e:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b f:f32[] = cond[ branches=( - { lambda ; g_:f32[] h:f32[] i:f32[] j:f32[]. let - k:f32[] = sub j h - in (k,) } - { lambda ; l:f32[] m_:f32[] n:f32[] o:f32[]. let - p:f32[] = add n l - in (p,) } + { lambda ; g:f32[] h:f32[] i:f32[]. let j:f32[] = sub i g in (j,) } + { lambda ; k:f32[] l:f32[] m:f32[]. let n:f32[] = add l k in (n,) } ) - ] e a a c d + ] e a c d in (f,) }""" jaxpr = api.make_jaxpr(f)(jnp.float32(3.)) self.assertMultiLineStrippedEqual(expected, str(jaxpr))