mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
1016 lines
44 KiB
Python
1016 lines
44 KiB
Python
# Copyright 2022 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from functools import partial
|
|
import itertools as it
|
|
|
|
from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
import numpy as np
|
|
import jax
|
|
from jax import core
|
|
from jax import lax
|
|
from jax._src import linear_util as lu
|
|
from jax.config import config
|
|
from jax.interpreters import partial_eval as pe
|
|
from jax._src import test_util as jtu
|
|
from jax._src.util import tuple_insert
|
|
import jax.numpy as jnp
|
|
from jax._src.lax.control_flow import for_loop
|
|
|
|
try:
|
|
import hypothesis as hp
|
|
import hypothesis.extra.numpy as hnp
|
|
import hypothesis.strategies as hps
|
|
CAN_USE_HYPOTHESIS = True
|
|
except (ModuleNotFoundError, ImportError):
|
|
CAN_USE_HYPOTHESIS = False
|
|
|
|
from jax._src import state
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
class StatePrimitivesTest(jtu.JaxTestCase):
|
|
|
|
def test_cant_eval_get_primitive(self):
|
|
with self.assertRaises(ValueError):
|
|
state.get_p.bind(jnp.ones(5))
|
|
|
|
def test_cant_eval_swap_primitive(self):
|
|
with self.assertRaises(ValueError):
|
|
state.swap_p.bind(jnp.ones(5), jnp.zeros(5))
|
|
|
|
def test_cant_eval_addupdate_primitive(self):
|
|
with self.assertRaises(ValueError):
|
|
state.addupdate_p.bind(jnp.ones(5), jnp.zeros(5))
|
|
|
|
def test_get_abstract_aval_must_take_in_refs(self):
|
|
ref_aval = core.ShapedArray((), jnp.float32)
|
|
def f(x_ref):
|
|
return [state.ref_get(x_ref, ())]
|
|
with self.assertRaises(ValueError):
|
|
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval])
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name="trivial_get", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32,
|
|
idx=(), out_shape=(1, 2), out_dtype=jnp.float32),
|
|
dict(testcase_name="get_with_index", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32,
|
|
idx=(0,), out_shape=(2,), out_dtype=jnp.float32),
|
|
dict(testcase_name="get_with_nonleading_index", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32,
|
|
idx=(slice(None), 0), out_shape=(1,), out_dtype=jnp.float32),
|
|
dict(testcase_name="get_with_array_index", ref_shape=(1, 2, 3, 4),
|
|
ref_dtype=jnp.float32,
|
|
idx=(np.array([0, 1]),), out_shape=(2, 2, 3, 4),
|
|
out_dtype=jnp.float32),
|
|
dict(testcase_name="get_with_multiple_array_index",
|
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
|
idx=(np.array([0, 1]), np.array([0, 1])),
|
|
out_shape=(2, 2, 4), out_dtype=jnp.float32),
|
|
dict(testcase_name="get_with_nonleading_multiple_array_index",
|
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
|
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
|
|
out_shape=(2, 1, 2), out_dtype=jnp.float32),
|
|
)
|
|
def test_get_abstract_eval(self, ref_shape, ref_dtype, idx, out_shape=None,
|
|
out_dtype=None, should_error=False):
|
|
ref_aval = state.ShapedArrayRef(ref_shape, ref_dtype)
|
|
def f(x_ref):
|
|
out = state.ref_get(x_ref, idx)
|
|
return [out]
|
|
if should_error:
|
|
with self.assertRaises(Exception):
|
|
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval])
|
|
else:
|
|
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(f), [ref_aval])
|
|
self.assertSetEqual(jaxpr.effects, {state.ReadEffect(ref_aval)})
|
|
self.assertLen(out_avals, 1)
|
|
out_aval, = out_avals
|
|
self.assertIsInstance(out_aval, core.ShapedArray)
|
|
self.assertEqual(out_aval.shape, out_shape)
|
|
self.assertEqual(out_aval.dtype, out_dtype)
|
|
|
|
def test_swap_abstract_eval_must_take_in_refs(self):
|
|
ref_aval = core.ShapedArray((), jnp.float32)
|
|
val_aval = core.ShapedArray((), jnp.float32)
|
|
def f(x_ref, val):
|
|
return [state.ref_swap(x_ref, (), val)]
|
|
with self.assertRaises(ValueError):
|
|
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name="invalid_val_shape", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
|
|
idx=(), should_error=True),
|
|
dict(testcase_name="invalid_val_shape_slice", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
|
|
idx=(slice(None),), should_error=True),
|
|
dict(testcase_name="trivial_swap", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32, val_shape=(1, 2), val_dtype=jnp.float32,
|
|
idx=(), out_shape=(1, 2), out_dtype=jnp.float32),
|
|
dict(testcase_name="bad_dtype", ref_shape=(1, 2),
|
|
ref_dtype=jnp.int32, val_shape=(1, 2), val_dtype=jnp.float32,
|
|
idx=(), should_error=True),
|
|
dict(testcase_name="swap_with_index", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
|
|
idx=(0,), out_shape=(2,), out_dtype=jnp.float32),
|
|
dict(testcase_name="swap_with_nonleading_index", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32, val_shape=(1,), val_dtype=jnp.float32,
|
|
idx=(slice(None), 0), out_shape=(1,), out_dtype=jnp.float32),
|
|
dict(testcase_name="swap_with_nonleading_index_bad_val", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
|
|
idx=(slice(None), 0), should_error=True),
|
|
dict(testcase_name="swap_with_array_index", ref_shape=(1, 2, 3, 4),
|
|
ref_dtype=jnp.float32, val_shape=(2, 2, 3, 4), val_dtype=jnp.float32,
|
|
idx=(np.array([0, 1]),), out_shape=(2, 2, 3, 4),
|
|
out_dtype=jnp.float32),
|
|
dict(testcase_name="swap_with_multiple_array_index",
|
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
|
val_shape=(2, 2, 4), val_dtype=jnp.float32,
|
|
idx=(np.array([0, 1]), np.array([0, 1])),
|
|
out_shape=(2, 2, 4), out_dtype=jnp.float32),
|
|
dict(testcase_name="swap_with_nonleading_multiple_array_index",
|
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
|
val_shape=(2, 1, 2), val_dtype=jnp.float32,
|
|
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1])),
|
|
out_shape=(2, 1, 2), out_dtype=jnp.float32),
|
|
)
|
|
def test_swap_abstract_eval(self, ref_shape, ref_dtype,
|
|
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
|
|
should_error=False):
|
|
ref_aval = state.ShapedArrayRef(ref_shape, ref_dtype)
|
|
val_aval = core.ShapedArray(val_shape, val_dtype)
|
|
def f(x_ref, val):
|
|
out = state.ref_swap(x_ref, idx, val)
|
|
return [out]
|
|
if should_error:
|
|
with self.assertRaises(Exception):
|
|
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
|
|
else:
|
|
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(f), [ref_aval, val_aval])
|
|
self.assertSetEqual(jaxpr.effects, {state.WriteEffect(ref_aval)})
|
|
self.assertLen(out_avals, 1)
|
|
out_aval, = out_avals
|
|
self.assertIsInstance(out_aval, core.ShapedArray)
|
|
self.assertEqual(out_aval.shape, out_shape)
|
|
self.assertEqual(out_aval.dtype, out_dtype)
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name="invalid_val_shape", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
|
|
idx=(), should_error=True),
|
|
dict(testcase_name="invalid_val_shape_slice", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
|
|
idx=(slice(None),), should_error=True),
|
|
dict(testcase_name="trivial_addupdate", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32, val_shape=(1, 2), val_dtype=jnp.float32,
|
|
idx=(), out_shape=(1, 2), out_dtype=jnp.float32),
|
|
dict(testcase_name="bad_dtype", ref_shape=(1, 2),
|
|
ref_dtype=jnp.int32, val_shape=(1, 2), val_dtype=jnp.float32,
|
|
idx=(), should_error=True),
|
|
dict(testcase_name="addupdate_with_index", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
|
|
idx=(0,), out_shape=(2,), out_dtype=jnp.float32),
|
|
dict(testcase_name="addupdate_with_nonleading_index", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32, val_shape=(1,), val_dtype=jnp.float32,
|
|
idx=(slice(None), 0)),
|
|
dict(testcase_name="addupdate_with_nonleading_index_bad_val", ref_shape=(1, 2),
|
|
ref_dtype=jnp.float32, val_shape=(2,), val_dtype=jnp.float32,
|
|
idx=(slice(None), 0), should_error=True),
|
|
dict(testcase_name="addupdate_with_array_index", ref_shape=(1, 2, 3, 4),
|
|
ref_dtype=jnp.float32, val_shape=(2, 2, 3, 4), val_dtype=jnp.float32,
|
|
idx=(np.array([0, 1]),)),
|
|
dict(testcase_name="addupdate_with_multiple_array_index",
|
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
|
val_shape=(2, 2, 4), val_dtype=jnp.float32,
|
|
idx=(np.array([0, 1]), np.array([0, 1]))),
|
|
dict(testcase_name="addupdate_with_nonleading_multiple_array_index",
|
|
ref_shape=(1, 3, 2, 4), ref_dtype=jnp.float32,
|
|
val_shape=(2, 1, 2), val_dtype=jnp.float32,
|
|
idx=(slice(None), np.array([0, 1]), slice(None), np.array([0, 1]))),
|
|
)
|
|
def test_addupdate_abstract_eval(self, ref_shape, ref_dtype,
|
|
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
|
|
should_error=False):
|
|
ref_aval = state.ShapedArrayRef(ref_shape, ref_dtype)
|
|
val_aval = core.ShapedArray(val_shape, val_dtype)
|
|
def f(x_ref, val):
|
|
state.ref_addupdate(x_ref, idx, val)
|
|
return []
|
|
if should_error:
|
|
with self.assertRaises(Exception):
|
|
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
|
|
else:
|
|
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(f), [ref_aval, val_aval])
|
|
self.assertSetEqual(jaxpr.effects, {state.AccumEffect(ref_aval)})
|
|
self.assertLen(out_avals, 0)
|
|
|
|
def test_addupdate_abstract_eval_must_take_in_refs(self):
|
|
ref_aval = core.ShapedArray((), jnp.float32)
|
|
val_aval = core.ShapedArray((), jnp.float32)
|
|
def f(x_ref, val):
|
|
return [state.ref_addupdate(x_ref, (), val)]
|
|
with self.assertRaises(ValueError):
|
|
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
|
|
|
|
def test_can_represent_get_and_swap_in_jaxprs(self):
|
|
|
|
def body(x):
|
|
x[()] = jnp.int32(1)
|
|
x[()] = jnp.int32(2)
|
|
return (x[()],)
|
|
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
|
self.assertLen(consts, 0)
|
|
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
|
|
self.assertEqual(jaxpr.eqns[0].primitive, state.swap_p)
|
|
self.assertEqual(jaxpr.eqns[1].primitive, state.swap_p)
|
|
self.assertEqual(jaxpr.eqns[2].primitive, state.get_p)
|
|
|
|
def test_can_represent_addupdate_in_jaxprs(self):
|
|
|
|
def body(x):
|
|
state.ref_addupdate(x, (), jnp.int32(1))
|
|
return (x[()],)
|
|
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
|
self.assertLen(consts, 0)
|
|
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
|
|
self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p)
|
|
|
|
def test_get_custom_pretty_printing_rule(self):
|
|
def body(x_ref):
|
|
x = x_ref[()]
|
|
return [x]
|
|
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
|
self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False))
|
|
|
|
def body(x_ref):
|
|
x = x_ref[:, 0]
|
|
return [x]
|
|
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32)])
|
|
self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False))
|
|
|
|
def test_set_custom_pretty_printing_rule(self):
|
|
def body(x_ref):
|
|
x_ref[()] = jnp.int32(2)
|
|
return []
|
|
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
|
self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False))
|
|
|
|
def body(x_ref, val):
|
|
x_ref[:, 0] = val
|
|
return []
|
|
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32),
|
|
core.ShapedArray((1,), jnp.int32)])
|
|
self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False))
|
|
|
|
def test_swap_custom_pretty_printing_rule(self):
|
|
def body(x_ref):
|
|
x = state.ref_swap(x_ref, (), jnp.int32(2))
|
|
return [x]
|
|
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
|
self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False))
|
|
|
|
def body(x_ref, val):
|
|
x = state.ref_swap(x_ref, (slice(None), 0), val)
|
|
return [x]
|
|
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32),
|
|
core.ShapedArray((1,), jnp.int32)])
|
|
self.assertIn("c:i32[1], a[:,0] <- a[:,0], b",
|
|
jaxpr.pretty_print(use_color=False))
|
|
|
|
def test_addupdate_custom_pretty_printing_rule(self):
|
|
def body(x_ref):
|
|
state.ref_addupdate(x_ref, (), jnp.int32(2))
|
|
return []
|
|
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
|
|
|
self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False))
|
|
|
|
def body(x_ref, val):
|
|
state.ref_addupdate(x_ref, (slice(None), 0), val)
|
|
return []
|
|
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32),
|
|
core.ShapedArray((1,), jnp.int32)])
|
|
self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False))
|
|
|
|
|
|
def test_get_jvp(self):
|
|
|
|
def f(r):
|
|
x = r[()]
|
|
return jnp.cos(x)
|
|
|
|
def g(r, rdot):
|
|
return jax.jvp(f, (r,), (rdot,))
|
|
|
|
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
|
state.ShapedArrayRef((), jnp.dtype('float32'))]
|
|
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
|
self.assertEqual(jaxpr.eqns[0].primitive, state.get_p)
|
|
self.assertEqual(jaxpr.eqns[1].primitive, state.get_p)
|
|
|
|
def test_swap_jvp(self):
|
|
|
|
def f(a):
|
|
x = a[()]
|
|
a[()] = jnp.sin(x)
|
|
return a[()]
|
|
|
|
def g(r, rdot):
|
|
return jax.jvp(f, (r,), (rdot,))
|
|
|
|
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
|
state.ShapedArrayRef((), jnp.dtype('float32'))]
|
|
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
|
self.assertEqual(jaxpr.eqns[0].primitive, state.get_p)
|
|
self.assertEqual(jaxpr.eqns[1].primitive, state.get_p)
|
|
self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p)
|
|
self.assertEqual(jaxpr.eqns[3].primitive, lax.cos_p)
|
|
self.assertEqual(jaxpr.eqns[4].primitive, lax.mul_p)
|
|
self.assertEqual(jaxpr.eqns[5].primitive, state.swap_p)
|
|
self.assertEqual(jaxpr.eqns[6].primitive, state.swap_p)
|
|
|
|
def test_addupdate_jvp(self):
|
|
|
|
def f(a):
|
|
state.ref_addupdate(a, (), jnp.float32(1.))
|
|
return a[()]
|
|
|
|
def g(r, rdot):
|
|
return jax.jvp(f, (r,), (rdot,))
|
|
|
|
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
|
state.ShapedArrayRef((), jnp.dtype('float32'))]
|
|
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
|
self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p)
|
|
self.assertEqual(jaxpr.eqns[1].primitive, state.addupdate_p)
|
|
self.assertEqual(jaxpr.eqns[2].primitive, state.get_p)
|
|
self.assertEqual(jaxpr.eqns[3].primitive, state.get_p)
|
|
|
|
@jtu.sample_product(
|
|
[dict(ref_shape=ref_shape, ref_bdim=ref_bdim, idx_shape=idx_shape,
|
|
indexed_dims=indexed_dims, idx_bdims=idx_bdims, out_bdim=out_bdim)
|
|
for ref_shape in [(1,), (2, 3), (4, 5, 6)]
|
|
for ref_bdim in range(1 + len(ref_shape))
|
|
for idx_shape in [(), (1,), (2,), (5, 6)]
|
|
for indexed_dims in it.product([True, False], repeat=len(ref_shape))
|
|
for idx_bdims in it.product([None, *range(1 + len(idx_shape))],
|
|
repeat=sum(indexed_dims))
|
|
for out_bdim in range(1 + len(ref_shape) - sum(indexed_dims)
|
|
+ len(idx_shape) * any(indexed_dims))
|
|
],
|
|
op=[
|
|
lambda x_ref, indexer: [x_ref[indexer]],
|
|
lambda x_ref, indexer: [
|
|
state.ref_swap(x_ref, indexer,
|
|
jnp.ones(x_ref.shape, x_ref.dtype)[None][(0,
|
|
*indexer)])],
|
|
lambda x_ref, indexer: (
|
|
state.ref_addupdate(x_ref, indexer,
|
|
jnp.ones(x_ref.shape, x_ref.dtype)[None][(0,
|
|
*indexer)])
|
|
or [jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, *indexer)]])
|
|
],
|
|
)
|
|
def test_vmap(self, ref_shape, ref_bdim, idx_shape, indexed_dims,
|
|
idx_bdims, out_bdim, op):
|
|
|
|
float_ = (jnp.dtype('float64') if jax.config.jax_enable_x64 else
|
|
jnp.dtype('float32'))
|
|
int_ = (jnp.dtype('int64') if jax.config.jax_enable_x64 else
|
|
jnp.dtype('int32'))
|
|
axis_size = 7
|
|
out_shape = tuple([d for d, b in zip(ref_shape, indexed_dims) if not b])
|
|
if any(indexed_dims):
|
|
out_shape = (*idx_shape, *out_shape)
|
|
|
|
def maybe_insert(shape, idx):
|
|
if idx is None:
|
|
return shape
|
|
return tuple_insert(shape, idx, axis_size)
|
|
|
|
batched_ref_shape = maybe_insert(ref_shape, ref_bdim)
|
|
ref_aval = state.ShapedArrayRef(ref_shape, float_)
|
|
bat_ref_aval = state.ShapedArrayRef(batched_ref_shape, float_)
|
|
|
|
idx_avals = [core.ShapedArray(idx_shape, int_)
|
|
for _ in idx_bdims]
|
|
bat_idx_avals = [
|
|
core.ShapedArray(maybe_insert(idx_shape, idx_bdim), int_)
|
|
for idx_bdim in idx_bdims]
|
|
|
|
def f(x_ref, *idxs):
|
|
idxs_ = iter(idxs)
|
|
indexer = tuple([next(idxs_) if b else slice(None) for b in indexed_dims])
|
|
return op(x_ref, indexer)
|
|
|
|
rng = self.rng()
|
|
a = rng.randn(*bat_ref_aval.shape)
|
|
his = [d for d, b in zip(ref_aval.shape, indexed_dims) if b]
|
|
idxs = [rng.randint(low=0, high=hi, size=i.shape)
|
|
for i, hi in zip(bat_idx_avals, his)]
|
|
|
|
# discharge-of-vmap
|
|
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
|
|
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(f_batched), [bat_ref_aval, *bat_idx_avals])
|
|
jaxpr, consts = state.discharge_state(stateful_jaxpr, stateful_consts)
|
|
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs)
|
|
# vmap-of-discharge
|
|
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(f), [ref_aval, *idx_avals])
|
|
jaxpr_, consts_ = state.discharge_state(stateful_jaxpr, stateful_consts)
|
|
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
|
|
in_axes=(ref_bdim, *idx_bdims),
|
|
out_axes=[out_bdim, ref_bdim])
|
|
vmap_of_discharge_ans = f_batched(a, *idxs)
|
|
|
|
self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans,
|
|
check_dtypes=False)
|
|
|
|
|
|
class StateDischargeTest(jtu.JaxTestCase):
|
|
|
|
def test_discharge_get(self):
|
|
def f(a_ref):
|
|
a = state.ref_get(a_ref, ())
|
|
return [a + 1]
|
|
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32'))]
|
|
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
|
in_avals)
|
|
# Discharging should just turn this into a jaxpr that just adds 1.
|
|
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
|
|
self.assertLen(discharged_jaxpr.invars, 1)
|
|
self.assertLen(discharged_jaxpr.outvars, 2)
|
|
self.assertEqual(discharged_jaxpr.eqns[0].primitive, lax.add_p)
|
|
# Should be able to evaluate this jaxpr
|
|
self.assertListEqual(core.eval_jaxpr(discharged_jaxpr, (),
|
|
jnp.float32(1.)), [2., 1.])
|
|
|
|
def test_discharge_get_with_slice(self):
|
|
def f(a_ref):
|
|
a = state.ref_get(a_ref, (0, 1))
|
|
return [a + 1]
|
|
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
|
|
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
|
in_avals)
|
|
# Discharging should just turn this into a jaxpr that just adds 1.
|
|
discharged_jaxpr, () = state.discharge_state(stateful_jaxpr, consts)
|
|
self.assertLen(discharged_jaxpr.invars, 1)
|
|
self.assertLen(discharged_jaxpr.outvars, 2)
|
|
self.assertIn(lax.dynamic_slice_p,
|
|
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
|
# Should be able to evaluate this jaxpr
|
|
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
|
|
outval, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
|
|
self.assertTrue((outval == inval[0, 1] + 1).all())
|
|
self.assertTrue((refval == inval).all())
|
|
|
|
def test_discharge_get_with_gather(self):
|
|
def f(a_ref):
|
|
a = a_ref[jnp.array([0, 1])]
|
|
return [a + 1]
|
|
in_avals = [state.ShapedArrayRef((4, 3), jnp.dtype('float32'))]
|
|
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(f), in_avals)
|
|
discharged_jaxpr, discharged_consts = state.discharge_state(
|
|
stateful_jaxpr, consts)
|
|
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
|
|
outval, refval = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval)
|
|
self.assertTrue((outval == inval[jnp.array([0, 1])] + 1).all())
|
|
self.assertTrue((refval == inval).all())
|
|
|
|
def test_discharge_set(self):
|
|
def f(a_ref, b):
|
|
state.ref_set(a_ref, (), b + 1)
|
|
return []
|
|
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
|
core.ShapedArray((), jnp.dtype('float32'))]
|
|
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
|
in_avals)
|
|
# Discharging should just turn this into a jaxpr that ignores the first
|
|
# value and returns second value plus 1.
|
|
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
|
|
self.assertLen(discharged_jaxpr.invars, 2)
|
|
self.assertLen(discharged_jaxpr.outvars, 1)
|
|
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.),
|
|
jnp.float32(1.))[0], 2.)
|
|
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.),
|
|
jnp.float32(1.))[0], 2.)
|
|
|
|
def test_discharge_set_with_slice(self):
|
|
def f(a_ref):
|
|
state.ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32')))
|
|
return []
|
|
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
|
|
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
|
in_avals)
|
|
# Discharging should just turn this into a jaxpr that just adds 1.
|
|
discharged_jaxpr, () = state.discharge_state(stateful_jaxpr, consts)
|
|
self.assertLen(discharged_jaxpr.invars, 1)
|
|
self.assertLen(discharged_jaxpr.outvars, 1)
|
|
self.assertIn(lax.dynamic_update_slice_p,
|
|
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
|
self.assertIn(lax.dynamic_slice_p,
|
|
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
|
# Should be able to evaluate this jaxpr
|
|
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
|
|
refval, = core.eval_jaxpr(discharged_jaxpr, (), inval)
|
|
self.assertTrue((refval == inval.at[0, 1].set(1.)).all())
|
|
|
|
def test_discharge_set_with_gather(self):
|
|
def f(a_ref):
|
|
a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32')
|
|
return []
|
|
in_avals = [state.ShapedArrayRef((4, 3), jnp.dtype('float32'))]
|
|
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
|
in_avals)
|
|
discharged_jaxpr, discharged_consts = state.discharge_state(
|
|
stateful_jaxpr, consts)
|
|
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
|
|
refval, = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval)
|
|
self.assertTrue((refval == inval.at[jnp.array([0, 1])].set(1.)).all())
|
|
|
|
def test_discharge_addupdate(self):
|
|
def f(a_ref, b):
|
|
state.ref_addupdate(a_ref, (), b + 1)
|
|
return []
|
|
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
|
core.ShapedArray((), jnp.dtype('float32'))]
|
|
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
|
in_avals)
|
|
# Discharging should just turn this into a jaxpr that adds the first value,
|
|
# second value, and 1.
|
|
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
|
|
self.assertLen(discharged_jaxpr.invars, 2)
|
|
self.assertLen(discharged_jaxpr.outvars, 1)
|
|
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(0.),
|
|
jnp.float32(1.))[0], 2.)
|
|
self.assertEqual(core.eval_jaxpr(discharged_jaxpr, (), jnp.float32(2.),
|
|
jnp.float32(1.))[0], 4.)
|
|
|
|
def test_discharge_addupdate_with_slice(self):
|
|
def f(a_ref):
|
|
state.ref_addupdate(a_ref, (0, 1),
|
|
jnp.ones(2, dtype=jnp.dtype('float32')))
|
|
return []
|
|
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
|
|
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
|
in_avals)
|
|
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
|
|
self.assertLen(discharged_jaxpr.invars, 1)
|
|
self.assertLen(discharged_jaxpr.outvars, 1)
|
|
self.assertIn(lax.dynamic_update_slice_p,
|
|
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
|
self.assertIn(lax.add_p,
|
|
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
|
self.assertIn(lax.dynamic_slice_p,
|
|
set(eqn.primitive for eqn in discharged_jaxpr.eqns))
|
|
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
|
|
refval, = core.eval_jaxpr(discharged_jaxpr, (), inval)
|
|
self.assertTrue((refval == inval.at[0, 1].add(1.)).all())
|
|
|
|
def test_discharge_addupdate_with_gather(self):
|
|
def f(a_ref):
|
|
state.ref_addupdate(a_ref, (jnp.array([0, 1]),),
|
|
jnp.ones((2, 3), 'float32'))
|
|
return []
|
|
in_avals = [state.ShapedArrayRef((4, 3), jnp.dtype('float32'))]
|
|
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
|
in_avals)
|
|
discharged_jaxpr, discharged_consts = state.discharge_state(
|
|
stateful_jaxpr, consts)
|
|
inval = jnp.arange(4 * 3, dtype=jnp.float32).reshape((4, 3))
|
|
refval, = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval)
|
|
self.assertTrue((refval == inval.at[jnp.array([0, 1])].add(1.)).all())
|
|
|
|
def test_discharge_jaxpr_with_multiple_outputs(self):
|
|
def f(a_ref):
|
|
a = state.ref_get(a_ref, ())
|
|
b = a + 1
|
|
return [a, b]
|
|
in_avals = [state.ShapedArrayRef((4,), jnp.dtype('float32'))]
|
|
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
|
in_avals)
|
|
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
|
|
self.assertLen(discharged_jaxpr.invars, 1)
|
|
self.assertLen(discharged_jaxpr.outvars, 3)
|
|
inval = jnp.arange(4., dtype=jnp.float32)
|
|
a, b, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
|
|
self.assertTrue((a == inval).all())
|
|
self.assertTrue((b == inval + 1).all())
|
|
self.assertTrue((refval == inval).all())
|
|
|
|
def test_partially_discharging_jaxpr_keeps_refs(self):
|
|
def f(a_ref, b_ref):
|
|
state.ref_set(a_ref, (), jnp.ones(4, jnp.float32))
|
|
state.ref_set(b_ref, (), jnp.ones(4, jnp.float32))
|
|
return []
|
|
in_avals = [
|
|
state.ShapedArrayRef((4,), jnp.dtype('float32')),
|
|
state.ShapedArrayRef((4,), jnp.dtype('float32'))
|
|
]
|
|
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
|
in_avals)
|
|
discharged_jaxpr, _ = state.discharge_state(
|
|
stateful_jaxpr, consts, should_discharge=[False, True])
|
|
self.assertLen(discharged_jaxpr.invars, 2)
|
|
self.assertLen(discharged_jaxpr.outvars, 1)
|
|
self.assertIsInstance(discharged_jaxpr.invars[0].aval, state.ShapedArrayRef)
|
|
self.assertIsInstance(discharged_jaxpr.invars[1].aval, core.ShapedArray)
|
|
self.assertEqual(discharged_jaxpr.effects,
|
|
{state.WriteEffect(discharged_jaxpr.invars[0].aval)})
|
|
|
|
|
|
if CAN_USE_HYPOTHESIS:
|
|
|
|
def index_arrays(size, idx_shape):
|
|
valid_idx = hps.integers(min_value=-size, max_value=size - 1)
|
|
return hnp.arrays(np.int32, idx_shape, elements=valid_idx)
|
|
|
|
Shape = tuple[int, ...]
|
|
|
|
class IndexParam(NamedTuple):
|
|
ref_aval: state.ShapedArrayRef
|
|
ref_shape: Shape
|
|
indexed_dims: list[bool]
|
|
idx_avals: tuple[core.ShapedArray, ...]
|
|
idx_shape: Shape
|
|
slice_aval: core.ShapedArray
|
|
slice_shape: Shape
|
|
|
|
@hps.composite
|
|
def index_params(draw):
|
|
ref_shape = draw(hnp.array_shapes(max_dims=4, max_side=7), label='ref_shape')
|
|
indexed_dims = draw(hps.lists(hps.booleans(),
|
|
min_size=len(ref_shape),
|
|
max_size=len(ref_shape)))
|
|
idx_shape = draw(hnp.array_shapes(max_dims=3, max_side=5))
|
|
if any(indexed_dims):
|
|
sliced_shape = (s for s, b in zip(ref_shape, indexed_dims) if not b)
|
|
slice_shape = (*idx_shape, *sliced_shape)
|
|
else:
|
|
slice_shape = ref_shape
|
|
ref_aval = state.ShapedArrayRef(ref_shape, np.float32)
|
|
idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in
|
|
range(sum(indexed_dims)))
|
|
slice_aval = core.ShapedArray(slice_shape, np.float32)
|
|
return IndexParam(ref_aval, ref_shape, indexed_dims, idx_avals, idx_shape,
|
|
slice_aval, slice_shape)
|
|
|
|
class VmappableIndexParam(NamedTuple):
|
|
index_param: IndexParam
|
|
ref_bdim: Optional[int]
|
|
non_slice_idx_bdims: tuple[Optional[int], ...]
|
|
slice_bdim: int
|
|
bat_ref_aval: state.ShapedArrayRef
|
|
bat_ref_shape: Shape
|
|
bat_non_slice_idx_avals: tuple[core.ShapedArray, ...]
|
|
bat_non_slice_idx_shapes: tuple[Shape, ...]
|
|
bat_slice_aval: core.ShapedArray
|
|
bat_slice_shape: Shape
|
|
|
|
def maybe_tuple_insert(t: tuple[Any, ...], idx: Optional[int],
|
|
val: Any) -> tuple[Any, ...]:
|
|
if idx is None:
|
|
return t
|
|
return tuple_insert(t, idx, val)
|
|
|
|
@hps.composite
|
|
def vmappable_index_params(draw, *, op_type: str):
|
|
axis_size = draw(hps.integers(min_value=1, max_value=7), label='axis_size')
|
|
index_param: IndexParam = draw(index_params())
|
|
non_slice_idx_bdims = tuple(
|
|
draw(hps.one_of(
|
|
hps.none(),
|
|
hps.integers(min_value=0, max_value=len(index_param.idx_shape))))
|
|
for b in index_param.indexed_dims if b)
|
|
bat_non_slice_idx_shapes = tuple(
|
|
maybe_tuple_insert(index_param.idx_shape, idx_bdim, axis_size)
|
|
for idx_bdim in non_slice_idx_bdims)
|
|
if op_type == "swap":
|
|
# In a swap, the ref *must* be batched
|
|
ref_bdim = draw(hps.integers(min_value=0,
|
|
max_value=len(index_param.ref_shape)))
|
|
if any(idx_bdim is not None for idx_bdim in non_slice_idx_bdims):
|
|
# If it's a swap, if indices are batched, val must be batched.
|
|
slice_bdim = draw(hps.integers(
|
|
min_value=0, max_value=len(index_param.slice_shape)))
|
|
else:
|
|
slice_bdim = draw(hps.one_of(hps.none(), hps.integers(
|
|
min_value=0, max_value=len(index_param.slice_shape))))
|
|
elif op_type == "get":
|
|
# In a get, the indices must be batched or ref is batched
|
|
if all(idx_bdim is None for idx_bdim in non_slice_idx_bdims):
|
|
ref_bdim = draw(hps.integers(min_value=0,
|
|
max_value=len(index_param.ref_shape)))
|
|
else:
|
|
ref_bdim = draw(hps.one_of(hps.none(),
|
|
hps.integers(min_value=0, max_value=len(index_param.ref_shape))))
|
|
slice_bdim = draw(hps.integers(
|
|
min_value=0, max_value=len(index_param.slice_shape)))
|
|
|
|
bat_ref_shape = maybe_tuple_insert(index_param.ref_shape, ref_bdim, axis_size)
|
|
bat_ref_aval = state.ShapedArrayRef(bat_ref_shape, np.float32)
|
|
bat_non_slice_idx_avals = tuple(
|
|
core.ShapedArray(shape, np.int32) for shape in bat_non_slice_idx_shapes)
|
|
bat_slice_shape = maybe_tuple_insert(index_param.slice_shape, slice_bdim, axis_size)
|
|
bat_slice_aval = core.ShapedArray(bat_slice_shape, np.float32)
|
|
return VmappableIndexParam(index_param, ref_bdim, non_slice_idx_bdims,
|
|
slice_bdim, bat_ref_aval, bat_ref_shape,
|
|
bat_non_slice_idx_avals, bat_non_slice_idx_shapes,
|
|
bat_slice_aval, bat_slice_shape)
|
|
|
|
class GetVmapParams(NamedTuple):
|
|
vmap_index_param: VmappableIndexParam
|
|
bat_ref: np.ndarray
|
|
bat_idxs: tuple[np.ndarray, ...]
|
|
|
|
@hps.composite
|
|
def get_vmap_params(draw):
|
|
vmap_index_param: VmappableIndexParam = draw(
|
|
vmappable_index_params(op_type="get"))
|
|
bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape))
|
|
bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes)
|
|
bat_idxs = tuple(
|
|
draw(index_arrays(size, next(bat_idx_shapes_)))
|
|
for size, indexed in zip(
|
|
vmap_index_param.index_param.ref_shape,
|
|
vmap_index_param.index_param.indexed_dims)
|
|
if indexed)
|
|
assert next(bat_idx_shapes_, None) is None
|
|
return GetVmapParams(vmap_index_param, bat_ref, bat_idxs)
|
|
|
|
class SetVmapParams(NamedTuple):
|
|
vmap_index_param: VmappableIndexParam
|
|
bat_ref: np.ndarray
|
|
bat_val: np.ndarray
|
|
bat_idxs: tuple[np.ndarray, ...]
|
|
|
|
@hps.composite
|
|
def set_vmap_params(draw):
|
|
vmap_index_param: VmappableIndexParam = draw(vmappable_index_params(
|
|
op_type="swap"))
|
|
bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape))
|
|
bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes)
|
|
bat_idxs = tuple(
|
|
draw(index_arrays(size, next(bat_idx_shapes_)))
|
|
for size, indexed in zip(
|
|
vmap_index_param.index_param.ref_shape,
|
|
vmap_index_param.index_param.indexed_dims)
|
|
if indexed)
|
|
assert next(bat_idx_shapes_, None) is None
|
|
bat_val = draw(hnp.arrays(np.float32, vmap_index_param.bat_slice_shape))
|
|
return SetVmapParams(vmap_index_param, bat_ref, bat_val, bat_idxs)
|
|
|
|
Indexer = Tuple[Union[int, slice, np.ndarray]]
|
|
|
|
def _unpack_idx(idx: Indexer
|
|
) -> Tuple[Sequence[Union[int, np.ndarray]], Sequence[bool]]:
|
|
indexed_dims = [type(i) != slice for i in idx]
|
|
non_slice_idx = [i for i, b in zip(idx, indexed_dims) if b]
|
|
return non_slice_idx, indexed_dims
|
|
|
|
def _pack_idx(non_slice_idx: Sequence[Union[int, np.ndarray]],
|
|
indexed_dims: Sequence[bool]) -> Indexer:
|
|
idx_ = iter(non_slice_idx)
|
|
idx = tuple(next(idx_) if b else slice(None) for b in indexed_dims)
|
|
assert next(idx_, None) is None
|
|
return idx
|
|
|
|
class StateHypothesisTest(jtu.JaxTestCase):
|
|
|
|
@hp.given(get_vmap_params())
|
|
@hp.settings(deadline=None, print_blob=True, max_examples=50)
|
|
def test_get_vmap(self, get_vmap_param: GetVmapParams):
|
|
|
|
indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims
|
|
|
|
def f(ref, *non_slice_idx):
|
|
idx = _pack_idx(non_slice_idx, indexed_dims)
|
|
return [state.ref_get(ref, idx)]
|
|
ref_aval = get_vmap_param.vmap_index_param.index_param.ref_aval
|
|
bat_ref_aval = get_vmap_param.vmap_index_param.bat_ref_aval
|
|
bat_non_slice_idx_avals = get_vmap_param.vmap_index_param.bat_non_slice_idx_avals
|
|
ref_bdim = get_vmap_param.vmap_index_param.ref_bdim
|
|
idx_bdims = get_vmap_param.vmap_index_param.non_slice_idx_bdims
|
|
out_bdim = get_vmap_param.vmap_index_param.slice_bdim
|
|
non_slice_idx = get_vmap_param.bat_idxs
|
|
idx_avals = get_vmap_param.vmap_index_param.index_param.idx_avals
|
|
ref = get_vmap_param.bat_ref
|
|
|
|
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
|
|
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(f_batched), [bat_ref_aval, *bat_non_slice_idx_avals])
|
|
jaxpr, consts = state.discharge_state(stateful_jaxpr, stateful_consts)
|
|
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx)
|
|
|
|
# vmap-of-discharge
|
|
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(f), [ref_aval, *idx_avals])
|
|
jaxpr_, consts_ = state.discharge_state(stateful_jaxpr, stateful_consts)
|
|
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
|
|
in_axes=(ref_bdim, *idx_bdims),
|
|
out_axes=[out_bdim, ref_bdim])
|
|
vmap_of_discharge_ans = f_batched(ref, *non_slice_idx)
|
|
|
|
self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans,
|
|
check_dtypes=False)
|
|
|
|
|
|
@hp.given(set_vmap_params())
|
|
@hp.settings(deadline=None, print_blob=True, max_examples=50)
|
|
def test_set_vmap(self, set_vmap_param: SetVmapParams):
|
|
|
|
indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims
|
|
|
|
def f(ref, val, *non_slice_idx):
|
|
idx = _pack_idx(non_slice_idx, indexed_dims)
|
|
state.ref_set(ref, idx, val)
|
|
return []
|
|
ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval
|
|
bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval
|
|
bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals
|
|
ref_bdim = set_vmap_param.vmap_index_param.ref_bdim
|
|
idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims
|
|
non_slice_idx = set_vmap_param.bat_idxs
|
|
idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals
|
|
ref = set_vmap_param.bat_ref
|
|
val = set_vmap_param.bat_val
|
|
bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval
|
|
val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval
|
|
val_bdim = set_vmap_param.vmap_index_param.slice_bdim
|
|
|
|
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
|
|
out_axes=[])
|
|
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
|
|
jaxpr, consts = state.discharge_state(stateful_jaxpr, stateful_consts)
|
|
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
|
|
|
|
# vmap-of-discharge
|
|
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(f), [ref_aval, val_aval, *idx_avals])
|
|
jaxpr_, consts_ = state.discharge_state(stateful_jaxpr, stateful_consts)
|
|
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
|
|
in_axes=(ref_bdim, val_bdim, *idx_bdims),
|
|
out_axes=[ref_bdim])
|
|
vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx)
|
|
|
|
self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans,
|
|
check_dtypes=False)
|
|
|
|
|
|
@hp.given(set_vmap_params())
|
|
@hp.settings(deadline=None, print_blob=True, max_examples=50)
|
|
def test_addupdate_vmap(self, set_vmap_param: SetVmapParams):
|
|
|
|
indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims
|
|
|
|
def f(ref, val, *non_slice_idx):
|
|
idx = _pack_idx(non_slice_idx, indexed_dims)
|
|
state.ref_addupdate(ref, idx, val)
|
|
return []
|
|
ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval
|
|
bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval
|
|
bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals
|
|
ref_bdim = set_vmap_param.vmap_index_param.ref_bdim
|
|
idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims
|
|
non_slice_idx = set_vmap_param.bat_idxs
|
|
idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals
|
|
ref = set_vmap_param.bat_ref
|
|
val = set_vmap_param.bat_val
|
|
bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval
|
|
val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval
|
|
val_bdim = set_vmap_param.vmap_index_param.slice_bdim
|
|
|
|
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
|
|
out_axes=[])
|
|
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
|
|
jaxpr, consts = state.discharge_state(stateful_jaxpr, stateful_consts)
|
|
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
|
|
|
|
# vmap-of-discharge
|
|
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(f), [ref_aval, val_aval, *idx_avals])
|
|
jaxpr_, consts_ = state.discharge_state(stateful_jaxpr, stateful_consts)
|
|
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
|
|
in_axes=(ref_bdim, val_bdim, *idx_bdims),
|
|
out_axes=[ref_bdim])
|
|
vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx)
|
|
|
|
self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans,
|
|
check_dtypes=False)
|
|
|
|
|
|
class StateControlFlowTest(jtu.JaxTestCase):
|
|
|
|
def test_simple_cond(self):
|
|
def f(pred):
|
|
def body(x_ref):
|
|
def true_fun():
|
|
x_ref[()] = 1.
|
|
def false_fun():
|
|
pass
|
|
lax.cond(pred, true_fun, false_fun)
|
|
return for_loop.run_state(body, 0.)
|
|
jaxpr = jax.make_jaxpr(f)(True).jaxpr
|
|
self.assertEmpty(jaxpr.effects)
|
|
self.assertAllClose(jax.jit(f)(True), 1.)
|
|
self.assertAllClose(jax.jit(f)(False), 0.)
|
|
|
|
def test_nested_cond(self):
|
|
def f(pred):
|
|
def body(x_ref):
|
|
def true_fun():
|
|
def true_fun_inner():
|
|
x_ref[()] = 1.
|
|
def false_fun_inner():
|
|
pass
|
|
return lax.cond(pred, true_fun_inner, false_fun_inner)
|
|
def false_fun():
|
|
pass
|
|
lax.cond(pred, true_fun, false_fun)
|
|
return for_loop.run_state(body, 0.)
|
|
jaxpr = jax.make_jaxpr(f)(True).jaxpr
|
|
self.assertEmpty(jaxpr.effects)
|
|
self.assertAllClose(jax.jit(f)(True), 1.)
|
|
self.assertAllClose(jax.jit(f)(False), 0.)
|
|
|
|
def test_cond_jvp_with_state(self):
|
|
def f(pred, init_value):
|
|
def body(x_ref):
|
|
def true_fun():
|
|
x_ref[()] = x_ref[()] ** 2
|
|
def false_fun():
|
|
pass
|
|
lax.cond(pred, true_fun, false_fun)
|
|
return for_loop.run_state(body, init_value)
|
|
|
|
out_primal, out_tangent = jax.jvp(partial(f, True), (3.,), (1.,))
|
|
self.assertAllClose(out_primal, 9.)
|
|
self.assertAllClose(out_tangent, 6.)
|
|
|
|
out_primal, out_tangent = jax.jvp(partial(f, False), (3.,), (1.,))
|
|
self.assertAllClose(out_primal, 3.)
|
|
self.assertAllClose(out_tangent, 1.)
|
|
|
|
def test_cond_vmap_not_implemented(self):
|
|
@jax.jit
|
|
def f(init_value):
|
|
def body(x_ref):
|
|
def true_fun():
|
|
x_ref[()] = x_ref[()] ** 2
|
|
def false_fun():
|
|
pass
|
|
lax.cond(x_ref[()] < 1, true_fun, false_fun)
|
|
return for_loop.run_state(body, init_value)
|
|
|
|
with self.assertRaises(NotImplementedError):
|
|
jax.vmap(f)(jnp.arange(2.))
|
|
|
|
def test_cond_grad_not_implemented(self):
|
|
@jax.jit
|
|
def f(init_value):
|
|
def body(x_ref):
|
|
def true_fun():
|
|
x_ref[()] = x_ref[()] ** 2
|
|
def false_fun():
|
|
pass
|
|
lax.cond(True, true_fun, false_fun)
|
|
return for_loop.run_state(body, init_value)
|
|
|
|
with self.assertRaises(NotImplementedError):
|
|
jax.grad(f)(3.)
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|