mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #1706 from gnecula/loops
An implementation of an experimental syntactic sugar for 'for' and `while` loops and conditionals.
This commit is contained in:
commit
397a244e7f
6
docs/jax.experimental.loops.rst
Normal file
6
docs/jax.experimental.loops.rst
Normal file
@ -0,0 +1,6 @@
|
||||
jax.experimental.loops module
|
||||
=============================
|
||||
|
||||
.. automodule:: jax.experimental.loops
|
||||
:members:
|
||||
:show-inheritance:
|
@ -4,6 +4,7 @@ jax.experimental package
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
jax.experimental.loops
|
||||
jax.experimental.optimizers
|
||||
jax.experimental.stax
|
||||
jax.experimental.vectorize
|
||||
|
571
jax/experimental/loops.py
Normal file
571
jax/experimental/loops.py
Normal file
@ -0,0 +1,571 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Loops is an **experimental** module for syntactic sugar for loops and control-flow.
|
||||
|
||||
The current implementation should convert loops correctly to JAX internal
|
||||
representation, and most transformations should work (see below), but we have
|
||||
not yet fine-tuned the performance of the resulting XLA compilation!
|
||||
|
||||
By default, loops and control-flow in JAX are executed and inlined during tracing.
|
||||
For example, in the following code the `for` loop is unrolled during JAX tracing::
|
||||
|
||||
arr = onp.zeros(5)
|
||||
for i in range(arr.shape[0]):
|
||||
arr[i] += 2.
|
||||
if i % 2 == 0:
|
||||
arr[i] += 1.
|
||||
|
||||
In order to capture the structured control-flow one has to use the higher-order
|
||||
JAX operations, which require you to express the body of the loops and
|
||||
conditionals as functions, and the array updates using a functional style that
|
||||
returns an updated array, e.g.::
|
||||
|
||||
arr = onp.zeros(5)
|
||||
def loop_body(i, acc_arr):
|
||||
arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
|
||||
return lax.cond(i % 2 == 0,
|
||||
arr1,
|
||||
lambda arr1: ops.index_update(arr1, i, arr1[i] + 1),
|
||||
arr1,
|
||||
lambda arr1: arr1)
|
||||
arr = lax.fori_loop(0, arr.shape[0], loop_body, arr)
|
||||
|
||||
The default notation quickly gets unreadable with deeper nested loops.
|
||||
With the utilities in this module you can write loops and conditionals that
|
||||
look closer to plain Python, as long as you keep the loop-carried state in a
|
||||
special `loops.scope` object and use `for` loops over special
|
||||
`scope.range` iterators::
|
||||
|
||||
from jax.experimental import loops
|
||||
with loops.Scope() as s:
|
||||
s.arr = np.zeros(5) # Create the mutable state of the loop as `scope` fields.
|
||||
for i in s.range(s.arr.shape[0]):
|
||||
s.arr = ops.index_update(s.arr, i, s.arr[i] + 2.)
|
||||
for _ in s.cond_range(i % 2 == 0): # Conditionals as loops with 0 or 1 iterations
|
||||
s.arr = ops.index_update(s.arr, i, s.arr[i] + 1.)
|
||||
|
||||
Loops constructed with `range` must have literal constant bounds. If you need
|
||||
loops with dynamic bounds, you can use the more general `while_range` iterator.
|
||||
However, in that case that `grad` transformation is not supported::
|
||||
|
||||
s.idx = start
|
||||
for _ in s.while_range(lambda: s.idx < end):
|
||||
s.idx += 1
|
||||
|
||||
Notes:
|
||||
* Loops and conditionals to be functionalized can appear only inside scopes
|
||||
constructed with `loops.Scope` and they must use one of the `Scope.range`
|
||||
iterators. All other loops are unrolled during tracing, as usual in JAX.
|
||||
* Only scope data (stored in fields of the scope object) is functionalized.
|
||||
All other state, e.g., in other Python variables, will not be considered as
|
||||
being part of the loop output. All references to the mutable state should be
|
||||
through the scope: `s.arr`.
|
||||
* Conceptually, this model is still "functional" in the sense that a loop over
|
||||
a `Scope.range` behaves as a function whose input and output is the scope data.
|
||||
* Scopes should be passed down to callees that need to use loop
|
||||
functionalization, or they may be nested.
|
||||
* The programming model is that the loop body over a `scope.range` is traced
|
||||
only once, using abstract shape values, similar to how JAX traces function
|
||||
bodies.
|
||||
|
||||
Restrictions:
|
||||
* The tracing of the loop body should not exit prematurely with `return`,
|
||||
`exception`, `break`. This would be detected and reported as errors when we
|
||||
encounter unnested scopes.
|
||||
* The loop index variable should not be used after the loop. Similarly, one
|
||||
should not use outside the loop data computed in the loop body, except data
|
||||
stored in fields of the scope object.
|
||||
* No new mutable state can be created inside a loop to be functionalized.
|
||||
All mutable state must be created outside all loops and conditionals.
|
||||
* For a `while` loop, the conditional function is not allowed to modify the
|
||||
scope state. This is a checked error. Also, for `while` loops the `grad`
|
||||
transformation does not work. An alternative that allows `grad` is a bounded
|
||||
loop (`range`).
|
||||
|
||||
Transformations:
|
||||
* All transformations are supported, except `grad` is not supported for
|
||||
`Scope.while_range` loops.
|
||||
* `vmap` is very useful for such loops because it pushes more work into the
|
||||
inner-loops, which should help performance for accelerators.
|
||||
|
||||
For usage example, see tests/loops_test.py.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
import itertools
|
||||
import numpy as onp
|
||||
import six
|
||||
import traceback
|
||||
|
||||
from jax import abstract_arrays
|
||||
from jax import lax, core
|
||||
from jax.lax import lax_control_flow
|
||||
from jax import tree_util
|
||||
from jax import numpy as jnp
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.util import unzip2, safe_map
|
||||
|
||||
|
||||
class Scope(object):
|
||||
"""A scope context manager to keep the state of loop bodies for functionalization.
|
||||
|
||||
Usage::
|
||||
|
||||
with Scope() as s:
|
||||
s.data = 0.
|
||||
for i in s.range(5):
|
||||
s.data += 1.
|
||||
return s.data
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._mutable_state = {} # state to be functionalized, indexed by name.
|
||||
self._active_ranges = [] # stack of active ranges, last one is the innermost.
|
||||
|
||||
def range(self, first, second=None, third=None):
|
||||
"""Creates an iterator for bounded iterations to be functionalized.
|
||||
|
||||
The body is converted to a `lax.scan`, for which all JAX transformations work.
|
||||
The `first`, `second`, and `third` arguments must be integer literals.
|
||||
|
||||
Usage::
|
||||
|
||||
range(5) # start=0, end=5, step=1
|
||||
range(1, 5) # start=1, end=5, step=1
|
||||
range(1, 5, 2) # start=1, end=5, step=2
|
||||
|
||||
s.out = 1.
|
||||
for i in scope.range(5):
|
||||
s.out += 1.
|
||||
"""
|
||||
if third is not None:
|
||||
start = int(first)
|
||||
stop = int(second)
|
||||
step = int(third)
|
||||
else:
|
||||
step = 1
|
||||
if second is not None:
|
||||
start = int(first)
|
||||
stop = int(second)
|
||||
else:
|
||||
start = 0
|
||||
stop = int(first)
|
||||
return _BodyTracer(self, _BoundedLoopBuilder(start, stop, step))
|
||||
|
||||
def cond_range(self, pred):
|
||||
"""Creates a conditional iterator with 0 or 1 iterations based on the boolean.
|
||||
|
||||
The body is converted to a `lax.cond`. All JAX transformations work.
|
||||
|
||||
Usage::
|
||||
|
||||
for _ in scope.cond_range(s.field < 0.):
|
||||
s.field = - s.field
|
||||
"""
|
||||
# TODO: share these checks with lax_control_flow.cond
|
||||
if len(onp.shape(pred)) != 0:
|
||||
raise TypeError(
|
||||
"Pred must be a scalar, got {} of shape {}.".format(pred, onp.shape(pred)))
|
||||
|
||||
try:
|
||||
pred_dtype = onp.result_type(pred)
|
||||
except TypeError:
|
||||
msg = ("Pred type must be either boolean or number, got {}.")
|
||||
raise TypeError(msg.format(pred))
|
||||
|
||||
if pred_dtype.kind != 'b':
|
||||
if pred_dtype.kind in 'iuf':
|
||||
pred = pred != 0
|
||||
else:
|
||||
msg = ("Pred type must be either boolean or number, got {}.")
|
||||
raise TypeError(msg.format(pred_dtype))
|
||||
|
||||
return _BodyTracer(self, _CondBuilder(pred))
|
||||
|
||||
def while_range(self, cond_func):
|
||||
"""Creates an iterator that continues as long as `cond_func` returns true.
|
||||
|
||||
The body is converted to a `lax.while_loop`.
|
||||
The `grad` transformation does not work.
|
||||
|
||||
Usage::
|
||||
|
||||
for _ in scope.while_range(lambda: s.loss > 1.e-5):
|
||||
s.loss = loss(...)
|
||||
|
||||
Args:
|
||||
cond_func: a lambda with no arguments, the condition for the "while".
|
||||
"""
|
||||
return _BodyTracer(self, _WhileBuilder(cond_func))
|
||||
|
||||
def _push_range(self, range_):
|
||||
for ar in self._active_ranges:
|
||||
if ar is range_:
|
||||
raise ValueError("Range is reused nested inside itself.")
|
||||
self._active_ranges.append(range_)
|
||||
|
||||
def _pop_range(self, range_):
|
||||
if not (range_ is self._active_ranges[-1]):
|
||||
self._error_premature_exit_range()
|
||||
self._active_ranges.pop()
|
||||
|
||||
def _error_premature_exit_range(self):
|
||||
"""Raises error about premature exit from a range"""
|
||||
msg = "Some ranges have exited prematurely. The innermost such range is at\n{}"
|
||||
raise ValueError(msg.format(self._active_ranges[-1].location()))
|
||||
|
||||
def __getattr__(self, key):
|
||||
"""Accessor for scope data.
|
||||
|
||||
Called only if the attribute is not found, which will happen when we read
|
||||
scope data that has been stored in self._mutable_state.
|
||||
"""
|
||||
mt_val = self._mutable_state.get(key)
|
||||
if mt_val is None:
|
||||
raise AttributeError(
|
||||
"Reading uninitialized data '{}' from the scope.".format(key))
|
||||
return mt_val
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
"""Update scope data to be functionalized.
|
||||
|
||||
Called for *all* attribute setting.
|
||||
"""
|
||||
if key in ["_active_ranges", "_mutable_state"]:
|
||||
object.__setattr__(self, key, value)
|
||||
else:
|
||||
if self._active_ranges and key not in self._mutable_state:
|
||||
raise ValueError(
|
||||
"New mutable state '{}' cannot be created inside a loop.".format(key))
|
||||
self._mutable_state[key] = value
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
if self._active_ranges: # We have some ranges that we did not exit properly
|
||||
self._error_premature_exit_range()
|
||||
return True
|
||||
else:
|
||||
# The exception may come from inside one or more ranges. We let the current
|
||||
# exception propagate, assuming it terminates the tracing. If not, the
|
||||
# tracers may be left in an inconsistent state.
|
||||
return False # re-raise
|
||||
|
||||
|
||||
class _BodyTracer(object):
|
||||
"""Traces the body of the loop and builds a functional control-flow representation.
|
||||
|
||||
This class is also an iterator, only the first iteration is traced.
|
||||
"""
|
||||
|
||||
def __init__(self, scope, loop_builder):
|
||||
"""
|
||||
Params:
|
||||
scope: the current scope
|
||||
loop_builder: instance of _LoopBuilder
|
||||
"""
|
||||
self.scope = scope
|
||||
self.loop_builder = loop_builder
|
||||
self.first_iteration = True # If we are tracing the first iteration
|
||||
if six.PY3:
|
||||
# Stack trace, without this line and the s.range function
|
||||
self.stack = traceback.StackSummary.from_list(traceback.extract_stack()[:-2])
|
||||
else:
|
||||
self.stack = None
|
||||
|
||||
# Next are state kept from the start of the first iteration to the end of the iteration.
|
||||
self.carried_state_initial = {}
|
||||
# The parameters that were created for state upon entering an arbitrary iteration.
|
||||
self.carried_state_vars = {}
|
||||
|
||||
self.trace = None
|
||||
# List of scope fields carried through the loop
|
||||
self.carried_state_names = None
|
||||
self.init_tree = None # The PyTreeDef corresponding to carried_state_names
|
||||
self.init_vals = None # The values corresponding to self.init_tree
|
||||
|
||||
def location(self):
|
||||
"""A multiline string representing the source location of the range."""
|
||||
if self.stack is not None:
|
||||
return " ".join(self.stack.format())
|
||||
else:
|
||||
return ""
|
||||
|
||||
def __iter__(self):
|
||||
"""Called before starting the first iteration."""
|
||||
self.first_iteration = True # In case we reuse the range
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.first_iteration:
|
||||
self.first_iteration = False
|
||||
self.scope._push_range(self)
|
||||
self.start_tracing_body()
|
||||
return self._index_var
|
||||
else:
|
||||
self.end_tracing_body()
|
||||
self.scope._pop_range(self)
|
||||
raise StopIteration # Trace only one iteration.
|
||||
|
||||
def next(self): # For PY2
|
||||
return self.__next__()
|
||||
|
||||
def start_tracing_body(self):
|
||||
"""Called upon starting the tracing of the loop body."""
|
||||
# Make a copy of the current value of the mutable state
|
||||
self.carried_state_initial = copy.copy(self.scope._mutable_state)
|
||||
# The entire state is carried.
|
||||
self.carried_state_names = sorted(self.scope._mutable_state.keys())
|
||||
|
||||
# TODO: This is the first part of partial_eval.trace_to_subjaxpr. Share.
|
||||
self.trace = _BodyTracer.start_subtrace()
|
||||
# Set the scope._mutable_state to new tracing variables.
|
||||
for key, initial in self.carried_state_initial.items():
|
||||
mt_aval = _BodyTracer.abstractify(initial)
|
||||
mt_pval = pe.PartialVal((mt_aval, core.unit))
|
||||
mt_var = self.trace.new_arg(mt_pval)
|
||||
self.carried_state_vars[key] = mt_var
|
||||
self.scope._mutable_state[key] = mt_var
|
||||
|
||||
index_var_aval = _BodyTracer.abstractify(0)
|
||||
index_var_pval = pe.PartialVal((index_var_aval, core.unit))
|
||||
self._index_var = self.trace.new_arg(index_var_pval)
|
||||
|
||||
def end_tracing_body(self):
|
||||
"""Called when we are done tracing one iteration of the body."""
|
||||
# We will turn the body of the loop into a function that takes some values
|
||||
# for the scope state (carried_state_names) and returns the values for the
|
||||
# same state fields after one execution of the body. For some of the ranges,
|
||||
# e.g., scope.range, the function will also take the index_var as last parameter.
|
||||
in_tracers = [self.carried_state_vars[ms] for ms in self.carried_state_names]
|
||||
if self.loop_builder.can_use_index_var():
|
||||
in_tracers += [self._index_var]
|
||||
|
||||
# Make the jaxpr for the body of the loop
|
||||
# TODO: See which mutable state was changed in the one iteration.
|
||||
# For now, we assume all state changes.
|
||||
body_out_tracers = tuple([self.scope._mutable_state[ms]
|
||||
for ms in self.carried_state_names])
|
||||
try:
|
||||
# If the body actually uses the index variable, and is not allowed to
|
||||
# (e.g., cond_range and while_range), then in_tracers will not contain
|
||||
# the tracer for the index_var, and trace_to_jaxpr_finalize will throw
|
||||
# an assertion error.
|
||||
body_typed_jaxpr, body_const_vals = _BodyTracer.trace_to_jaxpr_finalize(
|
||||
in_tracers=in_tracers,
|
||||
out_tracers=body_out_tracers,
|
||||
trace=self.trace)
|
||||
except AssertionError as e:
|
||||
if "Encountered unexpected tracer" == str(e):
|
||||
raise ValueError("Body of cond_range or while_range should not use the "
|
||||
"index variable returned by iterator.")
|
||||
raise
|
||||
# End the subtrace for the loop body, before we trace the condition
|
||||
_BodyTracer.end_subtrace()
|
||||
|
||||
carried_init_val = tuple([self.carried_state_initial[ms]
|
||||
for ms in self.carried_state_names])
|
||||
carried_init_vals, carried_tree = tree_util.tree_flatten(carried_init_val)
|
||||
|
||||
carried_out_vals = self.loop_builder.build_output_vals(
|
||||
self.scope, self.carried_state_names, carried_tree,
|
||||
carried_init_vals, body_typed_jaxpr, body_const_vals)
|
||||
carried_mutable_state_unflattened = tree_util.tree_unflatten(carried_tree,
|
||||
carried_out_vals)
|
||||
|
||||
# Update the mutable state with the values of the changed vars, after the loop.
|
||||
for ms, mv in zip(self.carried_state_names, carried_mutable_state_unflattened):
|
||||
self.scope._mutable_state[ms] = mv
|
||||
|
||||
@staticmethod
|
||||
def start_subtrace():
|
||||
"""Starts a nested trace, returns the Trace object."""
|
||||
# TODO: This follows the __enter__ part of core.new_master. share
|
||||
level = core.trace_state.trace_stack.next_level(False)
|
||||
master = core.MasterTrace(level, pe.JaxprTrace)
|
||||
core.trace_state.trace_stack.push(master, False)
|
||||
return pe.JaxprTrace(master, core.cur_sublevel())
|
||||
|
||||
@staticmethod
|
||||
def end_subtrace():
|
||||
# TODO: This follows the __exit__ part of core.new_master
|
||||
core.trace_state.trace_stack.pop(False)
|
||||
|
||||
@staticmethod
|
||||
def abstractify(x):
|
||||
return abstract_arrays.raise_to_shaped(core.get_aval(x))
|
||||
|
||||
@staticmethod
|
||||
def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True):
|
||||
# TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
|
||||
instantiate = [instantiate] * len(out_tracers)
|
||||
out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers))
|
||||
out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
|
||||
instantiate, out_tracers)
|
||||
jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
|
||||
out_pvals = [t.pval for t in out_tracers]
|
||||
# TODO: this is from partial_eval.trace_to_jaxpr. Share.
|
||||
assert not env
|
||||
|
||||
# TODO: this is from the final part of lax_control_flow._initial_style_jaxpr
|
||||
out_avals = safe_map(abstract_arrays.raise_to_shaped, unzip2(out_pvals)[0])
|
||||
const_avals = tuple(abstract_arrays.raise_to_shaped(core.get_aval(c))
|
||||
for c in consts)
|
||||
|
||||
in_pvals = [t.pval for t in in_tracers]
|
||||
in_avals = tuple(safe_map(abstract_arrays.raise_to_shaped, unzip2(in_pvals)[0]))
|
||||
|
||||
typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr),
|
||||
(), const_avals + in_avals, out_avals)
|
||||
return typed_jaxpr, consts
|
||||
|
||||
|
||||
class _LoopBuilder(object):
|
||||
"""Abstract superclass for the loop builders"""
|
||||
|
||||
def can_use_index_var(self):
|
||||
"""Whether this kind of loop can use the index var returned by the range iterator."""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_output_vals(self, scope, carried_state_names, carried_tree,
|
||||
init_vals, body_typed_jaxpr, body_const_vals):
|
||||
"""Builds the output values for the loop carried state.
|
||||
|
||||
Params:
|
||||
scope: the current Scope object.
|
||||
carried_state_names: the list of names of mutable state fields that is
|
||||
carried through the body.
|
||||
carried_tree: the PyTreeDef for the tuple of carried_state_names.
|
||||
init_vals: the initial values on body entry corresponding to the init_tree.
|
||||
body_typed_jaxpr: the Jaxpr for the body returning the new values of
|
||||
carried_state_names.
|
||||
body_const_vals: the constant values for the body.
|
||||
|
||||
Returns:
|
||||
the output tracer corresponding to the lax primitive representing the loop.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __str__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _BoundedLoopBuilder(_LoopBuilder):
|
||||
"""Builds a lax operation corresponding to a bounded range iteration."""
|
||||
|
||||
def __init__(self, start, stop, step):
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.step = step
|
||||
self._index_var = None # The parameter for the index variable
|
||||
|
||||
def can_use_index_var(self):
|
||||
return True
|
||||
|
||||
def build_output_vals(self, scope, carried_state_names, carried_tree,
|
||||
init_vals, body_typed_jaxpr, body_const_vals):
|
||||
arange_val = jnp.arange(self.start, stop=self.stop, step=self.step)
|
||||
return lax_control_flow.scan_p.bind(*itertools.chain(body_const_vals,
|
||||
init_vals, [arange_val]),
|
||||
forward=True, length=arange_val.shape[0],
|
||||
jaxpr=body_typed_jaxpr,
|
||||
num_consts=len(body_const_vals),
|
||||
num_carry=len(init_vals),
|
||||
linear=(False,) * (len(body_const_vals) +
|
||||
len(init_vals) + 1))
|
||||
|
||||
|
||||
class _CondBuilder(_LoopBuilder):
|
||||
"""Builds a lax.cond operation."""
|
||||
|
||||
def __init__(self, pred):
|
||||
self.pred = pred
|
||||
|
||||
def can_use_index_var(self):
|
||||
return False
|
||||
|
||||
def build_output_vals(self, scope, carried_state_names, carried_tree,
|
||||
init_vals, body_typed_jaxpr, body_const_vals):
|
||||
# Simulate a pass-through false branch
|
||||
init_avals = safe_map(_BodyTracer.abstractify, init_vals)
|
||||
false_body_typed_jaxpr, false_body_const_vals, _ = (
|
||||
lax_control_flow._initial_style_jaxpr(lambda *args: args,
|
||||
carried_tree,
|
||||
tuple(init_avals)))
|
||||
return lax_control_flow.cond_p.bind(
|
||||
*itertools.chain([self.pred], body_const_vals,
|
||||
init_vals, false_body_const_vals, init_vals),
|
||||
true_jaxpr=body_typed_jaxpr, false_jaxpr=false_body_typed_jaxpr,
|
||||
true_nconsts=len(body_const_vals), false_nconsts=len(false_body_const_vals))
|
||||
|
||||
|
||||
class _WhileBuilder(_LoopBuilder):
|
||||
"""Builds a lax.while operation."""
|
||||
|
||||
def __init__(self, cond_func):
|
||||
self.cond_func = cond_func # Function with 0 arguments (can reference the scope)
|
||||
|
||||
def can_use_index_var(self):
|
||||
return False
|
||||
|
||||
def build_output_vals(self, scope, carried_state_names, carried_tree,
|
||||
init_vals, body_typed_jaxpr, body_const_vals):
|
||||
# Trace the conditional function. cond_func takes 0 arguments, but
|
||||
# for lax.while we need a conditional function that takes the
|
||||
# carried_state_names. _initial_style_jaxpr will start its own trace and
|
||||
# will create tracers for all the carried state. We must put these values
|
||||
# in the scope._mutable_state before we trace the conditional
|
||||
# function.
|
||||
def cond_func_wrapped(*args):
|
||||
assert len(args) == len(carried_state_names)
|
||||
for ms, init_ms in zip(carried_state_names, args):
|
||||
scope._mutable_state[ms] = init_ms
|
||||
res = self.cond_func()
|
||||
# Conditional function is not allowed to modify the scope state
|
||||
for ms, init_ms in zip(carried_state_names, args):
|
||||
if not (scope._mutable_state[ms] is init_ms):
|
||||
msg = "Conditional function modifies scope.{} field."
|
||||
raise ValueError(msg.format(ms))
|
||||
return res
|
||||
|
||||
init_avals = safe_map(_BodyTracer.abstractify, init_vals)
|
||||
cond_jaxpr, cond_consts, cond_tree = (
|
||||
lax_control_flow._initial_style_jaxpr(cond_func_wrapped,
|
||||
carried_tree,
|
||||
tuple(init_avals)))
|
||||
# TODO: share these checks with lax_control_flow.while
|
||||
if not tree_util.treedef_is_leaf(cond_tree):
|
||||
msg = "cond_fun must return a boolean scalar, but got pytree {}."
|
||||
raise TypeError(msg.format(cond_tree))
|
||||
if cond_jaxpr.out_avals != [abstract_arrays.ShapedArray((), onp.bool_)]:
|
||||
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
|
||||
raise TypeError(msg.format(cond_jaxpr.out_avals))
|
||||
|
||||
return lax_control_flow.while_p.bind(*itertools.chain(cond_consts,
|
||||
body_const_vals,
|
||||
init_vals),
|
||||
cond_nconsts=len(cond_consts),
|
||||
cond_jaxpr=cond_jaxpr,
|
||||
body_nconsts=len(body_const_vals),
|
||||
body_jaxpr=body_typed_jaxpr)
|
||||
|
@ -400,7 +400,7 @@ def tracers_to_jaxpr(in_tracers, out_tracers):
|
||||
eqns.append(eqn_tracer_to_var(var, recipe))
|
||||
processed_eqns.add(recipe.eqn_id)
|
||||
elif isinstance(recipe, LambdaBinding):
|
||||
assert any(t is in_tracer for in_tracer in in_tracers)
|
||||
assert any(t is in_tracer for in_tracer in in_tracers), "Encountered unexpected tracer"
|
||||
assert in_tracers, "Lambda binding with no args"
|
||||
elif isinstance(recipe, FreeVar):
|
||||
env[var(t)] = recipe.val
|
||||
|
406
tests/loops_test.py
Normal file
406
tests/loops_test.py
Normal file
@ -0,0 +1,406 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the experimental/loops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as onp
|
||||
import re
|
||||
import six
|
||||
|
||||
from jax import api, lax, ops
|
||||
from jax import numpy as np
|
||||
from jax import test_util as jtu
|
||||
from jax.experimental import loops
|
||||
|
||||
|
||||
class LoopsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_scope_no_loops(self):
|
||||
def f_op(r):
|
||||
with loops.Scope() as s:
|
||||
s.x = r + 1
|
||||
return s.x
|
||||
self.assertAllClose(4.0, f_op(3.), check_dtypes=True)
|
||||
|
||||
def test_loop_empty(self):
|
||||
def f_op(r):
|
||||
with loops.Scope() as s:
|
||||
for _ in s.range(5):
|
||||
pass
|
||||
return r
|
||||
|
||||
self.assertAllClose(3.0, f_op(3.), check_dtypes=True)
|
||||
|
||||
def test_loop_1(self):
|
||||
"""One loop with one state var, with transforms."""
|
||||
def f_op(inc):
|
||||
with loops.Scope() as s:
|
||||
s.out = 10.
|
||||
for _ in s.range(5):
|
||||
s.out += inc
|
||||
return s.out
|
||||
def f_expected(inc):
|
||||
return 10 + 5 * inc
|
||||
self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True)
|
||||
self.assertAllClose(5., api.grad(f_op)(2.), check_dtypes=True)
|
||||
self.assertAllClose(5., api.grad(f_op)(2.), check_dtypes=True)
|
||||
inc_batch = onp.arange(5, dtype=onp.float32)
|
||||
self.assertAllClose(np.array([f_expected(inc) for inc in inc_batch]),
|
||||
api.vmap(f_op)(inc_batch), check_dtypes=True)
|
||||
|
||||
|
||||
def test_loop_2(self):
|
||||
"""One loop, two state fields."""
|
||||
def f_op(inc):
|
||||
with loops.Scope() as s:
|
||||
s.out1 = 10.
|
||||
s.out2 = 20.
|
||||
for i in s.range(5):
|
||||
s.out1 += inc
|
||||
s.out2 += 1.
|
||||
return (s.out1, s.out2)
|
||||
|
||||
self.assertAllClose((10. + 2. * 5, 20. + 1. * 5), f_op(2.), check_dtypes=True)
|
||||
|
||||
|
||||
def test_add_vectors(self):
|
||||
def add_vec(x, y):
|
||||
with loops.Scope() as s:
|
||||
n = x.shape[0]
|
||||
assert n == y.shape[0]
|
||||
s.out = np.zeros(shape=[n], dtype=np.float32)
|
||||
for i in s.range(n):
|
||||
s.out = ops.index_add(s.out, i, x[i] + y[i])
|
||||
return s.out
|
||||
|
||||
x = np.array([1., 2., 3.], dtype=np.float32)
|
||||
y = np.array([4., 5., 6.], dtype=np.float32)
|
||||
self.assertAllClose(np.add(x, y), add_vec(x, y), check_dtypes=True)
|
||||
|
||||
def test_matmul(self):
|
||||
def matmul(x, y):
|
||||
with loops.Scope() as s:
|
||||
n, m = x.shape
|
||||
m1, p = y.shape
|
||||
assert m == m1
|
||||
s.out = np.zeros(shape=[n, p], dtype=np.float32)
|
||||
for i in s.range(n):
|
||||
for j in s.range(p):
|
||||
for k in s.range(m):
|
||||
s.out = ops.index_add(s.out, (i, j), x[i, k] * y[k, j])
|
||||
return s.out
|
||||
|
||||
x = np.array([[1., 2., 3.]], dtype=np.float32) # 1x3
|
||||
y = np.array([[4.], [5.], [6.]], dtype=np.float32) # 3x1
|
||||
self.assertAllClose(np.matmul(x, y), matmul(x, y), check_dtypes=True)
|
||||
|
||||
def test_reuse_range(self):
|
||||
"""Ranges can be reused, as long as not nested in each other."""
|
||||
def f_op():
|
||||
with loops.Scope() as s:
|
||||
r1 = s.range(5)
|
||||
s.out = 0
|
||||
for _ in r1:
|
||||
s.out += 1
|
||||
for _ in r1:
|
||||
s.out += 1
|
||||
return s.out
|
||||
|
||||
self.assertEqual(10, f_op())
|
||||
|
||||
|
||||
def test_loop_nested(self):
|
||||
def f_op(inc):
|
||||
with loops.Scope() as s:
|
||||
s.out = 10.
|
||||
for i in s.range(5):
|
||||
s.out += inc
|
||||
for j in s.range(6):
|
||||
s.out += inc
|
||||
return s.out
|
||||
|
||||
self.assertAllClose(10. + 5 * (2. + 6 * 2.), f_op(2.), check_dtypes=True)
|
||||
|
||||
def test_example_doc(self):
|
||||
"The example from the module docstring."
|
||||
def f_expected():
|
||||
arr = onp.zeros(5)
|
||||
for i in range(arr.shape[0]):
|
||||
arr[i] += 2.
|
||||
if i % 2 == 0:
|
||||
arr[i] += 1.
|
||||
return arr
|
||||
|
||||
def f_op_jax():
|
||||
arr = onp.zeros(5)
|
||||
def loop_body(i, acc_arr):
|
||||
arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
|
||||
return lax.cond(i % 2 == 0,
|
||||
arr1,
|
||||
lambda arr1: ops.index_update(arr1, i, arr1[i] + 1.),
|
||||
arr1,
|
||||
lambda arr1: arr1)
|
||||
arr = lax.fori_loop(0, arr.shape[0], loop_body, arr)
|
||||
return arr
|
||||
|
||||
def f_op_loops():
|
||||
with loops.Scope() as s:
|
||||
s.arr = np.zeros(5) # Must create the mutable state of the loop as `scope` fields.
|
||||
for i in s.range(s.arr.shape[0]):
|
||||
s.arr = ops.index_update(s.arr, i, s.arr[i] + 2.)
|
||||
for _ in s.cond_range(i % 2 == 0): # Conditionals are also sugared as loops with 0 or 1 iterations
|
||||
s.arr = ops.index_update(s.arr, i, s.arr[i] + 1.)
|
||||
return s.arr
|
||||
|
||||
self.assertAllClose(f_expected(), f_op_jax(), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(), f_op_loops(), check_dtypes=True)
|
||||
|
||||
def test_loop_mutable_used_but_not_changed(self):
|
||||
def f_op(inc):
|
||||
with loops.Scope() as s:
|
||||
s.read_only = inc
|
||||
s.out = 10.
|
||||
for i in s.range(5):
|
||||
s.out += s.read_only
|
||||
# It is Ok to use regular Python variables outside loops.
|
||||
save_to_other_var = s.out
|
||||
|
||||
return save_to_other_var
|
||||
|
||||
self.assertAllClose(10. + 5 * 2., f_op(2.), check_dtypes=True)
|
||||
|
||||
def test_range_locations(self):
|
||||
"""Ranges have locations."""
|
||||
if six.PY2: self.skipTest("Source location not implemented for PY2")
|
||||
with loops.Scope() as s:
|
||||
r = s.range(5)
|
||||
cr = s.cond_range(True)
|
||||
wr = s.while_range(lambda: True)
|
||||
for range in [r, cr, wr]:
|
||||
self.assertIn("loops_test.py", range.location())
|
||||
self.assertIn(self._testMethodName, range.location())
|
||||
|
||||
def test_error_reuse_range_nested(self):
|
||||
"""Ranges cannot be reused nested in their own iteration."""
|
||||
def f_op():
|
||||
with loops.Scope() as s:
|
||||
r1 = s.range(5)
|
||||
s.out = 0
|
||||
for _ in r1:
|
||||
for _ in r1:
|
||||
s.out += 1
|
||||
return s.out
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "Range is reused nested inside itself."):
|
||||
f_op()
|
||||
|
||||
def test_error_early_exit_range(self):
|
||||
"""Ranges do not support early exit from loop body."""
|
||||
def bad_function(exit_how="break"):
|
||||
with loops.Scope() as s:
|
||||
for i in s.range(555):
|
||||
if exit_how == "break":
|
||||
break
|
||||
elif exit_how == "return":
|
||||
return 1.
|
||||
elif exit_how == "exception":
|
||||
raise ValueError("test exception")
|
||||
# Start another range, we get here after a "break" above
|
||||
for i in s.range(5):
|
||||
pass
|
||||
return 0.
|
||||
|
||||
if six.PY3:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.compile(("Some ranges have exited prematurely. The innermost such range is at"
|
||||
".*s.range.555."), re.DOTALL)):
|
||||
bad_function("break")
|
||||
with self.assertRaisesRegex(ValueError, "Some ranges have exited prematurely"):
|
||||
bad_function("return")
|
||||
# On exception exit, we let the exception propagate
|
||||
with self.assertRaisesRegex(ValueError, "test exception"):
|
||||
bad_function("exception")
|
||||
|
||||
def test_error_early_exit_range_nested(self):
|
||||
"""Exit early from a nested range."""
|
||||
def bad_function():
|
||||
with loops.Scope() as s:
|
||||
for i in s.range(5): # When we end this range, we'll find the inner range still active
|
||||
for j in s.range(6):
|
||||
break
|
||||
return 0.
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Some ranges have exited prematurely."):
|
||||
bad_function()
|
||||
|
||||
def test_loop_index_var_live_expect_fail(self):
|
||||
"""The index variable is live after the loop."""
|
||||
self.skipTest("Don't know how to check that index variable is not used after loop.")
|
||||
def f_op(r):
|
||||
with loops.Scope() as s:
|
||||
for i in s.range(r):
|
||||
pass
|
||||
return i
|
||||
|
||||
self.assertAllClose(4, f_op(4), check_dtypes=True)
|
||||
|
||||
def test_error_new_state_in_loop(self):
|
||||
"""Error when creating new state in a loop."""
|
||||
def f_op(inc):
|
||||
with loops.Scope() as s:
|
||||
s.out = 10.
|
||||
for i in s.range(5):
|
||||
s.other_state = 1.
|
||||
s.out += inc
|
||||
return s.out
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(ValueError,
|
||||
"New mutable state 'other_state' cannot be created inside a loop."):
|
||||
f_op(2.)
|
||||
|
||||
def test_error_range_ends_static(self):
|
||||
def f_op(start, end, inc):
|
||||
with loops.Scope() as s:
|
||||
s.out = 0.
|
||||
for i in s.range(start, end):
|
||||
s.out += inc
|
||||
return s.out
|
||||
|
||||
self.assertAllClose(16., f_op(0, 4, 4.), check_dtypes=True)
|
||||
# Ok to jit, as long as the start and end are static
|
||||
self.assertAllClose(16., api.jit(f_op, static_argnums=(0, 1))(0, 4, 4.), check_dtypes=True)
|
||||
with self.assertRaisesRegex(TypeError, "Abstract value passed to `int`, which requires a concrete value"):
|
||||
self.assertAllClose(16., api.jit(f_op)(0, 4, 4.), check_dtypes=True)
|
||||
with self.assertRaisesRegex(TypeError, "Abstract value passed to `int`, which requires a concrete value"):
|
||||
self.assertAllClose(16., api.vmap(f_op)(np.zeros(10), np.ones(10), np.array([4.] * 10)), check_dtypes=True)
|
||||
|
||||
def test_cond(self):
|
||||
def f_op(inc):
|
||||
with loops.Scope() as s:
|
||||
s.out = 10.
|
||||
for i in s.cond_range(inc > 0):
|
||||
s.out += inc
|
||||
return s.out
|
||||
|
||||
self.assertAllClose(10. + 2., f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(10., f_op(-2.), check_dtypes=True)
|
||||
|
||||
def test_cond_state(self):
|
||||
"""Conditionals predicated on scope fields."""
|
||||
def f_op(init):
|
||||
with loops.Scope() as s:
|
||||
s.out = init
|
||||
for _ in s.cond_range(s.out > 0.):
|
||||
s.out *= 2.
|
||||
return s.out
|
||||
|
||||
self.assertAllClose(2. * 2., f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(-2., f_op(-2.), check_dtypes=True)
|
||||
|
||||
def test_cond_nested(self):
|
||||
"""Nested conditionals."""
|
||||
def f_expected(init):
|
||||
"""Multi-linear function.
|
||||
x in (..0) x + 1.
|
||||
x in [0..10) x + 1 + 2 + 4
|
||||
x in [10..) x + 1 + 2 + 4 + 8
|
||||
"""
|
||||
out = init
|
||||
if out >= 0.:
|
||||
out += 2.
|
||||
if out - 2. >= 10.:
|
||||
out += 8.
|
||||
out += 4.
|
||||
out += 1.
|
||||
return out
|
||||
|
||||
def f_op(init):
|
||||
with loops.Scope() as s:
|
||||
s.out = init
|
||||
for _ in s.cond_range(s.out >= 0.):
|
||||
s.out += 2.
|
||||
for _ in s.cond_range(s.out - 2. >= 10.):
|
||||
s.out += 8.
|
||||
s.out += 4.
|
||||
s.out += 1.
|
||||
return s.out
|
||||
|
||||
for init in [-1., 0., 9., 10.]:
|
||||
self.assertAllClose(f_expected(init), f_op(init), check_dtypes=True)
|
||||
|
||||
|
||||
def test_error_cond_using_index_var(self):
|
||||
"""Conditionals should not use the iteration index value."""
|
||||
def f_op(inc):
|
||||
with loops.Scope() as s:
|
||||
s.out = 10.
|
||||
for i in s.cond_range(inc > 0):
|
||||
s.out += i
|
||||
return s.out
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"Body of cond_range or while_range should not use the index variable returned by iterator."):
|
||||
api.make_jaxpr(f_op)(2.)
|
||||
|
||||
def test_while(self):
|
||||
def f_op(init):
|
||||
with loops.Scope() as s:
|
||||
s.out = init
|
||||
for _ in s.while_range(lambda: s.out < 5.):
|
||||
s.out += 2.
|
||||
s.out += 1.
|
||||
return s.out
|
||||
def f_expected(init):
|
||||
out = init
|
||||
while out < 5.:
|
||||
out += 2.
|
||||
out += 1.
|
||||
return out
|
||||
|
||||
self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(1.), f_op(1.), check_dtypes=True)
|
||||
init_batch = np.array([1., 2., 3.])
|
||||
self.assertAllClose(np.array([f_expected(init) for init in init_batch]),
|
||||
api.vmap(f_op)(init_batch), check_dtypes=True)
|
||||
|
||||
def test_error_while_cond_mutation(self):
|
||||
"""Disallow mutation in the while conditional."""
|
||||
def f_op(init):
|
||||
with loops.Scope() as s:
|
||||
s.out = init
|
||||
|
||||
def cond_func():
|
||||
s.out += 1. # Not allowed
|
||||
return s.out < 5.
|
||||
|
||||
for _ in s.while_range(cond_func):
|
||||
s.out += 2.
|
||||
s.out += 1.
|
||||
return s.out
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(ValueError,
|
||||
"Conditional function modifies scope.out field."):
|
||||
f_op(0.)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
Loading…
x
Reference in New Issue
Block a user