mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
An implementation of an experimental syntactic sugar for 'for' loops.
See description in jax/experimental/loops.py.
This commit is contained in:
parent
9b853a4255
commit
d24c374d59
6
docs/jax.experimental.loops.rst
Normal file
6
docs/jax.experimental.loops.rst
Normal file
@ -0,0 +1,6 @@
|
||||
jax.experimental.loops module
|
||||
=============================
|
||||
|
||||
.. automodule:: jax.experimental.loops
|
||||
:members:
|
||||
:show-inheritance:
|
@ -4,6 +4,7 @@ jax.experimental package
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
jax.experimental.loops
|
||||
jax.experimental.optimizers
|
||||
jax.experimental.stax
|
||||
jax.experimental.vectorize
|
||||
|
504
jax/experimental/loops.py
Normal file
504
jax/experimental/loops.py
Normal file
@ -0,0 +1,504 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Loops is an **experimental** module for syntactic sugar for loops and control-flow in JAX.
|
||||
|
||||
The current implementation should convert loops correctly to JAX internal representation, and most transformation
|
||||
should work (see below), but we have not yet fine-tuned the performance of the resulting XLA compilation!
|
||||
|
||||
By default, loops and control-flow in JAX are executed and inlined during tracing. For example, in the following code
|
||||
the `for` loop is unrolled during JAX tracing::
|
||||
|
||||
arr = np.zeros(5)
|
||||
for i in range(arr.shape[0]):
|
||||
arr[i] = i + 2
|
||||
if i % 2 == 0:
|
||||
arr[i] += 1
|
||||
|
||||
In order to capture the structured control-flow one has to use the higher-order JAX operations, which require
|
||||
you to express the body of the loops and conditionals as functions,
|
||||
and the array updates using a functional style that returns an updated array, e.g.::
|
||||
|
||||
arr = np.zeros(5)
|
||||
def loop_body(i, acc_arr):
|
||||
let arr1 = lax.index_update(acc_arr, i, acc_arr[i] + 2)
|
||||
return lax.cond(i % 2 == 0,
|
||||
arr1,
|
||||
lambda arr1: lax.index_update(arr1, i, arr1[i] + 1),
|
||||
arr1,
|
||||
lambda arr1: arr1)
|
||||
arr = lax.fori(0, arr.shape[0], loop_body, arr)
|
||||
|
||||
With the utilities in this module you can write loops and conditionals that look closer to plain Python,
|
||||
as long as you keep the loop-carried state in a special `loops.scope` object and use `for` loops
|
||||
over special `scope.range` iterators::
|
||||
|
||||
from jax.experimental import loops
|
||||
with loops.scope() as s:
|
||||
s.arr = np.zeros(5) # Must create the mutable state of the loop as `scope` fields.
|
||||
for i in s.range(s.arr.shape[0]):
|
||||
s.arr = lax.index_update(s.arr, i, s.arr[i] + 2)
|
||||
for _ in s.cond_range(i % 2 == 0): # Conditionals are also sugared as loops with 0 or 1 iterations
|
||||
s.arr = lax.index_update(s.arr, i, s.arr[i] + 1)
|
||||
|
||||
|
||||
Notes:
|
||||
* Loops and conditionals to be functionalized can appear only inside scopes constructed with `loops.Scope`
|
||||
and they must use one of the `Scope.range` iterators. All other loops are unrolled during tracing, as usual in JAX.
|
||||
* Only scope data (stored in fields of the scope object) is functionalized. All other state, e.g., in other
|
||||
Python variables, will not be considered as being part of the loop output.
|
||||
All references to the mutable state should be through the scope: `s.arr`.
|
||||
* Conceptually, this model is still "functional" in the sense that a loop over a `Scope.range` behaves
|
||||
as a function whose input and output is the scope data.
|
||||
* Scopes should be passed down to callees that need to use loop functionalization, or they may be nested.
|
||||
* The programming model is that the loop body over a `scope.range` is traced only once, using abstract shape values,
|
||||
similar to how JAX traces function bodies.
|
||||
|
||||
Restrictions:
|
||||
* The tracing of the loop body should not exit prematurely with `return`, `exception`, `break`.
|
||||
This would be detected and reported as errors when we encounter unnested scopes.
|
||||
* The loop index variable should not be used after the loop. Similarly, one should not use outside the loop
|
||||
data computed in the loop body, except data stored in fields of the scope object.
|
||||
* No new mutable state can be created inside a loop to be functionalized. All mutable state must be created outside
|
||||
all loops and conditionals.
|
||||
* For a `while` loop, the conditional function is not allowed to modify the scope state. This is a checked error.
|
||||
Also, for `while` loops the `grad` transformation does not work. An alternative that allows `grad` is a bounded
|
||||
loop (`range`) with a conditional to ignore a suffix of iterations.
|
||||
|
||||
Transformations:
|
||||
* All transformations are supported, except `grad` is not supported for `Scope.while_range` loops.
|
||||
* `vmap` is very useful for such loops because it pushes more work into the inner-loops, which should help
|
||||
performance for accelerators.
|
||||
|
||||
For usage example, see tests/loops_test.py.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
import itertools
|
||||
import numpy as onp
|
||||
import traceback
|
||||
from warnings import warn
|
||||
|
||||
from jax import abstract_arrays
|
||||
from jax import lax, core
|
||||
from jax.lax import lax_control_flow
|
||||
from jax import tree_util
|
||||
from jax import numpy as jnp
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.util import unzip2, safe_map
|
||||
|
||||
|
||||
class Scope(object):
|
||||
"""A scope context manager to keep the state of loop bodies for functionalization.
|
||||
|
||||
Usage::
|
||||
|
||||
with Scope() as s:
|
||||
s.data = 0.
|
||||
for i in s.range(5):
|
||||
s.data += 1.
|
||||
return s.data
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._mutable_state = {} # Keep here the state to be functionalized, indexed by name.
|
||||
self._active_ranges = [] # The stack of active ranges, the last one is the innermost.
|
||||
|
||||
def range(self, first, second=None, third=None):
|
||||
"""Creates a range for bounded iterations to be functionalized.
|
||||
|
||||
The body is converted to a `lax.scan`, for which all JAX transformations work.
|
||||
|
||||
Usage::
|
||||
|
||||
range(5) # start=0, end=5, step=1
|
||||
range(1, 5) # start=1, end=5, step=1
|
||||
range(1, 5, 2) # start=1, end=5, step=2
|
||||
|
||||
s.out = 1.
|
||||
for i in scope.range(5):
|
||||
s.out += 1.
|
||||
"""
|
||||
if third is not None:
|
||||
start = int(first)
|
||||
stop = int(second)
|
||||
step = int(third)
|
||||
else:
|
||||
step = 1
|
||||
if second is not None:
|
||||
start = int(first)
|
||||
stop = int(second)
|
||||
else:
|
||||
start = 0
|
||||
stop = int(first)
|
||||
return _BodyTracer(self, _BoundedLoopBuilder(start, stop, step))
|
||||
|
||||
def cond_range(self, pred):
|
||||
"""Creates a conditional range with 0 or 1 iterations based on the boolean "pred".
|
||||
|
||||
The body is converted to a `lax.cond`. All JAX transformations work.
|
||||
|
||||
Usage::
|
||||
|
||||
for _ in scope.cond_range(s.field < 0.):
|
||||
s.field = - s.field
|
||||
"""
|
||||
# TODO: share these checks with lax_control_flow.cond
|
||||
if len(onp.shape(pred)) != 0:
|
||||
raise TypeError("Pred must be a scalar, got {} of shape {}.".format(pred, onp.shape(pred)))
|
||||
|
||||
try:
|
||||
pred_dtype = onp.result_type(pred)
|
||||
except TypeError:
|
||||
msg = ("Pred type must be either boolean or number, got {}.")
|
||||
raise TypeError(msg.format(pred))
|
||||
|
||||
if pred_dtype.kind != 'b':
|
||||
if pred_dtype.kind in 'iuf':
|
||||
pred = pred != 0
|
||||
else:
|
||||
msg = ("Pred type must be either boolean or number, got {}.")
|
||||
raise TypeError(msg.format(pred_dtype))
|
||||
|
||||
return _BodyTracer(self, _CondBuilder(pred))
|
||||
|
||||
def while_range(self, cond_func):
|
||||
"""Creates a range that continues as long as `cond_func` returns true.
|
||||
|
||||
The body is converted to a `lax.while_loop`. The `grad` transformation does not work.
|
||||
|
||||
Usage::
|
||||
|
||||
for _ in scope.while_range(lambda: s.loss > 1.e-5):
|
||||
s.loss = loss(...)
|
||||
|
||||
Args:
|
||||
cond_func: a lambda with no arguments, the condition for the "while".
|
||||
"""
|
||||
return _BodyTracer(self, _WhileBuilder(cond_func))
|
||||
|
||||
def _push_range(self, range):
|
||||
for ar in self._active_ranges:
|
||||
if ar is range:
|
||||
raise ValueError("Range is reused nested inside itself.") # No need to include location, range is in stacktrace
|
||||
self._active_ranges.append(range)
|
||||
|
||||
def _pop_range(self, range):
|
||||
if not (range is self._active_ranges[-1]):
|
||||
self._error_premature_exit_range()
|
||||
self._active_ranges.pop()
|
||||
|
||||
def _error_premature_exit_range(self):
|
||||
"""Raises error about premature exit from a range"""
|
||||
msg = "Some ranges have exited prematurely. The innermost such range is at\n{}"
|
||||
raise ValueError(msg.format(self._active_ranges[-1].location()))
|
||||
|
||||
def __getattr__(self, key):
|
||||
"""Accessor for scope data.
|
||||
|
||||
Called only if the attribute is not found, which will happen when we read
|
||||
scope data that has been stored in self._mutable_state.
|
||||
"""
|
||||
mt_val = self._mutable_state.get(key)
|
||||
if mt_val is None:
|
||||
raise AttributeError("Reading uninitialized data '{}' from the scope.".format(key))
|
||||
return mt_val
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
"""Update scope data to be functionalized.
|
||||
|
||||
Called for *all* attribute setting.
|
||||
"""
|
||||
if key in ["_active_ranges", "_mutable_state"]:
|
||||
object.__setattr__(self, key, value)
|
||||
else:
|
||||
if self._active_ranges and key not in self._mutable_state:
|
||||
raise ValueError("New mutable state '{}' cannot be created inside a loop.".format(key))
|
||||
self._mutable_state[key] = value
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
if self._active_ranges: # We have some ranges that we did not exit properly
|
||||
self._error_premature_exit_range()
|
||||
return True
|
||||
else:
|
||||
# The exception may come from inside one or more ranges. We let the current
|
||||
# exception propagate, assuming it terminates the tracing. If not, the
|
||||
# tracers may be left in an inconsistent state.
|
||||
return False # re-raise
|
||||
|
||||
|
||||
class _BodyTracer(object):
|
||||
"""Traces the body of the loop and builds a functional control-flow representation.
|
||||
|
||||
This class is also an iterator, only the first iteration is traced.
|
||||
"""
|
||||
|
||||
def __init__(self, scope, loop_builder):
|
||||
"""
|
||||
Params:
|
||||
scope: the current scope
|
||||
loop_builder: instance of _LoopBuilder
|
||||
"""
|
||||
self.scope = scope
|
||||
self.loop_builder = loop_builder
|
||||
self.first_iteration = True # If we are tracing the first iteration
|
||||
self.stack = traceback.StackSummary.from_list(traceback.extract_stack()[:-2]) # Stack trace, without this line and the s.range function
|
||||
|
||||
# The rest is state kept from the start of the first iteration to the end of the iteration.
|
||||
self.carried_state_initial = {} # The initial values of the mutable state upon entering the range body.
|
||||
self.carried_state_vars = {} # The parameters that were created for state upon entering an arbitrary iteration.
|
||||
|
||||
self.trace = None
|
||||
self.carried_state_names = None # List of scope fields carried through the loop
|
||||
self.init_tree = None # The PyTreeDef corresponding to carried_state_names
|
||||
self.init_vals = None # The values corresponding to self.init_tree
|
||||
|
||||
def location(self):
|
||||
"""A multiline string representing the source location of the range."""
|
||||
return " ".join(self.stack.format())
|
||||
|
||||
def __iter__(self):
|
||||
"""Called before starting the first iteration."""
|
||||
self.first_iteration = True # In case we reuse the range
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.first_iteration:
|
||||
self.first_iteration = False
|
||||
self.scope._push_range(self)
|
||||
self.start_tracing_body()
|
||||
return self._index_var
|
||||
else:
|
||||
self.end_tracing_body()
|
||||
self.scope._pop_range(self)
|
||||
raise StopIteration # Trace only one iteration.
|
||||
|
||||
def start_tracing_body(self):
|
||||
"""Called upon starting the tracing of the loop body."""
|
||||
# Make a copy of the current value of the mutable state
|
||||
self.carried_state_initial = copy.copy(self.scope._mutable_state)
|
||||
# The entire state is carried.
|
||||
self.carried_state_names = sorted(self.scope._mutable_state.keys())
|
||||
|
||||
# TODO: This is the first part of partial_eval.trace_to_subjaxpr. Share.
|
||||
self.trace = _BodyTracer.start_subtrace()
|
||||
# Set the scope._mutable_state to new tracing variables.
|
||||
for key, initial in self.carried_state_initial.items():
|
||||
mt_aval = _BodyTracer.abstractify(initial)
|
||||
mt_pval = pe.PartialVal((mt_aval, core.unit))
|
||||
mt_var = self.trace.new_arg(mt_pval)
|
||||
self.carried_state_vars[key] = mt_var
|
||||
self.scope._mutable_state[key] = mt_var
|
||||
|
||||
index_var_aval = _BodyTracer.abstractify(0)
|
||||
index_var_pval = pe.PartialVal((index_var_aval, core.unit))
|
||||
self._index_var = self.trace.new_arg(index_var_pval)
|
||||
|
||||
def end_tracing_body(self):
|
||||
"""Called when we are done tracing one iteration of the body."""
|
||||
# We will turn the body of the loop into a function that takes some values for the scope state (carried_state_names)
|
||||
# and returns the values for the same state fields after one execution of the body. For some of the ranges,
|
||||
# e.g., scope.range, the function will also take the index_var as last parameter.
|
||||
in_tracers = [self.carried_state_vars[ms] for ms in self.carried_state_names]
|
||||
if self.loop_builder.can_use_index_var():
|
||||
in_tracers += [self._index_var]
|
||||
|
||||
# Make the jaxpr for the body of the loop
|
||||
# TODO: See which mutable state was changed in the one iteration. For now, we assume all state changes.
|
||||
body_out_tracers = tuple([self.scope._mutable_state[ms] for ms in self.carried_state_names])
|
||||
try:
|
||||
# If the body actually uses the index variable, and is not allowed to (e.g., cond_range and while_range),
|
||||
# then in_tracers will not contain the tracer for the index_var, and trace_to_jaxpr_finalize will throw
|
||||
# an assertion error.
|
||||
body_typed_jaxpr, body_const_vals = _BodyTracer.trace_to_jaxpr_finalize(in_tracers=in_tracers,
|
||||
out_tracers=body_out_tracers,
|
||||
trace=self.trace)
|
||||
except AssertionError as e:
|
||||
if "Encountered unexpected tracer" == str(e):
|
||||
raise ValueError("Body of cond_range or while_range should not use the index variable returned by iterator.")
|
||||
raise
|
||||
_BodyTracer.end_subtrace() # End the subtrace for the loop body, before we trace the condition
|
||||
|
||||
carried_init_val = tuple([self.carried_state_initial[ms] for ms in self.carried_state_names])
|
||||
carried_init_vals, carried_tree = tree_util.tree_flatten(carried_init_val)
|
||||
|
||||
carried_out_vals = self.loop_builder.build_output_vals(
|
||||
self.scope, self.carried_state_names, carried_tree, carried_init_vals, body_typed_jaxpr, body_const_vals)
|
||||
carried_mutable_state_unflattened = tree_util.tree_unflatten(carried_tree, carried_out_vals)
|
||||
|
||||
# Update the mutable state with the values of the changed vars, after the loop.
|
||||
for ms, mv in zip(self.carried_state_names, carried_mutable_state_unflattened):
|
||||
self.scope._mutable_state[ms] = mv
|
||||
|
||||
@staticmethod
|
||||
def start_subtrace():
|
||||
"""Starts a nested trace, returns the Trace object."""
|
||||
# TODO: This follows the __enter__ part of core.new_master. share
|
||||
level = core.trace_state.trace_stack.next_level(False)
|
||||
master = core.MasterTrace(level, pe.JaxprTrace)
|
||||
core.trace_state.trace_stack.push(master, False)
|
||||
return pe.JaxprTrace(master, core.cur_sublevel())
|
||||
|
||||
@staticmethod
|
||||
def end_subtrace():
|
||||
# TODO: This follows the __exit__ part of core.new_master
|
||||
core.trace_state.trace_stack.pop(False)
|
||||
|
||||
@staticmethod
|
||||
def abstractify(x):
|
||||
return abstract_arrays.raise_to_shaped(core.get_aval(x))
|
||||
|
||||
@staticmethod
|
||||
def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True):
|
||||
# TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
|
||||
instantiate = [instantiate] * len(out_tracers)
|
||||
out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers))
|
||||
out_tracers = safe_map(partial(pe.instantiate_const_at, trace), instantiate, out_tracers)
|
||||
jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
|
||||
out_pvals = [t.pval for t in out_tracers]
|
||||
# TODO: this is from partial_eval.trace_to_jaxpr. Share.
|
||||
assert not env
|
||||
|
||||
# TODO: this is from the final part of lax_control_flow._initial_style_jaxpr
|
||||
out_avals = safe_map(abstract_arrays.raise_to_shaped, unzip2(out_pvals)[0])
|
||||
const_avals = tuple(abstract_arrays.raise_to_shaped(core.get_aval(c)) for c in consts)
|
||||
|
||||
in_pvals = [t.pval for t in in_tracers]
|
||||
in_avals = tuple(safe_map(abstract_arrays.raise_to_shaped, unzip2(in_pvals)[0]))
|
||||
|
||||
typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr),
|
||||
(), const_avals + in_avals, out_avals)
|
||||
return typed_jaxpr, consts
|
||||
|
||||
|
||||
class _LoopBuilder(object):
|
||||
"""Abstract superclass for the loop builders"""
|
||||
|
||||
def can_use_index_var(self):
|
||||
"""Whether this kind of loop can use the index var returned by the range iterator."""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_typed_jaxpr, body_const_vals):
|
||||
"""Builds the output values for the loop carried state.
|
||||
|
||||
Params:
|
||||
scope: the current Scope object.
|
||||
carried_state_names: the list of names of mutable state fields that is carried through the body.
|
||||
carried_tree: the PyTreeDef for the tuple of carried_state_names.
|
||||
init_vals: the initial values on body entry corresponding to the init_tree.
|
||||
body_typed_jaxpr: the Jaxpr for the body returning the new values of carried_state_names.
|
||||
body_const_vals: the constant values for the body.
|
||||
|
||||
Returns:
|
||||
the output tracer corresponding to the lax primitive representing the loop.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __str__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _BoundedLoopBuilder(_LoopBuilder):
|
||||
"""Builds a lax operation corresponding to a bounded range iteration."""
|
||||
|
||||
def __init__(self, start, stop, step):
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.step = step
|
||||
self._index_var = None # The parameter for the index variable
|
||||
|
||||
def can_use_index_var(self):
|
||||
return True
|
||||
|
||||
def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_typed_jaxpr, body_const_vals):
|
||||
arange_val = jnp.arange(self.start, stop=self.stop, step=self.step)
|
||||
return lax_control_flow.scan_p.bind(*itertools.chain(body_const_vals, init_vals, [arange_val]),
|
||||
forward=True, length=arange_val.shape[0], jaxpr=body_typed_jaxpr,
|
||||
num_consts=len(body_const_vals), num_carry=len(init_vals),
|
||||
linear=(False,) * (len(body_const_vals) + len(init_vals) + 1))
|
||||
|
||||
|
||||
class _CondBuilder(_LoopBuilder):
|
||||
"""Builds a lax.cond operation."""
|
||||
|
||||
def __init__(self, pred):
|
||||
self.pred = pred
|
||||
|
||||
def can_use_index_var(self):
|
||||
return False
|
||||
|
||||
def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_typed_jaxpr, body_const_vals):
|
||||
# Simulate a pass-through false branch
|
||||
init_avals = safe_map(_BodyTracer.abstractify, init_vals)
|
||||
false_body_typed_jaxpr, false_body_const_vals, _ = lax_control_flow._initial_style_jaxpr(lambda *args: args,
|
||||
carried_tree,
|
||||
tuple(init_avals))
|
||||
return lax_control_flow.cond_p.bind(
|
||||
*itertools.chain([self.pred], body_const_vals, init_vals, false_body_const_vals, init_vals),
|
||||
true_jaxpr=body_typed_jaxpr, false_jaxpr=false_body_typed_jaxpr,
|
||||
true_nconsts=len(body_const_vals), false_nconsts=len(false_body_const_vals))
|
||||
|
||||
|
||||
class _WhileBuilder(_LoopBuilder):
|
||||
"""Builds a lax.while operation."""
|
||||
|
||||
def __init__(self, cond_func):
|
||||
self.cond_func = cond_func # Function with 0 arguments (can reference the scope)
|
||||
|
||||
def can_use_index_var(self):
|
||||
return False
|
||||
|
||||
def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_typed_jaxpr, body_const_vals):
|
||||
# Trace the conditional function. cond_func takes 0 arguments, but for lax.while we need a conditional function
|
||||
# that takes the carried_state_names. _initial_style_jaxpr will start its own trace and will create tracers for
|
||||
# all the carried state. We must put these values in the scope._mutable_state before we trace the conditional
|
||||
# function.
|
||||
def cond_func_wrapped(*args):
|
||||
assert len(args) == len(carried_state_names)
|
||||
for ms, init_ms in zip(carried_state_names, args):
|
||||
scope._mutable_state[ms] = init_ms
|
||||
res = self.cond_func()
|
||||
# Conditional function is not allowed to modify the scope state
|
||||
for ms, init_ms in zip(carried_state_names, args):
|
||||
if not (scope._mutable_state[ms] is init_ms):
|
||||
msg = "Conditional function modifies scope.{} field."
|
||||
raise ValueError(msg.format(ms))
|
||||
return res
|
||||
|
||||
init_avals = safe_map(_BodyTracer.abstractify, init_vals)
|
||||
cond_jaxpr, cond_consts, cond_tree = lax_control_flow._initial_style_jaxpr(cond_func_wrapped,
|
||||
carried_tree,
|
||||
tuple(init_avals))
|
||||
# TODO: share these checks with lax_control_flow.while
|
||||
if not tree_util.treedef_is_leaf(cond_tree):
|
||||
msg = "cond_fun must return a boolean scalar, but got pytree {}."
|
||||
raise TypeError(msg.format(cond_tree))
|
||||
if cond_jaxpr.out_avals != [abstract_arrays.ShapedArray((), onp.bool_)]:
|
||||
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
|
||||
raise TypeError(msg.format(cond_jaxpr.out_avals))
|
||||
|
||||
return lax_control_flow.while_p.bind(*itertools.chain(cond_consts, body_const_vals, init_vals),
|
||||
cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
|
||||
body_nconsts=len(body_const_vals), body_jaxpr=body_typed_jaxpr)
|
||||
|
@ -400,7 +400,7 @@ def tracers_to_jaxpr(in_tracers, out_tracers):
|
||||
eqns.append(eqn_tracer_to_var(var, recipe))
|
||||
processed_eqns.add(recipe.eqn_id)
|
||||
elif isinstance(recipe, LambdaBinding):
|
||||
assert any(t is in_tracer for in_tracer in in_tracers)
|
||||
assert any(t is in_tracer for in_tracer in in_tracers), "Encountered unexpected tracer"
|
||||
assert in_tracers, "Lambda binding with no args"
|
||||
elif isinstance(recipe, FreeVar):
|
||||
env[var(t)] = recipe.val
|
||||
|
369
tests/loops_test.py
Normal file
369
tests/loops_test.py
Normal file
@ -0,0 +1,369 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the experimental/loops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as onp
|
||||
import re
|
||||
|
||||
from jax import api, ops
|
||||
from jax import numpy as np
|
||||
from jax import test_util as jtu
|
||||
from jax.experimental import loops
|
||||
|
||||
|
||||
class LoopsSugarTest(jtu.JaxTestCase):
|
||||
|
||||
def test_scope_no_loops(self):
|
||||
def f_op(r):
|
||||
with loops.Scope() as s:
|
||||
s.x = r + 1
|
||||
return s.x
|
||||
self.assertAllClose(4.0, f_op(3.), check_dtypes=True)
|
||||
|
||||
def test_loop_empty(self):
|
||||
def f_op(r):
|
||||
with loops.Scope() as s:
|
||||
for _ in s.range(5):
|
||||
pass
|
||||
return r
|
||||
|
||||
self.assertAllClose(3.0, f_op(3.), check_dtypes=True)
|
||||
|
||||
def test_loop_1(self):
|
||||
"""One loop with one state var, with transforms."""
|
||||
def f_op(inc):
|
||||
with loops.Scope() as s:
|
||||
s.out = 10.
|
||||
for _ in s.range(5):
|
||||
s.out += inc
|
||||
return s.out
|
||||
def f_expected(inc):
|
||||
return 10 + 5 * inc
|
||||
self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True)
|
||||
self.assertAllClose(5., api.grad(f_op)(2.), check_dtypes=True)
|
||||
self.assertAllClose(5., api.grad(f_op)(2.), check_dtypes=True)
|
||||
inc_batch = onp.arange(5, dtype=onp.float32)
|
||||
self.assertAllClose(np.array([f_expected(inc) for inc in inc_batch]),
|
||||
api.vmap(f_op)(inc_batch), check_dtypes=True)
|
||||
|
||||
|
||||
def test_loop_2(self):
|
||||
"""One loop, two state fields."""
|
||||
def f_op(inc):
|
||||
with loops.Scope() as s:
|
||||
s.out1 = 10.
|
||||
s.out2 = 20.
|
||||
for i in s.range(5):
|
||||
s.out1 += inc
|
||||
s.out2 += 1.
|
||||
return (s.out1, s.out2)
|
||||
|
||||
self.assertAllClose((10. + 2. * 5, 20. + 1. * 5), f_op(2.), check_dtypes=True)
|
||||
|
||||
|
||||
def test_add_vectors(self):
|
||||
def add_vec(x, y):
|
||||
with loops.Scope() as s:
|
||||
n = x.shape[0]
|
||||
assert n == y.shape[0]
|
||||
s.out = np.zeros(shape=[n], dtype=np.float32)
|
||||
for i in s.range(n):
|
||||
s.out = ops.index_add(s.out, i, x[i] + y[i])
|
||||
return s.out
|
||||
|
||||
x = np.array([1., 2., 3.])
|
||||
y = np.array([4., 5., 6.])
|
||||
self.assertAllClose(np.add(x, y), add_vec(x, y), check_dtypes=True)
|
||||
|
||||
def test_matmul(self):
|
||||
def matmul(x, y):
|
||||
with loops.Scope() as s:
|
||||
n, m = x.shape
|
||||
m1, p = y.shape
|
||||
assert m == m1
|
||||
s.out = np.zeros(shape=[n, p], dtype=np.float32)
|
||||
for i in s.range(n):
|
||||
for j in s.range(p):
|
||||
for k in s.range(m):
|
||||
s.out = ops.index_add(s.out, (i, j), x[i, k] * y[k, j])
|
||||
return s.out
|
||||
|
||||
x = np.array([[1., 2., 3.]]) # 1x3
|
||||
y = np.array([[4.], [5.], [6.]]) # 3x1
|
||||
self.assertAllClose(np.matmul(x, y), matmul(x, y), check_dtypes=True)
|
||||
|
||||
def test_reuse_range(self):
|
||||
"""Ranges can be reused, as long as not nested in each other."""
|
||||
def f_op():
|
||||
with loops.Scope() as s:
|
||||
r1 = s.range(5)
|
||||
s.out = 0
|
||||
for _ in r1:
|
||||
s.out += 1
|
||||
for _ in r1:
|
||||
s.out += 1
|
||||
return s.out
|
||||
|
||||
self.assertEqual(10, f_op())
|
||||
|
||||
|
||||
def test_loop_nested(self):
|
||||
def f_op(inc):
|
||||
with loops.Scope() as s:
|
||||
s.out = 10.
|
||||
for i in s.range(5):
|
||||
s.out += inc
|
||||
for j in s.range(6):
|
||||
s.out += inc
|
||||
return s.out
|
||||
|
||||
self.assertAllClose(10. + 5 * (2. + 6 * 2.), f_op(2.), check_dtypes=True)
|
||||
|
||||
def test_loop_mutable_used_but_not_changed(self):
|
||||
def f_op(inc):
|
||||
with loops.Scope() as s:
|
||||
s.read_only = inc
|
||||
s.out = 10.
|
||||
for i in s.range(5):
|
||||
s.out += s.read_only
|
||||
# It is Ok to use regular Python variables outside loops.
|
||||
save_to_other_var = s.out
|
||||
|
||||
return save_to_other_var
|
||||
|
||||
self.assertAllClose(10. + 5 * 2., f_op(2.), check_dtypes=True)
|
||||
|
||||
def test_range_locations(self):
|
||||
"""Ranges have locations."""
|
||||
with loops.Scope() as s:
|
||||
r = s.range(5)
|
||||
cr = s.cond_range(True)
|
||||
wr = s.while_range(lambda: True)
|
||||
for range in [r, cr, wr]:
|
||||
self.assertIn("loops_test.py", range.location())
|
||||
self.assertIn(self._testMethodName, range.location())
|
||||
|
||||
def test_error_reuse_range_nested(self):
|
||||
"""Ranges cannot be reused nested in their own iteration."""
|
||||
def f_op():
|
||||
with loops.Scope() as s:
|
||||
r1 = s.range(5)
|
||||
s.out = 0
|
||||
for _ in r1:
|
||||
for _ in r1:
|
||||
s.out += 1
|
||||
return s.out
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "Range is reused nested inside itself."):
|
||||
f_op()
|
||||
|
||||
def test_error_early_exit_range(self):
|
||||
"""Ranges do not support early exit from loop body."""
|
||||
def bad_function(exit_how="break"):
|
||||
with loops.Scope() as s:
|
||||
for i in s.range(555):
|
||||
if exit_how == "break":
|
||||
break
|
||||
elif exit_how == "return":
|
||||
return 1.
|
||||
elif exit_how == "exception":
|
||||
raise ValueError("test exception")
|
||||
# Start another range, we get here after a "break" above
|
||||
for i in s.range(5):
|
||||
pass
|
||||
return 0.
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.compile(("Some ranges have exited prematurely. The innermost such range is at"
|
||||
".*s.range.555."), re.DOTALL)):
|
||||
bad_function("break")
|
||||
with self.assertRaisesRegex(ValueError, "Some ranges have exited prematurely"):
|
||||
bad_function("return")
|
||||
# On exception exit, we let the exception propagate
|
||||
with self.assertRaisesRegex(ValueError, "test exception"):
|
||||
bad_function("exception")
|
||||
|
||||
def test_error_early_exit_range_nested(self):
|
||||
"""Exit early from a nested range."""
|
||||
def bad_function():
|
||||
with loops.Scope() as s:
|
||||
for i in s.range(5): # When we end this range, we'll find the inner range still active
|
||||
for j in s.range(6):
|
||||
break
|
||||
return 0.
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Some ranges have exited prematurely."):
|
||||
bad_function()
|
||||
|
||||
def test_loop_index_var_live_expect_fail(self):
|
||||
"""The index variable is live after the loop."""
|
||||
self.skipTest("Don't know how to check that index variable is not used after loop.")
|
||||
def f_op(r):
|
||||
with loops.Scope() as s:
|
||||
for i in s.range(r):
|
||||
pass
|
||||
return i
|
||||
|
||||
self.assertAllClose(4, f_op(4), check_dtypes=True)
|
||||
|
||||
def test_error_new_state_in_loop(self):
|
||||
"""Error when creating new state in a loop."""
|
||||
def f_op(inc):
|
||||
with loops.Scope() as s:
|
||||
s.out = 10.
|
||||
for i in s.range(5):
|
||||
s.other_state = 1.
|
||||
s.out += inc
|
||||
return s.out
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(ValueError,
|
||||
"New mutable state 'other_state' cannot be created inside a loop."):
|
||||
f_op(2.)
|
||||
|
||||
def test_error_range_ends_static(self):
|
||||
def f_op(start, end, inc):
|
||||
with loops.Scope() as s:
|
||||
s.out = 0.
|
||||
for i in s.range(start, end):
|
||||
s.out += inc
|
||||
return s.out
|
||||
|
||||
self.assertAllClose(16., f_op(0, 4, 4.), check_dtypes=True)
|
||||
# Ok to jit, as long as the start and end are static
|
||||
self.assertAllClose(16., api.jit(f_op, static_argnums=(0, 1))(0, 4, 4.), check_dtypes=True)
|
||||
with self.assertRaisesRegex(TypeError, "Abstract value passed to `int`, which requires a concrete value"):
|
||||
self.assertAllClose(16., api.jit(f_op)(0, 4, 4.), check_dtypes=True)
|
||||
with self.assertRaisesRegex(TypeError, "Abstract value passed to `int`, which requires a concrete value"):
|
||||
self.assertAllClose(16., api.vmap(f_op)(np.zeros(10), np.ones(10), np.array([4.] * 10)), check_dtypes=True)
|
||||
|
||||
def test_cond(self):
|
||||
def f_op(inc):
|
||||
with loops.Scope() as s:
|
||||
s.out = 10.
|
||||
for i in s.cond_range(inc > 0):
|
||||
s.out += inc
|
||||
return s.out
|
||||
|
||||
self.assertAllClose(10. + 2., f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(10., f_op(-2.), check_dtypes=True)
|
||||
|
||||
def test_cond_state(self):
|
||||
"""Conditionals predicated on scope fields."""
|
||||
def f_op(init):
|
||||
with loops.Scope() as s:
|
||||
s.out = init
|
||||
for _ in s.cond_range(s.out > 0.):
|
||||
s.out *= 2.
|
||||
return s.out
|
||||
|
||||
self.assertAllClose(2. * 2., f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(-2., f_op(-2.), check_dtypes=True)
|
||||
|
||||
def test_cond_nested(self):
|
||||
"""Nested conditionals."""
|
||||
def f_expected(init):
|
||||
"""Multi-linear function.
|
||||
x in (..0) x + 1.
|
||||
x in [0..10) x + 1 + 2 + 4
|
||||
x in [10..) x + 1 + 2 + 4 + 8
|
||||
"""
|
||||
out = init
|
||||
if out >= 0.:
|
||||
out += 2.
|
||||
if out - 2. >= 10.:
|
||||
out += 8.
|
||||
out += 4.
|
||||
out += 1.
|
||||
return out
|
||||
|
||||
def f_op(init):
|
||||
with loops.Scope() as s:
|
||||
s.out = init
|
||||
for _ in s.cond_range(s.out >= 0.):
|
||||
s.out += 2.
|
||||
for _ in s.cond_range(s.out - 2. >= 10.):
|
||||
s.out += 8.
|
||||
s.out += 4.
|
||||
s.out += 1.
|
||||
return s.out
|
||||
|
||||
for init in [-1., 0., 9., 10.]:
|
||||
self.assertAllClose(f_expected(init), f_op(init), check_dtypes=True)
|
||||
|
||||
|
||||
def test_error_cond_using_index_var(self):
|
||||
"""Conditionals should not use the iteration index value."""
|
||||
def f_op(inc):
|
||||
with loops.Scope() as s:
|
||||
s.out = 10.
|
||||
for i in s.cond_range(inc > 0):
|
||||
s.out += i
|
||||
return s.out
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"Body of cond_range or while_range should not use the index variable returned by iterator."):
|
||||
api.make_jaxpr(f_op)(2.)
|
||||
|
||||
def test_while(self):
|
||||
def f_op(init):
|
||||
with loops.Scope() as s:
|
||||
s.out = init
|
||||
for _ in s.while_range(lambda: s.out < 5.):
|
||||
s.out += 2.
|
||||
s.out += 1.
|
||||
return s.out
|
||||
def f_expected(init):
|
||||
out = init
|
||||
while out < 5.:
|
||||
out += 2.
|
||||
out += 1.
|
||||
return out
|
||||
|
||||
self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(1.), f_op(1.), check_dtypes=True)
|
||||
init_batch = np.array([1., 2., 3.])
|
||||
self.assertAllClose(np.array([f_expected(init) for init in init_batch]),
|
||||
api.vmap(f_op)(init_batch), check_dtypes=True)
|
||||
|
||||
def test_error_while_cond_mutation(self):
|
||||
"""Disallow mutation in the while conditional."""
|
||||
def f_op(init):
|
||||
with loops.Scope() as s:
|
||||
s.out = init
|
||||
|
||||
def cond_func():
|
||||
s.out += 1. # Not allowed
|
||||
return s.out < 5.
|
||||
|
||||
for _ in s.while_range(cond_func):
|
||||
s.out += 2.
|
||||
s.out += 1.
|
||||
return s.out
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(ValueError,
|
||||
"Conditional function modifies scope.out field."):
|
||||
f_op(0.)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
Loading…
x
Reference in New Issue
Block a user