rocm_jax/tests/state_test.py
2022-12-20 14:49:27 -08:00

1016 lines
44 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import itertools as it
from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax import core
from jax import lax
from jax._src import linear_util as lu
from jax.config import config
from jax.interpreters import partial_eval as pe
from jax._src import test_util as jtu
from jax._src.util import tuple_insert
import jax.numpy as jnp
from jax._src.lax.control_flow import for_loop
try:
import hypothesis as hp
import hypothesis.extra.numpy as hnp
import hypothesis.strategies as hps
CAN_USE_HYPOTHESIS = True
except (ModuleNotFoundError, ImportError):
CAN_USE_HYPOTHESIS = False
from jax._src import state
config.parse_flags_with_absl()
class StatePrimitivesTest(jtu.JaxTestCase):
def test_cant_eval_get_primitive(self):
with self.assertRaises(ValueError):
state.get_p.bind(jnp.ones(5))
def test_cant_eval_swap_primitive(self):
with self.assertRaises(ValueError):
state.swap_p.bind(jnp.ones(5), jnp.zeros(5))
def test_cant_eval_addupdate_primitive(self):
with self.assertRaises(ValueError):
state.addupdate_p.bind(jnp.ones(5), jnp.zeros(5))
def test_get_abstract_aval_must_take_in_refs(self):
ref_aval = core.ShapedArray((), jnp.float32)
def f(x_ref):
return [state.ref_get(x_ref, ())]
with self.assertRaises(ValueError):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval])
@parameterized.named_parameters(
dict(testcase_name="trivial_get", ref_shape=(1, 2),
ref_dtype=jnp.float32,
idx=(), out_shape=(1, 2), out_dtype=jnp.float32),
dict(testcase_name="get_with_index", ref_shape=(1, 2),
ref_dtype=jnp.float32,
idx=(0,), out_shape=(2,), out_dtype=jnp.float32),
dict(testcase_name="get_with_nonleading_index", ref_shape=(1, 2),
ref_dtype=jnp.float32,
idx=(slice(None), 0), out_shape=(1,), out_dtype=jnp.float32),
dict(testcase_name="get_with_array_index", ref_shape=(1, 2, 3, 4),
ref_dtype=jnp.float32,
idx=(np.array([0, 1]),), out_shape=(2, 2, 3, 4),
out_dtype=jnp.float32),
dict(testcase_name="get_with_multiple_array_index",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(np.array([0, 1]), np.array([0, 1])),
out_shape=(2, 2, 4), out_dtype=jnp.float32),
dict(testcase_name="get_with_nonleading_multiple_array_index",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
out_shape=(2, 1, 2), out_dtype=jnp.float32),
)
def test_get_abstract_eval(self, ref_shape, ref_dtype, idx, out_shape=None,
out_dtype=None, should_error=False):
ref_aval = state.ShapedArrayRef(ref_shape, ref_dtype)
def f(x_ref):
out = state.ref_get(x_ref, idx)
return [out]
if should_error:
with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval])
else:
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval])
self.assertSetEqual(jaxpr.effects, {state.ReadEffect(ref_aval)})
self.assertLen(out_avals, 1)
out_aval, = out_avals
self.assertIsInstance(out_aval, core.ShapedArray)
self.assertEqual(out_aval.shape, out_shape)
self.assertEqual(out_aval.dtype, out_dtype)
def test_swap_abstract_eval_must_take_in_refs(self):
ref_aval = core.ShapedArray((), jnp.float32)
val_aval = core.ShapedArray((), jnp.float32)
def f(x_ref, val):
return [state.ref_swap(x_ref, (), val)]
with self.assertRaises(ValueError):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
@parameterized.named_parameters(
dict(testcase_name="invalid_val_shape", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(), should_error=True),
dict(testcase_name="invalid_val_shape_slice", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(slice(None),), should_error=True),
dict(testcase_name="trivial_swap", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(1, 2), val_dtype=jnp.float32,
idx=(), out_shape=(1, 2), out_dtype=jnp.float32),
dict(testcase_name="bad_dtype", ref_shape=(1, 2),
ref_dtype=jnp.int32, val_shape=(1, 2), val_dtype=jnp.float32,
idx=(), should_error=True),
dict(testcase_name="swap_with_index", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(0,), out_shape=(2,), out_dtype=jnp.float32),
dict(testcase_name="swap_with_nonleading_index", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(1,), val_dtype=jnp.float32,
idx=(slice(None), 0), out_shape=(1,), out_dtype=jnp.float32),
dict(testcase_name="swap_with_nonleading_index_bad_val", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(slice(None), 0), should_error=True),
dict(testcase_name="swap_with_array_index", ref_shape=(1, 2, 3, 4),
ref_dtype=jnp.float32, val_shape=(2, 2, 3, 4), val_dtype=jnp.float32,
idx=(np.array([0, 1]),), out_shape=(2, 2, 3, 4),
out_dtype=jnp.float32),
dict(testcase_name="swap_with_multiple_array_index",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 2, 4), val_dtype=jnp.float32,
idx=(np.array([0, 1]), np.array([0, 1])),
out_shape=(2, 2, 4), out_dtype=jnp.float32),
dict(testcase_name="swap_with_nonleading_multiple_array_index",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 1, 2), val_dtype=jnp.float32,
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
out_shape=(2, 1, 2), out_dtype=jnp.float32),
)
def test_swap_abstract_eval(self, ref_shape, ref_dtype,
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
should_error=False):
ref_aval = state.ShapedArrayRef(ref_shape, ref_dtype)
val_aval = core.ShapedArray(val_shape, val_dtype)
def f(x_ref, val):
out = state.ref_swap(x_ref, idx, val)
return [out]
if should_error:
with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
else:
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval])
self.assertSetEqual(jaxpr.effects, {state.WriteEffect(ref_aval)})
self.assertLen(out_avals, 1)
out_aval, = out_avals
self.assertIsInstance(out_aval, core.ShapedArray)
self.assertEqual(out_aval.shape, out_shape)
self.assertEqual(out_aval.dtype, out_dtype)
@parameterized.named_parameters(
dict(testcase_name="invalid_val_shape", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(), should_error=True),
dict(testcase_name="invalid_val_shape_slice", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(slice(None),), should_error=True),
dict(testcase_name="trivial_addupdate", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(1, 2), val_dtype=jnp.float32,
idx=(), out_shape=(1, 2), out_dtype=jnp.float32),
dict(testcase_name="bad_dtype", ref_shape=(1, 2),
ref_dtype=jnp.int32, val_shape=(1, 2), val_dtype=jnp.float32,
idx=(), should_error=True),
dict(testcase_name="addupdate_with_index", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(0,), out_shape=(2,), out_dtype=jnp.float32),
dict(testcase_name="addupdate_with_nonleading_index", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(1,), val_dtype=jnp.float32,
idx=(slice(None), 0)),
dict(testcase_name="addupdate_with_nonleading_index_bad_val", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(slice(None), 0), should_error=True),
dict(testcase_name="addupdate_with_array_index", ref_shape=(1, 2, 3, 4),
ref_dtype=jnp.float32, val_shape=(2, 2, 3, 4), val_dtype=jnp.float32,
idx=(np.array([0, 1]),)),
dict(testcase_name="addupdate_with_multiple_array_index",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 2, 4), val_dtype=jnp.float32,
idx=(np.array([0, 1]), np.array([0, 1]))),
dict(testcase_name="addupdate_with_nonleading_multiple_array_index",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 1, 2), val_dtype=jnp.float32,
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1]))),
)
def test_addupdate_abstract_eval(self, ref_shape, ref_dtype,
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
should_error=False):
ref_aval = state.ShapedArrayRef(ref_shape, ref_dtype)
val_aval = core.ShapedArray(val_shape, val_dtype)
def f(x_ref, val):
state.ref_addupdate(x_ref, idx, val)
return []
if should_error:
with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
else:
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval])
self.assertSetEqual(jaxpr.effects, {state.AccumEffect(ref_aval)})
self.assertLen(out_avals, 0)
def test_addupdate_abstract_eval_must_take_in_refs(self):
ref_aval = core.ShapedArray((), jnp.float32)
val_aval = core.ShapedArray((), jnp.float32)
def f(x_ref, val):
return [state.ref_addupdate(x_ref, (), val)]
with self.assertRaises(ValueError):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
def test_can_represent_get_and_swap_in_jaxprs(self):
def body(x):
x[()] = jnp.int32(1)
x[()] = jnp.int32(2)
return (x[()],)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
self.assertEqual(jaxpr.eqns[0].primitive, state.swap_p)
self.assertEqual(jaxpr.eqns[1].primitive, state.swap_p)
self.assertEqual(jaxpr.eqns[2].primitive, state.get_p)
def test_can_represent_addupdate_in_jaxprs(self):
def body(x):
state.ref_addupdate(x, (), jnp.int32(1))
return (x[()],)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p)
def test_get_custom_pretty_printing_rule(self):
def body(x_ref):
x = x_ref[()]
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False))
def body(x_ref):
x = x_ref[:, 0]
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32)])
self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False))
def test_set_custom_pretty_printing_rule(self):
def body(x_ref):
x_ref[()] = jnp.int32(2)
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False))
def body(x_ref, val):
x_ref[:, 0] = val
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False))
def test_swap_custom_pretty_printing_rule(self):
def body(x_ref):
x = state.ref_swap(x_ref, (), jnp.int32(2))
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False))
def body(x_ref, val):
x = state.ref_swap(x_ref, (slice(None), 0), val)
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("c:i32[1], a[:,0] <- a[:,0], b",
jaxpr.pretty_print(use_color=False))
def test_addupdate_custom_pretty_printing_rule(self):
def body(x_ref):
state.ref_addupdate(x_ref, (), jnp.int32(2))
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False))
def body(x_ref, val):
state.ref_addupdate(x_ref, (slice(None), 0), val)
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False))
def test_get_jvp(self):
def f(r):
x = r[()]
return jnp.cos(x)
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
state.ShapedArrayRef((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, state.get_p)
self.assertEqual(jaxpr.eqns[1].primitive, state.get_p)
def test_swap_jvp(self):
def f(a):
x = a[()]
a[()] = jnp.sin(x)
return a[()]
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
state.ShapedArrayRef((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, state.get_p)
self.assertEqual(jaxpr.eqns[1].primitive, state.get_p)
self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p)
self.assertEqual(jaxpr.eqns[3].primitive, lax.cos_p)
self.assertEqual(jaxpr.eqns[4].primitive, lax.mul_p)
self.assertEqual(jaxpr.eqns[5].primitive, state.swap_p)
self.assertEqual(jaxpr.eqns[6].primitive, state.swap_p)
def test_addupdate_jvp(self):
def f(a):
state.ref_addupdate(a, (), jnp.float32(1.))
return a[()]
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
state.ShapedArrayRef((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p)
self.assertEqual(jaxpr.eqns[1].primitive, state.addupdate_p)
self.assertEqual(jaxpr.eqns[2].primitive, state.get_p)
self.assertEqual(jaxpr.eqns[3].primitive, state.get_p)
@jtu.sample_product(
[dict(ref_shape=ref_shape, ref_bdim=ref_bdim, idx_shape=idx_shape,
indexed_dims=indexed_dims, idx_bdims=idx_bdims, out_bdim=out_bdim)
for ref_shape in [(1,), (2, 3), (4, 5, 6)]
for ref_bdim in range(1 + len(ref_shape))
for idx_shape in [(), (1,), (2,), (5, 6)]
for indexed_dims in it.product([True, False], repeat=len(ref_shape))
for idx_bdims in it.product([None, *range(1 + len(idx_shape))],
repeat=sum(indexed_dims))
for out_bdim in range(1 + len(ref_shape) - sum(indexed_dims)
+ len(idx_shape) * any(indexed_dims))
],
op=[
lambda x_ref, indexer: [x_ref[indexer]],
lambda x_ref, indexer: [
state.ref_swap(x_ref, indexer,
jnp.ones(x_ref.shape, x_ref.dtype)[None][(0,
*indexer)])],
lambda x_ref, indexer: (
state.ref_addupdate(x_ref, indexer,
jnp.ones(x_ref.shape, x_ref.dtype)[None][(0,
*indexer)])
or [jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, *indexer)]])
],
)
def test_vmap(self, ref_shape, ref_bdim, idx_shape, indexed_dims,
idx_bdims, out_bdim, op):
float_ = (jnp.dtype('float64') if jax.config.jax_enable_x64 else
jnp.dtype('float32'))
int_ = (jnp.dtype('int64') if jax.config.jax_enable_x64 else
jnp.dtype('int32'))
axis_size = 7
out_shape = tuple([d for d, b in zip(ref_shape, indexed_dims) if not b])
if any(indexed_dims):
out_shape = (*idx_shape, *out_shape)
def maybe_insert(shape, idx):
if idx is None:
return shape
return tuple_insert(shape, idx, axis_size)
batched_ref_shape = maybe_insert(ref_shape, ref_bdim)
ref_aval = state.ShapedArrayRef(ref_shape, float_)
bat_ref_aval = state.ShapedArrayRef(batched_ref_shape, float_)
idx_avals = [core.ShapedArray(idx_shape, int_)
for _ in idx_bdims]
bat_idx_avals = [
core.ShapedArray(maybe_insert(idx_shape, idx_bdim), int_)
for idx_bdim in idx_bdims]
def f(x_ref, *idxs):
idxs_ = iter(idxs)
indexer = tuple([next(idxs_) if b else slice(None) for b in indexed_dims])
return op(x_ref, indexer)
rng = self.rng()
a = rng.randn(*bat_ref_aval.shape)
his = [d for d, b in zip(ref_aval.shape, indexed_dims) if b]
idxs = [rng.randint(low=0, high=hi, size=i.shape)
for i, hi in zip(bat_idx_avals, his)]
# discharge-of-vmap
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, *bat_idx_avals])
jaxpr, consts = state.discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, *idx_avals])
jaxpr_, consts_ = state.discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, *idx_bdims),
out_axes=[out_bdim, ref_bdim])
vmap_of_discharge_ans = f_batched(a, *idxs)
self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans,
check_dtypes=False)
class StateDischargeTest(jtu.JaxTestCase):
def test_discharge_get(self):
def f(a_ref):
a = state.ref_get(a_ref, ())
return [a + 1]
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 2)
self.assertEqual(discharged_jaxpr.eqns[0].primitive, lax.add_p)
# Should be able to evaluate this jaxpr
self.assertListEqual(core.eval_jaxpr(discharged_jaxpr, (),
jnp.float32(1.)), [2., 1.])
def test_discharge_get_with_slice(self):
def f(a_ref):
a = state.ref_get(a_ref, (0, 1))
return [a + 1]
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, () = state.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 2)
self.assertIn(lax.dynamic_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
# Should be able to evaluate this jaxpr
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
outval, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((outval == inval[0, 1] + 1).all())
self.assertTrue((refval == inval).all())
def test_discharge_get_with_gather(self):
def f(a_ref):
a = a_ref[jnp.array([0, 1])]
return [a + 1]
in_avals = [state.ShapedArrayRef((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), in_avals)
discharged_jaxpr, discharged_consts = state.discharge_state(
stateful_jaxpr, consts)
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
outval, refval = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval)
self.assertTrue((outval == inval[jnp.array([0, 1])] + 1).all())
self.assertTrue((refval == inval).all())
def test_discharge_set(self):
def f(a_ref, b):
state.ref_set(a_ref, (), b + 1)
return []
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that ignores the first
# value and returns second value plus 1.
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 2)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.),
jnp.float32(1.))[0], 2.)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.),
jnp.float32(1.))[0], 2.)
def test_discharge_set_with_slice(self):
def f(a_ref):
state.ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32')))
return []
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, () = state.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertIn(lax.dynamic_update_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
self.assertIn(lax.dynamic_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
# Should be able to evaluate this jaxpr
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
refval, = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((refval == inval.at[0, 1].set(1.)).all())
def test_discharge_set_with_gather(self):
def f(a_ref):
a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32')
return []
in_avals = [state.ShapedArrayRef((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, discharged_consts = state.discharge_state(
stateful_jaxpr, consts)
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
refval, = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval)
self.assertTrue((refval == inval.at[jnp.array([0, 1])].set(1.)).all())
def test_discharge_addupdate(self):
def f(a_ref, b):
state.ref_addupdate(a_ref, (), b + 1)
return []
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that adds the first value,
# second value, and 1.
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 2)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.),
jnp.float32(1.))[0], 2.)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.),
jnp.float32(1.))[0], 4.)
def test_discharge_addupdate_with_slice(self):
def f(a_ref):
state.ref_addupdate(a_ref, (0, 1),
jnp.ones(2, dtype=jnp.dtype('float32')))
return []
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertIn(lax.dynamic_update_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
self.assertIn(lax.add_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
self.assertIn(lax.dynamic_slice_p,
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
refval, = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((refval == inval.at[0, 1].add(1.)).all())
def test_discharge_addupdate_with_gather(self):
def f(a_ref):
state.ref_addupdate(a_ref, (jnp.array([0, 1]),),
jnp.ones((2, 3), 'float32'))
return []
in_avals = [state.ShapedArrayRef((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, discharged_consts = state.discharge_state(
stateful_jaxpr, consts)
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
refval, = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval)
self.assertTrue((refval == inval.at[jnp.array([0, 1])].add(1.)).all())
def test_discharge_jaxpr_with_multiple_outputs(self):
def f(a_ref):
a = state.ref_get(a_ref, ())
b = a + 1
return [a, b]
in_avals = [state.ShapedArrayRef((4,), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 3)
inval = jnp.arange(4., dtype=jnp.float32)
a, b, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((a == inval).all())
self.assertTrue((b == inval + 1).all())
self.assertTrue((refval == inval).all())
def test_partially_discharging_jaxpr_keeps_refs(self):
def f(a_ref, b_ref):
state.ref_set(a_ref, (), jnp.ones(4, jnp.float32))
state.ref_set(b_ref, (), jnp.ones(4, jnp.float32))
return []
in_avals = [
state.ShapedArrayRef((4,), jnp.dtype('float32')),
state.ShapedArrayRef((4,), jnp.dtype('float32'))
]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = state.discharge_state(
stateful_jaxpr, consts, should_discharge=[False, True])
self.assertLen(discharged_jaxpr.invars, 2)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertIsInstance(discharged_jaxpr.invars[0].aval, state.ShapedArrayRef)
self.assertIsInstance(discharged_jaxpr.invars[1].aval, core.ShapedArray)
self.assertEqual(discharged_jaxpr.effects,
{state.WriteEffect(discharged_jaxpr.invars[0].aval)})
if CAN_USE_HYPOTHESIS:
def index_arrays(size, idx_shape):
valid_idx = hps.integers(min_value=-size, max_value=size - 1)
return hnp.arrays(np.int32, idx_shape, elements=valid_idx)
Shape = tuple[int, ...]
class IndexParam(NamedTuple):
ref_aval: state.ShapedArrayRef
ref_shape: Shape
indexed_dims: list[bool]
idx_avals: tuple[core.ShapedArray, ...]
idx_shape: Shape
slice_aval: core.ShapedArray
slice_shape: Shape
@hps.composite
def index_params(draw):
ref_shape = draw(hnp.array_shapes(max_dims=4, max_side=7), label='ref_shape')
indexed_dims = draw(hps.lists(hps.booleans(),
min_size=len(ref_shape),
max_size=len(ref_shape)))
idx_shape = draw(hnp.array_shapes(max_dims=3, max_side=5))
if any(indexed_dims):
sliced_shape = (s for s, b in zip(ref_shape, indexed_dims) if not b)
slice_shape = (*idx_shape, *sliced_shape)
else:
slice_shape = ref_shape
ref_aval = state.ShapedArrayRef(ref_shape, np.float32)
idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in
range(sum(indexed_dims)))
slice_aval = core.ShapedArray(slice_shape, np.float32)
return IndexParam(ref_aval, ref_shape, indexed_dims, idx_avals, idx_shape,
slice_aval, slice_shape)
class VmappableIndexParam(NamedTuple):
index_param: IndexParam
ref_bdim: Optional[int]
non_slice_idx_bdims: tuple[Optional[int], ...]
slice_bdim: int
bat_ref_aval: state.ShapedArrayRef
bat_ref_shape: Shape
bat_non_slice_idx_avals: tuple[core.ShapedArray, ...]
bat_non_slice_idx_shapes: tuple[Shape, ...]
bat_slice_aval: core.ShapedArray
bat_slice_shape: Shape
def maybe_tuple_insert(t: tuple[Any, ...], idx: Optional[int],
val: Any) -> tuple[Any, ...]:
if idx is None:
return t
return tuple_insert(t, idx, val)
@hps.composite
def vmappable_index_params(draw, *, op_type: str):
axis_size = draw(hps.integers(min_value=1, max_value=7), label='axis_size')
index_param: IndexParam = draw(index_params())
non_slice_idx_bdims = tuple(
draw(hps.one_of(
hps.none(),
hps.integers(min_value=0, max_value=len(index_param.idx_shape))))
for b in index_param.indexed_dims if b)
bat_non_slice_idx_shapes = tuple(
maybe_tuple_insert(index_param.idx_shape, idx_bdim, axis_size)
for idx_bdim in non_slice_idx_bdims)
if op_type == "swap":
# In a swap, the ref *must* be batched
ref_bdim = draw(hps.integers(min_value=0,
max_value=len(index_param.ref_shape)))
if any(idx_bdim is not None for idx_bdim in non_slice_idx_bdims):
# If it's a swap, if indices are batched, val must be batched.
slice_bdim = draw(hps.integers(
min_value=0, max_value=len(index_param.slice_shape)))
else:
slice_bdim = draw(hps.one_of(hps.none(), hps.integers(
min_value=0, max_value=len(index_param.slice_shape))))
elif op_type == "get":
# In a get, the indices must be batched or ref is batched
if all(idx_bdim is None for idx_bdim in non_slice_idx_bdims):
ref_bdim = draw(hps.integers(min_value=0,
max_value=len(index_param.ref_shape)))
else:
ref_bdim = draw(hps.one_of(hps.none(),
hps.integers(min_value=0, max_value=len(index_param.ref_shape))))
slice_bdim = draw(hps.integers(
min_value=0, max_value=len(index_param.slice_shape)))
bat_ref_shape = maybe_tuple_insert(index_param.ref_shape, ref_bdim, axis_size)
bat_ref_aval = state.ShapedArrayRef(bat_ref_shape, np.float32)
bat_non_slice_idx_avals = tuple(
core.ShapedArray(shape, np.int32) for shape in bat_non_slice_idx_shapes)
bat_slice_shape = maybe_tuple_insert(index_param.slice_shape, slice_bdim, axis_size)
bat_slice_aval = core.ShapedArray(bat_slice_shape, np.float32)
return VmappableIndexParam(index_param, ref_bdim, non_slice_idx_bdims,
slice_bdim, bat_ref_aval, bat_ref_shape,
bat_non_slice_idx_avals, bat_non_slice_idx_shapes,
bat_slice_aval, bat_slice_shape)
class GetVmapParams(NamedTuple):
vmap_index_param: VmappableIndexParam
bat_ref: np.ndarray
bat_idxs: tuple[np.ndarray, ...]
@hps.composite
def get_vmap_params(draw):
vmap_index_param: VmappableIndexParam = draw(
vmappable_index_params(op_type="get"))
bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape))
bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes)
bat_idxs = tuple(
draw(index_arrays(size, next(bat_idx_shapes_)))
for size, indexed in zip(
vmap_index_param.index_param.ref_shape,
vmap_index_param.index_param.indexed_dims)
if indexed)
assert next(bat_idx_shapes_, None) is None
return GetVmapParams(vmap_index_param, bat_ref, bat_idxs)
class SetVmapParams(NamedTuple):
vmap_index_param: VmappableIndexParam
bat_ref: np.ndarray
bat_val: np.ndarray
bat_idxs: tuple[np.ndarray, ...]
@hps.composite
def set_vmap_params(draw):
vmap_index_param: VmappableIndexParam = draw(vmappable_index_params(
op_type="swap"))
bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape))
bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes)
bat_idxs = tuple(
draw(index_arrays(size, next(bat_idx_shapes_)))
for size, indexed in zip(
vmap_index_param.index_param.ref_shape,
vmap_index_param.index_param.indexed_dims)
if indexed)
assert next(bat_idx_shapes_, None) is None
bat_val = draw(hnp.arrays(np.float32, vmap_index_param.bat_slice_shape))
return SetVmapParams(vmap_index_param, bat_ref, bat_val, bat_idxs)
Indexer = Tuple[Union[int, slice, np.ndarray]]
def _unpack_idx(idx: Indexer
) -> Tuple[Sequence[Union[int, np.ndarray]], Sequence[bool]]:
indexed_dims = [type(i) != slice for i in idx]
non_slice_idx = [i for i, b in zip(idx, indexed_dims) if b]
return non_slice_idx, indexed_dims
def _pack_idx(non_slice_idx: Sequence[Union[int, np.ndarray]],
indexed_dims: Sequence[bool]) -> Indexer:
idx_ = iter(non_slice_idx)
idx = tuple(next(idx_) if b else slice(None) for b in indexed_dims)
assert next(idx_, None) is None
return idx
class StateHypothesisTest(jtu.JaxTestCase):
@hp.given(get_vmap_params())
@hp.settings(deadline=None, print_blob=True, max_examples=50)
def test_get_vmap(self, get_vmap_param: GetVmapParams):
indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims
def f(ref, *non_slice_idx):
idx = _pack_idx(non_slice_idx, indexed_dims)
return [state.ref_get(ref, idx)]
ref_aval = get_vmap_param.vmap_index_param.index_param.ref_aval
bat_ref_aval = get_vmap_param.vmap_index_param.bat_ref_aval
bat_non_slice_idx_avals = get_vmap_param.vmap_index_param.bat_non_slice_idx_avals
ref_bdim = get_vmap_param.vmap_index_param.ref_bdim
idx_bdims = get_vmap_param.vmap_index_param.non_slice_idx_bdims
out_bdim = get_vmap_param.vmap_index_param.slice_bdim
non_slice_idx = get_vmap_param.bat_idxs
idx_avals = get_vmap_param.vmap_index_param.index_param.idx_avals
ref = get_vmap_param.bat_ref
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, *bat_non_slice_idx_avals])
jaxpr, consts = state.discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, *idx_avals])
jaxpr_, consts_ = state.discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, *idx_bdims),
out_axes=[out_bdim, ref_bdim])
vmap_of_discharge_ans = f_batched(ref, *non_slice_idx)
self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans,
check_dtypes=False)
@hp.given(set_vmap_params())
@hp.settings(deadline=None, print_blob=True, max_examples=50)
def test_set_vmap(self, set_vmap_param: SetVmapParams):
indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims
def f(ref, val, *non_slice_idx):
idx = _pack_idx(non_slice_idx, indexed_dims)
state.ref_set(ref, idx, val)
return []
ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval
bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval
bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals
ref_bdim = set_vmap_param.vmap_index_param.ref_bdim
idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims
non_slice_idx = set_vmap_param.bat_idxs
idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals
ref = set_vmap_param.bat_ref
val = set_vmap_param.bat_val
bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval
val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval
val_bdim = set_vmap_param.vmap_index_param.slice_bdim
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
out_axes=[])
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
jaxpr, consts = state.discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval, *idx_avals])
jaxpr_, consts_ = state.discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, val_bdim, *idx_bdims),
out_axes=[ref_bdim])
vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx)
self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans,
check_dtypes=False)
@hp.given(set_vmap_params())
@hp.settings(deadline=None, print_blob=True, max_examples=50)
def test_addupdate_vmap(self, set_vmap_param: SetVmapParams):
indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims
def f(ref, val, *non_slice_idx):
idx = _pack_idx(non_slice_idx, indexed_dims)
state.ref_addupdate(ref, idx, val)
return []
ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval
bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval
bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals
ref_bdim = set_vmap_param.vmap_index_param.ref_bdim
idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims
non_slice_idx = set_vmap_param.bat_idxs
idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals
ref = set_vmap_param.bat_ref
val = set_vmap_param.bat_val
bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval
val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval
val_bdim = set_vmap_param.vmap_index_param.slice_bdim
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
out_axes=[])
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
jaxpr, consts = state.discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval, *idx_avals])
jaxpr_, consts_ = state.discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, val_bdim, *idx_bdims),
out_axes=[ref_bdim])
vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx)
self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans,
check_dtypes=False)
class StateControlFlowTest(jtu.JaxTestCase):
def test_simple_cond(self):
def f(pred):
def body(x_ref):
def true_fun():
x_ref[()] = 1.
def false_fun():
pass
lax.cond(pred, true_fun, false_fun)
return for_loop.run_state(body, 0.)
jaxpr = jax.make_jaxpr(f)(True).jaxpr
self.assertEmpty(jaxpr.effects)
self.assertAllClose(jax.jit(f)(True), 1.)
self.assertAllClose(jax.jit(f)(False), 0.)
def test_nested_cond(self):
def f(pred):
def body(x_ref):
def true_fun():
def true_fun_inner():
x_ref[()] = 1.
def false_fun_inner():
pass
return lax.cond(pred, true_fun_inner, false_fun_inner)
def false_fun():
pass
lax.cond(pred, true_fun, false_fun)
return for_loop.run_state(body, 0.)
jaxpr = jax.make_jaxpr(f)(True).jaxpr
self.assertEmpty(jaxpr.effects)
self.assertAllClose(jax.jit(f)(True), 1.)
self.assertAllClose(jax.jit(f)(False), 0.)
def test_cond_jvp_with_state(self):
def f(pred, init_value):
def body(x_ref):
def true_fun():
x_ref[()] = x_ref[()] ** 2
def false_fun():
pass
lax.cond(pred, true_fun, false_fun)
return for_loop.run_state(body, init_value)
out_primal, out_tangent = jax.jvp(partial(f, True), (3.,), (1.,))
self.assertAllClose(out_primal, 9.)
self.assertAllClose(out_tangent, 6.)
out_primal, out_tangent = jax.jvp(partial(f, False), (3.,), (1.,))
self.assertAllClose(out_primal, 3.)
self.assertAllClose(out_tangent, 1.)
def test_cond_vmap_not_implemented(self):
@jax.jit
def f(init_value):
def body(x_ref):
def true_fun():
x_ref[()] = x_ref[()] ** 2
def false_fun():
pass
lax.cond(x_ref[()] < 1, true_fun, false_fun)
return for_loop.run_state(body, init_value)
with self.assertRaises(NotImplementedError):
jax.vmap(f)(jnp.arange(2.))
def test_cond_grad_not_implemented(self):
@jax.jit
def f(init_value):
def body(x_ref):
def true_fun():
x_ref[()] = x_ref[()] ** 2
def false_fun():
pass
lax.cond(True, true_fun, false_fun)
return for_loop.run_state(body, init_value)
with self.assertRaises(NotImplementedError):
jax.grad(f)(3.)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())