2022-09-22 12:26:48 -07:00
|
|
|
|
# Copyright 2018 The JAX Authors.
|
2019-04-10 13:00:31 -07:00
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
2019-05-21 07:08:13 -07:00
|
|
|
|
import collections
|
2023-10-27 03:04:39 +03:00
|
|
|
|
import contextlib
|
2019-04-10 13:00:31 -07:00
|
|
|
|
from functools import partial
|
2019-05-16 10:18:53 -07:00
|
|
|
|
import itertools
|
2020-05-21 08:35:17 -07:00
|
|
|
|
import operator
|
2019-09-25 16:04:15 -07:00
|
|
|
|
import re
|
2021-10-04 17:54:18 -07:00
|
|
|
|
import unittest
|
2019-04-10 13:00:31 -07:00
|
|
|
|
|
|
|
|
|
from absl.testing import absltest
|
2019-05-10 10:18:37 -07:00
|
|
|
|
from absl.testing import parameterized
|
2019-04-10 13:00:31 -07:00
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
import numpy as np
|
2019-04-10 13:00:31 -07:00
|
|
|
|
|
2021-01-06 10:45:19 -08:00
|
|
|
|
import jax
|
2023-02-14 23:00:40 -08:00
|
|
|
|
from jax._src import core
|
2022-12-02 13:20:30 -08:00
|
|
|
|
from jax import dtypes
|
2019-04-10 13:00:31 -07:00
|
|
|
|
from jax import lax
|
2019-08-29 22:28:03 -04:00
|
|
|
|
from jax import random
|
2021-09-24 07:02:08 -07:00
|
|
|
|
from jax._src import test_util as jtu
|
xla: improvement to string representation of PyTreeDef
The string representation of PyTreeDef was different to how the underlying
containers are represented in python. This sometimes made it harder to read
error messages. This commit modifies the representation of tuples, lists,
dicts, and None so that it matches the pythonic representation.
The representation of custom nodes and NamedTuples is left unchanged since
their structure is not easily accessible in C++. However, to avoid confusion
they are now labelled "CustomNode" instead of "PyTreeDef". The latter is now
only used to wrap the whole representation. See below for examples.
Tests that relied on a specific string representation of PyTreeDef in error
messages are modified to be agnostic to the representation. Instead, this
commit adds a separate test of the string representation in tree_util_test.
Examples:
```
OLD: PyTreeDef(dict[['a', 'b']], [*,*])
NEW: PyTreeDef({'a': *, 'b': *})
OLD: PyTreeDef(tuple, [PyTreeDef(tuple, [*,*]),PyTreeDef(list, [*,PyTreeDef(tuple, [*,PyTreeDef(None, []),*])])])
NEW: PyTreeDef(((*, *), [*, (*, None, *)]))
OLD: PyTreeDef(list, [PyTreeDef(<class '__main__.AnObject'>[[4, 'foo']], [*,PyTreeDef(None, [])])])
NEW: PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])
```
PiperOrigin-RevId: 369298267
2021-04-19 14:06:11 -07:00
|
|
|
|
from jax import tree_util
|
2021-01-11 14:20:32 -08:00
|
|
|
|
from jax._src.util import unzip2
|
2022-04-30 21:50:18 -07:00
|
|
|
|
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies
|
2020-05-05 14:59:16 -04:00
|
|
|
|
import jax.numpy as jnp # scan tests use numpy
|
2019-10-20 18:10:03 -07:00
|
|
|
|
import jax.scipy as jsp
|
2022-08-23 09:36:12 -07:00
|
|
|
|
from jax._src.lax import control_flow as lax_control_flow
|
2022-06-02 11:50:03 -07:00
|
|
|
|
from jax._src.lax.control_flow import for_loop
|
2024-06-20 20:16:01 +03:00
|
|
|
|
from jax._src.interpreters import mlir
|
2019-05-08 11:37:27 -07:00
|
|
|
|
|
2024-04-11 13:23:27 -07:00
|
|
|
|
jax.config.parse_flags_with_absl()
|
2019-10-29 11:04:55 -04:00
|
|
|
|
|
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
# Some tests are useful for testing both lax.cond and lax.switch. This function
|
|
|
|
|
# provides a lax.cond-compatible interface to a two-branch lax.switch. Several
|
|
|
|
|
# tests in this file are parameterized such that they either call into lax.cond
|
|
|
|
|
# or into this function.
|
|
|
|
|
def cond_via_switch(pred, true_fun, false_fun, op, *args):
|
|
|
|
|
if len(args) > 0:
|
|
|
|
|
assert len(args) == 1
|
|
|
|
|
true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0]
|
|
|
|
|
op = (false_op, true_op)
|
|
|
|
|
false_fun = lambda op: _false_fun(op[0])
|
|
|
|
|
true_fun = lambda op: _true_fun(op[1])
|
|
|
|
|
index = lax.convert_element_type(pred, np.int32)
|
|
|
|
|
return lax.switch(index, [false_fun, true_fun], op)
|
|
|
|
|
|
2022-07-19 08:53:23 -07:00
|
|
|
|
def cond_with_new_checkpoint(pred, true_fun, false_fun, op, *args):
|
|
|
|
|
if args:
|
|
|
|
|
true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0]
|
|
|
|
|
op = (false_op, true_op)
|
|
|
|
|
false_fun = lambda op: _false_fun(op[0])
|
|
|
|
|
true_fun = lambda op: _true_fun(op[1])
|
|
|
|
|
index = lax.convert_element_type(pred, np.int32)
|
|
|
|
|
fn = lambda index, op: lax.switch(index, [false_fun, true_fun], op)
|
|
|
|
|
return new_checkpoint(fn)(index, op)
|
|
|
|
|
|
|
|
|
|
COND_IMPLS = [
|
|
|
|
|
(lax.cond, 'cond'),
|
|
|
|
|
(cond_via_switch, 'switch'),
|
|
|
|
|
(cond_with_new_checkpoint, 'new_checkpoint'),
|
|
|
|
|
]
|
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
|
2022-04-30 21:50:18 -07:00
|
|
|
|
# We wanted to try all scan tests with the scan partial evaluation rule that
|
|
|
|
|
# happens under ad_checkpoint.checkpoint, so we make a scan wrapper which
|
|
|
|
|
# wraps a ad_checkpoint.checkpoint around the computation.
|
|
|
|
|
def scan_with_new_checkpoint(f, *args, **kwargs):
|
|
|
|
|
return new_checkpoint(partial(lax.scan, f, **kwargs),
|
|
|
|
|
policy=checkpoint_policies.nothing_saveable)(*args)
|
|
|
|
|
def scan_with_new_checkpoint2(f, *args, **kwargs):
|
|
|
|
|
return new_checkpoint(partial(lax.scan, f, **kwargs),
|
|
|
|
|
policy=checkpoint_policies.everything_saveable)(*args)
|
|
|
|
|
|
2022-06-28 16:29:38 -07:00
|
|
|
|
def scan_with_for(f, *args, **kwargs):
|
|
|
|
|
return for_loop.scan(f, *args, **kwargs)
|
2022-04-30 21:50:18 -07:00
|
|
|
|
|
2022-09-02 10:11:58 -07:00
|
|
|
|
def scan_with_remat_for(f, *args, **kwargs):
|
|
|
|
|
return jax.remat(lambda *args: for_loop.scan(f, *args, **kwargs))(*args)
|
|
|
|
|
|
2022-06-28 16:29:38 -07:00
|
|
|
|
SCAN_IMPLS_WITH_FOR = [
|
|
|
|
|
(lax.scan, 'unroll1'),
|
|
|
|
|
(partial(lax.scan, unroll=2), 'unroll2'),
|
2024-03-28 10:54:02 -07:00
|
|
|
|
(partial(lax.scan, _split_transpose=True), 'split_transpose'),
|
2022-06-28 16:29:38 -07:00
|
|
|
|
(scan_with_new_checkpoint , 'new_checkpoint'),
|
|
|
|
|
(scan_with_new_checkpoint2, 'new_checkpoint2'),
|
2022-08-12 16:52:49 -07:00
|
|
|
|
(scan_with_for, 'for_loop'),
|
2022-09-02 10:11:58 -07:00
|
|
|
|
(scan_with_remat_for, 'for_loop_remat'),
|
|
|
|
|
]
|
|
|
|
|
|
2022-07-28 18:04:49 -07:00
|
|
|
|
def while_loop_new_checkpoint(cond_fun, body_fun, init_val):
|
|
|
|
|
return new_checkpoint(partial(lax.while_loop, cond_fun, body_fun))(init_val)
|
|
|
|
|
|
|
|
|
|
WHILE_LOOP_IMPLS = [
|
|
|
|
|
(lax.while_loop, 'while_loop'),
|
|
|
|
|
(while_loop_new_checkpoint, 'new_checkpoint'),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
2020-01-15 18:06:31 -08:00
|
|
|
|
def while_loop_reference(cond, body, carry):
|
|
|
|
|
while cond(carry):
|
|
|
|
|
carry = body(carry)
|
|
|
|
|
return carry
|
|
|
|
|
|
|
|
|
|
|
2019-05-08 11:37:27 -07:00
|
|
|
|
def scan_reference(f, init, xs):
|
|
|
|
|
carry = init
|
|
|
|
|
ys = []
|
|
|
|
|
for x in xs:
|
|
|
|
|
(carry, y) = f(carry, x)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
ys.append(lax.reshape(y, (1,) + np.shape(y)))
|
2019-05-08 11:37:27 -07:00
|
|
|
|
ys = lax.concatenate(ys, 0)
|
2019-05-11 09:29:12 -07:00
|
|
|
|
return carry, ys
|
2019-04-10 13:00:31 -07:00
|
|
|
|
|
|
|
|
|
|
2021-10-13 10:56:21 -04:00
|
|
|
|
ignore_jit_of_pmap_warning = partial(
|
|
|
|
|
jtu.ignore_warning, message=".*jit-of-pmap.*")
|
|
|
|
|
|
2024-06-20 20:16:01 +03:00
|
|
|
|
# A JAX primitive whose lowering is a custom call to a non-existent function.
|
|
|
|
|
prim_non_existent_custom_call = core.Primitive("__testing_non_existent_custom_call")
|
|
|
|
|
prim_non_existent_custom_call.def_abstract_eval(lambda x_aval: x_aval)
|
|
|
|
|
mlir.register_lowering(
|
|
|
|
|
prim_non_existent_custom_call,
|
|
|
|
|
lambda ctx, x: mlir.hlo.CustomCallOp(
|
|
|
|
|
[x.type], [x],
|
|
|
|
|
call_target_name=mlir.ir.StringAttr.get("__testing_non_existent_custom_call")).results)
|
|
|
|
|
|
2020-04-09 00:07:30 -07:00
|
|
|
|
|
2019-04-10 13:00:31 -07:00
|
|
|
|
class LaxControlFlowTest(jtu.JaxTestCase):
|
|
|
|
|
|
2021-01-06 10:45:19 -08:00
|
|
|
|
def setUp(self):
|
|
|
|
|
super().setUp()
|
2022-08-23 09:36:12 -07:00
|
|
|
|
lax_control_flow._initial_style_open_jaxpr.cache_clear()
|
|
|
|
|
lax_control_flow._initial_style_jaxpr.cache_clear()
|
2024-07-23 23:17:06 +00:00
|
|
|
|
lax_control_flow.common._pad_jaxpr_constvars.cache_clear()
|
2021-01-06 10:45:19 -08:00
|
|
|
|
|
2022-04-13 10:39:01 -07:00
|
|
|
|
def testCallableErrors(self):
|
|
|
|
|
not_callable = 42
|
|
|
|
|
with self.assertRaisesRegex(TypeError, "lax.fori_loop.*callable.*"):
|
|
|
|
|
lax.fori_loop(0, 1, not_callable, 0)
|
|
|
|
|
with self.assertRaisesRegex(TypeError, "lax.while_loop.*callable.*"):
|
|
|
|
|
lax.while_loop(not_callable, not_callable, 0)
|
|
|
|
|
with self.assertRaisesRegex(TypeError, "lax.switch:.*callable.*"):
|
|
|
|
|
lax.switch(0, [not_callable])
|
|
|
|
|
with self.assertRaisesRegex(TypeError, "lax.cond.*callable.*"):
|
|
|
|
|
lax.cond(0, not_callable, not_callable)
|
|
|
|
|
with self.assertRaisesRegex(TypeError, "lax.scan.*callable.*"):
|
|
|
|
|
lax.scan(not_callable, 0, 1)
|
|
|
|
|
with self.assertRaisesRegex(TypeError, "lax.associative_scan.*callable.*"):
|
|
|
|
|
lax.associative_scan(not_callable, 0)
|
|
|
|
|
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def testWhileWithTuple(self):
|
|
|
|
|
limit = 10
|
2019-04-10 13:00:31 -07:00
|
|
|
|
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def loop_cond(state):
|
|
|
|
|
pos, _ = state
|
|
|
|
|
return lax.lt(pos, limit)
|
2019-04-10 13:00:31 -07:00
|
|
|
|
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def loop_body(state):
|
|
|
|
|
pos, count = state
|
|
|
|
|
return (lax.add(pos, 1), lax.add(count, 1))
|
2019-04-10 13:00:31 -07:00
|
|
|
|
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def loop(init):
|
|
|
|
|
result = lax.while_loop(loop_cond, loop_body, (init, 0))
|
|
|
|
|
_, count = result
|
|
|
|
|
return count
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
cloop = jax.jit(loop)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
self.assertEqual(loop(2), limit - 2)
|
|
|
|
|
self.assertEqual(cloop(2), limit - 2)
|
|
|
|
|
self.assertEqual(cloop(2), limit - 2)
|
|
|
|
|
self.assertEqual(cloop(3), limit - 3)
|
|
|
|
|
|
2019-10-08 11:08:42 -07:00
|
|
|
|
def testWhileWithManyArgs(self):
|
|
|
|
|
nargs = 256
|
|
|
|
|
|
|
|
|
|
def loop_cond(state):
|
|
|
|
|
return lax.lt(state[0], 2)
|
|
|
|
|
|
|
|
|
|
def loop_body(state):
|
|
|
|
|
return tuple(lax.add(s, 1) for s in state)
|
|
|
|
|
|
|
|
|
|
_ = lax.while_loop(loop_cond, loop_body, (0,) * nargs)
|
|
|
|
|
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def testNestedWhile(self):
|
|
|
|
|
|
|
|
|
|
def outer_loop(num): # pylint: disable=missing-docstring
|
|
|
|
|
def cond_fun(state):
|
|
|
|
|
num, i, _ = state
|
|
|
|
|
return lax.lt(i, num)
|
|
|
|
|
|
|
|
|
|
def body_fun(state):
|
|
|
|
|
num, i, count = state
|
|
|
|
|
return (num, lax.add(i, 1), inner_loop(i, count))
|
|
|
|
|
|
|
|
|
|
init_val = (num, 0, 0)
|
|
|
|
|
_, i, count = lax.while_loop(cond_fun, body_fun, init_val)
|
|
|
|
|
return (i, count)
|
|
|
|
|
|
|
|
|
|
def inner_loop(i, count): # pylint: disable=missing-docstring
|
|
|
|
|
def cond_fun(state):
|
|
|
|
|
i, j, _ = state
|
|
|
|
|
return lax.le(j, i)
|
|
|
|
|
|
|
|
|
|
def body_fun(state):
|
|
|
|
|
i, j, count = state
|
|
|
|
|
return (i, lax.add(j, 1), lax.add(count, 1))
|
2019-04-10 13:00:31 -07:00
|
|
|
|
|
2019-05-09 21:58:21 -07:00
|
|
|
|
init_val = (i, 0, count)
|
|
|
|
|
_, _, count = lax.while_loop(cond_fun, body_fun, init_val)
|
|
|
|
|
return count
|
2019-05-08 16:27:23 -07:00
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
cloop = jax.jit(outer_loop)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
self.assertEqual(outer_loop(3), (3, 6))
|
|
|
|
|
self.assertEqual(cloop(3), (3, 6))
|
|
|
|
|
self.assertEqual(cloop(3), (3, 6))
|
|
|
|
|
self.assertEqual(cloop(2), (2, 3))
|
|
|
|
|
self.assertEqual(cloop(4), (4, 10))
|
|
|
|
|
|
|
|
|
|
def testWhileWithClosure(self):
|
|
|
|
|
|
|
|
|
|
def loop(init, local_limit, inc):
|
|
|
|
|
|
|
|
|
|
def loop_cond(state):
|
|
|
|
|
pos, _ = state
|
|
|
|
|
return lax.lt(pos, local_limit)
|
|
|
|
|
|
|
|
|
|
def loop_body(state):
|
|
|
|
|
effect[0] = True
|
|
|
|
|
pos, count = state
|
|
|
|
|
return (lax.add(pos, 1), lax.add(count, inc))
|
|
|
|
|
|
|
|
|
|
result = lax.while_loop(loop_cond, loop_body, (init, 0))
|
|
|
|
|
_, count = result
|
|
|
|
|
return count
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
cloop = jax.jit(loop)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
limit = 10
|
|
|
|
|
effect = [False]
|
|
|
|
|
self.assertEqual(loop(2, limit, 1), limit - 2)
|
|
|
|
|
assert effect[0]
|
|
|
|
|
effect[0] = False
|
|
|
|
|
self.assertEqual(cloop(2, limit, 1), limit - 2)
|
|
|
|
|
assert effect[0]
|
|
|
|
|
effect[0] = False
|
|
|
|
|
self.assertEqual(cloop(2, limit, 1), limit - 2)
|
|
|
|
|
self.assertEqual(cloop(3, limit, 1), limit - 3)
|
|
|
|
|
assert not effect[0]
|
|
|
|
|
|
|
|
|
|
def testWhileWithClosureJit(self):
|
|
|
|
|
|
|
|
|
|
def loop(init, local_limit, inc):
|
|
|
|
|
|
|
|
|
|
def loop_cond(state):
|
|
|
|
|
pos, _ = state
|
|
|
|
|
return lax.lt(pos, local_limit)
|
|
|
|
|
|
|
|
|
|
def loop_body(state):
|
|
|
|
|
effect[0] = True
|
|
|
|
|
pos, count = state
|
|
|
|
|
f = lambda pos, inc: (lax.add(pos, 1), lax.add(count, inc))
|
2021-09-13 16:00:22 -04:00
|
|
|
|
return jax.jit(f)(pos, inc)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
result = lax.while_loop(loop_cond, loop_body, (init, 0))
|
|
|
|
|
_, count = result
|
|
|
|
|
return count
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
cloop = jax.jit(loop)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
limit = 10
|
|
|
|
|
effect = [False]
|
|
|
|
|
self.assertEqual(loop(2, limit, 1), limit - 2)
|
|
|
|
|
assert effect[0]
|
|
|
|
|
effect[0] = False
|
|
|
|
|
self.assertEqual(cloop(2, limit, 1), limit - 2)
|
|
|
|
|
assert effect[0]
|
|
|
|
|
effect[0] = False
|
|
|
|
|
self.assertEqual(cloop(2, limit, 1), limit - 2)
|
|
|
|
|
self.assertEqual(cloop(3, limit, 1), limit - 3)
|
|
|
|
|
assert not effect[0]
|
|
|
|
|
|
2019-11-14 12:26:41 +01:00
|
|
|
|
def testWhileTypeErrors(self):
|
|
|
|
|
"""Test typing error messages for while."""
|
2024-02-26 14:17:18 -08:00
|
|
|
|
tuple_treedef = jax.tree.structure((1., 1.))
|
|
|
|
|
leaf_treedef = jax.tree.structure(0.)
|
2024-08-03 21:49:58 +00:00
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
TypeError,
|
xla: improvement to string representation of PyTreeDef
The string representation of PyTreeDef was different to how the underlying
containers are represented in python. This sometimes made it harder to read
error messages. This commit modifies the representation of tuples, lists,
dicts, and None so that it matches the pythonic representation.
The representation of custom nodes and NamedTuples is left unchanged since
their structure is not easily accessible in C++. However, to avoid confusion
they are now labelled "CustomNode" instead of "PyTreeDef". The latter is now
only used to wrap the whole representation. See below for examples.
Tests that relied on a specific string representation of PyTreeDef in error
messages are modified to be agnostic to the representation. Instead, this
commit adds a separate test of the string representation in tree_util_test.
Examples:
```
OLD: PyTreeDef(dict[['a', 'b']], [*,*])
NEW: PyTreeDef({'a': *, 'b': *})
OLD: PyTreeDef(tuple, [PyTreeDef(tuple, [*,*]),PyTreeDef(list, [*,PyTreeDef(tuple, [*,PyTreeDef(None, []),*])])])
NEW: PyTreeDef(((*, *), [*, (*, None, *)]))
OLD: PyTreeDef(list, [PyTreeDef(<class '__main__.AnObject'>[[4, 'foo']], [*,PyTreeDef(None, [])])])
NEW: PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])
```
PiperOrigin-RevId: 369298267
2021-04-19 14:06:11 -07:00
|
|
|
|
re.escape(f"cond_fun must return a boolean scalar, but got pytree {tuple_treedef}.")):
|
2019-11-14 12:26:41 +01:00
|
|
|
|
lax.while_loop(lambda c: (1., 1.), lambda c: c, 0.)
|
2024-08-03 21:49:58 +00:00
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
TypeError,
|
2020-10-08 15:21:57 -07:00
|
|
|
|
re.escape("cond_fun must return a boolean scalar, but got output type(s) [ShapedArray(float32[])].")):
|
2020-07-30 12:59:36 -07:00
|
|
|
|
lax.while_loop(lambda c: np.float32(1.), lambda c: c, np.float32(0.))
|
2024-08-03 21:49:58 +00:00
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
TypeError,
|
|
|
|
|
re.escape("while_loop body function carry input and carry output must "
|
|
|
|
|
"have the same pytree structure, but they differ:\n\n"
|
|
|
|
|
"The input carry c is a")):
|
2019-11-14 12:26:41 +01:00
|
|
|
|
lax.while_loop(lambda c: True, lambda c: (1., 1.), 0.)
|
2024-08-03 21:49:58 +00:00
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
TypeError,
|
|
|
|
|
r"The input carry component c\[1\] has type float32\[\] but the "
|
|
|
|
|
r"corresponding output carry component has type bool\[\], so the "
|
|
|
|
|
"dtypes do not match."):
|
2021-11-19 13:31:29 -08:00
|
|
|
|
lax.while_loop(lambda c: True, lambda c: (True, True),
|
|
|
|
|
(np.bool_(True), np.float32(0.)))
|
2019-11-14 12:26:41 +01:00
|
|
|
|
|
2024-12-09 10:44:28 -08:00
|
|
|
|
def testWhileLoopCustomPytreeDiffAuxData(self):
|
|
|
|
|
class Node:
|
|
|
|
|
def __init__(self, x, y):
|
|
|
|
|
self.x = x
|
|
|
|
|
self.y = y
|
|
|
|
|
tree_util.register_pytree_with_keys(
|
|
|
|
|
Node,
|
|
|
|
|
lambda o: ((("x", o.x), ("y", o.y)), 'with_keys'), # flatten_with_keys
|
|
|
|
|
lambda _, xy: Node(xy[0], xy[1]), # unflatten (no key involved)
|
|
|
|
|
lambda o: ((o.x, o.y), 'without_keys'), # flatten
|
|
|
|
|
)
|
|
|
|
|
lax.while_loop(lambda o: o.x > 0., lambda c: Node(0., 0.), Node(1., 1.))
|
|
|
|
|
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def testNestedWhileWithDynamicUpdateSlice(self):
|
|
|
|
|
num = 5
|
|
|
|
|
|
|
|
|
|
def update_entry(arr, val, i, j):
|
|
|
|
|
val = lax.reshape(val, [1, 1])
|
|
|
|
|
return lax.dynamic_update_slice(arr, val, (i, j))
|
|
|
|
|
|
|
|
|
|
def outer_loop(arr): # pylint: disable=missing-docstring
|
|
|
|
|
|
|
|
|
|
def cond_fun(state):
|
|
|
|
|
i, num, _, _ = state
|
|
|
|
|
return lax.lt(i, num)
|
|
|
|
|
|
|
|
|
|
def body_fun(state):
|
|
|
|
|
i, num, arr, out = state
|
|
|
|
|
return (lax.add(i, 1), num, arr, inner_loop(i, arr, out))
|
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
out = np.zeros(arr.shape, dtype=arr.dtype)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
init_val = (0, num, arr, out)
|
|
|
|
|
_, _, _, out = lax.while_loop(cond_fun, body_fun, init_val)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def inner_loop(i, arr, out): # pylint: disable=missing-docstring
|
|
|
|
|
|
|
|
|
|
def cond_fun(state):
|
|
|
|
|
i, j, _, _ = state
|
|
|
|
|
return lax.le(j, i)
|
|
|
|
|
|
|
|
|
|
def body_fun(state):
|
|
|
|
|
i, j, arr, out = state
|
|
|
|
|
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
|
|
|
|
|
arr_i_j = lax.dynamic_index_in_dim(arr_i, j, 0, False)
|
|
|
|
|
out = update_entry(out, arr_i_j, i, j)
|
|
|
|
|
return (i, lax.add(j, 1), arr, out)
|
|
|
|
|
|
|
|
|
|
init_val = (i, 0, arr, out)
|
|
|
|
|
_, _, _, out = lax.while_loop(cond_fun, body_fun, init_val)
|
|
|
|
|
return out
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
cloop = jax.jit(outer_loop)
|
2021-12-10 10:32:09 -08:00
|
|
|
|
arr = self.rng().randn(5, 5)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
self.assertAllClose(outer_loop(arr), np.tril(arr), check_dtypes=False)
|
|
|
|
|
self.assertAllClose(cloop(arr), np.tril(arr), check_dtypes=False)
|
|
|
|
|
self.assertAllClose(cloop(arr), np.tril(arr), check_dtypes=False)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
def testLoopWithConjunctionCondition(self):
|
|
|
|
|
def sum_first_n(arr, num): # pylint: disable=missing-docstring
|
|
|
|
|
def cond_fun(state):
|
|
|
|
|
arr, num, i, _ = state
|
|
|
|
|
return lax.bitwise_and(lax.lt(i, num), lax.lt(i, arr.shape[0]))
|
|
|
|
|
|
|
|
|
|
def body_fun(state):
|
|
|
|
|
arr, num, i, total = state
|
|
|
|
|
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
|
2022-12-02 13:20:30 -08:00
|
|
|
|
return (arr, num, i + 1, total + arr_i)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
init_val = (arr, num, 0, 0.)
|
|
|
|
|
_, _, _, total = lax.while_loop(cond_fun, body_fun, init_val)
|
|
|
|
|
return total
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
cfun = jax.jit(sum_first_n)
|
2021-12-10 10:32:09 -08:00
|
|
|
|
x = self.rng().randn(10).astype(jnp.float_)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
for num in [0, 5, 10, 15]:
|
2020-05-05 14:59:16 -04:00
|
|
|
|
self.assertAllClose(sum_first_n(x, num), np.sum(x[:num]),
|
2019-05-09 21:58:21 -07:00
|
|
|
|
check_dtypes=False)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
|
|
|
|
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2019-05-16 07:40:31 -07:00
|
|
|
|
def testWhileLoopBatched(self):
|
|
|
|
|
def fun(x):
|
|
|
|
|
return lax.while_loop(lambda x: x < 3, lambda x: x + 2, x)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun)(np.array([0, 1, 2, 3]))
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected = np.array([4, 3, 4, 3])
|
2019-05-16 07:40:31 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
fun = jax.jit(fun)
|
|
|
|
|
ans = jax.vmap(fun)(np.array([0, 1, 2, 3]))
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected = np.array([4, 3, 4, 3])
|
2019-05-16 07:40:31 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
2021-03-04 04:13:04 -08:00
|
|
|
|
def testWhileLoopAxisIndexBatched(self):
|
|
|
|
|
def fun(x):
|
|
|
|
|
return lax.while_loop(lambda x: x < lax.axis_index('i'), lambda x: x + 2, x)
|
|
|
|
|
|
2022-06-17 14:14:23 -07:00
|
|
|
|
ans = jax.vmap(fun, axis_name='i')(np.array([0, 0, 0, 0], dtype='int32'))
|
2021-03-04 04:13:04 -08:00
|
|
|
|
expected = np.array([0, 2, 2, 4])
|
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
fun = jax.jit(fun)
|
2022-06-17 14:14:23 -07:00
|
|
|
|
ans = jax.vmap(fun, axis_name='i')(np.array([0, 0, 0, 0], dtype='int32'))
|
2021-03-04 04:13:04 -08:00
|
|
|
|
expected = np.array([0, 2, 2, 4])
|
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(lambda _, x: fun(x), axis_name='i', in_axes=(0, None))(
|
2021-08-26 13:34:01 -07:00
|
|
|
|
np.array([0, 0, 0, 0]), 0)
|
2022-06-17 14:14:23 -07:00
|
|
|
|
expected = np.array([0, 2, 2, 4], dtype='int32')
|
2021-08-26 13:34:01 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
2021-10-11 10:33:21 +00:00
|
|
|
|
def testWhileLoopBatchedWithConstBody(self):
|
|
|
|
|
def f(x):
|
|
|
|
|
def body_fn(_): return jnp.asarray(0., dtype=jnp.float32)
|
|
|
|
|
def cond_fn(_): return jnp.logical_not(False) == False
|
|
|
|
|
return jax.lax.while_loop(cond_fn, body_fn, x)
|
|
|
|
|
x = jnp.arange(5, dtype=jnp.float32)
|
|
|
|
|
self.assertAllClose(jax.vmap(f)(x), x)
|
2021-08-26 13:34:01 -07:00
|
|
|
|
|
2019-05-16 07:40:31 -07:00
|
|
|
|
def testWhileLoopCondConstsBatched(self):
|
|
|
|
|
def fun(x, y):
|
|
|
|
|
return lax.while_loop(lambda x: x < y, lambda x: x + 2, x)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun, in_axes=(None, 0))(0, np.array([2, 3]))
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected = np.array([2, 4])
|
2019-05-16 07:40:31 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
|
|
def testWhileLoopBodyConstsBatched(self):
|
|
|
|
|
def fun(x, y):
|
|
|
|
|
return lax.while_loop(lambda x: x < 3, lambda x: x + y, x)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun, in_axes=(None, 0))(0, jnp.array([2, 3]))
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected = np.array([4, 3])
|
2019-05-16 07:40:31 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
|
|
def testWhileLoopTupleBatched(self):
|
|
|
|
|
def cond_fun(loop_carry):
|
|
|
|
|
x, y = loop_carry
|
|
|
|
|
return x + y < 5
|
|
|
|
|
|
|
|
|
|
def body_fun(loop_carry):
|
|
|
|
|
x, y = loop_carry
|
|
|
|
|
x = x + 1
|
|
|
|
|
return x, y
|
|
|
|
|
|
|
|
|
|
def fun(x, y):
|
|
|
|
|
return lax.while_loop(cond_fun, body_fun, (x, y))
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun)(np.array([0, 0]), np.array([1, 2]))
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected = (np.array([4, 3]), np.array([1, 2]))
|
2019-05-16 07:40:31 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
2020-05-26 10:22:33 +03:00
|
|
|
|
def test_issue_3204(self):
|
|
|
|
|
# Error during XLA code generation for vmap of nested loops
|
|
|
|
|
def test(a, b):
|
|
|
|
|
val = 0
|
|
|
|
|
i = 0
|
|
|
|
|
j = 0
|
|
|
|
|
|
|
|
|
|
condfun_1 = lambda inp: inp[1] < a + 1
|
|
|
|
|
condfun_2 = lambda inp: inp[2] < b + 1
|
|
|
|
|
|
|
|
|
|
def bodyfun_1(inp):
|
|
|
|
|
val, i, j = inp
|
|
|
|
|
j = 0
|
|
|
|
|
|
|
|
|
|
def bodyfun_2(inp):
|
|
|
|
|
val, i, j = inp
|
|
|
|
|
val += i + j
|
|
|
|
|
j += 1
|
|
|
|
|
return (val, i, j)
|
|
|
|
|
|
|
|
|
|
result = lax.while_loop(condfun_2, bodyfun_2, (val, i, j))
|
|
|
|
|
val = result[0]
|
|
|
|
|
i += 1
|
|
|
|
|
return (val, i, j)
|
|
|
|
|
|
|
|
|
|
result = lax.while_loop(condfun_1, bodyfun_1, (val, i, j))
|
|
|
|
|
return result[0]
|
|
|
|
|
|
|
|
|
|
arr = np.arange(5)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
vmap_test = jax.vmap(test, (0, 0))
|
2020-05-26 10:22:33 +03:00
|
|
|
|
vmap_test(arr, arr)
|
|
|
|
|
|
2019-11-14 16:18:00 -05:00
|
|
|
|
def testForiLoopErrors(self):
|
2023-04-12 06:31:26 -07:00
|
|
|
|
"""Test typing error messages for fori_loop."""
|
2019-11-14 16:18:00 -05:00
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
TypeError, "arguments to fori_loop must have equal types"):
|
2020-05-05 14:59:16 -04:00
|
|
|
|
lax.fori_loop(np.int16(0), jnp.int32(10), (lambda i, c: c), jnp.float32(7))
|
2019-11-14 16:18:00 -05:00
|
|
|
|
|
2023-04-12 06:31:26 -07:00
|
|
|
|
def testForiLoopScalarLimits(self):
|
|
|
|
|
"""Test that scalar limits passed to fori_loop do not cause typing errors."""
|
|
|
|
|
body = lambda i, c: c + 1
|
|
|
|
|
init = jnp.float32(10)
|
|
|
|
|
|
|
|
|
|
result = lax.fori_loop(np.int16(0), 10, body, init)
|
|
|
|
|
self.assertEqual(result, init + 10)
|
|
|
|
|
|
|
|
|
|
result = lax.fori_loop(0, np.int16(10), body, init)
|
|
|
|
|
self.assertEqual(result, init + 10)
|
|
|
|
|
|
2023-12-04 10:26:25 -08:00
|
|
|
|
def test_fori_loop_supports_unrolling(self):
|
|
|
|
|
"""Test that we can unroll static fori_loops."""
|
|
|
|
|
body = lambda i, c: c + 1
|
|
|
|
|
init = jnp.float32(10)
|
|
|
|
|
|
|
|
|
|
result = lax.fori_loop(np.int16(0), 10, body, init,
|
|
|
|
|
unroll=3)
|
|
|
|
|
self.assertEqual(result, init + 10)
|
|
|
|
|
|
|
|
|
|
result = lax.fori_loop(0, np.int16(10), body, init,
|
|
|
|
|
unroll=2)
|
|
|
|
|
self.assertEqual(result, init + 10)
|
|
|
|
|
|
2023-12-04 11:54:12 -08:00
|
|
|
|
def test_fori_loop_supports_unrolling_with_bool(self):
|
|
|
|
|
"""Test that we can unroll static fori_loops."""
|
|
|
|
|
body = lambda i, c: c + 1
|
|
|
|
|
init = jnp.float32(10)
|
|
|
|
|
|
|
|
|
|
result = lax.fori_loop(np.int16(0), 10, body, init,
|
|
|
|
|
unroll=True)
|
|
|
|
|
self.assertEqual(result, init + 10)
|
|
|
|
|
|
|
|
|
|
result = lax.fori_loop(0, np.int16(10), body, init,
|
|
|
|
|
unroll=False)
|
|
|
|
|
self.assertEqual(result, init + 10)
|
|
|
|
|
|
2023-12-04 10:26:25 -08:00
|
|
|
|
def test_fori_loop_with_dynamic_indices_cannot_unroll(self):
|
|
|
|
|
"""Test that we can't unroll dynamic fori_loops."""
|
|
|
|
|
body = lambda i, c: c + 1
|
|
|
|
|
init = jnp.float32(10)
|
|
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
|
def f(upper):
|
|
|
|
|
return lax.fori_loop(np.int16(0), upper, body, init,
|
|
|
|
|
unroll=3)
|
|
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "Can only use `unroll`"):
|
|
|
|
|
f(10)
|
|
|
|
|
|
2024-01-25 16:33:25 +00:00
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{
|
|
|
|
|
"testcase_name": f"_{jit=}_{upper=}_{unroll=}",
|
|
|
|
|
"jit": jit,
|
|
|
|
|
"upper": upper,
|
|
|
|
|
"unroll": unroll,
|
|
|
|
|
}
|
|
|
|
|
for jit in (False, True)
|
|
|
|
|
for upper in (0, -1)
|
|
|
|
|
for unroll in (False, True)
|
|
|
|
|
)
|
|
|
|
|
def test_fori_loop_returns_init_with_nonpositive_length(
|
|
|
|
|
self, jit, upper, unroll
|
|
|
|
|
):
|
|
|
|
|
"""Test that `length <= 0` behaves like Python `range`."""
|
|
|
|
|
fori_loop_with_static_upper_and_lower = partial(
|
|
|
|
|
lax.fori_loop, 0, upper, lambda i, c: c + 1, unroll=unroll
|
|
|
|
|
)
|
|
|
|
|
if jit:
|
|
|
|
|
fori_loop_with_static_upper_and_lower = jax.jit(
|
|
|
|
|
fori_loop_with_static_upper_and_lower
|
|
|
|
|
)
|
|
|
|
|
init = jnp.float32(10)
|
|
|
|
|
self.assertEqual(fori_loop_with_static_upper_and_lower(init), init)
|
|
|
|
|
|
2025-01-09 13:26:29 +02:00
|
|
|
|
|
2019-05-16 07:40:31 -07:00
|
|
|
|
def testForiLoopBatched(self):
|
|
|
|
|
def body_fun(i, loop_carry):
|
|
|
|
|
x, y = loop_carry
|
|
|
|
|
x = x + 1
|
|
|
|
|
y = y + 2
|
|
|
|
|
return x, y
|
|
|
|
|
|
|
|
|
|
def fun(x):
|
|
|
|
|
return lax.fori_loop(0, 10, body_fun, (x, 0))
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun)(np.array([0, 1]))
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected = (np.array([10, 11]), np.array([20, 20]))
|
2019-05-16 07:40:31 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
2019-10-28 13:02:31 -07:00
|
|
|
|
def testForiLoopBatchedIssue1190(self):
|
2020-03-13 15:15:55 -07:00
|
|
|
|
cond_fun = lambda carry: carry[0] < 4
|
|
|
|
|
body_fun = lambda carry: (carry[0] + 1, carry[1] + 1)
|
|
|
|
|
f = lambda x: lax.while_loop(cond_fun, body_fun, (0, x))
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vmap(f))(jnp.arange(3))
|
2019-11-26 08:18:53 +01:00
|
|
|
|
eqn = jaxpr.jaxpr.eqns[0]
|
2019-10-28 13:02:31 -07:00
|
|
|
|
self.assertIs(eqn.primitive, lax.while_p)
|
|
|
|
|
self.assertEqual(eqn.params['cond_jaxpr'].in_avals[0].shape, ())
|
|
|
|
|
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def testForiLoopBasic(self):
|
2019-09-26 13:34:57 +02:00
|
|
|
|
def body_fun(i, tot):
|
|
|
|
|
return lax.add(tot, i)
|
|
|
|
|
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def count(num):
|
|
|
|
|
return lax.fori_loop(0, num, body_fun, 0)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(count(2), 1)
|
|
|
|
|
self.assertEqual(count(3), 3)
|
|
|
|
|
self.assertEqual(count(4), 6)
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
2019-09-26 13:34:57 +02:00
|
|
|
|
for args_maker in [lambda: [2], lambda: [3], lambda: [4]]:
|
2020-06-01 17:19:23 -04:00
|
|
|
|
self._CompileAndCheck(count, args_maker)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
def testForiLoopClosure(self):
|
|
|
|
|
def count(num):
|
|
|
|
|
def body_fun(i, tot):
|
|
|
|
|
return lax.add(num, lax.add(tot, i))
|
|
|
|
|
return lax.fori_loop(0, num, body_fun, 0)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
cfun = jax.jit(count)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
self.assertEqual(count(2), 1 + 2**2)
|
|
|
|
|
self.assertEqual(count(2), cfun(2))
|
|
|
|
|
self.assertEqual(count(3), 3 + 3**2)
|
|
|
|
|
self.assertEqual(count(3), cfun(3))
|
|
|
|
|
self.assertEqual(count(4), 6 + 4**2)
|
|
|
|
|
self.assertEqual(count(4), cfun(4))
|
|
|
|
|
|
|
|
|
|
def testForiLoopTupleState(self):
|
|
|
|
|
def sum_first_n(arr, num):
|
|
|
|
|
def body_fun(i, state):
|
|
|
|
|
arr, total = state
|
|
|
|
|
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
|
|
|
|
|
return (arr, lax.add(total, arr_i))
|
|
|
|
|
|
2022-12-02 13:20:30 -08:00
|
|
|
|
init_val = (arr, arr.dtype.type(0))
|
2019-05-09 21:58:21 -07:00
|
|
|
|
_, total = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun,
|
|
|
|
|
init_val)
|
|
|
|
|
return total
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
cfun = jax.jit(sum_first_n)
|
2021-12-10 10:32:09 -08:00
|
|
|
|
x = self.rng().randn(10).astype(jnp.float_)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
for num in [0, 5, 10, 15]:
|
2020-05-05 14:59:16 -04:00
|
|
|
|
self.assertAllClose(sum_first_n(x, num), np.sum(x[:num]),
|
2019-05-09 21:58:21 -07:00
|
|
|
|
check_dtypes=False)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
|
|
|
|
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
def testForiLoopDictState(self):
|
|
|
|
|
def sum_first_n(arr, num):
|
|
|
|
|
def body_fun(i, state):
|
|
|
|
|
arr, total = state['arr'], state['total']
|
|
|
|
|
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
|
|
|
|
|
return {'arr': arr, 'total': lax.add(total, arr_i)}
|
|
|
|
|
|
2022-12-02 13:20:30 -08:00
|
|
|
|
init_val = {'arr': arr, 'total': arr.dtype.type(0)}
|
2019-05-09 21:58:21 -07:00
|
|
|
|
out_val = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
|
|
|
|
|
return out_val['total']
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
cfun = jax.jit(sum_first_n)
|
2021-12-10 10:32:09 -08:00
|
|
|
|
x = self.rng().randn(10).astype(jnp.float_)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
for num in [0, 5, 10, 15]:
|
2020-05-05 14:59:16 -04:00
|
|
|
|
self.assertAllClose(sum_first_n(x, num), np.sum(x[:num]),
|
2019-05-09 21:58:21 -07:00
|
|
|
|
check_dtypes=False)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
|
|
|
|
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
def testForiLoopEmptyTupleInState(self):
|
|
|
|
|
def sum_first_n(arr, num):
|
|
|
|
|
def body_fun(i, state):
|
|
|
|
|
arr, total, _ = state
|
|
|
|
|
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
|
|
|
|
|
return (arr, lax.add(total, arr_i), ())
|
|
|
|
|
|
2022-12-02 13:20:30 -08:00
|
|
|
|
init_val = (arr, arr.dtype.type(0), ())
|
2019-05-09 21:58:21 -07:00
|
|
|
|
_, tot, _ = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
|
|
|
|
|
return tot
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
cfun = jax.jit(sum_first_n)
|
2021-12-10 10:32:09 -08:00
|
|
|
|
x = self.rng().randn(10).astype(jnp.float_)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
for num in [0, 5, 10, 15]:
|
2020-05-05 14:59:16 -04:00
|
|
|
|
self.assertAllClose(sum_first_n(x, num), np.sum(x[:num]),
|
2019-05-09 21:58:21 -07:00
|
|
|
|
check_dtypes=False)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
|
|
|
|
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2021-10-11 18:56:15 +02:00
|
|
|
|
def testForiLoopIssue8152(self):
|
|
|
|
|
y = lax.fori_loop(lower=0, upper=0, body_fun=lambda x, i: x + i, init_val=1.)
|
|
|
|
|
self.assertAllClose(y, 1., check_dtypes=False)
|
|
|
|
|
|
|
|
|
|
# trivial fori_loop should work - even when jit is disabled
|
|
|
|
|
with jax.disable_jit():
|
|
|
|
|
y = lax.fori_loop(lower=0, upper=0, body_fun=lambda x, i: x + i, init_val=1.)
|
|
|
|
|
self.assertAllClose(y, 1., check_dtypes=False)
|
|
|
|
|
|
|
|
|
|
# scan with length 0 should work with jit, but raise an error without
|
|
|
|
|
def should_raise_wo_jit():
|
|
|
|
|
carry, out = lax.scan(lambda c, x: (c + x, x), 0., np.array([]))
|
|
|
|
|
return carry
|
|
|
|
|
self.assertAllClose(should_raise_wo_jit(), 0., check_dtypes=False)
|
|
|
|
|
with jax.disable_jit():
|
|
|
|
|
self.assertRaises(ValueError, should_raise_wo_jit)
|
|
|
|
|
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def testCond(self):
|
|
|
|
|
def fun(x):
|
|
|
|
|
if x < 3:
|
|
|
|
|
return (x, x)
|
|
|
|
|
else:
|
|
|
|
|
y = lax.mul(2, x)
|
|
|
|
|
return y, lax.mul(2, y)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
@jax.jit
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def cfun(x):
|
|
|
|
|
def false_fun(x):
|
|
|
|
|
y = lax.mul(2, x)
|
|
|
|
|
return y, lax.mul(2, y)
|
2020-05-11 15:47:16 -07:00
|
|
|
|
return lax.cond(lax.lt(x, 3), lambda x: (x, x), false_fun, x)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
self.assertEqual(fun(0), cfun(0))
|
|
|
|
|
self.assertEqual(fun(0), (0, 0))
|
|
|
|
|
self.assertEqual(fun(1), cfun(1))
|
|
|
|
|
self.assertEqual(fun(1), (1, 1))
|
|
|
|
|
self.assertEqual(fun(2), cfun(2))
|
|
|
|
|
self.assertEqual(fun(2), (2, 2))
|
|
|
|
|
self.assertEqual(fun(3), cfun(3))
|
|
|
|
|
self.assertEqual(fun(3), (6, 12))
|
|
|
|
|
self.assertEqual(fun(4), cfun(4))
|
|
|
|
|
self.assertEqual(fun(4), (8, 16))
|
|
|
|
|
|
2022-07-26 13:12:16 -07:00
|
|
|
|
def testCondPredIsNone(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# see https://github.com/jax-ml/jax/issues/11574
|
2022-07-26 13:12:16 -07:00
|
|
|
|
def f(pred, x):
|
|
|
|
|
return lax.cond(pred, lambda x: x + 1, lambda x: x + 2, x)
|
|
|
|
|
|
|
|
|
|
self.assertRaisesRegex(TypeError, "cond predicate is None",
|
|
|
|
|
lambda: f(None, 1.))
|
|
|
|
|
self.assertRaisesRegex(TypeError, "cond predicate is None",
|
|
|
|
|
lambda: jax.jit(f)(None, 1.))
|
|
|
|
|
|
2021-11-05 09:10:11 -07:00
|
|
|
|
def testCondTwoOperands(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# see https://github.com/jax-ml/jax/issues/8469
|
2021-11-05 09:10:11 -07:00
|
|
|
|
add, mul = lax.add, lax.mul
|
|
|
|
|
|
|
|
|
|
def fun(x):
|
|
|
|
|
return add(x, x) if x == 0 else mul(x, x)
|
|
|
|
|
|
|
|
|
|
def cfun(x):
|
|
|
|
|
return lax.cond(x == 0, add, mul, x, x)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(fun(0), cfun(0))
|
|
|
|
|
self.assertEqual(fun(1), cfun(1))
|
|
|
|
|
cfun = jax.jit(cfun)
|
|
|
|
|
self.assertEqual(fun(0), cfun(0))
|
|
|
|
|
self.assertEqual(fun(1), cfun(1))
|
|
|
|
|
|
|
|
|
|
def testCondThreeOperands(self):
|
|
|
|
|
add = lambda x, y, z: x + y + z
|
|
|
|
|
mul = lambda x, y, z: x * y * z
|
|
|
|
|
|
|
|
|
|
def fun(x):
|
|
|
|
|
return add(x, x, x) if x == 0 else mul(x, x, x)
|
|
|
|
|
|
|
|
|
|
def cfun(x):
|
|
|
|
|
return lax.cond(x == 0, add, mul, x, x, x)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(fun(0), cfun(0))
|
|
|
|
|
self.assertEqual(fun(1), cfun(1))
|
|
|
|
|
cfun = jax.jit(cfun)
|
|
|
|
|
self.assertEqual(fun(0), cfun(0))
|
|
|
|
|
self.assertEqual(fun(1), cfun(1))
|
|
|
|
|
|
2023-06-27 18:48:39 -07:00
|
|
|
|
def testCondCallableOperands(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# see https://github.com/jax-ml/jax/issues/16413
|
2023-06-27 18:48:39 -07:00
|
|
|
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
|
|
|
class Foo:
|
|
|
|
|
def __init__(self, x):
|
|
|
|
|
self.x = x
|
|
|
|
|
|
|
|
|
|
def __call__(self, *xs):
|
|
|
|
|
assert False
|
|
|
|
|
return xs
|
|
|
|
|
|
|
|
|
|
def tree_flatten(self):
|
|
|
|
|
return (self.x,), None
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def tree_unflatten(cls, _, xs):
|
|
|
|
|
return cls(*xs)
|
|
|
|
|
|
|
|
|
|
f_00 = lambda a, b: a + b
|
|
|
|
|
f_01 = lambda a, b: a + b.x
|
|
|
|
|
f_10 = lambda a, b: a.x + b
|
|
|
|
|
f_11 = lambda a, b: a.x + b.x
|
|
|
|
|
|
|
|
|
|
# these don't raise
|
|
|
|
|
a = lax.cond(True, f_00, f_00, 3, 4)
|
|
|
|
|
b = lax.cond(True, f_01, f_01, 3, Foo(4))
|
|
|
|
|
c = lax.cond(True, f_10, f_10, Foo(3), 4)
|
|
|
|
|
d = lax.cond(True, f_11, f_11, Foo(3), Foo(4))
|
|
|
|
|
self.assertEqual(a, b)
|
|
|
|
|
self.assertEqual(a, c)
|
|
|
|
|
self.assertEqual(a, d)
|
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
def testSwitch(self):
|
|
|
|
|
def branch(x):
|
|
|
|
|
y = lax.mul(2, x)
|
|
|
|
|
return y, lax.mul(2, y)
|
|
|
|
|
|
|
|
|
|
branches = [lambda x: (x, x),
|
|
|
|
|
branch,
|
|
|
|
|
lambda x: (x, -x)]
|
|
|
|
|
|
|
|
|
|
def fun(x):
|
|
|
|
|
if x <= 0:
|
|
|
|
|
return branches[0](x)
|
|
|
|
|
elif x == 1:
|
|
|
|
|
return branches[1](x)
|
|
|
|
|
else:
|
|
|
|
|
return branches[2](x)
|
|
|
|
|
|
|
|
|
|
def cfun(x):
|
|
|
|
|
return lax.switch(x, branches, x)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(fun(-1), cfun(-1))
|
|
|
|
|
self.assertEqual(fun(0), cfun(0))
|
|
|
|
|
self.assertEqual(fun(1), cfun(1))
|
|
|
|
|
self.assertEqual(fun(2), cfun(2))
|
|
|
|
|
self.assertEqual(fun(3), cfun(3))
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
cfun = jax.jit(cfun)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
|
|
|
|
|
self.assertEqual(fun(-1), cfun(-1))
|
|
|
|
|
self.assertEqual(fun(0), cfun(0))
|
|
|
|
|
self.assertEqual(fun(1), cfun(1))
|
|
|
|
|
self.assertEqual(fun(2), cfun(2))
|
|
|
|
|
self.assertEqual(fun(3), cfun(3))
|
|
|
|
|
|
2021-11-05 09:11:01 -07:00
|
|
|
|
def testSwitchMultiOperands(self):
|
|
|
|
|
branches = [lax.add, lax.mul]
|
|
|
|
|
|
|
|
|
|
def fun(x):
|
|
|
|
|
i = 0 if x <= 0 else 1
|
|
|
|
|
return branches[i](x, x)
|
|
|
|
|
|
|
|
|
|
def cfun(x):
|
|
|
|
|
return lax.switch(x, branches, x, x)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(fun(-1), cfun(-1))
|
|
|
|
|
self.assertEqual(fun(0), cfun(0))
|
|
|
|
|
self.assertEqual(fun(1), cfun(1))
|
|
|
|
|
self.assertEqual(fun(2), cfun(2))
|
|
|
|
|
cfun = jax.jit(cfun)
|
|
|
|
|
self.assertEqual(fun(-1), cfun(-1))
|
|
|
|
|
self.assertEqual(fun(0), cfun(0))
|
|
|
|
|
self.assertEqual(fun(1), cfun(1))
|
|
|
|
|
self.assertEqual(fun(2), cfun(2))
|
|
|
|
|
|
2020-06-11 19:17:07 -07:00
|
|
|
|
def testSwitchResidualsMerge(self):
|
|
|
|
|
def get_conds(fun):
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jaxpr = jax.make_jaxpr(jax.grad(fun))(0., 0)
|
2020-06-11 19:17:07 -07:00
|
|
|
|
return [eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == 'cond']
|
|
|
|
|
|
|
|
|
|
def branch_invars_len(cond_eqn):
|
|
|
|
|
lens = [len(jaxpr.jaxpr.invars) for jaxpr in cond_eqn.params['branches']]
|
|
|
|
|
assert len(set(lens)) == 1
|
|
|
|
|
return lens[0]
|
|
|
|
|
|
|
|
|
|
def branch_outvars_len(cond_eqn):
|
|
|
|
|
lens = [len(jaxpr.jaxpr.outvars) for jaxpr in cond_eqn.params['branches']]
|
|
|
|
|
assert len(set(lens)) == 1
|
|
|
|
|
return lens[0]
|
|
|
|
|
|
|
|
|
|
branches1 = [
|
|
|
|
|
lambda x: jnp.sin(x),
|
|
|
|
|
lambda x: jnp.cos(x)] # branch residuals overlap, should be reused
|
|
|
|
|
branches2 = branches1 + [
|
|
|
|
|
lambda x: jnp.sinh(x)] # another overlapping residual, expect reuse
|
|
|
|
|
branches3 = branches2 + [
|
|
|
|
|
lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
|
|
|
|
|
def fun1(x, i):
|
|
|
|
|
return lax.switch(i + 1, branches1, x)
|
|
|
|
|
def fun2(x, i):
|
|
|
|
|
return lax.switch(i + 1, branches2, x)
|
|
|
|
|
def fun3(x, i):
|
|
|
|
|
return lax.switch(i + 1, branches3, x)
|
|
|
|
|
|
|
|
|
|
fwd1, bwd1 = get_conds(fun1)
|
|
|
|
|
fwd2, bwd2 = get_conds(fun2)
|
|
|
|
|
fwd3, bwd3 = get_conds(fun3)
|
|
|
|
|
|
|
|
|
|
fwd1_num_out = branch_outvars_len(fwd1)
|
|
|
|
|
fwd2_num_out = branch_outvars_len(fwd2)
|
|
|
|
|
fwd3_num_out = branch_outvars_len(fwd3)
|
|
|
|
|
assert fwd1_num_out == fwd2_num_out
|
|
|
|
|
assert fwd3_num_out == fwd2_num_out + 1
|
|
|
|
|
|
|
|
|
|
bwd1_num_in = branch_invars_len(bwd1)
|
|
|
|
|
bwd2_num_in = branch_invars_len(bwd2)
|
|
|
|
|
bwd3_num_in = branch_invars_len(bwd3)
|
|
|
|
|
assert bwd1_num_in == bwd2_num_in
|
|
|
|
|
assert bwd3_num_in == bwd2_num_in + 1
|
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
def testOneBranchSwitch(self):
|
|
|
|
|
branch = lambda x: -x
|
|
|
|
|
f = lambda i, x: lax.switch(i, [branch], x)
|
|
|
|
|
x = 7.
|
|
|
|
|
self.assertEqual(f(-1, x), branch(x))
|
|
|
|
|
self.assertEqual(f(0, x), branch(x))
|
|
|
|
|
self.assertEqual(f(1, x), branch(x))
|
2021-09-13 16:00:22 -04:00
|
|
|
|
cf = jax.jit(f)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
self.assertEqual(cf(-1, x), branch(x))
|
|
|
|
|
self.assertEqual(cf(0, x), branch(x))
|
|
|
|
|
self.assertEqual(cf(1, x), branch(x))
|
2021-09-13 16:00:22 -04:00
|
|
|
|
cf = jax.jit(f, static_argnums=0)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
self.assertEqual(cf(-1, x), branch(x))
|
|
|
|
|
self.assertEqual(cf(0, x), branch(x))
|
|
|
|
|
self.assertEqual(cf(1, x), branch(x))
|
|
|
|
|
|
2019-09-27 17:00:38 +01:00
|
|
|
|
def testIssue1379(self):
|
|
|
|
|
def fun(pred):
|
2020-05-11 15:47:16 -07:00
|
|
|
|
return lax.cond(pred, lambda x: (True, x), lambda x: (False, x), pred)
|
2019-11-14 16:18:00 -05:00
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
@jax.jit
|
2019-09-27 17:00:38 +01:00
|
|
|
|
def cfun(pred):
|
|
|
|
|
return fun(pred)
|
2019-11-14 16:18:00 -05:00
|
|
|
|
|
2019-09-27 17:00:38 +01:00
|
|
|
|
self.assertEqual(fun(0), cfun(0), (False,0))
|
|
|
|
|
self.assertEqual(fun(0.), cfun(0.), (False,0.))
|
|
|
|
|
self.assertEqual(fun(1), cfun(1), (True,1))
|
|
|
|
|
self.assertEqual(fun(1.), cfun(1.), (True,1.))
|
|
|
|
|
|
|
|
|
|
# test that proper errors are raised for wrong types
|
|
|
|
|
for pred in ["abc", [], [1,2]]:
|
|
|
|
|
for f in [fun, cfun]:
|
|
|
|
|
self.assertRaises(TypeError, f, pred)
|
2019-11-14 16:18:00 -05:00
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{"testcase_name": f"_{name}", "cond": cond}
|
|
|
|
|
for cond, name in COND_IMPLS)
|
|
|
|
|
def testNestedCond(self, cond):
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def fun(x):
|
|
|
|
|
if x < 2:
|
|
|
|
|
return lax.mul(2, x)
|
|
|
|
|
else:
|
|
|
|
|
if x < 5:
|
|
|
|
|
return lax.mul(3, x)
|
|
|
|
|
else:
|
|
|
|
|
return lax.mul(4, x)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
@jax.jit
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def cfun(x):
|
2020-06-01 19:07:11 -07:00
|
|
|
|
return cond(
|
2020-05-11 15:47:16 -07:00
|
|
|
|
lax.lt(x, 2),
|
2020-04-28 22:37:44 -07:00
|
|
|
|
lambda x: lax.mul(2, x),
|
2020-06-01 19:07:11 -07:00
|
|
|
|
lambda x: cond(lax.lt(x, 5),
|
|
|
|
|
x, lambda x: lax.mul(3, x),
|
|
|
|
|
4, lambda y: lax.mul(y, x)),
|
2020-05-11 15:47:16 -07:00
|
|
|
|
x)
|
2020-04-28 22:37:44 -07:00
|
|
|
|
|
2019-05-09 21:58:21 -07:00
|
|
|
|
self.assertEqual(cfun(1), 2)
|
|
|
|
|
self.assertEqual(cfun(3), 9)
|
|
|
|
|
self.assertEqual(cfun(6), 24)
|
|
|
|
|
self.assertEqual(cfun(1), fun(1))
|
|
|
|
|
self.assertEqual(cfun(3), fun(3))
|
|
|
|
|
self.assertEqual(cfun(6), fun(6))
|
|
|
|
|
|
2019-11-14 12:26:41 +01:00
|
|
|
|
def testCondTypeErrors(self):
|
|
|
|
|
"""Test typing error messages for cond."""
|
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
|
|
|
|
re.escape("Pred type must be either boolean or number, got <function")):
|
2020-05-11 15:47:16 -07:00
|
|
|
|
lax.cond(lambda x: True, lambda top: 2., lambda fop: 3., 1.)
|
2019-11-14 12:26:41 +01:00
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
2021-02-03 12:22:44 +02:00
|
|
|
|
re.escape("Pred must be a scalar, got foo of type <class 'str'>")):
|
2020-05-11 15:47:16 -07:00
|
|
|
|
lax.cond("foo", lambda top: 2., lambda fop: 3., 1.)
|
2019-11-14 12:26:41 +01:00
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
2021-02-03 12:22:44 +02:00
|
|
|
|
re.escape("Pred must be a scalar, got (1.0, 1.0) of type <class 'tuple'>")):
|
2020-05-11 15:47:16 -07:00
|
|
|
|
lax.cond((1., 1.), lambda top: 2., lambda fop: 3., 1.)
|
2019-11-14 12:26:41 +01:00
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
[better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140
Previously, the following code:
```
def f(i, x):
return lax.switch(i, [lambda x: dict(a=x),
lambda x: dict(a=(x, x))], x)
f(0, 42)
```
resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```
With this change the error message is more specific where the
difference is in the pytree structure:
```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
* at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-09 14:15:39 +02:00
|
|
|
|
re.compile("true_fun output must have same type structure "
|
|
|
|
|
"as false_fun output, but there are differences:.*"
|
|
|
|
|
r"at output\['a'\], true_fun output has pytree leaf", re.DOTALL)):
|
|
|
|
|
lax.cond(True, lambda top: dict(a=2.), lambda fop: dict(a=(3., 3.)), 1.)
|
2020-07-30 12:59:36 -07:00
|
|
|
|
with self.assertRaisesRegex(
|
2021-11-19 13:31:29 -08:00
|
|
|
|
TypeError,
|
[better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140
Previously, the following code:
```
def f(i, x):
return lax.switch(i, [lambda x: dict(a=x),
lambda x: dict(a=(x, x))], x)
f(0, 42)
```
resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```
With this change the error message is more specific where the
difference is in the pytree structure:
```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
* at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-09 14:15:39 +02:00
|
|
|
|
"true_fun output and false_fun output must have identical types, got\n"
|
2021-11-19 13:31:29 -08:00
|
|
|
|
r"DIFFERENT ShapedArray\(float32\[1\]\) vs. "
|
|
|
|
|
r"ShapedArray\(float32\[\].*\)."):
|
2020-05-11 15:47:16 -07:00
|
|
|
|
lax.cond(True,
|
2020-04-28 22:37:44 -07:00
|
|
|
|
lambda top: jnp.array([1.], jnp.float32),
|
2020-05-11 15:47:16 -07:00
|
|
|
|
lambda fop: jnp.float32(1.),
|
|
|
|
|
1.)
|
2019-11-14 12:26:41 +01:00
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
def testSwitchErrors(self):
|
|
|
|
|
"""Test typing error messages for switch."""
|
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
|
|
|
|
re.escape("Index type must be an integer, got <function")):
|
|
|
|
|
lax.switch(lambda x: True, [lambda _: 2., lambda _: 3.], 1.)
|
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
|
|
|
|
re.escape("Index type must be an integer, got foo.")):
|
|
|
|
|
lax.switch("foo", [lambda _: 2., lambda _: 3.], 1.)
|
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
|
|
|
|
re.escape("Branch index must be scalar, got (1.0, 1.0) of shape (2,).")):
|
|
|
|
|
lax.switch((1., 1.), [lambda _: 2., lambda _: 3.], 1.)
|
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
|
|
|
re.escape("Empty branch sequence")):
|
|
|
|
|
lax.switch(0, [], 1.)
|
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
[better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140
Previously, the following code:
```
def f(i, x):
return lax.switch(i, [lambda x: dict(a=x),
lambda x: dict(a=(x, x))], x)
f(0, 42)
```
resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```
With this change the error message is more specific where the
difference is in the pytree structure:
```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
* at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-09 14:15:39 +02:00
|
|
|
|
re.compile("branch 0 output must have same type structure "
|
|
|
|
|
"as branch 1 output, but there are differences:.*"
|
|
|
|
|
r"at output\['a'\], branch 0 output has pytree leaf", re.DOTALL)):
|
|
|
|
|
lax.switch(1, [lambda _: dict(a=2.), lambda _: dict(a=(3., 3.))], 1.)
|
2020-07-30 12:59:36 -07:00
|
|
|
|
with self.assertRaisesRegex(
|
2021-11-19 13:31:29 -08:00
|
|
|
|
TypeError,
|
[better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140
Previously, the following code:
```
def f(i, x):
return lax.switch(i, [lambda x: dict(a=x),
lambda x: dict(a=(x, x))], x)
f(0, 42)
```
resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```
With this change the error message is more specific where the
difference is in the pytree structure:
```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
* at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-09 14:15:39 +02:00
|
|
|
|
"branch 0 output and branch 1 output must have identical types, got\n"
|
|
|
|
|
r"{'a': 'DIFFERENT ShapedArray\(float32\[1\]\) "
|
|
|
|
|
r"vs. ShapedArray\(float32\[\].*\)'}."):
|
|
|
|
|
lax.switch(1, [lambda _: dict(a=jnp.array([1.], jnp.float32)),
|
|
|
|
|
lambda _: dict(a=jnp.float32(1.))],
|
2020-06-01 19:07:11 -07:00
|
|
|
|
1.)
|
|
|
|
|
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def testCondOneBranchConstant(self):
|
|
|
|
|
def fun(x):
|
|
|
|
|
if x < 3:
|
|
|
|
|
return 5.
|
|
|
|
|
else:
|
|
|
|
|
return x
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
@jax.jit
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def cfun(x):
|
2020-05-11 15:47:16 -07:00
|
|
|
|
return lax.cond(lax.lt(x, 3), lambda x: 5, lambda x: x, x)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
self.assertEqual(fun(0), cfun(0))
|
|
|
|
|
self.assertEqual(cfun(0), 5)
|
|
|
|
|
self.assertEqual(fun(4), cfun(4))
|
|
|
|
|
self.assertEqual(cfun(4), 4)
|
|
|
|
|
|
|
|
|
|
def testCondOneBranchConstantTuple(self):
|
|
|
|
|
def fun(x):
|
|
|
|
|
if x < 3:
|
|
|
|
|
return (1., 2., 3.)
|
|
|
|
|
else:
|
|
|
|
|
return (x, 2., 4.)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
@jax.jit
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def cfun(x):
|
2020-05-11 15:47:16 -07:00
|
|
|
|
return lax.cond(lax.lt(x, 3),
|
2020-04-28 22:37:44 -07:00
|
|
|
|
lambda x: (1, 2., 3.),
|
2020-05-11 15:47:16 -07:00
|
|
|
|
lambda x: (x, 2., 4.),
|
|
|
|
|
x)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
|
|
|
|
self.assertEqual(fun(0), cfun(0))
|
|
|
|
|
self.assertEqual(cfun(0), (1, 2., 3.))
|
|
|
|
|
self.assertEqual(fun(4), cfun(4))
|
|
|
|
|
self.assertEqual(cfun(4), (4, 2., 4.))
|
|
|
|
|
|
2019-08-23 11:02:19 -07:00
|
|
|
|
def testCondBatched(self):
|
|
|
|
|
def fun(x, y, z):
|
|
|
|
|
pred = lax.lt(x, 3)
|
|
|
|
|
true_fun = lambda y: y
|
|
|
|
|
false_fun = lambda z: lax.neg(z)
|
|
|
|
|
return lax.cond(pred, y, true_fun, z, false_fun)
|
|
|
|
|
|
|
|
|
|
# these cases stay as cond
|
2020-05-05 14:59:16 -04:00
|
|
|
|
x = jnp.array(2)
|
|
|
|
|
y = jnp.array([1, 2])
|
|
|
|
|
z = jnp.array([3, 4])
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun, (None, 0, 0))(x, y, z)
|
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vmap(fun, (None, 0, 0)))(x, y, z)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected = np.array([1, 2])
|
2019-08-23 11:02:19 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
assert "select" not in str(jaxpr)
|
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
x = jnp.array(4)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun, (None, 0, 0))(x, y, z)
|
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vmap(fun, (None, 0, 0)))(x, y, z)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected = np.array([-3, -4])
|
2019-08-23 11:02:19 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
assert "select" not in str(jaxpr)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
fun = jax.jit(fun)
|
|
|
|
|
ans = jax.vmap(fun, (None, 0, 0))(x, y, z)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected = np.array([-3, -4])
|
2019-08-23 11:02:19 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
z = jnp.array(5)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun, (None, 0, None))(x, y, z)
|
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vmap(fun, (None, 0, None)))(x, y, z)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected = np.array([-5, -5])
|
2019-08-23 11:02:19 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
assert "select" not in str(jaxpr)
|
|
|
|
|
|
|
|
|
|
# these cases become select
|
2020-05-05 14:59:16 -04:00
|
|
|
|
x = jnp.array([2, 4])
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun, (0, 0, None))(x, y, z)
|
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vmap(fun, (0, 0, None)))(x, y, z)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected = np.array([1, -5])
|
2019-08-23 11:02:19 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
assert "select" in str(jaxpr)
|
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
z = jnp.array([3, 4])
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun)(x, y, z)
|
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vmap(fun))(x, y, z)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected = np.array([1, -4])
|
2019-08-23 11:02:19 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
assert "select" in str(jaxpr)
|
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
def testSwitchBatched(self):
|
|
|
|
|
def fun(index, x, y, z):
|
|
|
|
|
branches = [lambda xyz: xyz[0],
|
|
|
|
|
lambda xyz: lax.neg(xyz[1]),
|
|
|
|
|
lambda xyz: lax.sign(xyz[2])]
|
|
|
|
|
return lax.switch(index, branches, (x, y, z))
|
|
|
|
|
|
|
|
|
|
# these cases stay as cond
|
|
|
|
|
x = jnp.array(0)
|
|
|
|
|
y = jnp.array([1, 2])
|
|
|
|
|
z = jnp.array([3, 4])
|
|
|
|
|
w = jnp.array(9)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun, (None, 0, 0, None))(x, y, z, w)
|
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vmap(fun, (None, 0, 0, None)))(x, y, z, w)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
expected = np.array([1, 2])
|
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
assert "select" not in str(jaxpr)
|
|
|
|
|
|
|
|
|
|
x = jnp.array(1)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun, (None, 0, 0, None))(x, y, z, w)
|
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vmap(fun, (None, 0, 0, None)))(x, y, z, w)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
expected = np.array([-3, -4])
|
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
assert "select" not in str(jaxpr)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
fun = jax.jit(fun)
|
|
|
|
|
ans = jax.vmap(fun, (None, 0, 0, None))(x, y, z, w)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
expected = np.array([-3, -4])
|
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
|
|
z = jnp.array(5)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun, (None, 0, None, None))(x, y, z, w)
|
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vmap(fun, (None, 0, None, None)))(x, y, z, w)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
expected = np.array([-5, -5])
|
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
assert "select" not in str(jaxpr)
|
|
|
|
|
|
|
|
|
|
# these cases become select
|
|
|
|
|
x = jnp.array([0, 1])
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun, (0, 0, None, None))(x, y, z, w)
|
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vmap(fun, (0, 0, None, None)))(x, y, z, w)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
expected = np.array([1, -5])
|
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
assert "select" in str(jaxpr)
|
|
|
|
|
|
|
|
|
|
z = jnp.array([3, 4])
|
|
|
|
|
w = jnp.array([9, 9])
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(fun)(x, y, z, w)
|
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vmap(fun))(x, y, z, w)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
expected = np.array([1, -4])
|
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
assert "select" in str(jaxpr)
|
|
|
|
|
|
2020-01-22 17:15:36 -08:00
|
|
|
|
def testCondJVP(self):
|
|
|
|
|
def fun_ref(x):
|
|
|
|
|
if x < 3:
|
|
|
|
|
return (x, x)
|
|
|
|
|
else:
|
|
|
|
|
y = 2 * x
|
|
|
|
|
return y, 2 * y
|
|
|
|
|
|
|
|
|
|
def fun(x):
|
|
|
|
|
def false_fun(x):
|
|
|
|
|
y = 2 * x
|
|
|
|
|
return y, 2 * y
|
2020-05-11 15:47:16 -07:00
|
|
|
|
return lax.cond(x < 3, lambda x: (x, x), false_fun, x)
|
2020-01-22 17:15:36 -08:00
|
|
|
|
|
|
|
|
|
x = 3.14
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.jvp(fun, (x,), (x,))
|
|
|
|
|
expected = jax.jvp(fun_ref, (x,), (x,))
|
2020-01-22 17:15:36 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
jtu.check_grads(fun, (x,), order=2, modes=["fwd"])
|
2020-01-22 17:15:36 -08:00
|
|
|
|
|
2020-01-30 15:03:00 -08:00
|
|
|
|
x = 2.72
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.jvp(fun, (x,), (x,))
|
|
|
|
|
expected = jax.jvp(fun_ref, (x,), (x,))
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
2020-01-22 17:15:36 -08:00
|
|
|
|
jtu.check_grads(fun, (x,), order=2, modes=["fwd"])
|
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
def testSwitchJVP(self):
|
|
|
|
|
def branch(x):
|
|
|
|
|
y = 2 * x
|
|
|
|
|
return y, 2 * y
|
|
|
|
|
|
|
|
|
|
branches = [lambda x: (x, x),
|
|
|
|
|
branch,
|
|
|
|
|
lambda x: (x, -x)]
|
|
|
|
|
|
|
|
|
|
def fun_ref(x):
|
|
|
|
|
idx = x // 1
|
|
|
|
|
if idx <= 0:
|
|
|
|
|
return branches[0](x)
|
|
|
|
|
elif idx == 1:
|
|
|
|
|
return branches[1](x)
|
|
|
|
|
else:
|
|
|
|
|
return branches[2](x)
|
|
|
|
|
|
|
|
|
|
def fun(x):
|
|
|
|
|
idx = lax.convert_element_type(x // 1, np.int32)
|
|
|
|
|
return lax.switch(idx, branches, x)
|
|
|
|
|
|
|
|
|
|
for x in [-0.7, 0.7, 1.7, 2.7, 3.7]:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.jvp(fun, (x,), (x,))
|
|
|
|
|
expected = jax.jvp(fun_ref, (x,), (x,))
|
2020-06-01 19:07:11 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
jtu.check_grads(fun, (x,), order=2, modes=["fwd"])
|
|
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{"testcase_name": f"_{name}", "cond": cond}
|
|
|
|
|
for cond, name in COND_IMPLS)
|
|
|
|
|
def testCondJVP2(self, cond):
|
2020-01-30 15:03:00 -08:00
|
|
|
|
def fun_ref(x):
|
|
|
|
|
if x < 3:
|
|
|
|
|
return 2.
|
|
|
|
|
else:
|
|
|
|
|
return 2. * x
|
|
|
|
|
|
|
|
|
|
def fun(x):
|
2022-07-19 08:53:23 -07:00
|
|
|
|
return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
|
|
|
|
|
x = 3.14
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.jvp(fun, (x,), (x,))
|
|
|
|
|
expected = jax.jvp(fun_ref, (x,), (x,))
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
jtu.check_grads(fun, (x,), order=2, modes=["fwd"])
|
|
|
|
|
|
|
|
|
|
x = 2.72
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.jvp(fun, (x,), (x,))
|
|
|
|
|
expected = jax.jvp(fun_ref, (x,), (x,))
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
jtu.check_grads(fun, (x,), order=2, modes=["fwd"])
|
|
|
|
|
|
|
|
|
|
def testCondGrad(self):
|
|
|
|
|
def f_ref(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
|
return 3. * x if x < 2 else jnp.sin(x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
|
2022-01-14 13:35:54 -08:00
|
|
|
|
@jax.jit
|
2020-01-30 15:03:00 -08:00
|
|
|
|
def f(x):
|
2020-05-11 15:47:16 -07:00
|
|
|
|
return lax.cond(x < 2, lambda x: 3. * x, lambda x: jnp.sin(x), x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
|
|
|
|
|
x = 2.14
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.grad(f)(x)
|
|
|
|
|
expected = jax.grad(f_ref)(x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"])
|
|
|
|
|
|
|
|
|
|
x = 1.72
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.grad(f)(x)
|
|
|
|
|
expected = jax.grad(f_ref)(x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"])
|
|
|
|
|
|
2021-02-05 16:25:23 -08:00
|
|
|
|
def testCondGradVmapNan(self):
|
|
|
|
|
eps = 1e-3
|
|
|
|
|
|
|
|
|
|
def safe1(x):
|
|
|
|
|
return lax.cond(x < eps, lambda _: eps, lambda _: jnp.sqrt(x), ())
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
out = jax.grad(lambda x: jax.vmap(safe1)(x).sum())(np.zeros(10))
|
2021-02-05 16:25:23 -08:00
|
|
|
|
self.assertFalse(np.isnan(out).any())
|
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
def testSwitchGrad(self):
|
|
|
|
|
branches = [lambda x: 3. * x,
|
|
|
|
|
lambda x: jnp.sin(x),
|
|
|
|
|
lambda x: -x]
|
|
|
|
|
|
|
|
|
|
def f_ref(x):
|
|
|
|
|
idx = x // 1
|
|
|
|
|
if idx <= 0:
|
|
|
|
|
return branches[0](x)
|
|
|
|
|
elif idx == 1:
|
|
|
|
|
return branches[1](x)
|
|
|
|
|
else:
|
|
|
|
|
return branches[2](x)
|
|
|
|
|
|
2022-01-14 13:35:54 -08:00
|
|
|
|
@jax.jit
|
2020-06-01 19:07:11 -07:00
|
|
|
|
def f(x):
|
|
|
|
|
idx = lax.convert_element_type(x // 1, np.int32)
|
|
|
|
|
return lax.switch(idx, branches, x)
|
|
|
|
|
|
|
|
|
|
for x in [-0.7, 0.7, 1.7, 2.7, 3.7]:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.grad(f)(x)
|
|
|
|
|
expected = jax.grad(f_ref)(x)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"])
|
|
|
|
|
|
2025-03-12 05:15:58 +00:00
|
|
|
|
@parameterized.parameters(itertools.product(range(4), repeat=3))
|
|
|
|
|
@jtu.run_on_devices("cpu")
|
|
|
|
|
def testSwitchGradWithForwarding(self, seed, num_input_fwd, num_output_fwd):
|
|
|
|
|
num_args = 3
|
|
|
|
|
num_branches = 4
|
|
|
|
|
rng = np.random.RandomState(seed)
|
|
|
|
|
in_perm = rng.permutation(num_args)
|
|
|
|
|
out_perm = rng.permutation(num_args)
|
|
|
|
|
|
|
|
|
|
def branch(s, inputs):
|
|
|
|
|
inputs = [inputs[i] for i in in_perm]
|
|
|
|
|
outputs = inputs[:num_input_fwd] + [
|
|
|
|
|
s * jnp.exp(inputs[i]) if i < num_output_fwd else jnp.sin(inputs[i])
|
|
|
|
|
for i in range(num_args - num_input_fwd)]
|
|
|
|
|
return [outputs[i] for i in out_perm]
|
|
|
|
|
|
|
|
|
|
branches = [partial(branch, i) for i in range(num_branches)]
|
|
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
|
def f_(idx, inputs):
|
|
|
|
|
idx = lax.convert_element_type(idx // 1, np.int32)
|
|
|
|
|
return lax.switch(idx, branches, inputs)
|
|
|
|
|
|
|
|
|
|
for idx in range(num_branches):
|
|
|
|
|
f = partial(f_, idx)
|
|
|
|
|
jtu.check_grads(f, (jnp.arange(float(num_args)),),
|
|
|
|
|
order=1, modes=['fwd', 'rev'], atol=1e-2, rtol=1e-2)
|
|
|
|
|
|
2020-11-16 16:36:05 -08:00
|
|
|
|
def testSwitchGradWithWeakTypeMismatch(self): # issue #4696, PR #4896
|
2022-12-02 13:20:30 -08:00
|
|
|
|
dtype = dtypes.canonicalize_dtype(np.float64)
|
2020-11-16 19:42:45 -08:00
|
|
|
|
dtype = jnp.float32 if dtype == jnp.float32 else jnp.float64
|
|
|
|
|
|
2020-11-13 18:25:16 -08:00
|
|
|
|
branches = [
|
2020-11-16 19:42:45 -08:00
|
|
|
|
lambda x: x, # This preserves the weak type of x.
|
|
|
|
|
lambda x: x + dtype(1), # This strips the weak type of x.
|
2020-11-13 18:25:16 -08:00
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def f_ref(x):
|
|
|
|
|
i = x.astype(jnp.int32)
|
|
|
|
|
return branches[i](x)
|
|
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
|
return lax.switch(x.astype(jnp.int32), branches, x)
|
|
|
|
|
|
|
|
|
|
for x in [0., 1.]:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.grad(f)(x)
|
|
|
|
|
expected = jax.grad(f_ref)(x)
|
2020-11-13 18:25:16 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{"testcase_name": f"_{name}", "cond": cond}
|
|
|
|
|
for cond, name in COND_IMPLS)
|
|
|
|
|
def testCondGrad2(self, cond):
|
2020-01-30 15:03:00 -08:00
|
|
|
|
def f_ref(x):
|
2022-12-02 13:20:30 -08:00
|
|
|
|
z = jnp.array([1., 2.], x.dtype) * x if x[0] < 2 else jnp.sin(x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
return z.sum()
|
|
|
|
|
|
|
|
|
|
def _f(x):
|
2020-06-01 19:07:11 -07:00
|
|
|
|
return cond(
|
2020-05-11 15:47:16 -07:00
|
|
|
|
x[0] < 2,
|
2022-12-02 13:20:30 -08:00
|
|
|
|
lambda x: jnp.array([1., 2.], x.dtype) * x,
|
2020-05-11 15:47:16 -07:00
|
|
|
|
lambda x: jnp.sin(x),
|
|
|
|
|
x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
f = lambda x: jax.jit(_f)(x).sum()
|
2020-01-30 15:03:00 -08:00
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
x = 2.14 * jnp.ones(2)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.grad(f)(x)
|
|
|
|
|
expected = jax.grad(f_ref)(x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"])
|
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
x = 1.72 * jnp.ones(2)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.grad(f)(x)
|
|
|
|
|
expected = jax.grad(f_ref)(x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
2020-05-04 14:34:08 -04:00
|
|
|
|
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"],
|
2020-05-05 14:59:16 -04:00
|
|
|
|
rtol={jnp.float32: 1e-2, jnp.float64: 2e-3})
|
2020-01-30 15:03:00 -08:00
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{"testcase_name": f"_{name}", "cond": cond}
|
|
|
|
|
for cond, name in COND_IMPLS)
|
|
|
|
|
def testCondGrad3(self, cond):
|
2020-01-30 15:03:00 -08:00
|
|
|
|
def fun_ref(x):
|
|
|
|
|
if x < 3:
|
|
|
|
|
return 2.
|
|
|
|
|
else:
|
|
|
|
|
return 2. * x
|
|
|
|
|
|
|
|
|
|
def fun(x):
|
2022-07-19 08:53:23 -07:00
|
|
|
|
return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
|
|
|
|
|
x = 3.14
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.grad(fun)(x)
|
|
|
|
|
expected = jax.grad(fun_ref)(x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
jtu.check_grads(fun, (x,), order=2, modes=["fwd", "rev"])
|
|
|
|
|
|
|
|
|
|
x = 2.72
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.grad(fun)(x)
|
|
|
|
|
expected = jax.grad(fun_ref)(x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
jtu.check_grads(fun, (x,), order=2, modes=["fwd", "rev"])
|
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{"testcase_name": f"_{name}", "cond": cond}
|
|
|
|
|
for cond, name in COND_IMPLS)
|
|
|
|
|
def testCondGrad4(self, cond):
|
2023-09-27 12:10:06 -07:00
|
|
|
|
if cond is cond_with_new_checkpoint and jtu.test_device_matches(['tpu']):
|
|
|
|
|
raise unittest.SkipTest("tpu bug") # TODO(parkers): tpu bug exhibited here
|
2020-01-30 15:03:00 -08:00
|
|
|
|
def fun_ref(x, y):
|
|
|
|
|
if x < 3:
|
2020-05-05 14:59:16 -04:00
|
|
|
|
return 2. * jnp.sin(y)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
else:
|
2020-05-05 14:59:16 -04:00
|
|
|
|
return 2. * jnp.cos(x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
|
2022-01-14 13:35:54 -08:00
|
|
|
|
@jax.jit
|
2020-01-30 15:03:00 -08:00
|
|
|
|
def fun(x, y):
|
2020-06-01 19:07:11 -07:00
|
|
|
|
return cond(
|
2020-01-30 15:03:00 -08:00
|
|
|
|
x < 3,
|
2022-07-19 08:53:23 -07:00
|
|
|
|
None, lambda _: 2. * jnp.sin(y),
|
2020-01-30 15:03:00 -08:00
|
|
|
|
x, lambda x: 2. * x)
|
|
|
|
|
|
|
|
|
|
y = 5.8
|
|
|
|
|
x = 3.14
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.grad(fun, 1)(x, y)
|
|
|
|
|
expected = jax.grad(fun_ref, 1)(x, y)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
jtu.check_grads(fun, (x, y), order=2, modes=["fwd", "rev"])
|
|
|
|
|
|
|
|
|
|
x = 2.72
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.grad(fun, 1)(x, y)
|
|
|
|
|
expected = jax.grad(fun_ref, 1)(x, y)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
jtu.check_grads(fun, (x, y), order=2, modes=["fwd", "rev"])
|
|
|
|
|
|
|
|
|
|
def testCondLinearize(self):
|
|
|
|
|
def f(x):
|
2020-05-11 15:47:16 -07:00
|
|
|
|
return lax.cond(x < 2, lambda x: 3. * x, lambda x: jnp.sin(x), x)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y, f_lin = jax.linearize(f, 1.)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(y, 3., check_dtypes=False)
|
|
|
|
|
self.assertAllClose(f_lin(2.), 6., check_dtypes=False)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y, f_lin = jax.linearize(f, 4.)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
self.assertAllClose(y, jnp.sin(4.), check_dtypes=False)
|
|
|
|
|
self.assertAllClose(f_lin(2.), jnp.cos(4.) * 2., check_dtypes=False)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
def testSwitchLinearize(self):
|
|
|
|
|
branches = [lambda x: 3. * x,
|
|
|
|
|
lambda x: jnp.sin(x),
|
|
|
|
|
lambda x: -x]
|
|
|
|
|
def f(x):
|
|
|
|
|
idx = lax.convert_element_type(x // 1, np.int32)
|
|
|
|
|
return lax.switch(idx, branches, x)
|
|
|
|
|
|
|
|
|
|
# branch 0
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y, f_lin = jax.linearize(f, -1.)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
self.assertAllClose(y, -3., check_dtypes=False)
|
|
|
|
|
self.assertAllClose(f_lin(2.), 6., check_dtypes=False)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y, f_lin = jax.linearize(f, 0.)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
self.assertAllClose(y, 0., check_dtypes=False)
|
|
|
|
|
self.assertAllClose(f_lin(2.), 6., check_dtypes=False)
|
|
|
|
|
|
|
|
|
|
# branch 1
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y, f_lin = jax.linearize(f, 1.)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
self.assertAllClose(y, jnp.sin(1.), check_dtypes=False)
|
|
|
|
|
self.assertAllClose(f_lin(2.), jnp.cos(1.) * 2., check_dtypes=False)
|
|
|
|
|
|
|
|
|
|
# branch 2
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y, f_lin = jax.linearize(f, 2.)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
self.assertAllClose(y, -2., check_dtypes=False)
|
|
|
|
|
self.assertAllClose(f_lin(2.), -2., check_dtypes=False)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y, f_lin = jax.linearize(f, 3.)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
self.assertAllClose(y, -3., check_dtypes=False)
|
|
|
|
|
self.assertAllClose(f_lin(2.), -2., check_dtypes=False)
|
|
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{"testcase_name": f"_{name}", "cond": cond}
|
|
|
|
|
for cond, name in COND_IMPLS)
|
|
|
|
|
def testCondLinearize2(self, cond):
|
2020-01-30 15:03:00 -08:00
|
|
|
|
def f_ref(x):
|
2022-12-02 13:20:30 -08:00
|
|
|
|
z = jnp.array([1., 2.], x.dtype) * x if x[0] < 2 else jnp.cos(jnp.sin(x))
|
2020-01-30 15:03:00 -08:00
|
|
|
|
return z.sum()
|
|
|
|
|
|
|
|
|
|
def f(x):
|
2020-06-01 19:07:11 -07:00
|
|
|
|
return cond(
|
2020-05-11 15:47:16 -07:00
|
|
|
|
x[0] < 2,
|
2022-12-02 13:20:30 -08:00
|
|
|
|
lambda x: jnp.array([1., 2.], x.dtype) * x,
|
2020-05-11 15:47:16 -07:00
|
|
|
|
lambda x: jnp.cos(jnp.sin(x)),
|
|
|
|
|
x).sum()
|
2020-01-30 15:03:00 -08:00
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
x = 2.14 * jnp.ones(2)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y, f_lin = jax.linearize(f, x)
|
|
|
|
|
y_ref, f_lin_ref = jax.linearize(f_ref, x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(y, y_ref, check_dtypes=False)
|
|
|
|
|
self.assertAllClose(f_lin(x), f_lin_ref(x), check_dtypes=False)
|
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
x = -2.14 * jnp.ones(2)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y, f_lin = jax.linearize(f, x)
|
|
|
|
|
y_ref, f_lin_ref = jax.linearize(f_ref, x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(y, y_ref, check_dtypes=False)
|
|
|
|
|
self.assertAllClose(f_lin(x), f_lin_ref(x), check_dtypes=False)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
f = jax.jit(f)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
x = 2.14 * jnp.ones(2)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y, f_lin = jax.linearize(f, x)
|
|
|
|
|
y_ref, f_lin_ref = jax.linearize(f_ref, x)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(y, y_ref, check_dtypes=False)
|
|
|
|
|
self.assertAllClose(f_lin(x), f_lin_ref(x), check_dtypes=False)
|
|
|
|
|
|
|
|
|
|
def testCondJit(self):
|
|
|
|
|
def f(x):
|
2020-05-11 15:47:16 -07:00
|
|
|
|
return lax.cond(x < 2, lambda x: 3. * x, lambda x: jnp.sin(x), x)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y = jax.jit(f)(1.)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
expected = f(1.)
|
2020-04-30 06:19:46 -07:00
|
|
|
|
self.assertAllClose(y, expected, check_dtypes=False)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y = jax.jit(f)(4.)
|
2020-04-30 06:19:46 -07:00
|
|
|
|
expected = f(4.)
|
|
|
|
|
self.assertAllClose(y, expected, check_dtypes=False)
|
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
def testSwitchJit(self):
|
|
|
|
|
branches = [lambda x: 3. * x,
|
|
|
|
|
lambda x: jnp.sin(x),
|
|
|
|
|
lambda x: -x]
|
|
|
|
|
def f(x):
|
|
|
|
|
idx = lax.convert_element_type(x // 1, np.int32)
|
|
|
|
|
return lax.switch(idx, branches, x)
|
|
|
|
|
for x in [-1., 0., 1., 2., 3.]:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y = jax.jit(f)(x)
|
2020-06-01 19:07:11 -07:00
|
|
|
|
expected = f(x)
|
|
|
|
|
self.assertAllClose(y, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{"testcase_name": f"_{name}", "cond": cond}
|
|
|
|
|
for cond, name in COND_IMPLS)
|
|
|
|
|
def testCondJitDisabled(self, cond):
|
2020-05-11 17:11:21 -07:00
|
|
|
|
def f_ref(x):
|
|
|
|
|
return 3. * x if x < 2 else jnp.sin(x)
|
|
|
|
|
def f(x):
|
2020-06-01 19:07:11 -07:00
|
|
|
|
return cond(x < 2, lambda x: 3. * x, lambda x: jnp.sin(x), x)
|
2020-05-11 17:11:21 -07:00
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
with jax.disable_jit():
|
2020-05-11 17:11:21 -07:00
|
|
|
|
y = f(1.)
|
|
|
|
|
expected = f_ref(1.)
|
|
|
|
|
self.assertAllClose(y, expected, check_dtypes=False)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
with jax.disable_jit():
|
|
|
|
|
y = jax.jit(f)(1.)
|
2020-05-11 17:11:21 -07:00
|
|
|
|
expected = f(1.)
|
|
|
|
|
self.assertAllClose(y, expected, check_dtypes=False)
|
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{"testcase_name": f"_{name}", "cond": cond}
|
|
|
|
|
for cond, name in COND_IMPLS)
|
|
|
|
|
def testCondWithConsts(self, cond):
|
2020-04-30 06:19:46 -07:00
|
|
|
|
def f(x):
|
2020-06-01 19:07:11 -07:00
|
|
|
|
return cond(x < 2,
|
|
|
|
|
lambda x: np.array([1., 2.]) * x,
|
|
|
|
|
lambda x: np.array([3., 4.]) * jnp.sin(x),
|
|
|
|
|
x)
|
2020-04-30 06:19:46 -07:00
|
|
|
|
|
|
|
|
|
def f_ref(x):
|
|
|
|
|
if x < 2:
|
|
|
|
|
return np.array([1., 2.]) * x
|
|
|
|
|
else:
|
|
|
|
|
return np.array([3., 4.]) * np.sin(x)
|
|
|
|
|
|
|
|
|
|
y = f(1.)
|
|
|
|
|
expected = f_ref(1.)
|
|
|
|
|
self.assertAllClose(y, expected, check_dtypes=False)
|
|
|
|
|
y = f(4.)
|
|
|
|
|
expected = f_ref(4.)
|
|
|
|
|
self.assertAllClose(y, expected, check_dtypes=False)
|
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{"testcase_name": f"_{name}", "cond": cond}
|
|
|
|
|
for cond, name in COND_IMPLS)
|
|
|
|
|
def testCondJitWithConsts(self, cond):
|
2020-04-30 06:19:46 -07:00
|
|
|
|
def f(x):
|
2020-06-01 19:07:11 -07:00
|
|
|
|
return cond(x < 2,
|
|
|
|
|
lambda x: np.array([1., 2.]) * x,
|
|
|
|
|
lambda x: np.array([3., 4.]) * jnp.sin(x),
|
|
|
|
|
x)
|
2020-04-30 06:19:46 -07:00
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y = jax.jit(f)(1.)
|
2020-04-30 06:19:46 -07:00
|
|
|
|
expected = f(1.)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
self.assertAllClose(y, expected, check_dtypes=False)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
y = jax.jit(f)(4.)
|
2020-01-30 15:03:00 -08:00
|
|
|
|
expected = f(4.)
|
|
|
|
|
self.assertAllClose(y, expected, check_dtypes=False)
|
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{"testcase_name": f"_{name}", "cond": cond}
|
|
|
|
|
for cond, name in COND_IMPLS)
|
|
|
|
|
def testCondVmapGrad(self, cond):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# https://github.com/jax-ml/jax/issues/2264
|
2020-02-19 21:57:04 -08:00
|
|
|
|
def f_1(x): return x ** 2
|
|
|
|
|
def f_2(x): return x ** 3
|
|
|
|
|
|
2020-06-01 19:07:11 -07:00
|
|
|
|
def f(x): return cond(x > 0, f_1, f_2, x)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
def g(x): return jnp.where(x > 0, f_1(x), f_2(x))
|
2020-02-19 21:57:04 -08:00
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
x = jnp.linspace(-1, 1, 20)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(jax.grad(f))(x)
|
|
|
|
|
expected = jax.vmap(jax.grad(g))(x)
|
2020-02-19 21:57:04 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
2023-08-25 14:11:19 -07:00
|
|
|
|
@jax.legacy_prng_key('allow')
|
2019-08-29 22:28:03 -04:00
|
|
|
|
def testIssue1263(self):
|
|
|
|
|
def f(rng, x):
|
|
|
|
|
cond = random.bernoulli(rng)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
return lax.cond(cond, x, lambda x: x, jnp.abs(x) - 1., lambda x: x)
|
2019-08-29 22:28:03 -04:00
|
|
|
|
|
|
|
|
|
def body_fn(i, state):
|
|
|
|
|
rng, x = state
|
|
|
|
|
key, subkey = random.split(rng)
|
|
|
|
|
return key, f(subkey, x)
|
|
|
|
|
|
|
|
|
|
def g(rng, x):
|
|
|
|
|
return lax.fori_loop(0, 10, body_fn, (rng, x))
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jax.vmap(g)(random.split(random.PRNGKey(0), 3), jnp.ones((3, 4)))
|
2019-08-29 22:28:03 -04:00
|
|
|
|
|
2019-05-10 09:48:51 -07:00
|
|
|
|
def testIssue514(self):
|
|
|
|
|
# just check this doesn't crash
|
|
|
|
|
lax.cond(True,
|
|
|
|
|
(0, 0), lambda x: (x[0], 0),
|
|
|
|
|
(1, 1), lambda x: x)
|
2019-05-08 16:27:23 -07:00
|
|
|
|
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def testIssue649(self):
|
|
|
|
|
from jax import lax
|
|
|
|
|
|
|
|
|
|
def body(x):
|
|
|
|
|
a, b = x
|
|
|
|
|
return (7, b + 1)
|
|
|
|
|
|
|
|
|
|
def cond(x):
|
|
|
|
|
a, b = x
|
|
|
|
|
return b < 10
|
|
|
|
|
|
|
|
|
|
out = lax.while_loop(cond, body, (33, 4))
|
|
|
|
|
self.assertEqual(out, (7, 10))
|
|
|
|
|
|
2019-05-10 10:18:37 -07:00
|
|
|
|
@parameterized.named_parameters(
|
2022-12-01 09:12:01 -08:00
|
|
|
|
{"testcase_name": f"_{jit_scan=}_{jit_f=}_impl={scan_name}",
|
2022-06-28 16:29:38 -07:00
|
|
|
|
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl,
|
|
|
|
|
"impl_name": scan_name}
|
2019-05-10 10:18:37 -07:00
|
|
|
|
for jit_scan in [False, True]
|
2020-07-15 11:00:50 -07:00
|
|
|
|
for jit_f in [False, True]
|
2022-06-28 16:29:38 -07:00
|
|
|
|
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
|
|
|
|
def testScanImpl(self, jit_scan, jit_f, scan, impl_name):
|
2021-12-10 10:32:09 -08:00
|
|
|
|
rng = self.rng()
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
|
|
|
|
d = rng.randn(2)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def f(c, a):
|
|
|
|
|
assert a.shape == (3,)
|
|
|
|
|
assert c.shape == (4,)
|
2023-01-18 13:11:56 -08:00
|
|
|
|
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
|
2020-05-05 14:59:16 -04:00
|
|
|
|
c = jnp.sin(c * b)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
assert b.shape == ()
|
2019-05-11 09:29:12 -07:00
|
|
|
|
return c, b
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2019-05-10 10:18:37 -07:00
|
|
|
|
if jit_f:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
f = jax.jit(f)
|
2019-05-10 10:18:37 -07:00
|
|
|
|
if jit_scan:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
scan = jax.jit(scan, static_argnums=(0,))
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
|
as_ = rng.randn(5, 3)
|
|
|
|
|
c = rng.randn(4)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2019-05-10 10:18:37 -07:00
|
|
|
|
ans = scan(f, c, as_)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
expected = scan_reference(f, c, as_)
|
2022-06-28 16:29:38 -07:00
|
|
|
|
rtol = {np.float64: 1.4e-15}
|
|
|
|
|
atol = {np.float64: 8e-15}
|
|
|
|
|
if impl_name == "for":
|
|
|
|
|
rtol[np.float32] = 8e-5
|
|
|
|
|
atol[np.float32] = 3e-5
|
2021-11-24 10:02:35 -08:00
|
|
|
|
self.assertAllClose(
|
|
|
|
|
ans,
|
|
|
|
|
expected,
|
|
|
|
|
check_dtypes=False,
|
2022-06-28 16:29:38 -07:00
|
|
|
|
rtol=rtol,
|
|
|
|
|
atol=atol)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2019-05-10 10:18:37 -07:00
|
|
|
|
@parameterized.named_parameters(
|
2022-12-01 09:12:01 -08:00
|
|
|
|
{"testcase_name": f"_{jit_scan=}_{jit_f=}_impl={scan_name}",
|
2020-07-15 11:00:50 -07:00
|
|
|
|
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
|
2019-05-10 10:18:37 -07:00
|
|
|
|
for jit_scan in [False, True]
|
2020-07-15 11:00:50 -07:00
|
|
|
|
for jit_f in [False, True]
|
2022-06-28 16:29:38 -07:00
|
|
|
|
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
2020-07-15 11:00:50 -07:00
|
|
|
|
def testScanJVP(self, jit_scan, jit_f, scan):
|
2021-12-10 10:32:09 -08:00
|
|
|
|
rng = self.rng()
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
|
|
|
|
d = rng.randn(2)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def f(c, a):
|
|
|
|
|
assert a.shape == (3,)
|
|
|
|
|
assert c.shape == (4,)
|
2023-01-19 15:20:02 -08:00
|
|
|
|
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
|
2020-05-05 14:59:16 -04:00
|
|
|
|
c = jnp.sin(c * b)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
assert b.shape == ()
|
2019-05-11 09:29:12 -07:00
|
|
|
|
return c, b
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2019-05-10 10:18:37 -07:00
|
|
|
|
if jit_f:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
f = jax.jit(f)
|
2019-05-10 10:18:37 -07:00
|
|
|
|
if jit_scan:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
scan = jax.jit(scan, static_argnums=(0,))
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
|
as_ = rng.randn(5, 3)
|
|
|
|
|
c = rng.randn(4)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.jvp( lambda c, as_: scan(f, c, as_), (c, as_), (c, as_))
|
|
|
|
|
expected = jax.jvp(lambda c, as_: scan_reference(f, c, as_), (c, as_), (c, as_))
|
2021-12-10 10:32:09 -08:00
|
|
|
|
tol = {np.float64: 1e-12, np.float32: 1e-4}
|
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False, rtol=tol, atol=tol)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2024-09-03 09:45:28 -04:00
|
|
|
|
jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["fwd"],
|
|
|
|
|
rtol={jnp.float32: 2e-1})
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
2019-05-10 10:18:37 -07:00
|
|
|
|
@parameterized.named_parameters(
|
2022-12-01 09:12:01 -08:00
|
|
|
|
{"testcase_name": f"_{jit_scan=}_{jit_f=}_impl={scan_name}",
|
2020-07-15 11:00:50 -07:00
|
|
|
|
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
|
2019-05-10 10:18:37 -07:00
|
|
|
|
for jit_scan in [False, True]
|
2020-07-15 11:00:50 -07:00
|
|
|
|
for jit_f in [False, True]
|
2022-06-29 11:32:52 -07:00
|
|
|
|
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
2020-07-15 11:00:50 -07:00
|
|
|
|
def testScanLinearize(self, jit_scan, jit_f, scan):
|
2021-12-10 10:32:09 -08:00
|
|
|
|
rng = self.rng()
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
|
|
|
|
d = rng.randn(2)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
def f(c, a):
|
|
|
|
|
assert a.shape == (3,)
|
|
|
|
|
assert c.shape == (4,)
|
2023-01-19 15:20:02 -08:00
|
|
|
|
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
|
2020-05-05 14:59:16 -04:00
|
|
|
|
c = jnp.sin(c * b)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
assert b.shape == ()
|
2019-05-11 09:29:12 -07:00
|
|
|
|
return c, b
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2019-05-10 10:18:37 -07:00
|
|
|
|
if jit_f:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
f = jax.jit(f)
|
2019-05-10 10:18:37 -07:00
|
|
|
|
if jit_scan:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
scan = jax.jit(scan, static_argnums=(0,))
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
|
as_ = rng.randn(5, 3)
|
|
|
|
|
c = rng.randn(4)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2022-04-30 21:50:18 -07:00
|
|
|
|
if scan is scan_with_new_checkpoint2:
|
|
|
|
|
rtol = {np.float64: 1e-12, np.float32: 1e-4}
|
2022-08-12 16:52:49 -07:00
|
|
|
|
elif scan is scan_with_for:
|
|
|
|
|
rtol = {np.float64: 1e-12, np.float32: 1e-4}
|
2022-04-30 21:50:18 -07:00
|
|
|
|
else:
|
|
|
|
|
rtol = {np.float64: 1e-14, np.float32: 1e-4}
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.linearize(lambda c, as_: scan(f, c, as_), c, as_)[1](c, as_)
|
|
|
|
|
expected = jax.linearize(lambda c, as_: scan_reference(f, c, as_), c, as_)[1](c, as_)
|
2022-04-30 21:50:18 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2019-05-10 10:18:37 -07:00
|
|
|
|
@parameterized.named_parameters(
|
2022-12-01 09:12:01 -08:00
|
|
|
|
{"testcase_name": f"_{jit_scan=}_{jit_f=}_impl={scan_name}",
|
2020-07-15 11:00:50 -07:00
|
|
|
|
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
|
2019-05-10 10:18:37 -07:00
|
|
|
|
for jit_scan in [False, True]
|
2020-07-15 11:00:50 -07:00
|
|
|
|
for jit_f in [False, True]
|
2022-08-12 10:59:20 -07:00
|
|
|
|
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
2020-02-05 17:35:46 +01:00
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2020-07-15 11:00:50 -07:00
|
|
|
|
def testScanGrad(self, jit_scan, jit_f, scan):
|
2021-12-10 10:32:09 -08:00
|
|
|
|
rng = self.rng()
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
|
|
|
|
d = rng.randn(2)
|
2019-05-08 11:37:27 -07:00
|
|
|
|
def f(c, a):
|
|
|
|
|
assert a.shape == (3,)
|
|
|
|
|
assert c.shape == (4,)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
b = jnp.sum(jnp.sin(a)) + jnp.sum(jnp.sin(c)) + jnp.sum(jnp.sin(d))
|
|
|
|
|
c = jnp.sin(c * b)
|
2019-05-08 11:37:27 -07:00
|
|
|
|
assert b.shape == ()
|
2019-05-11 09:29:12 -07:00
|
|
|
|
return c, b
|
2019-05-08 11:37:27 -07:00
|
|
|
|
|
2023-10-09 10:34:55 -04:00
|
|
|
|
if scan is scan_with_new_checkpoint:
|
|
|
|
|
rtol = {np.float32: 5e-5, np.float64: 1e-13}
|
|
|
|
|
atol = 1e-5
|
|
|
|
|
elif scan is scan_with_for:
|
|
|
|
|
rtol = {np.float32: 2e-5, np.float64: 1e-13}
|
|
|
|
|
atol = {np.float32: 6e-2, np.float64: 1e-13}
|
|
|
|
|
else:
|
2024-03-29 14:39:58 -07:00
|
|
|
|
rtol = {np.float32: 2e-4, np.float64: 1e-13}
|
|
|
|
|
atol = {np.float32: 8e-5, np.float64: 1e-13}
|
2023-10-09 10:34:55 -04:00
|
|
|
|
|
2019-05-10 10:18:37 -07:00
|
|
|
|
if jit_f:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
f = jax.jit(f)
|
2019-05-10 10:18:37 -07:00
|
|
|
|
if jit_scan:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
scan = jax.jit(scan, static_argnums=(0,))
|
2019-05-08 11:37:27 -07:00
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
|
as_ = rng.randn(5, 3)
|
|
|
|
|
c = rng.randn(4)
|
2019-05-08 11:37:27 -07:00
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.grad(lambda c, as_: list( scan(f, c, as_))[0].sum())(c, as_)
|
|
|
|
|
expected = jax.grad(lambda c, as_: list(scan_reference(f, c, as_))[0].sum())(c, as_)
|
2022-08-12 16:52:49 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol, atol=atol)
|
2019-05-09 21:58:21 -07:00
|
|
|
|
|
2023-01-25 12:15:00 -08:00
|
|
|
|
rtol = 5e-3 if scan is not scan_with_new_checkpoint2 else 5e-2
|
2023-09-27 12:10:06 -07:00
|
|
|
|
atol = 5e-2 if jtu.test_device_matches(["tpu"]) else 1e-3
|
2019-07-27 15:46:14 -07:00
|
|
|
|
jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["rev"],
|
2022-08-12 16:52:49 -07:00
|
|
|
|
atol=atol, rtol=rtol)
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
2021-12-10 10:32:09 -08:00
|
|
|
|
@jtu.skip_on_devices("tpu") # TPU lacks precision for this test.
|
2020-02-05 17:35:46 +01:00
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2019-05-10 15:52:12 -07:00
|
|
|
|
def testScanRnn(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
|
r = self.rng()
|
2019-05-10 15:52:12 -07:00
|
|
|
|
|
|
|
|
|
n_in = 4
|
2019-05-16 10:18:53 -07:00
|
|
|
|
n_hid = 2
|
|
|
|
|
n_out = 1
|
|
|
|
|
length = 3
|
2019-05-10 15:52:12 -07:00
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
W_trans = r.randn(n_hid, n_hid + n_in).astype(jnp.float_)
|
|
|
|
|
W_out = r.randn(n_out, n_hid + n_in).astype(jnp.float_)
|
2019-05-10 15:52:12 -07:00
|
|
|
|
params = W_trans, W_out
|
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
inputs = r.randn(length, n_in).astype(jnp.float_)
|
|
|
|
|
targets = r.randn(length, n_out).astype(jnp.float_)
|
2019-05-10 15:52:12 -07:00
|
|
|
|
|
|
|
|
|
def step(params, state, input):
|
|
|
|
|
W_trans, W_out = params
|
2020-05-05 14:59:16 -04:00
|
|
|
|
stacked = jnp.concatenate([state, input])
|
|
|
|
|
output = jnp.tanh(jnp.dot(W_out, stacked))
|
|
|
|
|
next_state = jnp.tanh(jnp.dot(W_trans, stacked))
|
2019-05-11 09:29:12 -07:00
|
|
|
|
return next_state, output
|
2019-05-10 15:52:12 -07:00
|
|
|
|
|
|
|
|
|
def rnn(params, inputs):
|
2020-05-05 14:59:16 -04:00
|
|
|
|
init_state = jnp.zeros(n_hid)
|
2019-05-10 15:52:12 -07:00
|
|
|
|
_, outputs = lax.scan(partial(step, params), init_state, inputs)
|
|
|
|
|
return outputs
|
|
|
|
|
|
2022-01-14 13:35:54 -08:00
|
|
|
|
@jax.jit
|
2019-05-10 15:52:12 -07:00
|
|
|
|
def loss(params, inputs, targets):
|
|
|
|
|
predictions = rnn(params, inputs)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
return jnp.sum((predictions - targets)**2)
|
2019-05-10 15:52:12 -07:00
|
|
|
|
|
|
|
|
|
# evaluation doesn't crash
|
|
|
|
|
loss(params, inputs, targets)
|
|
|
|
|
|
|
|
|
|
# jvp evaluation doesn't crash
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jax.jvp(lambda params: loss(params, inputs, targets), (params,), (params,))
|
2019-05-10 15:52:12 -07:00
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
|
# jvp numerical check passes
|
2019-11-16 13:51:42 -05:00
|
|
|
|
jtu.check_grads(loss, (params, inputs, targets), order=2, modes=["fwd"],
|
2020-05-05 14:59:16 -04:00
|
|
|
|
rtol={np.float32: 2e-2, np.float64: 1e-6})
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
|
|
|
|
# linearize works
|
2021-09-13 16:00:22 -04:00
|
|
|
|
_, expected = jax.jvp(loss, (params, inputs, targets),
|
2019-07-27 15:46:14 -07:00
|
|
|
|
(params, inputs, targets))
|
2021-09-13 16:00:22 -04:00
|
|
|
|
_, linfun = jax.linearize(loss, params, inputs, targets)
|
2019-07-27 15:46:14 -07:00
|
|
|
|
ans = linfun(params, inputs, targets)
|
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
2019-05-10 15:52:12 -07:00
|
|
|
|
# gradient evaluation doesn't crash
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jax.grad(loss)(params, inputs, targets)
|
2019-05-10 15:52:12 -07:00
|
|
|
|
|
2019-05-16 10:18:53 -07:00
|
|
|
|
# gradient check passes
|
2019-11-16 13:51:42 -05:00
|
|
|
|
jtu.check_grads(loss, (params, inputs, targets), order=2, rtol=2e-2)
|
2019-05-16 10:18:53 -07:00
|
|
|
|
|
|
|
|
|
# we can vmap to batch things
|
|
|
|
|
batch_size = 7
|
2020-05-05 14:59:16 -04:00
|
|
|
|
batched_inputs = r.randn(batch_size, length, n_in).astype(jnp.float_)
|
|
|
|
|
batched_targets = r.randn(batch_size, length, n_out).astype(jnp.float_)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
batched_loss = jax.vmap(lambda x, y: loss(params, x, y))
|
2019-05-16 10:18:53 -07:00
|
|
|
|
losses = batched_loss(batched_inputs, batched_targets)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected = np.stack(list(map(lambda x, y: loss(params, x, y),
|
2019-05-16 10:18:53 -07:00
|
|
|
|
batched_inputs, batched_targets)))
|
2019-11-16 13:51:42 -05:00
|
|
|
|
self.assertAllClose(losses, expected, check_dtypes=False, rtol=1e-2)
|
2019-05-10 15:52:12 -07:00
|
|
|
|
|
2022-04-30 21:50:18 -07:00
|
|
|
|
@parameterized.named_parameters(
|
2022-12-01 09:12:01 -08:00
|
|
|
|
{"testcase_name": f"_impl={scan_name}", "scan": scan_impl}
|
2022-09-02 10:11:58 -07:00
|
|
|
|
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
2022-04-30 21:50:18 -07:00
|
|
|
|
def testIssue711(self, scan):
|
2019-05-20 09:09:32 -07:00
|
|
|
|
# Tests reverse-mode differentiation through a scan for which the scanned
|
|
|
|
|
# function also involves reverse-mode differentiation.
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# See https://github.com/jax-ml/jax/issues/711
|
2019-05-20 09:09:32 -07:00
|
|
|
|
def harmonic_bond(conf, params):
|
2020-05-05 14:59:16 -04:00
|
|
|
|
return jnp.sum(conf * params)
|
2019-05-20 09:09:32 -07:00
|
|
|
|
|
|
|
|
|
def minimize_structure(test_params):
|
|
|
|
|
energy_fn = partial(harmonic_bond, params=test_params)
|
|
|
|
|
|
|
|
|
|
def apply_carry(carry, _):
|
|
|
|
|
i, x = carry
|
2021-09-13 16:00:22 -04:00
|
|
|
|
new_x = x - 0.1 * jax.grad(energy_fn)(x)
|
2019-05-20 09:09:32 -07:00
|
|
|
|
new_carry = (i+1, new_x)
|
|
|
|
|
return new_carry, _
|
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
x0 = jnp.array([1., 2., 3.])
|
2022-04-30 21:50:18 -07:00
|
|
|
|
carry_final, _ = scan(apply_carry, (0, x0), jnp.zeros((75, 0)))
|
2019-05-20 09:09:32 -07:00
|
|
|
|
_, x_final = carry_final
|
|
|
|
|
return x_final
|
|
|
|
|
|
|
|
|
|
initial_params = 0.5
|
|
|
|
|
minimize_structure(initial_params) # doesn't crash
|
|
|
|
|
|
|
|
|
|
def loss(test_params):
|
|
|
|
|
x_final = minimize_structure(test_params)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
return jnp.sum(jnp.sin(1.0 - x_final))
|
2019-05-20 09:09:32 -07:00
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jax.grad(loss)(0.25) # doesn't crash
|
2019-05-20 09:09:32 -07:00
|
|
|
|
|
2019-05-21 07:08:13 -07:00
|
|
|
|
def testIssue744(self):
|
|
|
|
|
Point = collections.namedtuple('Point', ['x', 'y'])
|
2020-05-05 14:59:16 -04:00
|
|
|
|
p0 = Point(x=jnp.array(1), y=jnp.array(2))
|
2019-05-21 07:08:13 -07:00
|
|
|
|
|
|
|
|
|
def plus_one(p, iter_idx):
|
|
|
|
|
return Point(p.x+1, p.y+1), iter_idx
|
|
|
|
|
|
2019-10-24 10:10:04 +02:00
|
|
|
|
self.assertRaisesRegex(
|
2019-05-21 07:08:13 -07:00
|
|
|
|
ValueError,
|
2019-05-21 11:18:44 -07:00
|
|
|
|
'scan got value with no leading axis to scan over.*',
|
2019-05-21 07:08:13 -07:00
|
|
|
|
lambda: lax.scan(plus_one, p0, list(range(5))))
|
|
|
|
|
|
2023-03-23 14:39:40 -07:00
|
|
|
|
def testScanBodyOutputError(self):
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
TypeError,
|
2020-10-08 15:21:57 -07:00
|
|
|
|
re.escape("scan body output must be a pair, got ShapedArray(float32[]).")):
|
2023-03-23 14:39:40 -07:00
|
|
|
|
lax.scan(lambda c, x: np.float32(0.), 0, jnp.arange(5.))
|
|
|
|
|
|
2024-12-16 09:26:01 -08:00
|
|
|
|
def testScanMetadataError(self):
|
|
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/25507
|
|
|
|
|
def f(loop_i, x):
|
|
|
|
|
return {'T': jnp.array([0.5])}
|
|
|
|
|
|
|
|
|
|
init_val = {'t': jnp.array([1.0])}
|
|
|
|
|
msg = r".*with pytree metadata \('t',\).*with pytree metadata \('T',\)"
|
|
|
|
|
with self.assertRaisesRegex(TypeError, msg):
|
|
|
|
|
jax.lax.fori_loop(0, 1, f, init_val)
|
|
|
|
|
|
2023-03-23 14:39:40 -07:00
|
|
|
|
def testScanBodyCarryPytreeMismatchErrors(self):
|
|
|
|
|
with self.assertRaisesRegex(
|
2019-11-14 21:37:28 -08:00
|
|
|
|
TypeError,
|
2024-08-03 21:49:58 +00:00
|
|
|
|
re.escape("function carry input and carry output must have "
|
|
|
|
|
"the same pytree structure, but they differ:\n\n"
|
|
|
|
|
"The input carry c is a tuple of length 2")):
|
2023-03-23 14:39:40 -07:00
|
|
|
|
lax.scan(lambda c, x: ((0, 0, 0), x), (1, (2, 3)), jnp.arange(5.))
|
|
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
TypeError,
|
2024-08-03 21:49:58 +00:00
|
|
|
|
re.escape("function carry input and carry output must have the "
|
|
|
|
|
"same pytree structure, but they differ:\n\n"
|
|
|
|
|
"The input carry x is a tuple of length 2")):
|
2023-03-23 14:39:40 -07:00
|
|
|
|
lax.scan(lambda x, _: ((x[0].astype('float32'),), None),
|
|
|
|
|
(jnp.array(0, 'int32'),) * 2, None, length=1)
|
|
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
TypeError,
|
2024-08-03 21:49:58 +00:00
|
|
|
|
re.escape("function carry input and carry output must have the "
|
|
|
|
|
"same pytree structure, but they differ:\n\n"
|
|
|
|
|
"The input carry x is a <class 'tuple'> but the corres")):
|
2023-03-23 14:39:40 -07:00
|
|
|
|
jax.lax.scan(lambda x, _: ([x[0].astype('float32'),] * 2, None),
|
|
|
|
|
(jnp.array(0, 'int32'),) * 2, None, length=1)
|
|
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
TypeError,
|
2024-08-03 21:49:58 +00:00
|
|
|
|
re.escape("function carry input and carry output must have the "
|
|
|
|
|
"same pytree structure, but they differ:\n\n"
|
|
|
|
|
"The input carry x is a <class 'dict'> with 1 child but")):
|
2023-03-23 14:39:40 -07:00
|
|
|
|
jax.lax.scan(lambda x, _: ({'a': x['a'], 'b': x['a']}, None),
|
|
|
|
|
{'a': jnp.array(0, 'int32')}, None, length=1)
|
|
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
TypeError,
|
2024-08-03 21:49:58 +00:00
|
|
|
|
re.escape("function carry input and carry output must have the "
|
|
|
|
|
"same pytree structure, but they differ:\n\n"
|
2023-03-23 14:39:40 -07:00
|
|
|
|
" * the input carry component x[0] is a <class 'dict'> with "
|
|
|
|
|
"1 child but the corresponding component of the carry "
|
|
|
|
|
"output is a <class 'dict'> with 2 children")):
|
|
|
|
|
jax.lax.scan(lambda x, _: (({'a': x[0]['a'], 'b': x[0]['a']},) * 2, None),
|
|
|
|
|
({'a': jnp.array(0, 'int32')},) * 2, None, length=1)
|
|
|
|
|
|
|
|
|
|
def testScanBodyCarryTypeMismatchErrors(self):
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
TypeError,
|
2024-08-03 21:49:58 +00:00
|
|
|
|
re.escape("function carry input and carry output must have equal "
|
|
|
|
|
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
|
|
|
|
|
"The input carry x has type int32[] but the corresponding "
|
2023-03-23 14:39:40 -07:00
|
|
|
|
"output carry component has type float32[], so the dtypes do "
|
|
|
|
|
"not match"
|
|
|
|
|
)):
|
|
|
|
|
jax.lax.scan(lambda x, _: (x.astype('float32'), None),
|
|
|
|
|
jnp.array(0, 'int32'), None, length=1)
|
|
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
TypeError,
|
2024-08-03 21:49:58 +00:00
|
|
|
|
re.escape("function carry input and carry output must have equal "
|
|
|
|
|
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
|
|
|
|
|
"The input carry component x[1] has type int32[] but the "
|
2023-03-23 14:39:40 -07:00
|
|
|
|
"corresponding output carry component has type float32[], "
|
|
|
|
|
"so the dtypes do not match"
|
|
|
|
|
)):
|
|
|
|
|
jax.lax.scan(lambda x, _: ((x[0], x[1].astype('float32')), None),
|
|
|
|
|
(jnp.array(0, 'int32'),) * 2, None, length=1)
|
|
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
TypeError,
|
2024-08-03 21:49:58 +00:00
|
|
|
|
re.escape("function carry input and carry output must have equal "
|
|
|
|
|
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
|
2023-03-23 14:39:40 -07:00
|
|
|
|
" * the input carry component x[0] has type int32[] but the "
|
|
|
|
|
"corresponding output carry component has type float32[], "
|
2024-08-03 21:49:58 +00:00
|
|
|
|
"so the dtypes do not match;\n"
|
2023-03-23 14:39:40 -07:00
|
|
|
|
" * the input carry component x[1] has type int32[] but the "
|
|
|
|
|
"corresponding output carry component has type float32[1,1], "
|
2024-08-03 21:49:58 +00:00
|
|
|
|
"so the dtypes do not match and also the shapes do not match."
|
2023-03-23 14:39:40 -07:00
|
|
|
|
)):
|
|
|
|
|
jax.lax.scan(lambda x, _: ((x[0].astype('float32'),
|
|
|
|
|
x[1].astype('float32').reshape(1, 1),
|
|
|
|
|
x[2]), None),
|
|
|
|
|
(jnp.array(0, 'int32'),) * 3, None, length=1)
|
2019-10-24 10:10:04 +02:00
|
|
|
|
|
2024-05-15 15:01:01 -04:00
|
|
|
|
@jax.enable_checks(False)
|
|
|
|
|
def testScanInvalidUnrollRaises(self):
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "`unroll` must be"):
|
|
|
|
|
jax.lax.scan(lambda x, _: (x, x), 0, jnp.arange(5), unroll=-1)
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "`unroll` must be"):
|
|
|
|
|
jax.lax.scan(lambda x, _: (x, x), 0, jnp.arange(5), unroll=0)
|
|
|
|
|
|
2020-07-15 11:00:50 -07:00
|
|
|
|
@parameterized.named_parameters(
|
2022-05-12 19:13:00 +01:00
|
|
|
|
{"testcase_name": f"_{scan_name}",
|
2020-07-15 11:00:50 -07:00
|
|
|
|
"scan": scan_impl}
|
2022-08-12 16:52:49 -07:00
|
|
|
|
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
2020-07-15 11:00:50 -07:00
|
|
|
|
def testScanHigherOrderDifferentiation(self, scan):
|
2019-05-21 18:07:22 -07:00
|
|
|
|
d = 0.75
|
2019-05-21 13:54:37 -07:00
|
|
|
|
def f(c, a):
|
2020-05-05 14:59:16 -04:00
|
|
|
|
b = jnp.sin(c * jnp.sum(jnp.cos(d * a)))
|
|
|
|
|
c = 0.9 * jnp.cos(d * jnp.sum(jnp.sin(c * a)))
|
2019-05-21 13:54:37 -07:00
|
|
|
|
return c, b
|
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
as_ = jnp.arange(6.).reshape((3, 2))
|
2022-12-02 13:20:30 -08:00
|
|
|
|
c = jnp.array(1, dtype=as_.dtype)
|
2019-05-21 13:54:37 -07:00
|
|
|
|
|
2020-07-15 11:00:50 -07:00
|
|
|
|
jtu.check_grads(lambda c, as_: scan(f, c, as_), (c, as_),
|
2020-05-05 14:59:16 -04:00
|
|
|
|
modes=["rev"], order=2, rtol={np.float32: 6e-3})
|
2019-05-21 13:54:37 -07:00
|
|
|
|
|
2019-05-16 10:18:53 -07:00
|
|
|
|
@parameterized.named_parameters(
|
2022-12-01 09:12:01 -08:00
|
|
|
|
{"testcase_name": f"_{jit_scan=}_{jit_f=}_{in_axes=}_impl={scan_name}",
|
2020-07-15 11:00:50 -07:00
|
|
|
|
"jit_scan": jit_scan, "jit_f": jit_f, "in_axes": in_axes,
|
|
|
|
|
"scan": scan_impl}
|
2019-05-16 10:18:53 -07:00
|
|
|
|
for jit_scan in [False, True]
|
|
|
|
|
for jit_f in [False, True]
|
2022-08-16 16:44:50 -07:00
|
|
|
|
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR
|
2019-05-16 10:18:53 -07:00
|
|
|
|
for in_axes in itertools.product([None, 0, 1], [None, 0, 1, 2])
|
|
|
|
|
if in_axes != (None, None))
|
2020-07-15 11:00:50 -07:00
|
|
|
|
def testScanVmap(self, jit_scan, jit_f, in_axes, scan):
|
2021-12-10 10:32:09 -08:00
|
|
|
|
rng = self.rng()
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
|
|
|
|
d = rng.randn(2)
|
2019-05-16 10:18:53 -07:00
|
|
|
|
def f(c, a):
|
|
|
|
|
assert a.shape == (3,)
|
|
|
|
|
assert c.shape == (4,)
|
2023-01-19 15:20:02 -08:00
|
|
|
|
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
|
2020-05-05 14:59:16 -04:00
|
|
|
|
c = jnp.sin(c * b)
|
2019-05-16 10:18:53 -07:00
|
|
|
|
assert b.shape == ()
|
|
|
|
|
return c, b
|
|
|
|
|
|
|
|
|
|
if jit_f:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
f = jax.jit(f)
|
2019-05-16 10:18:53 -07:00
|
|
|
|
if jit_scan:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
scan = jax.jit(scan, static_argnums=(0,))
|
2019-05-16 10:18:53 -07:00
|
|
|
|
|
|
|
|
|
as_shape = [5, 3]
|
|
|
|
|
c_shape = [4]
|
|
|
|
|
|
|
|
|
|
c_bdim, as_bdim = in_axes
|
|
|
|
|
if c_bdim is not None:
|
|
|
|
|
c_shape.insert(c_bdim, 7)
|
|
|
|
|
if as_bdim is not None:
|
|
|
|
|
as_shape.insert(as_bdim, 7)
|
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
|
as_ = rng.randn(*as_shape)
|
|
|
|
|
c = rng.randn(*c_shape)
|
2019-05-16 10:18:53 -07:00
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(lambda c, as_: scan(f, c, as_), in_axes)(c, as_)
|
|
|
|
|
expected = jax.vmap(lambda c, as_: scan_reference(f, c, as_), in_axes)(c, as_)
|
2020-12-01 16:02:06 -05:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False,
|
|
|
|
|
rtol=1e-5, atol=1e-5)
|
2019-05-16 10:18:53 -07:00
|
|
|
|
|
|
|
|
|
def testScanVmapTuples(self):
|
|
|
|
|
def f(c, a):
|
|
|
|
|
a1, a2 = a
|
|
|
|
|
c1, c2 = c
|
2023-01-19 15:20:02 -08:00
|
|
|
|
b = jnp.sum(jnp.cos(a1)) * jnp.sum(c2 * a2)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
c = c1 * jnp.sin(jnp.sum(a1 * a2)), c2 * jnp.cos(jnp.sum(a1))
|
2019-05-16 10:18:53 -07:00
|
|
|
|
return c, b
|
|
|
|
|
|
|
|
|
|
in_axes = (0, (1, 2))
|
|
|
|
|
|
2021-12-10 10:32:09 -08:00
|
|
|
|
r = self.rng()
|
2019-05-16 10:18:53 -07:00
|
|
|
|
as_ = (r.randn(3, 7), r.randn(3, 4, 7))
|
|
|
|
|
c = (r.randn(7, 2), r.randn(7))
|
|
|
|
|
|
|
|
|
|
expected_c_out, expected_bs = [], []
|
|
|
|
|
for i in range(7):
|
|
|
|
|
c_out, bs = lax.scan(f, (c[0][i], c[1][i]), (as_[0][:,i], as_[1][:,:,i]))
|
|
|
|
|
expected_c_out.append(c_out)
|
|
|
|
|
expected_bs.append(bs)
|
|
|
|
|
expected_c_out_0, expected_c_out_1 = unzip2(expected_c_out)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected_c_out = (jnp.stack(expected_c_out_0), jnp.stack(expected_c_out_1))
|
|
|
|
|
expected_bs = jnp.stack(expected_bs)
|
2019-05-16 10:18:53 -07:00
|
|
|
|
expected = expected_c_out, expected_bs
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.vmap(lambda c, as_: lax.scan(f, c, as_), in_axes)(c, as_)
|
2019-05-16 10:18:53 -07:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
2022-04-30 21:50:18 -07:00
|
|
|
|
@parameterized.named_parameters(
|
2022-12-01 09:12:01 -08:00
|
|
|
|
{"testcase_name": f"_impl={scan_name}", "scan": scan_impl}
|
2022-08-16 16:44:50 -07:00
|
|
|
|
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
2022-04-30 21:50:18 -07:00
|
|
|
|
def testScanVmapFixpoint(self, scan):
|
2019-11-05 15:00:21 -08:00
|
|
|
|
def f(carry_init):
|
|
|
|
|
def scan_body(c, x):
|
|
|
|
|
# The carry is a 4-tuple, the last element starts batched,
|
|
|
|
|
# and the carry is shifted left at each iteration.
|
|
|
|
|
return ((c[1], c[2], c[3], 0.), None)
|
2022-04-30 21:50:18 -07:00
|
|
|
|
return scan(scan_body, (0., 1., 2., carry_init), jnp.zeros(2))
|
2020-05-05 14:59:16 -04:00
|
|
|
|
carry_init = jnp.array([3., 4., 5.])
|
2021-09-13 16:00:22 -04:00
|
|
|
|
carry_out, _ = jax.vmap(f)(carry_init)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
self.assertAllClose(carry_out[3], jnp.array([0., 0., 0.]), check_dtypes=False)
|
|
|
|
|
self.assertAllClose(carry_out[2], jnp.array([0., 0., 0.]), check_dtypes = False)
|
2019-11-05 15:00:21 -08:00
|
|
|
|
# After two shifts, we get the carry_init
|
|
|
|
|
self.assertAllClose(carry_out[1], carry_init, check_dtypes=False)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
self.assertAllClose(carry_out[0], jnp.array([2., 2., 2.]), check_dtypes = False)
|
2019-11-05 15:00:21 -08:00
|
|
|
|
|
2019-05-22 13:33:51 -07:00
|
|
|
|
def testIssue757(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# code from https://github.com/jax-ml/jax/issues/757
|
2019-05-22 13:33:51 -07:00
|
|
|
|
def fn(a):
|
2021-11-24 10:02:35 -08:00
|
|
|
|
return jnp.cos(a)
|
2019-05-22 13:33:51 -07:00
|
|
|
|
|
|
|
|
|
def loop(val):
|
2021-11-24 10:02:35 -08:00
|
|
|
|
iterations = 10
|
|
|
|
|
|
|
|
|
|
def apply_carry(x, i):
|
|
|
|
|
return jax.grad(fn, argnums=(0,))(x)[0], i
|
2019-05-22 13:33:51 -07:00
|
|
|
|
|
2021-11-24 10:02:35 -08:00
|
|
|
|
final_val, _ = lax.scan(apply_carry, val, jnp.arange(iterations))
|
|
|
|
|
return final_val
|
2019-05-22 13:33:51 -07:00
|
|
|
|
|
|
|
|
|
arg = 0.5
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jax.jit(jax.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash
|
2019-05-22 13:33:51 -07:00
|
|
|
|
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
|
def testIssue804(self):
|
2025-03-04 18:30:34 -08:00
|
|
|
|
# https://github.com/jax-ml/jax/issues/804
|
2021-09-23 06:33:25 -07:00
|
|
|
|
num_devices = jax.device_count()
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
|
f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jax.pmap(f, axis_name="i")(jnp.ones((num_devices, 4))) # doesn't crash
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
|
|
2019-08-05 12:13:07 -07:00
|
|
|
|
def testMap(self):
|
|
|
|
|
f = lambda x: x ** 2
|
2020-05-05 14:59:16 -04:00
|
|
|
|
xs = jnp.arange(10)
|
2019-08-05 12:13:07 -07:00
|
|
|
|
expected = xs ** 2
|
|
|
|
|
actual = lax.map(f, xs)
|
2020-06-01 17:19:23 -04:00
|
|
|
|
self.assertAllClose(actual, expected)
|
2019-08-05 12:13:07 -07:00
|
|
|
|
|
2020-03-13 15:15:55 -07:00
|
|
|
|
def testMapEmpty(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# https://github.com/jax-ml/jax/issues/2412
|
2020-05-05 14:59:16 -04:00
|
|
|
|
ans = lax.map(lambda x: x * x, jnp.array([]))
|
|
|
|
|
expected = jnp.array([])
|
2020-06-01 17:19:23 -04:00
|
|
|
|
self.assertAllClose(ans, expected)
|
2020-03-13 15:15:55 -07:00
|
|
|
|
|
2025-01-10 06:58:01 -08:00
|
|
|
|
@jtu.thread_unsafe_test() # Cache eviction means we might retrace
|
2019-07-27 15:46:14 -07:00
|
|
|
|
def testCaching(self):
|
|
|
|
|
def cond(x):
|
|
|
|
|
assert python_should_be_executing
|
|
|
|
|
return x < 5
|
|
|
|
|
|
|
|
|
|
def body(x):
|
|
|
|
|
assert python_should_be_executing
|
|
|
|
|
return x + 2
|
|
|
|
|
|
|
|
|
|
python_should_be_executing = True
|
|
|
|
|
lax.while_loop(cond, body, 0)
|
|
|
|
|
|
|
|
|
|
python_should_be_executing = False
|
|
|
|
|
lax.while_loop(cond, body, 0)
|
|
|
|
|
|
2021-10-04 17:54:18 -07:00
|
|
|
|
# This second caching test shows a different kind of caching that we haven't
|
|
|
|
|
# implemented (but could!), namely that Python functions that are distinct
|
|
|
|
|
# objects but are equivalent functions trigger cache hits. This kind of
|
|
|
|
|
# caching could be salient when using lambda functions with control flow:
|
|
|
|
|
#
|
|
|
|
|
# lax.while_loop(lambda x: x < 5, lambda x: x + 2, 0)
|
|
|
|
|
# lax.while_loop(lambda x: x < 5, lambda x: x + 2, 0)
|
|
|
|
|
#
|
|
|
|
|
# To get a cache hit on the second line we'd need to form a jaxpr and
|
|
|
|
|
# compare them for equality (including the literals on identity). We could
|
|
|
|
|
# implement that by adding a __hash__/__eq__ to core.Jaxpr and
|
|
|
|
|
# core.ClosedJaxpr (see #1221).
|
|
|
|
|
@unittest.skip("not implemented")
|
2019-07-27 15:46:14 -07:00
|
|
|
|
def testCaching2(self):
|
|
|
|
|
def cond(x):
|
|
|
|
|
assert python_should_be_executing
|
|
|
|
|
return x < 5
|
|
|
|
|
|
|
|
|
|
def body(x):
|
|
|
|
|
assert python_should_be_executing
|
|
|
|
|
return x + 2
|
|
|
|
|
|
|
|
|
|
python_should_be_executing = True
|
|
|
|
|
lax.while_loop(cond, body, 0)
|
|
|
|
|
|
|
|
|
|
def cond(x):
|
|
|
|
|
assert python_should_be_executing
|
|
|
|
|
return x < 5
|
|
|
|
|
|
|
|
|
|
def body(x):
|
|
|
|
|
assert python_should_be_executing
|
|
|
|
|
return x + 2
|
|
|
|
|
|
|
|
|
|
python_should_be_executing = False
|
|
|
|
|
lax.while_loop(cond, body, 0)
|
|
|
|
|
|
2022-06-29 11:32:52 -07:00
|
|
|
|
def test_caches_depend_on_axis_env(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# https://github.com/jax-ml/jax/issues/9187
|
2022-06-29 11:32:52 -07:00
|
|
|
|
scanned_f = lambda _, __: (lax.psum(1, 'i'), None)
|
|
|
|
|
f = lambda: lax.scan(scanned_f, 0, None, length=1)[0]
|
|
|
|
|
ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
|
|
|
|
|
self.assertEqual(ans, 2)
|
|
|
|
|
ans = jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
|
|
|
|
|
self.assertEqual(ans, 3)
|
|
|
|
|
|
2019-08-23 16:40:49 -07:00
|
|
|
|
def testWhileCondConstant(self):
|
|
|
|
|
out = lax.while_loop(lambda _: False, lambda _: (), ()) # doesn't crash
|
|
|
|
|
self.assertEqual(out, ())
|
|
|
|
|
|
2020-05-06 20:05:47 +03:00
|
|
|
|
@parameterized.named_parameters(
|
2022-12-01 09:12:01 -08:00
|
|
|
|
{"testcase_name": f"_{jit_loop=}_{jit_body=}_{jit_cond=}",
|
2020-05-06 20:05:47 +03:00
|
|
|
|
"jit_loop": jit_loop, "jit_body": jit_body, "jit_cond": jit_cond}
|
|
|
|
|
for jit_loop in [False, True]
|
|
|
|
|
for jit_body in [False, True]
|
|
|
|
|
for jit_cond in [False, True])
|
2020-05-08 17:18:11 +03:00
|
|
|
|
def testWhileJVP(self, jit_loop=True, jit_body=False, jit_cond=True):
|
2020-01-15 18:06:31 -08:00
|
|
|
|
cond = lambda x: x[0, 2] <= 8
|
|
|
|
|
body = lambda x: x * x
|
|
|
|
|
|
|
|
|
|
if jit_cond:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
cond = jax.jit(cond)
|
2020-01-15 18:06:31 -08:00
|
|
|
|
if jit_body:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
body = jax.jit(body)
|
2020-01-15 18:06:31 -08:00
|
|
|
|
|
|
|
|
|
loop = partial(lax.while_loop, cond, body)
|
|
|
|
|
if jit_loop:
|
2021-09-13 16:00:22 -04:00
|
|
|
|
loop = jax.jit(loop)
|
2020-01-15 18:06:31 -08:00
|
|
|
|
|
|
|
|
|
loop_ref = partial(while_loop_reference, cond, body)
|
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
x = jnp.arange(9.).reshape((3, 3))
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.jvp(loop, (x,), (x,))
|
|
|
|
|
expected = jax.jvp(loop_ref, (x,), (x,))
|
2020-01-15 18:06:31 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
|
|
jtu.check_grads(loop, (x,), order=2, modes=["fwd"])
|
|
|
|
|
|
2022-04-23 15:32:30 -07:00
|
|
|
|
@parameterized.named_parameters(
|
2022-12-01 09:12:01 -08:00
|
|
|
|
{"testcase_name": f"_{jit_loop=}_{jit_body=}_{jit_cond=}_impl={while_name}",
|
2022-07-28 18:04:49 -07:00
|
|
|
|
"jit_loop": jit_loop, "jit_body": jit_body, "jit_cond": jit_cond,
|
|
|
|
|
"while_loop": while_impl}
|
2022-04-23 15:32:30 -07:00
|
|
|
|
for jit_loop in [False, True]
|
|
|
|
|
for jit_body in [False, True]
|
2022-07-28 18:04:49 -07:00
|
|
|
|
for jit_cond in [False, True]
|
|
|
|
|
for while_impl, while_name in WHILE_LOOP_IMPLS)
|
|
|
|
|
def testWhileLinearize(self, while_loop, jit_loop=True, jit_body=False,
|
|
|
|
|
jit_cond=True):
|
2022-04-23 15:32:30 -07:00
|
|
|
|
cond = lambda x: x[0, 2] <= 8
|
|
|
|
|
body = lambda x: x * x
|
|
|
|
|
|
|
|
|
|
if jit_cond:
|
|
|
|
|
cond = jax.jit(cond)
|
|
|
|
|
if jit_body:
|
|
|
|
|
body = jax.jit(body)
|
|
|
|
|
|
2022-07-28 18:04:49 -07:00
|
|
|
|
loop = partial(while_loop, cond, body)
|
2022-04-23 15:32:30 -07:00
|
|
|
|
if jit_loop:
|
|
|
|
|
loop = jax.jit(loop)
|
|
|
|
|
|
|
|
|
|
loop_ref = partial(while_loop_reference, cond, body)
|
|
|
|
|
|
|
|
|
|
x = jnp.arange(9.).reshape((3, 3))
|
|
|
|
|
y, f_lin = jax.linearize(loop, x)
|
|
|
|
|
ydot = f_lin(x)
|
|
|
|
|
y_expected, ydot_expected = jax.jvp(loop_ref, (x,), (x,))
|
|
|
|
|
self.assertAllClose(y, y_expected, check_dtypes=False)
|
|
|
|
|
self.assertAllClose(ydot, ydot_expected, check_dtypes=False)
|
|
|
|
|
|
2020-01-15 18:06:31 -08:00
|
|
|
|
def testWhileJVPViaForiLoop(self):
|
|
|
|
|
f = lambda x: lax.fori_loop(0, 3, lambda i, x: x * 2, x)
|
|
|
|
|
self.assertAllClose(f(2.), 16., check_dtypes=False)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
self.assertAllClose(jax.jvp(f, (2.,), (1.,)), (16., 8.), check_dtypes=False)
|
2020-01-15 18:06:31 -08:00
|
|
|
|
jtu.check_grads(f, (2.,), order=2, modes=["fwd"])
|
|
|
|
|
|
|
|
|
|
f = lambda x: lax.fori_loop(0, 3, lambda i, x: x * (i + 1), x)
|
|
|
|
|
self.assertAllClose(f(2.), 12., check_dtypes=False)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
self.assertAllClose(jax.jvp(f, (2.,), (1.,)), (12., 6.), check_dtypes=False)
|
2020-01-15 18:06:31 -08:00
|
|
|
|
jtu.check_grads(f, (2.,), order=2, modes=["fwd"])
|
|
|
|
|
|
|
|
|
|
def testWhileJVPWithGrowingNonzeroTangents(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
|
rng = self.rng()
|
2020-01-15 18:06:31 -08:00
|
|
|
|
|
|
|
|
|
def cond(state):
|
|
|
|
|
i, x, y, z = state
|
|
|
|
|
return i < 2
|
|
|
|
|
|
|
|
|
|
def body(state):
|
|
|
|
|
i, x, y, z = state
|
|
|
|
|
y = x * x
|
|
|
|
|
z = y * y
|
|
|
|
|
return i + 1, x, y, z
|
|
|
|
|
|
|
|
|
|
y, z = rng.randn(2), rng.randn(2)
|
|
|
|
|
def loop(loop_impl, x):
|
|
|
|
|
return loop_impl(cond, body, (0, x, y, z))[1]
|
|
|
|
|
|
|
|
|
|
loop_lax = partial(loop, lax.while_loop)
|
|
|
|
|
loop_ref = partial(loop, while_loop_reference)
|
|
|
|
|
|
|
|
|
|
x = rng.randn(2)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
ans = jax.jvp(loop_lax, (x,), (x,))
|
|
|
|
|
expected = jax.jvp(loop_ref, (x,), (x,))
|
2020-01-15 18:06:31 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
|
|
jtu.check_grads(loop_lax, (x,), order=2, modes=["fwd"])
|
|
|
|
|
|
2021-04-23 08:31:11 -07:00
|
|
|
|
def testStaticForiGrad(self):
|
|
|
|
|
func = lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jax.grad(func)(1.) # doesn't crash
|
|
|
|
|
jax.linearize(func, 1.) # doesn't crash
|
2021-04-23 08:31:11 -07:00
|
|
|
|
|
2020-04-17 20:08:24 +03:00
|
|
|
|
@parameterized.named_parameters(
|
2022-12-01 09:12:01 -08:00
|
|
|
|
dict(testcase_name=f"_{loop=}", loop=loop)
|
2021-04-23 08:31:11 -07:00
|
|
|
|
for loop in ["while", "fori_inside_cond", "fori_inside_scan"])
|
2020-04-17 20:08:24 +03:00
|
|
|
|
def testWhileGradError(self, loop: str = "fori_inside_scan"):
|
|
|
|
|
# Raise error for vjp for loops
|
|
|
|
|
if loop == "while":
|
|
|
|
|
func = lambda x: lax.while_loop(lambda i: i < 5., lambda i: i + 1., x)
|
|
|
|
|
elif loop == "fori_inside_jit":
|
2021-09-13 16:00:22 -04:00
|
|
|
|
func = jax.jit(lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x))
|
2020-04-17 20:08:24 +03:00
|
|
|
|
elif loop == "fori_inside_cond":
|
2021-04-23 08:31:11 -07:00
|
|
|
|
func = lambda x: lax.cond(
|
|
|
|
|
True,
|
|
|
|
|
x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x),
|
|
|
|
|
1., lambda x: x)
|
2020-04-17 20:08:24 +03:00
|
|
|
|
elif loop == "fori_inside_scan":
|
2021-04-23 08:31:11 -07:00
|
|
|
|
func = lambda x: lax.scan(
|
|
|
|
|
lambda c, x: (lax.fori_loop(x, x + 2., lambda i, c1: c1 * c, x), None),
|
|
|
|
|
x, np.ones(2))[0]
|
2020-04-17 20:08:24 +03:00
|
|
|
|
else:
|
|
|
|
|
assert False
|
|
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "Reverse-mode differentiation does not work for lax.while_loop"):
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jax.grad(func)(1.)
|
2020-04-17 20:08:24 +03:00
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jax.linearize(func, 1.) # Linearization works
|
2020-04-17 20:08:24 +03:00
|
|
|
|
|
2023-08-25 14:11:19 -07:00
|
|
|
|
@jax.legacy_prng_key('allow')
|
2019-09-09 18:25:16 -07:00
|
|
|
|
def testIssue1316(self):
|
|
|
|
|
def f(carry, _):
|
|
|
|
|
c, key = carry
|
|
|
|
|
key, _ = random.split(key)
|
|
|
|
|
return (c, key), ()
|
|
|
|
|
|
|
|
|
|
key = random.PRNGKey(0)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jax.grad(lambda c: lax.scan(f, (c, key), np.ones(3))[0][0])(0.) # doesn't crash
|
2019-09-09 18:25:16 -07:00
|
|
|
|
|
2019-09-20 15:42:05 -07:00
|
|
|
|
def testIssue1361(self):
|
2021-09-13 16:00:22 -04:00
|
|
|
|
@jax.jit
|
2019-09-20 15:42:05 -07:00
|
|
|
|
def jit_run_scan(x):
|
|
|
|
|
def fun(carry, _):
|
|
|
|
|
x, _ = carry
|
|
|
|
|
return (2 * x, 0.), None
|
2020-05-05 14:59:16 -04:00
|
|
|
|
(x, _), _ = lax.scan(fun, (x, 0.), jnp.arange(3))
|
2019-09-20 15:42:05 -07:00
|
|
|
|
return x
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jax.grad(lambda x: jit_run_scan(x))(0.) # doesn't crash
|
2019-09-20 15:42:05 -07:00
|
|
|
|
|
2019-11-11 07:45:31 -08:00
|
|
|
|
def testIssue810(self):
|
|
|
|
|
def loss(A):
|
|
|
|
|
def step(x, i):
|
2020-05-05 14:59:16 -04:00
|
|
|
|
return jnp.matmul(A, x), None
|
|
|
|
|
init_x = jnp.zeros(A.shape[-1:])
|
|
|
|
|
last_x, _ = lax.scan(step, init_x, jnp.arange(10))
|
|
|
|
|
return jnp.sum(last_x)
|
2019-11-11 07:45:31 -08:00
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
A = jnp.zeros((3, 3))
|
2019-11-11 07:45:31 -08:00
|
|
|
|
# The second DUS was unnecessarily replicating A across time.
|
|
|
|
|
# We check XLA because _scan_impl is "underneath" the jaxpr language.
|
2024-06-18 15:46:41 -07:00
|
|
|
|
s = jax.jit(jax.grad(loss)).lower(A).as_text('hlo')
|
2019-11-11 07:45:31 -08:00
|
|
|
|
assert s.count("dynamic-update-slice(") < 2
|
|
|
|
|
|
2019-12-02 17:44:58 -08:00
|
|
|
|
def testScanLengthArg(self):
|
|
|
|
|
def arange(n):
|
|
|
|
|
return lax.scan(lambda c, _: (c + 1, c), 0, None, length=n)[1]
|
|
|
|
|
|
|
|
|
|
ans = arange(10)
|
2020-05-05 14:59:16 -04:00
|
|
|
|
expected = np.arange(10)
|
2019-12-02 17:44:58 -08:00
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
2021-10-13 10:56:21 -04:00
|
|
|
|
@ignore_jit_of_pmap_warning()
|
2020-03-30 17:48:07 -07:00
|
|
|
|
def test_while_loop_of_pmap(self):
|
2023-12-21 16:45:49 +08:00
|
|
|
|
# Avoid accuracy issue caused by too many devices.
|
|
|
|
|
DEVICE_LIMITATION = 4
|
|
|
|
|
devices = jax.devices()
|
|
|
|
|
count = jax.device_count()
|
|
|
|
|
if jax.device_count() >= DEVICE_LIMITATION:
|
|
|
|
|
devices = devices[:DEVICE_LIMITATION]
|
|
|
|
|
count = DEVICE_LIMITATION
|
2020-03-31 11:54:57 -07:00
|
|
|
|
|
2023-12-21 16:45:49 +08:00
|
|
|
|
# code from jsnoek@
|
2020-03-30 17:48:07 -07:00
|
|
|
|
def body(i, x):
|
2023-12-21 16:45:49 +08:00
|
|
|
|
result = jax.pmap(lambda z: lax.psum(jnp.sin(z), 'i'), devices=devices, axis_name='i')(x)
|
2020-03-30 17:48:07 -07:00
|
|
|
|
return result + x
|
2020-06-02 19:25:47 -07:00
|
|
|
|
f_loop = lambda x: lax.fori_loop(0, 3, body, x) # noqa: F821
|
2023-12-21 16:45:49 +08:00
|
|
|
|
ans = f_loop(jnp.ones(count))
|
2020-03-30 17:48:07 -07:00
|
|
|
|
del body, f_loop
|
|
|
|
|
|
|
|
|
|
def body2(i, x):
|
2020-05-05 14:59:16 -04:00
|
|
|
|
result = jnp.broadcast_to(jnp.sin(x).sum(), x.shape)
|
2020-03-30 17:48:07 -07:00
|
|
|
|
return result + x
|
|
|
|
|
g_loop = lambda x: lax.fori_loop(0, 3, body2, x)
|
2023-12-21 16:45:49 +08:00
|
|
|
|
expected = g_loop(jnp.ones(count))
|
2020-03-30 17:48:07 -07:00
|
|
|
|
|
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
2021-10-13 10:56:21 -04:00
|
|
|
|
@ignore_jit_of_pmap_warning()
|
2020-03-31 13:07:59 -07:00
|
|
|
|
def test_while_loop_of_pmap_error_message(self):
|
|
|
|
|
|
|
|
|
|
def body(i, x):
|
2021-09-13 16:00:22 -04:00
|
|
|
|
result = jax.pmap(lambda z: lax.psum(jnp.sin(z), 'i'), axis_name='i')(x)
|
2020-03-31 13:07:59 -07:00
|
|
|
|
return result + x
|
|
|
|
|
f_loop = lambda x: lax.fori_loop(0, 3, body, x)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
too_big = 2 * jax.device_count()
|
2020-03-31 13:07:59 -07:00
|
|
|
|
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
|
ValueError,
|
|
|
|
|
re.escape(
|
2021-10-13 10:56:21 -04:00
|
|
|
|
"compiling computation `scan` that requires {} "
|
|
|
|
|
"replicas, but only {} XLA devices are available."
|
|
|
|
|
.format(too_big, jax.device_count())),
|
2020-05-05 14:59:16 -04:00
|
|
|
|
lambda: f_loop(jnp.ones(too_big)))
|
2020-03-31 13:07:59 -07:00
|
|
|
|
|
2020-07-15 11:00:50 -07:00
|
|
|
|
@parameterized.named_parameters(
|
2022-05-12 19:13:00 +01:00
|
|
|
|
{"testcase_name": f"_{scan_name}",
|
2020-07-15 11:00:50 -07:00
|
|
|
|
"scan": scan_impl}
|
2022-06-28 16:29:38 -07:00
|
|
|
|
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
2020-07-15 11:00:50 -07:00
|
|
|
|
def test_scan_reverse(self, scan):
|
2020-05-04 19:44:22 -07:00
|
|
|
|
def cumsum(x, reverse):
|
2020-07-15 11:00:50 -07:00
|
|
|
|
return scan(lambda c, x: (c + x, c + x), 0, x, reverse=reverse)[1]
|
2020-05-04 19:44:22 -07:00
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
|
x = np.array([3, 1, 4, 1, 5, 9])
|
|
|
|
|
self.assertAllClose(np.cumsum(x), cumsum(x, False), check_dtypes=False)
|
|
|
|
|
self.assertAllClose(np.cumsum(x[::-1])[::-1], cumsum(x, True), check_dtypes=False)
|
2020-05-04 19:44:22 -07:00
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
with jax.disable_jit():
|
2020-05-05 14:59:16 -04:00
|
|
|
|
self.assertAllClose(np.cumsum(x), cumsum(x, False), check_dtypes=False)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
with jax.disable_jit():
|
2020-05-05 14:59:16 -04:00
|
|
|
|
self.assertAllClose(np.cumsum(x[::-1])[::-1], cumsum(x, True), check_dtypes=False)
|
2020-05-04 19:44:22 -07:00
|
|
|
|
|
2020-07-15 11:00:50 -07:00
|
|
|
|
def test_scan_unroll(self):
|
|
|
|
|
d = jnp.ones(2)
|
|
|
|
|
def f(c, a):
|
|
|
|
|
assert a.shape == (3,)
|
|
|
|
|
assert c.shape == (4,)
|
2023-01-19 15:20:02 -08:00
|
|
|
|
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
|
2020-07-15 11:00:50 -07:00
|
|
|
|
c = jnp.sin(c * b)
|
|
|
|
|
assert b.shape == ()
|
|
|
|
|
return c, b
|
|
|
|
|
|
2025-03-11 13:44:37 -07:00
|
|
|
|
xs = jnp.ones((20, 3))
|
2020-07-15 11:00:50 -07:00
|
|
|
|
c = jnp.ones(4)
|
|
|
|
|
|
|
|
|
|
scan = lambda c, xs: lax.scan(f, c, xs)
|
|
|
|
|
scan_unrolled = lambda c, xs: lax.scan(f, c, xs, unroll=2)
|
2024-10-30 06:19:54 -07:00
|
|
|
|
scan_fully_unrolled = lambda c, xs: lax.scan(f, c, xs, unroll=True)
|
2020-07-15 11:00:50 -07:00
|
|
|
|
|
|
|
|
|
# jaxprs should be the same size
|
|
|
|
|
self.assertEqual(
|
2021-09-13 16:00:22 -04:00
|
|
|
|
len(str(jax.make_jaxpr(scan)(c, xs))),
|
|
|
|
|
len(str(jax.make_jaxpr(scan_unrolled)(c, xs))))
|
2020-07-15 11:00:50 -07:00
|
|
|
|
|
|
|
|
|
# but HLO should grow due to unrolling
|
2024-10-30 06:19:54 -07:00
|
|
|
|
scan_hlo = str(jax.jit(scan).lower(c, xs).as_text("hlo"))
|
|
|
|
|
scan_unrolled_hlo = str(jax.jit(scan_unrolled).lower(c, xs).as_text("hlo"))
|
|
|
|
|
scan_fully_unrolled_hlo = str(
|
|
|
|
|
jax.jit(scan_fully_unrolled).lower(c, xs).as_text("hlo"))
|
|
|
|
|
|
|
|
|
|
self.assertLess(len(scan_hlo), len(scan_unrolled_hlo))
|
|
|
|
|
self.assertLess(len(scan_unrolled_hlo), len(scan_fully_unrolled_hlo))
|
|
|
|
|
|
|
|
|
|
# and the lowering should contain a while loop, unless the scan is fully
|
|
|
|
|
# unrolled
|
|
|
|
|
self.assertIn("while(", scan_hlo)
|
|
|
|
|
self.assertIn("while(", scan_unrolled_hlo)
|
|
|
|
|
self.assertNotIn("while(", scan_fully_unrolled_hlo)
|
2020-07-15 11:00:50 -07:00
|
|
|
|
|
2024-04-23 17:26:23 -04:00
|
|
|
|
def test_scan_xs_none(self):
|
|
|
|
|
def f(h, _):
|
|
|
|
|
return h + 1, None
|
|
|
|
|
|
|
|
|
|
length = 20
|
|
|
|
|
h, _ = lax.scan(f, 0, length=length)
|
|
|
|
|
self.assertEqual(h, length)
|
|
|
|
|
|
2020-05-19 18:14:10 -07:00
|
|
|
|
def test_disable_jit_cond_with_vmap(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# https://github.com/jax-ml/jax/issues/3093
|
2020-05-19 18:14:10 -07:00
|
|
|
|
def fn(t):
|
|
|
|
|
return lax.cond(t > 0, 0, lambda x: 0, 0, lambda x: 1)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
fn = jax.vmap(fn)
|
2020-05-19 18:14:10 -07:00
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
with jax.disable_jit():
|
2020-06-02 19:25:47 -07:00
|
|
|
|
_ = fn(jnp.array([1])) # doesn't crash
|
2020-05-19 18:14:10 -07:00
|
|
|
|
|
|
|
|
|
def test_disable_jit_while_loop_with_vmap(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# https://github.com/jax-ml/jax/issues/2823
|
2020-05-19 18:14:10 -07:00
|
|
|
|
def trivial_while(y):
|
|
|
|
|
return lax.while_loop(lambda x: x < 10.0, lambda x: x + 1.0, y)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
with jax.disable_jit():
|
|
|
|
|
jax.vmap(trivial_while)(jnp.array([3.0,4.0])) # doesn't crash
|
2020-05-19 18:14:10 -07:00
|
|
|
|
|
2020-05-20 20:21:41 -07:00
|
|
|
|
def test_vmaps_of_while_loop(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# https://github.com/jax-ml/jax/issues/3164
|
2020-05-20 20:21:41 -07:00
|
|
|
|
def f(x, n): return lax.fori_loop(0, n, lambda _, x: x + 1, x)
|
|
|
|
|
x, n = jnp.arange(3), jnp.arange(4)
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jax.vmap(jax.vmap(f, (None, 0)), (0, None))(x, n) # doesn't crash
|
2020-05-20 20:21:41 -07:00
|
|
|
|
|
2025-03-11 09:20:11 -07:00
|
|
|
|
def test_disable_jit_while_loop_with_mutation(self):
|
|
|
|
|
# https://github.com/jax-ml/jax/issues/27019
|
|
|
|
|
|
|
|
|
|
def body_fun(carry):
|
|
|
|
|
x, y = carry
|
|
|
|
|
x += 1 # in-place if x is mutable
|
|
|
|
|
return x, y + x
|
|
|
|
|
|
|
|
|
|
def cond_fun(carry):
|
|
|
|
|
x, _ = carry
|
|
|
|
|
return x < 10
|
|
|
|
|
|
|
|
|
|
def f():
|
|
|
|
|
val = np.array(1.0) # mutable value
|
|
|
|
|
return jax.lax.while_loop(cond_fun, body_fun, (val, val))[1]
|
|
|
|
|
|
|
|
|
|
with jax.disable_jit(False):
|
|
|
|
|
result_jit = f()
|
|
|
|
|
with jax.disable_jit(True):
|
|
|
|
|
result_nojit = f()
|
|
|
|
|
self.assertEqual(result_jit, result_nojit)
|
|
|
|
|
|
Optimize lax.associative_scan, reimplement cumsum, etc. on top of associative_scan.
Add support for an axis= parameter to associative_scan.
We previously had two associative scan implementations, namely lax.associative_scan, and the implementations of cumsum, cumprod, etc.
lax.associative_scan was more efficient in some ways because unlike the cumsum implementation it did not pad the input array to the nearest power of two size. This appears to have been a significant cause of https://github.com/google/jax/issues/4135.
The cumsum/cummax implementation used slightly more efficient code to slice and
interleave arrays, which this change adds to associative_scan as well. Since we
are now using lax primitives that make it easy to select an axis, add support
for user-chosen scan axes as well.
We can also simplify the implementation of associative_scan: one of the
recursive base cases seems unnecessary, and we can simplify the code by removing
it.
Benchmarks from #4135 on my workstation:
Before:
bench_cumsum: 0.900s
bench_associative_scan: 0.597s
bench_scan: 0.359s
bench_np: 1.619s
After:
bench_cumsum: 0.435s
bench_associative_scan: 0.435s
bench_scan: 0.362s
bench_np: 1.669s
Before, with taskset -c 0:
bench_cumsum: 1.989s
bench_associative_scan: 1.556s
bench_scan: 0.428s
bench_np: 1.670s
After, with taskset -c 0:
bench_cumsum: 1.271s
bench_associative_scan: 1.275s
bench_scan: 0.438s
bench_np: 1.673s
2020-10-15 20:26:29 -04:00
|
|
|
|
@parameterized.named_parameters(
|
2022-12-01 09:12:01 -08:00
|
|
|
|
{"testcase_name": f"_{shape}_{axis=}",
|
Optimize lax.associative_scan, reimplement cumsum, etc. on top of associative_scan.
Add support for an axis= parameter to associative_scan.
We previously had two associative scan implementations, namely lax.associative_scan, and the implementations of cumsum, cumprod, etc.
lax.associative_scan was more efficient in some ways because unlike the cumsum implementation it did not pad the input array to the nearest power of two size. This appears to have been a significant cause of https://github.com/google/jax/issues/4135.
The cumsum/cummax implementation used slightly more efficient code to slice and
interleave arrays, which this change adds to associative_scan as well. Since we
are now using lax primitives that make it easy to select an axis, add support
for user-chosen scan axes as well.
We can also simplify the implementation of associative_scan: one of the
recursive base cases seems unnecessary, and we can simplify the code by removing
it.
Benchmarks from #4135 on my workstation:
Before:
bench_cumsum: 0.900s
bench_associative_scan: 0.597s
bench_scan: 0.359s
bench_np: 1.619s
After:
bench_cumsum: 0.435s
bench_associative_scan: 0.435s
bench_scan: 0.362s
bench_np: 1.669s
Before, with taskset -c 0:
bench_cumsum: 1.989s
bench_associative_scan: 1.556s
bench_scan: 0.428s
bench_np: 1.670s
After, with taskset -c 0:
bench_cumsum: 1.271s
bench_associative_scan: 1.275s
bench_scan: 0.438s
bench_np: 1.673s
2020-10-15 20:26:29 -04:00
|
|
|
|
"shape": shape, "axis": axis}
|
|
|
|
|
for shape in [
|
|
|
|
|
[0], [1], [2], [3], [5], [10], [1000],
|
|
|
|
|
[2, 3], [7, 5], [5, 6, 7]
|
|
|
|
|
]
|
|
|
|
|
for axis in range(-len(shape), len(shape) - 1))
|
|
|
|
|
def testAssociativeScanUnstructured(self, shape, axis):
|
|
|
|
|
data = np.arange(np.prod(shape)).reshape(shape) + 7
|
|
|
|
|
expected = np.cumsum(data, axis=axis)
|
|
|
|
|
result = lax.associative_scan(operator.add, data, axis=axis)
|
2020-05-21 08:35:17 -07:00
|
|
|
|
self.assertAllClose(result, expected, check_dtypes=False)
|
|
|
|
|
|
2020-09-04 19:21:43 +03:00
|
|
|
|
def testAssociativeScanUnstructured1000Reverse(self):
|
Optimize lax.associative_scan, reimplement cumsum, etc. on top of associative_scan.
Add support for an axis= parameter to associative_scan.
We previously had two associative scan implementations, namely lax.associative_scan, and the implementations of cumsum, cumprod, etc.
lax.associative_scan was more efficient in some ways because unlike the cumsum implementation it did not pad the input array to the nearest power of two size. This appears to have been a significant cause of https://github.com/google/jax/issues/4135.
The cumsum/cummax implementation used slightly more efficient code to slice and
interleave arrays, which this change adds to associative_scan as well. Since we
are now using lax primitives that make it easy to select an axis, add support
for user-chosen scan axes as well.
We can also simplify the implementation of associative_scan: one of the
recursive base cases seems unnecessary, and we can simplify the code by removing
it.
Benchmarks from #4135 on my workstation:
Before:
bench_cumsum: 0.900s
bench_associative_scan: 0.597s
bench_scan: 0.359s
bench_np: 1.619s
After:
bench_cumsum: 0.435s
bench_associative_scan: 0.435s
bench_scan: 0.362s
bench_np: 1.669s
Before, with taskset -c 0:
bench_cumsum: 1.989s
bench_associative_scan: 1.556s
bench_scan: 0.428s
bench_np: 1.670s
After, with taskset -c 0:
bench_cumsum: 1.271s
bench_associative_scan: 1.275s
bench_scan: 0.438s
bench_np: 1.673s
2020-10-15 20:26:29 -04:00
|
|
|
|
data = np.arange(1000) + 32
|
2020-09-04 19:21:43 +03:00
|
|
|
|
expected = np.cumsum(data[::-1])[::-1]
|
|
|
|
|
result = lax.associative_scan(operator.add, data, reverse=True)
|
|
|
|
|
self.assertAllClose(result, expected, check_dtypes=False)
|
|
|
|
|
|
2020-05-21 08:35:17 -07:00
|
|
|
|
def testAssociativeScanStructured3(self):
|
|
|
|
|
pair = collections.namedtuple('pair', ('first', 'second'))
|
|
|
|
|
data = pair(first=np.array([0., 1., 2.]),
|
|
|
|
|
second=np.array([0., 10., 20.]))
|
|
|
|
|
|
|
|
|
|
def fn(a, b):
|
|
|
|
|
return pair(first=a.first + b.first,
|
|
|
|
|
second=a.second + b.second)
|
|
|
|
|
|
|
|
|
|
result = lax.associative_scan(fn, elems=data)
|
|
|
|
|
self.assertAllClose(result.first, np.array([0., 1., 3.]),
|
|
|
|
|
check_dtypes=False)
|
|
|
|
|
self.assertAllClose(result.second, np.array([0., 10., 30.]),
|
|
|
|
|
check_dtypes=False)
|
|
|
|
|
|
2021-05-12 10:28:55 -04:00
|
|
|
|
def testAssociativeScanOfBools(self):
|
|
|
|
|
x = jnp.array([False, True, True, True, False, True])
|
|
|
|
|
y = lax.associative_scan(lax.bitwise_xor, x)
|
|
|
|
|
self.assertArraysEqual(np.array([False, True, False, True, True, False]), y)
|
|
|
|
|
|
2021-05-12 14:42:00 -04:00
|
|
|
|
@parameterized.named_parameters({"testcase_name": f"_{shape}", "shape": shape}
|
|
|
|
|
for shape in [2, 43, 100])
|
|
|
|
|
def testAssociativeScanSolvingRegressionTest(self, shape):
|
|
|
|
|
# This test checks that the batching rule doesn't raise for a batch
|
|
|
|
|
# sensitive function (solve).
|
|
|
|
|
ms = np.repeat(np.eye(2).reshape(1, 2, 2), shape, axis=0)
|
|
|
|
|
vs = np.ones((shape, 2))
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
@jax.vmap
|
2021-05-12 14:42:00 -04:00
|
|
|
|
def fn(a, b):
|
|
|
|
|
m1, v1 = a
|
|
|
|
|
m2, v2 = b
|
|
|
|
|
return m1 + m2, jsp.linalg.solve(m1, v2) + jsp.linalg.solve(m2, v1)
|
|
|
|
|
|
|
|
|
|
_ = lax.associative_scan(fn, elems=(ms, vs))
|
|
|
|
|
|
2020-07-30 14:02:48 -07:00
|
|
|
|
def test_scan_typecheck_param(self):
|
|
|
|
|
d = jnp.ones(2)
|
|
|
|
|
def f(c, a):
|
|
|
|
|
b = jnp.cos(jnp.sum(a) + jnp.sum(c) + jnp.sum(d))
|
|
|
|
|
c = jnp.sin(c * b)
|
|
|
|
|
return c, b
|
|
|
|
|
|
|
|
|
|
xs = jnp.ones((5, 3))
|
|
|
|
|
c = jnp.ones(4)
|
|
|
|
|
scan_fun = lambda c, xs: lax.scan(f, c, xs)
|
|
|
|
|
|
|
|
|
|
def new_jaxpr():
|
2024-06-07 13:47:49 -07:00
|
|
|
|
jaxpr = jax.make_jaxpr(partial(scan_fun))(c, xs).jaxpr
|
2020-07-30 14:02:48 -07:00
|
|
|
|
scan = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'scan')
|
|
|
|
|
return jaxpr, scan
|
|
|
|
|
|
|
|
|
|
jaxpr, eqn = new_jaxpr()
|
|
|
|
|
eqn.params['reverse'] = 4
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
|
core.JaxprTypeError,
|
|
|
|
|
re.escape('invalid scan param reverse of type int, bool required: 4'),
|
|
|
|
|
lambda: core.check_jaxpr(jaxpr))
|
|
|
|
|
|
|
|
|
|
jaxpr, eqn = new_jaxpr()
|
|
|
|
|
eqn.params['num_consts'] = -3
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
|
core.JaxprTypeError,
|
|
|
|
|
re.escape('invalid scan param num_consts of type int, '
|
|
|
|
|
'non-negative int required: -3'),
|
|
|
|
|
lambda: core.check_jaxpr(jaxpr))
|
|
|
|
|
|
|
|
|
|
def test_cond_typecheck_param(self):
|
|
|
|
|
def new_jaxpr():
|
2021-09-13 16:00:22 -04:00
|
|
|
|
jaxpr = jax.make_jaxpr(
|
2020-07-30 14:02:48 -07:00
|
|
|
|
lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr
|
|
|
|
|
cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond')
|
|
|
|
|
return jaxpr, cond
|
|
|
|
|
|
|
|
|
|
jaxpr, eqn = new_jaxpr()
|
|
|
|
|
eqn.params['branches'] = (4, 2)
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
|
core.JaxprTypeError,
|
|
|
|
|
re.escape('invalid cond param branches of type tuple, '
|
2020-09-18 10:07:13 -07:00
|
|
|
|
'tuple of ClosedJaxpr required: (4, 2)'),
|
2020-07-30 14:02:48 -07:00
|
|
|
|
lambda: core.check_jaxpr(jaxpr))
|
|
|
|
|
|
2022-03-01 13:10:28 -08:00
|
|
|
|
def test_cond_transformation_rule_with_consts(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# https://github.com/jax-ml/jax/pull/9731
|
2022-03-01 13:10:28 -08:00
|
|
|
|
|
|
|
|
|
@jax.custom_jvp
|
|
|
|
|
def f(x):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
@f.defjvp
|
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
|
(x,), (xdot,) = primals, tangents
|
2022-06-17 14:14:23 -07:00
|
|
|
|
const = np.arange(3, dtype=x.dtype)
|
2022-03-01 13:10:28 -08:00
|
|
|
|
return x * const, xdot * const
|
|
|
|
|
|
|
|
|
|
g = lambda x: jax.lax.cond(True, f, lambda x: x, x)
|
|
|
|
|
x = np.arange(3, dtype='float32')
|
|
|
|
|
jax.jvp(g, (x,), (x,)) # doesn't crash
|
|
|
|
|
|
2025-02-26 06:04:26 -05:00
|
|
|
|
@jtu.thread_unsafe_test()
|
2023-01-18 10:17:01 -08:00
|
|
|
|
def test_cond_excessive_compilation(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/14058
|
2023-01-18 10:17:01 -08:00
|
|
|
|
def f(x):
|
|
|
|
|
return x + 1
|
|
|
|
|
|
|
|
|
|
def g(x):
|
|
|
|
|
return x + 2
|
|
|
|
|
|
2024-08-09 20:03:06 -07:00
|
|
|
|
with jtu.count_jit_and_pmap_lowerings() as count:
|
2023-01-18 10:17:01 -08:00
|
|
|
|
for x in range(10):
|
|
|
|
|
lax.cond(x, f, g, x)
|
|
|
|
|
# Should observe a maximum of 4 compiles: convert_element_type, f, g, cond
|
|
|
|
|
# In #14058, this was observed to be 31 compiles.
|
2024-12-11 16:54:52 -05:00
|
|
|
|
self.assertLess(count(), 5)
|
2023-01-18 10:17:01 -08:00
|
|
|
|
|
2020-10-08 11:52:07 -07:00
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype}
|
|
|
|
|
for dtype in jtu.dtypes.all_integer)
|
|
|
|
|
def test_scan_init_weak_type(self, dtype):
|
|
|
|
|
def func(carry, x):
|
|
|
|
|
return carry + x, x
|
|
|
|
|
init_weak = 0 # Python scalars are weakly-typed.
|
|
|
|
|
x = jnp.ones(5, dtype=dtype)
|
|
|
|
|
carry, result = lax.scan(func, init_weak, x)
|
2022-06-17 14:14:23 -07:00
|
|
|
|
self.assertEqual(carry, x.sum(dtype=carry.dtype))
|
2020-10-07 15:25:01 -07:00
|
|
|
|
self.assertArraysEqual(result, x)
|
|
|
|
|
|
2020-10-08 11:52:07 -07:00
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype}
|
|
|
|
|
for dtype in jtu.dtypes.all_integer)
|
|
|
|
|
def test_while_loop_init_weak_type(self, dtype):
|
|
|
|
|
# This tests whether lax.while_loop can properly handle weakly-typed
|
|
|
|
|
# initial values.
|
|
|
|
|
def cond_fun(val):
|
|
|
|
|
return val < 2
|
|
|
|
|
def body_fun(val):
|
|
|
|
|
return val + increment
|
|
|
|
|
increment = jnp.array(1, dtype=dtype)
|
|
|
|
|
init_weak = 0 # Python scalars are weakly-typed.
|
|
|
|
|
result = lax.while_loop(cond_fun, body_fun, init_weak)
|
|
|
|
|
self.assertArraysEqual(result, jnp.full_like(increment, 2))
|
|
|
|
|
|
2022-04-30 21:50:18 -07:00
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{"testcase_name": f"{suffix}", "remat": remat}
|
|
|
|
|
for suffix, remat in [
|
|
|
|
|
('', None),
|
|
|
|
|
('new_remat', new_checkpoint),
|
|
|
|
|
])
|
|
|
|
|
def test_scan_vjp_forwards_extensive_residuals(self, remat):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# https://github.com/jax-ml/jax/issues/4510
|
2020-10-09 10:35:47 -07:00
|
|
|
|
def cumprod(x):
|
|
|
|
|
s = jnp.ones((2, 32), jnp.float32)
|
|
|
|
|
return lax.scan(lambda s, x: (x*s, s), s, x)
|
2022-04-30 21:50:18 -07:00
|
|
|
|
if remat is not None:
|
|
|
|
|
cumprod = remat(cumprod)
|
2020-10-09 10:35:47 -07:00
|
|
|
|
|
2021-12-10 10:32:09 -08:00
|
|
|
|
rng = self.rng()
|
2020-10-09 11:01:58 -07:00
|
|
|
|
x = jnp.asarray(rng.randn(32, 2, 32).astype('float32'))
|
2021-09-13 16:00:22 -04:00
|
|
|
|
_, vjp_fun = jax.vjp(cumprod, x)
|
2020-10-09 10:35:47 -07:00
|
|
|
|
|
|
|
|
|
# Need to spelunk into vjp_fun. This is fragile, and if it causes problems
|
2022-04-30 21:50:18 -07:00
|
|
|
|
# just skip this test and make an issue for mattjj.
|
2020-10-09 10:35:47 -07:00
|
|
|
|
*_, ext_res = vjp_fun.args[0].args[0]
|
|
|
|
|
self.assertIs(ext_res, x)
|
|
|
|
|
|
2022-04-30 21:50:18 -07:00
|
|
|
|
if remat is not None:
|
|
|
|
|
# TODO(mattjj): make the numpy.ndarray test pass w/ remat
|
|
|
|
|
raise unittest.SkipTest("new-remat-of-scan doesn't convert numpy.ndarray")
|
|
|
|
|
|
2023-08-18 16:50:36 -04:00
|
|
|
|
x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not Array
|
2021-09-13 16:00:22 -04:00
|
|
|
|
_, vjp_fun = jax.vjp(cumprod, x)
|
2020-10-09 10:35:47 -07:00
|
|
|
|
*_, ext_res = vjp_fun.args[0].args[0]
|
2022-09-26 16:17:26 -07:00
|
|
|
|
self.assertIsInstance(ext_res, jax.Array)
|
2020-10-09 10:35:47 -07:00
|
|
|
|
|
2020-10-26 10:11:13 +00:00
|
|
|
|
def test_scan_vmap_collectives(self):
|
|
|
|
|
def scan_f(state, x):
|
|
|
|
|
s = lax.psum(state, 'i') * x
|
|
|
|
|
return state, s
|
|
|
|
|
|
|
|
|
|
def scan(state, xs):
|
|
|
|
|
return lax.scan(scan_f, state, xs)
|
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
|
scan_v = jax.vmap(scan, in_axes=0, out_axes=0, axis_name='i')
|
2020-10-26 10:11:13 +00:00
|
|
|
|
self.assertAllClose(
|
2022-06-17 14:14:23 -07:00
|
|
|
|
scan_v(jnp.ones([1]), jnp.arange(5.).reshape((1, 5))),
|
2022-12-02 13:20:30 -08:00
|
|
|
|
(jnp.array([1.]), jnp.array([[0., 1., 2., 3., 4.]])), check_dtypes=False)
|
2020-10-09 10:35:47 -07:00
|
|
|
|
|
2021-03-05 12:47:20 -08:00
|
|
|
|
def test_xla_cpu_gpu_loop_cond_bug(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# https://github.com/jax-ml/jax/issues/5900
|
2021-03-05 12:47:20 -08:00
|
|
|
|
def deriv(f):
|
|
|
|
|
return lambda x, *args: jax.linearize(lambda x: f(x, *args), x)[1](1.0)
|
|
|
|
|
|
|
|
|
|
def _while_loop(cond_fun, body_fun, init_val, max_iter):
|
|
|
|
|
def _iter(val):
|
|
|
|
|
next_val = body_fun(val)
|
|
|
|
|
next_cond = True
|
|
|
|
|
return next_val, next_cond
|
|
|
|
|
|
|
|
|
|
def _fun(tup, _):
|
|
|
|
|
val, cond = tup
|
|
|
|
|
return jax.lax.cond(cond, _iter, lambda x: (x, False), val), _
|
|
|
|
|
|
|
|
|
|
init = (init_val, cond_fun(init_val))
|
|
|
|
|
return jax.lax.scan(_fun, init, None, length=max_iter)[0][0]
|
|
|
|
|
|
|
|
|
|
def my_pow(x, y):
|
|
|
|
|
def body_fun(val):
|
|
|
|
|
return val * x
|
|
|
|
|
def cond_fun(val):
|
|
|
|
|
return True
|
|
|
|
|
return _while_loop(cond_fun, body_fun, 1.0, y)
|
|
|
|
|
|
|
|
|
|
self.assertAllClose(deriv(my_pow)(3.0, 1), 1.0, check_dtypes=False)
|
|
|
|
|
|
2019-04-10 13:00:31 -07:00
|
|
|
|
|
2021-07-07 11:03:59 -07:00
|
|
|
|
def test_while_loop_fixed_point_with_batched_pred_and_consts(self):
|
|
|
|
|
def f(i, x):
|
|
|
|
|
def cond(carry):
|
|
|
|
|
i, x = carry
|
|
|
|
|
return i < 5
|
|
|
|
|
def body(carry):
|
|
|
|
|
i, z = carry
|
|
|
|
|
# Close over const with batch dim = 1
|
|
|
|
|
return i + 1, z + x
|
|
|
|
|
return lax.while_loop(cond, body, (i, jnp.ones(3)))[1]
|
|
|
|
|
jax.vmap(f, in_axes=(0, 1))(jnp.arange(4), jnp.ones((3, 4)))
|
|
|
|
|
|
2022-03-22 17:39:31 -07:00
|
|
|
|
def test_cond_ad_batched_unit(self):
|
|
|
|
|
# see issue #9985
|
|
|
|
|
def cond_id(x):
|
|
|
|
|
return lax.cond(x < 0., lambda x: x, lambda x: x, x)
|
|
|
|
|
jax.vmap(jax.jacrev(lambda x: cond_id(cond_id(x))))(jnp.ones(1))
|
|
|
|
|
|
2022-07-26 01:02:00 -07:00
|
|
|
|
@parameterized.named_parameters(
|
2023-07-21 14:20:39 -04:00
|
|
|
|
{"testcase_name": f"impl={scan_name}", "scan": scan_impl}
|
2022-08-16 16:44:50 -07:00
|
|
|
|
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
|
2022-07-26 01:02:00 -07:00
|
|
|
|
def test_scan_hoisting_consts(self, scan):
|
|
|
|
|
A = jnp.arange(4.).reshape(2, 2)
|
|
|
|
|
B = jnp.arange(4.).reshape(2, 2) + 1.
|
|
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
|
def body(c, _):
|
|
|
|
|
c1, c2, c3 = c
|
|
|
|
|
return (jnp.dot(A, c1), jnp.dot(B, c2), jnp.dot(jnp.sin(B), c3)), None
|
|
|
|
|
init_carry = (x * jnp.ones(2), x * jnp.ones(2), x * jnp.ones(2))
|
|
|
|
|
(c1, c2, c3), _ = scan(body, init_carry, None, length=3)
|
|
|
|
|
return jnp.sum(c1) + jnp.sum(c2) + jnp.sum(c3)
|
|
|
|
|
|
|
|
|
|
jax.grad(f)(1.) # doesn't crash
|
|
|
|
|
|
2023-01-16 10:29:59 -08:00
|
|
|
|
def test_custom_jvp_tangent_cond_transpose(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# https://github.com/jax-ml/jax/issues/14026
|
2023-01-16 10:29:59 -08:00
|
|
|
|
def mask_fun(arr, choice):
|
|
|
|
|
out = (1 - choice) * arr.sum() + choice * (1 - arr.sum())
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def switch_fun(arr, choice):
|
|
|
|
|
choice = jnp.floor(choice).astype(jnp.int32)
|
|
|
|
|
out = jax.lax.switch(choice, [lambda x: x.sum(), lambda x: 1 - x.sum()], arr)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
test_arr = jnp.arange(3.)
|
|
|
|
|
test_val = 0.
|
|
|
|
|
|
|
|
|
|
expected1 = jax.grad(mask_fun)(test_arr, test_val)
|
|
|
|
|
expected2 = jax.grad(switch_fun)(test_arr, test_val)
|
|
|
|
|
|
|
|
|
|
def good_switchfun_jvp(primals, tangents):
|
|
|
|
|
arr, choice = primals
|
|
|
|
|
arr_dot, choice_dot = tangents
|
|
|
|
|
return switch_fun(arr, choice), mask_fun(arr_dot, choice)
|
|
|
|
|
|
|
|
|
|
def bad_switchfun_jvp(primals, tangents):
|
|
|
|
|
arr, choice = primals
|
|
|
|
|
arr_dot, choice_dot = tangents
|
|
|
|
|
return switch_fun(arr, choice), switch_fun(arr_dot, choice)
|
|
|
|
|
|
|
|
|
|
good_custom_switchfun = jax.custom_jvp(switch_fun)
|
|
|
|
|
good_custom_switchfun.defjvp(good_switchfun_jvp)
|
|
|
|
|
expected3 = jax.grad(good_custom_switchfun)(test_arr, test_val)
|
|
|
|
|
|
|
|
|
|
bad_custom_switchfun = jax.custom_jvp(switch_fun)
|
|
|
|
|
bad_custom_switchfun.defjvp(bad_switchfun_jvp)
|
|
|
|
|
actual = jax.grad(bad_custom_switchfun)(test_arr, test_val)
|
|
|
|
|
|
|
|
|
|
self.assertAllClose(expected1, expected2)
|
|
|
|
|
self.assertAllClose(expected2, expected3)
|
|
|
|
|
self.assertAllClose(expected3, actual)
|
|
|
|
|
|
2023-10-27 03:04:39 +03:00
|
|
|
|
def test_platform_dependent(self):
|
|
|
|
|
def f(x):
|
|
|
|
|
return lax.platform_dependent(x, cpu=jnp.sin, default=jnp.cos)
|
|
|
|
|
|
|
|
|
|
x = np.arange(3, dtype=np.float32)
|
|
|
|
|
res = f(x)
|
|
|
|
|
self.assertAllClose(
|
|
|
|
|
res,
|
|
|
|
|
np.sin(x) if jtu.device_under_test() == "cpu" else np.cos(x))
|
|
|
|
|
|
|
|
|
|
def test_platform_dependent_no_args(self):
|
|
|
|
|
def f(x):
|
|
|
|
|
return lax.platform_dependent(cpu=lambda: jnp.sin(x),
|
|
|
|
|
default=lambda: jnp.cos(x))
|
|
|
|
|
|
|
|
|
|
x = np.arange(3, dtype=np.float32)
|
|
|
|
|
res = f(x)
|
|
|
|
|
self.assertAllClose(
|
|
|
|
|
res,
|
|
|
|
|
np.sin(x) if jtu.device_under_test() == "cpu" else np.cos(x))
|
|
|
|
|
|
|
|
|
|
def test_platform_dependent_lowering(self):
|
|
|
|
|
def f(x):
|
|
|
|
|
return lax.platform_dependent(x, cpu=jnp.sin, default=jnp.cos)
|
|
|
|
|
|
|
|
|
|
x = np.arange(3, dtype=np.float32)
|
|
|
|
|
lowered = jax.jit(f).lower(x)
|
|
|
|
|
stablehlo = lowered.as_text()
|
|
|
|
|
self.assertIn("stablehlo.case", stablehlo)
|
|
|
|
|
self.assertIn("stablehlo.sine", stablehlo)
|
|
|
|
|
self.assertIn("stablehlo.cosine", stablehlo)
|
|
|
|
|
|
|
|
|
|
# The HLO has been canonicalized and contains only the branch we need
|
|
|
|
|
hlo = lowered.as_text("hlo")
|
|
|
|
|
if jtu.device_under_test() == "cpu":
|
|
|
|
|
self.assertIn(" sine", hlo)
|
|
|
|
|
self.assertNotIn(" cosine", hlo)
|
|
|
|
|
else:
|
|
|
|
|
self.assertNotIn(" sine", hlo)
|
|
|
|
|
self.assertIn(" cosine", hlo)
|
|
|
|
|
|
2024-06-20 20:16:01 +03:00
|
|
|
|
def test_platform_dependent_with_non_existent_custom_call(self):
|
|
|
|
|
if not jtu.test_device_matches(["cpu"]):
|
|
|
|
|
self.skipTest("Only for CPU")
|
|
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
|
# One use with the bad custom call on a different platform branch
|
|
|
|
|
x1 = lax.platform_dependent(x,
|
|
|
|
|
cpu=jnp.sin,
|
|
|
|
|
other=prim_non_existent_custom_call.bind)
|
|
|
|
|
# and with the bad custom call in the default branch
|
|
|
|
|
x2 = lax.platform_dependent(x,
|
|
|
|
|
cpu=jnp.sin,
|
|
|
|
|
default=prim_non_existent_custom_call.bind)
|
|
|
|
|
# and one use where the current platform is the default
|
|
|
|
|
x3 = lax.platform_dependent(x,
|
|
|
|
|
other=prim_non_existent_custom_call.bind,
|
|
|
|
|
default=jnp.sin)
|
|
|
|
|
return x1 + x2 + x3
|
|
|
|
|
|
|
|
|
|
x = np.arange(3, dtype=np.float32)
|
|
|
|
|
hlo = str(jax.jit(f).lower(x).compiler_ir())
|
|
|
|
|
occurrences = re.findall(prim_non_existent_custom_call.name, hlo)
|
|
|
|
|
self.assertLen(occurrences, 3)
|
|
|
|
|
|
|
|
|
|
res_eager = f(x)
|
|
|
|
|
self.assertAllClose(res_eager, 3. * np.sin(x))
|
|
|
|
|
res_jit = jax.jit(f)(x)
|
|
|
|
|
self.assertAllClose(res_jit, 3 * np.sin(x))
|
|
|
|
|
|
|
|
|
|
res_vmap = jax.vmap(f)(x)
|
|
|
|
|
self.assertAllClose(res_vmap, 3. * np.sin(x))
|
|
|
|
|
|
|
|
|
|
_, res_jvp = jax.jvp(f, (x,), (np.full(x.shape, .1, dtype=x.dtype),))
|
|
|
|
|
self.assertAllClose(res_jvp, .3 * np.cos(x))
|
|
|
|
|
|
|
|
|
|
res_grad = jax.grad(f)(1.)
|
|
|
|
|
self.assertAllClose(res_grad, 3. * np.cos(1.))
|
|
|
|
|
|
2023-10-27 03:04:39 +03:00
|
|
|
|
def test_platform_dependent_multiple_identical_branches(self):
|
|
|
|
|
x = np.arange(3, dtype=np.float32)
|
|
|
|
|
def f(x):
|
|
|
|
|
return lax.platform_dependent(
|
|
|
|
|
x,
|
|
|
|
|
cpu=jnp.sin,
|
|
|
|
|
tpu=jnp.sin,
|
|
|
|
|
default=lambda x: x)
|
|
|
|
|
res = f(x)
|
|
|
|
|
self.assertAllClose(
|
|
|
|
|
res,
|
|
|
|
|
np.sin(x) if jtu.device_under_test() in ["cpu", "tpu"] else x)
|
|
|
|
|
# We only lower the common branches once
|
|
|
|
|
stablehlo = jax.jit(f).lower(x).as_text()
|
|
|
|
|
sines = re.findall(r"stablehlo.sine", stablehlo)
|
|
|
|
|
self.assertEqual(1, len(sines))
|
|
|
|
|
|
|
|
|
|
def test_platform_dependent_no_default(self):
|
|
|
|
|
ctx = contextlib.ExitStack()
|
|
|
|
|
if jtu.device_under_test() != "tpu":
|
|
|
|
|
ctx.enter_context(
|
2023-11-03 08:28:59 +01:00
|
|
|
|
self.assertRaisesRegex(NotImplementedError,
|
2023-10-27 03:04:39 +03:00
|
|
|
|
"translation rule .* not found for platform"))
|
|
|
|
|
with ctx:
|
|
|
|
|
lax.platform_dependent(
|
|
|
|
|
3.,
|
|
|
|
|
tpu=lambda x: x + 2.)
|
|
|
|
|
|
|
|
|
|
def test_platform_dependent_batched(self):
|
|
|
|
|
def f(x):
|
|
|
|
|
return lax.platform_dependent(x, cpu=jnp.sin, default=jnp.cos)
|
|
|
|
|
|
|
|
|
|
xs = np.arange(3, dtype=np.float32)
|
|
|
|
|
self.assertAllClose(
|
|
|
|
|
jax.vmap(f)(xs),
|
|
|
|
|
np.sin(xs) if jtu.device_under_test() == "cpu" else np.cos(xs))
|
|
|
|
|
# We can still fold the un-needed branch
|
|
|
|
|
hlo = jax.jit(jax.vmap(f)).lower(xs).as_text('hlo')
|
|
|
|
|
expect_a_sine = (jtu.device_under_test() == "cpu")
|
|
|
|
|
self.assertEqual(expect_a_sine, " sine(" in hlo)
|
|
|
|
|
self.assertEqual(not expect_a_sine, " cosine(" in hlo)
|
|
|
|
|
|
|
|
|
|
def test_platform_dependent_grad(self):
|
|
|
|
|
# For a function "lax.dot(x, x)", we choose two branches with very different
|
|
|
|
|
# implementations (a dot and a scan), and therefore different residuals,
|
|
|
|
|
# so that we can verify whether the residuals are as we expect (we don't
|
|
|
|
|
# get residuals from a different platform.
|
|
|
|
|
x = np.arange(8, dtype=np.float32)
|
|
|
|
|
def f_impl_dot(x): # x: f32[8]
|
|
|
|
|
return jnp.dot(x, x)
|
|
|
|
|
def f_impl_scan(x):
|
|
|
|
|
def scan_body(carry, x_i):
|
|
|
|
|
return (carry + x_i * x_i, None)
|
|
|
|
|
return lax.scan(scan_body, np.float32(0.), x)[0]
|
|
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
|
return jnp.sin(lax.platform_dependent(x,
|
|
|
|
|
cpu=f_impl_dot,
|
|
|
|
|
default=f_impl_scan))
|
|
|
|
|
self.assertAllClose(
|
|
|
|
|
jax.grad(f)(x),
|
|
|
|
|
jax.grad(lambda x: jnp.sin(f_impl_dot(x)))(x))
|
|
|
|
|
|
|
|
|
|
# Check that we do not have contamination of computations across platforms
|
|
|
|
|
hlo = jax.jit(jax.grad(f)).lower(x).as_text('hlo')
|
|
|
|
|
expect_a_dot = (jtu.device_under_test() == "cpu")
|
|
|
|
|
self.assertEqual(expect_a_dot, " dot(" in hlo)
|
|
|
|
|
self.assertEqual(not expect_a_dot, " while(" in hlo)
|
|
|
|
|
|
2024-05-25 19:05:31 +00:00
|
|
|
|
def test_scan_lowering_doesnt_introduce_singleton(self):
|
|
|
|
|
b = 4
|
|
|
|
|
i = 2
|
|
|
|
|
|
|
|
|
|
def scan(y):
|
|
|
|
|
def body(carry, x):
|
|
|
|
|
return carry, jnp.dot(x, x)
|
|
|
|
|
return jax.lax.scan(body, 1.0, y, unroll=False)
|
|
|
|
|
|
|
|
|
|
fn = jax.jit(scan)
|
|
|
|
|
|
|
|
|
|
init = np.array(np.arange(b * i * i), dtype=np.float32).reshape((b, i, i))
|
|
|
|
|
hlo_text = fn.lower(init).as_text('hlo')
|
|
|
|
|
self.assertNotIn('4,1,2,2', hlo_text)
|
|
|
|
|
|
2024-08-24 17:29:20 -07:00
|
|
|
|
def test_scan_length_concrete_error(self):
|
|
|
|
|
f = jax.jit(lambda n, x: jax.lax.scan(lambda c, z: (c, z), x, (), n))
|
|
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
core.ConcretizationTypeError,
|
|
|
|
|
"The `length` argument to `scan` expects a concrete `int` value.*"):
|
|
|
|
|
f(3, 1.)
|
|
|
|
|
|
2024-08-24 22:39:28 -07:00
|
|
|
|
def test_scan_unroll_concrete_error(self):
|
|
|
|
|
f = jax.jit(lambda n, x: jax.lax.scan(
|
|
|
|
|
lambda c, z: (c, z), x, (), 10, unroll=n))
|
|
|
|
|
|
|
|
|
|
msg = ("The `unroll` argument to `scan` expects a concrete `int` or "
|
|
|
|
|
"`bool` value.*")
|
|
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
|
|
|
f(3, 1.)
|
|
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
|
|
|
f(True, 1.)
|
|
|
|
|
|
2024-06-05 16:39:55 +00:00
|
|
|
|
def test_cond_vmap_forwarding_doesnt_promote(self):
|
|
|
|
|
def f(x, y):
|
2024-06-11 15:33:05 +01:00
|
|
|
|
x, y = jax.lax.cond(
|
|
|
|
|
x < 3,
|
|
|
|
|
lambda x, y: (x * 2, y),
|
|
|
|
|
lambda x, y: (x * 3, y),
|
|
|
|
|
x, y
|
|
|
|
|
)
|
2024-06-05 16:39:55 +00:00
|
|
|
|
return x, y
|
|
|
|
|
|
2024-06-11 15:33:05 +01:00
|
|
|
|
x = jnp.arange(3)
|
|
|
|
|
y = jnp.array(3.)
|
|
|
|
|
|
|
|
|
|
x2, y2 = jax.vmap(f, in_axes=(0, None), out_axes=(0, None))(x, y) # don't crash
|
|
|
|
|
|
|
|
|
|
assert x is not x2
|
|
|
|
|
assert y is y2
|
|
|
|
|
|
|
|
|
|
def test_cond_casting(self):
|
|
|
|
|
x = 1.0
|
|
|
|
|
identity = lambda x: x
|
2024-06-05 16:39:55 +00:00
|
|
|
|
|
2024-06-11 15:33:05 +01:00
|
|
|
|
y = lax.cond(True, identity, identity, x)
|
|
|
|
|
self.assertEqual(y, x)
|
|
|
|
|
self.assertIsInstance(y, jax.Array)
|
2023-01-16 10:29:59 -08:00
|
|
|
|
|
2025-01-10 06:58:01 -08:00
|
|
|
|
@jtu.thread_unsafe_test() # live_arrays count isn't thread-safe
|
2024-07-23 23:17:06 +00:00
|
|
|
|
def test_cond_memory_leak(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# https://github.com/jax-ml/jax/issues/12719
|
2024-07-23 23:17:06 +00:00
|
|
|
|
|
|
|
|
|
def leak():
|
|
|
|
|
data = jax.device_put(np.zeros((1024), dtype=np.float32) + 1)
|
|
|
|
|
def g():
|
|
|
|
|
return jax.lax.cond(
|
|
|
|
|
True,
|
|
|
|
|
lambda: data[0], # noqa: F821
|
|
|
|
|
lambda: data[1], # noqa: F821
|
|
|
|
|
)
|
|
|
|
|
jg = jax.jit(g)
|
|
|
|
|
_ = jg().block_until_ready()
|
|
|
|
|
del g, jg, data, _
|
|
|
|
|
|
2024-07-25 21:06:01 +00:00
|
|
|
|
nbufs = lambda: len(jax.live_arrays())
|
|
|
|
|
base = nbufs()
|
2024-07-23 23:17:06 +00:00
|
|
|
|
leak()
|
2024-07-25 21:06:01 +00:00
|
|
|
|
self.assertEqual(base, nbufs())
|
2024-07-23 23:17:06 +00:00
|
|
|
|
leak()
|
2024-07-25 21:06:01 +00:00
|
|
|
|
self.assertEqual(base, nbufs())
|
2024-07-23 23:17:06 +00:00
|
|
|
|
leak()
|
2024-07-25 21:06:01 +00:00
|
|
|
|
self.assertEqual(base, nbufs())
|
2024-07-23 23:17:06 +00:00
|
|
|
|
|
2025-03-22 20:40:14 +00:00
|
|
|
|
def test_grad_remat_while_fixpoint(self):
|
|
|
|
|
@jax.remat
|
|
|
|
|
def f(x, y):
|
|
|
|
|
def cond(_):
|
|
|
|
|
return False
|
|
|
|
|
def body(c):
|
|
|
|
|
x, y = c
|
|
|
|
|
return (y, x)
|
|
|
|
|
x, y = jax.lax.while_loop(cond, body, (x, y))
|
|
|
|
|
return x + y
|
|
|
|
|
jax.linearize(f, 1., 2.) # don't crash
|
|
|
|
|
|
2024-07-23 23:17:06 +00:00
|
|
|
|
|
2019-04-10 13:00:31 -07:00
|
|
|
|
if __name__ == '__main__':
|
2020-06-24 16:24:33 -07:00
|
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|