mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

There's no reason why not two custom vmappable types cannot share the same spec_type. However, spec_types was a set, which can cause bugs / exceptions. Suppose that I register two vmappable data_types sharing the same spec_type, and then unregister one of the two. Then, the spec_type is no longer in the set to support the second data_type. Also, an exception will be raised if I try to unregister the two vmappable types (the second call to spec_types.remove). When unregistering a data type, instead of removing its spec_type from the set, we regenerate the set from the remaining vmappable types. PiperOrigin-RevId: 737280270
1388 lines
52 KiB
Python
1388 lines
52 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from contextlib import contextmanager
|
|
from functools import partial
|
|
import itertools as it
|
|
from typing import Any, TypeVar, Union
|
|
|
|
import numpy as np
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import jax.scipy as jsp
|
|
from jax._src import core
|
|
from jax._src import dtypes
|
|
from jax._src import test_util as jtu
|
|
from jax import lax
|
|
from jax._src.lax import parallel
|
|
from jax import random
|
|
from jax import jit, grad, jvp, vjp, make_jaxpr, jacfwd, jacrev, hessian
|
|
from jax import vmap
|
|
from jax.interpreters import batching
|
|
from jax.tree_util import register_pytree_node
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
# These are 'manual' tests for batching (vmap). The more exhaustive, more
|
|
# systematic tests are in lax_test.py's LaxVmapTest class.
|
|
|
|
class BatchingTest(jtu.JaxTestCase):
|
|
|
|
def testConstantFunction(self):
|
|
ans = vmap(lambda x: 3)(np.ones(4))
|
|
expected = 3 * np.ones(4)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
@jax.default_matmul_precision("float32")
|
|
def testNestedBatchingMatMat(self):
|
|
matvec = vmap(jnp.vdot, in_axes=(0, None))
|
|
matmat = vmap(matvec, in_axes=(None, 1), out_axes=1)
|
|
|
|
R = self.rng().randn
|
|
A = R(4, 3)
|
|
B = R(3, 2)
|
|
|
|
ans = matmat(A, B)
|
|
expected = np.dot(A, B)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
jaxpr = make_jaxpr(matmat)(A, B)
|
|
self.assertLen(jaxpr.jaxpr.eqns, 1)
|
|
|
|
def testPerExampleGradients(self):
|
|
def predict(params, inputs):
|
|
for W, b in params:
|
|
outputs = jnp.dot(W, inputs) + b
|
|
inputs = jnp.tanh(outputs)
|
|
return outputs
|
|
|
|
def loss(params, data):
|
|
inputs, targets = data
|
|
predictions = predict(params, inputs)
|
|
return jnp.sum((predictions - targets)**2)
|
|
|
|
batch_size = 5
|
|
layer_sizes = [3, 2, 4]
|
|
|
|
R = self.rng().randn
|
|
params = [(R(m, n), R(m))
|
|
for m, n in zip(layer_sizes[1:], layer_sizes[:-1])]
|
|
|
|
input_batch = R(5, 3)
|
|
target_batch = R(5, 4)
|
|
batch = (input_batch, target_batch)
|
|
|
|
ans = vmap(partial(grad(loss), params))(batch)
|
|
|
|
for ans_pair, param_pair in zip(ans, params):
|
|
dW, db = ans_pair
|
|
W, b = param_pair
|
|
|
|
self.assertEqual(dW.shape, (batch_size,) + W.shape)
|
|
self.assertEqual(db.shape, (batch_size,) + b.shape)
|
|
|
|
@jax.default_matmul_precision("float32")
|
|
def testJacobians(self):
|
|
def jacbwd(f, x):
|
|
y, pullback = vjp(f, x)
|
|
std_basis = np.eye(np.size(y)).reshape((-1,) + np.shape(y))
|
|
jac_flat, = vmap(pullback, out_axes=np.ndim(y))(std_basis)
|
|
return jac_flat.reshape(np.shape(y) + np.shape(x))
|
|
|
|
def jacfwd(f, x):
|
|
pushfwd = lambda v: jvp(f, (x,), (v,))
|
|
std_basis = np.eye(np.size(x)).reshape((-1,) + np.shape(x))
|
|
y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis)
|
|
return jac_flat.reshape(np.shape(y) + np.shape(x))
|
|
|
|
R = self.rng().randn
|
|
|
|
A = R(4, 3)
|
|
b = R(4)
|
|
f = lambda x: jnp.tanh(jnp.dot(A, x) + b)
|
|
|
|
x = R(3)
|
|
self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False)
|
|
|
|
def testBatchOfCompile(self):
|
|
side = []
|
|
|
|
@jit
|
|
def f(x):
|
|
side.append(None)
|
|
return x + x
|
|
|
|
g = jit(vmap(f))
|
|
self.assertAllClose(g(np.ones(2)), 2 * np.ones(2), check_dtypes=False)
|
|
self.assertEqual(len(side), 1)
|
|
self.assertAllClose(g(2 * np.ones(2)), 4 * np.ones(2),
|
|
check_dtypes=False)
|
|
self.assertEqual(len(side), 1)
|
|
|
|
def testSliceLax(self):
|
|
fun = lambda x: lax.slice(x, (2,), (4,))
|
|
R = self.rng().randn
|
|
x = R(5, 10)
|
|
|
|
ans = vmap(fun)(x)
|
|
expected_ans = x[:, 2:4]
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
def testSliceNumpy(self):
|
|
fun = lambda x: x[:, 2]
|
|
R = self.rng().randn
|
|
x = R(10, 5, 3, 7)
|
|
|
|
ans = vmap(fun)(x)
|
|
expected_ans = x[:, :, 2]
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
def testRevLax(self):
|
|
fun = lambda x: lax.rev(x, [0])
|
|
R = self.rng().randn
|
|
x = R(2, 3)
|
|
|
|
ans = vmap(fun)(x)
|
|
expected_ans = x[:, ::-1]
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
ans = vmap(fun, (1,), 1)(x)
|
|
expected_ans = x[::-1, :]
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
def testRevNumpy(self):
|
|
fun = lambda x: x[:, ::-1]
|
|
R = self.rng().randn
|
|
x = R(3, 2, 4)
|
|
|
|
ans = vmap(fun)(x)
|
|
expected_ans = x[:, :, ::-1]
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
ans = vmap(fun, (1,), 1)(x)
|
|
expected_ans = x[:, :, ::-1]
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
ans = vmap(fun, (2,), 2)(x)
|
|
expected_ans = x[:, ::-1, :]
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
def testNpMaximum(self):
|
|
fun = lambda x: jnp.maximum(x, 0.0)
|
|
R = self.rng().randn
|
|
x = R(10, 5, 3, 7)
|
|
|
|
ans = vmap(fun)(x)
|
|
expected_ans = np.maximum(x, 0.0)
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
def testNpGtrThan(self):
|
|
R = self.rng().randn
|
|
x = R(10, 5, 3, 7)
|
|
|
|
ans = vmap(lambda x: x > 1.0)(x)
|
|
expected_ans = x > 1.0
|
|
self.assertAllClose(ans, expected_ans)
|
|
|
|
@jax.default_matmul_precision("float32")
|
|
def testNpMaximumPerExampleGrad(self):
|
|
R = self.rng().randn
|
|
x = R(10, 5)
|
|
W = R(5, 5)
|
|
|
|
fun = lambda W, x: jnp.sum(jnp.maximum(jnp.dot(x, W), 0.0) ** 2)
|
|
|
|
ans = vmap(partial(grad(fun), W))(x)
|
|
|
|
W_t = jnp.transpose(W)
|
|
for i in range(10):
|
|
x_ex = x[i:i + 1]
|
|
|
|
expected_ans = 2.0 * jnp.dot(
|
|
jnp.maximum(jnp.dot(W_t, jnp.transpose(x_ex)), 0.0), x_ex)
|
|
expected_ans = jnp.transpose(expected_ans)
|
|
|
|
self.assertAllClose(ans[i], expected_ans, check_dtypes=False)
|
|
|
|
# Replace the default TF32 with float32 in order to make it pass on A100
|
|
@jax.default_matmul_precision("float32")
|
|
def testDotGeneral(self):
|
|
R = self.rng().randn
|
|
|
|
x = R(10, 3, 4, 5)
|
|
y = R(10, 3, 5, 6)
|
|
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
|
|
ans = vmap(fun)(x, y)
|
|
expected = lax.dot_general(x, y, [((3,), (2,)), ((0, 1), (0, 1))])
|
|
self.assertAllClose(ans, expected)
|
|
|
|
x = R(3, 4, 10, 5)
|
|
y = R(3, 10, 5, 6)
|
|
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
|
|
ans = vmap(fun, in_axes=(2, 1))(x, y)
|
|
expected = np.stack([fun(x[..., i, :], y[:, i, ...]) for i in range(10)])
|
|
self.assertAllClose(ans, expected)
|
|
|
|
x = R(3, 4, 5, 10)
|
|
y = R(3, 5, 6)
|
|
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
|
|
ans = vmap(fun, in_axes=(3, None))(x, y)
|
|
expected = np.stack([fun(x[..., i], y) for i in range(10)])
|
|
self.assertAllClose(ans, expected)
|
|
|
|
x = R(3, 4, 5)
|
|
y = R(3, 5, 10, 6)
|
|
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
|
|
ans = vmap(fun, in_axes=(None, 2))(x, y)
|
|
expected = np.stack([fun(x, y[..., i, :]) for i in range(10)])
|
|
self.assertAllClose(ans, expected)
|
|
|
|
x = R(4)
|
|
y = R(4, 10)
|
|
fun = lambda x, y: lax.dot_general(x, y, [((0,), (0,)), ((), ())])
|
|
ans = vmap(fun, in_axes=(None, 1))(x, y)
|
|
expected = np.stack([fun(x, y[..., i]) for i in range(10)])
|
|
self.assertAllClose(ans, expected)
|
|
|
|
def testDot(self):
|
|
# these tests are based on @shoyer's notebook studying gufuncs
|
|
|
|
def vecvec(a, b):
|
|
dot = jnp.dot
|
|
for ndim in range(1, max(a.ndim, b.ndim)):
|
|
a_ax = 0 if a.ndim > ndim else None
|
|
b_ax = 0 if b.ndim > ndim else None
|
|
dot = vmap(dot, in_axes=(a_ax, b_ax))
|
|
return dot(a, b)
|
|
|
|
assert vecvec(jnp.zeros((3,)), jnp.zeros((3,))).shape == ()
|
|
assert vecvec(jnp.zeros((2, 3)), jnp.zeros((3,))).shape == (2,)
|
|
assert vecvec(jnp.zeros((4, 2, 3)), jnp.zeros((3,))).shape == (4, 2)
|
|
|
|
def testDot2(self):
|
|
R = self.rng().randn
|
|
xs = R(10, 3)
|
|
ys = R(10, 3)
|
|
ans = vmap(jnp.dot)(xs, ys)
|
|
expected = np.einsum('ni,ni->n', xs, ys)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def testDot3(self):
|
|
R = self.rng().randn
|
|
xs = R(5, 8, 10)
|
|
ys = R(10, 1)
|
|
ans = vmap(jnp.dot, in_axes=(1, None))(xs, ys)
|
|
expected = np.einsum('inj,jk->nik', xs, ys)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def testDot4(self):
|
|
R = self.rng().randn
|
|
xs = R(3, 2)
|
|
ys = R(3)
|
|
ans = vmap(jnp.dot, in_axes=(1, None))(xs, ys)
|
|
expected = np.einsum('ij,i->j', xs, ys)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def testPad(self):
|
|
R = self.rng().randn
|
|
|
|
fun = lambda x: lax.pad(x, np.float32(0), [(1, 2, 1)])
|
|
x = R(5, 10).astype(np.float32)
|
|
ans = vmap(fun)(x)
|
|
expected_ans = jnp.stack(list(map(fun, x)))
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
|
|
fun = lambda x: lax.pad(x, np.float32(0), [(1, 2, 1), (0, 1, 0)])
|
|
x = R(5, 10, 3).astype(np.float32)
|
|
ans = vmap(fun)(x)
|
|
expected_ans = jnp.stack(list(map(fun, x)))
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
def testConcatenate(self):
|
|
R = lambda *shape: self.rng().randn(*shape).astype(np.float32)
|
|
|
|
fun = lambda *args: lax.concatenate(args, dimension=0)
|
|
x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3)
|
|
ans = vmap(fun, in_axes=(0, 1, None))(x, y, z)
|
|
expected_ans = np.concatenate([x, np.swapaxes(y, 0, 1),
|
|
np.broadcast_to(z, (10, 4, 3))], 1)
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
fun = lambda *args: lax.concatenate(args, dimension=1)
|
|
x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10)
|
|
ans = vmap(fun, in_axes=(0, None, 2))(x, y, z)
|
|
expected_ans = np.concatenate([x, np.broadcast_to(y, (10, 2, 3)),
|
|
np.moveaxis(z, 2, 0)], 2)
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
def testJacobianIssue54(self):
|
|
# test modeling the code in https://github.com/jax-ml/jax/issues/54
|
|
|
|
def func(xs):
|
|
return jnp.array(list(xs))
|
|
|
|
xs = jnp.ones((5, 1))
|
|
jacrev(func)(xs) # don't crash
|
|
jacfwd(func)(xs) # don't crash
|
|
|
|
def testAny(self):
|
|
# test modeling the code in https://github.com/jax-ml/jax/issues/108
|
|
|
|
ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]]))
|
|
expected = jnp.array([True, False])
|
|
self.assertAllClose(ans, expected)
|
|
|
|
def testHessian(self):
|
|
# test based on code from sindhwani@google
|
|
def fun(x, t):
|
|
return jnp.sum(jnp.power(jnp.maximum(x, 0.0), 2)) + t
|
|
|
|
x = np.array([-1., -0.5, 0., 0.5, 1.0])
|
|
|
|
ans = hessian(lambda x: fun(x, 0.0))(x)
|
|
expected = np.array([[0., 0., 0., 0., 0.],
|
|
[0., 0., 0., 0., 0.],
|
|
[0., 0.,0.5, 0., 0.],
|
|
[0., 0., 0., 2., 0.],
|
|
[0., 0., 0., 0., 2.]])
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def testDynamicSlice(self):
|
|
# test dynamic_slice via numpy indexing syntax
|
|
# see https://github.com/jax-ml/jax/issues/1613 for an explanation of why we
|
|
# need to use np rather than np to create x and idx
|
|
x = jnp.arange(30).reshape((10, 3))
|
|
|
|
ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1)
|
|
expected = x[:, 1]
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
idx = jnp.array([0, 1, 2, 1, 0] * 2)
|
|
ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx)
|
|
expected = x[np.arange(10), idx]
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
x = jnp.arange(3)
|
|
idx = jnp.array([0, 1, 2, 1, 0] * 2)
|
|
ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx)
|
|
expected = x[idx]
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def testDynamicUpdateSlice(self):
|
|
x = self.rng().randn(10, 3)
|
|
y = self.rng().randn(10)
|
|
ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
|
|
in_axes=(0, 0, None))(x, y, 1)
|
|
expected = x.copy()
|
|
expected[:, 1] = y
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
x = self.rng().randn(3)
|
|
idx = np.array([0, 1, 2, 1, 0] * 2)
|
|
ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
|
|
in_axes=(None, 0, 0))(x, y, idx)
|
|
expected = np.broadcast_to(x, (10, 3)).copy()
|
|
expected[np.arange(10), idx] = y
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
@jax.legacy_prng_key('allow')
|
|
def testRandom(self):
|
|
seeds = vmap(random.PRNGKey)(np.arange(10))
|
|
ans = vmap(partial(random.normal, shape=(3, 2)))(seeds)
|
|
expected = np.stack([random.normal(random.PRNGKey(seed), (3, 2))
|
|
for seed in np.arange(10)])
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
assert len(np.unique(ans)) == 10 * 3 * 2
|
|
|
|
def testSort(self):
|
|
v = np.arange(12)[::-1].reshape(3, 4)
|
|
|
|
sv = vmap(partial(lax.sort, dimension=0), (0,))(v)
|
|
self.assertAllClose(sv, v[:, ::-1])
|
|
|
|
sv = vmap(partial(lax.sort, dimension=-1), (0,))(v)
|
|
self.assertAllClose(sv, v[:, ::-1])
|
|
|
|
sv = vmap(partial(lax.sort, dimension=0), (1,))(v)
|
|
self.assertAllClose(sv, v[::-1, :].T)
|
|
|
|
sv = vmap(partial(lax.sort, dimension=0), (1,), 1)(v)
|
|
self.assertAllClose(sv, v[::-1, :])
|
|
|
|
def testSortKeyVal(self):
|
|
k = np.arange(12)[::-1].reshape(3, 4)
|
|
v = self.rng().permutation(12).reshape(3, 4)
|
|
|
|
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v)
|
|
self.assertAllClose(sk, k[:, ::-1])
|
|
self.assertAllClose(sv, v[:, ::-1])
|
|
|
|
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v)
|
|
self.assertAllClose(sk, k[::-1, :])
|
|
self.assertAllClose(sv, v[::-1, :])
|
|
|
|
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T)
|
|
self.assertAllClose(sk, k[:, ::-1])
|
|
self.assertAllClose(sv, v[:, ::-1])
|
|
|
|
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v)
|
|
self.assertAllClose(sk, k[:, ::-1])
|
|
self.assertAllClose(sv, v[:, ::-1])
|
|
|
|
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v)
|
|
self.assertAllClose(sk, np.broadcast_to(k[0, ::-1], (3, 4)))
|
|
self.assertAllClose(sv, v[:, ::-1])
|
|
|
|
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0])
|
|
self.assertAllClose(sk, k[:, ::-1])
|
|
self.assertAllClose(sv, np.broadcast_to(v[0, ::-1], (3, 4)))
|
|
|
|
def testConvGeneralDilated(self):
|
|
W = jnp.array(self.rng().randn(3, 3, 1, 5), dtype=np.float32)
|
|
X = jnp.array(self.rng().randn(10, 5, 5, 1), dtype=np.float32)
|
|
|
|
def f(params, x):
|
|
one = (1, 1)
|
|
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
|
|
y = lax.conv_general_dilated(
|
|
x, params, one, 'SAME', one, one, dimension_numbers)
|
|
return y
|
|
grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2))
|
|
|
|
# Test forward prop.
|
|
per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
|
per_example = jnp.reshape(per_example, (10, 5, 5, 5))
|
|
per_example_direct = f(W, X)
|
|
self.assertAllClose(per_example, per_example_direct)
|
|
|
|
# Test gradients.
|
|
per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
|
per_example_direct = []
|
|
for i in range(10):
|
|
g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1)))
|
|
per_example_direct += [
|
|
jnp.reshape(g, (1,) + g.shape)]
|
|
per_example_direct = jnp.concatenate(per_example_direct, axis=0)
|
|
self.assertAllClose(per_example, per_example_direct,
|
|
rtol=2e-2, atol=2e-3)
|
|
|
|
def testConvGeneralDilatedBatchNotMajor(self):
|
|
W = jnp.array(self.rng().randn(3, 3, 1, 4), dtype=np.float32)
|
|
x = jnp.array(self.rng().randn(3, 5, 7, 5, 1), dtype=np.float32)
|
|
|
|
def f(params, x):
|
|
one = (1, 1)
|
|
dimension_numbers = ('HNWC', 'HWIO', 'HWNC')
|
|
y = lax.conv_general_dilated(
|
|
x, params, one, 'SAME', one, one, dimension_numbers)
|
|
return y
|
|
|
|
per_example = vmap(partial(f, W))(x)
|
|
per_example = jnp.reshape(jnp.transpose(per_example, (1, 2, 0, 3, 4)),
|
|
(5, 5, 21, 4))
|
|
per_example_direct = f(W, jnp.reshape(jnp.transpose(x, (1, 0, 2, 3, 4)),
|
|
(5, 21, 5, 1)))
|
|
self.assertAllClose(per_example, per_example_direct)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": f"_op={name}", "op": op, "unit": unit}
|
|
for name, op, unit in [("max", lax.max, -jnp.inf), ("min", lax.min, jnp.inf)])
|
|
def testMinMaxPool(self, op, unit):
|
|
W = jnp.array(self.rng().randn(3, 3, 1, 5), dtype=np.float32)
|
|
X = jnp.array(self.rng().randn(10, 5, 5, 1), dtype=np.float32)
|
|
|
|
def f(params, x):
|
|
one = (1, 1)
|
|
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
|
|
y = lax.conv_general_dilated(
|
|
x, params, one, 'SAME', one, one, dimension_numbers)
|
|
y = lax.reduce_window(
|
|
y, unit, op, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME')
|
|
return y
|
|
grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2))
|
|
|
|
# Test forward prop.
|
|
per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
|
per_example = jnp.reshape(per_example, (10, 5, 5, 5))
|
|
per_example_direct = f(W, X)
|
|
self.assertAllClose(per_example, per_example_direct)
|
|
|
|
# Test gradients.
|
|
per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
|
per_example_direct = []
|
|
for i in range(10):
|
|
g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1)))
|
|
per_example_direct += [
|
|
jnp.reshape(g, (1,) + g.shape)]
|
|
per_example_direct = jnp.concatenate(per_example_direct, axis=0)
|
|
self.assertAllClose(per_example, per_example_direct, rtol=5e-2, atol=1e-3)
|
|
|
|
def testSumPool(self):
|
|
W = jnp.array(self.rng().randn(3, 3, 1, 5), dtype=np.float32)
|
|
X = jnp.array(self.rng().randn(10, 5, 5, 1), dtype=np.float32)
|
|
|
|
def f(params, x):
|
|
one = (1, 1)
|
|
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
|
|
y = lax.conv_general_dilated(
|
|
x, params, one, 'SAME', one, one, dimension_numbers)
|
|
y = lax.reduce_window(
|
|
y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME')
|
|
return y
|
|
grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2))
|
|
|
|
# Test forward prop.
|
|
per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
|
per_example = jnp.reshape(per_example, (10, 5, 5, 5))
|
|
per_example_direct = f(W, X)
|
|
self.assertAllClose(per_example, per_example_direct)
|
|
|
|
# Test gradients.
|
|
per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
|
per_example_direct = []
|
|
for i in range(10):
|
|
g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1)))
|
|
per_example_direct += [
|
|
jnp.reshape(g, (1,) + g.shape)]
|
|
per_example_direct = jnp.concatenate(per_example_direct, axis=0)
|
|
self.assertAllClose(per_example, per_example_direct,
|
|
rtol=3e-2, atol=1e-3)
|
|
|
|
def testCumProd(self):
|
|
x = jnp.arange(9).reshape(3, 3) + 1
|
|
y = vmap(lambda x: jnp.cumprod(x, axis=-1))(x)
|
|
self.assertAllClose(jnp.cumprod(x, axis=1), y)
|
|
|
|
def testSelect(self):
|
|
pred = np.array([True, False])
|
|
on_true = np.array([0, 1])
|
|
on_false = np.array([2, 3])
|
|
ans = vmap(lax.select)(pred, on_true, on_false)
|
|
expected = np.array([0, 3])
|
|
self.assertAllClose(ans, expected)
|
|
|
|
pred = np.array([False, True])
|
|
on_true = np.array([0, 1])
|
|
on_false = np.array([2, 3])
|
|
ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false)
|
|
expected = np.array([[2, 3],
|
|
[0, 1]])
|
|
self.assertAllClose(ans, expected)
|
|
|
|
pred = True
|
|
on_true = np.array([0, 1], np.float32)
|
|
on_false = np.array(3, np.float32)
|
|
ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false)
|
|
expected = np.array([0, 1], np.float32)
|
|
self.assertAllClose(ans, expected)
|
|
|
|
pred = np.array([False, True])
|
|
on_true = np.array([0, 1], np.float32)
|
|
on_false = np.array(3, np.float32)
|
|
ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false)
|
|
expected = np.array([3, 1], np.float32)
|
|
self.assertAllClose(ans, expected)
|
|
|
|
pred = np.array([False, True])
|
|
on_true = np.array([2], np.float32)
|
|
on_false = np.array([[3, 4]], np.float32)
|
|
ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false)
|
|
expected = np.array([[3, 2]], np.float32)
|
|
self.assertAllClose(ans, expected)
|
|
|
|
def testLaxLinalgCholesky(self):
|
|
a = self.rng().randn(10, 5, 5).astype(np.float32)
|
|
a = np.matmul(a, np.conj(np.swapaxes(a, -1, -2)))
|
|
|
|
ans = vmap(lax.linalg.cholesky)(a)
|
|
expected = np.linalg.cholesky(a)
|
|
self.assertAllClose(ans, expected, check_dtypes=False, atol=1E-3)
|
|
|
|
b = self.rng().randn(10, 5, 5).astype(np.float32)
|
|
b = np.matmul(b, np.conj(np.swapaxes(b, -1, -2)))
|
|
b_trans = np.swapaxes(b, 0, 1) # shape is (5, 10, 5)
|
|
|
|
ans = vmap(lax.linalg.cholesky, in_axes=1, out_axes=0)(b_trans)
|
|
expected = np.linalg.cholesky(b)
|
|
self.assertAllClose(ans, expected, check_dtypes=False, rtol=1e-4)
|
|
|
|
def testLaxLinalgTriangularSolve(self):
|
|
a = self.rng().randn(4, 10, 4).astype(np.float32)
|
|
a += np.eye(4, dtype=jnp.float32)[:, None, :]
|
|
b = self.rng().randn(5, 4, 10).astype(np.float32)
|
|
|
|
ans = vmap(lax.linalg.triangular_solve, in_axes=(1, 2))(a, b)
|
|
expected = np.stack(
|
|
[lax.linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)])
|
|
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)
|
|
|
|
ans = vmap(lax.linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b)
|
|
expected = np.stack(
|
|
[lax.linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10)])
|
|
self.assertAllClose(ans, expected)
|
|
|
|
ans = vmap(lax.linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0])
|
|
expected = np.stack(
|
|
[lax.linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)])
|
|
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)
|
|
|
|
def testLaxLinalgTridiagonalSolve(self):
|
|
dl = self.rng().randn(4, 10).astype(np.float32)
|
|
d = self.rng().randn(4, 10).astype(np.float32) + 1.
|
|
du = self.rng().randn(4, 10).astype(np.float32)
|
|
b = self.rng().randn(4, 5, 10).astype(np.float32)
|
|
|
|
ans = vmap(lax.linalg.tridiagonal_solve, in_axes=(1, 1, 1, 2))(dl, d, du, b)
|
|
expected = np.stack(
|
|
[lax.linalg.tridiagonal_solve(
|
|
dl[:, i], d[:, i], du[:, i], b[..., i]) for i in range(10)])
|
|
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)
|
|
|
|
ans = vmap(lax.linalg.tridiagonal_solve, in_axes=(None, None, None, 2))(
|
|
dl[:, 0], d[:, 0], du[:, 0], b)
|
|
expected = np.stack(
|
|
[lax.linalg.tridiagonal_solve(
|
|
dl[:, 0], d[:, 0], du[:, 0], b[..., i]) for i in range(10)])
|
|
self.assertAllClose(ans, expected)
|
|
|
|
ans = vmap(lax.linalg.tridiagonal_solve, in_axes=(1, 1, 1, None))(
|
|
dl, d, du, b[..., 0])
|
|
expected = np.stack(
|
|
[lax.linalg.tridiagonal_solve(
|
|
dl[:, i], d[:, i], du[:, i], b[..., 0]) for i in range(10)])
|
|
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
|
|
slice_sizes),
|
|
"axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
|
|
"slice_sizes": slice_sizes}
|
|
for dtype in [np.float32, np.int32]
|
|
for axis, shape, idxs, dnums, slice_sizes in [
|
|
(0, (3, 5), np.array([[0], [2]]), lax.GatherDimensionNumbers(
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
(1,)),
|
|
(1, (10, 3), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
|
(2,)),
|
|
(1, (10, 3, 5), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
(1, 3)),
|
|
(2, (10, 5, 3), np.array([[0, 2], [1, 0]]),
|
|
lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(0,),
|
|
start_index_map=(0, 1)),
|
|
(1, 3)),
|
|
])
|
|
def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes):
|
|
rng = jtu.rand_default(self.rng())
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
|
operand = rng(shape, dtype)
|
|
ans = vmap(fun, (axis, None))(operand, idxs)
|
|
expected = np.stack([fun(operand[(slice(None),) * axis + (i,)], idxs)
|
|
for i in range(operand.shape[axis])])
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
|
|
slice_sizes),
|
|
"axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
|
|
"slice_sizes": slice_sizes}
|
|
for dtype in [np.float32, np.float64]
|
|
for axis, shape, idxs, dnums, slice_sizes in [
|
|
(0, (3, 5), np.array([[0], [2]]), lax.GatherDimensionNumbers(
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
(1,)),
|
|
(1, (10, 3), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
|
(2,)),
|
|
(1, (10, 3, 5), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
(1, 3)),
|
|
(2, (10, 5, 3), np.array([[0, 2], [1, 0]]),
|
|
lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(0,),
|
|
start_index_map=(0, 1)),
|
|
(1, 3))
|
|
])
|
|
def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes):
|
|
rng = jtu.rand_default(self.rng())
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
|
gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
|
|
operand = rng(shape, dtype)
|
|
ans = vmap(gfun, (axis, None))(operand, idxs)
|
|
expected = np.stack([gfun(operand[(slice(None),) * axis + (i,)], idxs)
|
|
for i in range(operand.shape[axis])])
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
|
|
slice_sizes),
|
|
"axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
|
|
"slice_sizes": slice_sizes}
|
|
for dtype in [np.float32, np.int32]
|
|
for axis, shape, idxs, dnums, slice_sizes in [
|
|
(0, (5,), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)),
|
|
(1, (10,), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
|
|
lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)),
|
|
(1, (10, 5), np.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
|
|
lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)),
|
|
(0, (10, 5), np.array([[[0, 1], [2, 0]],
|
|
[[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
|
|
])
|
|
def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes):
|
|
rng = jtu.rand_default(self.rng())
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
|
operand = rng(shape, dtype)
|
|
ans = vmap(fun, (None, axis))(operand, idxs)
|
|
expected = np.stack([fun(operand, idxs[(slice(None),) * axis + (i,)])
|
|
for i in range(idxs.shape[axis])])
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
|
|
slice_sizes),
|
|
"axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
|
|
"slice_sizes": slice_sizes}
|
|
for dtype in [np.float32, np.float64]
|
|
for axis, shape, idxs, dnums, slice_sizes in [
|
|
(0, (5,), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)),
|
|
(1, (10,), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
|
|
lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)),
|
|
(1, (10, 5), np.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
|
|
lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)),
|
|
(0, (10, 5), np.array([[[0, 1], [2, 0]],
|
|
[[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
|
|
])
|
|
def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes):
|
|
rng = jtu.rand_default(self.rng())
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
|
gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
|
|
operand = rng(shape, dtype)
|
|
ans = vmap(gfun, (None, axis))(operand, idxs)
|
|
expected = np.stack([gfun(operand, idxs[(slice(None),) * axis + (i,)])
|
|
for i in range(idxs.shape[axis])])
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs,
|
|
dnums, slice_sizes),
|
|
"op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
|
|
dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes}
|
|
for dtype in [np.float32, np.int32]
|
|
for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
|
|
(0, 0, (2, 5), np.array([[[0], [2]], [[1], [3]]]),
|
|
lax.GatherDimensionNumbers(
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
(1,)),
|
|
(1, 1, (10, 2), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
|
|
lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
|
(2,)),
|
|
(0, 1, (2, 10, 5,), np.array([[[0, 2, 1], [0, 3, 3]]]).T,
|
|
lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
(1, 3)),
|
|
(2, 0, (10, 5, 2), np.array([[[0, 2], [1, 0]],
|
|
[[1, 0], [2, 0]]]),
|
|
lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
|
|
(1, 3)),
|
|
])
|
|
def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes):
|
|
rng = jtu.rand_default(self.rng())
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
|
operand = rng(shape, dtype)
|
|
assert operand.shape[op_axis] == idxs.shape[idxs_axis]
|
|
ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs)
|
|
expected = np.stack([fun(operand[(slice(None),) * op_axis + (i,)],
|
|
idxs[(slice(None),) * idxs_axis + (i,)])
|
|
for i in range(idxs.shape[idxs_axis])])
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs,
|
|
dnums, slice_sizes),
|
|
"op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
|
|
dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes}
|
|
for dtype in [np.float32]
|
|
for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
|
|
(0, 0, (2, 5), np.array([[[0], [2]], [[1], [3]]]),
|
|
lax.GatherDimensionNumbers(
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
(1,)),
|
|
(1, 1, (10, 2), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
|
|
lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
|
(2,)),
|
|
(0, 1, (2, 10, 5,), np.array([[[0, 2, 1], [0, 3, 3]]]).T,
|
|
lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
(1, 3)),
|
|
(2, 0, (10, 5, 2), np.array([[[0, 2], [1, 0]],
|
|
[[1, 0], [2, 0]]]),
|
|
lax.GatherDimensionNumbers(
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
|
|
(1, 3)),
|
|
])
|
|
def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
|
|
slice_sizes):
|
|
rng = jtu.rand_default(self.rng())
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
|
gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
|
|
operand = rng(shape, dtype)
|
|
assert operand.shape[op_axis] == idxs.shape[idxs_axis]
|
|
ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs)
|
|
expected = np.stack([gfun(operand[(slice(None),) * op_axis + (i,)],
|
|
idxs[(slice(None),) * idxs_axis + (i,)])
|
|
for i in range(idxs.shape[idxs_axis])])
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def testNumpyIndexing1(self):
|
|
a = jnp.arange(2 * 3 * 4).reshape((2, 3, 4))
|
|
ind = np.array([[0, 1],
|
|
[2, 0]])
|
|
def f(a, ind):
|
|
return a[:, ind]
|
|
expected = np.stack([f(a, ind[i, :]) for i in range(ind.shape[0])])
|
|
ans = vmap(f, (None, 0))(a, ind)
|
|
assert np.all(ans == expected)
|
|
|
|
def testNumpyIndexing2(self):
|
|
a = jnp.arange(2 * 3 * 4).reshape((2, 3, 4))
|
|
def f(a):
|
|
inds = jnp.array([0, 2])
|
|
return a[:, inds]
|
|
ans = vmap(f)(a)
|
|
expected = np.stack([f(a[:, i, :]) for i in range(a.shape[1])], axis=1)
|
|
assert np.all(ans == expected)
|
|
|
|
def testTranspose(self):
|
|
x = np.arange(4 * 3 * 3).reshape((4, 3, 3))
|
|
ans = vmap(lambda x: x + x.T)(x)
|
|
expected = x + np.swapaxes(x, -1, -2)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def testTransposePermutation(self):
|
|
x = np.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
|
|
ans = vmap(lambda x: jnp.transpose(x, (1, 0, 2)))(x)
|
|
expected = np.transpose(x, (0, 2, 1, 3))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
x = np.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
|
|
ans = vmap(lambda x: jnp.transpose(x, (1, 2, 0)))(x)
|
|
expected = np.transpose(x, (0, 2, 3, 1))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
x = np.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5))
|
|
ans = vmap(lambda x: jnp.transpose(x, (1, 2, 0)), in_axes=2)(x)
|
|
expected = np.transpose(x, (2, 1, 3, 0))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def testIssue354(self):
|
|
psd_mat = self.rng().randn(20, 10)
|
|
psd_mat = psd_mat.T.dot(psd_mat)
|
|
vec = self.rng().randn(10)
|
|
|
|
def f(scale):
|
|
scaled_mat = scale[jnp.newaxis] * psd_mat
|
|
chol = jnp.linalg.cholesky(scaled_mat)
|
|
return -0.5 * jnp.sum((jnp.einsum('ij,j->i', chol, vec))**2)
|
|
vmapped_f = vmap(f)
|
|
vmapped_f_grad = grad(lambda x: jnp.sum(vmapped_f(x)))
|
|
|
|
scales = np.array([[0.1], [0.2], [0.3], [0.4], [0.5]])
|
|
ans = vmapped_f_grad(scales) # don't crash!
|
|
expected = np.stack([grad(f)(scale) for scale in scales])
|
|
self.assertAllClose(ans, expected, check_dtypes=False,
|
|
rtol=jtu.default_gradient_tolerance)
|
|
|
|
@jax.legacy_prng_key('allow')
|
|
def testIssue489(self):
|
|
# https://github.com/jax-ml/jax/issues/489
|
|
def f(key):
|
|
def body_fn(uk):
|
|
key = uk[1]
|
|
u = random.uniform(key, ())
|
|
key, _ = random.split(key)
|
|
return u, key
|
|
|
|
u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
|
|
return u
|
|
|
|
with jax.debug_key_reuse(False):
|
|
print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash
|
|
|
|
def testEmptyTuples(self):
|
|
# Ensure there is no crash when a vectorized input contains empty tuples.
|
|
result = vmap(lambda x, _: x + 1)(np.array([0, 1]), ())
|
|
self.assertAllClose(result, np.array([1, 2]), check_dtypes=False)
|
|
# Ensure there is no crash when a vectorized output contains empty tuples.
|
|
result, empty_tuple = vmap(lambda x: (x + 1, ()))(np.array([0, 1]))
|
|
self.assertAllClose(result, np.array([1, 2]), check_dtypes=False)
|
|
self.assertEqual((), empty_tuple)
|
|
|
|
def testIndexAddBatchedIndexesOnly(self):
|
|
f = lambda x, idx, y: jnp.asarray(x).at[idx].add(y)
|
|
result = vmap(f, (None, 0, None))(np.zeros((10,)), np.arange(10,), 1.)
|
|
self.assertAllClose(result, np.eye(10), check_dtypes=False)
|
|
|
|
def testIssue1170(self):
|
|
def f(index1, index2):
|
|
return jnp.arange(36).reshape(6, 6)[index1, index2]
|
|
g = jax.jit(jax.pmap(f))
|
|
ans = g(index1=np.asarray([1]), index2=np.asarray([2]))
|
|
expected = g(np.asarray([1]), np.asarray([2]))
|
|
self.assertAllClose(ans, expected)
|
|
|
|
def testIssue3883(self):
|
|
def scalar_f(x):
|
|
return lax.dynamic_slice(x, [], [])
|
|
|
|
xs = jnp.array([1, 2, 3, 4])
|
|
ans = vmap(scalar_f)(xs)
|
|
expected = jnp.array([scalar_f(x) for x in xs])
|
|
self.assertAllClose(ans, expected)
|
|
|
|
def scalar_f2(x):
|
|
return lax.dynamic_update_slice(x, 7, [])
|
|
|
|
xs = jnp.array([1, 2, 3, 4])
|
|
ans = vmap(scalar_f2)(xs)
|
|
expected = jnp.array([scalar_f2(x) for x in xs])
|
|
self.assertAllClose(ans, expected)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_{}_vmap_names={}_collective_names={}".format(
|
|
collective.__name__.replace(" ", ""),
|
|
"".join(vmap_names), "".join(collective_names)),
|
|
"collective": collective, "bulk_op": bulk_op, "vmap_names": vmap_names,
|
|
"collective_names": collective_names}
|
|
for collective, bulk_op in [(lax.psum, jnp.sum),
|
|
(lax.pmax, jnp.max),
|
|
(lax.pmin, jnp.min)]
|
|
for vmap_names in [('i',), ('i', 'j'), ('i', 'j', 'k')]
|
|
for subset_size in range(1, len(vmap_names) + 1)
|
|
for collective_subset in it.combinations(vmap_names, subset_size)
|
|
for collective_names in it.permutations(collective_subset))
|
|
def testCommAssocCollective(self, collective, bulk_op, vmap_names, collective_names):
|
|
shape = (2, 2, 2)
|
|
x = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
|
|
|
|
# To test relative permutations of the order in which the axis names appear
|
|
# in the primitive call versus the order the vmaps are applied, we always
|
|
# apply vmaps in the order of the `vmap_names` argument, and apply the
|
|
# collective with names according to the `collective_names` argument.
|
|
f = lambda x: x - collective(x, collective_names)
|
|
# Use non-zero in and out axes to improve the coverage
|
|
for i, axis_name in enumerate(vmap_names):
|
|
f = vmap(f, axis_name=axis_name, in_axes=i, out_axes=i)
|
|
pos_axis = [i for i, name in enumerate(vmap_names) if name in collective_names]
|
|
self.assertAllClose(f(x), x - bulk_op(x, axis=pos_axis, keepdims=True))
|
|
|
|
if collective is lax.psum:
|
|
jtu.check_grads(f, (x,), 2, eps=1)
|
|
|
|
def testPPermute(self):
|
|
nelem = 10
|
|
ntests = 10
|
|
x = np.arange(nelem)
|
|
rng = self.rng()
|
|
for i in range(ntests):
|
|
perm = np.arange(nelem)
|
|
rng.shuffle(perm)
|
|
perm_pairs = np.stack([np.arange(nelem), perm], axis=-1)
|
|
rng.shuffle(perm_pairs)
|
|
self.assertAllClose(
|
|
vmap(lambda x: x - lax.ppermute(x, 'i', perm_pairs), axis_name='i')(x),
|
|
x - x[np.argsort(perm)])
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
|
|
"split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
|
|
for split_axis, concat_axis, vmap_axis in it.product(range(3), range(3), range(4)))
|
|
def testAllToAll(self, vmap_axis, split_axis, concat_axis):
|
|
shape = (4, 4, 4, 4)
|
|
x = np.arange(np.prod(shape)).reshape(shape)
|
|
f = vmap(lambda x: lax.all_to_all(x, 'i', split_axis, concat_axis),
|
|
in_axes=vmap_axis, axis_name='i')
|
|
y = f(x)
|
|
ref = jnp.moveaxis(x, (vmap_axis, split_axis + (vmap_axis <= split_axis)),
|
|
(concat_axis + 1, 0))
|
|
self.assertAllClose(y, ref)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
|
|
"split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
|
|
for split_axis, concat_axis, vmap_axis in it.product(range(2), range(2), range(3)))
|
|
def testAllToAllSplitAxis(self, vmap_axis, split_axis, concat_axis):
|
|
shape = (4, 4, 4)
|
|
x = np.arange(np.prod(shape)).reshape(shape)
|
|
|
|
@partial(vmap, in_axes=vmap_axis, axis_name='i')
|
|
@partial(vmap, in_axes=vmap_axis, axis_name='j')
|
|
def f(x):
|
|
return lax.all_to_all(x, ('i', 'j'), split_axis, concat_axis)
|
|
|
|
unroll_shape = (2, 2, *shape[1:])
|
|
unroll_shape = list(shape)
|
|
unroll_shape[vmap_axis:vmap_axis+1] = (2, 2)
|
|
x_unroll = x.reshape(unroll_shape)
|
|
y_unrolled = f(x_unroll)
|
|
y = y_unrolled.reshape(shape)
|
|
|
|
if vmap_axis <= split_axis:
|
|
split_axis += 1
|
|
ref = jnp.moveaxis(x, (vmap_axis, split_axis),
|
|
(concat_axis + 1, 0))
|
|
self.assertAllClose(y, ref)
|
|
|
|
def testNegativeAxes(self):
|
|
x = np.arange(3*4*5).reshape(3, 4, 5)
|
|
self.assertAllClose(jax.vmap(jnp.sum, in_axes=-3)(x),
|
|
jnp.sum(x, axis=(1, 2)))
|
|
self.assertAllClose(jax.vmap(jnp.sum, in_axes=-2)(x),
|
|
jnp.sum(x, axis=(0, 2)))
|
|
self.assertAllClose(jax.vmap(jnp.sum, in_axes=-1)(x),
|
|
jnp.sum(x, axis=(0, 1)))
|
|
|
|
|
|
error = (r"vmap was requested to map its argument along axis -4, which "
|
|
r"implies that its rank should be at least 4, but is only 3 "
|
|
r"\(its shape is \(3, 4, 5\)\)")
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
jax.vmap(jnp.sum, in_axes=-4)(x)
|
|
|
|
id = lambda y: y
|
|
self.assertAllClose(x, jax.vmap(id, in_axes=0, out_axes=-3)(x))
|
|
self.assertAllClose(x.transpose(1, 0, 2),
|
|
jax.vmap(id, in_axes=0, out_axes=-2)(x))
|
|
self.assertAllClose(x.transpose(1, 2, 0),
|
|
jax.vmap(id, in_axes=0, out_axes=-1)(x))
|
|
|
|
with self.assertRaisesRegex(ValueError, "axis -4 is out of bounds.*"):
|
|
jax.vmap(id, in_axes=0, out_axes=-4)(x)
|
|
|
|
self.assertAllClose(
|
|
np.full((5,), 7),
|
|
jax.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -1))(
|
|
np.arange(5), 7)[1])
|
|
|
|
with self.assertRaisesRegex(ValueError, "axis -2 is out of bounds.*"):
|
|
jax.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -2))(
|
|
np.arange(5), 7)
|
|
|
|
def testAxisIndex(self):
|
|
x = np.arange(10, dtype='int32')
|
|
self.assertAllClose(
|
|
vmap(lambda x: x - lax.axis_index('i'), axis_name='i')(x),
|
|
x - np.arange(x.shape[0], dtype='int32'))
|
|
|
|
def testVmapKwargs(self):
|
|
# https://github.com/jax-ml/jax/issues/912
|
|
|
|
def f(a, b):
|
|
return (2*a, 3*b)
|
|
|
|
x = vmap(f)(jnp.array([1]), jnp.array([2])) # works
|
|
y = vmap(f)(a=jnp.array([1]), b=jnp.array([2])) # doesn't work
|
|
self.assertAllClose(x, y)
|
|
|
|
def testGradOfPsum(self):
|
|
a = jnp.ones(5)
|
|
f = vmap(jax.grad(lambda x: -lax.psum(x, 'i')), out_axes=None, axis_name='i')
|
|
self.assertEqual(
|
|
f(a),
|
|
core.jaxpr_as_fun(jax.make_jaxpr(f)(a))(a)[0])
|
|
|
|
def testAllGatherToUnmapped(self):
|
|
def f(x):
|
|
return lax.all_gather(x, axis_name='i')
|
|
|
|
x = jnp.arange(15).reshape((3, 5))
|
|
# Original mapped axis becomes first axis of unmapped return value.
|
|
self.assertAllClose(vmap(f, axis_name='i', in_axes=1, out_axes=None)(x), x.T)
|
|
|
|
def testBatchedAllGather(self):
|
|
def f(x):
|
|
return lax.all_gather(x, axis_name='i')
|
|
|
|
x = jnp.arange(15).reshape((3, 5))
|
|
res = vmap(vmap(f, axis_name='i', out_axes=None), axis_name='j')(x)
|
|
self.assertAllClose(res, x)
|
|
|
|
res = vmap(vmap(f, axis_name='j'), axis_name='i', out_axes=None)(x)
|
|
self.assertAllClose(res, x.T)
|
|
|
|
def testAllGatherTiled(self):
|
|
def f(x):
|
|
return lax.all_gather(x, axis_name='i', tiled=True)
|
|
|
|
x = jnp.arange(60).reshape((4, 3, 5))
|
|
res = vmap(f, axis_name='i', in_axes=(1,), out_axes=None)(x)
|
|
self.assertAllClose(res, x.transpose((1, 0, 2)).reshape(-1, 5))
|
|
|
|
def testBatchedAllGatherTiled(self):
|
|
def f(x):
|
|
return lax.all_gather(x, axis_name='i', tiled=True)
|
|
|
|
x = jnp.arange(60).reshape((4, 3, 5))
|
|
res = vmap(vmap(f, in_axes=1, out_axes=1), axis_name='i', in_axes=1, out_axes=None)(x)
|
|
self.assertAllClose(res, x.transpose((1, 0, 2)).reshape(-1, 5))
|
|
|
|
def testAllGatherVjp(self):
|
|
def f(x):
|
|
return lax.all_gather(x, axis_name='i')
|
|
|
|
rng = self.rng()
|
|
x = rng.randn(3, 4)
|
|
y_bar = rng.randn(3, 3, 4)
|
|
|
|
x_bar, = vmap(lambda x, y_bar: vjp(f, x)[1](y_bar), axis_name='i')(x, y_bar)
|
|
self.assertAllClose(x_bar, np.sum(y_bar, axis=0))
|
|
|
|
def testAllGatherOfConst(self):
|
|
def f(x):
|
|
a = lax.all_gather(jnp.ones_like(x), axis_name='i')
|
|
b = lax.all_gather(1, axis_name='i')
|
|
return a, b
|
|
|
|
x = jnp.arange(15).reshape((3, 5))
|
|
a, b = vmap(f, axis_name='i', in_axes=1, out_axes=None)(x)
|
|
self.assertAllClose(a, jnp.ones(shape=(5, 3), dtype=x.dtype))
|
|
self.assertAllClose(b, jnp.ones(shape=(5,), dtype=b.dtype))
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_shape={}_axis={}_collective={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
axis, collective.__name__.replace(" ", "")),
|
|
"shape": shape, "dtype": dtype, "axis": axis,
|
|
"collective": collective, "bulk_op": bulk_op}
|
|
for collective, bulk_op in [(parallel.pargmax, jnp.argmax),
|
|
(parallel.pargmin, jnp.argmin)]
|
|
for dtype in [np.float32, np.int32]
|
|
for shape in [(7,), (5, 8)]
|
|
for axis in range(len(shape))
|
|
)
|
|
def testArgAllReduce(self, shape, dtype, axis, collective, bulk_op):
|
|
rng = jtu.rand_default(self.rng())
|
|
x = rng(shape, dtype)
|
|
ans = vmap(lambda x: collective(x, 'i'), in_axes=axis, out_axes=None,
|
|
axis_name='i')(x)
|
|
expected = bulk_op(x, axis=axis)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def testReduceScatterAutodiff(self):
|
|
f = vmap(partial(lax.psum_scatter, axis_name='i'), axis_name='i')
|
|
x = self.rng().randn(3, 3, 4)
|
|
jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.)
|
|
|
|
def testNonJaxTypedOutput(self):
|
|
with self.assertRaisesRegex(
|
|
TypeError, "Output from batched function.*is not a valid JAX type"):
|
|
vmap(lambda x: "hello")(np.arange(5))
|
|
|
|
def testIssue6096(self):
|
|
def f(x):
|
|
return jsp.special.betainc(jnp.ones(3), 1., x)
|
|
|
|
self.assertEqual(f(jnp.ones(3)).shape, (3,))
|
|
self.assertEqual(jax.vmap(f)(jnp.ones((2, 3))).shape, (2, 3))
|
|
|
|
def testPpermuteBatcherTrivial(self):
|
|
# https://github.com/jax-ml/jax/issues/8688
|
|
def ppermute(input):
|
|
return jax.lax.ppermute(input, axis_name="i", perm=[[0, 1], [1, 0]])
|
|
|
|
grad_fn = jax.grad(ppermute)
|
|
|
|
vmapped_gradients_fn = jax.vmap(grad_fn, axis_name="i")
|
|
|
|
vector = jax.numpy.array([1., 2.])
|
|
ans = vmapped_gradients_fn(vector) # doesn't crash
|
|
self.assertAllClose(ans, jnp.ones(2), check_dtypes=False)
|
|
|
|
def testBatchingPreservesWeakType(self):
|
|
# Regression test for https://github.com/jax-ml/jax/issues/10025
|
|
x = jnp.ravel(1)
|
|
self.assertTrue(dtypes.is_weakly_typed(x))
|
|
@vmap
|
|
def f(x):
|
|
self.assertTrue(dtypes.is_weakly_typed(x), f"{x} is not weakly-typed")
|
|
return x
|
|
y = f(x)
|
|
self.assertTrue(dtypes.is_weakly_typed(y))
|
|
|
|
|
|
Array = Any
|
|
ArrayElt = Any
|
|
Int = Union[int, core.Tracer]
|
|
|
|
# Can't used NamedTuple here b/c those are pytrees
|
|
class NamedArray:
|
|
names: list[str]
|
|
data: Array
|
|
|
|
def __init__(self, names, data):
|
|
assert len(names) == data.ndim
|
|
self.names = names
|
|
self.data = data
|
|
|
|
def __repr__(self) -> str:
|
|
return f'NamedArray(names={self.names}, data={self.data})'
|
|
|
|
class NamedMapSpec:
|
|
name: str | None
|
|
axis: int | None
|
|
|
|
def __init__(self, name: str, axis: int | None):
|
|
assert (name is None) == (axis is None)
|
|
self.name = name
|
|
self.axis = axis
|
|
|
|
def named_mul(x: NamedArray, y: NamedArray) -> NamedArray:
|
|
if x.names != y.names: raise Exception
|
|
return NamedArray(x.names, lax.mul(x.data, y.data))
|
|
|
|
# TODO(mattjj): don't make this a pytree
|
|
register_pytree_node(NamedArray,
|
|
lambda x: ((x.data,), x.names),
|
|
lambda names, xs: NamedArray(names, xs[0]))
|
|
|
|
|
|
def named_to_elt(cont: Callable[[Array, int | None], ArrayElt],
|
|
_: Int, val: NamedArray, spec: NamedMapSpec) -> NamedArray:
|
|
if spec.name is None:
|
|
return val
|
|
else:
|
|
elt_names, mapped_name = list_pop(val.names, spec.axis)
|
|
if mapped_name != spec.name: raise Exception
|
|
elt = cont(val.data, spec.axis)
|
|
return NamedArray(elt_names, elt)
|
|
|
|
def named_from_elt(cont: Callable[[int, ArrayElt, int | None], Array],
|
|
axis_size: int, elt: NamedArray, annotation: NamedMapSpec
|
|
) -> NamedArray:
|
|
data = cont(axis_size, elt.data, annotation.axis)
|
|
if annotation.axis is None:
|
|
return NamedArray(elt.names, data)
|
|
else:
|
|
names = list_insert(elt.names, annotation.axis, annotation.name)
|
|
return NamedArray(names, data)
|
|
|
|
@contextmanager
|
|
def temporarily_register_named_array_vmappable():
|
|
batching.register_vmappable(NamedArray, NamedMapSpec, int,
|
|
named_to_elt, named_from_elt, None)
|
|
try:
|
|
yield
|
|
finally:
|
|
batching.unregister_vmappable(NamedArray)
|
|
|
|
a = TypeVar('a')
|
|
|
|
def list_pop(lst: list[a], idx: int) -> a:
|
|
lst = list(lst)
|
|
return lst, lst.pop(idx)
|
|
|
|
def list_insert(lst: list[a], idx: int, val: a) -> list[a]:
|
|
lst = list(lst)
|
|
lst.insert(idx, val)
|
|
return lst
|
|
|
|
|
|
@jtu.thread_unsafe_test_class() # temporary registration isn't thread-safe
|
|
class VmappableTest(jtu.JaxTestCase):
|
|
def test_basic(self):
|
|
with temporarily_register_named_array_vmappable():
|
|
def f(x):
|
|
return named_mul(x, x)
|
|
|
|
x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
|
|
g = jax.vmap(f,
|
|
in_axes=NamedMapSpec('i', 0),
|
|
out_axes=NamedMapSpec('i', 1),
|
|
axis_size=3)
|
|
ans = g(x)
|
|
expected = NamedArray(['j', 'i'], jnp.arange(12.).reshape(3, 4).T ** 2)
|
|
|
|
self.assertEqual(ans.names, expected.names)
|
|
self.assertAllClose(ans.data, expected.data)
|
|
|
|
def test_basic_jit(self):
|
|
with temporarily_register_named_array_vmappable():
|
|
def f(x):
|
|
return named_mul(x, x)
|
|
|
|
x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
|
|
ans = jax.jit(f)(x)
|
|
expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2)
|
|
|
|
self.assertEqual(ans.names, expected.names)
|
|
self.assertAllClose(ans.data, expected.data)
|
|
|
|
def test_types_with_same_spec(self):
|
|
# We register NamedArray.
|
|
batching.register_vmappable(NamedArray, NamedMapSpec, int,
|
|
named_to_elt, named_from_elt, None)
|
|
|
|
# We then register another type that uses NamedMapSpec as the spec_type too,
|
|
# and immediately unregister it.
|
|
class Foo:
|
|
pass
|
|
batching.register_vmappable(Foo, NamedMapSpec, int,
|
|
named_to_elt, named_from_elt, None)
|
|
batching.unregister_vmappable(Foo)
|
|
|
|
# We should still be able to use vmap on NamedArray.
|
|
def f(x):
|
|
return named_mul(x, x)
|
|
|
|
x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
|
|
ans = jax.jit(f)(x)
|
|
expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2)
|
|
|
|
self.assertEqual(ans.names, expected.names)
|
|
self.assertAllClose(ans.data, expected.data)
|
|
|
|
# And unregister NamedArray without exceptions.
|
|
batching.unregister_vmappable(NamedArray)
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|