mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

A change that avoids duplicating subcomputations in XLA causes this test to fail, but we can make it work again by increasing the number of iterations. PiperOrigin-RevId: 735875835
3072 lines
101 KiB
Python
3072 lines
101 KiB
Python
# Copyright 2018 The JAX Authors.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# https://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
|
||
import collections
|
||
import contextlib
|
||
from functools import partial
|
||
import itertools
|
||
import operator
|
||
import re
|
||
import unittest
|
||
|
||
from absl.testing import absltest
|
||
from absl.testing import parameterized
|
||
|
||
import numpy as np
|
||
|
||
import jax
|
||
from jax._src import core
|
||
from jax import dtypes
|
||
from jax import lax
|
||
from jax import random
|
||
from jax._src import test_util as jtu
|
||
from jax import tree_util
|
||
from jax._src.util import unzip2
|
||
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies
|
||
import jax.numpy as jnp # scan tests use numpy
|
||
import jax.scipy as jsp
|
||
from jax._src.lax import control_flow as lax_control_flow
|
||
from jax._src.lax.control_flow import for_loop
|
||
from jax._src.interpreters import mlir
|
||
|
||
jax.config.parse_flags_with_absl()
|
||
|
||
|
||
# Some tests are useful for testing both lax.cond and lax.switch. This function
|
||
# provides a lax.cond-compatible interface to a two-branch lax.switch. Several
|
||
# tests in this file are parameterized such that they either call into lax.cond
|
||
# or into this function.
|
||
def cond_via_switch(pred, true_fun, false_fun, op, *args):
|
||
if len(args) > 0:
|
||
assert len(args) == 1
|
||
true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0]
|
||
op = (false_op, true_op)
|
||
false_fun = lambda op: _false_fun(op[0])
|
||
true_fun = lambda op: _true_fun(op[1])
|
||
index = lax.convert_element_type(pred, np.int32)
|
||
return lax.switch(index, [false_fun, true_fun], op)
|
||
|
||
def cond_with_new_checkpoint(pred, true_fun, false_fun, op, *args):
|
||
if args:
|
||
true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0]
|
||
op = (false_op, true_op)
|
||
false_fun = lambda op: _false_fun(op[0])
|
||
true_fun = lambda op: _true_fun(op[1])
|
||
index = lax.convert_element_type(pred, np.int32)
|
||
fn = lambda index, op: lax.switch(index, [false_fun, true_fun], op)
|
||
return new_checkpoint(fn)(index, op)
|
||
|
||
COND_IMPLS = [
|
||
(lax.cond, 'cond'),
|
||
(cond_via_switch, 'switch'),
|
||
(cond_with_new_checkpoint, 'new_checkpoint'),
|
||
]
|
||
|
||
|
||
# We wanted to try all scan tests with the scan partial evaluation rule that
|
||
# happens under ad_checkpoint.checkpoint, so we make a scan wrapper which
|
||
# wraps a ad_checkpoint.checkpoint around the computation.
|
||
def scan_with_new_checkpoint(f, *args, **kwargs):
|
||
return new_checkpoint(partial(lax.scan, f, **kwargs),
|
||
policy=checkpoint_policies.nothing_saveable)(*args)
|
||
def scan_with_new_checkpoint2(f, *args, **kwargs):
|
||
return new_checkpoint(partial(lax.scan, f, **kwargs),
|
||
policy=checkpoint_policies.everything_saveable)(*args)
|
||
|
||
def scan_with_for(f, *args, **kwargs):
|
||
return for_loop.scan(f, *args, **kwargs)
|
||
|
||
def scan_with_remat_for(f, *args, **kwargs):
|
||
return jax.remat(lambda *args: for_loop.scan(f, *args, **kwargs))(*args)
|
||
|
||
SCAN_IMPLS_WITH_FOR = [
|
||
(lax.scan, 'unroll1'),
|
||
(partial(lax.scan, unroll=2), 'unroll2'),
|
||
(partial(lax.scan, _split_transpose=True), 'split_transpose'),
|
||
(scan_with_new_checkpoint , 'new_checkpoint'),
|
||
(scan_with_new_checkpoint2, 'new_checkpoint2'),
|
||
(scan_with_for, 'for_loop'),
|
||
(scan_with_remat_for, 'for_loop_remat'),
|
||
]
|
||
|
||
def while_loop_new_checkpoint(cond_fun, body_fun, init_val):
|
||
return new_checkpoint(partial(lax.while_loop, cond_fun, body_fun))(init_val)
|
||
|
||
WHILE_LOOP_IMPLS = [
|
||
(lax.while_loop, 'while_loop'),
|
||
(while_loop_new_checkpoint, 'new_checkpoint'),
|
||
]
|
||
|
||
|
||
def while_loop_reference(cond, body, carry):
|
||
while cond(carry):
|
||
carry = body(carry)
|
||
return carry
|
||
|
||
|
||
def scan_reference(f, init, xs):
|
||
carry = init
|
||
ys = []
|
||
for x in xs:
|
||
(carry, y) = f(carry, x)
|
||
ys.append(lax.reshape(y, (1,) + np.shape(y)))
|
||
ys = lax.concatenate(ys, 0)
|
||
return carry, ys
|
||
|
||
|
||
ignore_jit_of_pmap_warning = partial(
|
||
jtu.ignore_warning, message=".*jit-of-pmap.*")
|
||
|
||
# A JAX primitive whose lowering is a custom call to a non-existent function.
|
||
prim_non_existent_custom_call = core.Primitive("__testing_non_existent_custom_call")
|
||
prim_non_existent_custom_call.def_abstract_eval(lambda x_aval: x_aval)
|
||
mlir.register_lowering(
|
||
prim_non_existent_custom_call,
|
||
lambda ctx, x: mlir.hlo.CustomCallOp(
|
||
[x.type], [x],
|
||
call_target_name=mlir.ir.StringAttr.get("__testing_non_existent_custom_call")).results)
|
||
|
||
|
||
class LaxControlFlowTest(jtu.JaxTestCase):
|
||
|
||
def setUp(self):
|
||
super().setUp()
|
||
lax_control_flow._initial_style_open_jaxpr.cache_clear()
|
||
lax_control_flow._initial_style_jaxpr.cache_clear()
|
||
lax_control_flow.common._pad_jaxpr_constvars.cache_clear()
|
||
|
||
def testCallableErrors(self):
|
||
not_callable = 42
|
||
with self.assertRaisesRegex(TypeError, "lax.fori_loop.*callable.*"):
|
||
lax.fori_loop(0, 1, not_callable, 0)
|
||
with self.assertRaisesRegex(TypeError, "lax.while_loop.*callable.*"):
|
||
lax.while_loop(not_callable, not_callable, 0)
|
||
with self.assertRaisesRegex(TypeError, "lax.switch:.*callable.*"):
|
||
lax.switch(0, [not_callable])
|
||
with self.assertRaisesRegex(TypeError, "lax.cond.*callable.*"):
|
||
lax.cond(0, not_callable, not_callable)
|
||
with self.assertRaisesRegex(TypeError, "lax.scan.*callable.*"):
|
||
lax.scan(not_callable, 0, 1)
|
||
with self.assertRaisesRegex(TypeError, "lax.associative_scan.*callable.*"):
|
||
lax.associative_scan(not_callable, 0)
|
||
|
||
def testWhileWithTuple(self):
|
||
limit = 10
|
||
|
||
def loop_cond(state):
|
||
pos, _ = state
|
||
return lax.lt(pos, limit)
|
||
|
||
def loop_body(state):
|
||
pos, count = state
|
||
return (lax.add(pos, 1), lax.add(count, 1))
|
||
|
||
def loop(init):
|
||
result = lax.while_loop(loop_cond, loop_body, (init, 0))
|
||
_, count = result
|
||
return count
|
||
|
||
cloop = jax.jit(loop)
|
||
|
||
self.assertEqual(loop(2), limit - 2)
|
||
self.assertEqual(cloop(2), limit - 2)
|
||
self.assertEqual(cloop(2), limit - 2)
|
||
self.assertEqual(cloop(3), limit - 3)
|
||
|
||
def testWhileWithManyArgs(self):
|
||
nargs = 256
|
||
|
||
def loop_cond(state):
|
||
return lax.lt(state[0], 2)
|
||
|
||
def loop_body(state):
|
||
return tuple(lax.add(s, 1) for s in state)
|
||
|
||
_ = lax.while_loop(loop_cond, loop_body, (0,) * nargs)
|
||
|
||
def testNestedWhile(self):
|
||
|
||
def outer_loop(num): # pylint: disable=missing-docstring
|
||
def cond_fun(state):
|
||
num, i, _ = state
|
||
return lax.lt(i, num)
|
||
|
||
def body_fun(state):
|
||
num, i, count = state
|
||
return (num, lax.add(i, 1), inner_loop(i, count))
|
||
|
||
init_val = (num, 0, 0)
|
||
_, i, count = lax.while_loop(cond_fun, body_fun, init_val)
|
||
return (i, count)
|
||
|
||
def inner_loop(i, count): # pylint: disable=missing-docstring
|
||
def cond_fun(state):
|
||
i, j, _ = state
|
||
return lax.le(j, i)
|
||
|
||
def body_fun(state):
|
||
i, j, count = state
|
||
return (i, lax.add(j, 1), lax.add(count, 1))
|
||
|
||
init_val = (i, 0, count)
|
||
_, _, count = lax.while_loop(cond_fun, body_fun, init_val)
|
||
return count
|
||
|
||
cloop = jax.jit(outer_loop)
|
||
|
||
self.assertEqual(outer_loop(3), (3, 6))
|
||
self.assertEqual(cloop(3), (3, 6))
|
||
self.assertEqual(cloop(3), (3, 6))
|
||
self.assertEqual(cloop(2), (2, 3))
|
||
self.assertEqual(cloop(4), (4, 10))
|
||
|
||
def testWhileWithClosure(self):
|
||
|
||
def loop(init, local_limit, inc):
|
||
|
||
def loop_cond(state):
|
||
pos, _ = state
|
||
return lax.lt(pos, local_limit)
|
||
|
||
def loop_body(state):
|
||
effect[0] = True
|
||
pos, count = state
|
||
return (lax.add(pos, 1), lax.add(count, inc))
|
||
|
||
result = lax.while_loop(loop_cond, loop_body, (init, 0))
|
||
_, count = result
|
||
return count
|
||
|
||
cloop = jax.jit(loop)
|
||
|
||
limit = 10
|
||
effect = [False]
|
||
self.assertEqual(loop(2, limit, 1), limit - 2)
|
||
assert effect[0]
|
||
effect[0] = False
|
||
self.assertEqual(cloop(2, limit, 1), limit - 2)
|
||
assert effect[0]
|
||
effect[0] = False
|
||
self.assertEqual(cloop(2, limit, 1), limit - 2)
|
||
self.assertEqual(cloop(3, limit, 1), limit - 3)
|
||
assert not effect[0]
|
||
|
||
def testWhileWithClosureJit(self):
|
||
|
||
def loop(init, local_limit, inc):
|
||
|
||
def loop_cond(state):
|
||
pos, _ = state
|
||
return lax.lt(pos, local_limit)
|
||
|
||
def loop_body(state):
|
||
effect[0] = True
|
||
pos, count = state
|
||
f = lambda pos, inc: (lax.add(pos, 1), lax.add(count, inc))
|
||
return jax.jit(f)(pos, inc)
|
||
|
||
result = lax.while_loop(loop_cond, loop_body, (init, 0))
|
||
_, count = result
|
||
return count
|
||
|
||
cloop = jax.jit(loop)
|
||
|
||
limit = 10
|
||
effect = [False]
|
||
self.assertEqual(loop(2, limit, 1), limit - 2)
|
||
assert effect[0]
|
||
effect[0] = False
|
||
self.assertEqual(cloop(2, limit, 1), limit - 2)
|
||
assert effect[0]
|
||
effect[0] = False
|
||
self.assertEqual(cloop(2, limit, 1), limit - 2)
|
||
self.assertEqual(cloop(3, limit, 1), limit - 3)
|
||
assert not effect[0]
|
||
|
||
def testWhileTypeErrors(self):
|
||
"""Test typing error messages for while."""
|
||
tuple_treedef = jax.tree.structure((1., 1.))
|
||
leaf_treedef = jax.tree.structure(0.)
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
re.escape(f"cond_fun must return a boolean scalar, but got pytree {tuple_treedef}.")):
|
||
lax.while_loop(lambda c: (1., 1.), lambda c: c, 0.)
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
re.escape("cond_fun must return a boolean scalar, but got output type(s) [ShapedArray(float32[])].")):
|
||
lax.while_loop(lambda c: np.float32(1.), lambda c: c, np.float32(0.))
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
re.escape("while_loop body function carry input and carry output must "
|
||
"have the same pytree structure, but they differ:\n\n"
|
||
"The input carry c is a")):
|
||
lax.while_loop(lambda c: True, lambda c: (1., 1.), 0.)
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
r"The input carry component c\[1\] has type float32\[\] but the "
|
||
r"corresponding output carry component has type bool\[\], so the "
|
||
"dtypes do not match."):
|
||
lax.while_loop(lambda c: True, lambda c: (True, True),
|
||
(np.bool_(True), np.float32(0.)))
|
||
|
||
def testWhileLoopCustomPytreeDiffAuxData(self):
|
||
class Node:
|
||
def __init__(self, x, y):
|
||
self.x = x
|
||
self.y = y
|
||
tree_util.register_pytree_with_keys(
|
||
Node,
|
||
lambda o: ((("x", o.x), ("y", o.y)), 'with_keys'), # flatten_with_keys
|
||
lambda _, xy: Node(xy[0], xy[1]), # unflatten (no key involved)
|
||
lambda o: ((o.x, o.y), 'without_keys'), # flatten
|
||
)
|
||
lax.while_loop(lambda o: o.x > 0., lambda c: Node(0., 0.), Node(1., 1.))
|
||
|
||
def testNestedWhileWithDynamicUpdateSlice(self):
|
||
num = 5
|
||
|
||
def update_entry(arr, val, i, j):
|
||
val = lax.reshape(val, [1, 1])
|
||
return lax.dynamic_update_slice(arr, val, (i, j))
|
||
|
||
def outer_loop(arr): # pylint: disable=missing-docstring
|
||
|
||
def cond_fun(state):
|
||
i, num, _, _ = state
|
||
return lax.lt(i, num)
|
||
|
||
def body_fun(state):
|
||
i, num, arr, out = state
|
||
return (lax.add(i, 1), num, arr, inner_loop(i, arr, out))
|
||
|
||
out = np.zeros(arr.shape, dtype=arr.dtype)
|
||
init_val = (0, num, arr, out)
|
||
_, _, _, out = lax.while_loop(cond_fun, body_fun, init_val)
|
||
return out
|
||
|
||
def inner_loop(i, arr, out): # pylint: disable=missing-docstring
|
||
|
||
def cond_fun(state):
|
||
i, j, _, _ = state
|
||
return lax.le(j, i)
|
||
|
||
def body_fun(state):
|
||
i, j, arr, out = state
|
||
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
|
||
arr_i_j = lax.dynamic_index_in_dim(arr_i, j, 0, False)
|
||
out = update_entry(out, arr_i_j, i, j)
|
||
return (i, lax.add(j, 1), arr, out)
|
||
|
||
init_val = (i, 0, arr, out)
|
||
_, _, _, out = lax.while_loop(cond_fun, body_fun, init_val)
|
||
return out
|
||
|
||
cloop = jax.jit(outer_loop)
|
||
arr = self.rng().randn(5, 5)
|
||
self.assertAllClose(outer_loop(arr), np.tril(arr), check_dtypes=False)
|
||
self.assertAllClose(cloop(arr), np.tril(arr), check_dtypes=False)
|
||
self.assertAllClose(cloop(arr), np.tril(arr), check_dtypes=False)
|
||
|
||
def testLoopWithConjunctionCondition(self):
|
||
def sum_first_n(arr, num): # pylint: disable=missing-docstring
|
||
def cond_fun(state):
|
||
arr, num, i, _ = state
|
||
return lax.bitwise_and(lax.lt(i, num), lax.lt(i, arr.shape[0]))
|
||
|
||
def body_fun(state):
|
||
arr, num, i, total = state
|
||
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
|
||
return (arr, num, i + 1, total + arr_i)
|
||
|
||
init_val = (arr, num, 0, 0.)
|
||
_, _, _, total = lax.while_loop(cond_fun, body_fun, init_val)
|
||
return total
|
||
|
||
cfun = jax.jit(sum_first_n)
|
||
x = self.rng().randn(10).astype(jnp.float_)
|
||
|
||
for num in [0, 5, 10, 15]:
|
||
self.assertAllClose(sum_first_n(x, num), np.sum(x[:num]),
|
||
check_dtypes=False)
|
||
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
||
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
||
|
||
def testWhileLoopBatched(self):
|
||
def fun(x):
|
||
return lax.while_loop(lambda x: x < 3, lambda x: x + 2, x)
|
||
|
||
ans = jax.vmap(fun)(np.array([0, 1, 2, 3]))
|
||
expected = np.array([4, 3, 4, 3])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
fun = jax.jit(fun)
|
||
ans = jax.vmap(fun)(np.array([0, 1, 2, 3]))
|
||
expected = np.array([4, 3, 4, 3])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
def testWhileLoopAxisIndexBatched(self):
|
||
def fun(x):
|
||
return lax.while_loop(lambda x: x < lax.axis_index('i'), lambda x: x + 2, x)
|
||
|
||
ans = jax.vmap(fun, axis_name='i')(np.array([0, 0, 0, 0], dtype='int32'))
|
||
expected = np.array([0, 2, 2, 4])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
fun = jax.jit(fun)
|
||
ans = jax.vmap(fun, axis_name='i')(np.array([0, 0, 0, 0], dtype='int32'))
|
||
expected = np.array([0, 2, 2, 4])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
ans = jax.vmap(lambda _, x: fun(x), axis_name='i', in_axes=(0, None))(
|
||
np.array([0, 0, 0, 0]), 0)
|
||
expected = np.array([0, 2, 2, 4], dtype='int32')
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
def testWhileLoopBatchedWithConstBody(self):
|
||
def f(x):
|
||
def body_fn(_): return jnp.asarray(0., dtype=jnp.float32)
|
||
def cond_fn(_): return jnp.logical_not(False) == False
|
||
return jax.lax.while_loop(cond_fn, body_fn, x)
|
||
x = jnp.arange(5, dtype=jnp.float32)
|
||
self.assertAllClose(jax.vmap(f)(x), x)
|
||
|
||
def testWhileLoopCondConstsBatched(self):
|
||
def fun(x, y):
|
||
return lax.while_loop(lambda x: x < y, lambda x: x + 2, x)
|
||
|
||
ans = jax.vmap(fun, in_axes=(None, 0))(0, np.array([2, 3]))
|
||
expected = np.array([2, 4])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
def testWhileLoopBodyConstsBatched(self):
|
||
def fun(x, y):
|
||
return lax.while_loop(lambda x: x < 3, lambda x: x + y, x)
|
||
|
||
ans = jax.vmap(fun, in_axes=(None, 0))(0, jnp.array([2, 3]))
|
||
expected = np.array([4, 3])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
def testWhileLoopTupleBatched(self):
|
||
def cond_fun(loop_carry):
|
||
x, y = loop_carry
|
||
return x + y < 5
|
||
|
||
def body_fun(loop_carry):
|
||
x, y = loop_carry
|
||
x = x + 1
|
||
return x, y
|
||
|
||
def fun(x, y):
|
||
return lax.while_loop(cond_fun, body_fun, (x, y))
|
||
|
||
ans = jax.vmap(fun)(np.array([0, 0]), np.array([1, 2]))
|
||
expected = (np.array([4, 3]), np.array([1, 2]))
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
def test_issue_3204(self):
|
||
# Error during XLA code generation for vmap of nested loops
|
||
def test(a, b):
|
||
val = 0
|
||
i = 0
|
||
j = 0
|
||
|
||
condfun_1 = lambda inp: inp[1] < a + 1
|
||
condfun_2 = lambda inp: inp[2] < b + 1
|
||
|
||
def bodyfun_1(inp):
|
||
val, i, j = inp
|
||
j = 0
|
||
|
||
def bodyfun_2(inp):
|
||
val, i, j = inp
|
||
val += i + j
|
||
j += 1
|
||
return (val, i, j)
|
||
|
||
result = lax.while_loop(condfun_2, bodyfun_2, (val, i, j))
|
||
val = result[0]
|
||
i += 1
|
||
return (val, i, j)
|
||
|
||
result = lax.while_loop(condfun_1, bodyfun_1, (val, i, j))
|
||
return result[0]
|
||
|
||
arr = np.arange(5)
|
||
vmap_test = jax.vmap(test, (0, 0))
|
||
vmap_test(arr, arr)
|
||
|
||
def testForiLoopErrors(self):
|
||
"""Test typing error messages for fori_loop."""
|
||
with self.assertRaisesRegex(
|
||
TypeError, "arguments to fori_loop must have equal types"):
|
||
lax.fori_loop(np.int16(0), jnp.int32(10), (lambda i, c: c), jnp.float32(7))
|
||
|
||
def testForiLoopScalarLimits(self):
|
||
"""Test that scalar limits passed to fori_loop do not cause typing errors."""
|
||
body = lambda i, c: c + 1
|
||
init = jnp.float32(10)
|
||
|
||
result = lax.fori_loop(np.int16(0), 10, body, init)
|
||
self.assertEqual(result, init + 10)
|
||
|
||
result = lax.fori_loop(0, np.int16(10), body, init)
|
||
self.assertEqual(result, init + 10)
|
||
|
||
def test_fori_loop_supports_unrolling(self):
|
||
"""Test that we can unroll static fori_loops."""
|
||
body = lambda i, c: c + 1
|
||
init = jnp.float32(10)
|
||
|
||
result = lax.fori_loop(np.int16(0), 10, body, init,
|
||
unroll=3)
|
||
self.assertEqual(result, init + 10)
|
||
|
||
result = lax.fori_loop(0, np.int16(10), body, init,
|
||
unroll=2)
|
||
self.assertEqual(result, init + 10)
|
||
|
||
def test_fori_loop_supports_unrolling_with_bool(self):
|
||
"""Test that we can unroll static fori_loops."""
|
||
body = lambda i, c: c + 1
|
||
init = jnp.float32(10)
|
||
|
||
result = lax.fori_loop(np.int16(0), 10, body, init,
|
||
unroll=True)
|
||
self.assertEqual(result, init + 10)
|
||
|
||
result = lax.fori_loop(0, np.int16(10), body, init,
|
||
unroll=False)
|
||
self.assertEqual(result, init + 10)
|
||
|
||
def test_fori_loop_with_dynamic_indices_cannot_unroll(self):
|
||
"""Test that we can't unroll dynamic fori_loops."""
|
||
body = lambda i, c: c + 1
|
||
init = jnp.float32(10)
|
||
|
||
@jax.jit
|
||
def f(upper):
|
||
return lax.fori_loop(np.int16(0), upper, body, init,
|
||
unroll=3)
|
||
|
||
with self.assertRaisesRegex(ValueError, "Can only use `unroll`"):
|
||
f(10)
|
||
|
||
@parameterized.named_parameters(
|
||
{
|
||
"testcase_name": f"_{jit=}_{upper=}_{unroll=}",
|
||
"jit": jit,
|
||
"upper": upper,
|
||
"unroll": unroll,
|
||
}
|
||
for jit in (False, True)
|
||
for upper in (0, -1)
|
||
for unroll in (False, True)
|
||
)
|
||
def test_fori_loop_returns_init_with_nonpositive_length(
|
||
self, jit, upper, unroll
|
||
):
|
||
"""Test that `length <= 0` behaves like Python `range`."""
|
||
fori_loop_with_static_upper_and_lower = partial(
|
||
lax.fori_loop, 0, upper, lambda i, c: c + 1, unroll=unroll
|
||
)
|
||
if jit:
|
||
fori_loop_with_static_upper_and_lower = jax.jit(
|
||
fori_loop_with_static_upper_and_lower
|
||
)
|
||
init = jnp.float32(10)
|
||
self.assertEqual(fori_loop_with_static_upper_and_lower(init), init)
|
||
|
||
|
||
def testForiLoopBatched(self):
|
||
def body_fun(i, loop_carry):
|
||
x, y = loop_carry
|
||
x = x + 1
|
||
y = y + 2
|
||
return x, y
|
||
|
||
def fun(x):
|
||
return lax.fori_loop(0, 10, body_fun, (x, 0))
|
||
|
||
ans = jax.vmap(fun)(np.array([0, 1]))
|
||
expected = (np.array([10, 11]), np.array([20, 20]))
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
def testForiLoopBatchedIssue1190(self):
|
||
cond_fun = lambda carry: carry[0] < 4
|
||
body_fun = lambda carry: (carry[0] + 1, carry[1] + 1)
|
||
f = lambda x: lax.while_loop(cond_fun, body_fun, (0, x))
|
||
jaxpr = jax.make_jaxpr(jax.vmap(f))(jnp.arange(3))
|
||
eqn = jaxpr.jaxpr.eqns[0]
|
||
self.assertIs(eqn.primitive, lax.while_p)
|
||
self.assertEqual(eqn.params['cond_jaxpr'].in_avals[0].shape, ())
|
||
|
||
def testForiLoopBasic(self):
|
||
def body_fun(i, tot):
|
||
return lax.add(tot, i)
|
||
|
||
def count(num):
|
||
return lax.fori_loop(0, num, body_fun, 0)
|
||
|
||
self.assertEqual(count(2), 1)
|
||
self.assertEqual(count(3), 3)
|
||
self.assertEqual(count(4), 6)
|
||
|
||
for args_maker in [lambda: [2], lambda: [3], lambda: [4]]:
|
||
self._CompileAndCheck(count, args_maker)
|
||
|
||
def testForiLoopClosure(self):
|
||
def count(num):
|
||
def body_fun(i, tot):
|
||
return lax.add(num, lax.add(tot, i))
|
||
return lax.fori_loop(0, num, body_fun, 0)
|
||
|
||
cfun = jax.jit(count)
|
||
|
||
self.assertEqual(count(2), 1 + 2**2)
|
||
self.assertEqual(count(2), cfun(2))
|
||
self.assertEqual(count(3), 3 + 3**2)
|
||
self.assertEqual(count(3), cfun(3))
|
||
self.assertEqual(count(4), 6 + 4**2)
|
||
self.assertEqual(count(4), cfun(4))
|
||
|
||
def testForiLoopTupleState(self):
|
||
def sum_first_n(arr, num):
|
||
def body_fun(i, state):
|
||
arr, total = state
|
||
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
|
||
return (arr, lax.add(total, arr_i))
|
||
|
||
init_val = (arr, arr.dtype.type(0))
|
||
_, total = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun,
|
||
init_val)
|
||
return total
|
||
|
||
cfun = jax.jit(sum_first_n)
|
||
x = self.rng().randn(10).astype(jnp.float_)
|
||
|
||
for num in [0, 5, 10, 15]:
|
||
self.assertAllClose(sum_first_n(x, num), np.sum(x[:num]),
|
||
check_dtypes=False)
|
||
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
||
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
||
|
||
def testForiLoopDictState(self):
|
||
def sum_first_n(arr, num):
|
||
def body_fun(i, state):
|
||
arr, total = state['arr'], state['total']
|
||
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
|
||
return {'arr': arr, 'total': lax.add(total, arr_i)}
|
||
|
||
init_val = {'arr': arr, 'total': arr.dtype.type(0)}
|
||
out_val = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
|
||
return out_val['total']
|
||
|
||
cfun = jax.jit(sum_first_n)
|
||
x = self.rng().randn(10).astype(jnp.float_)
|
||
|
||
for num in [0, 5, 10, 15]:
|
||
self.assertAllClose(sum_first_n(x, num), np.sum(x[:num]),
|
||
check_dtypes=False)
|
||
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
||
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
||
|
||
def testForiLoopEmptyTupleInState(self):
|
||
def sum_first_n(arr, num):
|
||
def body_fun(i, state):
|
||
arr, total, _ = state
|
||
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
|
||
return (arr, lax.add(total, arr_i), ())
|
||
|
||
init_val = (arr, arr.dtype.type(0), ())
|
||
_, tot, _ = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
|
||
return tot
|
||
|
||
cfun = jax.jit(sum_first_n)
|
||
x = self.rng().randn(10).astype(jnp.float_)
|
||
|
||
for num in [0, 5, 10, 15]:
|
||
self.assertAllClose(sum_first_n(x, num), np.sum(x[:num]),
|
||
check_dtypes=False)
|
||
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
||
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
||
|
||
def testForiLoopIssue8152(self):
|
||
y = lax.fori_loop(lower=0, upper=0, body_fun=lambda x, i: x + i, init_val=1.)
|
||
self.assertAllClose(y, 1., check_dtypes=False)
|
||
|
||
# trivial fori_loop should work - even when jit is disabled
|
||
with jax.disable_jit():
|
||
y = lax.fori_loop(lower=0, upper=0, body_fun=lambda x, i: x + i, init_val=1.)
|
||
self.assertAllClose(y, 1., check_dtypes=False)
|
||
|
||
# scan with length 0 should work with jit, but raise an error without
|
||
def should_raise_wo_jit():
|
||
carry, out = lax.scan(lambda c, x: (c + x, x), 0., np.array([]))
|
||
return carry
|
||
self.assertAllClose(should_raise_wo_jit(), 0., check_dtypes=False)
|
||
with jax.disable_jit():
|
||
self.assertRaises(ValueError, should_raise_wo_jit)
|
||
|
||
def testCond(self):
|
||
def fun(x):
|
||
if x < 3:
|
||
return (x, x)
|
||
else:
|
||
y = lax.mul(2, x)
|
||
return y, lax.mul(2, y)
|
||
|
||
@jax.jit
|
||
def cfun(x):
|
||
def false_fun(x):
|
||
y = lax.mul(2, x)
|
||
return y, lax.mul(2, y)
|
||
return lax.cond(lax.lt(x, 3), lambda x: (x, x), false_fun, x)
|
||
|
||
self.assertEqual(fun(0), cfun(0))
|
||
self.assertEqual(fun(0), (0, 0))
|
||
self.assertEqual(fun(1), cfun(1))
|
||
self.assertEqual(fun(1), (1, 1))
|
||
self.assertEqual(fun(2), cfun(2))
|
||
self.assertEqual(fun(2), (2, 2))
|
||
self.assertEqual(fun(3), cfun(3))
|
||
self.assertEqual(fun(3), (6, 12))
|
||
self.assertEqual(fun(4), cfun(4))
|
||
self.assertEqual(fun(4), (8, 16))
|
||
|
||
def testCondPredIsNone(self):
|
||
# see https://github.com/jax-ml/jax/issues/11574
|
||
def f(pred, x):
|
||
return lax.cond(pred, lambda x: x + 1, lambda x: x + 2, x)
|
||
|
||
self.assertRaisesRegex(TypeError, "cond predicate is None",
|
||
lambda: f(None, 1.))
|
||
self.assertRaisesRegex(TypeError, "cond predicate is None",
|
||
lambda: jax.jit(f)(None, 1.))
|
||
|
||
def testCondTwoOperands(self):
|
||
# see https://github.com/jax-ml/jax/issues/8469
|
||
add, mul = lax.add, lax.mul
|
||
|
||
def fun(x):
|
||
return add(x, x) if x == 0 else mul(x, x)
|
||
|
||
def cfun(x):
|
||
return lax.cond(x == 0, add, mul, x, x)
|
||
|
||
self.assertEqual(fun(0), cfun(0))
|
||
self.assertEqual(fun(1), cfun(1))
|
||
cfun = jax.jit(cfun)
|
||
self.assertEqual(fun(0), cfun(0))
|
||
self.assertEqual(fun(1), cfun(1))
|
||
|
||
def testCondThreeOperands(self):
|
||
add = lambda x, y, z: x + y + z
|
||
mul = lambda x, y, z: x * y * z
|
||
|
||
def fun(x):
|
||
return add(x, x, x) if x == 0 else mul(x, x, x)
|
||
|
||
def cfun(x):
|
||
return lax.cond(x == 0, add, mul, x, x, x)
|
||
|
||
self.assertEqual(fun(0), cfun(0))
|
||
self.assertEqual(fun(1), cfun(1))
|
||
cfun = jax.jit(cfun)
|
||
self.assertEqual(fun(0), cfun(0))
|
||
self.assertEqual(fun(1), cfun(1))
|
||
|
||
def testCondCallableOperands(self):
|
||
# see https://github.com/jax-ml/jax/issues/16413
|
||
|
||
@tree_util.register_pytree_node_class
|
||
class Foo:
|
||
def __init__(self, x):
|
||
self.x = x
|
||
|
||
def __call__(self, *xs):
|
||
assert False
|
||
return xs
|
||
|
||
def tree_flatten(self):
|
||
return (self.x,), None
|
||
|
||
@classmethod
|
||
def tree_unflatten(cls, _, xs):
|
||
return cls(*xs)
|
||
|
||
f_00 = lambda a, b: a + b
|
||
f_01 = lambda a, b: a + b.x
|
||
f_10 = lambda a, b: a.x + b
|
||
f_11 = lambda a, b: a.x + b.x
|
||
|
||
# these don't raise
|
||
a = lax.cond(True, f_00, f_00, 3, 4)
|
||
b = lax.cond(True, f_01, f_01, 3, Foo(4))
|
||
c = lax.cond(True, f_10, f_10, Foo(3), 4)
|
||
d = lax.cond(True, f_11, f_11, Foo(3), Foo(4))
|
||
self.assertEqual(a, b)
|
||
self.assertEqual(a, c)
|
||
self.assertEqual(a, d)
|
||
|
||
def testSwitch(self):
|
||
def branch(x):
|
||
y = lax.mul(2, x)
|
||
return y, lax.mul(2, y)
|
||
|
||
branches = [lambda x: (x, x),
|
||
branch,
|
||
lambda x: (x, -x)]
|
||
|
||
def fun(x):
|
||
if x <= 0:
|
||
return branches[0](x)
|
||
elif x == 1:
|
||
return branches[1](x)
|
||
else:
|
||
return branches[2](x)
|
||
|
||
def cfun(x):
|
||
return lax.switch(x, branches, x)
|
||
|
||
self.assertEqual(fun(-1), cfun(-1))
|
||
self.assertEqual(fun(0), cfun(0))
|
||
self.assertEqual(fun(1), cfun(1))
|
||
self.assertEqual(fun(2), cfun(2))
|
||
self.assertEqual(fun(3), cfun(3))
|
||
|
||
cfun = jax.jit(cfun)
|
||
|
||
self.assertEqual(fun(-1), cfun(-1))
|
||
self.assertEqual(fun(0), cfun(0))
|
||
self.assertEqual(fun(1), cfun(1))
|
||
self.assertEqual(fun(2), cfun(2))
|
||
self.assertEqual(fun(3), cfun(3))
|
||
|
||
def testSwitchMultiOperands(self):
|
||
branches = [lax.add, lax.mul]
|
||
|
||
def fun(x):
|
||
i = 0 if x <= 0 else 1
|
||
return branches[i](x, x)
|
||
|
||
def cfun(x):
|
||
return lax.switch(x, branches, x, x)
|
||
|
||
self.assertEqual(fun(-1), cfun(-1))
|
||
self.assertEqual(fun(0), cfun(0))
|
||
self.assertEqual(fun(1), cfun(1))
|
||
self.assertEqual(fun(2), cfun(2))
|
||
cfun = jax.jit(cfun)
|
||
self.assertEqual(fun(-1), cfun(-1))
|
||
self.assertEqual(fun(0), cfun(0))
|
||
self.assertEqual(fun(1), cfun(1))
|
||
self.assertEqual(fun(2), cfun(2))
|
||
|
||
def testSwitchResidualsMerge(self):
|
||
def get_conds(fun):
|
||
jaxpr = jax.make_jaxpr(jax.grad(fun))(0., 0)
|
||
return [eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == 'cond']
|
||
|
||
def branch_invars_len(cond_eqn):
|
||
lens = [len(jaxpr.jaxpr.invars) for jaxpr in cond_eqn.params['branches']]
|
||
assert len(set(lens)) == 1
|
||
return lens[0]
|
||
|
||
def branch_outvars_len(cond_eqn):
|
||
lens = [len(jaxpr.jaxpr.outvars) for jaxpr in cond_eqn.params['branches']]
|
||
assert len(set(lens)) == 1
|
||
return lens[0]
|
||
|
||
branches1 = [
|
||
lambda x: jnp.sin(x),
|
||
lambda x: jnp.cos(x)] # branch residuals overlap, should be reused
|
||
branches2 = branches1 + [
|
||
lambda x: jnp.sinh(x)] # another overlapping residual, expect reuse
|
||
branches3 = branches2 + [
|
||
lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
|
||
def fun1(x, i):
|
||
return lax.switch(i + 1, branches1, x)
|
||
def fun2(x, i):
|
||
return lax.switch(i + 1, branches2, x)
|
||
def fun3(x, i):
|
||
return lax.switch(i + 1, branches3, x)
|
||
|
||
fwd1, bwd1 = get_conds(fun1)
|
||
fwd2, bwd2 = get_conds(fun2)
|
||
fwd3, bwd3 = get_conds(fun3)
|
||
|
||
fwd1_num_out = branch_outvars_len(fwd1)
|
||
fwd2_num_out = branch_outvars_len(fwd2)
|
||
fwd3_num_out = branch_outvars_len(fwd3)
|
||
assert fwd1_num_out == fwd2_num_out
|
||
assert fwd3_num_out == fwd2_num_out + 1
|
||
|
||
bwd1_num_in = branch_invars_len(bwd1)
|
||
bwd2_num_in = branch_invars_len(bwd2)
|
||
bwd3_num_in = branch_invars_len(bwd3)
|
||
assert bwd1_num_in == bwd2_num_in
|
||
assert bwd3_num_in == bwd2_num_in + 1
|
||
|
||
def testOneBranchSwitch(self):
|
||
branch = lambda x: -x
|
||
f = lambda i, x: lax.switch(i, [branch], x)
|
||
x = 7.
|
||
self.assertEqual(f(-1, x), branch(x))
|
||
self.assertEqual(f(0, x), branch(x))
|
||
self.assertEqual(f(1, x), branch(x))
|
||
cf = jax.jit(f)
|
||
self.assertEqual(cf(-1, x), branch(x))
|
||
self.assertEqual(cf(0, x), branch(x))
|
||
self.assertEqual(cf(1, x), branch(x))
|
||
cf = jax.jit(f, static_argnums=0)
|
||
self.assertEqual(cf(-1, x), branch(x))
|
||
self.assertEqual(cf(0, x), branch(x))
|
||
self.assertEqual(cf(1, x), branch(x))
|
||
|
||
def testIssue1379(self):
|
||
def fun(pred):
|
||
return lax.cond(pred, lambda x: (True, x), lambda x: (False, x), pred)
|
||
|
||
@jax.jit
|
||
def cfun(pred):
|
||
return fun(pred)
|
||
|
||
self.assertEqual(fun(0), cfun(0), (False,0))
|
||
self.assertEqual(fun(0.), cfun(0.), (False,0.))
|
||
self.assertEqual(fun(1), cfun(1), (True,1))
|
||
self.assertEqual(fun(1.), cfun(1.), (True,1.))
|
||
|
||
# test that proper errors are raised for wrong types
|
||
for pred in ["abc", [], [1,2]]:
|
||
for f in [fun, cfun]:
|
||
self.assertRaises(TypeError, f, pred)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{name}", "cond": cond}
|
||
for cond, name in COND_IMPLS)
|
||
def testNestedCond(self, cond):
|
||
def fun(x):
|
||
if x < 2:
|
||
return lax.mul(2, x)
|
||
else:
|
||
if x < 5:
|
||
return lax.mul(3, x)
|
||
else:
|
||
return lax.mul(4, x)
|
||
|
||
@jax.jit
|
||
def cfun(x):
|
||
return cond(
|
||
lax.lt(x, 2),
|
||
lambda x: lax.mul(2, x),
|
||
lambda x: cond(lax.lt(x, 5),
|
||
x, lambda x: lax.mul(3, x),
|
||
4, lambda y: lax.mul(y, x)),
|
||
x)
|
||
|
||
self.assertEqual(cfun(1), 2)
|
||
self.assertEqual(cfun(3), 9)
|
||
self.assertEqual(cfun(6), 24)
|
||
self.assertEqual(cfun(1), fun(1))
|
||
self.assertEqual(cfun(3), fun(3))
|
||
self.assertEqual(cfun(6), fun(6))
|
||
|
||
def testCondTypeErrors(self):
|
||
"""Test typing error messages for cond."""
|
||
with self.assertRaisesRegex(TypeError,
|
||
re.escape("Pred type must be either boolean or number, got <function")):
|
||
lax.cond(lambda x: True, lambda top: 2., lambda fop: 3., 1.)
|
||
with self.assertRaisesRegex(TypeError,
|
||
re.escape("Pred must be a scalar, got foo of type <class 'str'>")):
|
||
lax.cond("foo", lambda top: 2., lambda fop: 3., 1.)
|
||
with self.assertRaisesRegex(TypeError,
|
||
re.escape("Pred must be a scalar, got (1.0, 1.0) of type <class 'tuple'>")):
|
||
lax.cond((1., 1.), lambda top: 2., lambda fop: 3., 1.)
|
||
with self.assertRaisesRegex(TypeError,
|
||
re.compile("true_fun output must have same type structure "
|
||
"as false_fun output, but there are differences:.*"
|
||
r"at output\['a'\], true_fun output has pytree leaf", re.DOTALL)):
|
||
lax.cond(True, lambda top: dict(a=2.), lambda fop: dict(a=(3., 3.)), 1.)
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
"true_fun output and false_fun output must have identical types, got\n"
|
||
r"DIFFERENT ShapedArray\(float32\[1\]\) vs. "
|
||
r"ShapedArray\(float32\[\].*\)."):
|
||
lax.cond(True,
|
||
lambda top: jnp.array([1.], jnp.float32),
|
||
lambda fop: jnp.float32(1.),
|
||
1.)
|
||
|
||
def testSwitchErrors(self):
|
||
"""Test typing error messages for switch."""
|
||
with self.assertRaisesRegex(TypeError,
|
||
re.escape("Index type must be an integer, got <function")):
|
||
lax.switch(lambda x: True, [lambda _: 2., lambda _: 3.], 1.)
|
||
with self.assertRaisesRegex(TypeError,
|
||
re.escape("Index type must be an integer, got foo.")):
|
||
lax.switch("foo", [lambda _: 2., lambda _: 3.], 1.)
|
||
with self.assertRaisesRegex(TypeError,
|
||
re.escape("Branch index must be scalar, got (1.0, 1.0) of shape (2,).")):
|
||
lax.switch((1., 1.), [lambda _: 2., lambda _: 3.], 1.)
|
||
with self.assertRaisesRegex(ValueError,
|
||
re.escape("Empty branch sequence")):
|
||
lax.switch(0, [], 1.)
|
||
with self.assertRaisesRegex(TypeError,
|
||
re.compile("branch 0 output must have same type structure "
|
||
"as branch 1 output, but there are differences:.*"
|
||
r"at output\['a'\], branch 0 output has pytree leaf", re.DOTALL)):
|
||
lax.switch(1, [lambda _: dict(a=2.), lambda _: dict(a=(3., 3.))], 1.)
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
"branch 0 output and branch 1 output must have identical types, got\n"
|
||
r"{'a': 'DIFFERENT ShapedArray\(float32\[1\]\) "
|
||
r"vs. ShapedArray\(float32\[\].*\)'}."):
|
||
lax.switch(1, [lambda _: dict(a=jnp.array([1.], jnp.float32)),
|
||
lambda _: dict(a=jnp.float32(1.))],
|
||
1.)
|
||
|
||
def testCondOneBranchConstant(self):
|
||
def fun(x):
|
||
if x < 3:
|
||
return 5.
|
||
else:
|
||
return x
|
||
|
||
@jax.jit
|
||
def cfun(x):
|
||
return lax.cond(lax.lt(x, 3), lambda x: 5, lambda x: x, x)
|
||
|
||
self.assertEqual(fun(0), cfun(0))
|
||
self.assertEqual(cfun(0), 5)
|
||
self.assertEqual(fun(4), cfun(4))
|
||
self.assertEqual(cfun(4), 4)
|
||
|
||
def testCondOneBranchConstantTuple(self):
|
||
def fun(x):
|
||
if x < 3:
|
||
return (1., 2., 3.)
|
||
else:
|
||
return (x, 2., 4.)
|
||
|
||
@jax.jit
|
||
def cfun(x):
|
||
return lax.cond(lax.lt(x, 3),
|
||
lambda x: (1, 2., 3.),
|
||
lambda x: (x, 2., 4.),
|
||
x)
|
||
|
||
self.assertEqual(fun(0), cfun(0))
|
||
self.assertEqual(cfun(0), (1, 2., 3.))
|
||
self.assertEqual(fun(4), cfun(4))
|
||
self.assertEqual(cfun(4), (4, 2., 4.))
|
||
|
||
def testCondBatched(self):
|
||
def fun(x, y, z):
|
||
pred = lax.lt(x, 3)
|
||
true_fun = lambda y: y
|
||
false_fun = lambda z: lax.neg(z)
|
||
return lax.cond(pred, y, true_fun, z, false_fun)
|
||
|
||
# these cases stay as cond
|
||
x = jnp.array(2)
|
||
y = jnp.array([1, 2])
|
||
z = jnp.array([3, 4])
|
||
ans = jax.vmap(fun, (None, 0, 0))(x, y, z)
|
||
jaxpr = jax.make_jaxpr(jax.vmap(fun, (None, 0, 0)))(x, y, z)
|
||
expected = np.array([1, 2])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
assert "select" not in str(jaxpr)
|
||
|
||
x = jnp.array(4)
|
||
ans = jax.vmap(fun, (None, 0, 0))(x, y, z)
|
||
jaxpr = jax.make_jaxpr(jax.vmap(fun, (None, 0, 0)))(x, y, z)
|
||
expected = np.array([-3, -4])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
assert "select" not in str(jaxpr)
|
||
|
||
fun = jax.jit(fun)
|
||
ans = jax.vmap(fun, (None, 0, 0))(x, y, z)
|
||
expected = np.array([-3, -4])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
z = jnp.array(5)
|
||
ans = jax.vmap(fun, (None, 0, None))(x, y, z)
|
||
jaxpr = jax.make_jaxpr(jax.vmap(fun, (None, 0, None)))(x, y, z)
|
||
expected = np.array([-5, -5])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
assert "select" not in str(jaxpr)
|
||
|
||
# these cases become select
|
||
x = jnp.array([2, 4])
|
||
ans = jax.vmap(fun, (0, 0, None))(x, y, z)
|
||
jaxpr = jax.make_jaxpr(jax.vmap(fun, (0, 0, None)))(x, y, z)
|
||
expected = np.array([1, -5])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
assert "select" in str(jaxpr)
|
||
|
||
z = jnp.array([3, 4])
|
||
ans = jax.vmap(fun)(x, y, z)
|
||
jaxpr = jax.make_jaxpr(jax.vmap(fun))(x, y, z)
|
||
expected = np.array([1, -4])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
assert "select" in str(jaxpr)
|
||
|
||
def testSwitchBatched(self):
|
||
def fun(index, x, y, z):
|
||
branches = [lambda xyz: xyz[0],
|
||
lambda xyz: lax.neg(xyz[1]),
|
||
lambda xyz: lax.sign(xyz[2])]
|
||
return lax.switch(index, branches, (x, y, z))
|
||
|
||
# these cases stay as cond
|
||
x = jnp.array(0)
|
||
y = jnp.array([1, 2])
|
||
z = jnp.array([3, 4])
|
||
w = jnp.array(9)
|
||
ans = jax.vmap(fun, (None, 0, 0, None))(x, y, z, w)
|
||
jaxpr = jax.make_jaxpr(jax.vmap(fun, (None, 0, 0, None)))(x, y, z, w)
|
||
expected = np.array([1, 2])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
assert "select" not in str(jaxpr)
|
||
|
||
x = jnp.array(1)
|
||
ans = jax.vmap(fun, (None, 0, 0, None))(x, y, z, w)
|
||
jaxpr = jax.make_jaxpr(jax.vmap(fun, (None, 0, 0, None)))(x, y, z, w)
|
||
expected = np.array([-3, -4])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
assert "select" not in str(jaxpr)
|
||
|
||
fun = jax.jit(fun)
|
||
ans = jax.vmap(fun, (None, 0, 0, None))(x, y, z, w)
|
||
expected = np.array([-3, -4])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
z = jnp.array(5)
|
||
ans = jax.vmap(fun, (None, 0, None, None))(x, y, z, w)
|
||
jaxpr = jax.make_jaxpr(jax.vmap(fun, (None, 0, None, None)))(x, y, z, w)
|
||
expected = np.array([-5, -5])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
assert "select" not in str(jaxpr)
|
||
|
||
# these cases become select
|
||
x = jnp.array([0, 1])
|
||
ans = jax.vmap(fun, (0, 0, None, None))(x, y, z, w)
|
||
jaxpr = jax.make_jaxpr(jax.vmap(fun, (0, 0, None, None)))(x, y, z, w)
|
||
expected = np.array([1, -5])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
assert "select" in str(jaxpr)
|
||
|
||
z = jnp.array([3, 4])
|
||
w = jnp.array([9, 9])
|
||
ans = jax.vmap(fun)(x, y, z, w)
|
||
jaxpr = jax.make_jaxpr(jax.vmap(fun))(x, y, z, w)
|
||
expected = np.array([1, -4])
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
assert "select" in str(jaxpr)
|
||
|
||
def testCondJVP(self):
|
||
def fun_ref(x):
|
||
if x < 3:
|
||
return (x, x)
|
||
else:
|
||
y = 2 * x
|
||
return y, 2 * y
|
||
|
||
def fun(x):
|
||
def false_fun(x):
|
||
y = 2 * x
|
||
return y, 2 * y
|
||
return lax.cond(x < 3, lambda x: (x, x), false_fun, x)
|
||
|
||
x = 3.14
|
||
ans = jax.jvp(fun, (x,), (x,))
|
||
expected = jax.jvp(fun_ref, (x,), (x,))
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
jtu.check_grads(fun, (x,), order=2, modes=["fwd"])
|
||
|
||
x = 2.72
|
||
ans = jax.jvp(fun, (x,), (x,))
|
||
expected = jax.jvp(fun_ref, (x,), (x,))
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
jtu.check_grads(fun, (x,), order=2, modes=["fwd"])
|
||
|
||
def testSwitchJVP(self):
|
||
def branch(x):
|
||
y = 2 * x
|
||
return y, 2 * y
|
||
|
||
branches = [lambda x: (x, x),
|
||
branch,
|
||
lambda x: (x, -x)]
|
||
|
||
def fun_ref(x):
|
||
idx = x // 1
|
||
if idx <= 0:
|
||
return branches[0](x)
|
||
elif idx == 1:
|
||
return branches[1](x)
|
||
else:
|
||
return branches[2](x)
|
||
|
||
def fun(x):
|
||
idx = lax.convert_element_type(x // 1, np.int32)
|
||
return lax.switch(idx, branches, x)
|
||
|
||
for x in [-0.7, 0.7, 1.7, 2.7, 3.7]:
|
||
ans = jax.jvp(fun, (x,), (x,))
|
||
expected = jax.jvp(fun_ref, (x,), (x,))
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
jtu.check_grads(fun, (x,), order=2, modes=["fwd"])
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{name}", "cond": cond}
|
||
for cond, name in COND_IMPLS)
|
||
def testCondJVP2(self, cond):
|
||
def fun_ref(x):
|
||
if x < 3:
|
||
return 2.
|
||
else:
|
||
return 2. * x
|
||
|
||
def fun(x):
|
||
return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x)
|
||
|
||
x = 3.14
|
||
ans = jax.jvp(fun, (x,), (x,))
|
||
expected = jax.jvp(fun_ref, (x,), (x,))
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
jtu.check_grads(fun, (x,), order=2, modes=["fwd"])
|
||
|
||
x = 2.72
|
||
ans = jax.jvp(fun, (x,), (x,))
|
||
expected = jax.jvp(fun_ref, (x,), (x,))
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
jtu.check_grads(fun, (x,), order=2, modes=["fwd"])
|
||
|
||
def testCondGrad(self):
|
||
def f_ref(x):
|
||
return 3. * x if x < 2 else jnp.sin(x)
|
||
|
||
@jax.jit
|
||
def f(x):
|
||
return lax.cond(x < 2, lambda x: 3. * x, lambda x: jnp.sin(x), x)
|
||
|
||
x = 2.14
|
||
ans = jax.grad(f)(x)
|
||
expected = jax.grad(f_ref)(x)
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"])
|
||
|
||
x = 1.72
|
||
ans = jax.grad(f)(x)
|
||
expected = jax.grad(f_ref)(x)
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"])
|
||
|
||
def testCondGradVmapNan(self):
|
||
eps = 1e-3
|
||
|
||
def safe1(x):
|
||
return lax.cond(x < eps, lambda _: eps, lambda _: jnp.sqrt(x), ())
|
||
|
||
out = jax.grad(lambda x: jax.vmap(safe1)(x).sum())(np.zeros(10))
|
||
self.assertFalse(np.isnan(out).any())
|
||
|
||
def testSwitchGrad(self):
|
||
branches = [lambda x: 3. * x,
|
||
lambda x: jnp.sin(x),
|
||
lambda x: -x]
|
||
|
||
def f_ref(x):
|
||
idx = x // 1
|
||
if idx <= 0:
|
||
return branches[0](x)
|
||
elif idx == 1:
|
||
return branches[1](x)
|
||
else:
|
||
return branches[2](x)
|
||
|
||
@jax.jit
|
||
def f(x):
|
||
idx = lax.convert_element_type(x // 1, np.int32)
|
||
return lax.switch(idx, branches, x)
|
||
|
||
for x in [-0.7, 0.7, 1.7, 2.7, 3.7]:
|
||
ans = jax.grad(f)(x)
|
||
expected = jax.grad(f_ref)(x)
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"])
|
||
|
||
def testSwitchGradWithWeakTypeMismatch(self): # issue #4696, PR #4896
|
||
dtype = dtypes.canonicalize_dtype(np.float64)
|
||
dtype = jnp.float32 if dtype == jnp.float32 else jnp.float64
|
||
|
||
branches = [
|
||
lambda x: x, # This preserves the weak type of x.
|
||
lambda x: x + dtype(1), # This strips the weak type of x.
|
||
]
|
||
|
||
def f_ref(x):
|
||
i = x.astype(jnp.int32)
|
||
return branches[i](x)
|
||
|
||
def f(x):
|
||
return lax.switch(x.astype(jnp.int32), branches, x)
|
||
|
||
for x in [0., 1.]:
|
||
ans = jax.grad(f)(x)
|
||
expected = jax.grad(f_ref)(x)
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{name}", "cond": cond}
|
||
for cond, name in COND_IMPLS)
|
||
def testCondGrad2(self, cond):
|
||
def f_ref(x):
|
||
z = jnp.array([1., 2.], x.dtype) * x if x[0] < 2 else jnp.sin(x)
|
||
return z.sum()
|
||
|
||
def _f(x):
|
||
return cond(
|
||
x[0] < 2,
|
||
lambda x: jnp.array([1., 2.], x.dtype) * x,
|
||
lambda x: jnp.sin(x),
|
||
x)
|
||
|
||
f = lambda x: jax.jit(_f)(x).sum()
|
||
|
||
x = 2.14 * jnp.ones(2)
|
||
ans = jax.grad(f)(x)
|
||
expected = jax.grad(f_ref)(x)
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"])
|
||
|
||
x = 1.72 * jnp.ones(2)
|
||
ans = jax.grad(f)(x)
|
||
expected = jax.grad(f_ref)(x)
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"],
|
||
rtol={jnp.float32: 1e-2, jnp.float64: 2e-3})
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{name}", "cond": cond}
|
||
for cond, name in COND_IMPLS)
|
||
def testCondGrad3(self, cond):
|
||
def fun_ref(x):
|
||
if x < 3:
|
||
return 2.
|
||
else:
|
||
return 2. * x
|
||
|
||
def fun(x):
|
||
return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x)
|
||
|
||
x = 3.14
|
||
ans = jax.grad(fun)(x)
|
||
expected = jax.grad(fun_ref)(x)
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
jtu.check_grads(fun, (x,), order=2, modes=["fwd", "rev"])
|
||
|
||
x = 2.72
|
||
ans = jax.grad(fun)(x)
|
||
expected = jax.grad(fun_ref)(x)
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
jtu.check_grads(fun, (x,), order=2, modes=["fwd", "rev"])
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{name}", "cond": cond}
|
||
for cond, name in COND_IMPLS)
|
||
def testCondGrad4(self, cond):
|
||
if cond is cond_with_new_checkpoint and jtu.test_device_matches(['tpu']):
|
||
raise unittest.SkipTest("tpu bug") # TODO(parkers): tpu bug exhibited here
|
||
def fun_ref(x, y):
|
||
if x < 3:
|
||
return 2. * jnp.sin(y)
|
||
else:
|
||
return 2. * jnp.cos(x)
|
||
|
||
@jax.jit
|
||
def fun(x, y):
|
||
return cond(
|
||
x < 3,
|
||
None, lambda _: 2. * jnp.sin(y),
|
||
x, lambda x: 2. * x)
|
||
|
||
y = 5.8
|
||
x = 3.14
|
||
ans = jax.grad(fun, 1)(x, y)
|
||
expected = jax.grad(fun_ref, 1)(x, y)
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
jtu.check_grads(fun, (x, y), order=2, modes=["fwd", "rev"])
|
||
|
||
x = 2.72
|
||
ans = jax.grad(fun, 1)(x, y)
|
||
expected = jax.grad(fun_ref, 1)(x, y)
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
jtu.check_grads(fun, (x, y), order=2, modes=["fwd", "rev"])
|
||
|
||
def testCondLinearize(self):
|
||
def f(x):
|
||
return lax.cond(x < 2, lambda x: 3. * x, lambda x: jnp.sin(x), x)
|
||
y, f_lin = jax.linearize(f, 1.)
|
||
self.assertAllClose(y, 3., check_dtypes=False)
|
||
self.assertAllClose(f_lin(2.), 6., check_dtypes=False)
|
||
y, f_lin = jax.linearize(f, 4.)
|
||
self.assertAllClose(y, jnp.sin(4.), check_dtypes=False)
|
||
self.assertAllClose(f_lin(2.), jnp.cos(4.) * 2., check_dtypes=False)
|
||
|
||
def testSwitchLinearize(self):
|
||
branches = [lambda x: 3. * x,
|
||
lambda x: jnp.sin(x),
|
||
lambda x: -x]
|
||
def f(x):
|
||
idx = lax.convert_element_type(x // 1, np.int32)
|
||
return lax.switch(idx, branches, x)
|
||
|
||
# branch 0
|
||
y, f_lin = jax.linearize(f, -1.)
|
||
self.assertAllClose(y, -3., check_dtypes=False)
|
||
self.assertAllClose(f_lin(2.), 6., check_dtypes=False)
|
||
y, f_lin = jax.linearize(f, 0.)
|
||
self.assertAllClose(y, 0., check_dtypes=False)
|
||
self.assertAllClose(f_lin(2.), 6., check_dtypes=False)
|
||
|
||
# branch 1
|
||
y, f_lin = jax.linearize(f, 1.)
|
||
self.assertAllClose(y, jnp.sin(1.), check_dtypes=False)
|
||
self.assertAllClose(f_lin(2.), jnp.cos(1.) * 2., check_dtypes=False)
|
||
|
||
# branch 2
|
||
y, f_lin = jax.linearize(f, 2.)
|
||
self.assertAllClose(y, -2., check_dtypes=False)
|
||
self.assertAllClose(f_lin(2.), -2., check_dtypes=False)
|
||
y, f_lin = jax.linearize(f, 3.)
|
||
self.assertAllClose(y, -3., check_dtypes=False)
|
||
self.assertAllClose(f_lin(2.), -2., check_dtypes=False)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{name}", "cond": cond}
|
||
for cond, name in COND_IMPLS)
|
||
def testCondLinearize2(self, cond):
|
||
def f_ref(x):
|
||
z = jnp.array([1., 2.], x.dtype) * x if x[0] < 2 else jnp.cos(jnp.sin(x))
|
||
return z.sum()
|
||
|
||
def f(x):
|
||
return cond(
|
||
x[0] < 2,
|
||
lambda x: jnp.array([1., 2.], x.dtype) * x,
|
||
lambda x: jnp.cos(jnp.sin(x)),
|
||
x).sum()
|
||
|
||
x = 2.14 * jnp.ones(2)
|
||
y, f_lin = jax.linearize(f, x)
|
||
y_ref, f_lin_ref = jax.linearize(f_ref, x)
|
||
self.assertAllClose(y, y_ref, check_dtypes=False)
|
||
self.assertAllClose(f_lin(x), f_lin_ref(x), check_dtypes=False)
|
||
|
||
x = -2.14 * jnp.ones(2)
|
||
y, f_lin = jax.linearize(f, x)
|
||
y_ref, f_lin_ref = jax.linearize(f_ref, x)
|
||
self.assertAllClose(y, y_ref, check_dtypes=False)
|
||
self.assertAllClose(f_lin(x), f_lin_ref(x), check_dtypes=False)
|
||
|
||
f = jax.jit(f)
|
||
x = 2.14 * jnp.ones(2)
|
||
y, f_lin = jax.linearize(f, x)
|
||
y_ref, f_lin_ref = jax.linearize(f_ref, x)
|
||
self.assertAllClose(y, y_ref, check_dtypes=False)
|
||
self.assertAllClose(f_lin(x), f_lin_ref(x), check_dtypes=False)
|
||
|
||
def testCondJit(self):
|
||
def f(x):
|
||
return lax.cond(x < 2, lambda x: 3. * x, lambda x: jnp.sin(x), x)
|
||
y = jax.jit(f)(1.)
|
||
expected = f(1.)
|
||
self.assertAllClose(y, expected, check_dtypes=False)
|
||
y = jax.jit(f)(4.)
|
||
expected = f(4.)
|
||
self.assertAllClose(y, expected, check_dtypes=False)
|
||
|
||
def testSwitchJit(self):
|
||
branches = [lambda x: 3. * x,
|
||
lambda x: jnp.sin(x),
|
||
lambda x: -x]
|
||
def f(x):
|
||
idx = lax.convert_element_type(x // 1, np.int32)
|
||
return lax.switch(idx, branches, x)
|
||
for x in [-1., 0., 1., 2., 3.]:
|
||
y = jax.jit(f)(x)
|
||
expected = f(x)
|
||
self.assertAllClose(y, expected, check_dtypes=False)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{name}", "cond": cond}
|
||
for cond, name in COND_IMPLS)
|
||
def testCondJitDisabled(self, cond):
|
||
def f_ref(x):
|
||
return 3. * x if x < 2 else jnp.sin(x)
|
||
def f(x):
|
||
return cond(x < 2, lambda x: 3. * x, lambda x: jnp.sin(x), x)
|
||
|
||
with jax.disable_jit():
|
||
y = f(1.)
|
||
expected = f_ref(1.)
|
||
self.assertAllClose(y, expected, check_dtypes=False)
|
||
|
||
with jax.disable_jit():
|
||
y = jax.jit(f)(1.)
|
||
expected = f(1.)
|
||
self.assertAllClose(y, expected, check_dtypes=False)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{name}", "cond": cond}
|
||
for cond, name in COND_IMPLS)
|
||
def testCondWithConsts(self, cond):
|
||
def f(x):
|
||
return cond(x < 2,
|
||
lambda x: np.array([1., 2.]) * x,
|
||
lambda x: np.array([3., 4.]) * jnp.sin(x),
|
||
x)
|
||
|
||
def f_ref(x):
|
||
if x < 2:
|
||
return np.array([1., 2.]) * x
|
||
else:
|
||
return np.array([3., 4.]) * np.sin(x)
|
||
|
||
y = f(1.)
|
||
expected = f_ref(1.)
|
||
self.assertAllClose(y, expected, check_dtypes=False)
|
||
y = f(4.)
|
||
expected = f_ref(4.)
|
||
self.assertAllClose(y, expected, check_dtypes=False)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{name}", "cond": cond}
|
||
for cond, name in COND_IMPLS)
|
||
def testCondJitWithConsts(self, cond):
|
||
def f(x):
|
||
return cond(x < 2,
|
||
lambda x: np.array([1., 2.]) * x,
|
||
lambda x: np.array([3., 4.]) * jnp.sin(x),
|
||
x)
|
||
|
||
y = jax.jit(f)(1.)
|
||
expected = f(1.)
|
||
self.assertAllClose(y, expected, check_dtypes=False)
|
||
y = jax.jit(f)(4.)
|
||
expected = f(4.)
|
||
self.assertAllClose(y, expected, check_dtypes=False)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{name}", "cond": cond}
|
||
for cond, name in COND_IMPLS)
|
||
def testCondVmapGrad(self, cond):
|
||
# https://github.com/jax-ml/jax/issues/2264
|
||
def f_1(x): return x ** 2
|
||
def f_2(x): return x ** 3
|
||
|
||
def f(x): return cond(x > 0, f_1, f_2, x)
|
||
def g(x): return jnp.where(x > 0, f_1(x), f_2(x))
|
||
|
||
x = jnp.linspace(-1, 1, 20)
|
||
ans = jax.vmap(jax.grad(f))(x)
|
||
expected = jax.vmap(jax.grad(g))(x)
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
@jax.legacy_prng_key('allow')
|
||
def testIssue1263(self):
|
||
def f(rng, x):
|
||
cond = random.bernoulli(rng)
|
||
return lax.cond(cond, x, lambda x: x, jnp.abs(x) - 1., lambda x: x)
|
||
|
||
def body_fn(i, state):
|
||
rng, x = state
|
||
key, subkey = random.split(rng)
|
||
return key, f(subkey, x)
|
||
|
||
def g(rng, x):
|
||
return lax.fori_loop(0, 10, body_fn, (rng, x))
|
||
|
||
jax.vmap(g)(random.split(random.PRNGKey(0), 3), jnp.ones((3, 4)))
|
||
|
||
def testIssue514(self):
|
||
# just check this doesn't crash
|
||
lax.cond(True,
|
||
(0, 0), lambda x: (x[0], 0),
|
||
(1, 1), lambda x: x)
|
||
|
||
def testIssue649(self):
|
||
from jax import lax
|
||
|
||
def body(x):
|
||
a, b = x
|
||
return (7, b + 1)
|
||
|
||
def cond(x):
|
||
a, b = x
|
||
return b < 10
|
||
|
||
out = lax.while_loop(cond, body, (33, 4))
|
||
self.assertEqual(out, (7, 10))
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{jit_scan=}_{jit_f=}_impl={scan_name}",
|
||
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl,
|
||
"impl_name": scan_name}
|
||
for jit_scan in [False, True]
|
||
for jit_f in [False, True]
|
||
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
||
def testScanImpl(self, jit_scan, jit_f, scan, impl_name):
|
||
rng = self.rng()
|
||
|
||
d = rng.randn(2)
|
||
def f(c, a):
|
||
assert a.shape == (3,)
|
||
assert c.shape == (4,)
|
||
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
|
||
c = jnp.sin(c * b)
|
||
assert b.shape == ()
|
||
return c, b
|
||
|
||
if jit_f:
|
||
f = jax.jit(f)
|
||
if jit_scan:
|
||
scan = jax.jit(scan, static_argnums=(0,))
|
||
|
||
as_ = rng.randn(5, 3)
|
||
c = rng.randn(4)
|
||
|
||
ans = scan(f, c, as_)
|
||
expected = scan_reference(f, c, as_)
|
||
rtol = {np.float64: 1.4e-15}
|
||
atol = {np.float64: 8e-15}
|
||
if impl_name == "for":
|
||
rtol[np.float32] = 8e-5
|
||
atol[np.float32] = 3e-5
|
||
self.assertAllClose(
|
||
ans,
|
||
expected,
|
||
check_dtypes=False,
|
||
rtol=rtol,
|
||
atol=atol)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{jit_scan=}_{jit_f=}_impl={scan_name}",
|
||
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
|
||
for jit_scan in [False, True]
|
||
for jit_f in [False, True]
|
||
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
||
def testScanJVP(self, jit_scan, jit_f, scan):
|
||
rng = self.rng()
|
||
|
||
d = rng.randn(2)
|
||
def f(c, a):
|
||
assert a.shape == (3,)
|
||
assert c.shape == (4,)
|
||
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
|
||
c = jnp.sin(c * b)
|
||
assert b.shape == ()
|
||
return c, b
|
||
|
||
if jit_f:
|
||
f = jax.jit(f)
|
||
if jit_scan:
|
||
scan = jax.jit(scan, static_argnums=(0,))
|
||
|
||
as_ = rng.randn(5, 3)
|
||
c = rng.randn(4)
|
||
|
||
ans = jax.jvp( lambda c, as_: scan(f, c, as_), (c, as_), (c, as_))
|
||
expected = jax.jvp(lambda c, as_: scan_reference(f, c, as_), (c, as_), (c, as_))
|
||
tol = {np.float64: 1e-12, np.float32: 1e-4}
|
||
self.assertAllClose(ans, expected, check_dtypes=False, rtol=tol, atol=tol)
|
||
|
||
jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["fwd"],
|
||
rtol={jnp.float32: 2e-1})
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{jit_scan=}_{jit_f=}_impl={scan_name}",
|
||
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
|
||
for jit_scan in [False, True]
|
||
for jit_f in [False, True]
|
||
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
||
def testScanLinearize(self, jit_scan, jit_f, scan):
|
||
rng = self.rng()
|
||
|
||
d = rng.randn(2)
|
||
def f(c, a):
|
||
assert a.shape == (3,)
|
||
assert c.shape == (4,)
|
||
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
|
||
c = jnp.sin(c * b)
|
||
assert b.shape == ()
|
||
return c, b
|
||
|
||
if jit_f:
|
||
f = jax.jit(f)
|
||
if jit_scan:
|
||
scan = jax.jit(scan, static_argnums=(0,))
|
||
|
||
as_ = rng.randn(5, 3)
|
||
c = rng.randn(4)
|
||
|
||
if scan is scan_with_new_checkpoint2:
|
||
rtol = {np.float64: 1e-12, np.float32: 1e-4}
|
||
elif scan is scan_with_for:
|
||
rtol = {np.float64: 1e-12, np.float32: 1e-4}
|
||
else:
|
||
rtol = {np.float64: 1e-14, np.float32: 1e-4}
|
||
|
||
ans = jax.linearize(lambda c, as_: scan(f, c, as_), c, as_)[1](c, as_)
|
||
expected = jax.linearize(lambda c, as_: scan_reference(f, c, as_), c, as_)[1](c, as_)
|
||
self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{jit_scan=}_{jit_f=}_impl={scan_name}",
|
||
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
|
||
for jit_scan in [False, True]
|
||
for jit_f in [False, True]
|
||
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||
def testScanGrad(self, jit_scan, jit_f, scan):
|
||
rng = self.rng()
|
||
|
||
d = rng.randn(2)
|
||
def f(c, a):
|
||
assert a.shape == (3,)
|
||
assert c.shape == (4,)
|
||
b = jnp.sum(jnp.sin(a)) + jnp.sum(jnp.sin(c)) + jnp.sum(jnp.sin(d))
|
||
c = jnp.sin(c * b)
|
||
assert b.shape == ()
|
||
return c, b
|
||
|
||
if scan is scan_with_new_checkpoint:
|
||
rtol = {np.float32: 5e-5, np.float64: 1e-13}
|
||
atol = 1e-5
|
||
elif scan is scan_with_for:
|
||
rtol = {np.float32: 2e-5, np.float64: 1e-13}
|
||
atol = {np.float32: 6e-2, np.float64: 1e-13}
|
||
else:
|
||
rtol = {np.float32: 2e-4, np.float64: 1e-13}
|
||
atol = {np.float32: 8e-5, np.float64: 1e-13}
|
||
|
||
if jit_f:
|
||
f = jax.jit(f)
|
||
if jit_scan:
|
||
scan = jax.jit(scan, static_argnums=(0,))
|
||
|
||
as_ = rng.randn(5, 3)
|
||
c = rng.randn(4)
|
||
|
||
ans = jax.grad(lambda c, as_: list( scan(f, c, as_))[0].sum())(c, as_)
|
||
expected = jax.grad(lambda c, as_: list(scan_reference(f, c, as_))[0].sum())(c, as_)
|
||
self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol, atol=atol)
|
||
|
||
rtol = 5e-3 if scan is not scan_with_new_checkpoint2 else 5e-2
|
||
atol = 5e-2 if jtu.test_device_matches(["tpu"]) else 1e-3
|
||
jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["rev"],
|
||
atol=atol, rtol=rtol)
|
||
|
||
@jtu.skip_on_devices("tpu") # TPU lacks precision for this test.
|
||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||
def testScanRnn(self):
|
||
r = self.rng()
|
||
|
||
n_in = 4
|
||
n_hid = 2
|
||
n_out = 1
|
||
length = 3
|
||
|
||
W_trans = r.randn(n_hid, n_hid + n_in).astype(jnp.float_)
|
||
W_out = r.randn(n_out, n_hid + n_in).astype(jnp.float_)
|
||
params = W_trans, W_out
|
||
|
||
inputs = r.randn(length, n_in).astype(jnp.float_)
|
||
targets = r.randn(length, n_out).astype(jnp.float_)
|
||
|
||
def step(params, state, input):
|
||
W_trans, W_out = params
|
||
stacked = jnp.concatenate([state, input])
|
||
output = jnp.tanh(jnp.dot(W_out, stacked))
|
||
next_state = jnp.tanh(jnp.dot(W_trans, stacked))
|
||
return next_state, output
|
||
|
||
def rnn(params, inputs):
|
||
init_state = jnp.zeros(n_hid)
|
||
_, outputs = lax.scan(partial(step, params), init_state, inputs)
|
||
return outputs
|
||
|
||
@jax.jit
|
||
def loss(params, inputs, targets):
|
||
predictions = rnn(params, inputs)
|
||
return jnp.sum((predictions - targets)**2)
|
||
|
||
# evaluation doesn't crash
|
||
loss(params, inputs, targets)
|
||
|
||
# jvp evaluation doesn't crash
|
||
jax.jvp(lambda params: loss(params, inputs, targets), (params,), (params,))
|
||
|
||
# jvp numerical check passes
|
||
jtu.check_grads(loss, (params, inputs, targets), order=2, modes=["fwd"],
|
||
rtol={np.float32: 2e-2, np.float64: 1e-6})
|
||
|
||
# linearize works
|
||
_, expected = jax.jvp(loss, (params, inputs, targets),
|
||
(params, inputs, targets))
|
||
_, linfun = jax.linearize(loss, params, inputs, targets)
|
||
ans = linfun(params, inputs, targets)
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
# gradient evaluation doesn't crash
|
||
jax.grad(loss)(params, inputs, targets)
|
||
|
||
# gradient check passes
|
||
jtu.check_grads(loss, (params, inputs, targets), order=2, rtol=2e-2)
|
||
|
||
# we can vmap to batch things
|
||
batch_size = 7
|
||
batched_inputs = r.randn(batch_size, length, n_in).astype(jnp.float_)
|
||
batched_targets = r.randn(batch_size, length, n_out).astype(jnp.float_)
|
||
batched_loss = jax.vmap(lambda x, y: loss(params, x, y))
|
||
losses = batched_loss(batched_inputs, batched_targets)
|
||
expected = np.stack(list(map(lambda x, y: loss(params, x, y),
|
||
batched_inputs, batched_targets)))
|
||
self.assertAllClose(losses, expected, check_dtypes=False, rtol=1e-2)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_impl={scan_name}", "scan": scan_impl}
|
||
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
||
def testIssue711(self, scan):
|
||
# Tests reverse-mode differentiation through a scan for which the scanned
|
||
# function also involves reverse-mode differentiation.
|
||
# See https://github.com/jax-ml/jax/issues/711
|
||
def harmonic_bond(conf, params):
|
||
return jnp.sum(conf * params)
|
||
|
||
def minimize_structure(test_params):
|
||
energy_fn = partial(harmonic_bond, params=test_params)
|
||
|
||
def apply_carry(carry, _):
|
||
i, x = carry
|
||
new_x = x - 0.1 * jax.grad(energy_fn)(x)
|
||
new_carry = (i+1, new_x)
|
||
return new_carry, _
|
||
|
||
x0 = jnp.array([1., 2., 3.])
|
||
carry_final, _ = scan(apply_carry, (0, x0), jnp.zeros((75, 0)))
|
||
_, x_final = carry_final
|
||
return x_final
|
||
|
||
initial_params = 0.5
|
||
minimize_structure(initial_params) # doesn't crash
|
||
|
||
def loss(test_params):
|
||
x_final = minimize_structure(test_params)
|
||
return jnp.sum(jnp.sin(1.0 - x_final))
|
||
|
||
jax.grad(loss)(0.25) # doesn't crash
|
||
|
||
def testIssue744(self):
|
||
Point = collections.namedtuple('Point', ['x', 'y'])
|
||
p0 = Point(x=jnp.array(1), y=jnp.array(2))
|
||
|
||
def plus_one(p, iter_idx):
|
||
return Point(p.x+1, p.y+1), iter_idx
|
||
|
||
self.assertRaisesRegex(
|
||
ValueError,
|
||
'scan got value with no leading axis to scan over.*',
|
||
lambda: lax.scan(plus_one, p0, list(range(5))))
|
||
|
||
def testScanBodyOutputError(self):
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
re.escape("scan body output must be a pair, got ShapedArray(float32[]).")):
|
||
lax.scan(lambda c, x: np.float32(0.), 0, jnp.arange(5.))
|
||
|
||
def testScanMetadataError(self):
|
||
# Regression test for https://github.com/jax-ml/jax/issues/25507
|
||
def f(loop_i, x):
|
||
return {'T': jnp.array([0.5])}
|
||
|
||
init_val = {'t': jnp.array([1.0])}
|
||
msg = r".*with pytree metadata \('t',\).*with pytree metadata \('T',\)"
|
||
with self.assertRaisesRegex(TypeError, msg):
|
||
jax.lax.fori_loop(0, 1, f, init_val)
|
||
|
||
def testScanBodyCarryPytreeMismatchErrors(self):
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
re.escape("function carry input and carry output must have "
|
||
"the same pytree structure, but they differ:\n\n"
|
||
"The input carry c is a tuple of length 2")):
|
||
lax.scan(lambda c, x: ((0, 0, 0), x), (1, (2, 3)), jnp.arange(5.))
|
||
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
re.escape("function carry input and carry output must have the "
|
||
"same pytree structure, but they differ:\n\n"
|
||
"The input carry x is a tuple of length 2")):
|
||
lax.scan(lambda x, _: ((x[0].astype('float32'),), None),
|
||
(jnp.array(0, 'int32'),) * 2, None, length=1)
|
||
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
re.escape("function carry input and carry output must have the "
|
||
"same pytree structure, but they differ:\n\n"
|
||
"The input carry x is a <class 'tuple'> but the corres")):
|
||
jax.lax.scan(lambda x, _: ([x[0].astype('float32'),] * 2, None),
|
||
(jnp.array(0, 'int32'),) * 2, None, length=1)
|
||
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
re.escape("function carry input and carry output must have the "
|
||
"same pytree structure, but they differ:\n\n"
|
||
"The input carry x is a <class 'dict'> with 1 child but")):
|
||
jax.lax.scan(lambda x, _: ({'a': x['a'], 'b': x['a']}, None),
|
||
{'a': jnp.array(0, 'int32')}, None, length=1)
|
||
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
re.escape("function carry input and carry output must have the "
|
||
"same pytree structure, but they differ:\n\n"
|
||
" * the input carry component x[0] is a <class 'dict'> with "
|
||
"1 child but the corresponding component of the carry "
|
||
"output is a <class 'dict'> with 2 children")):
|
||
jax.lax.scan(lambda x, _: (({'a': x[0]['a'], 'b': x[0]['a']},) * 2, None),
|
||
({'a': jnp.array(0, 'int32')},) * 2, None, length=1)
|
||
|
||
def testScanBodyCarryTypeMismatchErrors(self):
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
re.escape("function carry input and carry output must have equal "
|
||
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
|
||
"The input carry x has type int32[] but the corresponding "
|
||
"output carry component has type float32[], so the dtypes do "
|
||
"not match"
|
||
)):
|
||
jax.lax.scan(lambda x, _: (x.astype('float32'), None),
|
||
jnp.array(0, 'int32'), None, length=1)
|
||
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
re.escape("function carry input and carry output must have equal "
|
||
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
|
||
"The input carry component x[1] has type int32[] but the "
|
||
"corresponding output carry component has type float32[], "
|
||
"so the dtypes do not match"
|
||
)):
|
||
jax.lax.scan(lambda x, _: ((x[0], x[1].astype('float32')), None),
|
||
(jnp.array(0, 'int32'),) * 2, None, length=1)
|
||
|
||
with self.assertRaisesRegex(
|
||
TypeError,
|
||
re.escape("function carry input and carry output must have equal "
|
||
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
|
||
" * the input carry component x[0] has type int32[] but the "
|
||
"corresponding output carry component has type float32[], "
|
||
"so the dtypes do not match;\n"
|
||
" * the input carry component x[1] has type int32[] but the "
|
||
"corresponding output carry component has type float32[1,1], "
|
||
"so the dtypes do not match and also the shapes do not match."
|
||
)):
|
||
jax.lax.scan(lambda x, _: ((x[0].astype('float32'),
|
||
x[1].astype('float32').reshape(1, 1),
|
||
x[2]), None),
|
||
(jnp.array(0, 'int32'),) * 3, None, length=1)
|
||
|
||
@jax.enable_checks(False)
|
||
def testScanInvalidUnrollRaises(self):
|
||
with self.assertRaisesRegex(ValueError, "`unroll` must be"):
|
||
jax.lax.scan(lambda x, _: (x, x), 0, jnp.arange(5), unroll=-1)
|
||
with self.assertRaisesRegex(ValueError, "`unroll` must be"):
|
||
jax.lax.scan(lambda x, _: (x, x), 0, jnp.arange(5), unroll=0)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{scan_name}",
|
||
"scan": scan_impl}
|
||
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
||
def testScanHigherOrderDifferentiation(self, scan):
|
||
d = 0.75
|
||
def f(c, a):
|
||
b = jnp.sin(c * jnp.sum(jnp.cos(d * a)))
|
||
c = 0.9 * jnp.cos(d * jnp.sum(jnp.sin(c * a)))
|
||
return c, b
|
||
|
||
as_ = jnp.arange(6.).reshape((3, 2))
|
||
c = jnp.array(1, dtype=as_.dtype)
|
||
|
||
jtu.check_grads(lambda c, as_: scan(f, c, as_), (c, as_),
|
||
modes=["rev"], order=2, rtol={np.float32: 6e-3})
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{jit_scan=}_{jit_f=}_{in_axes=}_impl={scan_name}",
|
||
"jit_scan": jit_scan, "jit_f": jit_f, "in_axes": in_axes,
|
||
"scan": scan_impl}
|
||
for jit_scan in [False, True]
|
||
for jit_f in [False, True]
|
||
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR
|
||
for in_axes in itertools.product([None, 0, 1], [None, 0, 1, 2])
|
||
if in_axes != (None, None))
|
||
def testScanVmap(self, jit_scan, jit_f, in_axes, scan):
|
||
rng = self.rng()
|
||
|
||
d = rng.randn(2)
|
||
def f(c, a):
|
||
assert a.shape == (3,)
|
||
assert c.shape == (4,)
|
||
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
|
||
c = jnp.sin(c * b)
|
||
assert b.shape == ()
|
||
return c, b
|
||
|
||
if jit_f:
|
||
f = jax.jit(f)
|
||
if jit_scan:
|
||
scan = jax.jit(scan, static_argnums=(0,))
|
||
|
||
as_shape = [5, 3]
|
||
c_shape = [4]
|
||
|
||
c_bdim, as_bdim = in_axes
|
||
if c_bdim is not None:
|
||
c_shape.insert(c_bdim, 7)
|
||
if as_bdim is not None:
|
||
as_shape.insert(as_bdim, 7)
|
||
|
||
as_ = rng.randn(*as_shape)
|
||
c = rng.randn(*c_shape)
|
||
|
||
ans = jax.vmap(lambda c, as_: scan(f, c, as_), in_axes)(c, as_)
|
||
expected = jax.vmap(lambda c, as_: scan_reference(f, c, as_), in_axes)(c, as_)
|
||
self.assertAllClose(ans, expected, check_dtypes=False,
|
||
rtol=1e-5, atol=1e-5)
|
||
|
||
def testScanVmapTuples(self):
|
||
def f(c, a):
|
||
a1, a2 = a
|
||
c1, c2 = c
|
||
b = jnp.sum(jnp.cos(a1)) * jnp.sum(c2 * a2)
|
||
c = c1 * jnp.sin(jnp.sum(a1 * a2)), c2 * jnp.cos(jnp.sum(a1))
|
||
return c, b
|
||
|
||
in_axes = (0, (1, 2))
|
||
|
||
r = self.rng()
|
||
as_ = (r.randn(3, 7), r.randn(3, 4, 7))
|
||
c = (r.randn(7, 2), r.randn(7))
|
||
|
||
expected_c_out, expected_bs = [], []
|
||
for i in range(7):
|
||
c_out, bs = lax.scan(f, (c[0][i], c[1][i]), (as_[0][:,i], as_[1][:,:,i]))
|
||
expected_c_out.append(c_out)
|
||
expected_bs.append(bs)
|
||
expected_c_out_0, expected_c_out_1 = unzip2(expected_c_out)
|
||
expected_c_out = (jnp.stack(expected_c_out_0), jnp.stack(expected_c_out_1))
|
||
expected_bs = jnp.stack(expected_bs)
|
||
expected = expected_c_out, expected_bs
|
||
|
||
ans = jax.vmap(lambda c, as_: lax.scan(f, c, as_), in_axes)(c, as_)
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_impl={scan_name}", "scan": scan_impl}
|
||
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
||
def testScanVmapFixpoint(self, scan):
|
||
def f(carry_init):
|
||
def scan_body(c, x):
|
||
# The carry is a 4-tuple, the last element starts batched,
|
||
# and the carry is shifted left at each iteration.
|
||
return ((c[1], c[2], c[3], 0.), None)
|
||
return scan(scan_body, (0., 1., 2., carry_init), jnp.zeros(2))
|
||
carry_init = jnp.array([3., 4., 5.])
|
||
carry_out, _ = jax.vmap(f)(carry_init)
|
||
self.assertAllClose(carry_out[3], jnp.array([0., 0., 0.]), check_dtypes=False)
|
||
self.assertAllClose(carry_out[2], jnp.array([0., 0., 0.]), check_dtypes = False)
|
||
# After two shifts, we get the carry_init
|
||
self.assertAllClose(carry_out[1], carry_init, check_dtypes=False)
|
||
self.assertAllClose(carry_out[0], jnp.array([2., 2., 2.]), check_dtypes = False)
|
||
|
||
def testIssue757(self):
|
||
# code from https://github.com/jax-ml/jax/issues/757
|
||
def fn(a):
|
||
return jnp.cos(a)
|
||
|
||
def loop(val):
|
||
iterations = 10
|
||
|
||
def apply_carry(x, i):
|
||
return jax.grad(fn, argnums=(0,))(x)[0], i
|
||
|
||
final_val, _ = lax.scan(apply_carry, val, jnp.arange(iterations))
|
||
return final_val
|
||
|
||
arg = 0.5
|
||
jax.jit(jax.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash
|
||
|
||
def testIssue804(self):
|
||
# https://github.com/jax-ml/jax/issues/804
|
||
num_devices = jax.device_count()
|
||
f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.)
|
||
jax.pmap(f, axis_name="i")(jnp.ones((num_devices, 4))) # doesn't crash
|
||
|
||
def testMap(self):
|
||
f = lambda x: x ** 2
|
||
xs = jnp.arange(10)
|
||
expected = xs ** 2
|
||
actual = lax.map(f, xs)
|
||
self.assertAllClose(actual, expected)
|
||
|
||
def testMapEmpty(self):
|
||
# https://github.com/jax-ml/jax/issues/2412
|
||
ans = lax.map(lambda x: x * x, jnp.array([]))
|
||
expected = jnp.array([])
|
||
self.assertAllClose(ans, expected)
|
||
|
||
@jtu.thread_unsafe_test() # Cache eviction means we might retrace
|
||
def testCaching(self):
|
||
def cond(x):
|
||
assert python_should_be_executing
|
||
return x < 5
|
||
|
||
def body(x):
|
||
assert python_should_be_executing
|
||
return x + 2
|
||
|
||
python_should_be_executing = True
|
||
lax.while_loop(cond, body, 0)
|
||
|
||
python_should_be_executing = False
|
||
lax.while_loop(cond, body, 0)
|
||
|
||
# This second caching test shows a different kind of caching that we haven't
|
||
# implemented (but could!), namely that Python functions that are distinct
|
||
# objects but are equivalent functions trigger cache hits. This kind of
|
||
# caching could be salient when using lambda functions with control flow:
|
||
#
|
||
# lax.while_loop(lambda x: x < 5, lambda x: x + 2, 0)
|
||
# lax.while_loop(lambda x: x < 5, lambda x: x + 2, 0)
|
||
#
|
||
# To get a cache hit on the second line we'd need to form a jaxpr and
|
||
# compare them for equality (including the literals on identity). We could
|
||
# implement that by adding a __hash__/__eq__ to core.Jaxpr and
|
||
# core.ClosedJaxpr (see #1221).
|
||
@unittest.skip("not implemented")
|
||
def testCaching2(self):
|
||
def cond(x):
|
||
assert python_should_be_executing
|
||
return x < 5
|
||
|
||
def body(x):
|
||
assert python_should_be_executing
|
||
return x + 2
|
||
|
||
python_should_be_executing = True
|
||
lax.while_loop(cond, body, 0)
|
||
|
||
def cond(x):
|
||
assert python_should_be_executing
|
||
return x < 5
|
||
|
||
def body(x):
|
||
assert python_should_be_executing
|
||
return x + 2
|
||
|
||
python_should_be_executing = False
|
||
lax.while_loop(cond, body, 0)
|
||
|
||
def test_caches_depend_on_axis_env(self):
|
||
# https://github.com/jax-ml/jax/issues/9187
|
||
scanned_f = lambda _, __: (lax.psum(1, 'i'), None)
|
||
f = lambda: lax.scan(scanned_f, 0, None, length=1)[0]
|
||
ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
|
||
self.assertEqual(ans, 2)
|
||
ans = jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
|
||
self.assertEqual(ans, 3)
|
||
|
||
def testWhileCondConstant(self):
|
||
out = lax.while_loop(lambda _: False, lambda _: (), ()) # doesn't crash
|
||
self.assertEqual(out, ())
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{jit_loop=}_{jit_body=}_{jit_cond=}",
|
||
"jit_loop": jit_loop, "jit_body": jit_body, "jit_cond": jit_cond}
|
||
for jit_loop in [False, True]
|
||
for jit_body in [False, True]
|
||
for jit_cond in [False, True])
|
||
def testWhileJVP(self, jit_loop=True, jit_body=False, jit_cond=True):
|
||
cond = lambda x: x[0, 2] <= 8
|
||
body = lambda x: x * x
|
||
|
||
if jit_cond:
|
||
cond = jax.jit(cond)
|
||
if jit_body:
|
||
body = jax.jit(body)
|
||
|
||
loop = partial(lax.while_loop, cond, body)
|
||
if jit_loop:
|
||
loop = jax.jit(loop)
|
||
|
||
loop_ref = partial(while_loop_reference, cond, body)
|
||
|
||
x = jnp.arange(9.).reshape((3, 3))
|
||
ans = jax.jvp(loop, (x,), (x,))
|
||
expected = jax.jvp(loop_ref, (x,), (x,))
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
jtu.check_grads(loop, (x,), order=2, modes=["fwd"])
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{jit_loop=}_{jit_body=}_{jit_cond=}_impl={while_name}",
|
||
"jit_loop": jit_loop, "jit_body": jit_body, "jit_cond": jit_cond,
|
||
"while_loop": while_impl}
|
||
for jit_loop in [False, True]
|
||
for jit_body in [False, True]
|
||
for jit_cond in [False, True]
|
||
for while_impl, while_name in WHILE_LOOP_IMPLS)
|
||
def testWhileLinearize(self, while_loop, jit_loop=True, jit_body=False,
|
||
jit_cond=True):
|
||
cond = lambda x: x[0, 2] <= 8
|
||
body = lambda x: x * x
|
||
|
||
if jit_cond:
|
||
cond = jax.jit(cond)
|
||
if jit_body:
|
||
body = jax.jit(body)
|
||
|
||
loop = partial(while_loop, cond, body)
|
||
if jit_loop:
|
||
loop = jax.jit(loop)
|
||
|
||
loop_ref = partial(while_loop_reference, cond, body)
|
||
|
||
x = jnp.arange(9.).reshape((3, 3))
|
||
y, f_lin = jax.linearize(loop, x)
|
||
ydot = f_lin(x)
|
||
y_expected, ydot_expected = jax.jvp(loop_ref, (x,), (x,))
|
||
self.assertAllClose(y, y_expected, check_dtypes=False)
|
||
self.assertAllClose(ydot, ydot_expected, check_dtypes=False)
|
||
|
||
def testWhileJVPViaForiLoop(self):
|
||
f = lambda x: lax.fori_loop(0, 3, lambda i, x: x * 2, x)
|
||
self.assertAllClose(f(2.), 16., check_dtypes=False)
|
||
self.assertAllClose(jax.jvp(f, (2.,), (1.,)), (16., 8.), check_dtypes=False)
|
||
jtu.check_grads(f, (2.,), order=2, modes=["fwd"])
|
||
|
||
f = lambda x: lax.fori_loop(0, 3, lambda i, x: x * (i + 1), x)
|
||
self.assertAllClose(f(2.), 12., check_dtypes=False)
|
||
self.assertAllClose(jax.jvp(f, (2.,), (1.,)), (12., 6.), check_dtypes=False)
|
||
jtu.check_grads(f, (2.,), order=2, modes=["fwd"])
|
||
|
||
def testWhileJVPWithGrowingNonzeroTangents(self):
|
||
rng = self.rng()
|
||
|
||
def cond(state):
|
||
i, x, y, z = state
|
||
return i < 2
|
||
|
||
def body(state):
|
||
i, x, y, z = state
|
||
y = x * x
|
||
z = y * y
|
||
return i + 1, x, y, z
|
||
|
||
y, z = rng.randn(2), rng.randn(2)
|
||
def loop(loop_impl, x):
|
||
return loop_impl(cond, body, (0, x, y, z))[1]
|
||
|
||
loop_lax = partial(loop, lax.while_loop)
|
||
loop_ref = partial(loop, while_loop_reference)
|
||
|
||
x = rng.randn(2)
|
||
ans = jax.jvp(loop_lax, (x,), (x,))
|
||
expected = jax.jvp(loop_ref, (x,), (x,))
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
jtu.check_grads(loop_lax, (x,), order=2, modes=["fwd"])
|
||
|
||
def testStaticForiGrad(self):
|
||
func = lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x)
|
||
jax.grad(func)(1.) # doesn't crash
|
||
jax.linearize(func, 1.) # doesn't crash
|
||
|
||
@parameterized.named_parameters(
|
||
dict(testcase_name=f"_{loop=}", loop=loop)
|
||
for loop in ["while", "fori_inside_cond", "fori_inside_scan"])
|
||
def testWhileGradError(self, loop: str = "fori_inside_scan"):
|
||
# Raise error for vjp for loops
|
||
if loop == "while":
|
||
func = lambda x: lax.while_loop(lambda i: i < 5., lambda i: i + 1., x)
|
||
elif loop == "fori_inside_jit":
|
||
func = jax.jit(lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x))
|
||
elif loop == "fori_inside_cond":
|
||
func = lambda x: lax.cond(
|
||
True,
|
||
x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x),
|
||
1., lambda x: x)
|
||
elif loop == "fori_inside_scan":
|
||
func = lambda x: lax.scan(
|
||
lambda c, x: (lax.fori_loop(x, x + 2., lambda i, c1: c1 * c, x), None),
|
||
x, np.ones(2))[0]
|
||
else:
|
||
assert False
|
||
|
||
with self.assertRaisesRegex(ValueError, "Reverse-mode differentiation does not work for lax.while_loop"):
|
||
jax.grad(func)(1.)
|
||
|
||
jax.linearize(func, 1.) # Linearization works
|
||
|
||
@jax.legacy_prng_key('allow')
|
||
def testIssue1316(self):
|
||
def f(carry, _):
|
||
c, key = carry
|
||
key, _ = random.split(key)
|
||
return (c, key), ()
|
||
|
||
key = random.PRNGKey(0)
|
||
jax.grad(lambda c: lax.scan(f, (c, key), np.ones(3))[0][0])(0.) # doesn't crash
|
||
|
||
def testIssue1361(self):
|
||
@jax.jit
|
||
def jit_run_scan(x):
|
||
def fun(carry, _):
|
||
x, _ = carry
|
||
return (2 * x, 0.), None
|
||
(x, _), _ = lax.scan(fun, (x, 0.), jnp.arange(3))
|
||
return x
|
||
|
||
jax.grad(lambda x: jit_run_scan(x))(0.) # doesn't crash
|
||
|
||
def testIssue810(self):
|
||
def loss(A):
|
||
def step(x, i):
|
||
return jnp.matmul(A, x), None
|
||
init_x = jnp.zeros(A.shape[-1:])
|
||
last_x, _ = lax.scan(step, init_x, jnp.arange(10))
|
||
return jnp.sum(last_x)
|
||
|
||
A = jnp.zeros((3, 3))
|
||
# The second DUS was unnecessarily replicating A across time.
|
||
# We check XLA because _scan_impl is "underneath" the jaxpr language.
|
||
s = jax.jit(jax.grad(loss)).lower(A).as_text('hlo')
|
||
assert s.count("dynamic-update-slice(") < 2
|
||
|
||
def testScanLengthArg(self):
|
||
def arange(n):
|
||
return lax.scan(lambda c, _: (c + 1, c), 0, None, length=n)[1]
|
||
|
||
ans = arange(10)
|
||
expected = np.arange(10)
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
@ignore_jit_of_pmap_warning()
|
||
def test_while_loop_of_pmap(self):
|
||
# Avoid accuracy issue caused by too many devices.
|
||
DEVICE_LIMITATION = 4
|
||
devices = jax.devices()
|
||
count = jax.device_count()
|
||
if jax.device_count() >= DEVICE_LIMITATION:
|
||
devices = devices[:DEVICE_LIMITATION]
|
||
count = DEVICE_LIMITATION
|
||
|
||
# code from jsnoek@
|
||
def body(i, x):
|
||
result = jax.pmap(lambda z: lax.psum(jnp.sin(z), 'i'), devices=devices, axis_name='i')(x)
|
||
return result + x
|
||
f_loop = lambda x: lax.fori_loop(0, 3, body, x) # noqa: F821
|
||
ans = f_loop(jnp.ones(count))
|
||
del body, f_loop
|
||
|
||
def body2(i, x):
|
||
result = jnp.broadcast_to(jnp.sin(x).sum(), x.shape)
|
||
return result + x
|
||
g_loop = lambda x: lax.fori_loop(0, 3, body2, x)
|
||
expected = g_loop(jnp.ones(count))
|
||
|
||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||
|
||
@ignore_jit_of_pmap_warning()
|
||
def test_while_loop_of_pmap_error_message(self):
|
||
|
||
def body(i, x):
|
||
result = jax.pmap(lambda z: lax.psum(jnp.sin(z), 'i'), axis_name='i')(x)
|
||
return result + x
|
||
f_loop = lambda x: lax.fori_loop(0, 3, body, x)
|
||
|
||
too_big = 2 * jax.device_count()
|
||
|
||
self.assertRaisesRegex(
|
||
ValueError,
|
||
re.escape(
|
||
"compiling computation `scan` that requires {} "
|
||
"replicas, but only {} XLA devices are available."
|
||
.format(too_big, jax.device_count())),
|
||
lambda: f_loop(jnp.ones(too_big)))
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{scan_name}",
|
||
"scan": scan_impl}
|
||
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
||
def test_scan_reverse(self, scan):
|
||
def cumsum(x, reverse):
|
||
return scan(lambda c, x: (c + x, c + x), 0, x, reverse=reverse)[1]
|
||
|
||
x = np.array([3, 1, 4, 1, 5, 9])
|
||
self.assertAllClose(np.cumsum(x), cumsum(x, False), check_dtypes=False)
|
||
self.assertAllClose(np.cumsum(x[::-1])[::-1], cumsum(x, True), check_dtypes=False)
|
||
|
||
with jax.disable_jit():
|
||
self.assertAllClose(np.cumsum(x), cumsum(x, False), check_dtypes=False)
|
||
with jax.disable_jit():
|
||
self.assertAllClose(np.cumsum(x[::-1])[::-1], cumsum(x, True), check_dtypes=False)
|
||
|
||
def test_scan_unroll(self):
|
||
d = jnp.ones(2)
|
||
def f(c, a):
|
||
assert a.shape == (3,)
|
||
assert c.shape == (4,)
|
||
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
|
||
c = jnp.sin(c * b)
|
||
assert b.shape == ()
|
||
return c, b
|
||
|
||
xs = jnp.ones((20, 3))
|
||
c = jnp.ones(4)
|
||
|
||
scan = lambda c, xs: lax.scan(f, c, xs)
|
||
scan_unrolled = lambda c, xs: lax.scan(f, c, xs, unroll=2)
|
||
scan_fully_unrolled = lambda c, xs: lax.scan(f, c, xs, unroll=True)
|
||
|
||
# jaxprs should be the same size
|
||
self.assertEqual(
|
||
len(str(jax.make_jaxpr(scan)(c, xs))),
|
||
len(str(jax.make_jaxpr(scan_unrolled)(c, xs))))
|
||
|
||
# but HLO should grow due to unrolling
|
||
scan_hlo = str(jax.jit(scan).lower(c, xs).as_text("hlo"))
|
||
scan_unrolled_hlo = str(jax.jit(scan_unrolled).lower(c, xs).as_text("hlo"))
|
||
scan_fully_unrolled_hlo = str(
|
||
jax.jit(scan_fully_unrolled).lower(c, xs).as_text("hlo"))
|
||
|
||
self.assertLess(len(scan_hlo), len(scan_unrolled_hlo))
|
||
self.assertLess(len(scan_unrolled_hlo), len(scan_fully_unrolled_hlo))
|
||
|
||
# and the lowering should contain a while loop, unless the scan is fully
|
||
# unrolled
|
||
self.assertIn("while(", scan_hlo)
|
||
self.assertIn("while(", scan_unrolled_hlo)
|
||
self.assertNotIn("while(", scan_fully_unrolled_hlo)
|
||
|
||
def test_scan_xs_none(self):
|
||
def f(h, _):
|
||
return h + 1, None
|
||
|
||
length = 20
|
||
h, _ = lax.scan(f, 0, length=length)
|
||
self.assertEqual(h, length)
|
||
|
||
def test_disable_jit_cond_with_vmap(self):
|
||
# https://github.com/jax-ml/jax/issues/3093
|
||
def fn(t):
|
||
return lax.cond(t > 0, 0, lambda x: 0, 0, lambda x: 1)
|
||
fn = jax.vmap(fn)
|
||
|
||
with jax.disable_jit():
|
||
_ = fn(jnp.array([1])) # doesn't crash
|
||
|
||
def test_disable_jit_while_loop_with_vmap(self):
|
||
# https://github.com/jax-ml/jax/issues/2823
|
||
def trivial_while(y):
|
||
return lax.while_loop(lambda x: x < 10.0, lambda x: x + 1.0, y)
|
||
with jax.disable_jit():
|
||
jax.vmap(trivial_while)(jnp.array([3.0,4.0])) # doesn't crash
|
||
|
||
def test_vmaps_of_while_loop(self):
|
||
# https://github.com/jax-ml/jax/issues/3164
|
||
def f(x, n): return lax.fori_loop(0, n, lambda _, x: x + 1, x)
|
||
x, n = jnp.arange(3), jnp.arange(4)
|
||
jax.vmap(jax.vmap(f, (None, 0)), (0, None))(x, n) # doesn't crash
|
||
|
||
def test_disable_jit_while_loop_with_mutation(self):
|
||
# https://github.com/jax-ml/jax/issues/27019
|
||
|
||
def body_fun(carry):
|
||
x, y = carry
|
||
x += 1 # in-place if x is mutable
|
||
return x, y + x
|
||
|
||
def cond_fun(carry):
|
||
x, _ = carry
|
||
return x < 10
|
||
|
||
def f():
|
||
val = np.array(1.0) # mutable value
|
||
return jax.lax.while_loop(cond_fun, body_fun, (val, val))[1]
|
||
|
||
with jax.disable_jit(False):
|
||
result_jit = f()
|
||
with jax.disable_jit(True):
|
||
result_nojit = f()
|
||
self.assertEqual(result_jit, result_nojit)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_{shape}_{axis=}",
|
||
"shape": shape, "axis": axis}
|
||
for shape in [
|
||
[0], [1], [2], [3], [5], [10], [1000],
|
||
[2, 3], [7, 5], [5, 6, 7]
|
||
]
|
||
for axis in range(-len(shape), len(shape) - 1))
|
||
def testAssociativeScanUnstructured(self, shape, axis):
|
||
data = np.arange(np.prod(shape)).reshape(shape) + 7
|
||
expected = np.cumsum(data, axis=axis)
|
||
result = lax.associative_scan(operator.add, data, axis=axis)
|
||
self.assertAllClose(result, expected, check_dtypes=False)
|
||
|
||
def testAssociativeScanUnstructured1000Reverse(self):
|
||
data = np.arange(1000) + 32
|
||
expected = np.cumsum(data[::-1])[::-1]
|
||
result = lax.associative_scan(operator.add, data, reverse=True)
|
||
self.assertAllClose(result, expected, check_dtypes=False)
|
||
|
||
def testAssociativeScanStructured3(self):
|
||
pair = collections.namedtuple('pair', ('first', 'second'))
|
||
data = pair(first=np.array([0., 1., 2.]),
|
||
second=np.array([0., 10., 20.]))
|
||
|
||
def fn(a, b):
|
||
return pair(first=a.first + b.first,
|
||
second=a.second + b.second)
|
||
|
||
result = lax.associative_scan(fn, elems=data)
|
||
self.assertAllClose(result.first, np.array([0., 1., 3.]),
|
||
check_dtypes=False)
|
||
self.assertAllClose(result.second, np.array([0., 10., 30.]),
|
||
check_dtypes=False)
|
||
|
||
def testAssociativeScanOfBools(self):
|
||
x = jnp.array([False, True, True, True, False, True])
|
||
y = lax.associative_scan(lax.bitwise_xor, x)
|
||
self.assertArraysEqual(np.array([False, True, False, True, True, False]), y)
|
||
|
||
@parameterized.named_parameters({"testcase_name": f"_{shape}", "shape": shape}
|
||
for shape in [2, 43, 100])
|
||
def testAssociativeScanSolvingRegressionTest(self, shape):
|
||
# This test checks that the batching rule doesn't raise for a batch
|
||
# sensitive function (solve).
|
||
ms = np.repeat(np.eye(2).reshape(1, 2, 2), shape, axis=0)
|
||
vs = np.ones((shape, 2))
|
||
|
||
@jax.vmap
|
||
def fn(a, b):
|
||
m1, v1 = a
|
||
m2, v2 = b
|
||
return m1 + m2, jsp.linalg.solve(m1, v2) + jsp.linalg.solve(m2, v1)
|
||
|
||
_ = lax.associative_scan(fn, elems=(ms, vs))
|
||
|
||
def test_scan_typecheck_param(self):
|
||
d = jnp.ones(2)
|
||
def f(c, a):
|
||
b = jnp.cos(jnp.sum(a) + jnp.sum(c) + jnp.sum(d))
|
||
c = jnp.sin(c * b)
|
||
return c, b
|
||
|
||
xs = jnp.ones((5, 3))
|
||
c = jnp.ones(4)
|
||
scan_fun = lambda c, xs: lax.scan(f, c, xs)
|
||
|
||
def new_jaxpr():
|
||
jaxpr = jax.make_jaxpr(partial(scan_fun))(c, xs).jaxpr
|
||
scan = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'scan')
|
||
return jaxpr, scan
|
||
|
||
jaxpr, eqn = new_jaxpr()
|
||
eqn.params['reverse'] = 4
|
||
self.assertRaisesRegex(
|
||
core.JaxprTypeError,
|
||
re.escape('invalid scan param reverse of type int, bool required: 4'),
|
||
lambda: core.check_jaxpr(jaxpr))
|
||
|
||
jaxpr, eqn = new_jaxpr()
|
||
eqn.params['num_consts'] = -3
|
||
self.assertRaisesRegex(
|
||
core.JaxprTypeError,
|
||
re.escape('invalid scan param num_consts of type int, '
|
||
'non-negative int required: -3'),
|
||
lambda: core.check_jaxpr(jaxpr))
|
||
|
||
def test_cond_typecheck_param(self):
|
||
def new_jaxpr():
|
||
jaxpr = jax.make_jaxpr(
|
||
lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr
|
||
cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond')
|
||
return jaxpr, cond
|
||
|
||
jaxpr, eqn = new_jaxpr()
|
||
eqn.params['branches'] = (4, 2)
|
||
self.assertRaisesRegex(
|
||
core.JaxprTypeError,
|
||
re.escape('invalid cond param branches of type tuple, '
|
||
'tuple of ClosedJaxpr required: (4, 2)'),
|
||
lambda: core.check_jaxpr(jaxpr))
|
||
|
||
def test_cond_transformation_rule_with_consts(self):
|
||
# https://github.com/jax-ml/jax/pull/9731
|
||
|
||
@jax.custom_jvp
|
||
def f(x):
|
||
return x
|
||
|
||
@f.defjvp
|
||
def f_jvp(primals, tangents):
|
||
(x,), (xdot,) = primals, tangents
|
||
const = np.arange(3, dtype=x.dtype)
|
||
return x * const, xdot * const
|
||
|
||
g = lambda x: jax.lax.cond(True, f, lambda x: x, x)
|
||
x = np.arange(3, dtype='float32')
|
||
jax.jvp(g, (x,), (x,)) # doesn't crash
|
||
|
||
@jtu.thread_unsafe_test()
|
||
def test_cond_excessive_compilation(self):
|
||
# Regression test for https://github.com/jax-ml/jax/issues/14058
|
||
def f(x):
|
||
return x + 1
|
||
|
||
def g(x):
|
||
return x + 2
|
||
|
||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||
for x in range(10):
|
||
lax.cond(x, f, g, x)
|
||
# Should observe a maximum of 4 compiles: convert_element_type, f, g, cond
|
||
# In #14058, this was observed to be 31 compiles.
|
||
self.assertLess(count(), 5)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype}
|
||
for dtype in jtu.dtypes.all_integer)
|
||
def test_scan_init_weak_type(self, dtype):
|
||
def func(carry, x):
|
||
return carry + x, x
|
||
init_weak = 0 # Python scalars are weakly-typed.
|
||
x = jnp.ones(5, dtype=dtype)
|
||
carry, result = lax.scan(func, init_weak, x)
|
||
self.assertEqual(carry, x.sum(dtype=carry.dtype))
|
||
self.assertArraysEqual(result, x)
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype}
|
||
for dtype in jtu.dtypes.all_integer)
|
||
def test_while_loop_init_weak_type(self, dtype):
|
||
# This tests whether lax.while_loop can properly handle weakly-typed
|
||
# initial values.
|
||
def cond_fun(val):
|
||
return val < 2
|
||
def body_fun(val):
|
||
return val + increment
|
||
increment = jnp.array(1, dtype=dtype)
|
||
init_weak = 0 # Python scalars are weakly-typed.
|
||
result = lax.while_loop(cond_fun, body_fun, init_weak)
|
||
self.assertArraysEqual(result, jnp.full_like(increment, 2))
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"{suffix}", "remat": remat}
|
||
for suffix, remat in [
|
||
('', None),
|
||
('new_remat', new_checkpoint),
|
||
])
|
||
def test_scan_vjp_forwards_extensive_residuals(self, remat):
|
||
# https://github.com/jax-ml/jax/issues/4510
|
||
def cumprod(x):
|
||
s = jnp.ones((2, 32), jnp.float32)
|
||
return lax.scan(lambda s, x: (x*s, s), s, x)
|
||
if remat is not None:
|
||
cumprod = remat(cumprod)
|
||
|
||
rng = self.rng()
|
||
x = jnp.asarray(rng.randn(32, 2, 32).astype('float32'))
|
||
_, vjp_fun = jax.vjp(cumprod, x)
|
||
|
||
# Need to spelunk into vjp_fun. This is fragile, and if it causes problems
|
||
# just skip this test and make an issue for mattjj.
|
||
*_, ext_res = vjp_fun.args[0].args[0]
|
||
self.assertIs(ext_res, x)
|
||
|
||
if remat is not None:
|
||
# TODO(mattjj): make the numpy.ndarray test pass w/ remat
|
||
raise unittest.SkipTest("new-remat-of-scan doesn't convert numpy.ndarray")
|
||
|
||
x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not Array
|
||
_, vjp_fun = jax.vjp(cumprod, x)
|
||
*_, ext_res = vjp_fun.args[0].args[0]
|
||
self.assertIsInstance(ext_res, jax.Array)
|
||
|
||
def test_scan_vmap_collectives(self):
|
||
def scan_f(state, x):
|
||
s = lax.psum(state, 'i') * x
|
||
return state, s
|
||
|
||
def scan(state, xs):
|
||
return lax.scan(scan_f, state, xs)
|
||
|
||
scan_v = jax.vmap(scan, in_axes=0, out_axes=0, axis_name='i')
|
||
self.assertAllClose(
|
||
scan_v(jnp.ones([1]), jnp.arange(5.).reshape((1, 5))),
|
||
(jnp.array([1.]), jnp.array([[0., 1., 2., 3., 4.]])), check_dtypes=False)
|
||
|
||
def test_xla_cpu_gpu_loop_cond_bug(self):
|
||
# https://github.com/jax-ml/jax/issues/5900
|
||
def deriv(f):
|
||
return lambda x, *args: jax.linearize(lambda x: f(x, *args), x)[1](1.0)
|
||
|
||
def _while_loop(cond_fun, body_fun, init_val, max_iter):
|
||
def _iter(val):
|
||
next_val = body_fun(val)
|
||
next_cond = True
|
||
return next_val, next_cond
|
||
|
||
def _fun(tup, _):
|
||
val, cond = tup
|
||
return jax.lax.cond(cond, _iter, lambda x: (x, False), val), _
|
||
|
||
init = (init_val, cond_fun(init_val))
|
||
return jax.lax.scan(_fun, init, None, length=max_iter)[0][0]
|
||
|
||
def my_pow(x, y):
|
||
def body_fun(val):
|
||
return val * x
|
||
def cond_fun(val):
|
||
return True
|
||
return _while_loop(cond_fun, body_fun, 1.0, y)
|
||
|
||
self.assertAllClose(deriv(my_pow)(3.0, 1), 1.0, check_dtypes=False)
|
||
|
||
|
||
def test_while_loop_fixed_point_with_batched_pred_and_consts(self):
|
||
def f(i, x):
|
||
def cond(carry):
|
||
i, x = carry
|
||
return i < 5
|
||
def body(carry):
|
||
i, z = carry
|
||
# Close over const with batch dim = 1
|
||
return i + 1, z + x
|
||
return lax.while_loop(cond, body, (i, jnp.ones(3)))[1]
|
||
jax.vmap(f, in_axes=(0, 1))(jnp.arange(4), jnp.ones((3, 4)))
|
||
|
||
def test_cond_ad_batched_unit(self):
|
||
# see issue #9985
|
||
def cond_id(x):
|
||
return lax.cond(x < 0., lambda x: x, lambda x: x, x)
|
||
jax.vmap(jax.jacrev(lambda x: cond_id(cond_id(x))))(jnp.ones(1))
|
||
|
||
@parameterized.named_parameters(
|
||
{"testcase_name": f"impl={scan_name}", "scan": scan_impl}
|
||
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
||
def test_scan_hoisting_consts(self, scan):
|
||
A = jnp.arange(4.).reshape(2, 2)
|
||
B = jnp.arange(4.).reshape(2, 2) + 1.
|
||
|
||
def f(x):
|
||
def body(c, _):
|
||
c1, c2, c3 = c
|
||
return (jnp.dot(A, c1), jnp.dot(B, c2), jnp.dot(jnp.sin(B), c3)), None
|
||
init_carry = (x * jnp.ones(2), x * jnp.ones(2), x * jnp.ones(2))
|
||
(c1, c2, c3), _ = scan(body, init_carry, None, length=3)
|
||
return jnp.sum(c1) + jnp.sum(c2) + jnp.sum(c3)
|
||
|
||
jax.grad(f)(1.) # doesn't crash
|
||
|
||
def test_custom_jvp_tangent_cond_transpose(self):
|
||
# https://github.com/jax-ml/jax/issues/14026
|
||
def mask_fun(arr, choice):
|
||
out = (1 - choice) * arr.sum() + choice * (1 - arr.sum())
|
||
return out
|
||
|
||
def switch_fun(arr, choice):
|
||
choice = jnp.floor(choice).astype(jnp.int32)
|
||
out = jax.lax.switch(choice, [lambda x: x.sum(), lambda x: 1 - x.sum()], arr)
|
||
return out
|
||
|
||
test_arr = jnp.arange(3.)
|
||
test_val = 0.
|
||
|
||
expected1 = jax.grad(mask_fun)(test_arr, test_val)
|
||
expected2 = jax.grad(switch_fun)(test_arr, test_val)
|
||
|
||
def good_switchfun_jvp(primals, tangents):
|
||
arr, choice = primals
|
||
arr_dot, choice_dot = tangents
|
||
return switch_fun(arr, choice), mask_fun(arr_dot, choice)
|
||
|
||
def bad_switchfun_jvp(primals, tangents):
|
||
arr, choice = primals
|
||
arr_dot, choice_dot = tangents
|
||
return switch_fun(arr, choice), switch_fun(arr_dot, choice)
|
||
|
||
good_custom_switchfun = jax.custom_jvp(switch_fun)
|
||
good_custom_switchfun.defjvp(good_switchfun_jvp)
|
||
expected3 = jax.grad(good_custom_switchfun)(test_arr, test_val)
|
||
|
||
bad_custom_switchfun = jax.custom_jvp(switch_fun)
|
||
bad_custom_switchfun.defjvp(bad_switchfun_jvp)
|
||
actual = jax.grad(bad_custom_switchfun)(test_arr, test_val)
|
||
|
||
self.assertAllClose(expected1, expected2)
|
||
self.assertAllClose(expected2, expected3)
|
||
self.assertAllClose(expected3, actual)
|
||
|
||
def test_platform_dependent(self):
|
||
def f(x):
|
||
return lax.platform_dependent(x, cpu=jnp.sin, default=jnp.cos)
|
||
|
||
x = np.arange(3, dtype=np.float32)
|
||
res = f(x)
|
||
self.assertAllClose(
|
||
res,
|
||
np.sin(x) if jtu.device_under_test() == "cpu" else np.cos(x))
|
||
|
||
def test_platform_dependent_no_args(self):
|
||
def f(x):
|
||
return lax.platform_dependent(cpu=lambda: jnp.sin(x),
|
||
default=lambda: jnp.cos(x))
|
||
|
||
x = np.arange(3, dtype=np.float32)
|
||
res = f(x)
|
||
self.assertAllClose(
|
||
res,
|
||
np.sin(x) if jtu.device_under_test() == "cpu" else np.cos(x))
|
||
|
||
def test_platform_dependent_lowering(self):
|
||
def f(x):
|
||
return lax.platform_dependent(x, cpu=jnp.sin, default=jnp.cos)
|
||
|
||
x = np.arange(3, dtype=np.float32)
|
||
lowered = jax.jit(f).lower(x)
|
||
stablehlo = lowered.as_text()
|
||
self.assertIn("stablehlo.case", stablehlo)
|
||
self.assertIn("stablehlo.sine", stablehlo)
|
||
self.assertIn("stablehlo.cosine", stablehlo)
|
||
|
||
# The HLO has been canonicalized and contains only the branch we need
|
||
hlo = lowered.as_text("hlo")
|
||
if jtu.device_under_test() == "cpu":
|
||
self.assertIn(" sine", hlo)
|
||
self.assertNotIn(" cosine", hlo)
|
||
else:
|
||
self.assertNotIn(" sine", hlo)
|
||
self.assertIn(" cosine", hlo)
|
||
|
||
def test_platform_dependent_with_non_existent_custom_call(self):
|
||
if not jtu.test_device_matches(["cpu"]):
|
||
self.skipTest("Only for CPU")
|
||
|
||
def f(x):
|
||
# One use with the bad custom call on a different platform branch
|
||
x1 = lax.platform_dependent(x,
|
||
cpu=jnp.sin,
|
||
other=prim_non_existent_custom_call.bind)
|
||
# and with the bad custom call in the default branch
|
||
x2 = lax.platform_dependent(x,
|
||
cpu=jnp.sin,
|
||
default=prim_non_existent_custom_call.bind)
|
||
# and one use where the current platform is the default
|
||
x3 = lax.platform_dependent(x,
|
||
other=prim_non_existent_custom_call.bind,
|
||
default=jnp.sin)
|
||
return x1 + x2 + x3
|
||
|
||
x = np.arange(3, dtype=np.float32)
|
||
hlo = str(jax.jit(f).lower(x).compiler_ir())
|
||
occurrences = re.findall(prim_non_existent_custom_call.name, hlo)
|
||
self.assertLen(occurrences, 3)
|
||
|
||
res_eager = f(x)
|
||
self.assertAllClose(res_eager, 3. * np.sin(x))
|
||
res_jit = jax.jit(f)(x)
|
||
self.assertAllClose(res_jit, 3 * np.sin(x))
|
||
|
||
res_vmap = jax.vmap(f)(x)
|
||
self.assertAllClose(res_vmap, 3. * np.sin(x))
|
||
|
||
_, res_jvp = jax.jvp(f, (x,), (np.full(x.shape, .1, dtype=x.dtype),))
|
||
self.assertAllClose(res_jvp, .3 * np.cos(x))
|
||
|
||
res_grad = jax.grad(f)(1.)
|
||
self.assertAllClose(res_grad, 3. * np.cos(1.))
|
||
|
||
def test_platform_dependent_multiple_identical_branches(self):
|
||
x = np.arange(3, dtype=np.float32)
|
||
def f(x):
|
||
return lax.platform_dependent(
|
||
x,
|
||
cpu=jnp.sin,
|
||
tpu=jnp.sin,
|
||
default=lambda x: x)
|
||
res = f(x)
|
||
self.assertAllClose(
|
||
res,
|
||
np.sin(x) if jtu.device_under_test() in ["cpu", "tpu"] else x)
|
||
# We only lower the common branches once
|
||
stablehlo = jax.jit(f).lower(x).as_text()
|
||
sines = re.findall(r"stablehlo.sine", stablehlo)
|
||
self.assertEqual(1, len(sines))
|
||
|
||
def test_platform_dependent_no_default(self):
|
||
ctx = contextlib.ExitStack()
|
||
if jtu.device_under_test() != "tpu":
|
||
ctx.enter_context(
|
||
self.assertRaisesRegex(NotImplementedError,
|
||
"translation rule .* not found for platform"))
|
||
with ctx:
|
||
lax.platform_dependent(
|
||
3.,
|
||
tpu=lambda x: x + 2.)
|
||
|
||
def test_platform_dependent_batched(self):
|
||
def f(x):
|
||
return lax.platform_dependent(x, cpu=jnp.sin, default=jnp.cos)
|
||
|
||
xs = np.arange(3, dtype=np.float32)
|
||
self.assertAllClose(
|
||
jax.vmap(f)(xs),
|
||
np.sin(xs) if jtu.device_under_test() == "cpu" else np.cos(xs))
|
||
# We can still fold the un-needed branch
|
||
hlo = jax.jit(jax.vmap(f)).lower(xs).as_text('hlo')
|
||
expect_a_sine = (jtu.device_under_test() == "cpu")
|
||
self.assertEqual(expect_a_sine, " sine(" in hlo)
|
||
self.assertEqual(not expect_a_sine, " cosine(" in hlo)
|
||
|
||
def test_platform_dependent_grad(self):
|
||
# For a function "lax.dot(x, x)", we choose two branches with very different
|
||
# implementations (a dot and a scan), and therefore different residuals,
|
||
# so that we can verify whether the residuals are as we expect (we don't
|
||
# get residuals from a different platform.
|
||
x = np.arange(8, dtype=np.float32)
|
||
def f_impl_dot(x): # x: f32[8]
|
||
return jnp.dot(x, x)
|
||
def f_impl_scan(x):
|
||
def scan_body(carry, x_i):
|
||
return (carry + x_i * x_i, None)
|
||
return lax.scan(scan_body, np.float32(0.), x)[0]
|
||
|
||
def f(x):
|
||
return jnp.sin(lax.platform_dependent(x,
|
||
cpu=f_impl_dot,
|
||
default=f_impl_scan))
|
||
self.assertAllClose(
|
||
jax.grad(f)(x),
|
||
jax.grad(lambda x: jnp.sin(f_impl_dot(x)))(x))
|
||
|
||
# Check that we do not have contamination of computations across platforms
|
||
hlo = jax.jit(jax.grad(f)).lower(x).as_text('hlo')
|
||
expect_a_dot = (jtu.device_under_test() == "cpu")
|
||
self.assertEqual(expect_a_dot, " dot(" in hlo)
|
||
self.assertEqual(not expect_a_dot, " while(" in hlo)
|
||
|
||
def test_scan_lowering_doesnt_introduce_singleton(self):
|
||
b = 4
|
||
i = 2
|
||
|
||
def scan(y):
|
||
def body(carry, x):
|
||
return carry, jnp.dot(x, x)
|
||
return jax.lax.scan(body, 1.0, y, unroll=False)
|
||
|
||
fn = jax.jit(scan)
|
||
|
||
init = np.array(np.arange(b * i * i), dtype=np.float32).reshape((b, i, i))
|
||
hlo_text = fn.lower(init).as_text('hlo')
|
||
self.assertNotIn('4,1,2,2', hlo_text)
|
||
|
||
def test_scan_length_concrete_error(self):
|
||
f = jax.jit(lambda n, x: jax.lax.scan(lambda c, z: (c, z), x, (), n))
|
||
|
||
with self.assertRaisesRegex(
|
||
core.ConcretizationTypeError,
|
||
"The `length` argument to `scan` expects a concrete `int` value.*"):
|
||
f(3, 1.)
|
||
|
||
def test_scan_unroll_concrete_error(self):
|
||
f = jax.jit(lambda n, x: jax.lax.scan(
|
||
lambda c, z: (c, z), x, (), 10, unroll=n))
|
||
|
||
msg = ("The `unroll` argument to `scan` expects a concrete `int` or "
|
||
"`bool` value.*")
|
||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||
f(3, 1.)
|
||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||
f(True, 1.)
|
||
|
||
def test_cond_vmap_forwarding_doesnt_promote(self):
|
||
def f(x, y):
|
||
x, y = jax.lax.cond(
|
||
x < 3,
|
||
lambda x, y: (x * 2, y),
|
||
lambda x, y: (x * 3, y),
|
||
x, y
|
||
)
|
||
return x, y
|
||
|
||
x = jnp.arange(3)
|
||
y = jnp.array(3.)
|
||
|
||
x2, y2 = jax.vmap(f, in_axes=(0, None), out_axes=(0, None))(x, y) # don't crash
|
||
|
||
assert x is not x2
|
||
assert y is y2
|
||
|
||
def test_cond_casting(self):
|
||
x = 1.0
|
||
identity = lambda x: x
|
||
|
||
y = lax.cond(True, identity, identity, x)
|
||
self.assertEqual(y, x)
|
||
self.assertIsInstance(y, jax.Array)
|
||
|
||
@jtu.thread_unsafe_test() # live_arrays count isn't thread-safe
|
||
def test_cond_memory_leak(self):
|
||
# https://github.com/jax-ml/jax/issues/12719
|
||
|
||
def leak():
|
||
data = jax.device_put(np.zeros((1024), dtype=np.float32) + 1)
|
||
def g():
|
||
return jax.lax.cond(
|
||
True,
|
||
lambda: data[0], # noqa: F821
|
||
lambda: data[1], # noqa: F821
|
||
)
|
||
jg = jax.jit(g)
|
||
_ = jg().block_until_ready()
|
||
del g, jg, data, _
|
||
|
||
nbufs = lambda: len(jax.live_arrays())
|
||
base = nbufs()
|
||
leak()
|
||
self.assertEqual(base, nbufs())
|
||
leak()
|
||
self.assertEqual(base, nbufs())
|
||
leak()
|
||
self.assertEqual(base, nbufs())
|
||
|
||
|
||
if __name__ == '__main__':
|
||
absltest.main(testLoader=jtu.JaxTestLoader())
|