mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 14:26:07 +00:00
586 lines
24 KiB
Python
586 lines
24 KiB
Python
# Copyright 2019 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Loops is an **experimental** module for syntactic sugar for loops and control-flow.
|
|
|
|
The current implementation should convert loops correctly to JAX internal
|
|
representation, and most transformations should work (see below), but we have
|
|
not yet fine-tuned the performance of the resulting XLA compilation!
|
|
|
|
By default, loops and control-flow in JAX are executed and inlined during tracing.
|
|
For example, in the following code the `for` loop is unrolled during JAX tracing::
|
|
|
|
arr = np.zeros(5)
|
|
for i in range(arr.shape[0]):
|
|
arr[i] += 2.
|
|
if i % 2 == 0:
|
|
arr[i] += 1.
|
|
|
|
In order to capture the structured control-flow one can use the higher-order
|
|
JAX operations, which require you to express the body of the loops and
|
|
conditionals as functions, and the array updates using a functional style that
|
|
returns an updated array, e.g.::
|
|
|
|
arr = np.zeros(5)
|
|
def loop_body(i, acc_arr):
|
|
arr1 = acc_arr.at[i].set(acc_arr[i] + 2.)
|
|
return lax.cond(i % 2 == 0,
|
|
arr1,
|
|
lambda arr1: arr1.at[i].set(arr1[i] + 1),
|
|
arr1,
|
|
lambda arr1: arr1)
|
|
arr = lax.fori_loop(0, arr.shape[0], loop_body, arr)
|
|
|
|
This API quickly gets unreadable with deeper nested loops.
|
|
With the utilities in this module you can write loops and conditionals that
|
|
look closer to plain Python, as long as you keep the loop-carried state in a
|
|
special `loops.scope` object and use `for` loops over special
|
|
`scope.range` iterators::
|
|
|
|
from jax.experimental import loops
|
|
with loops.Scope() as s:
|
|
s.arr = np.zeros(5) # Create the mutable state of the loop as `scope` fields.
|
|
for i in s.range(s.arr.shape[0]):
|
|
s.arr = s.arr.at[i].set(s.arr[i] + 2.)
|
|
for _ in s.cond_range(i % 2 == 0): # Conditionals as loops with 0 or 1 iterations
|
|
s.arr = s.arr.at[i].set(s.arr[i] + 1.)
|
|
|
|
Loops constructed with `range` must have literal constant bounds. If you need
|
|
loops with dynamic bounds, you can use the more general `while_range` iterator.
|
|
However, in that case the `grad` transformation is not supported::
|
|
|
|
s.idx = start
|
|
for _ in s.while_range(lambda: s.idx < end):
|
|
s.idx += 1
|
|
|
|
Notes:
|
|
* Loops and conditionals to be functionalized can appear only inside scopes
|
|
constructed with `loops.Scope` and they must use one of the `Scope.range`
|
|
iterators. All other loops are unrolled during tracing, as usual in JAX.
|
|
* Only scope data (stored in fields of the scope object) is functionalized.
|
|
All other state, e.g., in other Python variables, will not be considered as
|
|
being part of the loop output. All references to the mutable state should be
|
|
through the scope, e.g., `s.arr`.
|
|
* The scope fields can be pytrees, and can themselves be mutable data structures.
|
|
* Conceptually, this model is still "functional" in the sense that a loop over
|
|
a `Scope.range` behaves as a function whose input and output is the scope data.
|
|
* Scopes should be passed down to callees that need to use loop
|
|
functionalization, or they may be nested.
|
|
* The programming model is that the loop body over a `scope.range` is traced
|
|
only once, using abstract shape values, similar to how JAX traces function
|
|
bodies.
|
|
|
|
Restrictions:
|
|
* The tracing of the loop body should not exit prematurely with `return`,
|
|
`exception`, `break`. This would be detected and reported as errors when we
|
|
encounter unnested scopes.
|
|
* The loop index variable should not be used after the loop. Similarly, one
|
|
should not use outside the loop data computed in the loop body, except data
|
|
stored in fields of the scope object.
|
|
* No new mutable state can be created inside a loop to be functionalized.
|
|
All mutable state must be created outside all loops and conditionals.
|
|
* Once the loop starts all updates to loop state must be with new values of the
|
|
same abstract values as the values on loop start.
|
|
* For a `while` loop, the conditional function is not allowed to modify the
|
|
scope state. This is a checked error. Also, for `while` loops, the `grad`
|
|
transformation does not work. An alternative that allows `grad` is a bounded
|
|
loop (`range`).
|
|
|
|
Transformations:
|
|
* All transformations are supported, except `grad` is not supported for
|
|
`Scope.while_range` loops.
|
|
* `vmap` is very useful for such loops because it pushes more work into the
|
|
inner-loops, which should help performance for accelerators.
|
|
|
|
For usage example, see tests/loops_test.py.
|
|
"""
|
|
|
|
from functools import partial
|
|
import itertools
|
|
import numpy as np
|
|
import traceback
|
|
from typing import Any, Dict, List, cast
|
|
|
|
from jax import lax, core
|
|
from jax._src.lax import control_flow as lax_control_flow
|
|
from jax import tree_util
|
|
from jax.errors import UnexpectedTracerError
|
|
from jax.interpreters import partial_eval as pe
|
|
from jax._src.util import safe_map
|
|
|
|
|
|
class Scope(object):
|
|
"""A scope context manager to keep the state of loop bodies for functionalization.
|
|
|
|
Usage::
|
|
|
|
with Scope() as s:
|
|
s.data = 0.
|
|
for i in s.range(5):
|
|
s.data += 1.
|
|
return s.data
|
|
|
|
"""
|
|
|
|
def __init__(self):
|
|
# state to be functionalized, indexed by names, can be pytrees
|
|
self._mutable_state: Dict[str, Any] = {}
|
|
# the pytrees of abstract values; set when the loop starts.
|
|
self._mutable_state_aval: Dict[str, core.AbstractValue] = {}
|
|
|
|
self._active_ranges = [] # stack of active ranges, last one is the innermost.
|
|
self._count_subtraces = 0 # How many net started subtraces, for error recovery
|
|
|
|
def range(self, first, second=None, third=None):
|
|
"""Creates an iterator for bounded iterations to be functionalized.
|
|
|
|
The body is converted to a `lax.scan`, for which all JAX transformations work.
|
|
The `first`, `second`, and `third` arguments must be integer literals.
|
|
|
|
Usage::
|
|
|
|
range(5) # start=0, end=5, step=1
|
|
range(1, 5) # start=1, end=5, step=1
|
|
range(1, 5, 2) # start=1, end=5, step=2
|
|
|
|
s.out = 1.
|
|
for i in scope.range(5):
|
|
s.out += 1.
|
|
"""
|
|
if third is not None:
|
|
start = int(first)
|
|
stop = int(second)
|
|
step = int(third)
|
|
else:
|
|
step = 1
|
|
if second is not None:
|
|
start = int(first)
|
|
stop = int(second)
|
|
else:
|
|
start = 0
|
|
stop = int(first)
|
|
return _BodyTracer(self, _BoundedLoopBuilder(start, stop, step))
|
|
|
|
def cond_range(self, pred):
|
|
"""Creates a conditional iterator with 0 or 1 iterations based on the boolean.
|
|
|
|
The body is converted to a `lax.cond`. All JAX transformations work.
|
|
|
|
Usage::
|
|
|
|
for _ in scope.cond_range(s.field < 0.):
|
|
s.field = - s.field
|
|
"""
|
|
# TODO: share these checks with lax_control_flow.cond
|
|
if len(np.shape(pred)) != 0:
|
|
raise TypeError(
|
|
"Pred must be a scalar, got {} of shape {}.".format(pred, np.shape(pred)))
|
|
|
|
try:
|
|
pred_dtype = np.result_type(pred)
|
|
except TypeError as err:
|
|
msg = ("Pred type must be either boolean or number, got {}.")
|
|
raise TypeError(msg.format(pred)) from err
|
|
|
|
if pred_dtype.kind != 'b':
|
|
if pred_dtype.kind in 'iuf':
|
|
pred = pred != 0
|
|
else:
|
|
msg = ("Pred type must be either boolean or number, got {}.")
|
|
raise TypeError(msg.format(pred_dtype))
|
|
|
|
return _BodyTracer(self, _CondBuilder(pred))
|
|
|
|
def while_range(self, cond_func):
|
|
"""Creates an iterator that continues as long as `cond_func` returns true.
|
|
|
|
The body is converted to a `lax.while_loop`.
|
|
The `grad` transformation does not work.
|
|
|
|
Usage::
|
|
|
|
for _ in scope.while_range(lambda: s.loss > 1.e-5):
|
|
s.loss = loss(...)
|
|
|
|
Args:
|
|
cond_func: a lambda with no arguments, the condition for the "while".
|
|
"""
|
|
return _BodyTracer(self, _WhileBuilder(cond_func))
|
|
|
|
def _push_range(self, range_):
|
|
for ar in self._active_ranges:
|
|
if ar is range_:
|
|
raise ValueError("Range is reused nested inside itself.")
|
|
self._active_ranges.append(range_)
|
|
|
|
def _pop_range(self, range_):
|
|
if not (range_ is self._active_ranges[-1]):
|
|
self._error_premature_exit_range()
|
|
self._active_ranges.pop()
|
|
|
|
def _error_premature_exit_range(self):
|
|
"""Raises error about premature exit from a range"""
|
|
msg = "Some ranges have exited prematurely. The innermost such range is at\n{}"
|
|
raise ValueError(msg.format(self._active_ranges[-1].location()))
|
|
|
|
def __getattr__(self, key):
|
|
"""Accessor for scope data.
|
|
|
|
Called only if the attribute is not found, which will happen when we read
|
|
scope data that has been stored in self._mutable_state.
|
|
"""
|
|
mt_val = self._mutable_state.get(key)
|
|
if mt_val is None:
|
|
raise AttributeError(
|
|
"Reading uninitialized data '{}' from the scope.".format(key))
|
|
return mt_val
|
|
|
|
def __setattr__(self, key, value):
|
|
"""Update scope data to be functionalized.
|
|
|
|
Called for *all* attribute setting.
|
|
"""
|
|
if key in ["_active_ranges", "_mutable_state", "_mutable_state_aval", "_count_subtraces"]:
|
|
object.__setattr__(self, key, value)
|
|
else:
|
|
if self._active_ranges:
|
|
if key not in self._mutable_state:
|
|
raise ValueError(
|
|
"New mutable state '{}' cannot be created inside a loop.".format(key))
|
|
assert key in self._mutable_state_aval
|
|
old_aval = self._mutable_state_aval[key]
|
|
flat_values, flat_tree = tree_util.tree_flatten(value)
|
|
new_aval = flat_tree.unflatten(safe_map(_BodyTracer.abstractify, flat_values))
|
|
if old_aval != new_aval:
|
|
msg = (f"Mutable state '{key}' is updated with new abstract value "
|
|
f"{new_aval}, which is different from previous one {old_aval}")
|
|
raise TypeError(msg)
|
|
self._mutable_state[key] = value
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
try:
|
|
if exc_type is None:
|
|
if self._active_ranges: # We have some ranges that we did not exit properly
|
|
self._error_premature_exit_range()
|
|
return True
|
|
else:
|
|
# The exception may come from inside one or more ranges. We let the current
|
|
# exception propagate, assuming it terminates the tracing. If not, the
|
|
# tracers may be left in an inconsistent state.
|
|
return False # re-raise
|
|
finally:
|
|
# Ensure we leave the global trace_state as we found it
|
|
while self._count_subtraces > 0:
|
|
self.end_subtrace()
|
|
|
|
def start_subtrace(self):
|
|
"""Starts a nested trace, returns the Trace object."""
|
|
# TODO: This follows the __enter__ part of core.new_main.
|
|
level = core.thread_local_state.trace_state.trace_stack.next_level()
|
|
main = core.MainTrace(level, pe.JaxprTrace)
|
|
core.thread_local_state.trace_state.trace_stack.push(main)
|
|
self._count_subtraces += 1
|
|
return pe.JaxprTrace(main, core.cur_sublevel())
|
|
|
|
def end_subtrace(self):
|
|
# TODO: This follows the __exit__ part of core.new_main
|
|
core.thread_local_state.trace_state.trace_stack.pop()
|
|
self._count_subtraces -= 1
|
|
|
|
|
|
class _BodyTracer(object):
|
|
"""Traces the body of the loop and builds a functional control-flow representation.
|
|
|
|
This class is also an iterator, only the first iteration is traced.
|
|
"""
|
|
|
|
def __init__(self, scope, loop_builder):
|
|
"""
|
|
Params:
|
|
scope: the current scope
|
|
loop_builder: instance of _LoopBuilder
|
|
"""
|
|
self.scope = scope
|
|
self.loop_builder = loop_builder
|
|
self.first_iteration = True # If we are tracing the first iteration
|
|
# Stack trace, without this line and the s.range function
|
|
self.stack = traceback.StackSummary.from_list(
|
|
cast(List[Any], traceback.extract_stack()[:-2]))
|
|
|
|
# Next are state kept from the start of the first iteration to the end of the iteration.
|
|
# List of scope fields carried through the loop
|
|
self.carried_state_names: List[str] = None
|
|
self.carried_state_initial = {} # Copy of the initial values of state, before loop starts
|
|
# The parameters that were created for state upon entering an arbitrary iteration.
|
|
self.carried_state_vars = {} # For each state, the list of Tracer variables introduced
|
|
# when starting to trace the loop body.
|
|
|
|
self.trace = None
|
|
|
|
def location(self):
|
|
"""A multiline string representing the source location of the range."""
|
|
if self.stack is not None:
|
|
return " ".join(self.stack.format())
|
|
else:
|
|
return ""
|
|
|
|
def __iter__(self):
|
|
"""Called before starting the first iteration."""
|
|
self.first_iteration = True # In case we reuse the range
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.first_iteration:
|
|
self.first_iteration = False
|
|
self.scope._push_range(self)
|
|
self.start_tracing_body()
|
|
return self._index_var
|
|
else:
|
|
self.end_tracing_body()
|
|
self.scope._pop_range(self)
|
|
raise StopIteration # Trace only one iteration.
|
|
|
|
def next(self): # For PY2
|
|
return self.__next__()
|
|
|
|
def start_tracing_body(self):
|
|
"""Called upon starting the tracing of the loop body."""
|
|
# TODO: This is the first part of partial_eval.trace_to_subjaxpr. Share.
|
|
self.trace = self.scope.start_subtrace()
|
|
# The entire state is carried.
|
|
self.carried_state_names = sorted(self.scope._mutable_state.keys())
|
|
for key in self.carried_state_names:
|
|
init_val = self.scope._mutable_state[key]
|
|
flat_init_vals, init_tree = tree_util.tree_flatten(init_val)
|
|
flat_init_avals = safe_map(_BodyTracer.abstractify, flat_init_vals)
|
|
flat_init_pvals = safe_map(pe.PartialVal.unknown, flat_init_avals)
|
|
flat_init_vars = safe_map(self.trace.new_arg, flat_init_pvals)
|
|
self.carried_state_vars[key] = flat_init_vars
|
|
# Set the scope._mutable_state to new tracing variables.
|
|
self.scope._mutable_state[key] = init_tree.unflatten(flat_init_vars)
|
|
self.scope._mutable_state_aval[key] = init_tree.unflatten(flat_init_avals)
|
|
# Make a copy of the initial state by unflattening the flat_init_vals
|
|
self.carried_state_initial[key] = init_tree.unflatten(flat_init_vals)
|
|
|
|
index_var_aval = _BodyTracer.abstractify(0)
|
|
index_var_pval = pe.PartialVal.unknown(index_var_aval)
|
|
self._index_var = self.trace.new_arg(index_var_pval)
|
|
|
|
def end_tracing_body(self):
|
|
"""Called when we are done tracing one iteration of the body."""
|
|
# We will turn the body of the loop into a function that takes some values
|
|
# for the scope state (carried_state_names) and returns the values for the
|
|
# same state fields after one execution of the body. For some of the ranges,
|
|
# e.g., scope.range, the function will also take the index_var as last parameter.
|
|
in_tracers = tuple(itertools.chain(*[self.carried_state_vars[ms] for ms in self.carried_state_names]))
|
|
if self.loop_builder.can_use_index_var():
|
|
in_tracers += (self._index_var,)
|
|
|
|
# Make the jaxpr for the body of the loop
|
|
# TODO: See which mutable state was changed in the one iteration.
|
|
# For now, we assume all state changes.
|
|
body_out_tracers = []
|
|
for key in self.carried_state_names:
|
|
new_val = self.scope._mutable_state[key]
|
|
flat_new_values, flat_new_tree = tree_util.tree_flatten(new_val)
|
|
body_out_tracers.extend(flat_new_values)
|
|
assert key in self.scope._mutable_state_aval
|
|
old_aval = self.scope._mutable_state_aval[key]
|
|
new_aval = flat_new_tree.unflatten(safe_map(_BodyTracer.abstractify, flat_new_values))
|
|
if old_aval != new_aval:
|
|
msg = (f"Mutable state '{key}' had at the end of the loop body new abstract value "
|
|
f"{new_aval}, which is different from initial one {old_aval}")
|
|
raise TypeError(msg)
|
|
|
|
try:
|
|
# If the body actually uses the index variable, and is not allowed to
|
|
# (e.g., cond_range and while_range), then in_tracers will not contain
|
|
# the tracer for the index_var, and trace_to_jaxpr_finalize will throw
|
|
# an assertion error.
|
|
body_closed_jaxpr, body_const_vals = _BodyTracer.trace_to_jaxpr_finalize(
|
|
in_tracers=in_tracers,
|
|
out_tracers=body_out_tracers,
|
|
trace=self.trace)
|
|
except UnexpectedTracerError as e:
|
|
if "Tracer not among input tracers" in str(e):
|
|
raise ValueError("Body of cond_range or while_range should not use the "
|
|
"index variable returned by iterator.") from e
|
|
raise
|
|
# End the subtrace for the loop body, before we trace the condition
|
|
self.scope.end_subtrace()
|
|
|
|
carried_init_val = tuple([self.carried_state_initial[ms]
|
|
for ms in self.carried_state_names])
|
|
carried_init_vals, carried_tree = tree_util.tree_flatten(carried_init_val)
|
|
assert len(carried_init_vals) == len(body_out_tracers)
|
|
|
|
carried_out_vals = self.loop_builder.build_output_vals(
|
|
self.scope, self.carried_state_names, carried_tree,
|
|
carried_init_vals, body_closed_jaxpr, body_const_vals)
|
|
carried_mutable_state_unflattened = tree_util.tree_unflatten(carried_tree,
|
|
carried_out_vals)
|
|
|
|
# Update the mutable state with the values of the changed vars, after the loop.
|
|
for ms, mv in zip(self.carried_state_names, carried_mutable_state_unflattened):
|
|
self.scope._mutable_state[ms] = mv
|
|
|
|
@staticmethod
|
|
def abstractify(x):
|
|
return core.raise_to_shaped(core.get_aval(x), weak_type=False)
|
|
|
|
@staticmethod
|
|
def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True):
|
|
# TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
|
|
instantiate = [instantiate] * len(out_tracers)
|
|
out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers))
|
|
out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
|
|
instantiate, out_tracers)
|
|
jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
|
|
assert not env # TODO: this is from partial_eval.trace_to_jaxpr. Share.
|
|
closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
|
return closed_jaxpr, consts
|
|
|
|
|
|
class _LoopBuilder(object):
|
|
"""Abstract superclass for the loop builders"""
|
|
|
|
def can_use_index_var(self):
|
|
"""Whether this kind of loop can use the index var returned by the range iterator."""
|
|
raise NotImplementedError
|
|
|
|
def build_output_vals(self, scope, carried_state_names, carried_tree,
|
|
init_vals, body_closed_jaxpr, body_const_vals):
|
|
"""Builds the output values for the loop carried state.
|
|
|
|
Params:
|
|
scope: the current Scope object.
|
|
carried_state_names: the list of names of mutable state fields that is
|
|
carried through the body.
|
|
carried_tree: the PyTreeDef for the tuple of carried_state_names.
|
|
init_vals: the initial values on body entry corresponding to the init_tree.
|
|
body_closed_jaxpr: the Jaxpr for the body returning the new values of
|
|
carried_state_names.
|
|
body_const_vals: the constant values for the body.
|
|
|
|
Returns:
|
|
the output tracer corresponding to the lax primitive representing the loop.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def __str__(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class _BoundedLoopBuilder(_LoopBuilder):
|
|
"""Builds a lax operation corresponding to a bounded range iteration."""
|
|
|
|
def __init__(self, start, stop, step):
|
|
self.start = start
|
|
self.stop = stop
|
|
self.step = step
|
|
self._index_var = None # The parameter for the index variable
|
|
|
|
def can_use_index_var(self):
|
|
return True
|
|
|
|
def build_output_vals(self, scope, carried_state_names, carried_tree,
|
|
init_vals, body_closed_jaxpr, body_const_vals):
|
|
arange_val = np.arange(self.start, stop=self.stop, step=self.step)
|
|
return lax_control_flow.scan_p.bind(*body_const_vals, *init_vals, arange_val,
|
|
reverse=False, length=arange_val.shape[0],
|
|
jaxpr=body_closed_jaxpr,
|
|
num_consts=len(body_const_vals),
|
|
num_carry=len(init_vals),
|
|
linear=(False,) * (len(body_const_vals) +
|
|
len(init_vals) + 1),
|
|
unroll=1)
|
|
|
|
|
|
class _CondBuilder(_LoopBuilder):
|
|
"""Builds a lax.cond operation."""
|
|
|
|
def __init__(self, pred):
|
|
self.index = lax.convert_element_type(pred, np.int32)
|
|
|
|
def can_use_index_var(self):
|
|
return False
|
|
|
|
def build_output_vals(self, scope, carried_state_names, carried_tree,
|
|
init_vals, body_closed_jaxpr, body_const_vals):
|
|
# Simulate a pass-through false branch
|
|
in_vals, in_tree = tree_util.tree_flatten(
|
|
(body_const_vals, tree_util.tree_unflatten(carried_tree, init_vals)))
|
|
in_avals = safe_map(_BodyTracer.abstractify, in_vals)
|
|
pass_through_closed_jaxpr, pass_through_const_vals, _ = (
|
|
lax_control_flow._initial_style_jaxpr(
|
|
lambda *args: args[1],
|
|
in_tree,
|
|
tuple(in_avals)))
|
|
assert len(pass_through_const_vals) == 0
|
|
args = [*body_const_vals, *init_vals]
|
|
return lax_control_flow.cond_p.bind(
|
|
self.index, *args,
|
|
branches=(pass_through_closed_jaxpr, body_closed_jaxpr),
|
|
linear=(False,) * len(args))
|
|
|
|
|
|
class _WhileBuilder(_LoopBuilder):
|
|
"""Builds a lax.while operation."""
|
|
|
|
def __init__(self, cond_func):
|
|
self.cond_func = cond_func # Function with 0 arguments (can reference the scope)
|
|
|
|
def can_use_index_var(self):
|
|
return False
|
|
|
|
def build_output_vals(self, scope, carried_state_names, carried_tree,
|
|
init_vals, body_closed_jaxpr, body_const_vals):
|
|
# Trace the conditional function. cond_func takes 0 arguments, but
|
|
# for lax.while we need a conditional function that takes the
|
|
# carried_state_names. _initial_style_jaxpr will start its own trace and
|
|
# will create tracers for all the carried state. We must put these values
|
|
# in the scope._mutable_state before we trace the conditional
|
|
# function.
|
|
def cond_func_wrapped(*args):
|
|
assert len(args) == len(carried_state_names)
|
|
for ms, init_ms in zip(carried_state_names, args):
|
|
scope._mutable_state[ms] = init_ms
|
|
res = self.cond_func()
|
|
# Conditional function is not allowed to modify the scope state
|
|
for ms, init_ms in zip(carried_state_names, args):
|
|
if not (scope._mutable_state[ms] is init_ms):
|
|
raise ValueError(f"Conditional function modifies scope.{ms} field.")
|
|
return res
|
|
|
|
init_avals = safe_map(_BodyTracer.abstractify, init_vals)
|
|
cond_jaxpr, cond_consts, cond_tree = (
|
|
lax_control_flow._initial_style_jaxpr(cond_func_wrapped,
|
|
carried_tree,
|
|
tuple(init_avals)))
|
|
# TODO: share these checks with lax_control_flow.while
|
|
if not tree_util.treedef_is_leaf(cond_tree):
|
|
raise TypeError(f"cond_fun must return a boolean scalar, but got pytree {cond_tree}.")
|
|
if not safe_map(core.typecompat, cond_jaxpr.out_avals, [core.ShapedArray((), np.bool_)]):
|
|
raise TypeError(f"cond_fun must return a boolean scalar, but got output type(s) "
|
|
f"{cond_jaxpr.out_avals}.")
|
|
|
|
return lax_control_flow.while_p.bind(*cond_consts, *body_const_vals, *init_vals,
|
|
cond_nconsts=len(cond_consts),
|
|
cond_jaxpr=cond_jaxpr,
|
|
body_nconsts=len(body_const_vals),
|
|
body_jaxpr=body_closed_jaxpr)
|