Merge pull request #22402 from superbobry:maint

PiperOrigin-RevId: 652487969
This commit is contained in:
jax authors 2024-07-15 08:20:56 -07:00
commit 2e7e700090
3 changed files with 39 additions and 84 deletions

View File

@ -70,29 +70,6 @@ for_p.multiple_results = True
### Tracing utilities
def _hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
all_const_avals = [var.aval for var in jaxpr.constvars]
is_const_ref = [isinstance(var.aval, AbstractRef) for var in
jaxpr.constvars]
const_avals, const_ref_avals = partition_list(is_const_ref, all_const_avals)
const_avals = map(AbstractRef, const_avals)
merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals)
i_aval, *arg_avals = (var.aval for var in jaxpr.invars)
in_avals = [i_aval, *merged_const_avals, *arg_avals]
num_consts = len(merged_const_avals)
def _hoist(i, *consts_args):
all_consts, args = split_list(consts_args, [num_consts])
consts, const_refs = partition_list(is_const_ref, all_consts)
# We immediately read the const values out of the `Ref`s.
consts = map(lambda x: ref_get(x, ()), consts)
all_consts = merge_lists(is_const_ref, consts, const_refs)
return core.eval_jaxpr(jaxpr, all_consts, i, *args)
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_hoist), in_avals)
assert not consts, "All consts should have been converted to refs"
return hoisted_jaxpr
def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef,
state_avals: Sequence[core.AbstractValue]
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
@ -160,8 +137,7 @@ def for_loop(nsteps: int | Sequence[int],
body, state_tree, [idx_aval, *state_avals])
if out_tree != tree_structure(None):
raise Exception("`body` should not return anything.")
# Remove constvars from jaxpr and turn them into `Ref`s
jaxpr = _hoist_consts_to_refs(jaxpr)
jaxpr = state_utils.hoist_consts_to_refs(jaxpr, index=1)
which_linear = (False,) * (len(consts) + len(flat_state))
out_flat = for_p.bind(*consts, *flat_state, jaxpr=jaxpr, nsteps=int(nsteps),
reverse=reverse, which_linear=which_linear,

View File

@ -29,7 +29,6 @@ from jax._src import config
from jax._src import core as jax_core
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import state
from jax._src import tree_util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -38,7 +37,7 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pallas_core
from jax._src.pallas.primitives import uninitialized_value
from jax._src.state import discharge as state_discharge
from jax._src.state import primitives as sp
from jax._src.state import utils as state_utils
from jax._src.util import (
safe_map,
safe_zip,
@ -693,42 +692,6 @@ def _pallas_call_batching_rule(
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule
def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr:
"""Hoists the constants in the given jaxpr into invars.
Args:
jaxpr: The jaxpr.
Returns:
A new jaxpr where the constants were hoisted into invars as ``Ref``s.
The invars for the constants are added *before* any existing invars.
"""
if not jaxpr.constvars:
return jaxpr # Nothing to hoist.
is_const_ref = [
isinstance(var.aval, state.AbstractRef) for var in jaxpr.constvars
]
const_avals = [
var.aval if is_ref else state.AbstractRef(var.aval)
for is_ref, var in zip(is_const_ref, jaxpr.constvars)
]
in_avals = const_avals + [var.aval for var in jaxpr.invars]
def _hoist(*consts_args):
all_consts, args = split_list(consts_args, [len(const_avals)])
# We immediately read the const values out of the `Ref`s.
all_consts = [
c if is_ref else sp.ref_get(c, ())
for is_ref, c in zip(is_const_ref, all_consts)
]
return jax_core.eval_jaxpr(jaxpr, all_consts, *args)
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_hoist), in_avals)
assert not consts, "All consts should have been converted to refs"
return hoisted_jaxpr
def checkify_pallas_kernel_body_jaxpr(
body_jaxpr: jax_core.ClosedJaxpr,
@ -914,7 +877,7 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun,
jaxpr_flat_avals, debug)
if consts:
jaxpr = _hoist_consts_to_refs(jaxpr)
jaxpr = state_utils.hoist_consts_to_refs(jaxpr)
# Pad ``block_mappings`` to account for the hoisted constants.
grid_mapping = grid_mapping.replace(
block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)),

View File

@ -13,44 +13,60 @@
# limitations under the License.
"""Utilities for tracing stateful functions."""
from collections.abc import Sequence
from jax._src.interpreters import partial_eval as pe
from jax._src import core
from jax._src import linear_util as lu
from jax._src.state import AbstractRef
from jax._src.util import (partition_list, merge_lists, split_list, safe_map,
safe_zip)
from jax._src.util import split_list, safe_map, safe_zip
from jax._src.state.primitives import ref_get
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
def hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
all_const_avals = [var.aval for var in jaxpr.constvars]
is_const_ref = [isinstance(var.aval, AbstractRef) for var in
jaxpr.constvars]
const_avals_, const_ref_avals = partition_list(is_const_ref, all_const_avals)
const_avals: Sequence[AbstractRef] = map(AbstractRef, const_avals_)
merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals)
arg_avals = [var.aval for var in jaxpr.invars]
in_avals = [*merged_const_avals, *arg_avals]
num_consts = len(merged_const_avals)
def hoist_consts_to_refs(jaxpr: core.Jaxpr, *, index: int = 0) -> core.Jaxpr:
"""Hoists the constants in the given jaxpr into invars.
Args:
jaxpr: The jaxpr.
index: The index where the invars for the constants should be inserted.
By default, the new invars are inserted *before* any existing invars.
Returns:
A new jaxpr where the constants were hoisted into invars as ``Ref``s.
"""
if not jaxpr.constvars:
return jaxpr # Nothing to hoist.
is_const_ref = [
isinstance(var.aval, AbstractRef) for var in jaxpr.constvars
]
const_avals = [
var.aval if is_ref else AbstractRef(var.aval)
for is_ref, var in zip(is_const_ref, jaxpr.constvars)
]
in_avals = [var.aval for var in jaxpr.invars]
in_avals[index:index] = const_avals
def _hoist(*consts_args):
all_consts, args = split_list(consts_args, [num_consts])
consts, const_refs = partition_list(is_const_ref, all_consts)
args0, all_consts, args1 = split_list(
consts_args, [index, len(const_avals)]
)
# We immediately read the const values out of the `Ref`s.
consts = map(lambda x: ref_get(x, ()), consts)
all_consts = merge_lists(is_const_ref, consts, const_refs)
return core.eval_jaxpr(jaxpr, all_consts, *args)
all_consts = [
c if is_ref else ref_get(c, ())
for is_ref, c in zip(is_const_ref, all_consts)
]
return core.eval_jaxpr(jaxpr, all_consts, *args0, *args1)
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_hoist), in_avals)
assert not consts, "All consts should have been converted to refs"
return hoisted_jaxpr
def val_to_ref_aval(x) -> AbstractRef:
aval = core.raise_to_shaped(core.get_aval(x))
if type(aval) is not core.ShapedArray:
raise Exception(f"can't make ref from {x}")
raise TypeError(f"can't make ref from {x}")
return AbstractRef(aval)