mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00

Previously, we had two almost identical classes: `TracingDebugInfo` and `JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had a thunk to return the result paths, while `JaxprDebugInfo` had the result paths resolved to a tuple. The separation of these types provided some clarity, but also led to code duplication and required conversions as the debugging info goes from `WrappedFun` to a `Jaxpr` and then to `WrappedFun` again.
260 lines
11 KiB
Python
260 lines
11 KiB
Python
# Copyright 2022 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Module for the common control flow utilities."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Sequence
|
|
import os
|
|
from functools import partial
|
|
from typing import Any
|
|
|
|
from jax._src import api_util
|
|
from jax._src import core
|
|
from jax._src import linear_util as lu
|
|
from jax._src.lax import lax
|
|
from jax._src import effects
|
|
from jax._src import ad_util
|
|
from jax._src import state
|
|
from jax._src import util
|
|
from jax._src.util import weakref_lru_cache, safe_map, partition_list
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax.tree_util import tree_map, tree_unflatten, keystr, PyTreeDef
|
|
from jax._src.tree_util import equality_errors_pytreedef
|
|
|
|
map, unsafe_map = safe_map, map
|
|
|
|
effects.control_flow_allowed_effects.add_type(lax.InOutFeedEffect)
|
|
|
|
|
|
def _typecheck_param(prim, param, name, msg_required, pred):
|
|
if not pred:
|
|
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
|
|
f'{msg_required} required:')
|
|
param_str = str(param)
|
|
# Avoid using os.linesep here to have the same multi-line error message
|
|
# format on different platforms.
|
|
sep = os.linesep if '\n' in param_str or '\r' in param_str else ' '
|
|
msg = sep.join([msg, param_str])
|
|
raise core.JaxprTypeError(msg)
|
|
|
|
@weakref_lru_cache
|
|
def _initial_style_open_jaxpr(fun: Callable,
|
|
in_tree: PyTreeDef,
|
|
in_avals: Sequence[core.AbstractValue],
|
|
debug_info: core.DebugInfo):
|
|
wrapped_fun, out_tree = api_util.flatten_fun_nokwargs(
|
|
lu.wrap_init(fun, debug_info=debug_info),
|
|
in_tree)
|
|
jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
|
|
wrapped_fun, in_avals, debug_info)
|
|
return jaxpr, consts, out_tree(), attrs_tracked
|
|
|
|
@weakref_lru_cache
|
|
def _initial_style_jaxpr(fun: Callable,
|
|
in_tree: PyTreeDef,
|
|
in_avals: Sequence[core.AbstractValue],
|
|
debug_info: core.DebugInfo):
|
|
jaxpr, consts, out_tree, () = _initial_style_open_jaxpr(
|
|
fun, in_tree, in_avals, debug_info)
|
|
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
|
|
return closed_jaxpr, consts, out_tree
|
|
|
|
def _initial_style_jaxpr_attrs(fun: Callable,
|
|
in_tree: PyTreeDef,
|
|
in_avals: Sequence[core.AbstractValue],
|
|
debug_info: core.DebugInfo):
|
|
jaxpr, consts, out_tree, attrs_tracked = _initial_style_open_jaxpr(
|
|
fun, in_tree, in_avals, debug_info)
|
|
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
|
|
return closed_jaxpr, consts, out_tree, attrs_tracked
|
|
|
|
def _initial_style_jaxprs_with_common_consts(
|
|
funs: Sequence[Callable],
|
|
in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue],
|
|
debug_infos: Sequence[core.DebugInfo]):
|
|
# When staging the branches of a conditional into jaxprs, constants are
|
|
# extracted from each branch and converted to jaxpr arguments. To use the
|
|
# staged jaxprs as the branches to a conditional *primitive*, we need for
|
|
# their (input) signatures to match. This function "joins" the staged jaxprs:
|
|
# for each one, it makes another that accepts *all* constants, but only uses
|
|
# those that it needs (dropping the rest).
|
|
jaxpr_data = [_initial_style_open_jaxpr(fn, in_tree, in_avals, debug_info)
|
|
for fn, debug_info in zip(funs, debug_infos)]
|
|
if not jaxpr_data:
|
|
return [], [], []
|
|
|
|
jaxprs, all_consts, all_out_trees, all_attrs_tracked = zip(*jaxpr_data)
|
|
all_const_avals = [map(core.get_aval, consts) for consts in all_consts]
|
|
|
|
# TODO(sharadmv,mattjj): we could dedup *all consts* instead of just the Refs.
|
|
|
|
# We don't want two different Refs in a jaxpr's input to refer to the same
|
|
# Ref in the caller. We call this the "Ref aliasing problem" and it introduces
|
|
# difficulties when discharging Refs and when reasoning about programs with
|
|
# state effects. When unifying the arguments to each branch in a cond,
|
|
# however, we might naively pass the same Ref in multiple times.
|
|
#
|
|
# Here we dedup any `Ref`s that were closed over across the branches and
|
|
# pad out constants used across different branches.
|
|
# Let's consider an example case. For the following branch jaxprs, we will
|
|
# produce the following const lists, where `t_` indicates a tracer (a Ref).
|
|
# { lambda x:i32[] a:Ref{float64[]} c:Ref[float64[]}; . let
|
|
# a[] <- 1.0
|
|
# c[] <- 3.14
|
|
# in () }
|
|
|
|
# { lambda d:Ref[float64[]} b:Ref{float64[]} y:i32[]; . let
|
|
# d[] <- 6.28
|
|
# b[] <- 2.0
|
|
# in () }
|
|
# consts = [[0, t_e, t_f], [t_g, t_e, 1]]
|
|
#
|
|
# Notice how `t_e` is duplicated. To deduplicate the `Ref`s we first
|
|
# 1) Detecting duplicate `Ref` tracers. We keep track of duplicates in
|
|
# `tracer_id_to_canonical_id.` We store the deduped `Ref` tracers in a
|
|
# list called `canonical_refs`. We remove the `Ref`s from the consts.
|
|
# We should have the following lists:
|
|
# canonical_refs = [t_e, t_f, t_g]
|
|
# consts = [[0], [1]]
|
|
# 2) We need to munge the branch jaxprs to take in *all* the canonical Refs
|
|
# and ignore the ones it doesn't actually use. We do this by keeping track
|
|
# for each jaxpr for each of its input Refs which canonical_ref it
|
|
# corresponds to, producing the following list:
|
|
# canonical_ref_indices = [[0, 1], [2, 0]]
|
|
#
|
|
# Afterwards, we proceed by rewriting the jaxprs to be the following:
|
|
# { lambda a:Ref{float64[]} c:Ref[float64[]} b_:Ref{float64[]} x:i32[]; . let
|
|
# a[] <- 1.0
|
|
# c[] <- 3.14
|
|
# in () }
|
|
# { lambda b:Ref{float64[]} _:Ref{float64[]} d:Ref{float64[]} y:i32[]; . let
|
|
# d[] <- 6.28
|
|
# b[] <- 2.0
|
|
# in () }
|
|
canonical_ref_indices = []
|
|
canonical_refs: list[Any] = []
|
|
tracer_id_to_canonical_id = {}
|
|
all_nonref_consts = []
|
|
canonical_ref_avals = []
|
|
all_nonref_const_avals = []
|
|
for consts, consts_avals in zip(all_consts, all_const_avals):
|
|
ref_indices = []
|
|
nonref_consts = []
|
|
nonref_const_avals = []
|
|
for c, aval in zip(consts, consts_avals):
|
|
if isinstance(aval, state.AbstractRef):
|
|
tracer_id = id(c)
|
|
if tracer_id not in tracer_id_to_canonical_id:
|
|
canonical_id = len(canonical_refs)
|
|
canonical_refs.append(c)
|
|
tracer_id_to_canonical_id[tracer_id] = canonical_id
|
|
canonical_ref_avals.append(aval)
|
|
canonical_id = tracer_id_to_canonical_id[tracer_id]
|
|
ref_indices.append(canonical_id)
|
|
else:
|
|
nonref_consts.append(c)
|
|
nonref_const_avals.append(aval)
|
|
all_nonref_consts.append(nonref_consts)
|
|
all_nonref_const_avals.append(tuple(nonref_const_avals))
|
|
canonical_ref_indices.append(tuple(ref_indices))
|
|
|
|
consts = [*canonical_refs, *util.concatenate(all_nonref_consts)]
|
|
jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*all_nonref_const_avals,))
|
|
for i, jaxpr in enumerate(jaxprs))
|
|
return jaxprs, consts, all_out_trees
|
|
|
|
@weakref_lru_cache
|
|
def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices,
|
|
all_nonref_const_avals):
|
|
is_ref = [isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars]
|
|
nonref_constvars, ref_constvars = partition_list(is_ref, jaxpr.constvars)
|
|
newvar = core.gensym(suffix='_')
|
|
unused_const_vars = [tuple(map(newvar, const_avals))
|
|
for const_avals in all_nonref_const_avals]
|
|
padded_ref_constvars = map(newvar, canonical_ref_avals)
|
|
for canonical_id, ref_var in zip(canonical_ref_indices[i], ref_constvars):
|
|
padded_ref_constvars[canonical_id] = ref_var
|
|
const_prefix = util.concatenate(unused_const_vars[:i])
|
|
const_suffix = util.concatenate(unused_const_vars[i + 1:])
|
|
constvars = [*padded_ref_constvars, *const_prefix, *nonref_constvars,
|
|
*const_suffix]
|
|
jaxpr = jaxpr.replace(constvars=constvars)
|
|
effects = pe.make_jaxpr_effects(jaxpr.constvars, jaxpr.invars,
|
|
jaxpr.outvars, jaxpr.eqns)
|
|
jaxpr = jaxpr.replace(effects=effects)
|
|
return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
|
|
|
def _check_tree_and_avals(what1, tree1, avals1, what2, tree2, avals2):
|
|
"""Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
|
|
|
|
Corresponding `tree` and `avals` must match in the sense that the number of
|
|
leaves in `tree` must be equal to the length of `avals`. `what1` and
|
|
`what2` describe what the `tree1` and `tree2` represent.
|
|
"""
|
|
if tree1 != tree2:
|
|
errs = list(equality_errors_pytreedef(tree1, tree2))
|
|
msg = []
|
|
msg.append(
|
|
f"{what1} must have same type structure as {what2}, but there are differences: ")
|
|
for path, thing1, thing2, explanation in errs:
|
|
msg.append(
|
|
f" * at output{keystr(tuple(path))}, {what1} has {thing1} and "
|
|
f"{what2} has {thing2}, so {explanation}")
|
|
raise TypeError('\n'.join(msg))
|
|
|
|
if not all(map(core.typematch, avals1, avals2)):
|
|
diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
|
|
tree_unflatten(tree2, avals2))
|
|
raise TypeError(f"{what1} and {what2} must have identical types, got\n{diff}.")
|
|
|
|
def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False):
|
|
if has_aux:
|
|
actual_tree_children = actual_tree.children()
|
|
|
|
if len(actual_tree_children) == 2:
|
|
# select first child as result tree
|
|
actual_tree = actual_tree_children[0]
|
|
else:
|
|
raise ValueError(
|
|
f"{func_name}() produced a pytree with structure "
|
|
f"{actual_tree}, but a pytree tuple with auxiliary "
|
|
f"output was expected because has_aux was set to True.")
|
|
|
|
if actual_tree != expected_tree:
|
|
raise TypeError(
|
|
f"{func_name}() output pytree structure must match {expected_name}, "
|
|
f"got {actual_tree} and {expected_tree}.")
|
|
|
|
def _prune_zeros(ts):
|
|
return [t for t in ts if type(t) is not ad_util.Zero]
|
|
|
|
def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
|
|
return core.ClosedJaxpr(jaxpr, consts)
|
|
|
|
def _make_closed_jaxpr_attrs(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
|
|
jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
|
|
return core.ClosedJaxpr(jaxpr, consts), attrs_tracked
|
|
|
|
|
|
def _show_diff(array1, array2):
|
|
if core.typematch(array1, array2):
|
|
return f"{array1}"
|
|
return f"DIFFERENT {array1} vs. {array2}"
|
|
|
|
def _avals_short(avals):
|
|
to_str = lambda aval: getattr(aval, 'str_short', partial(str, aval))()
|
|
return ' '.join(map(to_str, avals))
|