rocm_jax/tests/for_loop_test.py

410 lines
14 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax import random
from jax._src import test_util as jtu
from jax._src.lax.control_flow import for_loop
import jax.numpy as jnp
jax.config.parse_flags_with_absl()
def remat_of_for_loop(nsteps, body, state, **kwargs):
return jax.remat(lambda state: for_loop.for_loop(nsteps, body, state,
**kwargs))(state)
def nested_for_loop(nsteps, body, state, **kwargs):
def outer_body(_, refs):
def inner_body(i, _):
body(i, refs)
return
for_loop.for_loop(nsteps, inner_body, ())
return for_loop.for_loop(1, outer_body, state)
FOR_LOOP_IMPLS = [
(for_loop.for_loop, 'for_loop'),
(jax.jit(for_loop.for_loop, static_argnums=(0, 1)), 'jit_for_loop'),
(remat_of_for_loop, 'remat_for_loop'),
(nested_for_loop, 'nested_for_loop'),
(partial(for_loop.for_loop, unroll=3), 'unrolled_for_loop'),
]
def _for_loop_impls(f):
return parameterized.named_parameters(
dict(testcase_name=impl_name, for_impl=for_impl)
for for_impl, impl_name in FOR_LOOP_IMPLS
)(f)
class ForLoopTest(jtu.JaxTestCase):
@_for_loop_impls
def test_for_loop_impl_trivial(self, for_impl):
out = for_impl(5, lambda i, _: None, None)
self.assertIsNone(out)
@_for_loop_impls
def test_for_loop_can_write_to_ref(self, for_impl):
def body(_, x_ref):
x_ref[()] = jnp.float32(1.)
out = for_impl(1, body, jnp.float32(0.))
self.assertEqual(out, 1.)
def body2(i, x_ref):
x_ref[()] = jnp.float32(i)
out = for_impl(2, body2, jnp.float32(0.))
self.assertEqual(out, 1.)
def body3(i, x_ref):
x_ref[()] = jnp.float32(i) * 2.
out = for_impl(2, body3, jnp.float32(0.))
self.assertEqual(out, 2.)
@_for_loop_impls
def test_for_loop_can_write_to_multiple_refs(self, for_impl):
def body(_, refs):
x_ref, y_ref = refs
x_ref[()] = jnp.float32(1.)
y_ref[()] = jnp.float32(2.)
x, y = for_impl(1, body, (jnp.float32(0.), jnp.float32(0.)))
self.assertEqual(x, 1.)
self.assertEqual(y, 2.)
@_for_loop_impls
def test_for_loop_can_read_from_ref(self, for_impl):
def body(_, x_ref):
x_ref[()] # pylint: disable=pointless-statement
x = for_impl(1, body, jnp.float32(0.))
self.assertEqual(x, 0.)
@_for_loop_impls
def test_for_loop_can_read_from_and_write_to_ref(self, for_impl):
def body(_, x_ref):
x = x_ref[()]
x_ref[()] = x + jnp.float32(1.)
x = for_impl(5, body, jnp.float32(0.))
self.assertEqual(x, 5.)
@_for_loop_impls
def test_for_loop_can_read_from_and_write_to_refs(self, for_impl):
def body2(_, refs):
x_ref, y_ref = refs
x = x_ref[()]
y_ref[()] = x + 1.
x_ref[()] = x + 1.
x, y = for_impl(5, body2, (0., 0.))
self.assertEqual(x, 5.)
self.assertEqual(y, 5.)
@_for_loop_impls
def test_for_loop_can_read_from_and_write_to_ref_slice(self, for_impl):
def body(i, x_ref):
x = x_ref[i]
x_ref[i] = x + jnp.float32(1.)
x = for_impl(4, body, jnp.ones(4, jnp.float32))
np.testing.assert_allclose(x, 2 * jnp.ones(4, jnp.float32))
def body2(i, x_ref):
x = x_ref[i, 0]
x_ref[i, 1] = x + x_ref[i, 1]
x = for_impl(4, body2, jnp.arange(8.).reshape((4, 2)))
np.testing.assert_allclose(
x, jnp.array([[0., 1.], [2., 5.], [4., 9.], [6., 13.]]))
@_for_loop_impls
@jax.legacy_prng_key('allow')
def test_for_loop_can_implement_cumsum(self, for_impl):
def cumsum(x):
def body(i, refs):
x_ref, accum_ref = refs
accum_ref[i + 1] = accum_ref[i] + x_ref[i]
accum = jnp.zeros(x.shape[0] + 1, x.dtype)
_, accum_out = for_impl(x.shape[0], body, (x, accum))
return accum_out[1:]
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (8,))
np.testing.assert_allclose(cumsum(x), jnp.cumsum(x), rtol=1e-6)
def for_body_swap(i, refs):
a_ref, b_ref = refs
a, b = a_ref[i], b_ref[i]
b_ref[i] = a
a_ref[i] = b
def swap_ref(a, b):
return b, a
def for_body_swap_swap(i, refs):
for_body_swap(i, refs)
for_body_swap(i, refs)
swap_swap_ref = lambda a, b: (a, b)
def for_body_sincos(i, refs):
a_ref, b_ref = refs
a = a_ref[i]
b_ref[i] = jnp.sin(jnp.cos(a))
sincos_ref = lambda x, y: (x, jnp.sin(jnp.cos(x)))
def for_body_sincostan(i, refs):
a_ref, b_ref = refs
a = a_ref[i]
b_ref[i] = jnp.tan(jnp.sin(jnp.cos(a)))
sincostan_ref = lambda x, y: (x, jnp.tan(jnp.sin(jnp.cos(x))))
def for_body_accum(i, refs):
x_ref, accum_ref = refs
accum_ref[i + 1] = accum_ref[i] + x_ref[i]
def accum_ref(x, accum):
for i in range(x.shape[0] - 1):
accum = accum.at[i + 1].set(accum[i] + x[i])
return x, accum
def for_body_sin_sq(i, refs):
x_ref, y_ref = refs
x = x_ref[i]
y = x
y_ref[i] = y
y = y_ref[i]
y_ref[i] = jnp.sin(y * y)
sin_sq_ref = lambda x, y: (x, jnp.sin(x * x))
def for_body_reverse(i, refs):
x_ref, y_ref = refs
j = y_ref.shape[0] - i - 1
y_ref[i] = x_ref[j]
reverse_ref = lambda x, y: (x, x[::-1])
def for_body_noop(i, refs):
pass
noop_ref = lambda x, y: (x, y)
for_reference = for_loop.discharged_for_loop
class ForLoopTransformationTest(jtu.JaxTestCase):
@jtu.sample_product(
[dict(for_body_name=for_body_name, f=for_body, ref=ref,
body_shapes=body_shapes, n=nsteps)
for for_body_name, for_body, ref, body_shapes, nsteps in [
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4),
("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4),
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
]
],
[dict(for_impl=for_impl, impl_name=impl_name)
for for_impl, impl_name in FOR_LOOP_IMPLS],
)
@jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts?
def test_for_jvp(self, f, ref, body_shapes, n, for_impl, for_body_name,
impl_name):
for_ = for_impl
rng = self.rng()
args = [rng.randn(*s) for s in body_shapes]
tol = {np.float64: 1e-12, np.float32: 1e-4}
ans = jax.jvp( lambda *args: for_( n, f, args), args, args)
ans_discharged = jax.jvp(lambda *args: for_reference(n, f, args), args, args)
expected = jax.jvp(ref, args, args)
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol, atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(partial(for_, n, f), (args,), order=2, modes=["fwd"])
@jtu.sample_product(
[dict(for_body_name=for_body_name, f=for_body, ref=ref,
body_shapes=body_shapes, n=nsteps)
for for_body_name, for_body, ref, body_shapes, nsteps in [
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4),
("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4),
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
]
],
[dict(for_impl=for_impl, impl_name=impl_name)
for for_impl, impl_name in FOR_LOOP_IMPLS],
)
@jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts?
def test_for_linearize(self, f, ref, body_shapes, n, for_impl, for_body_name,
impl_name):
for_ = for_impl
rng = self.rng()
args = [rng.randn(*s) for s in body_shapes]
tol = {np.float64: 1e-12, np.float32: 1e-4}
ans = jax.linearize(lambda *args: for_( n, f, args), *args)[1](*args)
ans_discharged = jax.linearize(lambda *args: for_reference(n, f, args),
*args)[1](*args)
expected = jax.linearize(ref, *args)[1](*args)
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol, atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
def test_for_loop_invar(self):
def f(x):
s = jnp.ones((2, 32), x.dtype)
def body(i, refs):
x_ref, y_ref = refs
y_ref[i] = s * x_ref[i] * jnp.cos(s)
# We should save `s` and `jnp.cos(s)` as residuals and not broadcast
# them.
return for_loop.for_loop(x.shape[0], body, (x, jnp.zeros_like(x)))
_, f_vjp = jax.linearize(f, jnp.ones((5, 2, 32)))
jaxpr = jax.make_jaxpr(f_vjp)(jnp.ones((5, 2, 32)))
consts = [v.aval for v in jaxpr.jaxpr.constvars
if v.aval.shape == (2, 32)]
self.assertLen(consts, 2)
def loss(A):
def step(x, _):
return jnp.matmul(A, x), None
init_x = jnp.zeros(A.shape[-1:])
last_x, _ = for_loop.scan(step, init_x, jnp.arange(10))
return jnp.sum(last_x)
A = jnp.zeros((3, 3))
# The second DUS was unnecessarily replicating A across time.
# We check XLA because _scan_impl is "underneath" the jaxpr language.
s = jax.jit(jax.grad(loss)).lower(A).as_text('hlo')
assert s.count("dynamic-update-slice(") < 2
@_for_loop_impls
def test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals(
self, for_impl):
def body(i, refs):
a_ref, b_ref, c_ref = refs
a = a_ref[i]
b = b_ref[()]
x = jnp.sin(a)
b_ref[()] = jnp.sin(b * x)
c_ref[i] = x * b
def f(a, b):
c = jnp.zeros_like(a)
_, b, c = for_impl(5, body, (a, b, c))
return b, c
a = jnp.arange(5.) + 1.
b = jnp.ones_like(a[0])
_, f_lin = jax.linearize(f, a, b)
expected_tangents = f_lin(a, b)
_, actual_tangents = jax.jvp(f, (a, b), (a, b))
np.testing.assert_allclose(actual_tangents[0], expected_tangents[0],
rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(actual_tangents[1], expected_tangents[1],
rtol=1e-6, atol=1e-6)
def body2(_, refs):
# Here we use `i_ref` as a loop counter
a_ref, b_ref, c_ref, i_ref = refs
i = i_ref[()]
a = a_ref[i]
b = b_ref[()]
x = jnp.sin(a)
b_ref[()] = jnp.sin(b * x)
c_ref[i] = x * b
i_ref[()] = i + 1
def g(a, b):
c = jnp.zeros_like(a)
_, b, c, _ = for_impl(5, body2, (a, b, c, 0))
return b, c
a = jnp.arange(5.) + 1.
b = jnp.ones_like(a[0])
_, g_lin = jax.linearize(f, a, b)
expected_tangents = g_lin(a, b)
_, actual_tangents = jax.jvp(g, (a, b), (a, b))
np.testing.assert_allclose(actual_tangents[0], expected_tangents[0])
np.testing.assert_allclose(actual_tangents[1], expected_tangents[1],
rtol=1e-6)
@jtu.sample_product(
[dict(for_body_name=for_body_name, f=for_body, ref=ref,
body_shapes=body_shapes, n=nsteps)
for for_body_name, for_body, ref, body_shapes, nsteps in [
("noop", for_body_noop, noop_ref, [(4,), (4,)], 4),
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4),
("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4),
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
]
],
[dict(for_impl=for_impl, impl_name=impl_name)
for for_impl, impl_name in FOR_LOOP_IMPLS],
)
@jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts?
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_for_grad(self, f, ref, body_shapes, n, for_impl, for_body_name,
impl_name):
for_ = for_impl
rng = self.rng()
args = [rng.randn(*s) for s in body_shapes]
tol = {np.float64: 1e-12, np.float32: 1e-4}
ans = jax.grad(lambda args: for_( n, f, args)[1].sum())(args)
ans_discharged = jax.grad(
lambda args: for_reference(n, f, args)[1].sum())(args)
expected = jax.grad(lambda args: ref(*args)[1].sum())(args)
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol,
atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=2,
rtol=7e-3, atol=1e-2)
@jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts?
@jax.legacy_prng_key('allow')
def test_grad_of_triple_nested_for_loop(self):
func = lambda x: jnp.sin(x) + 1.
@jax.jit
def f(x):
out = jnp.zeros_like(x)
def body(i, j, k, refs):
x_ref, out_ref = refs
y = func(x_ref[i, j, k])
out_ref[i, j, k] += y
return for_loop.for_loop(x.shape, body, (x, out))[1].sum()
x = random.normal(random.PRNGKey(0), (5, 4, 3))
ref = lambda x: jax.vmap(jax.vmap(jax.vmap(func)))(x).sum()
self.assertAllClose(f(x), ref(x))
jtu.check_grads(f, (x,), order=2, atol=0.1, rtol=0.1)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())