2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2018-11-17 18:03:33 -08:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-06-26 14:44:52 -04:00
|
|
|
from collections.abc import Callable
|
2021-10-06 14:18:07 -07:00
|
|
|
from contextlib import contextmanager
|
2021-09-13 17:24:44 -04:00
|
|
|
from functools import partial
|
2020-09-22 11:19:06 +00:00
|
|
|
import itertools as it
|
2024-06-26 14:44:52 -04:00
|
|
|
from typing import Any, TypeVar, Union
|
2021-10-06 14:18:07 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
import numpy as np
|
2018-11-17 18:03:33 -08:00
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
2020-01-18 08:26:23 -05:00
|
|
|
import jax
|
2020-05-05 14:59:16 -04:00
|
|
|
import jax.numpy as jnp
|
2021-03-19 21:01:00 -07:00
|
|
|
import jax.scipy as jsp
|
2023-02-14 23:00:40 -08:00
|
|
|
from jax._src import core
|
2022-03-30 11:06:56 -07:00
|
|
|
from jax._src import dtypes
|
2021-09-24 07:02:08 -07:00
|
|
|
from jax._src import test_util as jtu
|
2018-11-19 07:43:23 -08:00
|
|
|
from jax import lax
|
2021-02-08 20:24:19 -08:00
|
|
|
from jax._src.lax import parallel
|
2018-12-30 22:26:22 -08:00
|
|
|
from jax import random
|
2021-04-13 09:42:54 -07:00
|
|
|
from jax import jit, grad, jvp, vjp, make_jaxpr, jacfwd, jacrev, hessian
|
|
|
|
from jax import vmap
|
2021-10-06 14:18:07 -07:00
|
|
|
from jax.interpreters import batching
|
|
|
|
from jax.tree_util import register_pytree_node
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2024-04-11 13:23:27 -07:00
|
|
|
jax.config.parse_flags_with_absl()
|
2018-12-12 09:00:39 -08:00
|
|
|
|
2019-01-10 15:35:15 -08:00
|
|
|
|
2019-09-03 17:09:27 -07:00
|
|
|
# These are 'manual' tests for batching (vmap). The more exhaustive, more
|
|
|
|
# systematic tests are in lax_test.py's LaxVmapTest class.
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class BatchingTest(jtu.JaxTestCase):
|
|
|
|
|
2019-02-10 18:36:21 -08:00
|
|
|
def testConstantFunction(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = vmap(lambda x: 3)(np.ones(4))
|
|
|
|
expected = 3 * np.ones(4)
|
2019-02-10 18:36:21 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2023-06-09 09:22:52 -07:00
|
|
|
@jax.default_matmul_precision("float32")
|
2019-02-10 18:36:21 -08:00
|
|
|
def testNestedBatchingMatMat(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
matvec = vmap(jnp.vdot, in_axes=(0, None))
|
2019-02-10 18:36:21 -08:00
|
|
|
matmat = vmap(matvec, in_axes=(None, 1), out_axes=1)
|
|
|
|
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-02-10 18:36:21 -08:00
|
|
|
A = R(4, 3)
|
|
|
|
B = R(3, 2)
|
|
|
|
|
|
|
|
ans = matmat(A, B)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.dot(A, B)
|
2023-06-09 09:22:52 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
jaxpr = make_jaxpr(matmat)(A, B)
|
2021-10-06 14:18:07 -07:00
|
|
|
self.assertLen(jaxpr.jaxpr.eqns, 1)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
def testPerExampleGradients(self):
|
|
|
|
def predict(params, inputs):
|
|
|
|
for W, b in params:
|
2020-05-05 14:59:16 -04:00
|
|
|
outputs = jnp.dot(W, inputs) + b
|
|
|
|
inputs = jnp.tanh(outputs)
|
2019-02-10 18:36:21 -08:00
|
|
|
return outputs
|
|
|
|
|
|
|
|
def loss(params, data):
|
|
|
|
inputs, targets = data
|
|
|
|
predictions = predict(params, inputs)
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sum((predictions - targets)**2)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
batch_size = 5
|
|
|
|
layer_sizes = [3, 2, 4]
|
|
|
|
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-02-10 18:36:21 -08:00
|
|
|
params = [(R(m, n), R(m))
|
|
|
|
for m, n in zip(layer_sizes[1:], layer_sizes[:-1])]
|
|
|
|
|
|
|
|
input_batch = R(5, 3)
|
|
|
|
target_batch = R(5, 4)
|
|
|
|
batch = (input_batch, target_batch)
|
|
|
|
|
|
|
|
ans = vmap(partial(grad(loss), params))(batch)
|
|
|
|
|
|
|
|
for ans_pair, param_pair in zip(ans, params):
|
|
|
|
dW, db = ans_pair
|
|
|
|
W, b = param_pair
|
|
|
|
|
|
|
|
self.assertEqual(dW.shape, (batch_size,) + W.shape)
|
|
|
|
self.assertEqual(db.shape, (batch_size,) + b.shape)
|
|
|
|
|
2023-06-09 09:22:52 -07:00
|
|
|
@jax.default_matmul_precision("float32")
|
2019-02-10 18:36:21 -08:00
|
|
|
def testJacobians(self):
|
|
|
|
def jacbwd(f, x):
|
|
|
|
y, pullback = vjp(f, x)
|
2020-05-05 14:59:16 -04:00
|
|
|
std_basis = np.eye(np.size(y)).reshape((-1,) + np.shape(y))
|
|
|
|
jac_flat, = vmap(pullback, out_axes=np.ndim(y))(std_basis)
|
|
|
|
return jac_flat.reshape(np.shape(y) + np.shape(x))
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
def jacfwd(f, x):
|
|
|
|
pushfwd = lambda v: jvp(f, (x,), (v,))
|
2020-05-05 14:59:16 -04:00
|
|
|
std_basis = np.eye(np.size(x)).reshape((-1,) + np.shape(x))
|
2019-02-10 18:36:21 -08:00
|
|
|
y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis)
|
2020-05-05 14:59:16 -04:00
|
|
|
return jac_flat.reshape(np.shape(y) + np.shape(x))
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
A = R(4, 3)
|
|
|
|
b = R(4)
|
2020-05-05 14:59:16 -04:00
|
|
|
f = lambda x: jnp.tanh(jnp.dot(A, x) + b)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
x = R(3)
|
2023-06-09 09:22:52 -07:00
|
|
|
self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
def testBatchOfCompile(self):
|
|
|
|
side = []
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def f(x):
|
|
|
|
side.append(None)
|
|
|
|
return x + x
|
|
|
|
|
|
|
|
g = jit(vmap(f))
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(g(np.ones(2)), 2 * np.ones(2), check_dtypes=False)
|
2019-02-10 18:36:21 -08:00
|
|
|
self.assertEqual(len(side), 1)
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(g(2 * np.ones(2)), 4 * np.ones(2),
|
2019-02-10 18:36:21 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
self.assertEqual(len(side), 1)
|
|
|
|
|
|
|
|
def testSliceLax(self):
|
|
|
|
fun = lambda x: lax.slice(x, (2,), (4,))
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-02-10 18:36:21 -08:00
|
|
|
x = R(5, 10)
|
|
|
|
|
|
|
|
ans = vmap(fun)(x)
|
|
|
|
expected_ans = x[:, 2:4]
|
|
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
|
|
|
|
def testSliceNumpy(self):
|
|
|
|
fun = lambda x: x[:, 2]
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-02-10 18:36:21 -08:00
|
|
|
x = R(10, 5, 3, 7)
|
|
|
|
|
|
|
|
ans = vmap(fun)(x)
|
|
|
|
expected_ans = x[:, :, 2]
|
|
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
|
|
|
|
def testRevLax(self):
|
|
|
|
fun = lambda x: lax.rev(x, [0])
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-02-10 18:36:21 -08:00
|
|
|
x = R(2, 3)
|
|
|
|
|
|
|
|
ans = vmap(fun)(x)
|
|
|
|
expected_ans = x[:, ::-1]
|
|
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = vmap(fun, (1,), 1)(x)
|
|
|
|
expected_ans = x[::-1, :]
|
|
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
|
|
|
|
def testRevNumpy(self):
|
|
|
|
fun = lambda x: x[:, ::-1]
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-02-10 18:36:21 -08:00
|
|
|
x = R(3, 2, 4)
|
|
|
|
|
|
|
|
ans = vmap(fun)(x)
|
|
|
|
expected_ans = x[:, :, ::-1]
|
|
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = vmap(fun, (1,), 1)(x)
|
|
|
|
expected_ans = x[:, :, ::-1]
|
|
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = vmap(fun, (2,), 2)(x)
|
|
|
|
expected_ans = x[:, ::-1, :]
|
|
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
|
|
|
|
def testNpMaximum(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
fun = lambda x: jnp.maximum(x, 0.0)
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-02-10 18:36:21 -08:00
|
|
|
x = R(10, 5, 3, 7)
|
|
|
|
|
|
|
|
ans = vmap(fun)(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected_ans = np.maximum(x, 0.0)
|
2019-02-10 18:36:21 -08:00
|
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
|
|
|
|
def testNpGtrThan(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-02-10 18:36:21 -08:00
|
|
|
x = R(10, 5, 3, 7)
|
|
|
|
|
|
|
|
ans = vmap(lambda x: x > 1.0)(x)
|
|
|
|
expected_ans = x > 1.0
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected_ans)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2023-06-09 09:22:52 -07:00
|
|
|
@jax.default_matmul_precision("float32")
|
2019-02-10 18:36:21 -08:00
|
|
|
def testNpMaximumPerExampleGrad(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-02-10 18:36:21 -08:00
|
|
|
x = R(10, 5)
|
|
|
|
W = R(5, 5)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
fun = lambda W, x: jnp.sum(jnp.maximum(jnp.dot(x, W), 0.0) ** 2)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
ans = vmap(partial(grad(fun), W))(x)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
W_t = jnp.transpose(W)
|
2019-02-10 18:36:21 -08:00
|
|
|
for i in range(10):
|
|
|
|
x_ex = x[i:i + 1]
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
expected_ans = 2.0 * jnp.dot(
|
|
|
|
jnp.maximum(jnp.dot(W_t, jnp.transpose(x_ex)), 0.0), x_ex)
|
|
|
|
expected_ans = jnp.transpose(expected_ans)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2023-06-09 09:22:52 -07:00
|
|
|
self.assertAllClose(ans[i], expected_ans, check_dtypes=False)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2023-09-12 14:30:40 -07:00
|
|
|
# Replace the default TF32 with float32 in order to make it pass on A100
|
|
|
|
@jax.default_matmul_precision("float32")
|
2019-02-10 18:36:21 -08:00
|
|
|
def testDotGeneral(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
x = R(10, 3, 4, 5)
|
|
|
|
y = R(10, 3, 5, 6)
|
|
|
|
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
|
|
|
|
ans = vmap(fun)(x, y)
|
|
|
|
expected = lax.dot_general(x, y, [((3,), (2,)), ((0, 1), (0, 1))])
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
x = R(3, 4, 10, 5)
|
|
|
|
y = R(3, 10, 5, 6)
|
|
|
|
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
|
|
|
|
ans = vmap(fun, in_axes=(2, 1))(x, y)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack([fun(x[..., i, :], y[:, i, ...]) for i in range(10)])
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
x = R(3, 4, 5, 10)
|
|
|
|
y = R(3, 5, 6)
|
|
|
|
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
|
|
|
|
ans = vmap(fun, in_axes=(3, None))(x, y)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack([fun(x[..., i], y) for i in range(10)])
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
x = R(3, 4, 5)
|
|
|
|
y = R(3, 5, 10, 6)
|
|
|
|
fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
|
|
|
|
ans = vmap(fun, in_axes=(None, 2))(x, y)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack([fun(x, y[..., i, :]) for i in range(10)])
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2019-10-08 13:06:43 -07:00
|
|
|
x = R(4)
|
|
|
|
y = R(4, 10)
|
|
|
|
fun = lambda x, y: lax.dot_general(x, y, [((0,), (0,)), ((), ())])
|
|
|
|
ans = vmap(fun, in_axes=(None, 1))(x, y)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack([fun(x, y[..., i]) for i in range(10)])
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-10-08 13:06:43 -07:00
|
|
|
|
2019-02-10 18:36:21 -08:00
|
|
|
def testDot(self):
|
|
|
|
# these tests are based on @shoyer's notebook studying gufuncs
|
|
|
|
|
|
|
|
def vecvec(a, b):
|
2020-05-05 14:59:16 -04:00
|
|
|
dot = jnp.dot
|
2019-02-10 18:36:21 -08:00
|
|
|
for ndim in range(1, max(a.ndim, b.ndim)):
|
|
|
|
a_ax = 0 if a.ndim > ndim else None
|
|
|
|
b_ax = 0 if b.ndim > ndim else None
|
|
|
|
dot = vmap(dot, in_axes=(a_ax, b_ax))
|
|
|
|
return dot(a, b)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
assert vecvec(jnp.zeros((3,)), jnp.zeros((3,))).shape == ()
|
|
|
|
assert vecvec(jnp.zeros((2, 3)), jnp.zeros((3,))).shape == (2,)
|
|
|
|
assert vecvec(jnp.zeros((4, 2, 3)), jnp.zeros((3,))).shape == (4, 2)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2019-02-17 09:34:49 -08:00
|
|
|
def testDot2(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-02-17 09:34:49 -08:00
|
|
|
xs = R(10, 3)
|
|
|
|
ys = R(10, 3)
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = vmap(jnp.dot)(xs, ys)
|
|
|
|
expected = np.einsum('ni,ni->n', xs, ys)
|
2019-02-17 09:34:49 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-06-05 15:17:06 -07:00
|
|
|
def testDot3(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-06-05 15:17:06 -07:00
|
|
|
xs = R(5, 8, 10)
|
|
|
|
ys = R(10, 1)
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = vmap(jnp.dot, in_axes=(1, None))(xs, ys)
|
|
|
|
expected = np.einsum('inj,jk->nik', xs, ys)
|
2019-06-05 15:17:06 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-06-12 18:02:01 -07:00
|
|
|
def testDot4(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-06-12 18:02:01 -07:00
|
|
|
xs = R(3, 2)
|
|
|
|
ys = R(3)
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = vmap(jnp.dot, in_axes=(1, None))(xs, ys)
|
|
|
|
expected = np.einsum('ij,i->j', xs, ys)
|
2019-06-12 18:02:01 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
|
2019-02-10 18:36:21 -08:00
|
|
|
def testPad(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().randn
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
fun = lambda x: lax.pad(x, np.float32(0), [(1, 2, 1)])
|
|
|
|
x = R(5, 10).astype(np.float32)
|
2019-02-10 18:36:21 -08:00
|
|
|
ans = vmap(fun)(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected_ans = jnp.stack(list(map(fun, x)))
|
2019-02-10 18:36:21 -08:00
|
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
fun = lambda x: lax.pad(x, np.float32(0), [(1, 2, 1), (0, 1, 0)])
|
|
|
|
x = R(5, 10, 3).astype(np.float32)
|
2019-02-10 18:36:21 -08:00
|
|
|
ans = vmap(fun)(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected_ans = jnp.stack(list(map(fun, x)))
|
2019-02-10 18:36:21 -08:00
|
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
|
|
|
|
def testConcatenate(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
R = lambda *shape: self.rng().randn(*shape).astype(np.float32)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
fun = lambda *args: lax.concatenate(args, dimension=0)
|
|
|
|
x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3)
|
|
|
|
ans = vmap(fun, in_axes=(0, 1, None))(x, y, z)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected_ans = np.concatenate([x, np.swapaxes(y, 0, 1),
|
|
|
|
np.broadcast_to(z, (10, 4, 3))], 1)
|
2019-02-10 18:36:21 -08:00
|
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
|
|
|
|
fun = lambda *args: lax.concatenate(args, dimension=1)
|
|
|
|
x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10)
|
|
|
|
ans = vmap(fun, in_axes=(0, None, 2))(x, y, z)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected_ans = np.concatenate([x, np.broadcast_to(y, (10, 2, 3)),
|
|
|
|
np.moveaxis(z, 2, 0)], 2)
|
2019-02-10 18:36:21 -08:00
|
|
|
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
|
|
|
|
|
|
|
def testJacobianIssue54(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# test modeling the code in https://github.com/jax-ml/jax/issues/54
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
def func(xs):
|
2020-09-17 21:51:18 +05:30
|
|
|
return jnp.array(list(xs))
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
xs = jnp.ones((5, 1))
|
2019-02-10 18:36:21 -08:00
|
|
|
jacrev(func)(xs) # don't crash
|
|
|
|
jacfwd(func)(xs) # don't crash
|
|
|
|
|
|
|
|
def testAny(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# test modeling the code in https://github.com/jax-ml/jax/issues/108
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]]))
|
|
|
|
expected = jnp.array([True, False])
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
def testHessian(self):
|
|
|
|
# test based on code from sindhwani@google
|
|
|
|
def fun(x, t):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sum(jnp.power(jnp.maximum(x, 0.0), 2)) + t
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.array([-1., -0.5, 0., 0.5, 1.0])
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
ans = hessian(lambda x: fun(x, 0.0))(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.array([[0., 0., 0., 0., 0.],
|
2019-02-10 18:36:21 -08:00
|
|
|
[0., 0., 0., 0., 0.],
|
|
|
|
[0., 0.,0.5, 0., 0.],
|
|
|
|
[0., 0., 0., 2., 0.],
|
|
|
|
[0., 0., 0., 0., 2.]])
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def testDynamicSlice(self):
|
|
|
|
# test dynamic_slice via numpy indexing syntax
|
2024-09-20 07:51:48 -07:00
|
|
|
# see https://github.com/jax-ml/jax/issues/1613 for an explanation of why we
|
2020-05-05 14:59:16 -04:00
|
|
|
# need to use np rather than np to create x and idx
|
|
|
|
x = jnp.arange(30).reshape((10, 3))
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1)
|
|
|
|
expected = x[:, 1]
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
idx = jnp.array([0, 1, 2, 1, 0] * 2)
|
2019-02-10 18:36:21 -08:00
|
|
|
ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = x[np.arange(10), idx]
|
2019-02-10 18:36:21 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(3)
|
|
|
|
idx = jnp.array([0, 1, 2, 1, 0] * 2)
|
2019-02-10 18:36:21 -08:00
|
|
|
ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx)
|
|
|
|
expected = x[idx]
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-04-30 11:48:53 -04:00
|
|
|
def testDynamicUpdateSlice(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
x = self.rng().randn(10, 3)
|
|
|
|
y = self.rng().randn(10)
|
2019-04-30 11:48:53 -04:00
|
|
|
ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
|
|
|
|
in_axes=(0, 0, None))(x, y, 1)
|
|
|
|
expected = x.copy()
|
|
|
|
expected[:, 1] = y
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-12-10 10:32:09 -08:00
|
|
|
x = self.rng().randn(3)
|
2020-05-05 14:59:16 -04:00
|
|
|
idx = np.array([0, 1, 2, 1, 0] * 2)
|
2019-04-30 11:48:53 -04:00
|
|
|
ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
|
|
|
|
in_axes=(None, 0, 0))(x, y, idx)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.broadcast_to(x, (10, 3)).copy()
|
|
|
|
expected[np.arange(10), idx] = y
|
2019-04-30 11:48:53 -04:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2023-08-25 14:11:19 -07:00
|
|
|
@jax.legacy_prng_key('allow')
|
2019-02-10 18:36:21 -08:00
|
|
|
def testRandom(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
seeds = vmap(random.PRNGKey)(np.arange(10))
|
2019-02-10 18:36:21 -08:00
|
|
|
ans = vmap(partial(random.normal, shape=(3, 2)))(seeds)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack([random.normal(random.PRNGKey(seed), (3, 2))
|
|
|
|
for seed in np.arange(10)])
|
2019-02-10 18:36:21 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
2020-05-05 14:59:16 -04:00
|
|
|
assert len(np.unique(ans)) == 10 * 3 * 2
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2019-08-01 12:39:33 -04:00
|
|
|
def testSort(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
v = np.arange(12)[::-1].reshape(3, 4)
|
2019-08-01 12:39:33 -04:00
|
|
|
|
|
|
|
sv = vmap(partial(lax.sort, dimension=0), (0,))(v)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(sv, v[:, ::-1])
|
2019-08-01 12:39:33 -04:00
|
|
|
|
|
|
|
sv = vmap(partial(lax.sort, dimension=-1), (0,))(v)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(sv, v[:, ::-1])
|
2019-08-01 12:39:33 -04:00
|
|
|
|
|
|
|
sv = vmap(partial(lax.sort, dimension=0), (1,))(v)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(sv, v[::-1, :].T)
|
2019-08-01 12:39:33 -04:00
|
|
|
|
|
|
|
sv = vmap(partial(lax.sort, dimension=0), (1,), 1)(v)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(sv, v[::-1, :])
|
2019-08-01 12:39:33 -04:00
|
|
|
|
2019-02-10 18:36:21 -08:00
|
|
|
def testSortKeyVal(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
k = np.arange(12)[::-1].reshape(3, 4)
|
2021-12-10 10:32:09 -08:00
|
|
|
v = self.rng().permutation(12).reshape(3, 4)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(sk, k[:, ::-1])
|
|
|
|
self.assertAllClose(sv, v[:, ::-1])
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(sk, k[::-1, :])
|
|
|
|
self.assertAllClose(sv, v[::-1, :])
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(sk, k[:, ::-1])
|
|
|
|
self.assertAllClose(sv, v[:, ::-1])
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(sk, k[:, ::-1])
|
|
|
|
self.assertAllClose(sv, v[:, ::-1])
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(sk, np.broadcast_to(k[0, ::-1], (3, 4)))
|
|
|
|
self.assertAllClose(sv, v[:, ::-1])
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0])
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(sk, k[:, ::-1])
|
|
|
|
self.assertAllClose(sv, np.broadcast_to(v[0, ::-1], (3, 4)))
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
def testConvGeneralDilated(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
W = jnp.array(self.rng().randn(3, 3, 1, 5), dtype=np.float32)
|
|
|
|
X = jnp.array(self.rng().randn(10, 5, 5, 1), dtype=np.float32)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
def f(params, x):
|
|
|
|
one = (1, 1)
|
|
|
|
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
|
|
|
|
y = lax.conv_general_dilated(
|
|
|
|
x, params, one, 'SAME', one, one, dimension_numbers)
|
|
|
|
return y
|
2020-05-05 14:59:16 -04:00
|
|
|
grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2))
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
# Test forward prop.
|
2020-05-05 14:59:16 -04:00
|
|
|
per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
|
|
|
per_example = jnp.reshape(per_example, (10, 5, 5, 5))
|
2019-02-10 18:36:21 -08:00
|
|
|
per_example_direct = f(W, X)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(per_example, per_example_direct)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
# Test gradients.
|
2020-05-05 14:59:16 -04:00
|
|
|
per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
2019-02-10 18:36:21 -08:00
|
|
|
per_example_direct = []
|
|
|
|
for i in range(10):
|
2020-05-05 14:59:16 -04:00
|
|
|
g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1)))
|
2019-02-10 18:36:21 -08:00
|
|
|
per_example_direct += [
|
2020-05-05 14:59:16 -04:00
|
|
|
jnp.reshape(g, (1,) + g.shape)]
|
|
|
|
per_example_direct = jnp.concatenate(per_example_direct, axis=0)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(per_example, per_example_direct,
|
2020-10-12 09:35:39 -04:00
|
|
|
rtol=2e-2, atol=2e-3)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2019-04-03 12:41:14 -07:00
|
|
|
def testConvGeneralDilatedBatchNotMajor(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
W = jnp.array(self.rng().randn(3, 3, 1, 4), dtype=np.float32)
|
|
|
|
x = jnp.array(self.rng().randn(3, 5, 7, 5, 1), dtype=np.float32)
|
2019-04-03 12:41:14 -07:00
|
|
|
|
|
|
|
def f(params, x):
|
|
|
|
one = (1, 1)
|
|
|
|
dimension_numbers = ('HNWC', 'HWIO', 'HWNC')
|
|
|
|
y = lax.conv_general_dilated(
|
|
|
|
x, params, one, 'SAME', one, one, dimension_numbers)
|
|
|
|
return y
|
|
|
|
|
|
|
|
per_example = vmap(partial(f, W))(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
per_example = jnp.reshape(jnp.transpose(per_example, (1, 2, 0, 3, 4)),
|
2019-04-03 12:41:14 -07:00
|
|
|
(5, 5, 21, 4))
|
2020-05-05 14:59:16 -04:00
|
|
|
per_example_direct = f(W, jnp.reshape(jnp.transpose(x, (1, 0, 2, 3, 4)),
|
2019-04-03 12:41:14 -07:00
|
|
|
(5, 21, 5, 1)))
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(per_example, per_example_direct)
|
2019-04-03 12:41:14 -07:00
|
|
|
|
2019-06-26 10:19:42 -04:00
|
|
|
@parameterized.named_parameters(
|
2022-05-12 19:13:00 +01:00
|
|
|
{"testcase_name": f"_op={name}", "op": op, "unit": unit}
|
2020-05-05 14:59:16 -04:00
|
|
|
for name, op, unit in [("max", lax.max, -jnp.inf), ("min", lax.min, jnp.inf)])
|
2019-06-26 10:19:42 -04:00
|
|
|
def testMinMaxPool(self, op, unit):
|
2021-12-10 10:32:09 -08:00
|
|
|
W = jnp.array(self.rng().randn(3, 3, 1, 5), dtype=np.float32)
|
|
|
|
X = jnp.array(self.rng().randn(10, 5, 5, 1), dtype=np.float32)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
def f(params, x):
|
|
|
|
one = (1, 1)
|
|
|
|
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
|
|
|
|
y = lax.conv_general_dilated(
|
|
|
|
x, params, one, 'SAME', one, one, dimension_numbers)
|
|
|
|
y = lax.reduce_window(
|
2019-06-26 10:19:42 -04:00
|
|
|
y, unit, op, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME')
|
2019-02-10 18:36:21 -08:00
|
|
|
return y
|
2020-05-05 14:59:16 -04:00
|
|
|
grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2))
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
# Test forward prop.
|
2020-05-05 14:59:16 -04:00
|
|
|
per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
|
|
|
per_example = jnp.reshape(per_example, (10, 5, 5, 5))
|
2019-02-10 18:36:21 -08:00
|
|
|
per_example_direct = f(W, X)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(per_example, per_example_direct)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
# Test gradients.
|
2020-05-05 14:59:16 -04:00
|
|
|
per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
2019-02-10 18:36:21 -08:00
|
|
|
per_example_direct = []
|
|
|
|
for i in range(10):
|
2020-05-05 14:59:16 -04:00
|
|
|
g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1)))
|
2019-02-10 18:36:21 -08:00
|
|
|
per_example_direct += [
|
2020-05-05 14:59:16 -04:00
|
|
|
jnp.reshape(g, (1,) + g.shape)]
|
|
|
|
per_example_direct = jnp.concatenate(per_example_direct, axis=0)
|
2020-07-14 14:10:13 -04:00
|
|
|
self.assertAllClose(per_example, per_example_direct, rtol=5e-2, atol=1e-3)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
def testSumPool(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
W = jnp.array(self.rng().randn(3, 3, 1, 5), dtype=np.float32)
|
|
|
|
X = jnp.array(self.rng().randn(10, 5, 5, 1), dtype=np.float32)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
def f(params, x):
|
|
|
|
one = (1, 1)
|
|
|
|
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
|
|
|
|
y = lax.conv_general_dilated(
|
|
|
|
x, params, one, 'SAME', one, one, dimension_numbers)
|
|
|
|
y = lax.reduce_window(
|
|
|
|
y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME')
|
|
|
|
return y
|
2020-05-05 14:59:16 -04:00
|
|
|
grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2))
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
# Test forward prop.
|
2020-05-05 14:59:16 -04:00
|
|
|
per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
|
|
|
per_example = jnp.reshape(per_example, (10, 5, 5, 5))
|
2019-02-10 18:36:21 -08:00
|
|
|
per_example_direct = f(W, X)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(per_example, per_example_direct)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
# Test gradients.
|
2020-05-05 14:59:16 -04:00
|
|
|
per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
|
2019-02-10 18:36:21 -08:00
|
|
|
per_example_direct = []
|
|
|
|
for i in range(10):
|
2020-05-05 14:59:16 -04:00
|
|
|
g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1)))
|
2019-02-10 18:36:21 -08:00
|
|
|
per_example_direct += [
|
2020-05-05 14:59:16 -04:00
|
|
|
jnp.reshape(g, (1,) + g.shape)]
|
|
|
|
per_example_direct = jnp.concatenate(per_example_direct, axis=0)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(per_example, per_example_direct,
|
2020-10-12 09:35:39 -04:00
|
|
|
rtol=3e-2, atol=1e-3)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2019-07-13 10:22:26 -04:00
|
|
|
def testCumProd(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(9).reshape(3, 3) + 1
|
|
|
|
y = vmap(lambda x: jnp.cumprod(x, axis=-1))(x)
|
2022-12-02 13:20:30 -08:00
|
|
|
self.assertAllClose(jnp.cumprod(x, axis=1), y)
|
2019-07-13 10:22:26 -04:00
|
|
|
|
2019-02-10 18:36:21 -08:00
|
|
|
def testSelect(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
pred = np.array([True, False])
|
|
|
|
on_true = np.array([0, 1])
|
|
|
|
on_false = np.array([2, 3])
|
2019-02-10 18:36:21 -08:00
|
|
|
ans = vmap(lax.select)(pred, on_true, on_false)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.array([0, 3])
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
pred = np.array([False, True])
|
|
|
|
on_true = np.array([0, 1])
|
|
|
|
on_false = np.array([2, 3])
|
2019-02-10 18:36:21 -08:00
|
|
|
ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.array([[2, 3],
|
2019-02-10 18:36:21 -08:00
|
|
|
[0, 1]])
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
pred = True
|
2020-05-05 14:59:16 -04:00
|
|
|
on_true = np.array([0, 1], np.float32)
|
|
|
|
on_false = np.array(3, np.float32)
|
2019-02-10 18:36:21 -08:00
|
|
|
ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.array([0, 1], np.float32)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
pred = np.array([False, True])
|
|
|
|
on_true = np.array([0, 1], np.float32)
|
|
|
|
on_false = np.array(3, np.float32)
|
2019-02-10 18:36:21 -08:00
|
|
|
ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.array([3, 1], np.float32)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
pred = np.array([False, True])
|
|
|
|
on_true = np.array([2], np.float32)
|
|
|
|
on_false = np.array([[3, 4]], np.float32)
|
2019-02-10 18:36:21 -08:00
|
|
|
ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.array([[3, 2]], np.float32)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
def testLaxLinalgCholesky(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
a = self.rng().randn(10, 5, 5).astype(np.float32)
|
2020-05-05 14:59:16 -04:00
|
|
|
a = np.matmul(a, np.conj(np.swapaxes(a, -1, -2)))
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2020-11-04 13:33:30 -08:00
|
|
|
ans = vmap(lax.linalg.cholesky)(a)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.linalg.cholesky(a)
|
2021-12-10 10:32:09 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False, atol=1E-3)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2021-12-10 10:32:09 -08:00
|
|
|
b = self.rng().randn(10, 5, 5).astype(np.float32)
|
2020-05-05 14:59:16 -04:00
|
|
|
b = np.matmul(b, np.conj(np.swapaxes(b, -1, -2)))
|
|
|
|
b_trans = np.swapaxes(b, 0, 1) # shape is (5, 10, 5)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2020-11-04 13:33:30 -08:00
|
|
|
ans = vmap(lax.linalg.cholesky, in_axes=1, out_axes=0)(b_trans)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.linalg.cholesky(b)
|
2019-11-16 13:51:42 -05:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False, rtol=1e-4)
|
2019-02-06 10:58:41 -08:00
|
|
|
|
2019-03-10 17:31:51 -04:00
|
|
|
def testLaxLinalgTriangularSolve(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
a = self.rng().randn(4, 10, 4).astype(np.float32)
|
2020-05-05 14:59:16 -04:00
|
|
|
a += np.eye(4, dtype=jnp.float32)[:, None, :]
|
2021-12-10 10:32:09 -08:00
|
|
|
b = self.rng().randn(5, 4, 10).astype(np.float32)
|
2019-03-10 17:31:51 -04:00
|
|
|
|
2020-11-04 13:33:30 -08:00
|
|
|
ans = vmap(lax.linalg.triangular_solve, in_axes=(1, 2))(a, b)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack(
|
2020-11-04 13:33:30 -08:00
|
|
|
[lax.linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)])
|
2023-06-23 09:21:32 -07:00
|
|
|
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)
|
2019-03-10 17:31:51 -04:00
|
|
|
|
2020-11-04 13:33:30 -08:00
|
|
|
ans = vmap(lax.linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack(
|
2020-11-04 13:33:30 -08:00
|
|
|
[lax.linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10)])
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-03-10 17:31:51 -04:00
|
|
|
|
2020-11-04 13:33:30 -08:00
|
|
|
ans = vmap(lax.linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0])
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack(
|
2020-11-04 13:33:30 -08:00
|
|
|
[lax.linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)])
|
2023-06-23 09:21:32 -07:00
|
|
|
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)
|
2019-03-10 17:31:51 -04:00
|
|
|
|
2023-08-10 16:25:23 -07:00
|
|
|
def testLaxLinalgTridiagonalSolve(self):
|
|
|
|
dl = self.rng().randn(4, 10).astype(np.float32)
|
|
|
|
d = self.rng().randn(4, 10).astype(np.float32) + 1.
|
|
|
|
du = self.rng().randn(4, 10).astype(np.float32)
|
|
|
|
b = self.rng().randn(4, 5, 10).astype(np.float32)
|
|
|
|
|
|
|
|
ans = vmap(lax.linalg.tridiagonal_solve, in_axes=(1, 1, 1, 2))(dl, d, du, b)
|
|
|
|
expected = np.stack(
|
|
|
|
[lax.linalg.tridiagonal_solve(
|
|
|
|
dl[:, i], d[:, i], du[:, i], b[..., i]) for i in range(10)])
|
|
|
|
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)
|
|
|
|
|
|
|
|
ans = vmap(lax.linalg.tridiagonal_solve, in_axes=(None, None, None, 2))(
|
|
|
|
dl[:, 0], d[:, 0], du[:, 0], b)
|
|
|
|
expected = np.stack(
|
|
|
|
[lax.linalg.tridiagonal_solve(
|
|
|
|
dl[:, 0], d[:, 0], du[:, 0], b[..., i]) for i in range(10)])
|
|
|
|
self.assertAllClose(ans, expected)
|
|
|
|
|
|
|
|
ans = vmap(lax.linalg.tridiagonal_solve, in_axes=(1, 1, 1, None))(
|
|
|
|
dl, d, du, b[..., 0])
|
|
|
|
expected = np.stack(
|
|
|
|
[lax.linalg.tridiagonal_solve(
|
|
|
|
dl[:, i], d[:, i], du[:, i], b[..., 0]) for i in range(10)])
|
|
|
|
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)
|
|
|
|
|
|
|
|
|
2019-02-06 10:58:41 -08:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
|
|
|
|
slice_sizes),
|
|
|
|
"axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
|
2020-12-11 13:47:46 -08:00
|
|
|
"slice_sizes": slice_sizes}
|
2020-05-05 14:59:16 -04:00
|
|
|
for dtype in [np.float32, np.int32]
|
2019-02-06 10:58:41 -08:00
|
|
|
for axis, shape, idxs, dnums, slice_sizes in [
|
2020-05-05 14:59:16 -04:00
|
|
|
(0, (3, 5), np.array([[0], [2]]), lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(1, (10, 3), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
|
|
|
(2,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(1, (10, 3, 5), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1, 3)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(2, (10, 5, 3), np.array([[0, 2], [1, 0]]),
|
2019-03-01 11:59:54 -05:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,),
|
|
|
|
start_index_map=(0, 1)),
|
|
|
|
(1, 3)),
|
2020-12-11 13:47:46 -08:00
|
|
|
])
|
|
|
|
def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-02-06 10:58:41 -08:00
|
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
|
|
|
operand = rng(shape, dtype)
|
|
|
|
ans = vmap(fun, (axis, None))(operand, idxs)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack([fun(operand[(slice(None),) * axis + (i,)], idxs)
|
2019-02-06 10:58:41 -08:00
|
|
|
for i in range(operand.shape[axis])])
|
2019-02-05 08:39:03 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-02-11 10:24:21 -08:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
|
|
|
|
slice_sizes),
|
|
|
|
"axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
|
2020-12-11 13:47:46 -08:00
|
|
|
"slice_sizes": slice_sizes}
|
2020-05-05 14:59:16 -04:00
|
|
|
for dtype in [np.float32, np.float64]
|
2019-02-11 10:24:21 -08:00
|
|
|
for axis, shape, idxs, dnums, slice_sizes in [
|
2020-05-05 14:59:16 -04:00
|
|
|
(0, (3, 5), np.array([[0], [2]]), lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(1, (10, 3), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
|
|
|
(2,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(1, (10, 3, 5), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1, 3)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(2, (10, 5, 3), np.array([[0, 2], [1, 0]]),
|
2019-03-01 11:59:54 -05:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,),
|
|
|
|
start_index_map=(0, 1)),
|
2020-12-11 13:47:46 -08:00
|
|
|
(1, 3))
|
|
|
|
])
|
|
|
|
def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-02-11 10:24:21 -08:00
|
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
2020-05-05 14:59:16 -04:00
|
|
|
gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
|
2019-02-11 10:24:21 -08:00
|
|
|
operand = rng(shape, dtype)
|
|
|
|
ans = vmap(gfun, (axis, None))(operand, idxs)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack([gfun(operand[(slice(None),) * axis + (i,)], idxs)
|
2019-02-11 10:24:21 -08:00
|
|
|
for i in range(operand.shape[axis])])
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-02-10 18:36:21 -08:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
|
|
|
|
slice_sizes),
|
|
|
|
"axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
|
2020-12-11 13:47:46 -08:00
|
|
|
"slice_sizes": slice_sizes}
|
2020-05-05 14:59:16 -04:00
|
|
|
for dtype in [np.float32, np.int32]
|
2019-02-10 18:36:21 -08:00
|
|
|
for axis, shape, idxs, dnums, slice_sizes in [
|
2020-05-05 14:59:16 -04:00
|
|
|
(0, (5,), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(1, (10,), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
|
2019-02-10 18:36:21 -08:00
|
|
|
lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(1, (10, 5), np.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
|
2019-02-10 18:36:21 -08:00
|
|
|
lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(0, (10, 5), np.array([[[0, 1], [2, 0]],
|
2019-03-01 11:59:54 -05:00
|
|
|
[[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
|
2020-12-11 13:47:46 -08:00
|
|
|
])
|
|
|
|
def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-02-10 18:36:21 -08:00
|
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
|
|
|
operand = rng(shape, dtype)
|
|
|
|
ans = vmap(fun, (None, axis))(operand, idxs)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack([fun(operand, idxs[(slice(None),) * axis + (i,)])
|
2019-02-10 18:36:21 -08:00
|
|
|
for i in range(idxs.shape[axis])])
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-02-11 10:24:21 -08:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
|
|
|
|
slice_sizes),
|
|
|
|
"axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
|
2020-12-11 13:47:46 -08:00
|
|
|
"slice_sizes": slice_sizes}
|
2020-05-05 14:59:16 -04:00
|
|
|
for dtype in [np.float32, np.float64]
|
2019-02-11 10:24:21 -08:00
|
|
|
for axis, shape, idxs, dnums, slice_sizes in [
|
2020-05-05 14:59:16 -04:00
|
|
|
(0, (5,), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(1, (10,), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
|
2019-02-11 10:24:21 -08:00
|
|
|
lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(1, (10, 5), np.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
|
2019-02-11 10:24:21 -08:00
|
|
|
lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(0, (10, 5), np.array([[[0, 1], [2, 0]],
|
2019-03-01 11:59:54 -05:00
|
|
|
[[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
|
2020-12-11 13:47:46 -08:00
|
|
|
])
|
|
|
|
def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-02-11 10:24:21 -08:00
|
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
2020-05-05 14:59:16 -04:00
|
|
|
gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
|
2019-02-11 10:24:21 -08:00
|
|
|
operand = rng(shape, dtype)
|
|
|
|
ans = vmap(gfun, (None, axis))(operand, idxs)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack([gfun(operand, idxs[(slice(None),) * axis + (i,)])
|
2019-02-11 10:24:21 -08:00
|
|
|
for i in range(idxs.shape[axis])])
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-02-11 09:28:21 -08:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs,
|
|
|
|
dnums, slice_sizes),
|
|
|
|
"op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
|
2020-12-11 13:47:46 -08:00
|
|
|
dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes}
|
2020-05-05 14:59:16 -04:00
|
|
|
for dtype in [np.float32, np.int32]
|
2019-02-11 09:28:21 -08:00
|
|
|
for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
|
2020-05-05 14:59:16 -04:00
|
|
|
(0, 0, (2, 5), np.array([[[0], [2]], [[1], [3]]]),
|
2019-02-11 09:28:21 -08:00
|
|
|
lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(1, 1, (10, 2), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
|
2019-02-11 09:28:21 -08:00
|
|
|
lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
|
|
|
(2,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(0, 1, (2, 10, 5,), np.array([[[0, 2, 1], [0, 3, 3]]]).T,
|
2019-03-01 11:59:54 -05:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1, 3)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(2, 0, (10, 5, 2), np.array([[[0, 2], [1, 0]],
|
2019-03-01 11:59:54 -05:00
|
|
|
[[1, 0], [2, 0]]]),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
|
|
|
|
(1, 3)),
|
2020-12-11 13:47:46 -08:00
|
|
|
])
|
|
|
|
def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-02-11 09:28:21 -08:00
|
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
|
|
|
operand = rng(shape, dtype)
|
|
|
|
assert operand.shape[op_axis] == idxs.shape[idxs_axis]
|
|
|
|
ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack([fun(operand[(slice(None),) * op_axis + (i,)],
|
2019-02-11 09:28:21 -08:00
|
|
|
idxs[(slice(None),) * idxs_axis + (i,)])
|
|
|
|
for i in range(idxs.shape[idxs_axis])])
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2019-02-11 10:24:21 -08:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs,
|
|
|
|
dnums, slice_sizes),
|
|
|
|
"op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
|
2020-12-11 13:47:46 -08:00
|
|
|
dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes}
|
2020-05-05 14:59:16 -04:00
|
|
|
for dtype in [np.float32]
|
2019-02-11 10:24:21 -08:00
|
|
|
for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
|
2020-05-05 14:59:16 -04:00
|
|
|
(0, 0, (2, 5), np.array([[[0], [2]], [[1], [3]]]),
|
2019-03-01 11:59:54 -05:00
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(1, 1, (10, 2), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
|
2019-02-11 10:24:21 -08:00
|
|
|
lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
|
|
|
(2,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(0, 1, (2, 10, 5,), np.array([[[0, 2, 1], [0, 3, 3]]]).T,
|
2019-02-11 10:24:21 -08:00
|
|
|
lax.GatherDimensionNumbers(
|
2019-03-01 11:59:54 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1, 3)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(2, 0, (10, 5, 2), np.array([[[0, 2], [1, 0]],
|
2019-03-01 11:59:54 -05:00
|
|
|
[[1, 0], [2, 0]]]),
|
|
|
|
lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
|
|
|
|
(1, 3)),
|
2020-12-11 13:47:46 -08:00
|
|
|
])
|
2019-02-11 10:24:21 -08:00
|
|
|
def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
|
2020-12-11 13:47:46 -08:00
|
|
|
slice_sizes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-02-11 10:24:21 -08:00
|
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
2020-05-05 14:59:16 -04:00
|
|
|
gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
|
2019-02-11 10:24:21 -08:00
|
|
|
operand = rng(shape, dtype)
|
|
|
|
assert operand.shape[op_axis] == idxs.shape[idxs_axis]
|
|
|
|
ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack([gfun(operand[(slice(None),) * op_axis + (i,)],
|
2019-02-11 10:24:21 -08:00
|
|
|
idxs[(slice(None),) * idxs_axis + (i,)])
|
|
|
|
for i in range(idxs.shape[idxs_axis])])
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-02-10 18:36:21 -08:00
|
|
|
def testNumpyIndexing1(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
a = jnp.arange(2 * 3 * 4).reshape((2, 3, 4))
|
|
|
|
ind = np.array([[0, 1],
|
2019-02-10 18:36:21 -08:00
|
|
|
[2, 0]])
|
|
|
|
def f(a, ind):
|
|
|
|
return a[:, ind]
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack([f(a, ind[i, :]) for i in range(ind.shape[0])])
|
2019-02-10 18:36:21 -08:00
|
|
|
ans = vmap(f, (None, 0))(a, ind)
|
2020-05-05 14:59:16 -04:00
|
|
|
assert np.all(ans == expected)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
|
|
|
def testNumpyIndexing2(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
a = jnp.arange(2 * 3 * 4).reshape((2, 3, 4))
|
2019-02-10 18:36:21 -08:00
|
|
|
def f(a):
|
2020-05-05 14:59:16 -04:00
|
|
|
inds = jnp.array([0, 2])
|
2019-02-10 18:36:21 -08:00
|
|
|
return a[:, inds]
|
|
|
|
ans = vmap(f)(a)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack([f(a[:, i, :]) for i in range(a.shape[1])], axis=1)
|
|
|
|
assert np.all(ans == expected)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2019-02-12 07:26:32 -08:00
|
|
|
def testTranspose(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(4 * 3 * 3).reshape((4, 3, 3))
|
2019-02-12 07:26:32 -08:00
|
|
|
ans = vmap(lambda x: x + x.T)(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = x + np.swapaxes(x, -1, -2)
|
2019-02-12 07:26:32 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def testTransposePermutation(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
|
|
|
|
ans = vmap(lambda x: jnp.transpose(x, (1, 0, 2)))(x)
|
|
|
|
expected = np.transpose(x, (0, 2, 1, 3))
|
2019-02-12 07:26:32 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
|
|
|
|
ans = vmap(lambda x: jnp.transpose(x, (1, 2, 0)))(x)
|
|
|
|
expected = np.transpose(x, (0, 2, 3, 1))
|
2019-02-12 07:26:32 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5))
|
|
|
|
ans = vmap(lambda x: jnp.transpose(x, (1, 2, 0)), in_axes=2)(x)
|
|
|
|
expected = np.transpose(x, (2, 1, 3, 0))
|
2019-02-12 07:26:32 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-02-11 16:18:13 -08:00
|
|
|
def testIssue354(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
psd_mat = self.rng().randn(20, 10)
|
2019-02-11 16:18:13 -08:00
|
|
|
psd_mat = psd_mat.T.dot(psd_mat)
|
2021-12-10 10:32:09 -08:00
|
|
|
vec = self.rng().randn(10)
|
2019-02-11 16:18:13 -08:00
|
|
|
|
|
|
|
def f(scale):
|
2022-01-28 08:16:30 -08:00
|
|
|
scaled_mat = scale[jnp.newaxis] * psd_mat
|
2020-05-05 14:59:16 -04:00
|
|
|
chol = jnp.linalg.cholesky(scaled_mat)
|
|
|
|
return -0.5 * jnp.sum((jnp.einsum('ij,j->i', chol, vec))**2)
|
2019-02-11 16:18:13 -08:00
|
|
|
vmapped_f = vmap(f)
|
2020-05-05 14:59:16 -04:00
|
|
|
vmapped_f_grad = grad(lambda x: jnp.sum(vmapped_f(x)))
|
2019-02-11 16:18:13 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
scales = np.array([[0.1], [0.2], [0.3], [0.4], [0.5]])
|
2019-02-11 16:18:13 -08:00
|
|
|
ans = vmapped_f_grad(scales) # don't crash!
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.stack([grad(f)(scale) for scale in scales])
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False,
|
|
|
|
rtol=jtu.default_gradient_tolerance)
|
2019-02-11 16:18:13 -08:00
|
|
|
|
2019-02-15 21:14:11 -08:00
|
|
|
def testIssue387(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/387
|
2021-12-10 10:32:09 -08:00
|
|
|
R = self.rng().rand(100, 2)
|
2019-02-15 21:14:11 -08:00
|
|
|
|
|
|
|
def dist_sq(R):
|
2020-05-05 14:59:16 -04:00
|
|
|
dR = R[:, jnp.newaxis, :] - R[jnp.newaxis, :, :]
|
|
|
|
zero = jnp.zeros_like(dR)
|
|
|
|
dR = dR - jnp.where(jnp.abs(dR) < 0.5, zero, 0.5 * jnp.sign(dR))
|
|
|
|
return jnp.sum(dR ** 2, axis=2)
|
2019-02-15 21:14:11 -08:00
|
|
|
|
|
|
|
@jit
|
|
|
|
def f(R):
|
2020-06-02 19:25:47 -07:00
|
|
|
_ = dist_sq(R)
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sum(R ** 2)
|
2019-02-15 21:14:11 -08:00
|
|
|
|
2020-06-02 19:25:47 -07:00
|
|
|
_ = hessian(f)(R) # don't crash on UnshapedArray
|
2019-02-15 21:14:11 -08:00
|
|
|
|
2023-08-25 14:11:19 -07:00
|
|
|
@jax.legacy_prng_key('allow')
|
2019-03-10 20:53:53 -07:00
|
|
|
def testIssue489(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/489
|
2019-03-10 20:53:53 -07:00
|
|
|
def f(key):
|
|
|
|
def body_fn(uk):
|
|
|
|
key = uk[1]
|
2020-12-08 13:03:30 -08:00
|
|
|
u = random.uniform(key, ())
|
2019-03-10 20:53:53 -07:00
|
|
|
key, _ = random.split(key)
|
|
|
|
return u, key
|
|
|
|
|
2020-12-08 13:03:30 -08:00
|
|
|
u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
|
2019-03-10 20:53:53 -07:00
|
|
|
return u
|
|
|
|
|
2024-03-21 10:47:16 -07:00
|
|
|
with jax.debug_key_reuse(False):
|
2023-12-11 12:03:48 -08:00
|
|
|
print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash
|
2019-03-10 20:53:53 -07:00
|
|
|
|
2019-05-21 21:34:42 -04:00
|
|
|
def testEmptyTuples(self):
|
|
|
|
# Ensure there is no crash when a vectorized input contains empty tuples.
|
2020-05-05 14:59:16 -04:00
|
|
|
result = vmap(lambda x, _: x + 1)(np.array([0, 1]), ())
|
|
|
|
self.assertAllClose(result, np.array([1, 2]), check_dtypes=False)
|
2019-05-21 21:34:42 -04:00
|
|
|
# Ensure there is no crash when a vectorized output contains empty tuples.
|
2020-05-05 14:59:16 -04:00
|
|
|
result, empty_tuple = vmap(lambda x: (x + 1, ()))(np.array([0, 1]))
|
|
|
|
self.assertAllClose(result, np.array([1, 2]), check_dtypes=False)
|
2019-05-21 21:34:42 -04:00
|
|
|
self.assertEqual((), empty_tuple)
|
|
|
|
|
2019-05-29 17:13:46 -04:00
|
|
|
def testIndexAddBatchedIndexesOnly(self):
|
2021-09-13 16:40:45 -04:00
|
|
|
f = lambda x, idx, y: jnp.asarray(x).at[idx].add(y)
|
2020-05-05 14:59:16 -04:00
|
|
|
result = vmap(f, (None, 0, None))(np.zeros((10,)), np.arange(10,), 1.)
|
|
|
|
self.assertAllClose(result, np.eye(10), check_dtypes=False)
|
2019-05-29 17:13:46 -04:00
|
|
|
|
2019-08-12 18:03:25 -07:00
|
|
|
def testIssue1170(self):
|
|
|
|
def f(index1, index2):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.arange(36).reshape(6, 6)[index1, index2]
|
2019-08-12 18:03:25 -07:00
|
|
|
g = jax.jit(jax.pmap(f))
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = g(index1=np.asarray([1]), index2=np.asarray([2]))
|
|
|
|
expected = g(np.asarray([1]), np.asarray([2]))
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-08-12 18:03:25 -07:00
|
|
|
|
2020-07-29 03:39:32 +02:00
|
|
|
def testIssue3883(self):
|
|
|
|
def scalar_f(x):
|
|
|
|
return lax.dynamic_slice(x, [], [])
|
|
|
|
|
|
|
|
xs = jnp.array([1, 2, 3, 4])
|
|
|
|
ans = vmap(scalar_f)(xs)
|
|
|
|
expected = jnp.array([scalar_f(x) for x in xs])
|
|
|
|
self.assertAllClose(ans, expected)
|
|
|
|
|
|
|
|
def scalar_f2(x):
|
|
|
|
return lax.dynamic_update_slice(x, 7, [])
|
|
|
|
|
|
|
|
xs = jnp.array([1, 2, 3, 4])
|
|
|
|
ans = vmap(scalar_f2)(xs)
|
|
|
|
expected = jnp.array([scalar_f2(x) for x in xs])
|
|
|
|
self.assertAllClose(ans, expected)
|
2019-02-03 09:52:33 -08:00
|
|
|
|
2020-08-14 18:22:04 +02:00
|
|
|
@parameterized.named_parameters(
|
2020-11-24 09:58:44 -08:00
|
|
|
{"testcase_name": "_{}_vmap_names={}_collective_names={}".format(
|
2021-04-26 11:41:26 -07:00
|
|
|
collective.__name__.replace(" ", ""),
|
|
|
|
"".join(vmap_names), "".join(collective_names)),
|
2020-11-24 09:58:44 -08:00
|
|
|
"collective": collective, "bulk_op": bulk_op, "vmap_names": vmap_names,
|
|
|
|
"collective_names": collective_names}
|
|
|
|
for collective, bulk_op in [(lax.psum, jnp.sum),
|
|
|
|
(lax.pmax, jnp.max),
|
|
|
|
(lax.pmin, jnp.min)]
|
2021-04-26 11:41:26 -07:00
|
|
|
for vmap_names in [('i',), ('i', 'j'), ('i', 'j', 'k')]
|
|
|
|
for subset_size in range(1, len(vmap_names) + 1)
|
|
|
|
for collective_subset in it.combinations(vmap_names, subset_size)
|
|
|
|
for collective_names in it.permutations(collective_subset))
|
2020-11-24 09:58:44 -08:00
|
|
|
def testCommAssocCollective(self, collective, bulk_op, vmap_names, collective_names):
|
2021-04-26 11:41:26 -07:00
|
|
|
shape = (2, 2, 2)
|
|
|
|
x = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
|
2020-11-24 09:58:44 -08:00
|
|
|
|
|
|
|
# To test relative permutations of the order in which the axis names appear
|
|
|
|
# in the primitive call versus the order the vmaps are applied, we always
|
|
|
|
# apply vmaps in the order of the `vmap_names` argument, and apply the
|
|
|
|
# collective with names according to the `collective_names` argument.
|
|
|
|
f = lambda x: x - collective(x, collective_names)
|
2021-04-26 11:41:26 -07:00
|
|
|
# Use non-zero in and out axes to improve the coverage
|
|
|
|
for i, axis_name in enumerate(vmap_names):
|
|
|
|
f = vmap(f, axis_name=axis_name, in_axes=i, out_axes=i)
|
|
|
|
pos_axis = [i for i, name in enumerate(vmap_names) if name in collective_names]
|
|
|
|
self.assertAllClose(f(x), x - bulk_op(x, axis=pos_axis, keepdims=True))
|
2020-08-14 18:22:04 +02:00
|
|
|
|
2021-02-04 14:01:56 +00:00
|
|
|
if collective is lax.psum:
|
|
|
|
jtu.check_grads(f, (x,), 2, eps=1)
|
|
|
|
|
2020-08-28 20:03:39 +02:00
|
|
|
def testPPermute(self):
|
2020-08-18 12:02:28 +02:00
|
|
|
nelem = 10
|
|
|
|
ntests = 10
|
|
|
|
x = np.arange(nelem)
|
2021-12-10 10:32:09 -08:00
|
|
|
rng = self.rng()
|
2020-08-18 12:02:28 +02:00
|
|
|
for i in range(ntests):
|
|
|
|
perm = np.arange(nelem)
|
|
|
|
rng.shuffle(perm)
|
|
|
|
perm_pairs = np.stack([np.arange(nelem), perm], axis=-1)
|
|
|
|
rng.shuffle(perm_pairs)
|
|
|
|
self.assertAllClose(
|
2020-08-18 09:14:38 +00:00
|
|
|
vmap(lambda x: x - lax.ppermute(x, 'i', perm_pairs), axis_name='i')(x),
|
2021-11-18 13:40:25 -08:00
|
|
|
x - x[np.argsort(perm)])
|
2020-08-18 12:02:28 +02:00
|
|
|
|
2020-09-22 11:19:06 +00:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
|
|
|
|
"split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
|
|
|
|
for split_axis, concat_axis, vmap_axis in it.product(range(3), range(3), range(4)))
|
2021-03-05 12:24:56 +00:00
|
|
|
def testAllToAll(self, vmap_axis, split_axis, concat_axis):
|
|
|
|
shape = (4, 4, 4, 4)
|
2020-09-22 11:19:06 +00:00
|
|
|
x = np.arange(np.prod(shape)).reshape(shape)
|
2021-03-05 12:24:56 +00:00
|
|
|
f = vmap(lambda x: lax.all_to_all(x, 'i', split_axis, concat_axis),
|
|
|
|
in_axes=vmap_axis, axis_name='i')
|
|
|
|
y = f(x)
|
|
|
|
ref = jnp.moveaxis(x, (vmap_axis, split_axis + (vmap_axis <= split_axis)),
|
|
|
|
(concat_axis + 1, 0))
|
|
|
|
self.assertAllClose(y, ref)
|
2020-09-22 11:19:06 +00:00
|
|
|
|
2020-09-22 13:05:08 +00:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
|
|
|
|
"split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
|
|
|
|
for split_axis, concat_axis, vmap_axis in it.product(range(2), range(2), range(3)))
|
|
|
|
def testAllToAllSplitAxis(self, vmap_axis, split_axis, concat_axis):
|
|
|
|
shape = (4, 4, 4)
|
|
|
|
x = np.arange(np.prod(shape)).reshape(shape)
|
|
|
|
|
|
|
|
@partial(vmap, in_axes=vmap_axis, axis_name='i')
|
|
|
|
@partial(vmap, in_axes=vmap_axis, axis_name='j')
|
|
|
|
def f(x):
|
|
|
|
return lax.all_to_all(x, ('i', 'j'), split_axis, concat_axis)
|
|
|
|
|
|
|
|
unroll_shape = (2, 2, *shape[1:])
|
|
|
|
unroll_shape = list(shape)
|
|
|
|
unroll_shape[vmap_axis:vmap_axis+1] = (2, 2)
|
|
|
|
x_unroll = x.reshape(unroll_shape)
|
|
|
|
y_unrolled = f(x_unroll)
|
|
|
|
y = y_unrolled.reshape(shape)
|
|
|
|
|
|
|
|
if vmap_axis <= split_axis:
|
|
|
|
split_axis += 1
|
|
|
|
ref = jnp.moveaxis(x, (vmap_axis, split_axis),
|
|
|
|
(concat_axis + 1, 0))
|
|
|
|
self.assertAllClose(y, ref)
|
|
|
|
|
2020-08-24 20:21:19 -04:00
|
|
|
def testNegativeAxes(self):
|
|
|
|
x = np.arange(3*4*5).reshape(3, 4, 5)
|
|
|
|
self.assertAllClose(jax.vmap(jnp.sum, in_axes=-3)(x),
|
|
|
|
jnp.sum(x, axis=(1, 2)))
|
|
|
|
self.assertAllClose(jax.vmap(jnp.sum, in_axes=-2)(x),
|
|
|
|
jnp.sum(x, axis=(0, 2)))
|
|
|
|
self.assertAllClose(jax.vmap(jnp.sum, in_axes=-1)(x),
|
|
|
|
jnp.sum(x, axis=(0, 1)))
|
|
|
|
|
2021-07-14 11:39:52 +00:00
|
|
|
|
|
|
|
error = (r"vmap was requested to map its argument along axis -4, which "
|
|
|
|
r"implies that its rank should be at least 4, but is only 3 "
|
|
|
|
r"\(its shape is \(3, 4, 5\)\)")
|
|
|
|
with self.assertRaisesRegex(ValueError, error):
|
2020-08-24 20:21:19 -04:00
|
|
|
jax.vmap(jnp.sum, in_axes=-4)(x)
|
|
|
|
|
|
|
|
id = lambda y: y
|
|
|
|
self.assertAllClose(x, jax.vmap(id, in_axes=0, out_axes=-3)(x))
|
|
|
|
self.assertAllClose(x.transpose(1, 0, 2),
|
|
|
|
jax.vmap(id, in_axes=0, out_axes=-2)(x))
|
|
|
|
self.assertAllClose(x.transpose(1, 2, 0),
|
|
|
|
jax.vmap(id, in_axes=0, out_axes=-1)(x))
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "axis -4 is out of bounds.*"):
|
|
|
|
jax.vmap(id, in_axes=0, out_axes=-4)(x)
|
|
|
|
|
|
|
|
self.assertAllClose(
|
|
|
|
np.full((5,), 7),
|
|
|
|
jax.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -1))(
|
|
|
|
np.arange(5), 7)[1])
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "axis -2 is out of bounds.*"):
|
|
|
|
jax.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -2))(
|
|
|
|
np.arange(5), 7)
|
|
|
|
|
2020-08-28 20:03:39 +02:00
|
|
|
def testAxisIndex(self):
|
2022-06-17 16:08:54 -07:00
|
|
|
x = np.arange(10, dtype='int32')
|
2020-08-28 20:03:39 +02:00
|
|
|
self.assertAllClose(
|
|
|
|
vmap(lambda x: x - lax.axis_index('i'), axis_name='i')(x),
|
2022-06-17 16:08:54 -07:00
|
|
|
x - np.arange(x.shape[0], dtype='int32'))
|
2020-08-28 20:03:39 +02:00
|
|
|
|
2021-01-12 19:37:19 -08:00
|
|
|
def testVmapKwargs(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/912
|
2021-01-12 19:37:19 -08:00
|
|
|
|
|
|
|
def f(a, b):
|
|
|
|
return (2*a, 3*b)
|
|
|
|
|
|
|
|
x = vmap(f)(jnp.array([1]), jnp.array([2])) # works
|
|
|
|
y = vmap(f)(a=jnp.array([1]), b=jnp.array([2])) # doesn't work
|
|
|
|
self.assertAllClose(x, y)
|
|
|
|
|
2021-04-09 12:43:40 +00:00
|
|
|
def testGradOfPsum(self):
|
|
|
|
a = jnp.ones(5)
|
|
|
|
f = vmap(jax.grad(lambda x: -lax.psum(x, 'i')), out_axes=None, axis_name='i')
|
|
|
|
self.assertEqual(
|
|
|
|
f(a),
|
2023-02-14 23:00:40 -08:00
|
|
|
core.jaxpr_as_fun(jax.make_jaxpr(f)(a))(a)[0])
|
2021-04-09 12:43:40 +00:00
|
|
|
|
2021-01-25 16:52:38 -05:00
|
|
|
def testAllGatherToUnmapped(self):
|
|
|
|
def f(x):
|
|
|
|
return lax.all_gather(x, axis_name='i')
|
|
|
|
|
|
|
|
x = jnp.arange(15).reshape((3, 5))
|
|
|
|
# Original mapped axis becomes first axis of unmapped return value.
|
|
|
|
self.assertAllClose(vmap(f, axis_name='i', in_axes=1, out_axes=None)(x), x.T)
|
|
|
|
|
|
|
|
def testBatchedAllGather(self):
|
|
|
|
def f(x):
|
|
|
|
return lax.all_gather(x, axis_name='i')
|
|
|
|
|
|
|
|
x = jnp.arange(15).reshape((3, 5))
|
|
|
|
res = vmap(vmap(f, axis_name='i', out_axes=None), axis_name='j')(x)
|
|
|
|
self.assertAllClose(res, x)
|
|
|
|
|
|
|
|
res = vmap(vmap(f, axis_name='j'), axis_name='i', out_axes=None)(x)
|
|
|
|
self.assertAllClose(res, x.T)
|
|
|
|
|
2021-10-15 14:37:38 +00:00
|
|
|
def testAllGatherTiled(self):
|
|
|
|
def f(x):
|
|
|
|
return lax.all_gather(x, axis_name='i', tiled=True)
|
|
|
|
|
|
|
|
x = jnp.arange(60).reshape((4, 3, 5))
|
|
|
|
res = vmap(f, axis_name='i', in_axes=(1,), out_axes=None)(x)
|
|
|
|
self.assertAllClose(res, x.transpose((1, 0, 2)).reshape(-1, 5))
|
|
|
|
|
|
|
|
def testBatchedAllGatherTiled(self):
|
|
|
|
def f(x):
|
|
|
|
return lax.all_gather(x, axis_name='i', tiled=True)
|
|
|
|
|
|
|
|
x = jnp.arange(60).reshape((4, 3, 5))
|
|
|
|
res = vmap(vmap(f, in_axes=1, out_axes=1), axis_name='i', in_axes=1, out_axes=None)(x)
|
|
|
|
self.assertAllClose(res, x.transpose((1, 0, 2)).reshape(-1, 5))
|
|
|
|
|
2021-01-25 16:52:38 -05:00
|
|
|
def testAllGatherVjp(self):
|
|
|
|
def f(x):
|
|
|
|
return lax.all_gather(x, axis_name='i')
|
|
|
|
|
2021-12-10 10:32:09 -08:00
|
|
|
rng = self.rng()
|
2021-01-25 16:52:38 -05:00
|
|
|
x = rng.randn(3, 4)
|
|
|
|
y_bar = rng.randn(3, 3, 4)
|
|
|
|
|
|
|
|
x_bar, = vmap(lambda x, y_bar: vjp(f, x)[1](y_bar), axis_name='i')(x, y_bar)
|
|
|
|
self.assertAllClose(x_bar, np.sum(y_bar, axis=0))
|
2020-08-24 20:21:19 -04:00
|
|
|
|
2021-01-25 17:27:39 -05:00
|
|
|
def testAllGatherOfConst(self):
|
|
|
|
def f(x):
|
2021-01-25 17:47:50 -05:00
|
|
|
a = lax.all_gather(jnp.ones_like(x), axis_name='i')
|
|
|
|
b = lax.all_gather(1, axis_name='i')
|
|
|
|
return a, b
|
2021-01-25 17:27:39 -05:00
|
|
|
|
|
|
|
x = jnp.arange(15).reshape((3, 5))
|
2021-01-25 17:47:50 -05:00
|
|
|
a, b = vmap(f, axis_name='i', in_axes=1, out_axes=None)(x)
|
|
|
|
self.assertAllClose(a, jnp.ones(shape=(5, 3), dtype=x.dtype))
|
|
|
|
self.assertAllClose(b, jnp.ones(shape=(5,), dtype=b.dtype))
|
2021-01-25 17:27:39 -05:00
|
|
|
|
2021-02-08 20:24:19 -08:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_shape={}_axis={}_collective={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
axis, collective.__name__.replace(" ", "")),
|
|
|
|
"shape": shape, "dtype": dtype, "axis": axis,
|
|
|
|
"collective": collective, "bulk_op": bulk_op}
|
|
|
|
for collective, bulk_op in [(parallel.pargmax, jnp.argmax),
|
|
|
|
(parallel.pargmin, jnp.argmin)]
|
|
|
|
for dtype in [np.float32, np.int32]
|
|
|
|
for shape in [(7,), (5, 8)]
|
|
|
|
for axis in range(len(shape))
|
|
|
|
)
|
|
|
|
def testArgAllReduce(self, shape, dtype, axis, collective, bulk_op):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
x = rng(shape, dtype)
|
|
|
|
ans = vmap(lambda x: collective(x, 'i'), in_axes=axis, out_axes=None,
|
|
|
|
axis_name='i')(x)
|
|
|
|
expected = bulk_op(x, axis=axis)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2022-11-15 11:02:58 -08:00
|
|
|
def testReduceScatterAutodiff(self):
|
|
|
|
f = vmap(partial(lax.psum_scatter, axis_name='i'), axis_name='i')
|
|
|
|
x = self.rng().randn(3, 3, 4)
|
|
|
|
jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.)
|
|
|
|
|
2021-02-16 16:46:19 -05:00
|
|
|
def testNonJaxTypedOutput(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
TypeError, "Output from batched function.*is not a valid JAX type"):
|
|
|
|
vmap(lambda x: "hello")(np.arange(5))
|
|
|
|
|
2021-03-19 21:01:00 -07:00
|
|
|
def testIssue6096(self):
|
|
|
|
def f(x):
|
|
|
|
return jsp.special.betainc(jnp.ones(3), 1., x)
|
|
|
|
|
|
|
|
self.assertEqual(f(jnp.ones(3)).shape, (3,))
|
|
|
|
self.assertEqual(jax.vmap(f)(jnp.ones((2, 3))).shape, (2, 3))
|
|
|
|
|
2021-12-14 10:42:05 -08:00
|
|
|
def testPpermuteBatcherTrivial(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/8688
|
2021-12-14 10:42:05 -08:00
|
|
|
def ppermute(input):
|
|
|
|
return jax.lax.ppermute(input, axis_name="i", perm=[[0, 1], [1, 0]])
|
|
|
|
|
|
|
|
grad_fn = jax.grad(ppermute)
|
|
|
|
|
|
|
|
vmapped_gradients_fn = jax.vmap(grad_fn, axis_name="i")
|
|
|
|
|
|
|
|
vector = jax.numpy.array([1., 2.])
|
|
|
|
ans = vmapped_gradients_fn(vector) # doesn't crash
|
|
|
|
self.assertAllClose(ans, jnp.ones(2), check_dtypes=False)
|
|
|
|
|
2022-03-30 11:06:56 -07:00
|
|
|
def testBatchingPreservesWeakType(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/10025
|
2022-03-30 11:06:56 -07:00
|
|
|
x = jnp.ravel(1)
|
|
|
|
self.assertTrue(dtypes.is_weakly_typed(x))
|
|
|
|
@vmap
|
|
|
|
def f(x):
|
|
|
|
self.assertTrue(dtypes.is_weakly_typed(x), f"{x} is not weakly-typed")
|
|
|
|
return x
|
|
|
|
y = f(x)
|
|
|
|
self.assertTrue(dtypes.is_weakly_typed(y))
|
|
|
|
|
2021-03-19 21:01:00 -07:00
|
|
|
|
2021-10-06 14:18:07 -07:00
|
|
|
Array = Any
|
|
|
|
ArrayElt = Any
|
2023-02-14 23:00:40 -08:00
|
|
|
Int = Union[int, core.Tracer]
|
2021-10-06 14:18:07 -07:00
|
|
|
|
|
|
|
# Can't used NamedTuple here b/c those are pytrees
|
|
|
|
class NamedArray:
|
2023-06-23 15:11:37 -07:00
|
|
|
names: list[str]
|
2021-10-06 14:18:07 -07:00
|
|
|
data: Array
|
|
|
|
|
|
|
|
def __init__(self, names, data):
|
2023-01-18 00:12:25 -08:00
|
|
|
assert len(names) == data.ndim
|
2021-10-06 14:18:07 -07:00
|
|
|
self.names = names
|
|
|
|
self.data = data
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
return f'NamedArray(names={self.names}, data={self.data})'
|
|
|
|
|
|
|
|
class NamedMapSpec:
|
2023-12-11 13:59:29 +00:00
|
|
|
name: str | None
|
|
|
|
axis: int | None
|
2021-10-06 14:18:07 -07:00
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def __init__(self, name: str, axis: int | None):
|
2021-10-06 14:18:07 -07:00
|
|
|
assert (name is None) == (axis is None)
|
|
|
|
self.name = name
|
|
|
|
self.axis = axis
|
|
|
|
|
|
|
|
def named_mul(x: NamedArray, y: NamedArray) -> NamedArray:
|
|
|
|
if x.names != y.names: raise Exception
|
|
|
|
return NamedArray(x.names, lax.mul(x.data, y.data))
|
|
|
|
|
|
|
|
# TODO(mattjj): don't make this a pytree
|
|
|
|
register_pytree_node(NamedArray,
|
|
|
|
lambda x: ((x.data,), x.names),
|
|
|
|
lambda names, xs: NamedArray(names, xs[0]))
|
|
|
|
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def named_to_elt(cont: Callable[[Array, int | None], ArrayElt],
|
2021-10-06 14:18:07 -07:00
|
|
|
_: Int, val: NamedArray, spec: NamedMapSpec) -> NamedArray:
|
|
|
|
if spec.name is None:
|
|
|
|
return val
|
|
|
|
else:
|
|
|
|
elt_names, mapped_name = list_pop(val.names, spec.axis)
|
|
|
|
if mapped_name != spec.name: raise Exception
|
|
|
|
elt = cont(val.data, spec.axis)
|
|
|
|
return NamedArray(elt_names, elt)
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def named_from_elt(cont: Callable[[int, ArrayElt, int | None], Array],
|
2021-10-06 14:18:07 -07:00
|
|
|
axis_size: int, elt: NamedArray, annotation: NamedMapSpec
|
|
|
|
) -> NamedArray:
|
|
|
|
data = cont(axis_size, elt.data, annotation.axis)
|
|
|
|
if annotation.axis is None:
|
|
|
|
return NamedArray(elt.names, data)
|
|
|
|
else:
|
|
|
|
names = list_insert(elt.names, annotation.axis, annotation.name)
|
|
|
|
return NamedArray(names, data)
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def temporarily_register_named_array_vmappable():
|
|
|
|
batching.register_vmappable(NamedArray, NamedMapSpec, int,
|
|
|
|
named_to_elt, named_from_elt, None)
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
batching.unregister_vmappable(NamedArray)
|
|
|
|
|
|
|
|
a = TypeVar('a')
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def list_pop(lst: list[a], idx: int) -> a:
|
2021-10-06 14:18:07 -07:00
|
|
|
lst = list(lst)
|
|
|
|
return lst, lst.pop(idx)
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def list_insert(lst: list[a], idx: int, val: a) -> list[a]:
|
2021-10-06 14:18:07 -07:00
|
|
|
lst = list(lst)
|
|
|
|
lst.insert(idx, val)
|
|
|
|
return lst
|
|
|
|
|
|
|
|
|
|
|
|
class VmappableTest(jtu.JaxTestCase):
|
|
|
|
def test_basic(self):
|
|
|
|
with temporarily_register_named_array_vmappable():
|
|
|
|
def f(x):
|
|
|
|
return named_mul(x, x)
|
|
|
|
|
|
|
|
x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
|
|
|
|
g = jax.vmap(f,
|
|
|
|
in_axes=NamedMapSpec('i', 0),
|
|
|
|
out_axes=NamedMapSpec('i', 1),
|
|
|
|
axis_size=3)
|
|
|
|
ans = g(x)
|
|
|
|
expected = NamedArray(['j', 'i'], jnp.arange(12.).reshape(3, 4).T ** 2)
|
|
|
|
|
|
|
|
self.assertEqual(ans.names, expected.names)
|
|
|
|
self.assertAllClose(ans.data, expected.data)
|
|
|
|
|
|
|
|
def test_basic_jit(self):
|
|
|
|
with temporarily_register_named_array_vmappable():
|
|
|
|
def f(x):
|
|
|
|
return named_mul(x, x)
|
|
|
|
|
|
|
|
x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
|
|
|
|
ans = jax.jit(f)(x)
|
|
|
|
expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2)
|
|
|
|
|
|
|
|
self.assertEqual(ans.names, expected.names)
|
|
|
|
self.assertAllClose(ans.data, expected.data)
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
if __name__ == '__main__':
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|