rocm_jax/tests/state_test.py
vfdev-5 5a340a9781 Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true
Description:
- Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true to reduce the test execution time
  - especially for TSAN CI job where this test takes ~700 seconds to pass with the recent 3.13 cpython
- Removed optional deps for 3.14
2025-04-08 21:02:55 +00:00

1851 lines
72 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Callable, Sequence
from functools import partial
import itertools as it
from typing import Any, NamedTuple, Union
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax import api_util
from jax import random
from jax import lax
from jax._src import core
from jax._src import config
from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe
from jax._src import test_util as jtu
from jax._src.state import types as state_types
from jax._src.util import tuple_insert
import jax.numpy as jnp
from jax._src.lax.control_flow import for_loop
try:
import hypothesis as hp
import hypothesis.extra.numpy as hnp
import hypothesis.strategies as hps
CAN_USE_HYPOTHESIS = True
except (ModuleNotFoundError, ImportError):
CAN_USE_HYPOTHESIS = False
from jax._src.state.discharge import (run_state, run_state_reference,
discharge_state)
from jax._src.state.primitives import (get_p, swap_p, addupdate_p,
ref_addupdate, ref_get, ref_set,
ref_swap)
from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect,
AccumEffect, AbstractRef)
config.parse_flags_with_absl()
jtu.setup_hypothesis()
def wrap_init(f: Callable, nr_args: int):
# wrapper for lu.wrap_init with debugging info
return lu.wrap_init(
f,
debug_info=api_util.debug_info("state_test", f, (0,) * nr_args, {}))
class StatePrimitivesTest(jtu.JaxTestCase):
def test_get_abstract_aval_must_take_in_refs(self):
ref_aval = core.ShapedArray((), jnp.float32)
def f(x_ref):
return [ref_get(x_ref, ())]
with self.assertRaises(ValueError):
pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), [ref_aval])
@parameterized.named_parameters(
dict(testcase_name="trivial_get", ref_shape=(1, 2),
ref_dtype=jnp.float32,
idx=(), out_shape=(1, 2), out_dtype=jnp.float32),
dict(testcase_name="get_with_index", ref_shape=(1, 2),
ref_dtype=jnp.float32,
idx=(0,), out_shape=(2,), out_dtype=jnp.float32),
dict(testcase_name="get_with_nonleading_index", ref_shape=(1, 2),
ref_dtype=jnp.float32,
idx=(slice(None), 0), out_shape=(1,), out_dtype=jnp.float32),
dict(testcase_name="get_with_array_index", ref_shape=(1, 2, 3, 4),
ref_dtype=jnp.float32,
idx=(np.array([0, 1]),), out_shape=(2, 2, 3, 4),
out_dtype=jnp.float32),
dict(testcase_name="get_with_multiple_array_index",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(np.array([0, 1]), np.array([0, 1])),
out_shape=(2, 2, 4), out_dtype=jnp.float32),
dict(testcase_name="get_with_nonleading_multiple_array_index",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
out_shape=(2, 1, 2), out_dtype=jnp.float32),
dict(testcase_name="get_with_nontrivial_slice",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), np.array([0, 1]), slice(None), np.array([0, 1])),
out_shape=(2, 1, 2), out_dtype=jnp.float32),
dict(testcase_name="get_with_nontrivial_slice2",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), slice(1, 3), slice(None), slice(None)),
out_shape=(1, 2, 2, 4), out_dtype=jnp.float32),
dict(testcase_name="get_with_ref_simple_at",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(1, 3), slice(None), slice(None)),
out_shape=(2, 2, 4), out_dtype=jnp.float32,
at_indices=((0,),)),
dict(testcase_name="get_with_ref_simple_at2",
ref_shape=(6, 1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 2), slice(0, 1), slice(1, 3), slice(None), slice(None)),
out_shape=(2, 1, 2, 2, 4), out_dtype=jnp.float32,
at_indices=((slice(2, 6),),)),
dict(testcase_name="get_with_ref_multiple_at",
ref_shape=(1, 3, 5, 4), ref_dtype=jnp.float32,
idx=(slice(None), slice(None), slice(0, 2)),
out_shape=(3, 1, 2), out_dtype=jnp.float32,
at_indices=((0,), (slice(None), slice(0, 1)))),
)
def test_get_abstract_eval(self, ref_shape, ref_dtype, idx, out_shape=None,
out_dtype=None, at_indices=(),
should_error=False):
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
def f(x_ref):
for at_idx in at_indices:
x_ref = x_ref.at[at_idx]
out = ref_get(x_ref, idx)
return [out]
if should_error:
with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), [ref_aval])
else:
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f, 1), [ref_aval])
self.assertSetEqual(jaxpr.effects,
{ReadEffect(len(jaxpr.constvars))})
self.assertLen(out_avals, 1)
out_aval, = out_avals
self.assertIsInstance(out_aval, core.ShapedArray)
self.assertEqual(out_aval.shape, out_shape)
self.assertEqual(out_aval.dtype, out_dtype)
def test_swap_abstract_eval_must_take_in_refs(self):
ref_aval = core.ShapedArray((), jnp.float32)
val_aval = core.ShapedArray((), jnp.float32)
def f(x_ref, val):
return [ref_swap(x_ref, (), val)]
with self.assertRaises(ValueError):
pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
@parameterized.named_parameters(
dict(testcase_name="invalid_val_shape", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(), should_error=True),
dict(testcase_name="invalid_val_shape_slice", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(slice(None),), should_error=True),
dict(testcase_name="trivial_swap", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(1, 2), val_dtype=jnp.float32,
idx=(), out_shape=(1, 2), out_dtype=jnp.float32),
dict(testcase_name="bad_dtype", ref_shape=(1, 2),
ref_dtype=jnp.int32, val_shape=(1, 2), val_dtype=jnp.float32,
idx=(), should_error=True),
dict(testcase_name="swap_with_index", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(0,), out_shape=(2,), out_dtype=jnp.float32),
dict(testcase_name="swap_with_nonleading_index", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(1,), val_dtype=jnp.float32,
idx=(slice(None), 0), out_shape=(1,), out_dtype=jnp.float32),
dict(testcase_name="swap_with_nonleading_index_bad_val", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(slice(None), 0), should_error=True),
dict(testcase_name="swap_with_array_index", ref_shape=(1, 2, 3, 4),
ref_dtype=jnp.float32, val_shape=(2, 2, 3, 4), val_dtype=jnp.float32,
idx=(np.array([0, 1]),), out_shape=(2, 2, 3, 4),
out_dtype=jnp.float32),
dict(testcase_name="swap_with_multiple_array_index",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 2, 4), val_dtype=jnp.float32,
idx=(np.array([0, 1]), np.array([0, 1])),
out_shape=(2, 2, 4), out_dtype=jnp.float32),
dict(testcase_name="swap_with_nonleading_multiple_array_index",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 1, 2), val_dtype=jnp.float32,
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
out_shape=(2, 1, 2), out_dtype=jnp.float32),
dict(testcase_name="swap_with_nontrivial_slice",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), np.array([0, 1]), slice(None), np.array([0, 1])),
val_shape=(2, 1, 2), val_dtype=jnp.float32,
out_shape=(2, 1, 2), out_dtype=jnp.float32),
dict(testcase_name="swap_with_nontrivial_slice2",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), slice(1, 3), slice(None), slice(None)),
val_shape=(1, 2, 2, 4), val_dtype=jnp.float32,
out_shape=(1, 2, 2, 4), out_dtype=jnp.float32),
dict(testcase_name="swap_with_ref_simple_at",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(0, 1), slice(1, 3), slice(None),),
val_shape=(1, 1, 4), val_dtype=jnp.float32,
out_shape=(1, 1, 4), out_dtype=jnp.float32,
at_indices=((0,),),),
dict(testcase_name="swap_with_ref_simple_at2",
ref_shape=(4, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(None), slice(0, 1), slice(1, 3), slice(None),),
val_shape=(2, 1, 1, 4), val_dtype=jnp.float32,
out_shape=(2, 1, 1, 4), out_dtype=jnp.float32,
at_indices=((slice(0, 2),),),),
dict(testcase_name="swap_with_ref_multiple_at2",
ref_shape=(1, 4, 3, 2, 4), ref_dtype=jnp.float32,
idx=(slice(None), slice(0, 1), slice(1, 3), slice(None),),
val_shape=(2, 1, 1, 4), val_dtype=jnp.float32,
out_shape=(2, 1, 1, 4), out_dtype=jnp.float32,
at_indices=((slice(None), slice(0, 2),), (0,)),),
)
def test_swap_abstract_eval(self, ref_shape, ref_dtype,
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
at_indices=(), should_error=False):
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
val_aval = core.ShapedArray(val_shape, val_dtype)
def f(x_ref, val):
for at_idx in at_indices:
x_ref = x_ref.at[at_idx]
out = ref_swap(x_ref, idx, val)
return [out]
if should_error:
with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
else:
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f, 2), [ref_aval, val_aval])
self.assertSetEqual(jaxpr.effects,
{WriteEffect(len(jaxpr.constvars))})
self.assertLen(out_avals, 1)
out_aval, = out_avals
self.assertIsInstance(out_aval, core.ShapedArray)
self.assertEqual(out_aval.shape, out_shape)
self.assertEqual(out_aval.dtype, out_dtype)
@parameterized.named_parameters(
dict(testcase_name="invalid_val_shape", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(), should_error=True),
dict(testcase_name="invalid_val_shape_slice", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(slice(None),), should_error=True),
dict(testcase_name="trivial_addupdate", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(1, 2), val_dtype=jnp.float32,
idx=(),),
dict(testcase_name="bad_dtype", ref_shape=(1, 2),
ref_dtype=jnp.int32, val_shape=(1, 2), val_dtype=jnp.float32,
idx=(), should_error=True),
dict(testcase_name="addupdate_with_index", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(0,),),
dict(testcase_name="addupdate_with_nonleading_index", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(1,), val_dtype=jnp.float32,
idx=(slice(None), 0)),
dict(testcase_name="addupdate_with_nonleading_index_bad_val", ref_shape=(1, 2),
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
idx=(slice(None), 0), should_error=True),
dict(testcase_name="addupdate_with_array_index", ref_shape=(1, 2, 3, 4),
ref_dtype=jnp.float32, val_shape=(2, 2, 3, 4), val_dtype=jnp.float32,
idx=(np.array([0, 1]),)),
dict(testcase_name="addupdate_with_multiple_array_index",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 2, 4), val_dtype=jnp.float32,
idx=(np.array([0, 1]), np.array([0, 1]))),
dict(testcase_name="addupdate_with_nonleading_multiple_array_index",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 1, 2), val_dtype=jnp.float32,
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1]))),
dict(testcase_name="ref_with_simple_at",
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 2), val_dtype=jnp.float32,
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
at_indices=((0,),)),
dict(testcase_name="ref_with_simple_at2",
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 3, 4), val_dtype=jnp.float32,
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
at_indices=((slice(0, 3),),)),
dict(testcase_name="ref_with_multiple_at",
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 2), val_dtype=jnp.float32,
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
at_indices=((slice(0, 3),), (0,))),
dict(testcase_name="ref_with_multiple_at2",
ref_shape=(3, 3, 2, 4), ref_dtype=jnp.float32,
val_shape=(2, 2), val_dtype=jnp.float32,
idx=(np.array([0, 1]), slice(None), np.array([0, 1])),
at_indices=((slice(None), slice(0, 3),), (0,))),
)
def test_addupdate_abstract_eval(self, ref_shape, ref_dtype,
val_shape, val_dtype, idx, at_indices=(), should_error=False):
ref_aval = AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
val_aval = core.ShapedArray(val_shape, val_dtype)
def f(x_ref, val):
for at_idx in at_indices:
x_ref = x_ref.at[at_idx]
ref_addupdate(x_ref, idx, val)
return []
if should_error:
with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
else:
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f, 2), [ref_aval, val_aval])
self.assertSetEqual(jaxpr.effects,
{AccumEffect(len(jaxpr.constvars))})
self.assertLen(out_avals, 0)
def test_addupdate_abstract_eval_must_take_in_refs(self):
ref_aval = core.ShapedArray((), jnp.float32)
val_aval = core.ShapedArray((), jnp.float32)
def f(x_ref, val):
return [ref_addupdate(x_ref, (), val)]
with self.assertRaises(ValueError):
pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval])
def test_can_represent_get_and_swap_in_jaxprs(self):
def body(x):
x[()] = jnp.int32(1)
x[()] = jnp.int32(2)
return (x[()],)
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
self.assertEqual(jaxpr.eqns[0].primitive, swap_p)
self.assertEqual(jaxpr.eqns[1].primitive, swap_p)
self.assertEqual(jaxpr.eqns[2].primitive, get_p)
def test_can_represent_addupdate_in_jaxprs(self):
def body(x):
ref_addupdate(x, (), jnp.int32(1))
return (x[()],)
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p)
def test_get_custom_pretty_printing_rule(self):
def body(x_ref):
x = x_ref[()]
return [x]
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False))
def body(x_ref):
x = x_ref[:, 0]
return [x]
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
wrap_init(body, 1), [shaped_array_ref((1, 2), jnp.int32)])
self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False))
def test_set_custom_pretty_printing_rule(self):
def body(x_ref):
x_ref[()] = jnp.int32(2)
return []
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False))
def body(x_ref, val):
x_ref[:, 0] = val
return []
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False))
def test_swap_custom_pretty_printing_rule(self):
def body(x_ref):
x = ref_swap(x_ref, (), jnp.int32(2))
return [x]
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False))
def body(x_ref, val):
x = ref_swap(x_ref, (slice(None), 0), val)
return [x]
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("c:i32[1], a[:,0] <- a[:,0], b",
jaxpr.pretty_print(use_color=False))
def test_addupdate_custom_pretty_printing_rule(self):
def body(x_ref):
ref_addupdate(x_ref, (), jnp.int32(2))
return []
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False))
def body(x_ref, val):
ref_addupdate(x_ref, (slice(None), 0), val)
return []
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False))
def test_get_jvp(self):
def f(r):
x = r[()]
return jnp.cos(x)
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, get_p)
self.assertEqual(jaxpr.eqns[1].primitive, get_p)
def test_swap_jvp(self):
def f(a):
x = a[()]
a[()] = jnp.sin(x)
return a[()]
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, get_p)
self.assertEqual(jaxpr.eqns[1].primitive, get_p)
self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p)
self.assertEqual(jaxpr.eqns[3].primitive, lax.cos_p)
self.assertEqual(jaxpr.eqns[4].primitive, lax.mul_p)
self.assertEqual(jaxpr.eqns[5].primitive, swap_p)
self.assertEqual(jaxpr.eqns[6].primitive, swap_p)
def test_addupdate_jvp(self):
def f(a):
ref_addupdate(a, (), jnp.float32(1.))
return a[()]
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p)
self.assertEqual(jaxpr.eqns[1].primitive, addupdate_p)
self.assertEqual(jaxpr.eqns[2].primitive, get_p)
self.assertEqual(jaxpr.eqns[3].primitive, get_p)
@jtu.sample_product(
[dict(ref_shape=ref_shape, ref_bdim=ref_bdim, idx_shape=idx_shape,
indexed_dims=indexed_dims, idx_bdims=idx_bdims, out_bdim=out_bdim)
for ref_shape in [(1,), (2, 3), (4, 5, 6)]
for ref_bdim in range(1 + len(ref_shape))
for idx_shape in [(), (1,), (2,), (5, 6)]
for indexed_dims in it.product([True, False], repeat=len(ref_shape))
for idx_bdims in it.product([None, *range(1 + len(idx_shape))],
repeat=sum(indexed_dims))
for out_bdim in range(1 + len(ref_shape) - sum(indexed_dims)
+ len(idx_shape) * any(indexed_dims))
],
op=[
lambda x_ref, indexer: [x_ref[indexer]],
lambda x_ref, indexer: [
ref_swap(x_ref, indexer,
jnp.ones(x_ref.shape, x_ref.dtype)[None][(0,
*indexer)])],
lambda x_ref, indexer: (
ref_addupdate(x_ref, indexer,
jnp.ones(x_ref.shape, x_ref.dtype)[None][(0,
*indexer)])
or [jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, *indexer)]])
],
)
def test_vmap(self, ref_shape, ref_bdim, idx_shape, indexed_dims,
idx_bdims, out_bdim, op):
float_ = (jnp.dtype('float64') if config.enable_x64.value else
jnp.dtype('float32'))
int_ = (jnp.dtype('int64') if config.enable_x64.value else
jnp.dtype('int32'))
axis_size = 7
out_shape = tuple(d for d, b in zip(ref_shape, indexed_dims) if not b)
if any(indexed_dims):
out_shape = (*idx_shape, *out_shape)
def maybe_insert(shape, idx):
if idx is None:
return shape
return tuple_insert(shape, idx, axis_size)
batched_ref_shape = maybe_insert(ref_shape, ref_bdim)
ref_aval = shaped_array_ref(ref_shape, float_)
bat_ref_aval = shaped_array_ref(batched_ref_shape, float_)
idx_avals = [core.ShapedArray(idx_shape, int_)
for _ in idx_bdims]
bat_idx_avals = [
core.ShapedArray(maybe_insert(idx_shape, idx_bdim), int_)
for idx_bdim in idx_bdims]
def f(x_ref, *idxs):
idxs_ = iter(idxs)
indexer = tuple(next(idxs_) if b else slice(None) for b in indexed_dims)
return op(x_ref, indexer)
rng = self.rng()
a = rng.randn(*bat_ref_aval.shape)
his = [d for d, b in zip(ref_aval.shape, indexed_dims) if b]
idxs = [rng.randint(low=0, high=hi, size=i.shape)
for i, hi in zip(bat_idx_avals, his)]
# discharge-of-vmap
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f_batched, 1 + len(bat_idx_avals)), [bat_ref_aval, *bat_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, *idx_bdims),
out_axes=[out_bdim, ref_bdim])
vmap_of_discharge_ans = f_batched(a, *idxs)
self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans,
check_dtypes=False)
class StateDischargeTest(jtu.JaxTestCase):
def test_discharge_get(self):
def f(a_ref):
a = ref_get(a_ref, ())
return [a + 1]
in_avals = [shaped_array_ref((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 2)
self.assertEqual(discharged_jaxpr.eqns[0].primitive, lax.add_p)
# Should be able to evaluate this jaxpr
self.assertListEqual(core.eval_jaxpr(discharged_jaxpr, (),
jnp.float32(1.)), [2., 1.])
def test_discharge_get_with_slice(self):
def f(a_ref):
a = ref_get(a_ref, (0, 1))
return [a + 1]
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 2)
self.assertIn(lax.dynamic_slice_p,
{eqn.primitive for eqn in discharged_jaxpr.eqns})
# Should be able to evaluate this jaxpr
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
outval, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((outval == inval[0, 1] + 1).all())
self.assertTrue((refval == inval).all())
def test_discharge_get_with_gather(self):
def f(a_ref):
a = a_ref[jnp.array([0, 1])]
return [a + 1]
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f, 1), in_avals)
discharged_jaxpr, discharged_consts = discharge_state(
stateful_jaxpr, consts)
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
outval, refval = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval)
self.assertTrue((outval == inval[jnp.array([0, 1])] + 1).all())
self.assertTrue((refval == inval).all())
def test_discharge_set(self):
def f(a_ref, b):
ref_set(a_ref, (), b + 1)
return []
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
# Discharging should just turn this into a jaxpr that ignores the first
# value and returns second value plus 1.
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 2)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.),
jnp.float32(1.))[0], 2.)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.),
jnp.float32(1.))[0], 2.)
def test_discharge_set_with_slice(self):
def f(a_ref):
ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32')))
return []
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertIn(lax.dynamic_update_slice_p,
{eqn.primitive for eqn in discharged_jaxpr.eqns})
self.assertIn(lax.dynamic_slice_p,
{eqn.primitive for eqn in discharged_jaxpr.eqns})
# Should be able to evaluate this jaxpr
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
refval, = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((refval == inval.at[0, 1].set(1.)).all())
def test_discharge_set_with_gather(self):
def f(a_ref):
a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32')
return []
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
discharged_jaxpr, discharged_consts = discharge_state(
stateful_jaxpr, consts)
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
refval, = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval)
self.assertTrue((refval == inval.at[jnp.array([0, 1])].set(1.)).all())
def test_discharge_swap(self):
def f(a_ref):
a = ref_swap(
a_ref.at[0:4, 0:3, 0:2].at[1:3, :, 0],
(slice(None), slice(1, 3)),
jnp.zeros((2, 2), jnp.float32))
return [a + 1]
in_avals = [shaped_array_ref((4, 3, 2), jnp.float32)]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f, 1), in_avals)
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 2)
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
outval, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertArraysEqual(outval, inval[1:3, 1:3, 0] + 1)
self.assertArraysEqual(refval, inval.at[1:3, 1:3, 0].set(0))
def test_discharge_addupdate(self):
def f(a_ref, b):
ref_addupdate(a_ref, (), b + 1)
return []
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2),
in_avals)
# Discharging should just turn this into a jaxpr that adds the first value,
# second value, and 1.
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 2)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.),
jnp.float32(1.))[0], 2.)
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.),
jnp.float32(1.))[0], 4.)
def test_discharge_addupdate_with_slice(self):
def f(a_ref):
ref_addupdate(a_ref, (0, 1),
jnp.ones(2, dtype=jnp.dtype('float32')))
return []
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertIn(lax.dynamic_update_slice_p,
{eqn.primitive for eqn in discharged_jaxpr.eqns})
self.assertIn(lax.add_p,
{eqn.primitive for eqn in discharged_jaxpr.eqns})
self.assertIn(lax.dynamic_slice_p,
{eqn.primitive for eqn in discharged_jaxpr.eqns})
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
refval, = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((refval == inval.at[0, 1].add(1.)).all())
def test_discharge_addupdate_with_gather(self):
def f(a_ref):
ref_addupdate(a_ref, (jnp.array([0, 1]),),
jnp.ones((2, 3), 'float32'))
return []
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
discharged_jaxpr, discharged_consts = discharge_state(
stateful_jaxpr, consts)
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
refval, = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval)
self.assertTrue((refval == inval.at[jnp.array([0, 1])].add(1.)).all())
def test_discharge_jaxpr_with_multiple_outputs(self):
def f(a_ref):
a = ref_get(a_ref, ())
b = a + 1
return [a, b]
in_avals = [shaped_array_ref((4,), jnp.dtype('float32'))]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1),
in_avals)
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
self.assertLen(discharged_jaxpr.outvars, 3)
inval = jnp.arange(4., dtype=jnp.float32)
a, b, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
self.assertTrue((a == inval).all())
self.assertTrue((b == inval + 1).all())
self.assertTrue((refval == inval).all())
def test_partially_discharging_jaxpr_keeps_refs(self):
def f(a_ref, b_ref):
ref_set(a_ref, (), jnp.ones(4, jnp.float32))
ref_set(b_ref, (), jnp.ones(4, jnp.float32))
return []
in_avals = [
shaped_array_ref((4,), jnp.dtype('float32')),
shaped_array_ref((4,), jnp.dtype('float32'))
]
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2),
in_avals)
discharged_jaxpr, _ = discharge_state(
stateful_jaxpr, consts, should_discharge=[False, True])
self.assertLen(discharged_jaxpr.invars, 2)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertIsInstance(discharged_jaxpr.invars[0].aval, AbstractRef)
self.assertIsInstance(discharged_jaxpr.invars[1].aval, core.ShapedArray)
self.assertEqual(discharged_jaxpr.effects,
{WriteEffect(len(discharged_jaxpr.constvars))})
def test_ellipsis_index(self):
def f(ref):
ref_set(ref, ..., jnp.array(0., dtype=jnp.float32))
ref_get(ref, ...)
ref[...] = jnp.array(0., dtype=jnp.float32)
ref[...]
return []
in_avals = [shaped_array_ref((), jnp.float32)]
pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals)
def test_partial_discharge(self):
def f(a_ref, b_ref):
a_ref[...] = jnp.array(0., dtype=jnp.float32)
b_ref[...] = jnp.array(1., dtype=jnp.float32)
return a_ref[...], b_ref[...]
scalar_ref_1 = shaped_array_ref((), jnp.float32)
scalar_ref_2 = shaped_array_ref((), jnp.float32)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f, 2), [scalar_ref_1, scalar_ref_2])
discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True])
prim_count = lambda p, jaxpr: sum(eqn.primitive == p for eqn in jaxpr.eqns)
self.assertEqual(prim_count(swap_p, jaxpr) // 2, prim_count(swap_p, discharged_jaxpr))
self.assertEqual(prim_count(get_p, jaxpr) // 2, prim_count(get_p, discharged_jaxpr))
def test_partial_fori_discharge(self):
def f(a_ref, b_ref):
def body(i, st):
a_ref[...] += 2 * i
b_ref[...] += i
return ()
lax.fori_loop(0, 5, body, init_val=())
return a_ref[...], b_ref[...]
ref = lambda x: AbstractRef(core.get_aval(x))
f_jaxpr = jax.make_jaxpr(f)(ref(1.), ref(2.))
jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, True])
# Effects on y_ref were discharged away but not the effects on x_ref
self.assertEqual(f_jaxpr.effects, {ReadEffect(0), WriteEffect(0), ReadEffect(1), WriteEffect(1)})
self.assertEqual(jaxpr.effects, {ReadEffect(0), WriteEffect(0)})
# x_ref arg is still a reference but y_ref is discharged
self.assertNotIsInstance(jaxpr.invars[1].aval, AbstractRef)
self.assertIsInstance(jaxpr.invars[0].aval, AbstractRef)
# x_ref value is returned as part of the discharged refs set.
self.assertLen(f_jaxpr.out_avals, 2)
self.assertLen(jaxpr.outvars, 3)
if CAN_USE_HYPOTHESIS:
def index_arrays(size, idx_shape):
valid_idx = hps.integers(min_value=-size, max_value=size - 1)
return hnp.arrays(np.int32, idx_shape, elements=valid_idx)
Shape = tuple[int, ...]
class IndexParam(NamedTuple):
ref_aval: shaped_array_ref
ref_shape: Shape
indexed_dims: list[bool]
idx_avals: tuple[core.ShapedArray, ...]
idx_shape: Shape
slice_aval: core.ShapedArray
slice_shape: Shape
@hps.composite
def index_params(draw):
ref_shape = draw(hnp.array_shapes(max_dims=4, max_side=7), label='ref_shape')
indexed_dims = draw(hps.lists(hps.booleans(),
min_size=len(ref_shape),
max_size=len(ref_shape)))
idx_shape = draw(hnp.array_shapes(max_dims=3, max_side=5))
if any(indexed_dims):
sliced_shape = (s for s, b in zip(ref_shape, indexed_dims) if not b)
slice_shape = (*idx_shape, *sliced_shape)
else:
slice_shape = ref_shape
ref_aval = shaped_array_ref(ref_shape, np.float32)
idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in
range(sum(indexed_dims)))
slice_aval = core.ShapedArray(slice_shape, np.float32)
return IndexParam(ref_aval, ref_shape, indexed_dims, idx_avals, idx_shape,
slice_aval, slice_shape)
class VmappableIndexParam(NamedTuple):
index_param: IndexParam
ref_bdim: int | None
non_slice_idx_bdims: tuple[int | None, ...]
slice_bdim: int
bat_ref_aval: shaped_array_ref
bat_ref_shape: Shape
bat_non_slice_idx_avals: tuple[core.ShapedArray, ...]
bat_non_slice_idx_shapes: tuple[Shape, ...]
bat_slice_aval: core.ShapedArray
bat_slice_shape: Shape
def maybe_tuple_insert(t: tuple[Any, ...], idx: int | None,
val: Any) -> tuple[Any, ...]:
if idx is None:
return t
return tuple_insert(t, idx, val)
@hps.composite
def vmappable_index_params(draw, *, op_type: str):
axis_size = draw(hps.integers(min_value=1, max_value=7), label='axis_size')
index_param: IndexParam = draw(index_params())
non_slice_idx_bdims = tuple(
draw(hps.one_of(
hps.none(),
hps.integers(min_value=0, max_value=len(index_param.idx_shape))))
for b in index_param.indexed_dims if b)
bat_non_slice_idx_shapes = tuple(
maybe_tuple_insert(index_param.idx_shape, idx_bdim, axis_size)
for idx_bdim in non_slice_idx_bdims)
if op_type == "swap":
# In a swap, the ref *must* be batched
ref_bdim = draw(hps.integers(min_value=0,
max_value=len(index_param.ref_shape)))
if any(idx_bdim is not None for idx_bdim in non_slice_idx_bdims):
# If it's a swap, if indices are batched, val must be batched.
slice_bdim = draw(hps.integers(
min_value=0, max_value=len(index_param.slice_shape)))
else:
slice_bdim = draw(hps.one_of(hps.none(), hps.integers(
min_value=0, max_value=len(index_param.slice_shape))))
elif op_type == "get":
# In a get, the indices must be batched or ref is batched
if all(idx_bdim is None for idx_bdim in non_slice_idx_bdims):
ref_bdim = draw(hps.integers(min_value=0,
max_value=len(index_param.ref_shape)))
else:
ref_bdim = draw(hps.one_of(hps.none(),
hps.integers(min_value=0, max_value=len(index_param.ref_shape))))
slice_bdim = draw(hps.integers(
min_value=0, max_value=len(index_param.slice_shape)))
bat_ref_shape = maybe_tuple_insert(index_param.ref_shape, ref_bdim, axis_size)
bat_ref_aval = shaped_array_ref(bat_ref_shape, np.float32)
bat_non_slice_idx_avals = tuple(
core.ShapedArray(shape, np.int32) for shape in bat_non_slice_idx_shapes)
bat_slice_shape = maybe_tuple_insert(index_param.slice_shape, slice_bdim, axis_size)
bat_slice_aval = core.ShapedArray(bat_slice_shape, np.float32)
return VmappableIndexParam(index_param, ref_bdim, non_slice_idx_bdims,
slice_bdim, bat_ref_aval, bat_ref_shape,
bat_non_slice_idx_avals, bat_non_slice_idx_shapes,
bat_slice_aval, bat_slice_shape)
class GetVmapParams(NamedTuple):
vmap_index_param: VmappableIndexParam
bat_ref: np.ndarray
bat_idxs: tuple[np.ndarray, ...]
@hps.composite
def get_vmap_params(draw):
vmap_index_param: VmappableIndexParam = draw(
vmappable_index_params(op_type="get"))
bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape))
bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes)
bat_idxs = tuple(
draw(index_arrays(size, next(bat_idx_shapes_)))
for size, indexed in zip(
vmap_index_param.index_param.ref_shape,
vmap_index_param.index_param.indexed_dims)
if indexed)
assert next(bat_idx_shapes_, None) is None
return GetVmapParams(vmap_index_param, bat_ref, bat_idxs)
class SetVmapParams(NamedTuple):
vmap_index_param: VmappableIndexParam
bat_ref: np.ndarray
bat_val: np.ndarray
bat_idxs: tuple[np.ndarray, ...]
@hps.composite
def set_vmap_params(draw):
vmap_index_param: VmappableIndexParam = draw(vmappable_index_params(
op_type="swap"))
bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape))
bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes)
bat_idxs = tuple(
draw(index_arrays(size, next(bat_idx_shapes_)))
for size, indexed in zip(
vmap_index_param.index_param.ref_shape,
vmap_index_param.index_param.indexed_dims)
if indexed)
assert next(bat_idx_shapes_, None) is None
bat_val = draw(hnp.arrays(np.float32, vmap_index_param.bat_slice_shape))
return SetVmapParams(vmap_index_param, bat_ref, bat_val, bat_idxs)
Indexer = tuple[Union[int, slice, np.ndarray]]
def _unpack_idx(idx: Indexer
) -> tuple[Sequence[int | np.ndarray], Sequence[bool]]:
indexed_dims = [type(i) != slice for i in idx]
non_slice_idx = [i for i, b in zip(idx, indexed_dims) if b]
return non_slice_idx, indexed_dims
def _pack_idx(non_slice_idx: Sequence[int | np.ndarray],
indexed_dims: Sequence[bool]) -> Indexer:
idx_ = iter(non_slice_idx)
idx = tuple(next(idx_) if b else slice(None) for b in indexed_dims)
assert next(idx_, None) is None
return idx
@jtu.thread_unsafe_test_class() # hypothesis isn't thread-safe
class StateHypothesisTest(jtu.JaxTestCase):
@hp.given(get_vmap_params())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_get_vmap(self, get_vmap_param: GetVmapParams):
indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims
def f(ref, *non_slice_idx):
idx = _pack_idx(non_slice_idx, indexed_dims)
return [ref_get(ref, idx)]
ref_aval = get_vmap_param.vmap_index_param.index_param.ref_aval
bat_ref_aval = get_vmap_param.vmap_index_param.bat_ref_aval
bat_non_slice_idx_avals = get_vmap_param.vmap_index_param.bat_non_slice_idx_avals
ref_bdim = get_vmap_param.vmap_index_param.ref_bdim
idx_bdims = get_vmap_param.vmap_index_param.non_slice_idx_bdims
out_bdim = get_vmap_param.vmap_index_param.slice_bdim
non_slice_idx = get_vmap_param.bat_idxs
idx_avals = get_vmap_param.vmap_index_param.index_param.idx_avals
ref = get_vmap_param.bat_ref
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f_batched, 1 + len(bat_non_slice_idx_avals)),
[bat_ref_aval, *bat_non_slice_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, *idx_bdims),
out_axes=[out_bdim, ref_bdim])
vmap_of_discharge_ans = f_batched(ref, *non_slice_idx)
self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans,
check_dtypes=False)
@hp.given(set_vmap_params())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_set_vmap(self, set_vmap_param: SetVmapParams):
if jtu.test_device_matches(["gpu"]):
self.skipTest("Scatter is nondeterministic on GPU")
indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims
def f(ref, val, *non_slice_idx):
idx = _pack_idx(non_slice_idx, indexed_dims)
ref_set(ref, idx, val)
return []
ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval
bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval
bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals
ref_bdim = set_vmap_param.vmap_index_param.ref_bdim
idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims
non_slice_idx = set_vmap_param.bat_idxs
idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals
ref = set_vmap_param.bat_ref
val = set_vmap_param.bat_val
bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval
val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval
val_bdim = set_vmap_param.vmap_index_param.slice_bdim
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
out_axes=[])
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)),
[bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, val_bdim, *idx_bdims),
out_axes=[ref_bdim])
vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx)
self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans,
check_dtypes=False)
@hp.given(set_vmap_params())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_addupdate_vmap(self, set_vmap_param: SetVmapParams):
indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims
def f(ref, val, *non_slice_idx):
idx = _pack_idx(non_slice_idx, indexed_dims)
ref_addupdate(ref, idx, val)
return []
ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval
bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval
bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals
ref_bdim = set_vmap_param.vmap_index_param.ref_bdim
idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims
non_slice_idx = set_vmap_param.bat_idxs
idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals
ref = set_vmap_param.bat_ref
val = set_vmap_param.bat_val
bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval
val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval
val_bdim = set_vmap_param.vmap_index_param.slice_bdim
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
out_axes=[])
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)),
[bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
in_axes=(ref_bdim, val_bdim, *idx_bdims),
out_axes=[ref_bdim])
vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx)
self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans,
check_dtypes=False)
class StateControlFlowTest(jtu.JaxTestCase):
def test_simple_cond(self):
def f(pred):
def body(x_ref):
def true_fun():
x_ref[()] = 1.
def false_fun():
pass
lax.cond(pred, true_fun, false_fun)
return for_loop.run_state(body, 0.)
jaxpr = jax.make_jaxpr(f)(True).jaxpr
self.assertEmpty(jaxpr.effects)
self.assertAllClose(jax.jit(f)(True), 1.)
self.assertAllClose(jax.jit(f)(False), 0.)
def test_simple_cond_with_return(self):
def f(pred):
def body(refs):
x_ref, y_ref = refs
def true_fun():
x_ref[()] = 1.
return 4.
def false_fun():
return 5.
out = lax.cond(pred, true_fun, false_fun)
y_ref[...] = out
return for_loop.run_state(body, (0., 0.))
jaxpr = jax.make_jaxpr(f)(True).jaxpr
self.assertEmpty(jaxpr.effects)
out = jax.jit(f)(True)
self.assertTupleEqual(out, (1., 4.))
out = jax.jit(f)(False)
self.assertTupleEqual(out, (0., 5.))
def test_cond_discharge(self):
def f0(pred, x_ref, y_ref):
def true_fun():
x_ref[...] = 1.
def false_fun():
y_ref[...] = 2.
lax.cond(pred, true_fun, false_fun)
return x_ref[...], y_ref[...]
ref = lambda x: AbstractRef(core.get_aval(x))
f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.))
jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True])
# Effects on y_ref were discharged away but not the effects on x_ref
self.assertEqual(f_jaxpr.effects, {ReadEffect(1), WriteEffect(1), ReadEffect(2), WriteEffect(2)})
self.assertEqual(jaxpr.effects, {ReadEffect(1), WriteEffect(1)})
# x_ref arg is still a reference but y_ref is discharged
self.assertNotIsInstance(jaxpr.invars[2].aval, AbstractRef)
self.assertIsInstance(jaxpr.invars[1].aval, AbstractRef)
# x_ref value is returned as part of the discharged refs set.
self.assertLen(f_jaxpr.out_avals, 2)
self.assertLen(jaxpr.outvars, 3)
def test_cond_with_ref_reuse(self):
def f(pred):
def body(x_ref):
def true_fun():
x_ref[()] = 1.
def false_fun():
x_ref[()] = 2.
lax.cond(pred, true_fun, false_fun)
return run_state(body)(0.)
jaxpr = jax.make_jaxpr(f)(True).jaxpr
self.assertEmpty(jaxpr.effects)
out_true = jax.jit(f)(True)
expected_true = 1.
self.assertAllClose(out_true, expected_true)
out_false = jax.jit(f)(False)
expected_false = 2.
self.assertAllClose(out_false, expected_false)
def test_cond_readonly_refs(self):
def f(pred):
def body(refs):
x_ref, y_ref, z_ref = refs
def true_fun():
y_ref[()] = x_ref[()]
def false_fun():
y_ref[()] = x_ref[()] + z_ref[()]
lax.cond(pred, true_fun, false_fun)
return run_state(body)((1., 0., 2.))
jaxpr = jax.make_jaxpr(f)(True).jaxpr
[run_state_eqn] = jaxpr.eqns
*_, cond_eqn = discharge_state(run_state_eqn.params["jaxpr"], ())[0].eqns
self.assertIs(cond_eqn.primitive, lax.cond_p)
self.assertLen(cond_eqn.invars, 4) # pred + 3x ref values
self.assertLen(cond_eqn.outvars, 1) # only the updated ref value
self.assertAllClose(jax.jit(f)(True), (1., 1., 2.))
self.assertAllClose(jax.jit(f)(False), (1., 3., 2.))
def test_simple_cond_using_multiple_refs_with_interleaved_consts(self):
def f(pred):
def body(refs):
x_ref, y_ref, z_ref = refs
a = jnp.array(2.)
b = jnp.array(2.)
def true_fun():
x_ref[()] = 1. * a
y_ref[()] = 1.
def false_fun():
z_ref[()] = 1.
x_ref[()] = 2.
x_ref[()] = 2. * b
lax.cond(pred, true_fun, false_fun)
return run_state(body)((0., 0., 0.))
jaxpr = jax.make_jaxpr(f)(True).jaxpr
self.assertEmpty(jaxpr.effects)
out_true = jax.jit(f)(True)
expected_true = (2., 1., 0.)
self.assertAllClose(out_true, expected_true)
out_false = jax.jit(f)(False)
expected_false = (4., 0., 1.)
self.assertAllClose(out_false, expected_false)
def test_cond_nested_run_state_using_multiple_refs_with_interleaved_consts(
self):
def f(pred):
@run_state
def body(refs):
x_ref, o_ref1, o_ref2 = refs
@run_state
def body2(refs):
y_ref, o_ref3 = refs
@run_state
def body3(z_ref):
a = jnp.array(2.)
b = jnp.array(2.)
def true_fun():
x_ref[()] = 1. * a
y_ref[()] = 1.
def false_fun():
z_ref[()] = 1.
x_ref[()] = 2.
x_ref[()] = 2. * b
lax.cond(pred, true_fun, false_fun)
o_ref3[...] = body3(0.)
o_ref1[...], o_ref2[...] = body2((0., 0.))
return body((0., 0., 0.))
jaxpr = jax.make_jaxpr(f)(True).jaxpr
self.assertEmpty(jaxpr.effects)
out_true = jax.jit(f)(True)
expected_true = (2., 1., 0.)
self.assertAllClose(out_true, expected_true)
out_false = jax.jit(f)(False)
expected_false = (4., 0., 1.)
self.assertAllClose(out_false, expected_false)
def test_nested_cond_using_multiple_refs_with_interleaved_consts(self):
def f(pred1, pred2):
def body(refs):
x_ref, y_ref, z_ref = refs
a = jnp.array(2.)
def true_fun():
b = jnp.array(2.)
def inner_true_fun():
x_ref[()] = 1. * a
y_ref[()] = 1.
def inner_false_fun():
z_ref[()] = 1.
x_ref[()] = 2.
x_ref[()] = 2. * b
lax.cond(pred2, inner_true_fun, inner_false_fun)
def false_fun():
a = jnp.array(2.)
b = jnp.array(2.)
def inner_true_fun():
x_ref[()] = 1. * a
z_ref[()] = 4.
def inner_false_fun():
x_ref[()] = 2. * b
lax.cond(pred2, inner_true_fun, inner_false_fun)
lax.cond(pred1, true_fun, false_fun)
return run_state(body)((0., 0., 0.))
jaxpr = jax.make_jaxpr(f)(True, False).jaxpr
self.assertEmpty(jaxpr.effects)
out = jax.jit(f)(True, True)
expected = (2., 1., 0.)
self.assertTupleEqual(out, expected)
out = jax.jit(f)(True, False)
expected = (4., 0., 1.)
self.assertTupleEqual(out, expected)
out = jax.jit(f)(False, True)
expected = (2., 0., 4.)
self.assertTupleEqual(out, expected)
out = jax.jit(f)(False, False)
expected = (4., 0., 0.)
self.assertTupleEqual(out, expected)
def test_nested_cond(self):
def f(pred):
def body(x_ref):
def true_fun():
def true_fun_inner():
x_ref[()] = 1.
def false_fun_inner():
pass
return lax.cond(pred, true_fun_inner, false_fun_inner)
def false_fun():
pass
lax.cond(pred, true_fun, false_fun)
return for_loop.run_state(body, 0.)
jaxpr = jax.make_jaxpr(f)(True).jaxpr
self.assertEmpty(jaxpr.effects)
self.assertAllClose(jax.jit(f)(True), 1.)
self.assertAllClose(jax.jit(f)(False), 0.)
def test_cond_jvp_with_state(self):
def f(pred, init_value):
def body(x_ref):
def true_fun():
x_ref[()] = x_ref[()] ** 2
def false_fun():
pass
lax.cond(pred, true_fun, false_fun)
return for_loop.run_state(body, init_value)
out_primal, out_tangent = jax.jvp(partial(f, True), (3.,), (1.,))
self.assertAllClose(out_primal, 9.)
self.assertAllClose(out_tangent, 6.)
out_primal, out_tangent = jax.jvp(partial(f, False), (3.,), (1.,))
self.assertAllClose(out_primal, 3.)
self.assertAllClose(out_tangent, 1.)
def test_cond_vmap_not_implemented(self):
@jax.jit
def f(init_value):
def body(x_ref):
def true_fun():
x_ref[()] = x_ref[()] ** 2
def false_fun():
pass
lax.cond(x_ref[()] < 1, true_fun, false_fun)
return for_loop.run_state(body, init_value)
with self.assertRaises(NotImplementedError):
jax.vmap(f)(jnp.arange(2.))
def test_cond_grad_not_implemented(self):
@jax.jit
def f(init_value):
def body(x_ref):
def true_fun():
x_ref[()] = x_ref[()] ** 2
def false_fun():
pass
lax.cond(True, true_fun, false_fun)
return for_loop.run_state(body, init_value)
with self.assertRaises(NotImplementedError):
jax.grad(f)(3.)
def test_while_with_state_in_body(self):
def f(x, y, z):
@run_state
def body(x_ref):
def cond(i):
return i < y
def body(i):
x_ref[...] += z
return i + 1
lax.while_loop(cond, body, 0)
return body(x)
jaxpr = jax.make_jaxpr(f)(0, 5, 2).jaxpr
self.assertEmpty(jaxpr.effects)
self.assertAllClose(jax.jit(f)(0, 5, 2), 10)
self.assertAllClose(jax.jit(f)(1, 2, 3), 7)
def test_scan_with_state_in_body(self):
def f(x, w, y, zs):
@run_state
def body(refs):
x_ref, w_ref = refs
def body(y, z):
x_ref[...] += y
w_ref[...] += z
return y + 1, ()
lax.scan(body, y, zs)
return body((x, w))
zs = jnp.arange(5)
jaxpr = jax.make_jaxpr(f)(0, 1, 5, zs).jaxpr
self.assertEmpty(jaxpr.effects)
self.assertAllClose(jax.jit(f)(0, 1, 5, zs), (35, 11))
self.assertAllClose(jax.jit(f)(1, 1, 2, zs), (21, 11))
def test_scan_with_state_in_body_nested(self):
@run_state
def g(refs):
a_ref, x_ref, w_ref, y_ref, zs_ref = refs
def f(x, w, y, zs):
@run_state
def loop(refs):
x_ref, w_ref = refs
def body(y, z):
x_ref[...] += y
w_ref[...] += z
a_ref[...] += z
return y + 1, ()
lax.scan(body, y, zs)
a_ref[...] += 1
out = loop((x, w))
a_ref[...] += 1
return out
x, w = f(x_ref[...], w_ref[...], y_ref[...], zs_ref[...])
x_ref[...] = x
w_ref[...] = w
zs = jnp.arange(5)
jaxpr = jax.make_jaxpr(g)((1, 0, 1, 5, zs)).jaxpr
self.assertEmpty(jaxpr.effects)
self.assertAllClose(jax.jit(g)((1, 0, 1, 5, zs))[:3], (13, 35, 11))
self.assertAllClose(jax.jit(g)((1, 1, 1, 2, zs))[:3], (13, 21, 11))
class GeneralRefTest(jtu.JaxTestCase):
def test_token(self):
def f(x_ref):
x = x_ref[...]
x_ref[...] = x
ref_addupdate(x_ref, (), x)
return [x]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f, 1), [AbstractRef(core.AbstractToken())])
self.assertIs(type(jaxpr.outvars[0].aval), core.AbstractToken)
def test_ref_of_ref(self):
def f(x_ref_ref):
x_ref = x_ref_ref[...]
return [x_ref]
# Not sure why you'd ever want to do this, but it works!
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
wrap_init(f, 1),
[AbstractRef(AbstractRef(core.ShapedArray((), jnp.int32)))])
self.assertIs(type(jaxpr.outvars[0].aval), AbstractRef)
self.assertIs(type(jaxpr.outvars[0].aval.inner_aval), core.ShapedArray)
class RunStateTest(jtu.JaxTestCase):
def test_simple_run_state(self):
out = run_state(lambda _: None)(1)
self.assertEqual(out, 1)
def test_nontrivial_run_state(self):
def f(refs):
x_ref, y_ref = refs
x = x_ref[...] * y_ref[...]
y_ref[...] = x * 2
x_ref[...] = y_ref[...] + x_ref[...]
# x + x * y * 2, x * y * 2
x, y = run_state(f)((2, 3))
self.assertEqual(x, 2 + 2 * 3 * 2)
self.assertEqual(y, 2 * 3 * 2)
def test_run_state_with_uninitialized_input(self):
def f(refs):
x_ref, y_ref = refs
# y_ref is uninitialized so we shouldn't read from it until we write into
# it.
x = x_ref[...]
y_ref[...] = x * 2
x_ref[...] = y_ref[...] + x_ref[...]
# x + x * 2, x * 2
# jax.ShapeDtypeStruct is weirdly special to JAX, so we make our own class.
class MyArrayType:
pass
state_types._ref_type_aval_mappings[MyArrayType] = lambda _: (
AbstractRef(core.ShapedArray((), jnp.int32)),
state_types.uninitialized,
)
x, y = run_state(f)((jnp.int32(2), MyArrayType()))
self.assertEqual(x, 2 + 2 * 2)
self.assertEqual(y, 2 * 2)
def test_nontrivial_run_state_jit(self):
def f(refs):
x_ref, y_ref = refs
@jax.jit
def g():
x = x_ref[...] * y_ref[...]
y_ref[...] = x * 2
x_ref[...] = y_ref[...] + x_ref[...]
# x + x * y * 2, x * y * 2
g()
x, y = run_state(f)((2, 3))
self.assertEqual(x, 2 + 2 * 3 * 2)
self.assertEqual(y, 2 * 3 * 2)
def test_simple_run_state_with_multiple_refs(self):
out1, out2 = run_state(lambda _: None)((1, 2))
self.assertEqual(out1, 1)
self.assertEqual(out2, 2)
def test_simple_run_state_with_tuple(self):
out1, out2 = run_state(lambda _: None)((1, 2))
self.assertEqual(out1, 1)
self.assertEqual(out2, 2)
def test_can_stage_run_state(self):
def f(x):
return run_state(lambda _: None)(x)
jaxpr = jax.make_jaxpr(f)(2)
self.assertIsNotNone(jaxpr.jaxpr.debug_info)
self.assertIsNotNone(jaxpr.jaxpr.debug_info.func_src_info)
def test_can_stage_run_state_leaked_tracer_error(self):
leaks = []
def f(x):
def my_fun(x):
leaks.append(x)
return None
return run_state(my_fun)(x)
_ = jax.make_jaxpr(f)(2)
with self.assertRaisesRegex(jax.errors.UnexpectedTracerError,
"The function being traced when the value leaked was .*my_fun"):
jax.jit(lambda _: leaks[0])(1)
def test_nested_run_state_captures_effects(self):
def f(x):
def body(x_ref):
def inner(y_ref):
y_ref[...]
x_ref[...]
run_state(inner)(1)
return run_state(body)(x)
jaxpr = jax.make_jaxpr(f)(2)
self.assertEmpty(jaxpr.effects)
self.assertEmpty(jaxpr.jaxpr.eqns[0].effects)
self.assertSetEqual(jaxpr.jaxpr.eqns[0].params["jaxpr"].effects,
{ReadEffect(0)})
self.assertSetEqual(
jaxpr.jaxpr.eqns[0].params["jaxpr"].eqns[0].params["jaxpr"].effects,
{ReadEffect(0), ReadEffect(1)})
def test_jvp_of_run_state(self):
@run_state
def f(refs):
x_ref, y_ref = refs
y_ref[...] = jnp.sin(x_ref[...])
xy, xy_t = jax.jvp(f, ((2., 1.),), ((3., 1.),))
# x, sin(x)
self.assertAllClose(xy, (2., np.sin(2.)))
# t, cos(x) * t
self.assertAllClose(xy_t, (3., 3 * np.cos(2.)))
x, x_t = jax.jvp(lambda x: f((x, 0.))[1], (2.,), (3.,))
self.assertAllClose(x, np.sin(2.))
self.assertAllClose(x_t, 3 * np.cos(2.))
def test_jvp_of_run_state_with_zero_tangent(self):
@run_state
def f(refs):
x_ref, z_ref, y_ref = refs
del z_ref
y_ref[...] = jnp.sin(x_ref[...])
x, x_t = jax.jvp(lambda x: f((x, 0., 0.,))[2], (2.,), (3.,))
self.assertAllClose(x, np.sin(2.))
self.assertAllClose(x_t, 3 * np.cos(2.))
def test_linearize_of_run_state(self):
@run_state
def f(refs):
x_ref, y_ref = refs
y_ref[...] = jnp.sin(x_ref[...])
(x, y), f_lin = jax.linearize(f, (1., 0.))
self.assertAllClose(x, 1.)
self.assertAllClose(y, np.sin(1.))
x_t, y_t = f_lin((2., 1.))
self.assertAllClose(x_t, 2.)
self.assertAllClose(y_t, 2. * np.cos(1.))
def test_grad_of_run_state(self):
@run_state
def f(refs):
x_ref, y_ref = refs
y_ref[...] = jnp.sin(x_ref[...])
def sin(x):
return f((x, 0.))[1]
x_g = jax.grad(sin)(1.)
self.assertAllClose(x_g, np.cos(1.))
x_g2 = jax.grad(jax.grad(sin))(1.)
self.assertAllClose(x_g2, -np.sin(1.))
x_g3 = jax.grad(jax.grad(jax.grad(sin)))(1.)
self.assertAllClose(x_g3, -np.cos(1.))
def test_vjp_of_run_state(self):
@run_state
def f(refs):
x_ref, y_ref = refs
y_ref[...] = jnp.sin(x_ref[...])
(x, y), f_vjp = jax.vjp(f, (1., 0.))
self.assertAllClose(x, 1.)
self.assertAllClose(y, np.sin(1.))
((x_ct, y_ct),) = f_vjp((0., 1.))
self.assertAllClose(x_ct, np.cos(1.))
self.assertAllClose(y_ct, 0.)
def test_vjp_of_run_state_single(self):
@run_state
def f(x_ref):
x = x_ref[...]
def _body(ref):
ref[...] = jnp.sin(ref[...])
x = run_state(_body)(x)
x_ref[...] = x
y, f_lin = jax.linearize(f, 1.)
self.assertAllClose(y, np.sin(1.))
y_t = f_lin(1.)
self.assertAllClose(y_t, np.cos(1.))
y, f_vjp = jax.vjp(f, 1.)
self.assertAllClose(y, np.sin(1.))
x_ct, = f_vjp(1.)
self.assertAllClose(x_ct, np.cos(1.))
jtu.check_grads(f, (0.5,), order=3)
if CAN_USE_HYPOTHESIS:
class FuncSpec(NamedTuple):
fun: Callable[..., Any]
name: str
min_rank: int = 0
max_rank: int = 4
min_dim: int = 0
max_dim: int = 4
def call(self, *args):
return run_state(self.fun)(*args)
def ref(self, *args):
return run_state_reference(self.fun)(*args)
def sin_stateful(refs):
x_ref, y_ref = refs
y_ref[...] = jnp.sin(x_ref[...])
sin_spec = FuncSpec(sin_stateful, "sin")
def cos_stateful(refs):
x_ref, y_ref = refs
y_ref[...] = jnp.cos(x_ref[...])
cos_spec = FuncSpec(cos_stateful, "cos")
def mul2_stateful(refs):
x_ref, y_ref = refs
y_ref[...] = x_ref[...]
y_ref[...] = y_ref[...] + x_ref[...]
mul2_spec = FuncSpec(mul2_stateful, "mul2")
def mul2_stateful_with_constant(refs):
x_ref, y_ref = refs
y_ref[...] = (2. * np.ones(x_ref.shape, x_ref.dtype)) * x_ref[...]
mul2_constant_spec = FuncSpec(mul2_stateful_with_constant, "mul2_c")
def crazy_identity_stateful(refs):
x_ref, y_ref = refs
x = x_ref[...]
x_ref[...] = (x + x) / 2
y_ref[...] = x_ref[...]
y = y_ref[...]
y_ref[...] = (y + y) / 2
crazy_identity_spec = FuncSpec(crazy_identity_stateful, "id")
def func_spec(depth: int = 4):
raw_specs = hps.sampled_from([sin_spec, cos_spec, mul2_spec,
mul2_constant_spec, crazy_identity_spec])
if depth > 0:
return hps.one_of([raw_specs, nest_spec(depth - 1), add_spec(depth - 1),
compose_spec(depth - 1)])
return raw_specs
@hps.composite
def compose_spec(draw, depth):
f1 = draw(func_spec(depth))
f2 = draw(func_spec(depth))
def wrapped_impl(*args):
f1.fun(*args)
f2.fun(*args)
return FuncSpec(wrapped_impl,
f"({f2.name} . {f1.name})",
min_rank=max(f1.min_rank, f2.min_rank),
max_rank=min(f1.max_rank, f2.max_rank),
min_dim=max(f1.min_dim, f2.min_dim),
max_dim=min(f1.max_dim, f2.max_dim))
@hps.composite
def nest_spec(draw, depth):
f = draw(func_spec(depth))
def wrapped_impl(refs):
x_ref, y_ref = refs
x, y = x_ref[...], y_ref[...]
x, y = run_state(f.fun)((x, y))
x_ref[...], y_ref[...] = x, y
return FuncSpec(wrapped_impl,
f"nest({f.name})",
min_rank=f.min_rank,
max_rank=f.max_rank,
min_dim=f.min_dim,
max_dim=f.max_dim)
@hps.composite
def add_spec(draw, depth):
f1 = draw(func_spec(depth))
f2 = draw(func_spec(depth))
def wrapped_impl(refs):
x_ref, y_ref = refs
x, y = x_ref[...], y_ref[...]
x1, y1 = run_state(f1.fun)((x, y))
x2, y2 = run_state(f2.fun)((x, y))
x_ref[...], y_ref[...] = x1 + x2, y1 + y2
return FuncSpec(wrapped_impl,
f"({f2.name} + {f1.name})",
min_rank=max(f1.min_rank, f2.min_rank),
max_rank=min(f1.max_rank, f2.max_rank),
min_dim=max(f1.min_dim, f2.min_dim),
max_dim=min(f1.max_dim, f2.max_dim))
@jtu.thread_unsafe_test_class() # because of hypothesis
class RunStateHypothesisTest(jtu.JaxTestCase):
@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_jvp(self, data):
spec = data.draw(func_spec())
def impl(x):
return spec.call((x, jnp.zeros_like(x)))[1]
def ref(x):
return spec.ref((x, jnp.zeros_like(x)))[1]
k1, k2 = random.split(random.PRNGKey(0))
shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank,
max_dims=spec.max_rank, min_side=spec.min_dim,
max_side=spec.max_dim))
x = random.normal(k1, shape)
t = random.normal(k2, x.shape)
y, y_t = jax.jvp(impl, (x,), (t,))
y_ref, y_ref_t = jax.jvp(ref, (x,), (t,))
self.assertAllClose(y, y_ref)
self.assertAllClose(y_t, y_ref_t)
@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_linearize(self, data):
spec = data.draw(func_spec())
def impl(x):
return spec.call((x, jnp.zeros_like(x)))[1]
def ref(x):
return spec.ref((x, jnp.zeros_like(x)))[1]
k1, k2 = random.split(random.PRNGKey(0))
shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank,
max_dims=spec.max_rank, min_side=spec.min_dim,
max_side=spec.max_dim))
x = random.normal(k1, shape)
y, impl_lin = jax.linearize(impl, x)
y_ref, ref_lin = jax.linearize(ref, x)
self.assertAllClose(y, y_ref, atol=1e-2, rtol=1e-2)
t = random.normal(k2, x.shape)
self.assertAllClose(impl_lin(t), ref_lin(t), atol=1e-2, rtol=1e-2)
@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_vjp(self, data):
spec = data.draw(func_spec())
def impl(x):
return spec.call((x, jnp.zeros_like(x)))[1]
def ref(x):
return spec.ref((x, jnp.zeros_like(x)))[1]
key, k1, k2 = random.split(random.PRNGKey(0), 3)
shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank,
max_dims=spec.max_rank, min_side=spec.min_dim,
max_side=spec.max_dim))
x = random.normal(k1, shape)
# First order
y, impl_lin = jax.linearize(impl, x)
y_ref, ref_lin = jax.linearize(ref, x)
self.assertAllClose(y, y_ref)
t = random.normal(k2, x.shape)
self.assertAllClose(impl_lin(t), ref_lin(t))
y, impl_vjp = jax.vjp(impl, x)
y_ref, ref_vjp = jax.vjp(ref, x)
self.assertAllClose(y, y_ref)
t = random.normal(jax.random.clone(k2), x.shape)
y2 = random.normal(jax.random.clone(k1), y.shape)
self.assertAllClose(impl_vjp(t), ref_vjp(t))
if jtu.SKIP_SLOW_TESTS.value:
# Skip second order tests if JAX_SKIP_SLOW_TESTS=true
return
# Second order
key, k1, k2 = random.split(key, 3)
t2 = random.normal(k2, t.shape)
(x,), impl_lin2 = jax.linearize(impl_vjp, t2)
(x_ref,), ref_lin2 = jax.linearize(ref_vjp, t2)
self.assertAllClose(x, x_ref)
y2 = random.normal(k1, y.shape)
self.assertAllClose(impl_lin2(y2), ref_lin2(y2))
(x,), impl_vjp2 = jax.vjp(impl_vjp, t2)
(x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2)
self.assertAllClose(x, x_ref)
y2 = random.normal(jax.random.clone(k1), y.shape)
self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,)))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())