mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
55 lines
2.3 KiB
Python
55 lines
2.3 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Utilities for tracing stateful functions."""
|
|
|
|
from jax.interpreters import partial_eval as pe
|
|
from jax._src import core
|
|
from jax._src import linear_util as lu
|
|
from jax._src.state import AbstractRef
|
|
from jax._src.util import (partition_list, merge_lists, split_list, safe_map,
|
|
safe_zip)
|
|
from jax._src.state.primitives import ref_get
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
def hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
|
all_const_avals = [var.aval for var in jaxpr.constvars]
|
|
is_const_ref = [isinstance(var.aval, AbstractRef) for var in
|
|
jaxpr.constvars]
|
|
const_avals_, const_ref_avals = partition_list(is_const_ref, all_const_avals)
|
|
const_avals = map(AbstractRef, const_avals_)
|
|
merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals)
|
|
arg_avals = [var.aval for var in jaxpr.invars]
|
|
in_avals = [*merged_const_avals, *arg_avals]
|
|
num_consts = len(merged_const_avals)
|
|
|
|
def _hoist(*consts_args):
|
|
all_consts, args = split_list(consts_args, [num_consts])
|
|
consts, const_refs = partition_list(is_const_ref, all_consts)
|
|
# We immediately read the const values out of the `Ref`s.
|
|
consts = map(lambda x: ref_get(x, ()), consts)
|
|
all_consts = merge_lists(is_const_ref, consts, const_refs)
|
|
return core.eval_jaxpr(jaxpr, all_consts, *args)
|
|
hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(_hoist), in_avals)
|
|
assert not consts, "All consts should have been converted to refs"
|
|
return hoisted_jaxpr
|
|
|
|
def val_to_ref_aval(x) -> AbstractRef:
|
|
aval = core.raise_to_shaped(core.get_aval(x))
|
|
if type(aval) is not core.ShapedArray:
|
|
raise Exception(f"can't make ref from {x}")
|
|
return AbstractRef(aval)
|