mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Dedup non-ref constants closed in cond branch functions.
PiperOrigin-RevId: 734497907
This commit is contained in:
parent
bf95bf49d4
commit
e6db7a9d99
@ -27,7 +27,6 @@ from jax._src.lax import lax
|
||||
from jax._src import effects
|
||||
from jax._src import ad_util
|
||||
from jax._src import state
|
||||
from jax._src import util
|
||||
from jax._src.util import weakref_lru_cache, safe_map, partition_list
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_map, tree_unflatten, keystr, PyTreeDef
|
||||
@ -144,52 +143,55 @@ def _initial_style_jaxprs_with_common_consts(
|
||||
# b[] <- 2.0
|
||||
# in () }
|
||||
canonical_ref_indices = []
|
||||
canonical_non_ref_indices = []
|
||||
canonical_refs: list[Any] = []
|
||||
tracer_id_to_canonical_id = {}
|
||||
all_nonref_consts = []
|
||||
canonical_non_refs: list[Any] = []
|
||||
tracer_id_to_canonical_ref_id = {}
|
||||
tracer_id_to_canonical_non_ref_id = {}
|
||||
canonical_ref_avals = []
|
||||
all_nonref_const_avals = []
|
||||
canonical_non_ref_avals = []
|
||||
for consts, consts_avals in zip(all_consts, all_const_avals):
|
||||
ref_indices = []
|
||||
nonref_consts = []
|
||||
nonref_const_avals = []
|
||||
non_ref_indices = []
|
||||
for c, aval in zip(consts, consts_avals):
|
||||
tracer_id = id(c)
|
||||
if isinstance(aval, state.AbstractRef):
|
||||
tracer_id = id(c)
|
||||
if tracer_id not in tracer_id_to_canonical_id:
|
||||
if tracer_id not in tracer_id_to_canonical_ref_id:
|
||||
canonical_id = len(canonical_refs)
|
||||
canonical_refs.append(c)
|
||||
tracer_id_to_canonical_id[tracer_id] = canonical_id
|
||||
tracer_id_to_canonical_ref_id[tracer_id] = canonical_id
|
||||
canonical_ref_avals.append(aval)
|
||||
canonical_id = tracer_id_to_canonical_id[tracer_id]
|
||||
canonical_id = tracer_id_to_canonical_ref_id[tracer_id]
|
||||
ref_indices.append(canonical_id)
|
||||
else:
|
||||
nonref_consts.append(c)
|
||||
nonref_const_avals.append(aval)
|
||||
all_nonref_consts.append(nonref_consts)
|
||||
all_nonref_const_avals.append(tuple(nonref_const_avals))
|
||||
if tracer_id not in tracer_id_to_canonical_non_ref_id:
|
||||
canonical_id = len(canonical_non_refs)
|
||||
canonical_non_refs.append(c)
|
||||
tracer_id_to_canonical_non_ref_id[tracer_id] = canonical_id
|
||||
canonical_non_ref_avals.append(aval)
|
||||
canonical_id = tracer_id_to_canonical_non_ref_id[tracer_id]
|
||||
non_ref_indices.append(canonical_id)
|
||||
canonical_ref_indices.append(tuple(ref_indices))
|
||||
canonical_non_ref_indices.append(tuple(non_ref_indices))
|
||||
|
||||
consts = [*canonical_refs, *util.concatenate(all_nonref_consts)]
|
||||
jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*all_nonref_const_avals,))
|
||||
consts = [*canonical_refs, *canonical_non_refs]
|
||||
jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*canonical_non_ref_avals,), (*canonical_non_ref_indices,))
|
||||
for i, jaxpr in enumerate(jaxprs))
|
||||
return jaxprs, consts, all_out_trees
|
||||
|
||||
@weakref_lru_cache
|
||||
def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices,
|
||||
all_nonref_const_avals):
|
||||
canonical_non_ref_avals, canonical_non_ref_indices):
|
||||
is_ref = [isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars]
|
||||
nonref_constvars, ref_constvars = partition_list(is_ref, jaxpr.constvars)
|
||||
newvar = core.gensym(suffix='_')
|
||||
unused_const_vars = [tuple(map(newvar, const_avals))
|
||||
for const_avals in all_nonref_const_avals]
|
||||
padded_ref_constvars = map(newvar, canonical_ref_avals)
|
||||
padded_non_ref_constvars = map(newvar, canonical_non_ref_avals)
|
||||
for canonical_id, ref_var in zip(canonical_ref_indices[i], ref_constvars):
|
||||
padded_ref_constvars[canonical_id] = ref_var
|
||||
const_prefix = util.concatenate(unused_const_vars[:i])
|
||||
const_suffix = util.concatenate(unused_const_vars[i + 1:])
|
||||
constvars = [*padded_ref_constvars, *const_prefix, *nonref_constvars,
|
||||
*const_suffix]
|
||||
for canonical_id, non_ref_var in zip(canonical_non_ref_indices[i], nonref_constvars):
|
||||
padded_non_ref_constvars[canonical_id] = non_ref_var
|
||||
constvars = [*padded_ref_constvars, *padded_non_ref_constvars]
|
||||
jaxpr = jaxpr.replace(constvars=constvars)
|
||||
effects = pe.make_jaxpr_effects(jaxpr.constvars, jaxpr.invars,
|
||||
jaxpr.outvars, jaxpr.eqns)
|
||||
|
@ -281,8 +281,9 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
num_consts = len(consts)
|
||||
out_ = iter(out)
|
||||
|
||||
all_inputs = [*consts, *ops]
|
||||
out = [
|
||||
next(out_) if fwd is None else lax.asarray(ops[fwd - num_consts])
|
||||
next(out_) if fwd is None else lax.asarray(all_inputs[fwd])
|
||||
for fwd in in_fwd
|
||||
]
|
||||
assert next(out_, None) is None
|
||||
|
@ -6446,14 +6446,10 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
e:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
|
||||
f:f32[] = cond[
|
||||
branches=(
|
||||
{ lambda ; g_:f32[] h:f32[] i:f32[] j:f32[]. let
|
||||
k:f32[] = sub j h
|
||||
in (k,) }
|
||||
{ lambda ; l:f32[] m_:f32[] n:f32[] o:f32[]. let
|
||||
p:f32[] = add n l
|
||||
in (p,) }
|
||||
{ lambda ; g:f32[] h:f32[] i:f32[]. let j:f32[] = sub i g in (j,) }
|
||||
{ lambda ; k:f32[] l:f32[] m:f32[]. let n:f32[] = add l k in (n,) }
|
||||
)
|
||||
] e a a c d
|
||||
] e a c d
|
||||
in (f,) }"""
|
||||
jaxpr = api.make_jaxpr(f)(jnp.float32(3.))
|
||||
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
|
||||
|
Loading…
x
Reference in New Issue
Block a user