mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #22402 from superbobry:maint
PiperOrigin-RevId: 652487969
This commit is contained in:
commit
2e7e700090
@ -70,29 +70,6 @@ for_p.multiple_results = True
|
||||
|
||||
### Tracing utilities
|
||||
|
||||
def _hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
all_const_avals = [var.aval for var in jaxpr.constvars]
|
||||
is_const_ref = [isinstance(var.aval, AbstractRef) for var in
|
||||
jaxpr.constvars]
|
||||
const_avals, const_ref_avals = partition_list(is_const_ref, all_const_avals)
|
||||
const_avals = map(AbstractRef, const_avals)
|
||||
merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals)
|
||||
i_aval, *arg_avals = (var.aval for var in jaxpr.invars)
|
||||
in_avals = [i_aval, *merged_const_avals, *arg_avals]
|
||||
num_consts = len(merged_const_avals)
|
||||
|
||||
def _hoist(i, *consts_args):
|
||||
all_consts, args = split_list(consts_args, [num_consts])
|
||||
consts, const_refs = partition_list(is_const_ref, all_consts)
|
||||
# We immediately read the const values out of the `Ref`s.
|
||||
consts = map(lambda x: ref_get(x, ()), consts)
|
||||
all_consts = merge_lists(is_const_ref, consts, const_refs)
|
||||
return core.eval_jaxpr(jaxpr, all_consts, i, *args)
|
||||
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_hoist), in_avals)
|
||||
assert not consts, "All consts should have been converted to refs"
|
||||
return hoisted_jaxpr
|
||||
|
||||
def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef,
|
||||
state_avals: Sequence[core.AbstractValue]
|
||||
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
|
||||
@ -160,8 +137,7 @@ def for_loop(nsteps: int | Sequence[int],
|
||||
body, state_tree, [idx_aval, *state_avals])
|
||||
if out_tree != tree_structure(None):
|
||||
raise Exception("`body` should not return anything.")
|
||||
# Remove constvars from jaxpr and turn them into `Ref`s
|
||||
jaxpr = _hoist_consts_to_refs(jaxpr)
|
||||
jaxpr = state_utils.hoist_consts_to_refs(jaxpr, index=1)
|
||||
which_linear = (False,) * (len(consts) + len(flat_state))
|
||||
out_flat = for_p.bind(*consts, *flat_state, jaxpr=jaxpr, nsteps=int(nsteps),
|
||||
reverse=reverse, which_linear=which_linear,
|
||||
|
@ -29,7 +29,6 @@ from jax._src import config
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -38,7 +37,7 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.primitives import uninitialized_value
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import primitives as sp
|
||||
from jax._src.state import utils as state_utils
|
||||
from jax._src.util import (
|
||||
safe_map,
|
||||
safe_zip,
|
||||
@ -693,42 +692,6 @@ def _pallas_call_batching_rule(
|
||||
|
||||
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule
|
||||
|
||||
def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr:
|
||||
"""Hoists the constants in the given jaxpr into invars.
|
||||
|
||||
Args:
|
||||
jaxpr: The jaxpr.
|
||||
|
||||
Returns:
|
||||
A new jaxpr where the constants were hoisted into invars as ``Ref``s.
|
||||
The invars for the constants are added *before* any existing invars.
|
||||
"""
|
||||
if not jaxpr.constvars:
|
||||
return jaxpr # Nothing to hoist.
|
||||
|
||||
is_const_ref = [
|
||||
isinstance(var.aval, state.AbstractRef) for var in jaxpr.constvars
|
||||
]
|
||||
const_avals = [
|
||||
var.aval if is_ref else state.AbstractRef(var.aval)
|
||||
for is_ref, var in zip(is_const_ref, jaxpr.constvars)
|
||||
]
|
||||
in_avals = const_avals + [var.aval for var in jaxpr.invars]
|
||||
|
||||
def _hoist(*consts_args):
|
||||
all_consts, args = split_list(consts_args, [len(const_avals)])
|
||||
# We immediately read the const values out of the `Ref`s.
|
||||
all_consts = [
|
||||
c if is_ref else sp.ref_get(c, ())
|
||||
for is_ref, c in zip(is_const_ref, all_consts)
|
||||
]
|
||||
return jax_core.eval_jaxpr(jaxpr, all_consts, *args)
|
||||
|
||||
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_hoist), in_avals)
|
||||
assert not consts, "All consts should have been converted to refs"
|
||||
return hoisted_jaxpr
|
||||
|
||||
|
||||
def checkify_pallas_kernel_body_jaxpr(
|
||||
body_jaxpr: jax_core.ClosedJaxpr,
|
||||
@ -914,7 +877,7 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun,
|
||||
jaxpr_flat_avals, debug)
|
||||
if consts:
|
||||
jaxpr = _hoist_consts_to_refs(jaxpr)
|
||||
jaxpr = state_utils.hoist_consts_to_refs(jaxpr)
|
||||
# Pad ``block_mappings`` to account for the hoisted constants.
|
||||
grid_mapping = grid_mapping.replace(
|
||||
block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)),
|
||||
|
@ -13,44 +13,60 @@
|
||||
# limitations under the License.
|
||||
"""Utilities for tracing stateful functions."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.state import AbstractRef
|
||||
from jax._src.util import (partition_list, merge_lists, split_list, safe_map,
|
||||
safe_zip)
|
||||
from jax._src.util import split_list, safe_map, safe_zip
|
||||
from jax._src.state.primitives import ref_get
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
def hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
all_const_avals = [var.aval for var in jaxpr.constvars]
|
||||
is_const_ref = [isinstance(var.aval, AbstractRef) for var in
|
||||
jaxpr.constvars]
|
||||
const_avals_, const_ref_avals = partition_list(is_const_ref, all_const_avals)
|
||||
const_avals: Sequence[AbstractRef] = map(AbstractRef, const_avals_)
|
||||
merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals)
|
||||
arg_avals = [var.aval for var in jaxpr.invars]
|
||||
in_avals = [*merged_const_avals, *arg_avals]
|
||||
num_consts = len(merged_const_avals)
|
||||
|
||||
def hoist_consts_to_refs(jaxpr: core.Jaxpr, *, index: int = 0) -> core.Jaxpr:
|
||||
"""Hoists the constants in the given jaxpr into invars.
|
||||
|
||||
Args:
|
||||
jaxpr: The jaxpr.
|
||||
index: The index where the invars for the constants should be inserted.
|
||||
By default, the new invars are inserted *before* any existing invars.
|
||||
|
||||
Returns:
|
||||
A new jaxpr where the constants were hoisted into invars as ``Ref``s.
|
||||
"""
|
||||
if not jaxpr.constvars:
|
||||
return jaxpr # Nothing to hoist.
|
||||
|
||||
is_const_ref = [
|
||||
isinstance(var.aval, AbstractRef) for var in jaxpr.constvars
|
||||
]
|
||||
const_avals = [
|
||||
var.aval if is_ref else AbstractRef(var.aval)
|
||||
for is_ref, var in zip(is_const_ref, jaxpr.constvars)
|
||||
]
|
||||
in_avals = [var.aval for var in jaxpr.invars]
|
||||
in_avals[index:index] = const_avals
|
||||
|
||||
def _hoist(*consts_args):
|
||||
all_consts, args = split_list(consts_args, [num_consts])
|
||||
consts, const_refs = partition_list(is_const_ref, all_consts)
|
||||
args0, all_consts, args1 = split_list(
|
||||
consts_args, [index, len(const_avals)]
|
||||
)
|
||||
# We immediately read the const values out of the `Ref`s.
|
||||
consts = map(lambda x: ref_get(x, ()), consts)
|
||||
all_consts = merge_lists(is_const_ref, consts, const_refs)
|
||||
return core.eval_jaxpr(jaxpr, all_consts, *args)
|
||||
all_consts = [
|
||||
c if is_ref else ref_get(c, ())
|
||||
for is_ref, c in zip(is_const_ref, all_consts)
|
||||
]
|
||||
return core.eval_jaxpr(jaxpr, all_consts, *args0, *args1)
|
||||
|
||||
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_hoist), in_avals)
|
||||
assert not consts, "All consts should have been converted to refs"
|
||||
return hoisted_jaxpr
|
||||
|
||||
|
||||
def val_to_ref_aval(x) -> AbstractRef:
|
||||
aval = core.raise_to_shaped(core.get_aval(x))
|
||||
if type(aval) is not core.ShapedArray:
|
||||
raise Exception(f"can't make ref from {x}")
|
||||
raise TypeError(f"can't make ref from {x}")
|
||||
return AbstractRef(aval)
|
||||
|
Loading…
x
Reference in New Issue
Block a user