rocm_jax/tests/lax_control_flow_test.py
2019-08-23 16:40:49 -07:00

962 lines
28 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from functools import partial
import itertools
from unittest import SkipTest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as onp
import numpy.random as npr
from jax import api
from jax import core
from jax import lax
from jax import test_util as jtu
from jax.util import unzip2
from jax.lib import xla_bridge
import jax.numpy as np # scan tests use numpy
def scan_reference(f, init, xs):
carry = init
ys = []
for x in xs:
(carry, y) = f(carry, x)
ys.append(lax.reshape(y, (1,) + onp.shape(y)))
ys = lax.concatenate(ys, 0)
return carry, ys
class LaxControlFlowTest(jtu.JaxTestCase):
def testWhileWithTuple(self):
limit = 10
def loop_cond(state):
pos, _ = state
return lax.lt(pos, limit)
def loop_body(state):
pos, count = state
return (lax.add(pos, 1), lax.add(count, 1))
def loop(init):
result = lax.while_loop(loop_cond, loop_body, (init, 0))
_, count = result
return count
cloop = api.jit(loop)
self.assertEqual(loop(2), limit - 2)
self.assertEqual(cloop(2), limit - 2)
self.assertEqual(cloop(2), limit - 2)
self.assertEqual(cloop(3), limit - 3)
def testNestedWhile(self):
def outer_loop(num): # pylint: disable=missing-docstring
def cond_fun(state):
num, i, _ = state
return lax.lt(i, num)
def body_fun(state):
num, i, count = state
return (num, lax.add(i, 1), inner_loop(i, count))
init_val = (num, 0, 0)
_, i, count = lax.while_loop(cond_fun, body_fun, init_val)
return (i, count)
def inner_loop(i, count): # pylint: disable=missing-docstring
def cond_fun(state):
i, j, _ = state
return lax.le(j, i)
def body_fun(state):
i, j, count = state
return (i, lax.add(j, 1), lax.add(count, 1))
init_val = (i, 0, count)
_, _, count = lax.while_loop(cond_fun, body_fun, init_val)
return count
cloop = api.jit(outer_loop)
self.assertEqual(outer_loop(3), (3, 6))
self.assertEqual(cloop(3), (3, 6))
self.assertEqual(cloop(3), (3, 6))
self.assertEqual(cloop(2), (2, 3))
self.assertEqual(cloop(4), (4, 10))
def testWhileWithClosure(self):
def loop(init, local_limit, inc):
def loop_cond(state):
pos, _ = state
return lax.lt(pos, local_limit)
def loop_body(state):
effect[0] = True
pos, count = state
return (lax.add(pos, 1), lax.add(count, inc))
result = lax.while_loop(loop_cond, loop_body, (init, 0))
_, count = result
return count
cloop = api.jit(loop)
limit = 10
effect = [False]
self.assertEqual(loop(2, limit, 1), limit - 2)
assert effect[0]
effect[0] = False
self.assertEqual(cloop(2, limit, 1), limit - 2)
assert effect[0]
effect[0] = False
self.assertEqual(cloop(2, limit, 1), limit - 2)
self.assertEqual(cloop(3, limit, 1), limit - 3)
assert not effect[0]
def testWhileWithClosureJit(self):
def loop(init, local_limit, inc):
def loop_cond(state):
pos, _ = state
return lax.lt(pos, local_limit)
def loop_body(state):
effect[0] = True
pos, count = state
f = lambda pos, inc: (lax.add(pos, 1), lax.add(count, inc))
return api.jit(f)(pos, inc)
result = lax.while_loop(loop_cond, loop_body, (init, 0))
_, count = result
return count
cloop = api.jit(loop)
limit = 10
effect = [False]
self.assertEqual(loop(2, limit, 1), limit - 2)
assert effect[0]
effect[0] = False
self.assertEqual(cloop(2, limit, 1), limit - 2)
assert effect[0]
effect[0] = False
self.assertEqual(cloop(2, limit, 1), limit - 2)
self.assertEqual(cloop(3, limit, 1), limit - 3)
assert not effect[0]
def testNestedWhileWithDynamicUpdateSlice(self):
num = 5
def update_entry(arr, val, i, j):
val = lax.reshape(val, [1, 1])
return lax.dynamic_update_slice(arr, val, (i, j))
def outer_loop(arr): # pylint: disable=missing-docstring
def cond_fun(state):
i, num, _, _ = state
return lax.lt(i, num)
def body_fun(state):
i, num, arr, out = state
return (lax.add(i, 1), num, arr, inner_loop(i, arr, out))
out = onp.zeros(arr.shape, dtype=arr.dtype)
init_val = (0, num, arr, out)
_, _, _, out = lax.while_loop(cond_fun, body_fun, init_val)
return out
def inner_loop(i, arr, out): # pylint: disable=missing-docstring
def cond_fun(state):
i, j, _, _ = state
return lax.le(j, i)
def body_fun(state):
i, j, arr, out = state
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
arr_i_j = lax.dynamic_index_in_dim(arr_i, j, 0, False)
out = update_entry(out, arr_i_j, i, j)
return (i, lax.add(j, 1), arr, out)
init_val = (i, 0, arr, out)
_, _, _, out = lax.while_loop(cond_fun, body_fun, init_val)
return out
cloop = api.jit(outer_loop)
arr = npr.RandomState(0).randn(5, 5)
self.assertAllClose(outer_loop(arr), onp.tril(arr), check_dtypes=False)
self.assertAllClose(cloop(arr), onp.tril(arr), check_dtypes=False)
self.assertAllClose(cloop(arr), onp.tril(arr), check_dtypes=False)
def testLoopWithConjunctionCondition(self):
def sum_first_n(arr, num): # pylint: disable=missing-docstring
def cond_fun(state):
arr, num, i, _ = state
return lax.bitwise_and(lax.lt(i, num), lax.lt(i, arr.shape[0]))
def body_fun(state):
arr, num, i, total = state
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
return (arr, num, lax.add(i, 1), lax.add(total, arr_i))
init_val = (arr, num, 0, 0.)
_, _, _, total = lax.while_loop(cond_fun, body_fun, init_val)
return total
cfun = api.jit(sum_first_n)
x = npr.RandomState(0).randn(10)
for num in [0, 5, 10, 15]:
self.assertAllClose(sum_first_n(x, num), onp.sum(x[:num]),
check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
def testWhileLoopBatched(self):
def fun(x):
return lax.while_loop(lambda x: x < 3, lambda x: x + 2, x)
ans = api.vmap(fun)(onp.array([0, 1, 2, 3]))
expected = onp.array([4, 3, 4, 3])
self.assertAllClose(ans, expected, check_dtypes=False)
fun = api.jit(fun)
ans = api.vmap(fun)(onp.array([0, 1, 2, 3]))
expected = onp.array([4, 3, 4, 3])
self.assertAllClose(ans, expected, check_dtypes=False)
def testWhileLoopCondConstsBatched(self):
def fun(x, y):
return lax.while_loop(lambda x: x < y, lambda x: x + 2, x)
ans = api.vmap(fun, in_axes=(None, 0))(0, onp.array([2, 3]))
expected = onp.array([2, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
def testWhileLoopBodyConstsBatched(self):
def fun(x, y):
return lax.while_loop(lambda x: x < 3, lambda x: x + y, x)
ans = api.vmap(fun, in_axes=(None, 0))(0, onp.array([2, 3]))
expected = onp.array([4, 3])
self.assertAllClose(ans, expected, check_dtypes=False)
def testWhileLoopTupleBatched(self):
def cond_fun(loop_carry):
x, y = loop_carry
return x + y < 5
def body_fun(loop_carry):
x, y = loop_carry
x = x + 1
return x, y
def fun(x, y):
return lax.while_loop(cond_fun, body_fun, (x, y))
ans = api.vmap(fun)(onp.array([0, 0]), onp.array([1, 2]))
expected = (onp.array([4, 3]), onp.array([1, 2]))
self.assertAllClose(ans, expected, check_dtypes=False)
def testForiLoopBatched(self):
def body_fun(i, loop_carry):
x, y = loop_carry
x = x + 1
y = y + 2
return x, y
def fun(x):
return lax.fori_loop(0, 10, body_fun, (x, 0))
ans = api.vmap(fun)(onp.array([0, 1]))
expected = (onp.array([10, 11]), onp.array([20, 20]))
self.assertAllClose(ans, expected, check_dtypes=False)
def testForiLoopBasic(self):
def count(num):
def body_fun(i, tot):
return lax.add(tot, i)
return lax.fori_loop(0, num, body_fun, 0)
cfun = api.jit(count)
self.assertEqual(count(2), 1)
self.assertEqual(count(2), cfun(2))
self.assertEqual(count(3), 3)
self.assertEqual(count(3), cfun(3))
self.assertEqual(count(4), 6)
self.assertEqual(count(4), cfun(4))
def testForiLoopClosure(self):
def count(num):
def body_fun(i, tot):
return lax.add(num, lax.add(tot, i))
return lax.fori_loop(0, num, body_fun, 0)
cfun = api.jit(count)
self.assertEqual(count(2), 1 + 2**2)
self.assertEqual(count(2), cfun(2))
self.assertEqual(count(3), 3 + 3**2)
self.assertEqual(count(3), cfun(3))
self.assertEqual(count(4), 6 + 4**2)
self.assertEqual(count(4), cfun(4))
def testForiLoopTupleState(self):
def sum_first_n(arr, num):
def body_fun(i, state):
arr, total = state
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
return (arr, lax.add(total, arr_i))
init_val = (arr, 0.)
_, total = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun,
init_val)
return total
cfun = api.jit(sum_first_n)
x = npr.RandomState(0).randn(10)
for num in [0, 5, 10, 15]:
self.assertAllClose(sum_first_n(x, num), onp.sum(x[:num]),
check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
def testForiLoopDictState(self):
def sum_first_n(arr, num):
def body_fun(i, state):
arr, total = state['arr'], state['total']
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
return {'arr': arr, 'total': lax.add(total, arr_i)}
init_val = {'arr': arr, 'total': 0.}
out_val = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
return out_val['total']
cfun = api.jit(sum_first_n)
x = npr.RandomState(0).randn(10)
for num in [0, 5, 10, 15]:
self.assertAllClose(sum_first_n(x, num), onp.sum(x[:num]),
check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
def testForiLoopEmptyTupleInState(self):
def sum_first_n(arr, num):
def body_fun(i, state):
arr, total, _ = state
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
return (arr, lax.add(total, arr_i), ())
init_val = (arr, 0., ())
_, tot, _ = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
return tot
cfun = api.jit(sum_first_n)
x = npr.RandomState(0).randn(10)
for num in [0, 5, 10, 15]:
self.assertAllClose(sum_first_n(x, num), onp.sum(x[:num]),
check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
def testCond(self):
def fun(x):
if x < 3:
return (x, x)
else:
y = lax.mul(2, x)
return y, lax.mul(2, y)
@api.jit
def cfun(x):
def false_fun(x):
y = lax.mul(2, x)
return y, lax.mul(2, y)
return lax.cond(lax.lt(x, 3), x, lambda x: (x, x), x, false_fun)
self.assertEqual(fun(0), cfun(0))
self.assertEqual(fun(0), (0, 0))
self.assertEqual(fun(1), cfun(1))
self.assertEqual(fun(1), (1, 1))
self.assertEqual(fun(2), cfun(2))
self.assertEqual(fun(2), (2, 2))
self.assertEqual(fun(3), cfun(3))
self.assertEqual(fun(3), (6, 12))
self.assertEqual(fun(4), cfun(4))
self.assertEqual(fun(4), (8, 16))
def testNestedCond(self):
def fun(x):
if x < 2:
return lax.mul(2, x)
else:
if x < 5:
return lax.mul(3, x)
else:
return lax.mul(4, x)
@api.jit
def cfun(x):
return lax.cond(
lax.lt(x, 2),
x, lambda x: lax.mul(2, x),
x, lambda x: lax.cond(lax.lt(x, 5),
x, lambda x: lax.mul(3, x),
4, lambda y: lax.mul(y, x)))
self.assertEqual(cfun(1), 2)
self.assertEqual(cfun(3), 9)
self.assertEqual(cfun(6), 24)
self.assertEqual(cfun(1), fun(1))
self.assertEqual(cfun(3), fun(3))
self.assertEqual(cfun(6), fun(6))
def testCondOneBranchConstant(self):
def fun(x):
if x < 3:
return 5.
else:
return x
@api.jit
def cfun(x):
return lax.cond(lax.lt(x, 3), x, lambda x: 5, x, lambda x: x)
self.assertEqual(fun(0), cfun(0))
self.assertEqual(cfun(0), 5)
self.assertEqual(fun(4), cfun(4))
self.assertEqual(cfun(4), 4)
def testCondOneBranchConstantTuple(self):
def fun(x):
if x < 3:
return (1., 2., 3.)
else:
return (x, 2., 4.)
@api.jit
def cfun(x):
return lax.cond(lax.lt(x, 3),
x, lambda x: (1, 2., 3.),
x, lambda x: (x, 2., 4.))
self.assertEqual(fun(0), cfun(0))
self.assertEqual(cfun(0), (1, 2., 3.))
self.assertEqual(fun(4), cfun(4))
self.assertEqual(cfun(4), (4, 2., 4.))
def testCondBatched(self):
def fun(x, y, z):
pred = lax.lt(x, 3)
true_fun = lambda y: y
false_fun = lambda z: lax.neg(z)
return lax.cond(pred, y, true_fun, z, false_fun)
# these cases stay as cond
x = onp.array(2)
y = onp.array([1, 2])
z = onp.array([3, 4])
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0)))(x, y, z)
expected = onp.array([1, 2])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" not in str(jaxpr)
x = onp.array(4)
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0)))(x, y, z)
expected = onp.array([-3, -4])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" not in str(jaxpr)
fun = api.jit(fun)
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
expected = onp.array([-3, -4])
self.assertAllClose(ans, expected, check_dtypes=False)
z = onp.array(5)
ans = api.vmap(fun, (None, 0, None))(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, None)))(x, y, z)
expected = onp.array([-5, -5])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" not in str(jaxpr)
# these cases become select
x = onp.array([2, 4])
ans = api.vmap(fun, (0, 0, None))(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun, (0, 0, None)))(x, y, z)
expected = onp.array([1, -5])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" in str(jaxpr)
z = onp.array([3, 4])
ans = api.vmap(fun)(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun))(x, y, z)
expected = onp.array([1, -4])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" in str(jaxpr)
def testIssue514(self):
# just check this doesn't crash
lax.cond(True,
(0, 0), lambda x: (x[0], 0),
(1, 1), lambda x: x)
def testIssue649(self):
from jax import lax
def body(x):
a, b = x
return (7, b + 1)
def cond(x):
a, b = x
return b < 10
out = lax.while_loop(cond, body, (33, 4))
self.assertEqual(out, (7, 10))
@parameterized.named_parameters(
{"testcase_name": "_jit_scan={}_jit_f={}".format(jit_scan, jit_f),
"jit_scan": jit_scan, "jit_f": jit_f}
for jit_scan in [False, True]
for jit_f in [False, True])
def testScanImpl(self, jit_scan, jit_f):
rng = onp.random.RandomState(0)
d = rng.randn(2)
def f(c, a):
assert a.shape == (3,)
assert c.shape == (4,)
b = np.cos(np.sum(np.sin(a)) + np.sum(np.cos(c)) + np.sum(np.tan(d)))
c = np.sin(c * b)
assert b.shape == ()
return c, b
if jit_f:
f = api.jit(f)
if jit_scan:
scan = api.jit(lax.scan, (0,))
else:
scan = lax.scan
as_ = rng.randn(5, 3)
c = rng.randn(4)
ans = scan(f, c, as_)
expected = scan_reference(f, c, as_)
self.assertAllClose(ans, expected, check_dtypes=False)
@parameterized.named_parameters(
{"testcase_name": "_jit_scan={}_jit_f={}".format(jit_scan, jit_f),
"jit_scan": jit_scan, "jit_f": jit_f}
for jit_scan in [False, True]
for jit_f in [False, True])
def testScanJVP(self, jit_scan, jit_f):
rng = onp.random.RandomState(0)
d = rng.randn(2)
def f(c, a):
assert a.shape == (3,)
assert c.shape == (4,)
b = np.cos(np.sum(np.sin(a)) + np.sum(np.cos(c)) + np.sum(np.tan(d)))
c = np.sin(c * b)
assert b.shape == ()
return c, b
if jit_f:
f = api.jit(f)
if jit_scan:
scan = api.jit(lax.scan, (0,))
else:
scan = lax.scan
as_ = rng.randn(5, 3)
c = rng.randn(4)
ans = api.jvp(lambda c, as_: scan(f, c, as_), (c, as_), (c, as_))
expected = api.jvp(lambda c, as_: scan_reference(f, c, as_), (c, as_), (c, as_))
self.assertAllClose(ans, expected, check_dtypes=False)
jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["fwd"])
@parameterized.named_parameters(
{"testcase_name": "_jit_scan={}_jit_f={}".format(jit_scan, jit_f),
"jit_scan": jit_scan, "jit_f": jit_f}
for jit_scan in [False, True]
for jit_f in [False, True])
def testScanLinearize(self, jit_scan, jit_f):
rng = onp.random.RandomState(0)
d = rng.randn(2)
def f(c, a):
assert a.shape == (3,)
assert c.shape == (4,)
b = np.cos(np.sum(np.sin(a)) + np.sum(np.cos(c)) + np.sum(np.tan(d)))
c = np.sin(c * b)
assert b.shape == ()
return c, b
if jit_f:
f = api.jit(f)
if jit_scan:
scan = api.jit(lax.scan, (0,))
else:
scan = lax.scan
as_ = rng.randn(5, 3)
c = rng.randn(4)
ans = api.linearize(lambda c, as_: scan(f, c, as_), c, as_)[1](c, as_)
expected = api.linearize(lambda c, as_: scan_reference(f, c, as_), c, as_)[1](c, as_)
self.assertAllClose(ans, expected, check_dtypes=False)
@parameterized.named_parameters(
{"testcase_name": "_jit_scan={}_jit_f={}".format(jit_scan, jit_f),
"jit_scan": jit_scan, "jit_f": jit_f}
for jit_scan in [False, True]
for jit_f in [False, True])
def testScanGrad(self, jit_scan, jit_f):
rng = onp.random.RandomState(0)
d = rng.randn(2)
def f(c, a):
assert a.shape == (3,)
assert c.shape == (4,)
b = np.sum(np.sin(a)) + np.sum(np.sin(c)) + np.sum(np.sin(d))
c = np.sin(c * b)
assert b.shape == ()
return c, b
if jit_f:
f = api.jit(f)
if jit_scan:
scan = api.jit(lax.scan, (0,))
else:
scan = lax.scan
as_ = rng.randn(5, 3)
c = rng.randn(4)
ans = api.grad(lambda c, as_: list( scan(f, c, as_))[0].sum())(c, as_)
expected = api.grad(lambda c, as_: list(scan_reference(f, c, as_))[0].sum())(c, as_)
self.assertAllClose(ans, expected, check_dtypes=False)
jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["rev"],
atol=1e-3, rtol=1e-3)
def testScanRnn(self):
r = npr.RandomState(0)
n_in = 4
n_hid = 2
n_out = 1
length = 3
W_trans = r.randn(n_hid, n_hid + n_in)
W_out = r.randn(n_out, n_hid + n_in)
params = W_trans, W_out
inputs = r.randn(length, n_in)
targets = r.randn(length, n_out)
def step(params, state, input):
W_trans, W_out = params
stacked = np.concatenate([state, input])
output = np.tanh(np.dot(W_out, stacked))
next_state = np.tanh(np.dot(W_trans, stacked))
return next_state, output
def rnn(params, inputs):
init_state = np.zeros(n_hid)
_, outputs = lax.scan(partial(step, params), init_state, inputs)
return outputs
def loss(params, inputs, targets):
predictions = rnn(params, inputs)
return np.sum((predictions - targets)**2)
# evaluation doesn't crash
loss(params, inputs, targets)
# jvp evaluation doesn't crash
api.jvp(lambda params: loss(params, inputs, targets), (params,), (params,))
# jvp numerical check passes
jtu.check_grads(loss, (params, inputs, targets), order=2, modes=["fwd"])
# linearize works
_, expected = api.jvp(loss, (params, inputs, targets),
(params, inputs, targets))
_, linfun = api.linearize(loss, params, inputs, targets)
ans = linfun(params, inputs, targets)
self.assertAllClose(ans, expected, check_dtypes=False)
# gradient evaluation doesn't crash
api.grad(loss)(params, inputs, targets)
# gradient check passes
jtu.check_grads(loss, (params, inputs, targets), order=2)
# we can vmap to batch things
batch_size = 7
batched_inputs = r.randn(batch_size, length, n_in)
batched_targets = r.randn(batch_size, length, n_out)
batched_loss = api.vmap(lambda x, y: loss(params, x, y))
losses = batched_loss(batched_inputs, batched_targets)
expected = onp.stack(list(map(lambda x, y: loss(params, x, y),
batched_inputs, batched_targets)))
self.assertAllClose(losses, expected, check_dtypes=False)
def testIssue711(self):
# Tests reverse-mode differentiation through a scan for which the scanned
# function also involves reverse-mode differentiation.
# See https://github.com/google/jax/issues/711
def harmonic_bond(conf, params):
return np.sum(conf * params)
def minimize_structure(test_params):
energy_fn = partial(harmonic_bond, params=test_params)
grad_fn = api.grad(energy_fn)
def apply_carry(carry, _):
i, x = carry
new_x = x - 0.1 * api.grad(energy_fn)(x)
new_carry = (i+1, new_x)
return new_carry, _
x0 = np.array([1., 2., 3.])
carry_final, _ = lax.scan(apply_carry, (0, x0), np.zeros((75, 0)))
_, x_final = carry_final
return x_final
initial_params = 0.5
minimize_structure(initial_params) # doesn't crash
def loss(test_params):
x_final = minimize_structure(test_params)
return np.sum(np.sin(1.0 - x_final))
api.grad(loss)(0.25) # doesn't crash
def testIssue744(self):
Point = collections.namedtuple('Point', ['x', 'y'])
p0 = Point(x=np.array(1), y=np.array(2))
def plus_one(p, iter_idx):
return Point(p.x+1, p.y+1), iter_idx
self.assertRaisesRegexp(
ValueError,
'scan got value with no leading axis to scan over.*',
lambda: lax.scan(plus_one, p0, list(range(5))))
def testScanHigherOrderDifferentiation(self):
d = 0.75
def f(c, a):
b = np.sin(c * np.sum(np.cos(d * a)))
c = 0.9 * np.cos(d * np.sum(np.sin(c * a)))
return c, b
as_ = np.arange(6.).reshape((3, 2))
c = 1.
jtu.check_grads(lambda c, as_: lax.scan(f, c, as_), (c, as_),
modes=["rev"], order=2)
@parameterized.named_parameters(
{"testcase_name": "_jit_scan={}_jit_f={}_in_axes={}".format(
jit_scan, jit_f, in_axes),
"jit_scan": jit_scan, "jit_f": jit_f, "in_axes": in_axes}
for jit_scan in [False, True]
for jit_f in [False, True]
for in_axes in itertools.product([None, 0, 1], [None, 0, 1, 2])
if in_axes != (None, None))
def testScanVmap(self, jit_scan, jit_f, in_axes):
rng = onp.random.RandomState(0)
d = rng.randn(2)
def f(c, a):
assert a.shape == (3,)
assert c.shape == (4,)
b = np.cos(np.sum(np.sin(a)) + np.sum(np.cos(c)) + np.sum(np.tan(d)))
c = np.sin(c * b)
assert b.shape == ()
return c, b
if jit_f:
f = api.jit(f)
if jit_scan:
scan = api.jit(lax.scan, (0,))
else:
scan = lax.scan
as_shape = [5, 3]
c_shape = [4]
c_bdim, as_bdim = in_axes
if c_bdim is not None:
c_shape.insert(c_bdim, 7)
if as_bdim is not None:
as_shape.insert(as_bdim, 7)
as_ = rng.randn(*as_shape)
c = rng.randn(*c_shape)
ans = api.vmap(lambda c, as_: scan(f, c, as_), in_axes)(c, as_)
expected = api.vmap(lambda c, as_: scan_reference(f, c, as_), in_axes)(c, as_)
self.assertAllClose(ans, expected, check_dtypes=False)
def testScanVmapTuples(self):
def f(c, a):
a1, a2 = a
c1, c2 = c
b = np.sum(np.cos(a1)) * np.sum(np.tan(c2 * a2))
c = c1 * np.sin(np.sum(a1 * a2)), c2 * np.cos(np.sum(a1))
return c, b
in_axes = (0, (1, 2))
r = onp.random.RandomState(0)
as_ = (r.randn(3, 7), r.randn(3, 4, 7))
c = (r.randn(7, 2), r.randn(7))
expected_c_out, expected_bs = [], []
for i in range(7):
c_out, bs = lax.scan(f, (c[0][i], c[1][i]), (as_[0][:,i], as_[1][:,:,i]))
expected_c_out.append(c_out)
expected_bs.append(bs)
expected_c_out_0, expected_c_out_1 = unzip2(expected_c_out)
expected_c_out = (np.stack(expected_c_out_0), np.stack(expected_c_out_1))
expected_bs = np.stack(expected_bs)
expected = expected_c_out, expected_bs
ans = api.vmap(lambda c, as_: lax.scan(f, c, as_), in_axes)(c, as_)
self.assertAllClose(ans, expected, check_dtypes=False)
# TODO(mattjj, dougalm): fix this test when skip_checks is False
def testIssue757(self):
# code from https://github.com/google/jax/issues/757
def fn(a):
return np.cos(a)
def loop(val):
iterations = 10
def apply_carry(x, i):
return api.grad(fn, argnums=(0,))(x)[0], i
final_val, _ = lax.scan(
apply_carry,
val,
np.arange(iterations)
)
return final_val
arg = 0.5
api.jit(api.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash
# TODO(mattjj): add a test for "the David Sussillo bug"
def testIssue804(self):
num_devices = xla_bridge.device_count()
f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.)
api.pmap(f, axis_name="i")(np.ones((num_devices, 4))) # doesn't crash
def testMap(self):
f = lambda x: x ** 2
xs = np.arange(10)
expected = xs ** 2
actual = lax.map(f, xs)
self.assertAllClose(actual, expected, check_dtypes=True)
def testCaching(self):
def cond(x):
assert python_should_be_executing
return x < 5
def body(x):
assert python_should_be_executing
return x + 2
python_should_be_executing = True
lax.while_loop(cond, body, 0)
python_should_be_executing = False
lax.while_loop(cond, body, 0)
def testCaching2(self):
# This second caching test shows a different kind of caching that we haven't
# implemented (but could!), namely that Python functions that are distinct
# objects but are equivalent functions trigger cache hits. This kind of
# caching could be salient when using lambda functions with control flow:
#
# lax.while_loop(lambda x: x < 5, lambda x: x + 2, 0)
# lax.while_loop(lambda x: x < 5, lambda x: x + 2, 0)
#
# To get a cache hit on the second line we'd need to form a jaxpr and
# compare them for equality (including the literals on identity). We could
# implement that by adding a __hash__/__eq__ to core.Jaxpr and
# core.TypedJaxpr (see #1221).
raise SkipTest("not implemented")
def cond(x):
assert python_should_be_executing
return x < 5
def body(x):
assert python_should_be_executing
return x + 2
python_should_be_executing = True
lax.while_loop(cond, body, 0)
def cond(x):
assert python_should_be_executing
return x < 5
def body(x):
assert python_should_be_executing
return x + 2
python_should_be_executing = False
lax.while_loop(cond, body, 0)
def testWhileCondConstant(self):
out = lax.while_loop(lambda _: False, lambda _: (), ()) # doesn't crash
self.assertEqual(out, ())
if __name__ == '__main__':
absltest.main()