Replace np -> jnp, onp -> np in more places. (#2973)

* Replace np -> jnp, onp -> np in more places.

Context: #2370

* Fix typo in random_test.py
This commit is contained in:
Peter Hawkins 2020-05-05 16:40:41 -04:00 committed by GitHub
parent d59ecddfe8
commit b1bc841ae5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 888 additions and 897 deletions

View File

@ -21,13 +21,13 @@ To make it run faster, set env var TARGET_TOTAL_SECS to a low number (e.g. 2).
from absl import app
import jax
from jax import numpy as np
from jax import numpy as jnp
from jax import pmap
from jax.config import config
from benchmarks import benchmark
import numpy as onp
import numpy as np
def pmap_shard_sharded_device_array_benchmark():
@ -38,9 +38,9 @@ def pmap_shard_sharded_device_array_benchmark():
"""
def get_benchmark_fn(nargs, nshards):
pmap_fn = pmap(lambda *args: np.sum(args))
pmap_fn = pmap(lambda *args: jnp.sum(args))
shape = (nshards, 4)
args = [onp.random.random(shape) for _ in range(nargs)]
args = [np.random.random(shape) for _ in range(nargs)]
sharded_args = pmap(lambda x: x)(args)
assert all(isinstance(arg, jax.pxla.ShardedDeviceArray)
for arg in sharded_args)
@ -68,9 +68,9 @@ def pmap_shard_device_array_benchmark():
"""
def get_benchmark_fn(nargs, nshards):
pmap_fn = pmap(lambda *args: np.sum(args))
pmap_fn = pmap(lambda *args: jnp.sum(args))
shape = (nshards, 4)
args = [np.array(onp.random.random(shape)) for _ in range(nargs)]
args = [jnp.array(np.random.random(shape)) for _ in range(nargs)]
assert all(isinstance(arg, jax.xla.DeviceArray) for arg in args)
def benchmark_fn():
for _ in range(10):
@ -96,7 +96,7 @@ def pmap_shard_outputs_benchmark():
def get_benchmark_fn(nouts, nshards):
pmap_fn = pmap(lambda x: [x + i for i in range(nouts)])
shape = (nshards, 4)
arg = onp.random.random(shape)
arg = np.random.random(shape)
def benchmark_fn():
for _ in range(100):
pmap_fn(arg)
@ -118,7 +118,7 @@ def sharded_device_array_indexing_benchmark():
nshards = min(8, jax.local_device_count())
shape = (nshards, 8, 8)
def benchmark_fn():
arr = pmap(lambda x: x)(np.arange(np.prod(shape)).reshape(shape))
arr = pmap(lambda x: x)(jnp.arange(jnp.prod(shape)).reshape(shape))
indices = indices_fn()
for idx in indices:
arr[idx]

View File

@ -375,7 +375,7 @@ class Tracer(object):
"You might have\n"
" import numpy as np\n"
"instead of\n"
" import jax.numpy as np")
" import jax.numpy as jnp")
def __init__(self, trace):
self._trace = trace

View File

@ -107,18 +107,18 @@ class custom_jvp:
For example::
import jax.numpy as np
import jax.numpy as jnp
@jax.custom_jvp
def f(x, y):
return np.sin(x) * y
return jnp.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = np.cos(x) * x_dot * y - np.sin(x) * y_dot
tangent_out = jnp.cos(x) * x_dot * y - jnp.sin(x) * y_dot
return primal_out, tangent_out
For a more detailed introduction, see the tutorial_.
@ -150,18 +150,18 @@ class custom_jvp:
Example::
import jax.numpy as np
import jax.numpy as jnp
@jax.custom_jvp
def f(x, y):
return np.sin(x) * y
return jnp.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = np.cos(x) * x_dot * y - np.sin(x) * y_dot
tangent_out = jnp.cos(x) * x_dot * y - jnp.sin(x) * y_dot
return primal_out, tangent_out
"""
self.jvp = jvp
@ -184,10 +184,10 @@ class custom_jvp:
@jax.custom_jvp
def f(x, y):
return np.sin(x) * y
return jnp.sin(x) * y
f.defjvps(lambda x_dot, primal_out, x, y: np.cos(x) * x_dot * y,
lambda y_dot, primal_out, x, y: -np.sin(x) * y_dot)
f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
lambda y_dot, primal_out, x, y: -jnp.sin(x) * y_dot)
"""
if self.nondiff_argnums:
raise TypeError("Can't use ``defjvps`` with ``nondiff_argnums``.")
@ -370,14 +370,14 @@ class custom_vjp:
For example::
import jax.numpy as np
import jax.numpy as jnp
@jax.custom_vjp
def f(x, y):
return np.sin(x) * y
return jnp.sin(x) * y
def f_fwd(x, y):
return f(x, y), (np.cos(x), np.sin(x), y)
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res
@ -424,14 +424,14 @@ class custom_vjp:
Example::
import jax.numpy as np
import jax.numpy as jnp
@jax.custom_vjp
def f(x, y):
return np.sin(x) * y
return jnp.sin(x) * y
def f_fwd(x, y):
return f(x, y), (np.cos(x), np.sin(x), y)
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res

View File

@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as onp
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import jax.numpy as np
import jax.numpy as jnp
from jax import core
from jax.core import Trace, Tracer, new_master
@ -76,14 +75,14 @@ def _contains_query(vals, query):
if isinstance(query, tuple):
return map(partial(_contains_query, vals), query)
if np.isnan(query):
if np.any(np.isnan(vals)):
if jnp.isnan(query):
if jnp.any(jnp.isnan(vals)):
raise FoundValue('NaN')
elif np.isinf(query):
if np.any(np.isinf(vals)):
elif jnp.isinf(query):
if jnp.any(jnp.isinf(vals)):
raise FoundValue('Found Inf')
elif np.isscalar(query):
if np.any(vals == query):
elif jnp.isscalar(query):
if jnp.any(vals == query):
raise FoundValue(str(query))
else:
raise ValueError('Malformed Query: {}'.format(query))

View File

@ -27,7 +27,7 @@ from functools import partial
import operator as op
import jax
import jax.numpy as np
import jax.numpy as jnp
from jax import core
from jax import lax
from jax import ops
@ -52,12 +52,12 @@ def ravel_first_arg_(unravel, y_flat, *args):
def interp_fit_dopri(y0, y1, k, dt):
# Fit a polynomial to the results of a Runge-Kutta step.
dps_c_mid = np.array([
dps_c_mid = jnp.array([
6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2,
-2691868925 / 45128329728 / 2, 187940372067 / 1594534317056 / 2,
-1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2])
y_mid = y0 + dt * np.dot(dps_c_mid, k)
return np.array(fit_4th_order_polynomial(y0, y1, y_mid, k[0], k[-1], dt))
y_mid = y0 + dt * jnp.dot(dps_c_mid, k)
return jnp.array(fit_4th_order_polynomial(y0, y1, y_mid, k[0], k[-1], dt))
def fit_4th_order_polynomial(y0, y1, y_mid, dy0, dy1, dt):
a = -2.*dt*dy0 + 2.*dt*dy1 - 8.*y0 - 8.*y1 + 16.*y_mid
@ -71,26 +71,26 @@ def initial_step_size(fun, t0, y0, order, rtol, atol, f0):
# Algorithm from:
# E. Hairer, S. P. Norsett G. Wanner,
# Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
scale = atol + np.abs(y0) * rtol
d0 = np.linalg.norm(y0 / scale)
d1 = np.linalg.norm(f0 / scale)
scale = atol + jnp.abs(y0) * rtol
d0 = jnp.linalg.norm(y0 / scale)
d1 = jnp.linalg.norm(f0 / scale)
h0 = np.where((d0 < 1e-5) | (d1 < 1e-5), 1e-6, 0.01 * d0 / d1)
h0 = jnp.where((d0 < 1e-5) | (d1 < 1e-5), 1e-6, 0.01 * d0 / d1)
y1 = y0 + h0 * f0
f1 = fun(y1, t0 + h0)
d2 = np.linalg.norm((f1 - f0) / scale) / h0
d2 = jnp.linalg.norm((f1 - f0) / scale) / h0
h1 = np.where((d1 <= 1e-15) & (d2 <= 1e-15),
np.maximum(1e-6, h0 * 1e-3),
(0.01 / np.max(d1 + d2)) ** (1. / (order + 1.)))
h1 = jnp.where((d1 <= 1e-15) & (d2 <= 1e-15),
jnp.maximum(1e-6, h0 * 1e-3),
(0.01 / jnp.max(d1 + d2)) ** (1. / (order + 1.)))
return np.minimum(100. * h0, h1)
return jnp.minimum(100. * h0, h1)
def runge_kutta_step(func, y0, f0, t0, dt):
# Dopri5 Butcher tableaux
alpha = np.array([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1., 0])
beta = np.array([
alpha = jnp.array([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1., 0])
beta = jnp.array([
[1 / 5, 0, 0, 0, 0, 0, 0],
[3 / 40, 9 / 40, 0, 0, 0, 0, 0],
[44 / 45, -56 / 15, 32 / 9, 0, 0, 0, 0],
@ -98,49 +98,49 @@ def runge_kutta_step(func, y0, f0, t0, dt):
[9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656, 0, 0],
[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0]
])
c_sol = np.array([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0])
c_error = np.array([35 / 384 - 1951 / 21600, 0, 500 / 1113 - 22642 / 50085,
c_sol = jnp.array([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0])
c_error = jnp.array([35 / 384 - 1951 / 21600, 0, 500 / 1113 - 22642 / 50085,
125 / 192 - 451 / 720, -2187 / 6784 - -12231 / 42400,
11 / 84 - 649 / 6300, -1. / 60.])
def body_fun(i, k):
ti = t0 + dt * alpha[i-1]
yi = y0 + dt * np.dot(beta[i-1, :], k)
yi = y0 + dt * jnp.dot(beta[i-1, :], k)
ft = func(yi, ti)
return ops.index_update(k, jax.ops.index[i, :], ft)
k = ops.index_update(np.zeros((7, f0.shape[0])), ops.index[0, :], f0)
k = ops.index_update(jnp.zeros((7, f0.shape[0])), ops.index[0, :], f0)
k = lax.fori_loop(1, 7, body_fun, k)
y1 = dt * np.dot(c_sol, k) + y0
y1_error = dt * np.dot(c_error, k)
y1 = dt * jnp.dot(c_sol, k) + y0
y1_error = dt * jnp.dot(c_error, k)
f1 = k[-1]
return y1, f1, y1_error, k
def error_ratio(error_estimate, rtol, atol, y0, y1):
err_tol = atol + rtol * np.maximum(np.abs(y0), np.abs(y1))
err_tol = atol + rtol * jnp.maximum(jnp.abs(y0), jnp.abs(y1))
err_ratio = error_estimate / err_tol
return np.mean(err_ratio ** 2)
return jnp.mean(err_ratio ** 2)
def optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0,
dfactor=0.2, order=5.0):
"""Compute optimal Runge-Kutta stepsize."""
mean_error_ratio = np.max(mean_error_ratio)
dfactor = np.where(mean_error_ratio < 1, 1.0, dfactor)
mean_error_ratio = jnp.max(mean_error_ratio)
dfactor = jnp.where(mean_error_ratio < 1, 1.0, dfactor)
err_ratio = np.sqrt(mean_error_ratio)
factor = np.maximum(1.0 / ifactor,
np.minimum(err_ratio**(1.0 / order) / safety, 1.0 / dfactor))
return np.where(mean_error_ratio == 0, last_step * ifactor, last_step / factor)
err_ratio = jnp.sqrt(mean_error_ratio)
factor = jnp.maximum(1.0 / ifactor,
jnp.minimum(err_ratio**(1.0 / order) / safety, 1.0 / dfactor))
return jnp.where(mean_error_ratio == 0, last_step * ifactor, last_step / factor)
def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=np.inf):
def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
"""Adaptive stepsize (Dormand-Prince) Runge-Kutta odeint implementation.
Args:
func: function to evaluate the time derivative of the solution `y` at time
`t` as `func(y, t, *args)`, producing the same shape/structure as `y0`.
y0: array or pytree of arrays representing the initial value for the state.
t: array of float times for evaluation, like `np.linspace(0., 10., 101)`,
t: array of float times for evaluation, like `jnp.linspace(0., 10., 101)`,
in which the values must be strictly increasing.
*args: tuple of additional arguments for `func`, which must be arrays
scalars, or (nested) standard Python containers (tuples, lists, dicts,
@ -189,20 +189,20 @@ def _odeint(func, rtol, atol, mxstep, y0, ts, *args):
new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff]
old = [i + 1, y, f, t, dt, last_t, interp_coeff]
return map(partial(np.where, np.all(error_ratios <= 1.)), new, old)
return map(partial(jnp.where, jnp.all(error_ratios <= 1.)), new, old)
_, *carry = lax.while_loop(cond_fun, body_fun, [0] + carry)
_, _, t, _, last_t, interp_coeff = carry
relative_output_time = (target_t - last_t) / (t - last_t)
y_target = np.polyval(interp_coeff, relative_output_time)
y_target = jnp.polyval(interp_coeff, relative_output_time)
return carry, y_target
f0 = func_(y0, ts[0])
dt = initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0)
interp_coeff = np.array([y0] * 5)
interp_coeff = jnp.array([y0] * 5)
init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff]
_, ys = lax.scan(scan_fun, init_carry, ts[1:])
return np.concatenate((y0[None], ys))
return jnp.concatenate((y0[None], ys))
def _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args):
ys = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
@ -226,22 +226,22 @@ def _odeint_rev(func, rtol, atol, mxstep, res, g):
def scan_fun(carry, i):
y_bar, t0_bar, args_bar = carry
# Compute effect of moving measurement time
t_bar = np.dot(func(ys[i], ts[i], *args), g[i])
t_bar = jnp.dot(func(ys[i], ts[i], *args), g[i])
t0_bar = t0_bar - t_bar
# Run augmented system backwards to previous observation
_, y_bar, t0_bar, args_bar = odeint(
aug_dynamics, (ys[i], y_bar, t0_bar, args_bar),
np.array([-ts[i], -ts[i - 1]]),
jnp.array([-ts[i], -ts[i - 1]]),
*args, rtol=rtol, atol=atol, mxstep=mxstep)
y_bar, t0_bar, args_bar = tree_map(op.itemgetter(1), (y_bar, t0_bar, args_bar))
# Add gradient from current output
y_bar = y_bar + g[i - 1]
return (y_bar, t0_bar, args_bar), t_bar
init_carry = (g[-1], 0., tree_map(np.zeros_like, args))
init_carry = (g[-1], 0., tree_map(jnp.zeros_like, args))
(y_bar, t0_bar, args_bar), rev_ts_bar = lax.scan(
scan_fun, init_carry, np.arange(len(ts) - 1, 0, -1))
ts_bar = np.concatenate([np.array([t0_bar]), rev_ts_bar[::-1]])
scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1))
ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]])
return (y_bar, ts_bar, *args_bar)
_odeint.defvjp(_odeint_fwd, _odeint_rev)

View File

@ -71,7 +71,7 @@ from collections import namedtuple
import functools
import operator
import jax.numpy as np
import jax.numpy as jnp
from jax.util import partial, safe_zip, safe_map, unzip2
from jax import tree_util
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
@ -208,7 +208,7 @@ def momentum(step_size, mass):
"""
step_size = make_schedule(step_size)
def init(x0):
v0 = np.zeros_like(x0)
v0 = jnp.zeros_like(x0)
return x0, v0
def update(i, g, state):
x, velocity = state
@ -235,7 +235,7 @@ def nesterov(step_size, mass):
"""
step_size = make_schedule(step_size)
def init(x0):
v0 = np.zeros_like(x0)
v0 = jnp.zeros_like(x0)
return x0, v0
def update(i, g, state):
x, velocity = state
@ -266,14 +266,14 @@ def adagrad(step_size, momentum=0.9):
step_size = make_schedule(step_size)
def init(x0):
g_sq = np.zeros_like(x0)
m = np.zeros_like(x0)
g_sq = jnp.zeros_like(x0)
m = jnp.zeros_like(x0)
return x0, g_sq, m
def update(i, g, state):
x, g_sq, m = state
g_sq += g**2
g_sq_inv_sqrt = np.where(g_sq > 0, 1. / np.sqrt(g_sq), 0.0)
g_sq_inv_sqrt = jnp.where(g_sq > 0, 1. / jnp.sqrt(g_sq), 0.0)
m = (1. - momentum) * (g * g_sq_inv_sqrt) + momentum * m
x = x - step_size(i) * m
return x, g_sq, m
@ -300,12 +300,12 @@ def rmsprop(step_size, gamma=0.9, eps=1e-8):
"""
step_size = make_schedule(step_size)
def init(x0):
avg_sq_grad = np.zeros_like(x0)
avg_sq_grad = jnp.zeros_like(x0)
return x0, avg_sq_grad
def update(i, g, state):
x, avg_sq_grad = state
avg_sq_grad = avg_sq_grad * gamma + g**2 * (1. - gamma)
x = x - step_size(i) * g / np.sqrt(avg_sq_grad + eps)
x = x - step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
return x, avg_sq_grad
def get_params(state):
x, _ = state
@ -332,13 +332,13 @@ def rmsprop_momentum(step_size, gamma=0.9, eps=1e-8, momentum=0.9):
"""
step_size = make_schedule(step_size)
def init(x0):
avg_sq_grad = np.zeros_like(x0)
mom = np.zeros_like(x0)
avg_sq_grad = jnp.zeros_like(x0)
mom = jnp.zeros_like(x0)
return x0, avg_sq_grad, mom
def update(i, g, state):
x, avg_sq_grad, mom = state
avg_sq_grad = avg_sq_grad * gamma + g**2 * (1. - gamma)
mom = momentum * mom + step_size(i) * g / np.sqrt(avg_sq_grad + eps)
mom = momentum * mom + step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
x = x - mom
return x, avg_sq_grad, mom
def get_params(state):
@ -366,8 +366,8 @@ def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
"""
step_size = make_schedule(step_size)
def init(x0):
m0 = np.zeros_like(x0)
v0 = np.zeros_like(x0)
m0 = jnp.zeros_like(x0)
v0 = jnp.zeros_like(x0)
return x0, m0, v0
def update(i, g, state):
x, m, v = state
@ -375,7 +375,7 @@ def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
v = (1 - b2) * (g ** 2) + b2 * v # Second moment estimate.
mhat = m / (1 - b1 ** (i + 1)) # Bias correction.
vhat = v / (1 - b2 ** (i + 1))
x = x - step_size(i) * mhat / (np.sqrt(vhat) + eps)
x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
return x, m, v
def get_params(state):
x, m, v = state
@ -402,13 +402,13 @@ def adamax(step_size, b1=0.9, b2=0.999, eps=1e-8):
"""
step_size = make_schedule(step_size)
def init(x0):
m0 = np.zeros_like(x0)
u0 = np.zeros_like(x0)
m0 = jnp.zeros_like(x0)
u0 = jnp.zeros_like(x0)
return x0, m0, u0
def update(i, g, state):
x, m, u = state
m = (1 - b1) * g + b1 * m # First moment estimate.
u = np.maximum(b2 * u, np.abs(g)) # Update exponentially weighted infinity norm.
u = jnp.maximum(b2 * u, jnp.abs(g)) # Update exponentially weighted infinity norm.
x = x - (step_size(i) / (1 - b1 ** (i + 1))) * m / (u + eps)
return x, m, u
def get_params(state):
@ -444,14 +444,14 @@ def sm3(step_size, momentum=0.9):
return x[tuple(idx)]
def init(x0):
vs = [np.zeros(sz, dtype=x0.dtype) for sz in x0.shape]
return x0, np.zeros_like(x0), vs
vs = [jnp.zeros(sz, dtype=x0.dtype) for sz in x0.shape]
return x0, jnp.zeros_like(x0), vs
def update(i, g, state):
x, m, vs = state
vs = [broadcast_into(g.ndim, v, i) for i, v in enumerate(vs)]
accum = functools.reduce(np.minimum, vs) + g ** 2
accum_inv_sqrt = np.where(accum > 0, 1. / np.sqrt(accum), 0)
accum = functools.reduce(jnp.minimum, vs) + g ** 2
accum_inv_sqrt = jnp.where(accum > 0, 1. / jnp.sqrt(accum), 0)
m = (1. - momentum) * (g * accum_inv_sqrt) + momentum * m
x = x - step_size(i) * m
vs = [accum.max(splice(range(x.ndim), j, [])) for j in range(x.ndim)]
@ -479,7 +479,7 @@ def exponential_decay(step_size, decay_steps, decay_rate):
def inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False):
if staircase:
def schedule(i):
return step_size / (1 + decay_rate * np.floor(i / decay_steps))
return step_size / (1 + decay_rate * jnp.floor(i / decay_steps))
else:
def schedule(i):
return step_size / (1 + decay_rate * i / decay_steps)
@ -487,28 +487,28 @@ def inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False):
def polynomial_decay(step_size, decay_steps, final_step_size, power=1.0):
def schedule(step_num):
step_num = np.minimum(step_num, decay_steps)
step_num = jnp.minimum(step_num, decay_steps)
step_mult = (1 - step_num / decay_steps) ** power
return step_mult * (step_size - final_step_size) + final_step_size
return schedule
def piecewise_constant(boundaries, values):
boundaries = np.array(boundaries)
values = np.array(values)
boundaries = jnp.array(boundaries)
values = jnp.array(values)
if not boundaries.ndim == values.ndim == 1:
raise ValueError("boundaries and values must be sequences")
if not boundaries.shape[0] == values.shape[0] - 1:
raise ValueError("boundaries length must be one longer than values length")
def schedule(i):
return values[np.sum(i > boundaries)]
return values[jnp.sum(i > boundaries)]
return schedule
def make_schedule(scalar_or_schedule):
if callable(scalar_or_schedule):
return scalar_or_schedule
elif np.ndim(scalar_or_schedule) == 0:
elif jnp.ndim(scalar_or_schedule) == 0:
return constant(scalar_or_schedule)
else:
raise TypeError(type(scalar_or_schedule))
@ -519,12 +519,12 @@ def make_schedule(scalar_or_schedule):
def l2_norm(tree):
"""Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
leaves, _ = tree_flatten(tree)
return np.sqrt(sum(np.vdot(x, x) for x in leaves))
return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
def clip_grads(grad_tree, max_norm):
"""Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
norm = l2_norm(grad_tree)
normalize = lambda g: np.where(norm < max_norm, g, g * (max_norm / norm))
normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm))
return tree_map(normalize, grad_tree)

View File

@ -22,11 +22,9 @@ import functools
import itertools
import operator as op
import numpy as onp
from jax import lax
from jax import random
import jax.numpy as np
import jax.numpy as jnp
from jax.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu,
leaky_relu, selu, gelu, normalize)
@ -56,7 +54,7 @@ def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return np.dot(inputs, W) + b
return jnp.dot(inputs, W) + b
return init_fun, apply_fun
@ -123,7 +121,7 @@ def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
"""Layer construction function for a batch normalization layer."""
_beta_init = lambda rng, shape: beta_init(rng, shape) if center else ()
_gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else ()
axis = (axis,) if np.isscalar(axis) else axis
axis = (axis,) if jnp.isscalar(axis) else axis
def init_fun(rng, input_shape):
shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
k1, k2 = random.split(rng)
@ -131,9 +129,9 @@ def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
return input_shape, (beta, gamma)
def apply_fun(params, x, **kwargs):
beta, gamma = params
# TODO(phawkins): np.expand_dims should accept an axis tuple.
# TODO(phawkins): jnp.expand_dims should accept an axis tuple.
# (https://github.com/numpy/numpy/issues/12290)
ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
ed = tuple(None if i in axis else slice(None) for i in range(jnp.ndim(x)))
z = normalize(x, axis, epsilon=epsilon)
if center and scale: return gamma[ed] * z + beta[ed]
if center: return z + beta[ed]
@ -147,9 +145,9 @@ def elementwise(fun, **fun_kwargs):
init_fun = lambda rng, input_shape: (input_shape, ())
apply_fun = lambda params, inputs, **kwargs: fun(inputs, **fun_kwargs)
return init_fun, apply_fun
Tanh = elementwise(np.tanh)
Tanh = elementwise(jnp.tanh)
Relu = elementwise(relu)
Exp = elementwise(np.exp)
Exp = elementwise(jnp.exp)
LogSoftmax = elementwise(log_softmax, axis=-1)
Softmax = elementwise(softmax, axis=-1)
Softplus = elementwise(softplus)
@ -185,7 +183,7 @@ def _pooling_layer(reducer, init_val, rescaler=None):
return rescale(out, inputs, spec) if rescale else out
return init_fun, apply_fun
return PoolingLayer
MaxPool = _pooling_layer(lax.max, -np.inf)
MaxPool = _pooling_layer(lax.max, -jnp.inf)
SumPool = _pooling_layer(lax.add, 0.)
@ -199,10 +197,10 @@ def _normalize_by_window_size(dims, strides, padding):
spatial_shape = tuple(inputs.shape[i]
for i in range(inputs.ndim)
if i not in non_spatial_axes)
one = np.ones(spatial_shape, dtype=inputs.dtype)
one = jnp.ones(spatial_shape, dtype=inputs.dtype)
window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides, padding)
for i in sorted(non_spatial_axes):
window_sizes = np.expand_dims(window_sizes, i)
window_sizes = jnp.expand_dims(window_sizes, i)
return outputs / window_sizes
return rescale
@ -215,7 +213,7 @@ def Flatten():
output_shape = input_shape[0], functools.reduce(op.mul, input_shape[1:], 1)
return output_shape, ()
def apply_fun(params, inputs, **kwargs):
return np.reshape(inputs, (inputs.shape[0], -1))
return jnp.reshape(inputs, (inputs.shape[0], -1))
return init_fun, apply_fun
Flatten = Flatten()
@ -251,7 +249,7 @@ def FanInConcat(axis=-1):
out_shape = input_shape[0][:ax] + (concat_size,) + input_shape[0][ax+1:]
return out_shape, ()
def apply_fun(params, inputs, **kwargs):
return np.concatenate(inputs, axis)
return jnp.concatenate(inputs, axis)
return init_fun, apply_fun
@ -269,7 +267,7 @@ def Dropout(rate, mode='train'):
raise ValueError(msg)
if mode == 'train':
keep = random.bernoulli(rng, rate, inputs.shape)
return np.where(keep, inputs / rate, 0)
return jnp.where(keep, inputs / rate, 0)
else:
return inputs
return init_fun, apply_fun

View File

@ -17,7 +17,7 @@ from .tree_util import tree_flatten, tree_unflatten
from . import linear_util as lu
from .util import safe_zip
import jax.numpy as np
import jax.numpy as jnp
from jax.api import vjp
zip = safe_zip
@ -30,7 +30,7 @@ def ravel_pytree(pytree):
return flat, unravel_pytree
def ravel_list(*lst):
return np.concatenate([np.ravel(elt) for elt in lst]) if lst else np.array([])
return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([])
@lu.transformation_with_aux

View File

@ -15,13 +15,13 @@
"""Shared neural network activations and other functions."""
import numpy as onp
import numpy as np
from jax import custom_jvp
from jax import dtypes
from jax import lax
from jax.scipy.special import expit
import jax.numpy as np
import jax.numpy as jnp
# activations
@ -34,7 +34,7 @@ def relu(x):
.. math::
\mathrm{relu}(x) = \max(x, 0)
"""
return np.maximum(x, 0)
return jnp.maximum(x, 0)
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
def softplus(x):
@ -45,7 +45,7 @@ def softplus(x):
.. math::
\mathrm{softplus}(x) = \log(1 + e^x)
"""
return np.logaddexp(x, 0)
return jnp.logaddexp(x, 0)
def soft_sign(x):
r"""Soft-sign activation function.
@ -55,7 +55,7 @@ def soft_sign(x):
.. math::
\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
"""
return x / (np.abs(x) + 1)
return x / (jnp.abs(x) + 1)
def sigmoid(x):
r"""Sigmoid activation function.
@ -98,8 +98,8 @@ def elu(x, alpha=1.0):
\alpha \exp(x - 1), & x \le 0
\end{cases}
"""
safe_x = np.where(x > 0, 0., x)
return np.where(x > 0, x, alpha * np.expm1(safe_x))
safe_x = jnp.where(x > 0, 0., x)
return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x))
def leaky_relu(x, negative_slope=1e-2):
r"""Leaky rectified linear unit activation function.
@ -114,7 +114,7 @@ def leaky_relu(x, negative_slope=1e-2):
where :math:`\alpha` = :code:`negative_slope`.
"""
return np.where(x >= 0, x, negative_slope * x)
return jnp.where(x >= 0, x, negative_slope * x)
def hard_tanh(x):
r"""Hard :math:`\mathrm{tanh}` activation function.
@ -128,7 +128,7 @@ def hard_tanh(x):
1, & 1 < x
\end{cases}
"""
return np.where(x > 1, 1, np.where(x < -1, -1, x))
return jnp.where(x > 1, 1, jnp.where(x < -1, -1, x))
def celu(x, alpha=1.0):
r"""Continuously-differentiable exponential linear unit activation.
@ -144,7 +144,7 @@ def celu(x, alpha=1.0):
For more information, see
`Continuously Differentiable Exponential Linear Units
<https://arxiv.org/pdf/1704.07483.pdf>`_."""
return np.where(x > 0, x, alpha * np.expm1(x / alpha))
return jnp.where(x > 0, x, alpha * jnp.expm1(x / alpha))
def selu(x):
r"""Scaled exponential linear unit activation.
@ -181,15 +181,15 @@ def gelu(x):
speed. For more information, see `Gaussian Error Linear Units (GELUs)
<https://arxiv.org/abs/1606.08415>`_, section 2.
"""
sqrt_2_over_pi = onp.sqrt(2 / onp.pi).astype(x.dtype)
cdf = 0.5 * (1.0 + np.tanh(sqrt_2_over_pi * (x + 0.044715 * x**3)))
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * x**3)))
return x * cdf
def glu(x, axis=-1):
"""Gated linear unit activation function."""
size = x.shape[axis]
assert size % 2 == 0, "axis size must be divisible by 2"
x1, x2 = np.split(x, 2, axis)
x1, x2 = jnp.split(x, 2, axis)
return x1 * sigmoid(x2)
# other functions
@ -209,7 +209,7 @@ def log_softmax(x, axis=-1):
computed. Either an integer or a tuple of integers.
"""
shifted = x - x.max(axis, keepdims=True)
return shifted - np.log(np.sum(np.exp(shifted), axis, keepdims=True))
return shifted - jnp.log(jnp.sum(jnp.exp(shifted), axis, keepdims=True))
def softmax(x, axis=-1):
r"""Softmax function.
@ -225,35 +225,35 @@ def softmax(x, axis=-1):
softmax output summed across these dimensions should sum to :math:`1`.
Either an integer or a tuple of integers.
"""
unnormalized = np.exp(x - x.max(axis, keepdims=True))
unnormalized = jnp.exp(x - x.max(axis, keepdims=True))
return unnormalized / unnormalized.sum(axis, keepdims=True)
def normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-5):
"""Normalizes an array by subtracting mean and dividing by sqrt(var)."""
if mean is None:
mean = np.mean(x, axis, keepdims=True)
mean = jnp.mean(x, axis, keepdims=True)
if variance is None:
# this definition is traditionally seen as less accurate than np.var's
# this definition is traditionally seen as less accurate than jnp.var's
# mean((x - mean(x))**2) but may be faster and even, given typical
# activation distributions and low-precision arithmetic, more accurate
# when used in neural network normalization layers
variance = np.mean(x**2, axis, keepdims=True) - mean**2
variance = jnp.mean(x**2, axis, keepdims=True) - mean**2
return (x - mean) * lax.rsqrt(variance + epsilon)
def one_hot(x, num_classes, *, dtype=np.float64):
def one_hot(x, num_classes, *, dtype=jnp.float64):
"""One-hot encodes the given indicies.
Each index in the input ``x`` is encoded as a vector of zeros of length
``num_classes`` with the element at ``index`` set to one::
>>> jax.nn.one_hot(np.array([0, 1, 2]), 3)
>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
DeviceArray([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)
Indicies outside the range [0, num_classes) will be encoded as zeros::
>>> jax.nn.one_hot(np.array([-1, 3]), 3)
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
DeviceArray([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
@ -264,10 +264,10 @@ def one_hot(x, num_classes, *, dtype=np.float64):
jax_enable_x64 is true, otherwise float32).
"""
dtype = dtypes.canonicalize_dtype(dtype)
x = np.asarray(x)
lhs = x[..., np.newaxis]
rhs = lax.broadcast_to_rank(np.arange(num_classes, dtype=x.dtype), lhs.ndim)
return np.array(lhs == rhs, dtype=dtype)
x = jnp.asarray(x)
lhs = x[..., jnp.newaxis]
rhs = lax.broadcast_to_rank(jnp.arange(num_classes, dtype=x.dtype), lhs.ndim)
return jnp.array(lhs == rhs, dtype=dtype)
def relu6(x):
r"""Rectified Linear Unit 6 activation function.
@ -277,7 +277,7 @@ def relu6(x):
.. math::
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
"""
return np.minimum(np.maximum(x, 0), 6.)
return jnp.minimum(jnp.maximum(x, 0), 6.)
def hard_sigmoid(x):
r"""Hard Sigmoid activation function.

View File

@ -20,33 +20,33 @@ used in Keras and Sonnet.
from functools import partial
import numpy as onp
import numpy as np
import jax.numpy as np
import jax.numpy as jnp
from jax import lax
from jax import ops
from jax import random
def zeros(key, shape, dtype=np.float32): return np.zeros(shape, dtype)
def ones(key, shape, dtype=np.float32): return np.ones(shape, dtype)
def zeros(key, shape, dtype=jnp.float32): return jnp.zeros(shape, dtype)
def ones(key, shape, dtype=jnp.float32): return jnp.ones(shape, dtype)
def uniform(scale=1e-2, dtype=np.float32):
def uniform(scale=1e-2, dtype=jnp.float32):
def init(key, shape, dtype=dtype):
return random.uniform(key, shape, dtype) * scale
return init
def normal(stddev=1e-2, dtype=np.float32):
def normal(stddev=1e-2, dtype=jnp.float32):
def init(key, shape, dtype=dtype):
return random.normal(key, shape, dtype) * stddev
return init
def _compute_fans(shape, in_axis=-2, out_axis=-1):
receptive_field_size = onp.prod(shape) / shape[in_axis] / shape[out_axis]
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=np.float32):
def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float32):
def init(key, shape, dtype=dtype):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in": denominator = fan_in
@ -55,15 +55,15 @@ def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=n
else:
raise ValueError(
"invalid mode for variance scaling initializer: {}".format(mode))
variance = np.array(scale / denominator, dtype=dtype)
variance = jnp.array(scale / denominator, dtype=dtype)
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
stddev = np.sqrt(variance) / np.array(.87962566103423978, dtype)
stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype)
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
elif distribution == "normal":
return random.normal(key, shape, dtype) * np.sqrt(variance)
return random.normal(key, shape, dtype) * jnp.sqrt(variance)
elif distribution == "uniform":
return random.uniform(key, shape, dtype, -1) * onp.sqrt(3 * variance)
return random.uniform(key, shape, dtype, -1) * np.sqrt(3 * variance)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return init
@ -75,7 +75,7 @@ lecun_normal = partial(variance_scaling, 1.0, "fan_in", "truncated_normal")
kaiming_uniform = he_uniform = partial(variance_scaling, 2.0, "fan_in", "uniform")
kaiming_normal = he_normal = partial(variance_scaling, 2.0, "fan_in", "truncated_normal")
def orthogonal(scale=1.0, column_axis=-1, dtype=np.float32):
def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
"""
Construct an initializer for uniformly distributed orthogonal matrices.
@ -85,20 +85,20 @@ def orthogonal(scale=1.0, column_axis=-1, dtype=np.float32):
def init(key, shape, dtype=dtype):
if len(shape) < 2:
raise ValueError("orthogonal initializer requires at least a 2D shape")
n_rows, n_cols = onp.prod(shape) // shape[column_axis], shape[column_axis]
n_rows, n_cols = np.prod(shape) // shape[column_axis], shape[column_axis]
matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)
A = random.normal(key, matrix_shape, dtype)
Q, R = np.linalg.qr(A)
diag_sign = lax.broadcast_to_rank(np.sign(np.diag(R)), rank=Q.ndim)
Q, R = jnp.linalg.qr(A)
diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim)
Q *= diag_sign # needed for a uniform distribution
if n_rows < n_cols: Q = Q.T
Q = np.reshape(Q, tuple(onp.delete(shape, column_axis)) + (shape[column_axis],))
Q = np.moveaxis(Q, -1, column_axis)
Q = jnp.reshape(Q, tuple(np.delete(shape, column_axis)) + (shape[column_axis],))
Q = jnp.moveaxis(Q, -1, column_axis)
return scale * Q
return init
def delta_orthogonal(scale=1.0, column_axis=-1, dtype=np.float32):
def delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
"""
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
@ -112,7 +112,7 @@ def delta_orthogonal(scale=1.0, column_axis=-1, dtype=np.float32):
raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
ortho_init = orthogonal(scale=scale, column_axis=column_axis, dtype=dtype)
ortho_matrix = ortho_init(key, shape[-2:])
W = np.zeros(shape, dtype=dtype)
W = jnp.zeros(shape, dtype=dtype)
if len(shape) == 3:
k = shape[0]
return ops.index_update(W, ops.index[(k-1)//2, ...], ortho_matrix)

View File

@ -26,7 +26,7 @@ from .. import ops as jaxops
def _fft_core(func_name, fft_type, a, s, axes, norm):
# TODO(skye): implement padding/cropping based on 's'.
full_name = "jax.np.fft." + func_name
full_name = "jax.numpy.fft." + func_name
if s is not None:
raise NotImplementedError("%s only supports s=None, got %s" % (full_name, s))
if norm is not None:
@ -93,7 +93,7 @@ def irfftn(a, s=None, axes=None, norm=None):
def _fft_core_1d(func_name, fft_type, a, s, axis, norm):
full_name = "jax.np.fft." + func_name
full_name = "jax.numpy.fft." + func_name
if isinstance(axis, (list, tuple)):
raise ValueError(
"%s does not support multiple axes. Please use %sn. "
@ -125,7 +125,7 @@ def irfft(a, n=None, axis=-1, norm=None):
def _fft_core_2d(func_name, fft_type, a, s, axes, norm):
full_name = "jax.np.fft." + func_name
full_name = "jax.numpy.fft." + func_name
if len(axes) != 2:
raise ValueError(
"%s only supports 2 axes. Got axes = %r."
@ -159,12 +159,12 @@ def irfft2(a, s=None, axes=(-2,-1), norm=None):
def fftfreq(n, d=1.0):
if isinstance(n, list) or isinstance(n, tuple):
raise ValueError(
"The n argument of jax.np.fft.fftfreq only takes an int. "
"The n argument of jax.numpy.fft.fftfreq only takes an int. "
"Got n = %s." % list(n))
elif isinstance(d, list) or isinstance(d, tuple):
raise ValueError(
"The d argument of jax.np.fft.fftfreq only takes a single value. "
"The d argument of jax.numpy.fft.fftfreq only takes a single value. "
"Got d = %s." % list(d))
k = np.zeros(n)
@ -191,12 +191,12 @@ def fftfreq(n, d=1.0):
def rfftfreq(n, d=1.0):
if isinstance(n, list) or isinstance(n, tuple):
raise ValueError(
"The n argument of jax.np.fft.rfftfreq only takes an int. "
"The n argument of jax.numpy.fft.rfftfreq only takes an int. "
"Got n = %s." % list(n))
elif isinstance(d, list) or isinstance(d, tuple):
raise ValueError(
"The d argument of jax.np.fft.rfftfreq only takes a single value. "
"The d argument of jax.numpy.fft.rfftfreq only takes a single value. "
"Got d = %s." % list(d))
if n % 2 == 0:

View File

@ -28,10 +28,10 @@ See tensorflow/compiler/xla/service/hlo_runner.h.
Usage:
$ cat prog.py
import jax.numpy as np
import jax.numpy as jnp
def fn(x, y, z):
return np.dot(x, y) / z
return jnp.dot(x, y) / z
$ python jax_to_hlo.py \
--fn prog.fn \
@ -48,7 +48,7 @@ The order of elements in input_shapes determines the order of parameters in the
resulting HLO program.
Values of `constants` which are lists are converted to Numpy arrays using
np.asarray. In addition, you can specify constants using the flag
jnp.asarray. In addition, you can specify constants using the flag
--evaled_constants; values there that are strings are first evaluated using
ast.literal_eval. --evaled_constants is primarily useful for genrules; Skylark
doesn't support floating-point types, so genrules need to deal in strings.
@ -71,7 +71,7 @@ import functools
from absl import app
from absl import flags
import jax.api
import jax.numpy as np
import jax.numpy as jnp
from jax.lib import xla_client
FLAGS = flags.FLAGS
@ -119,7 +119,7 @@ def jax_to_hlo(fn, input_shapes, constants=None):
raise ValueError('Shape %s has a non-default layout, but only '
'the default layout is allowed.' % str(shape))
args.append(np.zeros(shape.dimensions(), dtype=shape.numpy_dtype()))
args.append(jnp.zeros(shape.dimensions(), dtype=shape.numpy_dtype()))
# Curry `constants` into the function.
fn_curried = functools.partial(fn, **constants)
@ -153,14 +153,14 @@ def main(argv):
constants = {}
for k, v in literal_eval(FLAGS.constants).items():
if isinstance(v, list):
v = np.asarray(v)
v = jnp.asarray(v)
constants[k] = v
for k, v in literal_eval(FLAGS.evaled_constants).items():
if isinstance(v, str):
v = literal_eval(v)
if isinstance(v, list):
v = np.asarray(v)
v = jnp.asarray(v)
if k in constants:
raise ValueError(
'Argument appears in both --constants and --evaled_constants: %s' % k)

View File

@ -180,7 +180,7 @@ class APITest(jtu.JaxTestCase):
jtu.check_raises(lambda: grad(f)(np.zeros(3)), Exception,
"Tracer can't be used with raw numpy functions. "
"You might have\n import numpy as np\ninstead of\n"
" import jax.numpy as np")
" import jax.numpy as jnp")
def test_binop_mismatch(self):
def f(x, y):

View File

@ -18,12 +18,12 @@ import gc
import itertools as it
import operator
import numpy as onp
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
from jax import core
from jax import numpy as np
from jax import numpy as jnp
from jax import test_util as jtu
from jax.api import jvp, linearize, vjp, jit
from jax.lax import UnshapedArray, ShapedArray, ConcreteArray
@ -35,24 +35,24 @@ from jax.interpreters import partial_eval as pe
from jax.config import config
config.parse_flags_with_absl()
_ = pe.PartialVal.unknown(UnshapedArray(onp.float32))
__ = pe.PartialVal.unknown(ShapedArray((), onp.float32))
_ = pe.PartialVal.unknown(UnshapedArray(np.float32))
__ = pe.PartialVal.unknown(ShapedArray((), np.float32))
def call(f, *args):
return jit(f)(*args)
def simple_fun(x, y):
return np.sin(x * y)
return jnp.sin(x * y)
def simple_fun_fanout(x, y):
return np.sin(x * y) * x
return jnp.sin(x * y) * x
def fun_with_call(x):
return call(np.sin, x)
return call(jnp.sin, x)
def fun_with_nested_calls(x):
def f(y):
y2 = np.sin(y) + 1.0 + (2.0 * x)
y2 = jnp.sin(y) + 1.0 + (2.0 * x)
@jit
def g(z):
@ -73,7 +73,7 @@ def fun_with_nested_calls_2(x):
q = call(lambda x: y, x)
q = q + call(lambda: y)
q = q + call(lambda y: w + y, y)
q = call(lambda w: call(np.sin, x) * y, 1.0) + q
q = call(lambda w: call(jnp.sin, x) * y, 1.0) + q
return q
p, t = jvp(baz, (x + 1.0,), (y,))
return t + (x * p)
@ -87,22 +87,22 @@ def fun_call_jitted(x):
return call(g, x)
def fun_with_two_calls(x):
return call(np.sin, x) + call(np.cos, x)
return call(jnp.sin, x) + call(jnp.cos, x)
def fun_with_call_closure(x):
def foo(y, z):
return (x * x) * np.sin(y) * z
return (x * x) * jnp.sin(y) * z
return call(foo, x, np.cos(x)) + x
return call(foo, x, jnp.cos(x)) + x
def product_io_fun(x, y):
xa = x['a']
xb = x['b']
y1, (y2, y3) = y
return np.sin(xa + y2), [xb, (y1, y3)]
return jnp.sin(xa + y2), [xb, (y1, y3)]
R = onp.random.randn
R = np.random.randn
CallSpec = namedtuple('CallSpec', ['fun', 'args'])
test_specs_base = [
CallSpec(simple_fun, (R(3, 2), R(3, 2))),
@ -173,12 +173,12 @@ class CoreTest(jtu.JaxTestCase):
@parameterized.parameters(test_specs)
def test_jvp(self, f, args):
jtu.check_jvp(f, partial(jvp, f), args, rtol={onp.float32: 3e-2})
jtu.check_jvp(f, partial(jvp, f), args, rtol={np.float32: 3e-2})
def test_jvp_zeros(self):
def foo(x):
def bar(y):
return np.sin(x * y)
return jnp.sin(x * y)
return jvp(bar, (3 * x,), (2 * x,))
jtu.check_eq(jit(foo)(0.5), foo(0.5))
@ -186,18 +186,18 @@ class CoreTest(jtu.JaxTestCase):
@parameterized.parameters(test_specs)
def test_jvp_linearized(self, f, args):
jtu.check_jvp(f, partial(jvp_unlinearized, f), args,
rtol={onp.float32: 3e-2})
rtol={np.float32: 3e-2})
@parameterized.parameters(test_specs)
def test_vjp(self, f, args):
jtu.check_vjp(f, partial(vjp, f), args,
rtol={onp.float32: 3e-1, onp.float64: 1e-5},
atol={onp.float32: 1e-2, onp.float64: 1e-5})
rtol={np.float32: 3e-1, np.float64: 1e-5},
atol={np.float32: 1e-2, np.float64: 1e-5})
def test_jvp_closure(self):
def foo(x):
def bar(y):
return np.multiply(x, y)
return jnp.multiply(x, y)
return jvp(bar, (3.0,), (1.0,))[1]
ans = jvp(foo, (1.0,), (2.0,))
assert ans == (1.0, 2.0), ans
@ -220,15 +220,15 @@ class CoreTest(jtu.JaxTestCase):
foo2 = jit(foo)
foo3 = jit(foo2)
x1, y1 = onp.array(1.0), onp.array(2.0)
x1, y1 = np.array(1.0), np.array(2.0)
assert foo(x1) == y1
assert foo2(x1) == y1
assert foo3(x1) == y1
x2, y2 = onp.array([1.0, 2.0]), onp.array([3.0, 4.0])
assert onp.all(foo(x2) == y2)
assert onp.all(foo2(x2) == y2)
assert onp.all(foo3(x2) == y2)
x2, y2 = np.array([1.0, 2.0]), np.array([3.0, 4.0])
assert np.all(foo(x2) == y2)
assert np.all(foo2(x2) == y2)
assert np.all(foo3(x2) == y2)
def test_product_jit(self):
def foo(x, tup):
@ -247,7 +247,7 @@ class CoreTest(jtu.JaxTestCase):
assert foo3(*args) == foo(*args)
def test_jvp_2(self):
d_sin = fwd_deriv(np.sin)
d_sin = fwd_deriv(jnp.sin)
d2_sin = fwd_deriv(d_sin)
d3_sin = fwd_deriv(d2_sin)
@ -262,7 +262,7 @@ class CoreTest(jtu.JaxTestCase):
return x.sum()
fn = partial(linearize, f)
params = np.zeros([])
params = jnp.zeros([])
debug = gc.get_debug()
try:

View File

@ -17,11 +17,9 @@
from absl.testing import absltest
from absl.testing import parameterized
import numpy as onp
import jax
from jax import test_util as jtu
from jax import numpy as np
from jax import numpy as jnp
from jax.config import config
config.parse_flags_with_absl()
@ -36,18 +34,18 @@ class DebugNaNsTest(jtu.JaxTestCase):
config.update("jax_debug_nans", self.cfg)
def testSingleResultPrimitiveNoNaN(self):
A = np.array([[1., 2.], [2., 3.]])
B = np.tanh(A)
A = jnp.array([[1., 2.], [2., 3.]])
B = jnp.tanh(A)
def testMultipleResultPrimitiveNoNaN(self):
A = np.array([[1., 2.], [2., 3.]])
D, V = np.linalg.eig(A)
A = jnp.array([[1., 2.], [2., 3.]])
D, V = jnp.linalg.eig(A)
def testJitComputationNoNaN(self):
A = np.array([[1., 2.], [2., 3.]])
B = jax.jit(np.tanh)(A)
A = jnp.array([[1., 2.], [2., 3.]])
B = jax.jit(jnp.tanh)(A)
def testSingleResultPrimitiveNaN(self):
A = np.array(0.)
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
B = 0. / A

View File

@ -21,12 +21,12 @@ import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as onp
import numpy as np
import jax
from jax import core
from jax import dtypes
from jax import numpy as np
from jax import numpy as jnp
from jax import test_util as jtu
from jax.interpreters import xla
@ -34,41 +34,41 @@ from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
bool_dtypes = [onp.dtype('bool')]
bool_dtypes = [np.dtype('bool')]
signed_dtypes = [onp.dtype('int8'), onp.dtype('int16'), onp.dtype('int32'),
onp.dtype('int64')]
signed_dtypes = [np.dtype('int8'), np.dtype('int16'), np.dtype('int32'),
np.dtype('int64')]
unsigned_dtypes = [onp.dtype('uint8'), onp.dtype('uint16'), onp.dtype('uint32'),
onp.dtype('uint64')]
unsigned_dtypes = [np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'),
np.dtype('uint64')]
onp_float_dtypes = [onp.dtype('float16'), onp.dtype('float32'),
onp.dtype('float64')]
np_float_dtypes = [np.dtype('float16'), np.dtype('float32'),
np.dtype('float64')]
float_dtypes = [onp.dtype(dtypes.bfloat16)] + onp_float_dtypes
float_dtypes = [np.dtype(dtypes.bfloat16)] + np_float_dtypes
complex_dtypes = [onp.dtype('complex64'), onp.dtype('complex128')]
complex_dtypes = [np.dtype('complex64'), np.dtype('complex128')]
all_dtypes = (bool_dtypes + signed_dtypes + unsigned_dtypes + float_dtypes +
complex_dtypes)
scalar_types = [np.bool_, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
np.bfloat16, np.float16, np.float32, np.float64,
np.complex64, np.complex128]
scalar_types = [jnp.bool_, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64,
jnp.complex64, jnp.complex128]
class DtypesTest(jtu.JaxTestCase):
@parameterized.named_parameters(
{"testcase_name": "_type={}".format(type.__name__), "type": type,
"dtype": dtype}
for type, dtype in [(bool, np.bool_), (int, np.int_), (float, np.float_),
(complex, np.complex_)])
for type, dtype in [(bool, jnp.bool_), (int, jnp.int_), (float, jnp.float_),
(complex, jnp.complex_)])
def testDefaultTypes(self, type, dtype):
for f in [np.array, jax.jit(np.array), jax.jit(lambda x: x)]:
for f in [jnp.array, jax.jit(jnp.array), jax.jit(lambda x: x)]:
y = f(type(0))
self.assertTrue(isinstance(y, np.ndarray), msg=(f, y))
self.assertTrue(isinstance(y, jnp.ndarray), msg=(f, y))
self.assertEqual(y.dtype, dtypes.canonicalize_dtype(dtype), msg=(f, y))
@parameterized.named_parameters(
@ -78,51 +78,51 @@ class DtypesTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu") # F16 not supported on TPU
def testBinaryPromotion(self, swap, jit):
testcases = [
(np.array(1.), 0., np.float_),
(np.array(1.), np.array(0.), np.float_),
(np.array(1.), np.array(0., dtype=np.float16), np.float_),
(np.array(1.), np.array(0., dtype=np.float32), np.float_),
(np.array(1.), np.array(0., dtype=np.float64), np.float64),
(np.array(1., dtype=np.float16), 0., np.float16),
(np.array(1., dtype=np.float32), 0., np.float32),
(np.array(1., dtype=np.float64), 0., np.float64),
(np.array(1., dtype=np.float16), np.array(0., dtype=np.float16), np.float16),
(np.array(1., dtype=np.float16), np.array(0., dtype=np.float32), np.float32),
(np.array(1., dtype=np.float16), np.array(0., dtype=np.float64), np.float64),
(np.array(1., dtype=np.float32), np.array(0., dtype=np.float32), np.float32),
(np.array(1., dtype=np.float32), np.array(0., dtype=np.float64), np.float64),
(np.array(1., dtype=np.float64), np.array(0., dtype=np.float64), np.float64),
(np.array([1.]), 0., np.float_),
(np.array([1.]), np.array(0.), np.float_),
(np.array([1.]), np.array(0., dtype=np.float16), np.float_),
(np.array([1.]), np.array(0., dtype=np.float32), np.float_),
(np.array([1.]), np.array(0., dtype=np.float64), np.float64),
(np.array([1.], dtype=np.float32), np.array(0., dtype=np.float16), np.float32),
(np.array([1.], dtype=np.float16), np.array(0., dtype=np.float32), np.float32),
(np.array([1.], dtype=np.float16), 0., np.float16),
(jnp.array(1.), 0., jnp.float_),
(jnp.array(1.), jnp.array(0.), jnp.float_),
(jnp.array(1.), jnp.array(0., dtype=jnp.float16), jnp.float_),
(jnp.array(1.), jnp.array(0., dtype=jnp.float32), jnp.float_),
(jnp.array(1.), jnp.array(0., dtype=jnp.float64), jnp.float64),
(jnp.array(1., dtype=jnp.float16), 0., jnp.float16),
(jnp.array(1., dtype=jnp.float32), 0., jnp.float32),
(jnp.array(1., dtype=jnp.float64), 0., jnp.float64),
(jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float16), jnp.float16),
(jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float32), jnp.float32),
(jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float64), jnp.float64),
(jnp.array(1., dtype=jnp.float32), jnp.array(0., dtype=jnp.float32), jnp.float32),
(jnp.array(1., dtype=jnp.float32), jnp.array(0., dtype=jnp.float64), jnp.float64),
(jnp.array(1., dtype=jnp.float64), jnp.array(0., dtype=jnp.float64), jnp.float64),
(jnp.array([1.]), 0., jnp.float_),
(jnp.array([1.]), jnp.array(0.), jnp.float_),
(jnp.array([1.]), jnp.array(0., dtype=jnp.float16), jnp.float_),
(jnp.array([1.]), jnp.array(0., dtype=jnp.float32), jnp.float_),
(jnp.array([1.]), jnp.array(0., dtype=jnp.float64), jnp.float64),
(jnp.array([1.], dtype=jnp.float32), jnp.array(0., dtype=jnp.float16), jnp.float32),
(jnp.array([1.], dtype=jnp.float16), jnp.array(0., dtype=jnp.float32), jnp.float32),
(jnp.array([1.], dtype=jnp.float16), 0., jnp.float16),
]
op = jax.jit(operator.add) if jit else operator.add
for x, y, dtype in testcases:
x, y = (y, x) if swap else (x, y)
z = x + y
self.assertTrue(isinstance(z, np.ndarray), msg=(x, y, z))
self.assertTrue(isinstance(z, jnp.ndarray), msg=(x, y, z))
self.assertEqual(z.dtype, dtypes.canonicalize_dtype(dtype), msg=(x, y, z))
def testPromoteDtypes(self):
for t1 in all_dtypes:
self.assertEqual(t1, dtypes.promote_types(t1, t1))
self.assertEqual(t1, dtypes.promote_types(t1, onp.bool_))
self.assertEqual(onp.dtype(onp.complex128),
dtypes.promote_types(t1, onp.complex128))
self.assertEqual(t1, dtypes.promote_types(t1, np.bool_))
self.assertEqual(np.dtype(np.complex128),
dtypes.promote_types(t1, np.complex128))
for t2 in all_dtypes:
# Symmetry
self.assertEqual(dtypes.promote_types(t1, t2),
dtypes.promote_types(t2, t1))
self.assertEqual(onp.dtype(onp.float32),
dtypes.promote_types(onp.float16, dtypes.bfloat16))
self.assertEqual(np.dtype(np.float32),
dtypes.promote_types(np.float16, dtypes.bfloat16))
# Promotions of non-inexact types against inexact types always prefer
# the inexact types.
@ -132,47 +132,47 @@ class DtypesTest(jtu.JaxTestCase):
# Promotions between exact types, or between inexact types, match NumPy.
for groups in [bool_dtypes + signed_dtypes + unsigned_dtypes,
onp_float_dtypes + complex_dtypes]:
np_float_dtypes + complex_dtypes]:
for t1, t2 in itertools.combinations(groups, 2):
self.assertEqual(onp.promote_types(t1, t2),
self.assertEqual(np.promote_types(t1, t2),
dtypes.promote_types(t1, t2))
def testScalarInstantiation(self):
for t in [np.bool_, np.int32, np.bfloat16, np.float32, np.complex64]:
for t in [jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64]:
a = t(1)
self.assertEqual(a.dtype, np.dtype(t))
self.assertEqual(a.dtype, jnp.dtype(t))
self.assertIsInstance(a, xla.DeviceArray)
self.assertEqual(0, np.ndim(a))
self.assertEqual(0, jnp.ndim(a))
def testIsSubdtype(self):
for t in scalar_types:
self.assertTrue(dtypes.issubdtype(t, t))
self.assertTrue(dtypes.issubdtype(onp.dtype(t).type, t))
self.assertTrue(dtypes.issubdtype(t, onp.dtype(t).type))
if t != np.bfloat16:
for category in [onp.generic, np.inexact, np.integer, np.signedinteger,
np.unsignedinteger, np.floating, np.complexfloating]:
self.assertTrue(dtypes.issubdtype(np.dtype(t).type, t))
self.assertTrue(dtypes.issubdtype(t, np.dtype(t).type))
if t != jnp.bfloat16:
for category in [np.generic, jnp.inexact, jnp.integer, jnp.signedinteger,
jnp.unsignedinteger, jnp.floating, jnp.complexfloating]:
self.assertEqual(dtypes.issubdtype(t, category),
onp.issubdtype(onp.dtype(t).type, category))
np.issubdtype(np.dtype(t).type, category))
self.assertEqual(dtypes.issubdtype(t, category),
onp.issubdtype(onp.dtype(t).type, category))
np.issubdtype(np.dtype(t).type, category))
def testArrayCasts(self):
for t in [np.bool_, np.int32, np.bfloat16, np.float32, np.complex64]:
a = onp.array([1, 2.5, -3.7])
self.assertEqual(a.astype(t).dtype, np.dtype(t))
self.assertEqual(np.array(a).astype(t).dtype, np.dtype(t))
for t in [jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64]:
a = np.array([1, 2.5, -3.7])
self.assertEqual(a.astype(t).dtype, jnp.dtype(t))
self.assertEqual(jnp.array(a).astype(t).dtype, jnp.dtype(t))
def testEnumPromotion(self):
class AnEnum(enum.IntEnum):
A = 42
B = 101
onp.testing.assert_equal(onp.array(42), onp.array(AnEnum.A))
np.testing.assert_equal(np.array(42), np.array(AnEnum.A))
with core.skipping_checks():
# Passing AnEnum.A to np.array fails the type check in bind
onp.testing.assert_equal(np.array(42), np.array(AnEnum.A))
onp.testing.assert_equal(onp.int32(101), onp.int32(AnEnum.B))
onp.testing.assert_equal(np.int32(101), np.int32(AnEnum.B))
# Passing AnEnum.A to jnp.array fails the type check in bind
np.testing.assert_equal(jnp.array(42), jnp.array(AnEnum.A))
np.testing.assert_equal(np.int32(101), np.int32(AnEnum.B))
np.testing.assert_equal(jnp.int32(101), jnp.int32(AnEnum.B))
if __name__ == "__main__":
absltest.main()

View File

@ -16,26 +16,26 @@
import itertools
import unittest
import numpy as onp
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
from jax import lax
from jax import numpy as np
from jax import numpy as jnp
from jax import test_util as jtu
from jax.config import config
config.parse_flags_with_absl()
float_dtypes = [onp.float32, onp.float64]
# TODO(b/144573940): onp.complex128 isn't supported by XLA, and the JAX
float_dtypes = [np.float32, np.float64]
# TODO(b/144573940): np.complex128 isn't supported by XLA, and the JAX
# implementation casts to complex64.
complex_dtypes = [onp.complex64]
complex_dtypes = [np.complex64]
inexact_dtypes = float_dtypes + complex_dtypes
int_dtypes = [onp.int32, onp.int64]
bool_dtypes = [onp.bool_]
int_dtypes = [np.int32, np.int64]
bool_dtypes = [np.bool_]
real_dtypes = float_dtypes + int_dtypes + bool_dtypes
all_dtypes = real_dtypes + complex_dtypes
@ -84,7 +84,7 @@ def _zero_for_irfft(z, axes):
else:
parts = [lax.slice_in_dim(z.real, 0, 1, axis=axis).real,
lax.slice_in_dim(z.real, 1, size, axis=axis)]
return np.concatenate(parts, axis=axis)
return jnp.concatenate(parts, axis=axis)
class FftTest(jtu.JaxTestCase):
@ -103,19 +103,19 @@ class FftTest(jtu.JaxTestCase):
def testFftn(self, inverse, real, shape, dtype, axes, rng_factory):
rng = rng_factory(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_op = _get_fftn_func(jnp.fft, inverse, real)
np_op = _get_fftn_func(np.fft, inverse, real)
onp_op = _get_fftn_func(onp.fft, inverse, real)
np_fn = lambda a: np_op(a, axes=axes)
onp_fn = lambda a: onp_op(a, axes=axes) if axes is None or axes else a
jnp_fn = lambda a: jnp_op(a, axes=axes)
np_fn = lambda a: np_op(a, axes=axes) if axes is None or axes else a
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=False,
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True)
# Test gradient for differentiable types.
if dtype in (float_dtypes if real and not inverse else inexact_dtypes):
# TODO(skye): can we be more precise?
tol = 0.15
jtu.check_grads(np_fn, args_maker(), order=2, atol=tol, rtol=tol)
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inverse={}_real={}".format(inverse, real),
@ -129,20 +129,20 @@ class FftTest(jtu.JaxTestCase):
name = 'r' + name
if inverse:
name = 'i' + name
func = _get_fftn_func(np.fft, inverse, real)
func = _get_fftn_func(jnp.fft, inverse, real)
self.assertRaisesRegex(
ValueError,
"jax.np.fft.{} only supports 1D, 2D, and 3D FFTs. "
"jax.numpy.fft.{} only supports 1D, 2D, and 3D FFTs. "
"Got axes None with input rank 4.".format(name),
lambda: func(rng([2, 3, 4, 5], dtype=onp.float64), axes=None))
lambda: func(rng([2, 3, 4, 5], dtype=np.float64), axes=None))
self.assertRaisesRegex(
ValueError,
"jax.np.fft.{} does not support repeated axes. Got axes \\[1, 1\\].".format(name),
lambda: func(rng([2, 3], dtype=onp.float64), axes=[1, 1]))
"jax.numpy.fft.{} does not support repeated axes. Got axes \\[1, 1\\].".format(name),
lambda: func(rng([2, 3], dtype=np.float64), axes=[1, 1]))
self.assertRaises(
ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axes=[2]))
ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2]))
self.assertRaises(
ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axes=[-3]))
ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3]))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inverse={}_real={}_shape={}_axis={}".format(
@ -163,14 +163,14 @@ class FftTest(jtu.JaxTestCase):
name = 'r' + name
if inverse:
name = 'i' + name
jnp_op = getattr(jnp.fft, name)
np_op = getattr(np.fft, name)
onp_op = getattr(onp.fft, name)
jnp_fn = lambda a: jnp_op(a, axis=axis)
np_fn = lambda a: np_op(a, axis=axis)
onp_fn = lambda a: onp_op(a, axis=axis)
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(onp_op, np_op, args_maker, check_dtypes=False,
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(np_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inverse={}_real={}".format(inverse, real),
@ -184,26 +184,26 @@ class FftTest(jtu.JaxTestCase):
name = 'r' + name
if inverse:
name = 'i' + name
func = getattr(np.fft, name)
func = getattr(jnp.fft, name)
self.assertRaisesRegex(
ValueError,
"jax.np.fft.{} does not support multiple axes. "
"Please use jax.np.fft.{}n. "
"jax.numpy.fft.{} does not support multiple axes. "
"Please use jax.numpy.fft.{}n. "
"Got axis = \\[1, 1\\].".format(name, name),
lambda: func(rng([2, 3], dtype=onp.float64), axis=[1, 1])
lambda: func(rng([2, 3], dtype=np.float64), axis=[1, 1])
)
self.assertRaisesRegex(
ValueError,
"jax.np.fft.{} does not support multiple axes. "
"Please use jax.np.fft.{}n. "
"jax.numpy.fft.{} does not support multiple axes. "
"Please use jax.numpy.fft.{}n. "
"Got axis = \\(1, 1\\).".format(name, name),
lambda: func(rng([2, 3], dtype=onp.float64), axis=(1, 1))
lambda: func(rng([2, 3], dtype=np.float64), axis=(1, 1))
)
self.assertRaises(
ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axis=[2]))
ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[2]))
self.assertRaises(
ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axis=[-3]))
ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[-3]))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inverse={}_real={}_shape={}_axes={}".format(
@ -224,12 +224,12 @@ class FftTest(jtu.JaxTestCase):
name = 'r' + name
if inverse:
name = 'i' + name
jnp_op = getattr(jnp.fft, name)
np_op = getattr(np.fft, name)
onp_op = getattr(onp.fft, name)
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(onp_op, np_op, args_maker, check_dtypes=False,
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(np_op, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inverse={}_real={}".format(inverse, real),
@ -243,24 +243,24 @@ class FftTest(jtu.JaxTestCase):
name = 'r' + name
if inverse:
name = 'i' + name
func = getattr(np.fft, name)
func = getattr(jnp.fft, name)
self.assertRaisesRegex(
ValueError,
"jax.np.fft.{} only supports 2 axes. "
"jax.numpy.fft.{} only supports 2 axes. "
"Got axes = \\[0\\].".format(name),
lambda: func(rng([2, 3], dtype=onp.float64), axes=[0])
lambda: func(rng([2, 3], dtype=np.float64), axes=[0])
)
self.assertRaisesRegex(
ValueError,
"jax.np.fft.{} only supports 2 axes. "
"jax.numpy.fft.{} only supports 2 axes. "
"Got axes = \\(0, 1, 2\\).".format(name),
lambda: func(rng([2, 3, 3], dtype=onp.float64), axes=(0, 1, 2))
lambda: func(rng([2, 3, 3], dtype=np.float64), axes=(0, 1, 2))
)
self.assertRaises(
ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axes=[2, 3]))
ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2, 3]))
self.assertRaises(
ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axes=[-3, -4]))
ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3, -4]))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_size={}_d={}".format(
@ -273,18 +273,18 @@ class FftTest(jtu.JaxTestCase):
def testFftfreq(self, size, d, dtype, rng_factory):
rng = rng_factory(self.rng())
args_maker = lambda: (rng([size], dtype),)
jnp_op = jnp.fft.fftfreq
np_op = np.fft.fftfreq
onp_op = onp.fft.fftfreq
jnp_fn = lambda a: jnp_op(size, d=d)
np_fn = lambda a: np_op(size, d=d)
onp_fn = lambda a: onp_op(size, d=d)
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=False,
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True)
# Test gradient for differentiable types.
if dtype in inexact_dtypes:
tol = 0.15 # TODO(skye): can we be more precise?
jtu.check_grads(np_fn, args_maker(), order=2, atol=tol, rtol=tol)
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_n={}".format(n),
@ -292,16 +292,16 @@ class FftTest(jtu.JaxTestCase):
for n in [[0,1,2]]))
def testFftfreqErrors(self, n):
name = 'fftfreq'
func = np.fft.fftfreq
func = jnp.fft.fftfreq
self.assertRaisesRegex(
ValueError,
"The n argument of jax.np.fft.{} only takes an int. "
"The n argument of jax.numpy.fft.{} only takes an int. "
"Got n = \\[0, 1, 2\\].".format(name),
lambda: func(n=n)
)
self.assertRaisesRegex(
ValueError,
"The d argument of jax.np.fft.{} only takes a single value. "
"The d argument of jax.numpy.fft.{} only takes a single value. "
"Got d = \\[0, 1, 2\\].".format(name),
lambda: func(n=10, d=n)
)
@ -317,18 +317,18 @@ class FftTest(jtu.JaxTestCase):
def testRfftfreq(self, size, d, dtype, rng_factory):
rng = rng_factory(self.rng())
args_maker = lambda: (rng([size], dtype),)
jnp_op = jnp.fft.rfftfreq
np_op = np.fft.rfftfreq
onp_op = onp.fft.rfftfreq
jnp_fn = lambda a: jnp_op(size, d=d)
np_fn = lambda a: np_op(size, d=d)
onp_fn = lambda a: onp_op(size, d=d)
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=False,
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True)
# Test gradient for differentiable types.
if dtype in inexact_dtypes:
tol = 0.15 # TODO(skye): can we be more precise?
jtu.check_grads(np_fn, args_maker(), order=2, atol=tol, rtol=tol)
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_n={}".format(n),
@ -336,16 +336,16 @@ class FftTest(jtu.JaxTestCase):
for n in [[0, 1, 2]]))
def testRfftfreqErrors(self, n):
name = 'rfftfreq'
func = np.fft.rfftfreq
func = jnp.fft.rfftfreq
self.assertRaisesRegex(
ValueError,
"The n argument of jax.np.fft.{} only takes an int. "
"The n argument of jax.numpy.fft.{} only takes an int. "
"Got n = \\[0, 1, 2\\].".format(name),
lambda: func(n=n)
)
self.assertRaisesRegex(
ValueError,
"The d argument of jax.np.fft.{} only takes a single value. "
"The d argument of jax.numpy.fft.{} only takes a single value. "
"Got d = \\[0, 1, 2\\].".format(name),
lambda: func(n=10, d=n)
)
@ -361,9 +361,9 @@ class FftTest(jtu.JaxTestCase):
def testFftshift(self, shape, dtype, rng_factory, axes):
rng = rng_factory(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda arg: jnp.fft.fftshift(arg, axes=axes)
np_fn = lambda arg: np.fft.fftshift(arg, axes=axes)
onp_fn = lambda arg: onp.fft.fftshift(arg, axes=axes)
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "dtype={}_axes={}".format(
@ -376,9 +376,9 @@ class FftTest(jtu.JaxTestCase):
def testIfftshift(self, shape, dtype, rng_factory, axes):
rng = rng_factory(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda arg: jnp.fft.ifftshift(arg, axes=axes)
np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes)
onp_fn = lambda arg: onp.fft.ifftshift(arg, axes=axes)
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True)
if __name__ == "__main__":
absltest.main()

View File

@ -16,14 +16,12 @@ from functools import partial
import itertools
import unittest
import numpy as onp
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
from jax import numpy as np
from jax import numpy as jnp
from jax import test_util as jtu
from jax.config import config
@ -41,9 +39,9 @@ class VectorizeTest(jtu.JaxTestCase):
((6, 5, 2, 3), (3, 4), (6, 5, 2, 4)),
]))
def test_matmat(self, left_shape, right_shape, result_shape):
matmat = np.vectorize(np.dot, signature='(n,m),(m,k)->(n,k)')
self.assertEqual(matmat(np.zeros(left_shape),
np.zeros(right_shape)).shape, result_shape)
matmat = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k)')
self.assertEqual(matmat(jnp.zeros(left_shape),
jnp.zeros(right_shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
@ -55,9 +53,9 @@ class VectorizeTest(jtu.JaxTestCase):
((5, 4, 2, 3), (1, 3), (5, 4, 2)),
]))
def test_matvec(self, left_shape, right_shape, result_shape):
matvec = np.vectorize(np.dot, signature='(n,m),(m)->(n)')
self.assertEqual(matvec(np.zeros(left_shape),
np.zeros(right_shape)).shape, result_shape)
matvec = jnp.vectorize(jnp.dot, signature='(n,m),(m)->(n)')
self.assertEqual(matvec(jnp.zeros(left_shape),
jnp.zeros(right_shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
@ -68,9 +66,9 @@ class VectorizeTest(jtu.JaxTestCase):
((4, 2, 3), (3,), (4, 2)),
]))
def test_vecmat(self, left_shape, right_shape, result_shape):
vecvec = np.vectorize(np.dot, signature='(m),(m)->()')
self.assertEqual(vecvec(np.zeros(left_shape),
np.zeros(right_shape)).shape, result_shape)
vecvec = jnp.vectorize(jnp.dot, signature='(m),(m)->()')
self.assertEqual(vecvec(jnp.zeros(left_shape),
jnp.zeros(right_shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(shape),
@ -84,11 +82,11 @@ class VectorizeTest(jtu.JaxTestCase):
size = 1
for x in shape:
size *= x
inputs = np.arange(size).reshape(shape)
inputs = jnp.arange(size).reshape(shape)
@partial(np.vectorize, signature='(n)->()')
@partial(jnp.vectorize, signature='(n)->()')
def magnitude(x):
return np.dot(x, x)
return jnp.dot(x, x)
self.assertEqual(magnitude(inputs).shape, result_shape)
@ -101,8 +99,8 @@ class VectorizeTest(jtu.JaxTestCase):
((1, 2, 3, 4), (1, 2, 3)),
]))
def test_mean(self, shape, result_shape):
mean = np.vectorize(np.mean, signature='(n)->()')
self.assertEqual(mean(np.zeros(shape)).shape, result_shape)
mean = jnp.vectorize(jnp.mean, signature='(n)->()')
self.assertEqual(mean(jnp.zeros(shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(shape),
@ -113,104 +111,104 @@ class VectorizeTest(jtu.JaxTestCase):
]))
def test_stack_plus_minus(self, shape, result_shape):
@partial(np.vectorize, signature='()->(n)')
@partial(jnp.vectorize, signature='()->(n)')
def stack_plus_minus(x):
return np.stack([x, -x])
return jnp.stack([x, -x])
self.assertEqual(stack_plus_minus(np.zeros(shape)).shape, result_shape)
self.assertEqual(stack_plus_minus(jnp.zeros(shape)).shape, result_shape)
def test_center(self):
@partial(np.vectorize, signature='(n)->(),(n)')
@partial(jnp.vectorize, signature='(n)->(),(n)')
def center(array):
bias = np.mean(array)
bias = jnp.mean(array)
debiased = array - bias
return bias, debiased
b, a = center(np.arange(3))
b, a = center(jnp.arange(3))
self.assertEqual(a.shape, (3,))
self.assertEqual(b.shape, ())
self.assertAllClose(1.0, b, check_dtypes=False)
b, a = center(np.arange(6).reshape(2, 3))
b, a = center(jnp.arange(6).reshape(2, 3))
self.assertEqual(a.shape, (2, 3))
self.assertEqual(b.shape, (2,))
self.assertAllClose(np.array([1.0, 4.0]), b, check_dtypes=False)
self.assertAllClose(jnp.array([1.0, 4.0]), b, check_dtypes=False)
def test_exclude_first(self):
@partial(np.vectorize, excluded={0})
@partial(jnp.vectorize, excluded={0})
def f(x, y):
assert x == 'foo'
assert y.ndim == 0
return y
x = np.arange(3)
x = jnp.arange(3)
self.assertAllClose(x, f('foo', x), check_dtypes=True)
self.assertAllClose(x, jax.jit(f, 0)('foo', x), check_dtypes=True)
def test_exclude_second(self):
@partial(np.vectorize, excluded={1})
@partial(jnp.vectorize, excluded={1})
def f(x, y):
assert x.ndim == 0
assert y == 'foo'
return x
x = np.arange(3)
x = jnp.arange(3)
self.assertAllClose(x, f(x, 'foo'), check_dtypes=True)
self.assertAllClose(x, jax.jit(f, 1)(x, 'foo'), check_dtypes=True)
def test_exclude_errors(self):
with self.assertRaisesRegex(
TypeError, "jax.numpy.vectorize can only exclude"):
np.vectorize(lambda x: x, excluded={'foo'})
jnp.vectorize(lambda x: x, excluded={'foo'})
with self.assertRaisesRegex(
ValueError, r"excluded=\{-1\} contains negative numbers"):
np.vectorize(lambda x: x, excluded={-1})
jnp.vectorize(lambda x: x, excluded={-1})
f = np.vectorize(lambda x: x, excluded={1})
f = jnp.vectorize(lambda x: x, excluded={1})
with self.assertRaisesRegex(
ValueError, r"excluded=\{1\} is invalid for 1 argument\(s\)"):
f(1.0)
def test_bad_inputs(self):
matmat = np.vectorize(np.dot, signature='(n,m),(m,k)->(n,k)')
matmat = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k)')
with self.assertRaisesRegex(
TypeError, "wrong number of positional arguments"):
matmat(np.zeros((3, 2)))
matmat(jnp.zeros((3, 2)))
with self.assertRaisesRegex(
ValueError,
r"input with shape \(2,\) does not have enough dimensions"):
matmat(np.zeros((2,)), np.zeros((2, 2)))
matmat(jnp.zeros((2,)), jnp.zeros((2, 2)))
with self.assertRaisesRegex(
ValueError, r"inconsistent size for core dimension 'm'"):
matmat(np.zeros((2, 3)), np.zeros((4, 5)))
matmat(jnp.zeros((2, 3)), jnp.zeros((4, 5)))
def test_wrong_output_type(self):
f = np.vectorize(np.dot, signature='(n,m),(m,k)->(n,k),()')
f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k),()')
with self.assertRaisesRegex(
TypeError, "output must be a tuple"):
f(np.zeros((2, 2)), np.zeros((2, 2)))
f(jnp.zeros((2, 2)), jnp.zeros((2, 2)))
def test_wrong_num_outputs(self):
f = np.vectorize(lambda *args: args, signature='(),()->(),(),()')
f = jnp.vectorize(lambda *args: args, signature='(),()->(),(),()')
with self.assertRaisesRegex(
TypeError, "wrong number of output arguments"):
f(1, 2)
def test_wrong_output_shape(self):
f = np.vectorize(np.dot, signature='(n,m),(m,k)->(n)')
f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n)')
with self.assertRaisesRegex(
ValueError, r"output shape \(2, 2\) does not match"):
f(np.zeros((2, 2)), np.zeros((2, 2)))
f(jnp.zeros((2, 2)), jnp.zeros((2, 2)))
def test_inconsistent_output_size(self):
f = np.vectorize(np.dot, signature='(n,m),(m,k)->(n,n)')
f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,n)')
with self.assertRaisesRegex(
ValueError, r"inconsistent size for core dimension 'n'"):
f(np.zeros((2, 3)), np.zeros((3, 4)))
f(jnp.zeros((2, 3)), jnp.zeros((3, 4)))
if __name__ == "__main__":

View File

@ -19,7 +19,7 @@ import itertools
import unittest
import sys
import numpy as onp
import numpy as np
import scipy as osp
from absl.testing import absltest
@ -30,7 +30,7 @@ import jax.lib
from jax import jit, grad, jvp, vmap
from jax import lax
from jax import lax_linalg
from jax import numpy as np
from jax import numpy as jnp
from jax import scipy as jsp
from jax import test_util as jtu
from jax.lib import xla_bridge
@ -40,16 +40,16 @@ from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
T = lambda x: onp.swapaxes(x, -1, -2)
T = lambda x: np.swapaxes(x, -1, -2)
float_types = [onp.float32, onp.float64]
complex_types = [onp.complex64, onp.complex128]
float_types = [np.float32, np.float64]
complex_types = [np.complex64, np.complex128]
def _skip_if_unsupported_type(dtype):
dtype = onp.dtype(dtype)
dtype = np.dtype(dtype)
if (not FLAGS.jax_enable_x64 and
dtype in (onp.dtype('float64'), onp.dtype('complex128'))):
dtype in (np.dtype('float64'), np.dtype('complex128'))):
raise unittest.SkipTest("--jax_enable_x64 is not set")
@ -69,25 +69,25 @@ class NumpyLinalgTest(jtu.JaxTestCase):
def args_maker():
factor_shape = shape[:-1] + (2 * shape[-1],)
a = rng(factor_shape, dtype)
return [onp.matmul(a, np.conj(T(a)))]
return [np.matmul(a, jnp.conj(T(a)))]
if (np.issubdtype(dtype, np.complexfloating) and
if (jnp.issubdtype(dtype, jnp.complexfloating) and
jtu.device_under_test() == "tpu"):
self.skipTest("Unimplemented case for complex Cholesky decomposition.")
self._CheckAgainstNumpy(onp.linalg.cholesky, np.linalg.cholesky, args_maker,
self._CheckAgainstNumpy(np.linalg.cholesky, jnp.linalg.cholesky, args_maker,
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(np.linalg.cholesky, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp.linalg.cholesky, args_maker, check_dtypes=True)
if np.finfo(dtype).bits == 64:
jtu.check_grads(np.linalg.cholesky, args_maker(), order=2)
if jnp.finfo(dtype).bits == 64:
jtu.check_grads(jnp.linalg.cholesky, args_maker(), order=2)
def testCholeskyGradPrecision(self):
rng = jtu.rand_default(self.rng())
a = rng((3, 3), onp.float32)
a = onp.dot(a, a.T)
a = rng((3, 3), np.float32)
a = np.dot(a, a.T)
jtu.assert_dot_precision(
lax.Precision.HIGHEST, partial(jvp, np.linalg.cholesky), (a,), (a,))
lax.Precision.HIGHEST, partial(jvp, jnp.linalg.cholesky), (a,), (a,))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -101,14 +101,14 @@ class NumpyLinalgTest(jtu.JaxTestCase):
_skip_if_unsupported_type(dtype)
args_maker = lambda: [rng((n, n), dtype)]
self._CheckAgainstNumpy(onp.linalg.det, np.linalg.det, args_maker,
self._CheckAgainstNumpy(np.linalg.det, jnp.linalg.det, args_maker,
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True,
rtol={onp.float64: 1e-13, onp.complex128: 1e-13})
self._CompileAndCheck(jnp.linalg.det, args_maker, check_dtypes=True,
rtol={np.float64: 1e-13, np.complex128: 1e-13})
def testDetOfSingularMatrix(self):
x = np.array([[-1., 3./2], [2./3, -1.]], dtype=onp.float32)
self.assertAllClose(onp.float32(0), jsp.linalg.det(x), check_dtypes=True)
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
self.assertAllClose(np.float32(0), jsp.linalg.det(x), check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -123,30 +123,30 @@ class NumpyLinalgTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
a = rng(shape, dtype)
jtu.check_grads(np.linalg.det, (a,), 2, atol=1e-1, rtol=1e-1)
jtu.check_grads(jnp.linalg.det, (a,), 2, atol=1e-1, rtol=1e-1)
# make sure there are no NaNs when a matrix is zero
if len(shape) == 2:
pass
jtu.check_grads(
np.linalg.det, (np.zeros_like(a),), 1, atol=1e-1, rtol=1e-1)
jnp.linalg.det, (jnp.zeros_like(a),), 1, atol=1e-1, rtol=1e-1)
else:
a[0] = 0
jtu.check_grads(np.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
jtu.check_grads(jnp.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
def testDetGradOfSingularMatrixCorank1(self):
# Rank 2 matrix with nonzero gradient
a = np.array([[ 50, -30, 45],
a = jnp.array([[ 50, -30, 45],
[-30, 90, -81],
[ 45, -81, 81]], dtype=np.float32)
jtu.check_grads(np.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
[ 45, -81, 81]], dtype=jnp.float32)
jtu.check_grads(jnp.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
@jtu.skip_on_devices("tpu") # TODO(mattjj,pfau): nan on tpu, investigate
def testDetGradOfSingularMatrixCorank2(self):
# Rank 1 matrix with zero gradient
b = np.array([[ 36, -42, 18],
b = jnp.array([[ 36, -42, 18],
[-42, 49, -21],
[ 18, -21, 9]], dtype=np.float32)
jtu.check_grads(np.linalg.det, (b,), 1, atol=1e-1, rtol=1e-1)
[ 18, -21, 9]], dtype=jnp.float32)
jtu.check_grads(jnp.linalg.det, (b,), 1, atol=1e-1, rtol=1e-1)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -178,16 +178,16 @@ class NumpyLinalgTest(jtu.JaxTestCase):
rng(b_shape + Q, dtype), # = a
rng(b_shape, dtype)] # = b
a, b = args_maker()
result = np.linalg.tensorsolve(*args_maker())
result = jnp.linalg.tensorsolve(*args_maker())
self.assertEqual(result.shape, Q)
self._CheckAgainstNumpy(onp.linalg.tensorsolve,
np.linalg.tensorsolve, args_maker,
self._CheckAgainstNumpy(np.linalg.tensorsolve,
jnp.linalg.tensorsolve, args_maker,
check_dtypes=True,
tol={onp.float32: 1e-2, onp.float64: 1e-3})
self._CompileAndCheck(np.linalg.tensorsolve,
tol={np.float32: 1e-2, np.float64: 1e-3})
self._CompileAndCheck(jnp.linalg.tensorsolve,
args_maker, check_dtypes=True,
rtol={onp.float64: 1e-13})
rtol={np.float64: 1e-13})
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -204,9 +204,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
_skip_if_unsupported_type(dtype)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker,
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp.linalg.slogdet, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -221,13 +221,13 @@ class NumpyLinalgTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
a = rng(shape, dtype)
jtu.check_grads(np.linalg.slogdet, (a,), 2, atol=1e-1, rtol=1e-1)
jtu.check_grads(jnp.linalg.slogdet, (a,), 2, atol=1e-1, rtol=1e-1)
def testIssue1213(self):
for n in range(5):
mat = np.array([onp.diag(onp.ones([5], dtype=onp.float32))*(-.01)] * 2)
mat = jnp.array([np.diag(np.ones([5], dtype=np.float32))*(-.01)] * 2)
args_maker = lambda: [mat]
self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker,
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
check_dtypes=True, tol=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
@ -249,14 +249,14 @@ class NumpyLinalgTest(jtu.JaxTestCase):
# Norm, adjusted for dimension and type.
def norm(x):
norm = onp.linalg.norm(x, axis=(-2, -1))
return norm / ((n + 1) * np.finfo(dtype).eps)
norm = np.linalg.norm(x, axis=(-2, -1))
return norm / ((n + 1) * jnp.finfo(dtype).eps)
a, = args_maker()
w, v = np.linalg.eig(a)
self.assertTrue(onp.all(norm(onp.matmul(a, v) - w[..., None, :] * v) < 100))
w, v = jnp.linalg.eig(a)
self.assertTrue(np.all(norm(np.matmul(a, v) - w[..., None, :] * v) < 100))
self._CompileAndCheck(partial(np.linalg.eig), args_maker,
self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
check_dtypes=True, rtol=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
@ -276,15 +276,15 @@ class NumpyLinalgTest(jtu.JaxTestCase):
n = shape[-1]
args_maker = lambda: [rng(shape, dtype)]
a, = args_maker()
w1, _ = np.linalg.eig(a)
w2 = np.linalg.eigvals(a)
w1, _ = jnp.linalg.eig(a)
w2 = jnp.linalg.eigvals(a)
self.assertAllClose(w1, w2, check_dtypes=True)
@jtu.skip_on_devices("gpu", "tpu")
def testEigvalsInf(self):
# https://github.com/google/jax/issues/2661
x = np.array([[np.inf]], np.float64)
self.assertTrue(np.all(np.isnan(np.linalg.eigvals(x))))
x = jnp.array([[jnp.inf]], jnp.float64)
self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x))))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -299,9 +299,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
_skip_if_unsupported_type(dtype)
shape = (10,) + shape
args = rng(shape, dtype)
ws, vs = vmap(np.linalg.eig)(args)
self.assertTrue(onp.all(onp.linalg.norm(
onp.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
ws, vs = vmap(jnp.linalg.eig)(args)
self.assertTrue(np.all(np.linalg.norm(
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_n={}_lower={}".format(
@ -316,7 +316,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
_skip_if_unsupported_type(dtype)
tol = 30
if jtu.device_under_test() == "tpu":
if np.issubdtype(dtype, onp.complexfloating):
if jnp.issubdtype(dtype, np.complexfloating):
raise unittest.SkipTest("No complex eigh on TPU")
# TODO(phawkins): this tolerance is unpleasantly high.
tol = 1500
@ -326,17 +326,17 @@ class NumpyLinalgTest(jtu.JaxTestCase):
# Norm, adjusted for dimension and type.
def norm(x):
norm = onp.linalg.norm(x, axis=(-2, -1))
return norm / ((n + 1) * np.finfo(dtype).eps)
norm = np.linalg.norm(x, axis=(-2, -1))
return norm / ((n + 1) * jnp.finfo(dtype).eps)
a, = args_maker()
a = (a + onp.conj(a.T)) / 2
w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a),
a = (a + np.conj(a.T)) / 2
w, v = jnp.linalg.eigh(np.tril(a) if lower else np.triu(a),
UPLO=uplo, symmetrize_input=False)
self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5)
self.assertTrue(norm(onp.matmul(a, v) - w * v) < tol)
self.assertTrue(norm(np.eye(n) - np.matmul(np.conj(T(v)), v)) < 5)
self.assertTrue(norm(np.matmul(a, v) - w * v) < tol)
self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo), args_maker,
self._CompileAndCheck(partial(jnp.linalg.eigh, UPLO=uplo), args_maker,
check_dtypes=True, rtol=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
@ -350,14 +350,14 @@ class NumpyLinalgTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
if jtu.device_under_test() == "tpu":
if np.issubdtype(dtype, np.complexfloating):
if jnp.issubdtype(dtype, jnp.complexfloating):
raise unittest.SkipTest("No complex eigh on TPU")
n = shape[-1]
def args_maker():
a = rng((n, n), dtype)
a = (a + onp.conj(a.T)) / 2
a = (a + np.conj(a.T)) / 2
return [a]
self._CheckAgainstNumpy(onp.linalg.eigvalsh, np.linalg.eigvalsh, args_maker,
self._CheckAgainstNumpy(np.linalg.eigvalsh, jnp.linalg.eigvalsh, args_maker,
check_dtypes=True, tol=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
@ -374,16 +374,16 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self.skipTest("Test fails with numeric errors.")
uplo = "L" if lower else "U"
a = rng(shape, dtype)
a = (a + onp.conj(T(a))) / 2
ones = onp.ones((a.shape[-1], a.shape[-1]), dtype=dtype)
a *= onp.tril(ones) if lower else onp.triu(ones)
a = (a + np.conj(T(a))) / 2
ones = np.ones((a.shape[-1], a.shape[-1]), dtype=dtype)
a *= np.tril(ones) if lower else np.triu(ones)
# Gradient checks will fail without symmetrization as the eigh jvp rule
# is only correct for tangents in the symmetric subspace, whereas the
# checker checks against unconstrained (co)tangents.
if dtype not in complex_types:
f = partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True)
f = partial(jnp.linalg.eigh, UPLO=uplo, symmetrize_input=True)
else: # only check eigenvalue grads for complex matrices
f = lambda a: partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0]
f = lambda a: partial(jnp.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0]
jtu.check_grads(f, (a,), 2, rtol=1e-1)
@parameterized.named_parameters(jtu.cases_from_list(
@ -409,32 +409,32 @@ class NumpyLinalgTest(jtu.JaxTestCase):
# eigenvectors. You only ever want to optimize eigenvector directions, not coordinates!
uplo = "L" if lower else "U"
a = rng(shape, dtype)
a = (a + onp.conj(a.T)) / 2
a = onp.tril(a) if lower else onp.triu(a)
a = (a + np.conj(a.T)) / 2
a = np.tril(a) if lower else np.triu(a)
a_dot = eps * rng(shape, dtype)
a_dot = (a_dot + onp.conj(a_dot.T)) / 2
a_dot = onp.tril(a_dot) if lower else onp.triu(a_dot)
a_dot = (a_dot + np.conj(a_dot.T)) / 2
a_dot = np.tril(a_dot) if lower else np.triu(a_dot)
# evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix
f = partial(np.linalg.eigh, UPLO=uplo)
f = partial(jnp.linalg.eigh, UPLO=uplo)
(w, v), (dw, dv) = jvp(f, primals=(a,), tangents=(a_dot,))
new_a = a + a_dot
new_w, new_v = f(new_a)
new_a = (new_a + onp.conj(new_a.T)) / 2
new_a = (new_a + np.conj(new_a.T)) / 2
# Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
RTOL=1e-2
assert onp.max(
onp.abs((onp.diag(onp.dot(onp.conj((v+dv).T), onp.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
assert np.max(
np.abs((np.diag(np.dot(np.conj((v+dv).T), np.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
# Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
assert onp.max(
onp.linalg.norm(onp.abs(new_w*(v+dv) - onp.dot(new_a, (v+dv))), axis=0) /
onp.linalg.norm(onp.abs(new_w*(v+dv)), axis=0)
assert np.max(
np.linalg.norm(np.abs(new_w*(v+dv) - np.dot(new_a, (v+dv))), axis=0) /
np.linalg.norm(np.abs(new_w*(v+dv)), axis=0)
) < RTOL
def testEighGradPrecision(self):
rng = jtu.rand_default(self.rng())
a = rng((3, 3), onp.float32)
a = rng((3, 3), np.float32)
jtu.assert_dot_precision(
lax.Precision.HIGHEST, partial(jvp, np.linalg.eigh), (a,), (a,))
lax.Precision.HIGHEST, partial(jvp, jnp.linalg.eigh), (a,), (a,))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -447,14 +447,14 @@ class NumpyLinalgTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
if (jtu.device_under_test() == "tpu" and
np.issubdtype(dtype, onp.complexfloating)):
jnp.issubdtype(dtype, np.complexfloating)):
raise unittest.SkipTest("No complex eigh on TPU")
shape = (10,) + shape
args = rng(shape, dtype)
args = (args + onp.conj(T(args))) / 2
args = (args + np.conj(T(args))) / 2
ws, vs = vmap(jsp.linalg.eigh)(args)
self.assertTrue(onp.all(onp.linalg.norm(
onp.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
self.assertTrue(np.all(np.linalg.norm(
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_ord={}_axis={}_keepdims={}".format(
@ -469,11 +469,11 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for keepdims in [False, True]
for ord in (
[None] if axis is None and len(shape) > 2
else [None, 0, 1, 2, 3, -1, -2, -3, np.inf, -np.inf]
else [None, 0, 1, 2, 3, -1, -2, -3, jnp.inf, -jnp.inf]
if (axis is None and len(shape) == 1) or
isinstance(axis, int) or
(isinstance(axis, tuple) and len(axis) == 1)
else [None, 'fro', 1, 2, -1, -2, np.inf, -np.inf, 'nuc'])
else [None, 'fro', 1, 2, -1, -2, jnp.inf, -jnp.inf, 'nuc'])
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default])) # type: ignore
def testNorm(self, shape, dtype, ord, axis, keepdims, rng_factory):
@ -485,9 +485,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
raise unittest.SkipTest("No adequate SVD implementation available")
args_maker = lambda: [rng(shape, dtype)]
onp_fn = partial(onp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker,
np_fn = partial(jnp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
self._CheckAgainstNumpy(np_fn, np_fn, args_maker,
check_dtypes=False, tol=1e-3)
self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
@ -512,39 +512,39 @@ class NumpyLinalgTest(jtu.JaxTestCase):
# Norm, adjusted for dimension and type.
def norm(x):
norm = onp.linalg.norm(x, axis=(-2, -1))
return norm / (max(m, n) * np.finfo(dtype).eps)
norm = np.linalg.norm(x, axis=(-2, -1))
return norm / (max(m, n) * jnp.finfo(dtype).eps)
a, = args_maker()
out = np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
out = jnp.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
if compute_uv:
# Check the reconstructed matrices
if full_matrices:
k = min(m, n)
if m < n:
self.assertTrue(onp.all(
norm(a - onp.matmul(out[1][..., None, :] * out[0], out[2][..., :k, :])) < 50))
self.assertTrue(np.all(
norm(a - np.matmul(out[1][..., None, :] * out[0], out[2][..., :k, :])) < 50))
else:
self.assertTrue(onp.all(
norm(a - onp.matmul(out[1][..., None, :] * out[0][..., :, :k], out[2])) < 350))
self.assertTrue(np.all(
norm(a - np.matmul(out[1][..., None, :] * out[0][..., :, :k], out[2])) < 350))
else:
self.assertTrue(onp.all(
norm(a - onp.matmul(out[1][..., None, :] * out[0], out[2])) < 350))
self.assertTrue(np.all(
norm(a - np.matmul(out[1][..., None, :] * out[0], out[2])) < 350))
# Check the unitary properties of the singular vector matrices.
self.assertTrue(onp.all(norm(onp.eye(out[0].shape[-1]) - onp.matmul(onp.conj(T(out[0])), out[0])) < 15))
self.assertTrue(np.all(norm(np.eye(out[0].shape[-1]) - np.matmul(np.conj(T(out[0])), out[0])) < 15))
if m >= n:
self.assertTrue(onp.all(norm(onp.eye(out[2].shape[-1]) - onp.matmul(onp.conj(T(out[2])), out[2])) < 10))
self.assertTrue(np.all(norm(np.eye(out[2].shape[-1]) - np.matmul(np.conj(T(out[2])), out[2])) < 10))
else:
self.assertTrue(onp.all(norm(onp.eye(out[2].shape[-2]) - onp.matmul(out[2], onp.conj(T(out[2])))) < 20))
self.assertTrue(np.all(norm(np.eye(out[2].shape[-2]) - np.matmul(out[2], np.conj(T(out[2])))) < 20))
else:
self.assertTrue(onp.allclose(onp.linalg.svd(a, compute_uv=False), onp.asarray(out), atol=1e-4, rtol=1e-4))
self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False), np.asarray(out), atol=1e-4, rtol=1e-4))
self._CompileAndCheck(partial(np.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv),
self._CompileAndCheck(partial(jnp.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv),
args_maker, check_dtypes=True)
if not (compute_uv and full_matrices):
svd = partial(np.linalg.svd, full_matrices=full_matrices,
svd = partial(jnp.linalg.svd, full_matrices=full_matrices,
compute_uv=compute_uv)
# TODO(phawkins): these tolerances seem very loose.
jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=5e-2, atol=2e-1)
@ -561,7 +561,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
def testQr(self, shape, dtype, full_matrices, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
if (np.issubdtype(dtype, onp.complexfloating) and
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest("No complex QR implementation")
m, n = shape[-2:]
@ -572,42 +572,42 @@ class NumpyLinalgTest(jtu.JaxTestCase):
mode, k = "reduced", min(m, n)
a = rng(shape, dtype)
lq, lr = np.linalg.qr(a, mode=mode)
lq, lr = jnp.linalg.qr(a, mode=mode)
# onp.linalg.qr doesn't support batch dimensions. But it seems like an
# np.linalg.qr doesn't support batch dimensions. But it seems like an
# inevitable extension so we support it in our version.
nq = onp.zeros(shape[:-2] + (m, k), dtype)
nr = onp.zeros(shape[:-2] + (k, n), dtype)
for index in onp.ndindex(*shape[:-2]):
nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode)
nq = np.zeros(shape[:-2] + (m, k), dtype)
nr = np.zeros(shape[:-2] + (k, n), dtype)
for index in np.ndindex(*shape[:-2]):
nq[index], nr[index] = np.linalg.qr(a[index], mode=mode)
max_rank = max(m, n)
# Norm, adjusted for dimension and type.
def norm(x):
n = onp.linalg.norm(x, axis=(-2, -1))
return n / (max_rank * np.finfo(dtype).eps)
n = np.linalg.norm(x, axis=(-2, -1))
return n / (max_rank * jnp.finfo(dtype).eps)
def compare_orthogonal(q1, q2):
# Q is unique up to sign, so normalize the sign first.
sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True)
phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios))
sum_of_ratios = np.sum(np.divide(q1, q2), axis=-2, keepdims=True)
phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
q1 *= phases
self.assertTrue(onp.all(norm(q1 - q2) < 30))
self.assertTrue(np.all(norm(q1 - q2) < 30))
# Check a ~= qr
self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30))
self.assertTrue(np.all(norm(a - np.matmul(lq, lr)) < 30))
# Compare the first 'k' vectors of Q; the remainder form an arbitrary
# orthonormal basis for the null space.
compare_orthogonal(nq[..., :k], lq[..., :k])
# Check that q is close to unitary.
self.assertTrue(onp.all(
norm(onp.eye(k) -onp.matmul(onp.conj(T(lq)), lq)) < 5))
self.assertTrue(np.all(
norm(np.eye(k) -np.matmul(np.conj(T(lq)), lq)) < 5))
if not full_matrices and m >= n:
jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a,), atol=3e-3)
jtu.check_jvp(jnp.linalg.qr, partial(jvp, jnp.linalg.qr), (a,), atol=3e-3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
@ -619,16 +619,16 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testQrBatching(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
args = rng(shape, np.float32)
args = rng(shape, jnp.float32)
qs, rs = vmap(jsp.linalg.qr)(args)
self.assertTrue(onp.all(onp.linalg.norm(args - onp.matmul(qs, rs)) < 1e-3))
self.assertTrue(np.all(np.linalg.norm(args - np.matmul(qs, rs)) < 1e-3))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_pnorm={}".format(jtu.format_shape_dtype_string(shape, dtype), pnorm),
"shape": shape, "pnorm": pnorm, "dtype": dtype}
for shape in [(1, 1), (4, 4), (2, 3, 5), (5, 5, 5), (20, 20), (5, 10)]
for pnorm in [np.inf, -np.inf, 1, -1, 2, -2, 'fro']
for pnorm in [jnp.inf, -jnp.inf, 1, -1, 2, -2, 'fro']
for dtype in float_types + complex_types))
@jtu.skip_on_devices("tpu") # SVD is not implemented on the TPU backend
@jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors
@ -648,12 +648,12 @@ class NumpyLinalgTest(jtu.JaxTestCase):
args_maker = args_gen(pnorm)
if pnorm not in [2, -2] and len(set(shape[-2:])) != 1:
with self.assertRaises(onp.linalg.LinAlgError):
np.linalg.cond(*args_maker())
with self.assertRaises(np.linalg.LinAlgError):
jnp.linalg.cond(*args_maker())
else:
self._CheckAgainstNumpy(onp.linalg.cond, np.linalg.cond, args_maker,
self._CheckAgainstNumpy(np.linalg.cond, jnp.linalg.cond, args_maker,
check_dtypes=False, tol=1e-3)
partial_norm = partial(np.linalg.cond, p=pnorm)
partial_norm = partial(jnp.linalg.cond, p=pnorm)
self._CompileAndCheck(partial_norm, lambda: [gen_mat()],
check_dtypes=False, rtol=1e-03, atol=1e-03)
@ -674,16 +674,16 @@ class NumpyLinalgTest(jtu.JaxTestCase):
while not invertible:
a = rng(shape, dtype)
try:
onp.linalg.inv(a)
np.linalg.inv(a)
invertible = True
except onp.linalg.LinAlgError:
except np.linalg.LinAlgError:
pass
return a
args_maker = lambda: [tensor_maker(), int(onp.floor(len(shape) / 2))]
self._CheckAgainstNumpy(onp.linalg.tensorinv, np.linalg.tensorinv, args_maker,
args_maker = lambda: [tensor_maker(), int(np.floor(len(shape) / 2))]
self._CheckAgainstNumpy(np.linalg.tensorinv, jnp.linalg.tensorinv, args_maker,
check_dtypes=False, tol=1e-3)
partial_inv = partial(np.linalg.tensorinv, ind=int(onp.floor(len(shape) / 2)))
partial_inv = partial(jnp.linalg.tensorinv, ind=int(np.floor(len(shape) / 2)))
self._CompileAndCheck(partial_inv, lambda: [tensor_maker()], check_dtypes=False, rtol=1e-03, atol=1e-03)
@parameterized.named_parameters(jtu.cases_from_list(
@ -707,9 +707,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
_skip_if_unsupported_type(dtype)
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(onp.linalg.solve, np.linalg.solve, args_maker,
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker,
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(np.linalg.solve, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp.linalg.solve, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -730,15 +730,15 @@ class NumpyLinalgTest(jtu.JaxTestCase):
while not invertible:
a = rng(shape, dtype)
try:
onp.linalg.inv(a)
np.linalg.inv(a)
invertible = True
except onp.linalg.LinAlgError:
except np.linalg.LinAlgError:
pass
return [a]
self._CheckAgainstNumpy(onp.linalg.inv, np.linalg.inv, args_maker,
self._CheckAgainstNumpy(np.linalg.inv, jnp.linalg.inv, args_maker,
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp.linalg.inv, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -754,26 +754,26 @@ class NumpyLinalgTest(jtu.JaxTestCase):
_skip_if_unsupported_type(dtype)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp.linalg.pinv, np.linalg.pinv, args_maker,
self._CheckAgainstNumpy(np.linalg.pinv, jnp.linalg.pinv, args_maker,
check_dtypes=True, tol=1e-2)
self._CompileAndCheck(np.linalg.pinv, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp.linalg.pinv, args_maker, check_dtypes=True)
# TODO(phawkins): 1e-1 seems like a very loose tolerance.
jtu.check_grads(np.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1)
jtu.check_grads(jnp.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1)
@jtu.skip_on_devices("tpu") # SVD is not implemented on the TPU backend
def testPinvGradIssue2792(self):
def f(p):
a = np.array([[0., 0.],[-p, 1.]], np.float32) * 1 / (1 + p**2)
return np.linalg.pinv(a)
j = jax.jacobian(f)(np.float32(2.))
self.assertAllClose(np.array([[0., -1.], [ 0., 0.]], np.float32), j,
a = jnp.array([[0., 0.],[-p, 1.]], jnp.float32) * 1 / (1 + p**2)
return jnp.linalg.pinv(a)
j = jax.jacobian(f)(jnp.float32(2.))
self.assertAllClose(jnp.array([[0., -1.], [ 0., 0.]], jnp.float32), j,
check_dtypes=True)
expected = np.array([[[[-1., 0.], [ 0., 0.]], [[0., -1.], [0., 0.]]],
expected = jnp.array([[[[-1., 0.], [ 0., 0.]], [[0., -1.], [0., 0.]]],
[[[0., 0.], [-1., 0.]], [[0., 0.], [0., -1.]]]],
dtype=np.float32)
dtype=jnp.float32)
self.assertAllClose(
expected, jax.jacobian(np.linalg.pinv)(np.eye(2, dtype=np.float32)),
expected, jax.jacobian(jnp.linalg.pinv)(jnp.eye(2, dtype=jnp.float32)),
check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
@ -791,10 +791,10 @@ class NumpyLinalgTest(jtu.JaxTestCase):
_skip_if_unsupported_type(dtype)
args_maker = lambda: [rng(shape, dtype)]
tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
self._CheckAgainstNumpy(partial(onp.linalg.matrix_power, n=n),
partial(np.linalg.matrix_power, n=n),
self._CheckAgainstNumpy(partial(np.linalg.matrix_power, n=n),
partial(jnp.linalg.matrix_power, n=n),
args_maker, check_dtypes=True, tol=tol)
self._CompileAndCheck(partial(np.linalg.matrix_power, n=n), args_maker,
self._CompileAndCheck(partial(jnp.linalg.matrix_power, n=n), args_maker,
check_dtypes=True, rtol=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
@ -811,9 +811,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
n = shape[-1]
args_maker = lambda: [rng(shape, dtype)]
a, = args_maker()
self._CheckAgainstNumpy(onp.linalg.matrix_rank, np.linalg.matrix_rank,
self._CheckAgainstNumpy(np.linalg.matrix_rank, jnp.linalg.matrix_rank,
args_maker, check_dtypes=False, tol=1e-3)
self._CompileAndCheck(np.linalg.matrix_rank, args_maker,
self._CompileAndCheck(jnp.linalg.matrix_rank, args_maker,
check_dtypes=False, rtol=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
@ -832,12 +832,12 @@ class NumpyLinalgTest(jtu.JaxTestCase):
_skip_if_unsupported_type(dtype)
args_maker = lambda: [[rng(shape, dtype) for shape in shapes]]
onp_fun = onp.linalg.multi_dot
jnp_fun = partial(np.linalg.multi_dot, precision=lax.Precision.HIGHEST)
tol = {onp.float32: 1e-4, onp.float64: 1e-10,
onp.complex64: 1e-4, onp.complex128: 1e-10}
np_fun = np.linalg.multi_dot
jnp_fun = partial(jnp.linalg.multi_dot, precision=lax.Precision.HIGHEST)
tol = {np.float32: 1e-4, np.float64: 1e-10,
np.complex64: 1e-4, np.complex128: 1e-10}
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True,
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True,
atol=tol, rtol=tol)
@ -846,39 +846,39 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU.
def testIssue669(self):
def test(x):
val, vec = np.linalg.eigh(x)
return np.real(np.sum(val))
val, vec = jnp.linalg.eigh(x)
return jnp.real(jnp.sum(val))
grad_test_jc = jit(grad(jit(test)))
xc = onp.eye(3, dtype=onp.complex)
xc = np.eye(3, dtype=np.complex)
self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testIssue1151(self):
A = np.array(onp.random.randn(100, 3, 3), dtype=np.float32)
b = np.array(onp.random.randn(100, 3), dtype=np.float32)
x = np.linalg.solve(A, b)
self.assertAllClose(vmap(np.dot)(A, x), b, atol=1e-3, rtol=1e-2,
A = jnp.array(np.random.randn(100, 3, 3), dtype=jnp.float32)
b = jnp.array(np.random.randn(100, 3), dtype=jnp.float32)
x = jnp.linalg.solve(A, b)
self.assertAllClose(vmap(jnp.dot)(A, x), b, atol=1e-3, rtol=1e-2,
check_dtypes=True)
jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A, b)
jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A, b)
jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A[0], b[0])
jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A[0], b[0])
jac0 = jax.jacobian(jnp.linalg.solve, argnums=0)(A, b)
jac1 = jax.jacobian(jnp.linalg.solve, argnums=1)(A, b)
jac0 = jax.jacobian(jnp.linalg.solve, argnums=0)(A[0], b[0])
jac1 = jax.jacobian(jnp.linalg.solve, argnums=1)(A[0], b[0])
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testIssue1383(self):
seed = jax.random.PRNGKey(0)
tmp = jax.random.uniform(seed, (2,2))
a = np.dot(tmp, tmp.T)
a = jnp.dot(tmp, tmp.T)
def f(inp):
val, vec = np.linalg.eigh(inp)
return np.dot(np.dot(vec, inp), vec.T)
val, vec = jnp.linalg.eigh(inp)
return jnp.dot(jnp.dot(vec, inp), vec.T)
grad_func = jax.jacfwd(f)
hess_func = jax.jacfwd(grad_func)
cube_func = jax.jacfwd(hess_func)
self.assertFalse(onp.any(onp.isnan(cube_func(a))))
self.assertFalse(np.any(np.isnan(cube_func(a))))
class ScipyLinalgTest(jtu.JaxTestCase):
@ -890,8 +890,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
(1,),
(7, -2),
(3, 4, 5),
(onp.ones((3, 4), dtype=np.float_), 5,
onp.random.randn(5, 2).astype(np.float_)),
(np.ones((3, 4), dtype=jnp.float_), 5,
np.random.randn(5, 2).astype(jnp.float_)),
])))
def testBlockDiag(self, args):
args_maker = lambda: args
@ -914,15 +914,15 @@ class ScipyLinalgTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
x, = args_maker()
p, l, u = jsp.linalg.lu(x)
self.assertAllClose(x, onp.matmul(p, onp.matmul(l, u)), check_dtypes=True,
rtol={onp.float32: 1e-3, onp.float64: 1e-12,
onp.complex64: 1e-3, onp.complex128: 1e-12})
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)), check_dtypes=True,
rtol={np.float32: 1e-3, np.float64: 1e-12,
np.complex64: 1e-3, np.complex128: 1e-12})
self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True)
def testLuOfSingularMatrix(self):
x = np.array([[-1., 3./2], [2./3, -1.]], dtype=onp.float32)
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
p, l, u = jsp.linalg.lu(x)
self.assertAllClose(x, onp.matmul(p, onp.matmul(l, u)), check_dtypes=True)
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)), check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -945,18 +945,18 @@ class ScipyLinalgTest(jtu.JaxTestCase):
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype, "rng_factory": rng_factory}
for shape in [(4, 5), (6, 5)]
for dtype in [np.float32]
for dtype in [jnp.float32]
for rng_factory in [jtu.rand_default]))
def testLuBatching(self, shape, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
args = [rng(shape, np.float32) for _ in range(10)]
args = [rng(shape, jnp.float32) for _ in range(10)]
expected = list(osp.linalg.lu(x) for x in args)
ps = onp.stack([out[0] for out in expected])
ls = onp.stack([out[1] for out in expected])
us = onp.stack([out[2] for out in expected])
ps = np.stack([out[0] for out in expected])
ls = np.stack([out[1] for out in expected])
us = np.stack([out[2] for out in expected])
actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(np.stack(args))
actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(jnp.stack(args))
self.assertAllClose(ps, actual_ps, check_dtypes=True)
self.assertAllClose(ls, actual_ls, check_dtypes=True)
self.assertAllClose(us, actual_us, check_dtypes=True)
@ -976,11 +976,11 @@ class ScipyLinalgTest(jtu.JaxTestCase):
x, = args_maker()
lu, piv = jsp.linalg.lu_factor(x)
l = onp.tril(lu, -1) + onp.eye(n, dtype=dtype)
u = onp.triu(lu)
l = np.tril(lu, -1) + np.eye(n, dtype=dtype)
u = np.triu(lu)
for i in range(n):
x[[i, piv[i]],] = x[[piv[i], i],]
self.assertAllClose(x, onp.matmul(l, u), check_dtypes=True, rtol=1e-3,
self.assertAllClose(x, np.matmul(l, u), check_dtypes=True, rtol=1e-3,
atol=1e-3)
self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True)
@ -1039,7 +1039,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
if (sym_pos and np.issubdtype(dtype, onp.complexfloating) and
if (sym_pos and jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest(
"Complex Cholesky decomposition not implemented on TPU")
@ -1049,8 +1049,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
def args_maker():
a = rng(lhs_shape, dtype)
if sym_pos:
a = onp.matmul(a, onp.conj(T(a)))
a = onp.tril(a) if lower else onp.triu(a)
a = np.matmul(a, np.conj(T(a)))
a = np.tril(a) if lower else np.triu(a)
return [a, rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
@ -1081,22 +1081,22 @@ class ScipyLinalgTest(jtu.JaxTestCase):
_skip_if_unsupported_type(dtype)
rng = rng_factory(self.rng())
k = rng(lhs_shape, dtype)
l = onp.linalg.cholesky(onp.matmul(k, T(k))
+ lhs_shape[-1] * onp.eye(lhs_shape[-1]))
l = np.linalg.cholesky(np.matmul(k, T(k))
+ lhs_shape[-1] * np.eye(lhs_shape[-1]))
l = l.astype(k.dtype)
b = rng(rhs_shape, dtype)
if unit_diagonal:
a = onp.tril(l, -1) + onp.eye(lhs_shape[-1], dtype=dtype)
a = np.tril(l, -1) + np.eye(lhs_shape[-1], dtype=dtype)
else:
a = l
a = a if lower else T(a)
inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype)
inv = np.linalg.inv(T(a) if transpose_a else a).astype(a.dtype)
if len(lhs_shape) == len(rhs_shape):
onp_ans = onp.matmul(inv, b)
np_ans = np.matmul(inv, b)
else:
onp_ans = onp.einsum("...ij,...j->...i", inv, b)
np_ans = np.einsum("...ij,...j->...i", inv, b)
# The standard scipy.linalg.solve_triangular doesn't support broadcasting.
# But it seems like an inevitable extension so we support it.
@ -1104,8 +1104,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower,
unit_diagonal=unit_diagonal)
self.assertAllClose(onp_ans, ans, check_dtypes=True,
rtol={onp.float32: 1e-4, onp.float64: 1e-11})
self.assertAllClose(np_ans, ans, check_dtypes=True,
rtol={np.float32: 1e-4, np.float64: 1e-11})
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -1122,7 +1122,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
for dtype in float_types + complex_types
for transpose_a in [False, True]
for conjugate_a in (
[False] if np.issubdtype(dtype, np.floating) else [False, True])
[False] if jnp.issubdtype(dtype, jnp.floating) else [False, True])
for left_side, a_shape, b_shape in [
(False, (4, 4), (4,)),
(False, (4, 4), (1, 4,)),
@ -1141,7 +1141,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
rng = rng_factory(self.rng())
# Test lax_linalg.triangular_solve instead of scipy.linalg.solve_triangular
# because it exposes more options.
A = np.tril(rng(a_shape, dtype) + 5 * onp.eye(a_shape[-1], dtype=dtype))
A = jnp.tril(rng(a_shape, dtype) + 5 * np.eye(a_shape[-1], dtype=dtype))
A = A if lower else T(A)
B = rng(b_shape, dtype)
f = partial(lax_linalg.triangular_solve, lower=lower,
@ -1165,21 +1165,21 @@ class ScipyLinalgTest(jtu.JaxTestCase):
]))
def testTriangularSolveBatching(self, left_side, a_shape, b_shape, bdims):
rng = jtu.rand_default(self.rng())
A = np.tril(rng(a_shape, onp.float32)
+ 5 * onp.eye(a_shape[-1], dtype=onp.float32))
B = rng(b_shape, onp.float32)
A = jnp.tril(rng(a_shape, np.float32)
+ 5 * np.eye(a_shape[-1], dtype=np.float32))
B = rng(b_shape, np.float32)
solve = partial(lax_linalg.triangular_solve, lower=True,
transpose_a=False, conjugate_a=False,
unit_diagonal=False, left_side=left_side)
X = vmap(solve, bdims)(A, B)
matmul = partial(np.matmul, precision=lax.Precision.HIGHEST)
matmul = partial(jnp.matmul, precision=lax.Precision.HIGHEST)
Y = matmul(A, X) if left_side else matmul(X, A)
onp.testing.assert_allclose(Y - B, 0, atol=1e-4)
np.testing.assert_allclose(Y - B, 0, atol=1e-4)
def testTriangularSolveGradPrecision(self):
rng = jtu.rand_default(self.rng())
a = np.tril(rng((3, 3), onp.float32))
b = rng((1, 3), onp.float32)
a = jnp.tril(rng((3, 3), np.float32))
b = rng((1, 3), np.float32)
jtu.assert_dot_precision(
lax.Precision.HIGHEST,
partial(jvp, lax_linalg.triangular_solve),
@ -1205,7 +1205,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
check_dtypes=True)
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
args_maker_triu = lambda: [onp.triu(rng((n, n), dtype))]
args_maker_triu = lambda: [np.triu(rng((n, n), dtype))]
jsp_fun_triu = lambda a: jsp.linalg.expm(a,upper_triangular=True)
self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu,
check_dtypes=True)
@ -1220,7 +1220,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
))
@jtu.skip_on_mac_linalg_bug()
def testIssue2131(self, n, dtype):
args_maker_zeros = lambda: [onp.zeros((n, n), dtype)]
args_maker_zeros = lambda: [np.zeros((n, n), dtype)]
osp_fun = lambda a: osp.linalg.expm(a)
jsp_fun = lambda a: jsp.linalg.expm(a)
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker_zeros,
@ -1248,10 +1248,10 @@ class ScipyLinalgTest(jtu.JaxTestCase):
def args_maker():
b = rng(rhs_shape, dtype)
if lower:
L = onp.tril(rng(lhs_shape, dtype))
L = np.tril(rng(lhs_shape, dtype))
return [(L, lower), b]
else:
U = onp.triu(rng(lhs_shape, dtype))
U = np.triu(rng(lhs_shape, dtype))
return [(U, lower), b]
self._CheckAgainstNumpy(osp.linalg.cho_solve, jsp.linalg.cho_solve,
args_maker, check_dtypes=True, tol=1e-3)

View File

@ -16,12 +16,12 @@
from absl.testing import absltest
import numpy as onp
import numpy as np
import re
from jax import api, lax, ops
from jax import core
from jax import numpy as np
from jax import numpy as jnp
from jax import test_util as jtu
from jax.experimental import loops
@ -61,9 +61,9 @@ class LoopsTest(jtu.JaxTestCase):
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True)
self.assertAllClose(5., api.grad(f_op)(2.), check_dtypes=True)
self.assertAllClose(5., api.grad(f_op)(2.), check_dtypes=True)
inc_batch = onp.arange(5, dtype=np.float_)
self.assertAllClose(np.array([f_expected(inc) for inc in inc_batch],
dtype=np.float_),
inc_batch = np.arange(5, dtype=jnp.float_)
self.assertAllClose(jnp.array([f_expected(inc) for inc in inc_batch],
dtype=jnp.float_),
api.vmap(f_op)(inc_batch), check_dtypes=True)
@ -86,14 +86,14 @@ class LoopsTest(jtu.JaxTestCase):
with loops.Scope() as s:
n = x.shape[0]
assert n == y.shape[0]
s.out = np.zeros(shape=[n], dtype=np.float32)
s.out = jnp.zeros(shape=[n], dtype=jnp.float32)
for i in s.range(n):
s.out = ops.index_add(s.out, i, x[i] + y[i])
return s.out
x = np.array([1., 2., 3.], dtype=np.float32)
y = np.array([4., 5., 6.], dtype=np.float32)
self.assertAllClose(np.add(x, y), add_vec(x, y), check_dtypes=True)
x = jnp.array([1., 2., 3.], dtype=jnp.float32)
y = jnp.array([4., 5., 6.], dtype=jnp.float32)
self.assertAllClose(jnp.add(x, y), add_vec(x, y), check_dtypes=True)
def test_matmul(self):
def matmul(x, y):
@ -101,16 +101,16 @@ class LoopsTest(jtu.JaxTestCase):
n, m = x.shape
m1, p = y.shape
assert m == m1
s.out = np.zeros(shape=[n, p], dtype=np.float32)
s.out = jnp.zeros(shape=[n, p], dtype=jnp.float32)
for i in s.range(n):
for j in s.range(p):
for k in s.range(m):
s.out = ops.index_add(s.out, (i, j), x[i, k] * y[k, j])
return s.out
x = np.array([[1., 2., 3.]], dtype=np.float32) # 1x3
y = np.array([[4.], [5.], [6.]], dtype=np.float32) # 3x1
self.assertAllClose(np.matmul(x, y), matmul(x, y), check_dtypes=True)
x = jnp.array([[1., 2., 3.]], dtype=jnp.float32) # 1x3
y = jnp.array([[4.], [5.], [6.]], dtype=jnp.float32) # 3x1
self.assertAllClose(jnp.matmul(x, y), matmul(x, y), check_dtypes=True)
def test_reuse_range(self):
"""Ranges can be reused, as long as not nested in each other."""
@ -142,7 +142,7 @@ class LoopsTest(jtu.JaxTestCase):
def test_example_doc(self):
"The example from the module docstring."
def f_expected():
arr = onp.zeros(5, dtype=np.float_)
arr = np.zeros(5, dtype=jnp.float_)
for i in range(arr.shape[0]):
arr[i] += 2.
if i % 2 == 0:
@ -150,7 +150,7 @@ class LoopsTest(jtu.JaxTestCase):
return arr
def f_op_jax():
arr = np.zeros(5)
arr = jnp.zeros(5)
def loop_body(i, acc_arr):
arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
return lax.cond(i % 2 == 0,
@ -163,7 +163,7 @@ class LoopsTest(jtu.JaxTestCase):
def f_op_loops():
with loops.Scope() as s:
s.arr = np.zeros(5) # Must create the mutable state of the loop as `scope` fields.
s.arr = jnp.zeros(5) # Must create the mutable state of the loop as `scope` fields.
for i in s.range(s.arr.shape[0]):
s.arr = ops.index_update(s.arr, i, s.arr[i] + 2.)
for _ in s.cond_range(i % 2 == 0): # Conditionals are also sugared as loops with 0 or 1 iterations
@ -288,7 +288,7 @@ class LoopsTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, "Abstract tracer value encountered where concrete value is expected"):
self.assertAllClose(16., api.jit(f_op)(0, 4, 4.), check_dtypes=True)
with self.assertRaisesRegex(TypeError, "Abstract tracer value encountered where concrete value is expected"):
self.assertAllClose(16., api.vmap(f_op)(np.zeros(10), np.ones(10), np.array([4.] * 10)), check_dtypes=True)
self.assertAllClose(16., api.vmap(f_op)(jnp.zeros(10), jnp.ones(10), jnp.array([4.] * 10)), check_dtypes=True)
def test_cond(self):
def f_op(inc):
@ -377,9 +377,9 @@ class LoopsTest(jtu.JaxTestCase):
self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True)
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True)
self.assertAllClose(f_expected(1.), f_op(1.), check_dtypes=True)
init_batch = onp.array([1., 2., 3.], dtype=onp.float32)
self.assertAllClose(onp.array([f_expected(init) for init in init_batch],
dtype=onp.float32),
init_batch = np.array([1., 2., 3.], dtype=np.float32)
self.assertAllClose(np.array([f_expected(init) for init in init_batch],
dtype=np.float32),
api.vmap(f_op)(init_batch), check_dtypes=True)
def test_error_while_cond_mutation(self):

View File

@ -17,11 +17,11 @@ from functools import partial
import itertools as it
from unittest import SkipTest
import numpy as onp
import numpy as np
from absl.testing import absltest, parameterized
from jax.interpreters.masking import shape_as_value, ShapeError, \
parse_spec, Poly, Mon
from jax import numpy as np, test_util as jtu, mask, vmap, jit, grad, lax, \
from jax import numpy as jnp, test_util as jtu, mask, vmap, jit, grad, lax, \
shapecheck, api
from jax.config import config
from jax.numpy.lax_numpy import _polymorphic_slice_indices
@ -60,9 +60,9 @@ class ShapesTest(jtu.JaxTestCase):
def test_Poly_equal(self):
assert constant_poly(3) == 3
assert onp.array(3, onp.int64) == constant_poly(3)
assert onp.array(3, onp.int64)[()] == constant_poly(3)
assert not onp.array(3, onp.int64) != constant_poly(3)
assert np.array(3, np.int64) == constant_poly(3)
assert np.array(3, np.int64)[()] == constant_poly(3)
assert not np.array(3, np.int64) != constant_poly(3)
assert constant_poly(4) != 3
assert 3 == constant_poly(3)
assert 4 != constant_poly(3)
@ -109,27 +109,27 @@ class ShapesTest(jtu.JaxTestCase):
def test_sum(self):
@shapecheck(['(m, n)'], '')
def sum(x):
return np.sum(x)
return jnp.sum(x)
def test_prod(self):
@shapecheck(['(m, n)'], '')
def prod(x):
return np.prod(x)
return jnp.prod(x)
def test_max(self):
@shapecheck(['(m, n)'], '')
def prod(x):
return np.max(x)
return jnp.max(x)
def test_min(self):
@shapecheck(['(m, n)'], '')
def prod(x):
return np.min(x)
return jnp.min(x)
def test_dot(self):
@shapecheck(['(m, n)', 'n'], 'm')
def matvec(A, b):
return np.dot(A, b)
return jnp.dot(A, b)
def thunk():
@shapecheck(['(m, n)', 'n'], 'm')
@ -159,12 +159,12 @@ class ShapesTest(jtu.JaxTestCase):
return api.device_put(x)
def test_broadcast_in_dim(self):
x = np.zeros(7)
x = jnp.zeros(7)
@shapecheck(['(n,)'], '(3, n, 4)')
def broadcast_in_dim(x):
return lax.broadcast_in_dim(x, shape=(3, x.shape[0], 4), broadcast_dimensions=(1,))
x = np.zeros((7, 1))
x = jnp.zeros((7, 1))
@shapecheck(['(n, 1)'], '(3, n, 4, 1)')
def broadcast_in_dim(x):
@ -181,17 +181,17 @@ class ShapesTest(jtu.JaxTestCase):
# @jit
# @grad
# def sum_square(x):
# return np.sum(x ** 2)
# return jnp.sum(x ** 2)
def test_pad(self):
@shapecheck(['n'], '2*n+1')
def p(x):
return lax.pad(x, np.array(0., x.dtype), [(1, 1, 1)])
return lax.pad(x, jnp.array(0., x.dtype), [(1, 1, 1)])
def test_numpy_pad(self):
@shapecheck(['n'], 'n+1')
def p(x):
return np.pad(x, (0, 1))
return jnp.pad(x, (0, 1))
@parameterized.named_parameters(jtu.cases_from_list(
{
@ -217,9 +217,9 @@ class ShapesTest(jtu.JaxTestCase):
dimension_numbers, lhs_perm, rhs_perm, out_perm):
valid = padding == 'VALID'
is_strided = strides[0] != 1
lhs_shape = '({}, {}, {}, {})'.format(*onp.take(['n', 'i', '2*h' if is_strided else 'h', 'w'], lhs_perm))
rhs_shape = '({}, {}, {}, {})'.format(*onp.take(['o', 'i', '2', '3'], rhs_perm))
out_shape = '({}, {}, {}, {})'.format(*onp.take([
lhs_shape = '({}, {}, {}, {})'.format(*np.take(['n', 'i', '2*h' if is_strided else 'h', 'w'], lhs_perm))
rhs_shape = '({}, {}, {}, {})'.format(*np.take(['o', 'i', '2', '3'], rhs_perm))
out_shape = '({}, {}, {}, {})'.format(*np.take([
'n', 'o', 'h+-1' if valid and not is_strided else 'h',
('w+-2' if valid else 'w') if lhs_dilation is None else '2*w+-1'], out_perm))
@ -275,14 +275,14 @@ class ShapesTest(jtu.JaxTestCase):
# https://travis-ci.org/github/google/jax/jobs/682086351
@shapecheck(['n'], 'n')
def range_like(x):
return lax.iota(np.int32, x.shape[0])
return lax.iota(jnp.int32, x.shape[0])
def test_arange(self):
raise SkipTest("not yet implemented")
# https://travis-ci.org/github/google/jax/jobs/682086351
@shapecheck(['n'], 'n')
def arange_like(x):
return np.arange(x.shape[0], dtype=np.int32)
return jnp.arange(x.shape[0], dtype=jnp.int32)
def test_expit(self):
@shapecheck(['n'], 'n')
@ -292,10 +292,10 @@ class ShapesTest(jtu.JaxTestCase):
def test_reshape(self):
@shapecheck(['n, a, b'], 'n, a*b')
def flatten(x):
return np.reshape(x, (x.shape[0], x.shape[1] * x.shape[2]))
return jnp.reshape(x, (x.shape[0], x.shape[1] * x.shape[2]))
def test_ravel(self):
a = np.array(1)
a = jnp.array(1)
@shapecheck(['n'], '')
def thunk(n):
@ -306,23 +306,23 @@ class MaskingTest(jtu.JaxTestCase):
def test_sum(self):
@partial(mask, in_shapes=['n'], out_shape='')
def padded_sum(x):
return np.sum(x)
return jnp.sum(x)
ans = padded_sum([np.array([3, 1, 4, 1, 5])], dict(n=3))
ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=3))
expected = 8
self.assertAllClose(ans, expected, check_dtypes=False)
ans = padded_sum([np.array([3, 1, 4, 1, 5])], dict(n=4))
ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=4))
expected = 9
self.assertAllClose(ans, expected, check_dtypes=False)
def test_sum_vmap(self):
@partial(mask, in_shapes=['n'], out_shape='')
def padded_sum(x):
return np.sum(x)
return jnp.sum(x)
ans = vmap(padded_sum)([np.ones((5, 10))], dict(n=np.arange(5)))
expected = onp.array([0, 1, 2, 3, 4])
ans = vmap(padded_sum)([jnp.ones((5, 10))], dict(n=jnp.arange(5)))
expected = np.array([0, 1, 2, 3, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
def test_add(self):
@ -330,13 +330,13 @@ class MaskingTest(jtu.JaxTestCase):
def addvecs(x, y):
return x + y
x = np.array([3, 1, 4, 1, 5, 9])
y = np.array([2, 6, 5, 3, 5, 8])
x = jnp.array([3, 1, 4, 1, 5, 9])
y = jnp.array([2, 6, 5, 3, 5, 8])
ans = addvecs([x, y], dict(n=3))
expected = onp.array([5, 7, 9])
expected = np.array([5, 7, 9])
self.assertAllClose(ans[:3], expected, check_dtypes=False)
thunk = lambda: addvecs([np.arange(5), np.arange(6)], dict(n=3))
thunk = lambda: addvecs([jnp.arange(5), jnp.arange(6)], dict(n=3))
self.assertRaisesRegex(ShapeError, "", thunk)
def test_scan(self):
@ -345,7 +345,7 @@ class MaskingTest(jtu.JaxTestCase):
out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
return out
ans = cumsum([np.array([5, 2, 9, 1, 4])], dict(n=3))
ans = cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=3))
expected = 16
self.assertAllClose(ans, expected, check_dtypes=False)
@ -355,8 +355,8 @@ class MaskingTest(jtu.JaxTestCase):
out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
return out
ans = vmap(cumsum)([np.arange(6).reshape(2, 3)], dict(n=np.array([1, 2])))
expected = onp.array([0, 7])
ans = vmap(cumsum)([jnp.arange(6).reshape(2, 3)], dict(n=jnp.array([1, 2])))
expected = np.array([0, 7])
self.assertAllClose(ans, expected, check_dtypes=False)
def test_scan_jit(self):
@ -371,17 +371,17 @@ class MaskingTest(jtu.JaxTestCase):
return cumsum(args, shape_env)
python_should_be_executing = True
ans = jit_cumsum([np.array([5, 2, 9, 1, 4])], dict(n=3))
ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=3))
expected = 16
self.assertAllClose(ans, expected, check_dtypes=False)
python_should_be_executing = False
ans = jit_cumsum([np.array([5, 2, 9, 1, 4])], dict(n=4))
ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=4))
expected = 17
self.assertAllClose(ans, expected, check_dtypes=False)
python_should_be_executing = False
ans = jit_cumsum([np.array([5, 2, 9, 1, 4])], dict(n=1))
ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=1))
expected = 5
self.assertAllClose(ans, expected, check_dtypes=False)
@ -390,9 +390,9 @@ class MaskingTest(jtu.JaxTestCase):
def cat(x, y, z):
return lax.concatenate([x, y, z], 0)
ans = cat([np.array([1, 9]), np.array([2, 4, 9]), np.array([3, 9])],
ans = cat([jnp.array([1, 9]), jnp.array([2, 4, 9]), jnp.array([3, 9])],
dict(n=1, m=2))
expected = onp.array([1, 2, 4, 3])
expected = np.array([1, 2, 4, 3])
self.assertAllClose(ans[:4], expected, check_dtypes=False)
def test_dot(self):
@ -400,46 +400,46 @@ class MaskingTest(jtu.JaxTestCase):
def dot(x, y):
return lax.dot(x, y)
x = onp.arange(6, dtype=onp.float32).reshape((2, 3))
y = onp.arange(12, dtype=onp.float32).reshape((3, 4))
x = np.arange(6, dtype=np.float32).reshape((2, 3))
y = np.arange(12, dtype=np.float32).reshape((3, 4))
ans = dot([x, y], dict(m=2, k=2, n=2))
expected = onp.dot(x[:2, :2], y[:2, :2])
expected = np.dot(x[:2, :2], y[:2, :2])
self.assertAllClose(ans[:2, :2], expected, check_dtypes=False)
def test_mean(self):
@partial(mask, in_shapes=['n'], out_shape='')
def padded_sum(x):
return np.sum(x) / shape_as_value(x.shape)[0]
return jnp.sum(x) / shape_as_value(x.shape)[0]
ans = padded_sum([np.array([3, 1, 4, 1, 5])], dict(n=3))
ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=3))
expected = 8 / 3
self.assertAllClose(ans, expected, check_dtypes=False)
def test_monomorphic(self):
@partial(mask, in_shapes=['(_, n)'], out_shape='')
def padded_sum(x):
return np.sum(x)
return jnp.sum(x)
ans = padded_sum([np.array([[3, 4], [5, 6]])], dict(n=1))
ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=1))
expected = 8
self.assertAllClose(ans, expected, check_dtypes=False)
def test_monomorphic2(self):
@partial(mask, in_shapes=['(_, n)'], out_shape='n')
def padded_sum(x):
return np.sum(x, axis=0)
return jnp.sum(x, axis=0)
ans = padded_sum([np.array([[3, 4], [5, 6]])], dict(n=2))
expected = np.array([8, 10])
ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=2))
expected = jnp.array([8, 10])
self.assertAllClose(ans, expected, check_dtypes=False)
def test_monomorphic3(self):
@partial(mask, in_shapes=['(_, n)'], out_shape='_')
def padded_sum(x):
return np.sum(x, axis=1)
return jnp.sum(x, axis=1)
ans = padded_sum([np.array([[3, 4], [5, 6]])], dict(n=1))
expected = np.array([3, 5])
ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=1))
expected = jnp.array([3, 5])
self.assertAllClose(ans, expected, check_dtypes=False)
def test_rnn(self):
@ -448,14 +448,14 @@ class MaskingTest(jtu.JaxTestCase):
@partial(mask, in_shapes=['(_, _)', '(t, _)'], out_shape='_')
def rnn(W, xs):
def step(h, x):
new_h = np.dot(W, h) + np.dot(W, x)
new_h = jnp.dot(W, h) + jnp.dot(W, x)
return new_h, ()
predicted, _ = lax.scan(step, np.zeros(n), xs)
predicted, _ = lax.scan(step, jnp.zeros(n), xs)
return predicted
rng = onp.random.RandomState(0)
W = np.eye(n)
xs = rng.randn(10, n).astype(np.float_)
rng = np.random.RandomState(0)
W = jnp.eye(n)
xs = rng.randn(10, n).astype(jnp.float_)
ans = rnn([W, xs], dict(t=4))
expected = xs[:4].sum(0)
self.assertAllClose(ans, expected, check_dtypes=False)
@ -466,24 +466,24 @@ class MaskingTest(jtu.JaxTestCase):
@partial(mask, in_shapes=['(_, _)', '(t, _)', '_'], out_shape='')
def rnn(W, xs, target):
def step(h, x):
new_h = np.tanh(np.dot(W, h) + np.dot(W, x))
new_h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
return new_h, ()
predicted, _ = lax.scan(step, np.zeros(n), xs)
return np.sum((predicted - target)**2)
predicted, _ = lax.scan(step, jnp.zeros(n), xs)
return jnp.sum((predicted - target)**2)
rng = onp.random.RandomState(0)
W = rng.randn(n, n).astype(np.float_)
xs = rng.randn(10, n).astype(np.float_)
y = rng.randn(n).astype(np.float_)
rng = np.random.RandomState(0)
W = rng.randn(n, n).astype(jnp.float_)
xs = rng.randn(10, n).astype(jnp.float_)
y = rng.randn(n).astype(jnp.float_)
ans = grad(lambda W: rnn([W, xs, y], dict(t=4)))(W)
def rnn_reference(W, xs, target):
h = np.zeros(n)
h = jnp.zeros(n)
for x in xs:
h = np.tanh(np.dot(W, h) + np.dot(W, x))
h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
predicted = h
return np.sum((predicted - target)**2)
return jnp.sum((predicted - target)**2)
expected = grad(lambda W: rnn_reference(W, xs[:4], y))(W)
@ -495,27 +495,27 @@ class MaskingTest(jtu.JaxTestCase):
@partial(mask, in_shapes=('(_, _)', '(t, _)', '_'), out_shape='')
def rnn(W, xs, target):
def step(h, x):
new_h = np.tanh(np.dot(W, h) + np.dot(W, x))
new_h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
return new_h, ()
predicted, _ = lax.scan(step, np.zeros(n), xs)
return np.sum((predicted - target)**2)
predicted, _ = lax.scan(step, jnp.zeros(n), xs)
return jnp.sum((predicted - target)**2)
rng = onp.random.RandomState(0)
W = rng.randn(n, n).astype(np.float_)
seqs = rng.randn(3, 10, n).astype(np.float_)
ts = np.array([2, 5, 4])
rng = np.random.RandomState(0)
W = rng.randn(n, n).astype(jnp.float_)
seqs = rng.randn(3, 10, n).astype(jnp.float_)
ts = jnp.array([2, 5, 4])
ys = rng.randn(3, n)
ans = grad(lambda W: vmap(rnn, ((None, 0, 0), 0))((W, seqs, ys), dict(t=ts)).sum())(W)
def rnn_reference(W, seqs, targets):
total_loss = np.array(0, np.float_)
total_loss = jnp.array(0, jnp.float_)
for xs, target in zip(seqs, targets):
h = np.zeros(n)
h = jnp.zeros(n)
for x in xs:
h = np.tanh(np.dot(W, h) + np.dot(W, x))
h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
predicted = h
total_loss = total_loss + np.sum((predicted - target)**2)
total_loss = total_loss + jnp.sum((predicted - target)**2)
return total_loss
seqs_ = [xs[:t] for xs, t in zip(seqs, ts)]
@ -530,7 +530,7 @@ class MaskingTest(jtu.JaxTestCase):
@partial(mask, in_shapes=['n'], out_shape='')
def padded_sum(x):
return np.sum(x)
return jnp.sum(x)
batched_sum = vmap(padded_sum)
@ -538,10 +538,10 @@ class MaskingTest(jtu.JaxTestCase):
def fun(x, ns):
return batched_sum([x], dict(n=ns)).sum()
x = np.array([[3, 1, 4, 1],
x = jnp.array([[3, 1, 4, 1],
[5, 9, 2, 6],
[5, 3, 5, 8]])
ns = np.array([2, 3, 2])
ns = jnp.array([2, 3, 2])
ans = fun([x, ns], dict(m=2))
expected = 3+1 + 5+9+2
self.assertAllClose(ans, expected, check_dtypes=False)
@ -553,8 +553,8 @@ class MaskingTest(jtu.JaxTestCase):
def padded_add(x):
return x + lax.iota(x.shape[0])
ans = padded_add([np.array([3, 1, 4, 1, 5])], dict(n=3))
expected = onp.array([3, 2, 6])
ans = padded_add([jnp.array([3, 1, 4, 1, 5])], dict(n=3))
expected = np.array([3, 2, 6])
self.assertAllClose(ans[:3], expected, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
@ -577,8 +577,8 @@ class MaskingTest(jtu.JaxTestCase):
def test_slice_oob_indexing(self):
# https://github.com/google/jax/issues/2245
self.assertAllClose(np.ones(5), np.ones(5)[:10], check_dtypes=True)
self.assertAllClose(np.ones(5), np.ones(5)[-10:], check_dtypes=True)
self.assertAllClose(jnp.ones(5), jnp.ones(5)[:10], check_dtypes=True)
self.assertAllClose(jnp.ones(5), jnp.ones(5)[-10:], check_dtypes=True)
if __name__ == '__main__':
absltest.main()

View File

@ -18,13 +18,13 @@ from functools import partial
from absl.testing import absltest
from absl.testing import parameterized
import numpy as onp
import numpy as np
import numpy.random as npr
from unittest import SkipTest
from jax import api
from jax import test_util as jtu
from jax import numpy as np
from jax import numpy as jnp
from jax.config import config
config.parse_flags_with_absl()
@ -46,10 +46,10 @@ class MultiBackendTest(jtu.JaxTestCase):
raise SkipTest("Backend is not CPU or the device under test")
@partial(api.jit, backend=backend)
def fun(x, y):
return np.matmul(x, y)
return jnp.matmul(x, y)
x = npr.uniform(size=(10,10))
y = npr.uniform(size=(10,10))
z_host = onp.matmul(x, y)
z_host = np.matmul(x, y)
z = fun(x, y)
self.assertAllClose(z, z_host, check_dtypes=True, rtol=1e-2)
correct_platform = backend if backend else jtu.device_under_test()
@ -67,11 +67,11 @@ class MultiBackendTest(jtu.JaxTestCase):
def fun(x, y):
@partial(api.jit, backend=inner)
def infun(x, y):
return np.matmul(x, y)
return infun(x, y) + np.ones_like(x)
return jnp.matmul(x, y)
return infun(x, y) + jnp.ones_like(x)
x = npr.uniform(size=(10,10))
y = npr.uniform(size=(10,10))
z_host = onp.matmul(x, y) + onp.ones_like(x)
z_host = np.matmul(x, y) + np.ones_like(x)
z = fun(x, y)
self.assertAllClose(z, z_host, check_dtypes=True, rtol=1e-2)
correct_platform = outer if outer else jtu.device_under_test()
@ -95,8 +95,8 @@ class MultiBackendTest(jtu.JaxTestCase):
def fun(x, y):
@partial(api.jit, backend=inner)
def infun(x, y):
return np.matmul(x, y)
return infun(x, y) + np.ones_like(x)
return jnp.matmul(x, y)
return infun(x, y) + jnp.ones_like(x)
x = npr.uniform(size=(10,10))
y = npr.uniform(size=(10,10))
self.assertRaises(ValueError, lambda: fun(x, y))
@ -111,11 +111,11 @@ class MultiBackendTest(jtu.JaxTestCase):
raise SkipTest("Backend is not CPU or the device under test")
@partial(api.jit, backend=backend)
def fun(x, y):
return np.matmul(x, y)
return jnp.matmul(x, y)
x = npr.uniform(size=(10,10))
y = npr.uniform(size=(10,10))
z = fun(x, y)
w = np.sin(z)
w = jnp.sin(z)
self.assertEqual(z.device_buffer.platform(), backend)
self.assertEqual(w.device_buffer.platform(), backend)
@ -123,13 +123,13 @@ class MultiBackendTest(jtu.JaxTestCase):
def testJitCpu(self):
@partial(api.jit, backend='cpu')
def get_arr(scale):
return scale + np.ones((2, 2))
return scale + jnp.ones((2, 2))
x = get_arr(0.1)
a = x / x.shape[0]
b = x + np.ones_like(x)
c = x + np.eye(2)
b = x + jnp.ones_like(x)
c = x + jnp.eye(2)
self.assertEqual(a.device_buffer.device(), api.devices('cpu')[0])
self.assertEqual(b.device_buffer.device(), api.devices('cpu')[0])
@ -138,7 +138,7 @@ class MultiBackendTest(jtu.JaxTestCase):
@jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends
def test_closed_over_values_device_placement(self):
# see https://github.com/google/jax/issues/1431
def f(): return np.add(3., 4.)
def f(): return jnp.add(3., 4.)
self.assertNotEqual(api.jit(f)().device_buffer.device(),
api.devices('cpu')[0])
self.assertEqual(api.jit(f, backend='cpu')().device_buffer.device(),
@ -156,7 +156,7 @@ class MultiBackendTest(jtu.JaxTestCase):
data_on_cpu = api.device_put(1, device=cpus[0])
self.assertEqual(data_on_cpu.device_buffer.device(), cpus[0])
def my_sin(x): return np.sin(x)
def my_sin(x): return jnp.sin(x)
# jit without any device spec follows the data
result1 = api.jit(my_sin)(2)
self.assertEqual(result1.device_buffer.device(), default_dev)
@ -176,7 +176,7 @@ class MultiBackendTest(jtu.JaxTestCase):
# https://github.com/google/jax/issues/2905
cpus = api.devices("cpu")
x = api.device_put(onp.ones(2), cpus[0])
x = api.device_put(np.ones(2), cpus[0])
y = x[0]
self.assertEqual(y.device_buffer.device(), cpus[0])
@ -185,7 +185,7 @@ class MultiBackendTest(jtu.JaxTestCase):
# https://github.com/google/jax/issues/2905
cpus = api.devices("cpu")
x = api.device_put(onp.ones(2), cpus[0])
x = api.device_put(np.ones(2), cpus[0])
y = x.sum()
self.assertEqual(y.device_buffer.device(), cpus[0])

View File

@ -13,23 +13,23 @@
# limitations under the License.
import numpy as onp
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
from jax import numpy as np
from jax import numpy as jnp
from jax import test_util as jtu, jit, partial
from jax.config import config
config.parse_flags_with_absl()
float_dtypes = [onp.float32, onp.float64]
float_dtypes = [np.float32, np.float64]
# implementation casts to complex64.
complex_dtypes = [onp.complex64]
complex_dtypes = [np.complex64]
inexact_dtypes = float_dtypes + complex_dtypes
int_dtypes = [onp.int32, onp.int64]
int_dtypes = [np.int32, np.int64]
real_dtypes = float_dtypes + int_dtypes
all_dtypes = real_dtypes + complex_dtypes
@ -51,17 +51,17 @@ class TestPolynomial(jtu.JaxTestCase):
for leading in [0, 1, 2, 3, 5, 7, 10]
for trailing in [0, 1, 2, 3, 5, 7, 10]))
def testRoots(self, dtype, rng_factory, length, leading, trailing):
rng = rng_factory(onp.random.RandomState(0))
rng = rng_factory(np.random.RandomState(0))
def args_maker():
p = rng((length,), dtype)
return np.concatenate(
[np.zeros(leading, p.dtype), p, np.zeros(trailing, p.dtype)]),
return jnp.concatenate(
[jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)]),
# order may differ (np.sort doesn't deal with complex numbers)
np_fn = lambda arg: onp.sort(np.roots(arg))
onp_fn = lambda arg: onp.sort(onp.roots(arg))
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=False,
# order may differ (jnp.sort doesn't deal with complex numbers)
np_fn = lambda arg: np.sort(jnp.roots(arg))
np_fn = lambda arg: np.sort(np.roots(arg))
self._CheckAgainstNumpy(np_fn, np_fn, args_maker, check_dtypes=False,
tol=3e-6)
@parameterized.named_parameters(jtu.cases_from_list(
@ -74,20 +74,20 @@ class TestPolynomial(jtu.JaxTestCase):
for length in [0, 1, 3, 10]
for trailing in [0, 1, 3, 7]))
def testRootsNostrip(self, length, dtype, rng_factory, trailing):
rng = rng_factory(onp.random.RandomState(0))
rng = rng_factory(np.random.RandomState(0))
def args_maker():
p = rng((length,), dtype)
if length != 0:
return np.concatenate([p, np.zeros(trailing, p.dtype)]),
return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]),
else:
# adding trailing would make input invalid (start with zeros)
return p,
# order may differ (np.sort doesn't deal with complex numbers)
np_fn = lambda arg: onp.sort(np.roots(arg, strip_zeros=False))
onp_fn = lambda arg: onp.sort(onp.roots(arg))
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker,
# order may differ (jnp.sort doesn't deal with complex numbers)
np_fn = lambda arg: np.sort(jnp.roots(arg, strip_zeros=False))
np_fn = lambda arg: np.sort(np.roots(arg))
self._CheckAgainstNumpy(np_fn, np_fn, args_maker,
check_dtypes=False, tol=1e-6)
@parameterized.named_parameters(jtu.cases_from_list(
@ -103,23 +103,23 @@ class TestPolynomial(jtu.JaxTestCase):
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
def testRootsJit(self, length, dtype, rng_factory, trailing):
rng = rng_factory(onp.random.RandomState(0))
rng = rng_factory(np.random.RandomState(0))
def args_maker():
p = rng((length,), dtype)
if length != 0:
return np.concatenate([p, np.zeros(trailing, p.dtype)]),
return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]),
else:
# adding trailing would make input invalid (start with zeros)
return p,
# order may differ (np.sort doesn't deal with complex numbers)
roots_compiled = jit(partial(np.roots, strip_zeros=False))
np_fn = lambda arg: onp.sort(roots_compiled(arg))
onp_fn = lambda arg: onp.sort(onp.roots(arg))
# order may differ (jnp.sort doesn't deal with complex numbers)
roots_compiled = jit(partial(jnp.roots, strip_zeros=False))
np_fn = lambda arg: np.sort(roots_compiled(arg))
np_fn = lambda arg: np.sort(np.roots(arg))
# Using strip_zeros=False makes the algorithm less efficient
# and leads to slightly different values compared ot numpy
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker,
self._CheckAgainstNumpy(np_fn, np_fn, args_maker,
check_dtypes=False, tol=1e-6)
@parameterized.named_parameters(jtu.cases_from_list(
@ -133,19 +133,19 @@ class TestPolynomial(jtu.JaxTestCase):
for zeros in [1, 2, 5]
for nonzeros in [0, 3]))
def testRootsInvalid(self, zeros, nonzeros, dtype, rng_factory):
rng = rng_factory(onp.random.RandomState(0))
rng = rng_factory(np.random.RandomState(0))
# The polynomial coefficients here start with zero and would have to
# be stripped before computing eigenvalues of the companion matrix.
# Setting strip_zeros=False skips this check,
# allowing jit transformation but yielding nan's for these inputs.
p = np.concatenate([np.zeros(zeros, dtype), rng((nonzeros,), dtype)])
p = jnp.concatenate([jnp.zeros(zeros, dtype), rng((nonzeros,), dtype)])
if p.size == 1:
# polynomial = const has no roots
self.assertTrue(np.roots(p, strip_zeros=False).size == 0)
self.assertTrue(jnp.roots(p, strip_zeros=False).size == 0)
else:
self.assertTrue(np.any(np.isnan(np.roots(p, strip_zeros=False))))
self.assertTrue(jnp.any(jnp.isnan(jnp.roots(p, strip_zeros=False))))
if __name__ == "__main__":

View File

@ -19,7 +19,7 @@ from unittest import SkipTest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as onp
import numpy as np
import scipy.linalg
import scipy.special
import scipy.stats
@ -28,7 +28,7 @@ from jax import api
from jax import core
from jax import grad
from jax import lax
from jax import numpy as np
from jax import numpy as jnp
from jax import random
from jax import test_util as jtu
from jax import vmap
@ -45,9 +45,9 @@ class LaxRandomTest(jtu.JaxTestCase):
nitems = len(samples)
nbins = 2 ** nbits
nexpected = nbins * (1 - ((nbins - 1) / nbins) ** nitems)
ncollisions = len(onp.unique(samples))
ncollisions = len(np.unique(samples))
sq_percent_deviation = ((ncollisions - nexpected) / nexpected) ** 2
self.assertLess(sq_percent_deviation, 1 / onp.sqrt(nexpected * fail_prob))
self.assertLess(sq_percent_deviation, 1 / np.sqrt(nexpected * fail_prob))
def _CheckKolmogorovSmirnovCDF(self, samples, cdf):
fail_prob = 0.01 # conservative bound on statistical fail prob by Kolmo CDF
@ -55,7 +55,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def _CheckChiSquared(self, samples, pmf):
alpha = 0.01 # significance level, threshold for p-value
values, actual_freq = onp.unique(samples, return_counts=True)
values, actual_freq = np.unique(samples, return_counts=True)
expected_freq = pmf(values) * samples.size
# per scipy: "A typical rule is that all of the observed and expected
# frequencies should be at least 5."
@ -71,16 +71,16 @@ class LaxRandomTest(jtu.JaxTestCase):
f'{expected_freq[valid]}\n{actual_freq[valid]}')
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64]))
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
for dtype in [np.float32, np.float64]))
def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
if not FLAGS.jax_enable_x64 and np.issubdtype(dtype, onp.float64):
if not FLAGS.jax_enable_x64 and jnp.issubdtype(dtype, np.float64):
raise SkipTest("can't test float64 agreement")
bits_dtype = onp.uint32 if np.finfo(dtype).bits == 32 else onp.uint64
numpy_bits = onp.array(1., dtype).view(bits_dtype)
bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
numpy_bits = np.array(1., dtype).view(bits_dtype)
xla_bits = api.jit(
lambda: lax.bitcast_convert_type(onp.array(1., dtype), bits_dtype))()
lambda: lax.bitcast_convert_type(np.array(1., dtype), bits_dtype))()
self.assertEqual(numpy_bits, xla_bits)
def testThreefry2x32(self):
@ -91,34 +91,34 @@ class LaxRandomTest(jtu.JaxTestCase):
return tuple([hex(x.copy()).rstrip("L") for x in result])
expected = ("0x6b200159", "0x99ba4efe")
result = random.threefry_2x32(onp.uint32([0, 0]), onp.uint32([0, 0]))
result = random.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0]))
self.assertEqual(expected, result_to_hex(result))
expected = ("0x1cb996fc", "0xbb002be7")
result = random.threefry_2x32(onp.uint32([-1, -1]), onp.uint32([-1, -1]))
result = random.threefry_2x32(np.uint32([-1, -1]), np.uint32([-1, -1]))
self.assertEqual(expected, result_to_hex(result))
expected = ("0xc4923a9c", "0x483df7a0")
result = random.threefry_2x32(
onp.uint32([0x13198a2e, 0x03707344]),
onp.uint32([0x243f6a88, 0x85a308d3]))
np.uint32([0x13198a2e, 0x03707344]),
np.uint32([0x243f6a88, 0x85a308d3]))
self.assertEqual(expected, result_to_hex(result))
def testThreefry2x32Large(self):
n = 10000000
result = random.threefry_2x32(
(onp.uint32(0x13198a2e), onp.uint32(0x03707344)),
np.concatenate([
np.full((n,), 0x243f6a88, np.uint32),
np.full((n,), 0x85a308d3, np.uint32)
(np.uint32(0x13198a2e), np.uint32(0x03707344)),
jnp.concatenate([
jnp.full((n,), 0x243f6a88, jnp.uint32),
jnp.full((n,), 0x85a308d3, jnp.uint32)
]))
onp.testing.assert_equal(result[:n], onp.full((n,), 0xc4923a9c, dtype=onp.uint32))
onp.testing.assert_equal(result[n:], onp.full((n,), 0x483df7a0, dtype=onp.uint32))
np.testing.assert_equal(result[:n], np.full((n,), 0xc4923a9c, dtype=np.uint32))
np.testing.assert_equal(result[n:], np.full((n,), 0x483df7a0, dtype=np.uint32))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64]))
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
for dtype in [np.float32, np.float64]))
def testRngUniform(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.uniform(key, (10000,), dtype)
@ -128,12 +128,12 @@ class LaxRandomTest(jtu.JaxTestCase):
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckCollisions(samples, np.finfo(dtype).nmant)
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.int32, onp.int64]))
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
for dtype in [np.int32, np.int64]))
def testRngRandint(self, dtype):
lo = 5
hi = 10
@ -146,12 +146,12 @@ class LaxRandomTest(jtu.JaxTestCase):
compiled_samples = crand(key)
for samples in [uncompiled_samples, compiled_samples]:
self.assertTrue(onp.all(lo <= samples))
self.assertTrue(onp.all(samples < hi))
self.assertTrue(np.all(lo <= samples))
self.assertTrue(np.all(samples < hi))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64]))
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
for dtype in [np.float32, np.float64]))
def testNormal(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.normal(key, (10000,), dtype)
@ -164,11 +164,11 @@ class LaxRandomTest(jtu.JaxTestCase):
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64, onp.int32, onp.int64]))
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
for dtype in [np.float32, np.float64, np.int32, np.int64]))
def testShuffle(self, dtype):
key = random.PRNGKey(0)
x = onp.arange(100).astype(dtype)
x = np.arange(100).astype(dtype)
rand = lambda key: random.shuffle(key, x)
crand = api.jit(rand)
@ -178,17 +178,17 @@ class LaxRandomTest(jtu.JaxTestCase):
perm2 = crand(key)
self.assertAllClose(perm1, perm2, check_dtypes=True)
self.assertFalse(onp.all(perm1 == x)) # seems unlikely!
self.assertAllClose(onp.sort(perm1), x, check_dtypes=False)
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
self.assertAllClose(np.sort(perm1), x, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"dtype": onp.dtype(dtype).name, "shape": shape}
for dtype in [onp.float32, onp.float64, onp.int32, onp.int64]
"dtype": np.dtype(dtype).name, "shape": shape}
for dtype in [np.float32, np.float64, np.int32, np.int64]
for shape in [100, (10, 10), (10, 5, 2)]))
def testPermutationArray(self, dtype, shape):
key = random.PRNGKey(0)
x = np.arange(np.prod(shape)).reshape(shape).astype(dtype)
x = jnp.arange(jnp.prod(shape)).reshape(shape).astype(dtype)
rand = lambda key: random.permutation(key, x)
crand = api.jit(rand)
@ -196,10 +196,10 @@ class LaxRandomTest(jtu.JaxTestCase):
perm2 = crand(key)
self.assertAllClose(perm1, perm2, check_dtypes=True)
self.assertFalse(onp.all(perm1 == x)) # seems unlikely!
self.assertAllClose(onp.sort(perm1.ravel()), x.ravel(), check_dtypes=False)
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
self.assertAllClose(np.sort(perm1.ravel()), x.ravel(), check_dtypes=False)
self.assertArraysAllClose(
x, np.arange(np.prod(shape)).reshape(shape).astype(dtype),
x, jnp.arange(jnp.prod(shape)).reshape(shape).astype(dtype),
check_dtypes=True)
def testPermutationInteger(self):
@ -213,8 +213,8 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertAllClose(perm1, perm2, check_dtypes=True)
self.assertEqual(perm1.dtype, perm2.dtype)
self.assertFalse(onp.all(perm1 == onp.arange(100))) # seems unlikely!
self.assertAllClose(onp.sort(perm1), onp.arange(100), check_dtypes=False)
self.assertFalse(np.all(perm1 == np.arange(100))) # seems unlikely!
self.assertAllClose(np.sort(perm1), np.arange(100), check_dtypes=False)
def testPermutationErrors(self):
key = random.PRNGKey(0)
@ -225,12 +225,12 @@ class LaxRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_p={}_{}".format(p, dtype),
"p": p, "dtype": onp.dtype(dtype).name}
"p": p, "dtype": np.dtype(dtype).name}
for p in [0.1, 0.5, 0.9]
for dtype in [onp.float32, onp.float64]))
for dtype in [np.float32, np.float64]))
def testBernoulli(self, p, dtype):
key = random.PRNGKey(0)
p = onp.array(p, dtype=dtype)
p = np.array(p, dtype=dtype)
rand = lambda key, p: random.bernoulli(key, p, (10000,))
crand = api.jit(rand)
@ -242,7 +242,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_p={}_{}_{}".format(p, dtype, sample_shape),
"p": p, "axis": axis, "dtype": onp.dtype(dtype).name, 'sample_shape': sample_shape}
"p": p, "axis": axis, "dtype": np.dtype(dtype).name, 'sample_shape': sample_shape}
for (p, axis) in [
([.25] * 4, -1),
([.1, .2, .3, .4], -1),
@ -250,12 +250,12 @@ class LaxRandomTest(jtu.JaxTestCase):
([[.5, .1], [.5, .9]], 0),
]
for sample_shape in [(10000,), (5000, 2)]
for dtype in [onp.float32, onp.float64]))
for dtype in [np.float32, np.float64]))
def testCategorical(self, p, axis, dtype, sample_shape):
key = random.PRNGKey(0)
p = onp.array(p, dtype=dtype)
logits = onp.log(p) - 42 # test unnormalized
out_shape = tuple(onp.delete(logits.shape, axis))
p = np.array(p, dtype=dtype)
logits = np.log(p) - 42 # test unnormalized
out_shape = tuple(np.delete(logits.shape, axis))
shape = sample_shape + out_shape
rand = lambda key, p: random.categorical(key, logits, shape=shape, axis=axis)
crand = api.jit(rand)
@ -268,9 +268,9 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
assert samples.shape == shape
samples = np.reshape(samples, (10000,) + out_shape)
samples = jnp.reshape(samples, (10000,) + out_shape)
if len(p.shape[:-1]) > 0:
ps = onp.transpose(p, (1, 0)) if axis == 0 else p
ps = np.transpose(p, (1, 0)) if axis == 0 else p
for cat_samples, cat_p in zip(samples.transpose(), ps):
self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x])
else:
@ -278,15 +278,15 @@ class LaxRandomTest(jtu.JaxTestCase):
def testBernoulliShape(self):
key = random.PRNGKey(0)
x = random.bernoulli(key, onp.array([0.2, 0.3]), shape=(3, 2))
x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_b={}_{}".format(a, b, dtype),
"a": a, "b": b, "dtype": onp.dtype(dtype).name}
"a": a, "b": b, "dtype": np.dtype(dtype).name}
for a in [0.2, 5.]
for b in [0.2, 5.]
for dtype in [onp.float64])) # NOTE: KS test fails with float32
for dtype in [np.float64])) # NOTE: KS test fails with float32
def testBeta(self, a, b, dtype):
if not FLAGS.jax_enable_x64:
raise SkipTest("skip test except on X64")
@ -301,8 +301,8 @@ class LaxRandomTest(jtu.JaxTestCase):
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64]))
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
for dtype in [np.float32, np.float64]))
def testCauchy(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.cauchy(key, (10000,), dtype)
@ -316,11 +316,11 @@ class LaxRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_alpha={}_{}".format(alpha, dtype),
"alpha": alpha, "dtype": onp.dtype(dtype).name}
"alpha": alpha, "dtype": np.dtype(dtype).name}
for alpha in [
onp.array([0.2, 1., 5.]),
np.array([0.2, 1., 5.]),
]
for dtype in [onp.float32, onp.float64]))
for dtype in [np.float32, np.float64]))
def testDirichlet(self, alpha, dtype):
key = random.PRNGKey(0)
rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype)
@ -330,14 +330,14 @@ class LaxRandomTest(jtu.JaxTestCase):
compiled_samples = crand(key, alpha)
for samples in [uncompiled_samples, compiled_samples]:
self.assertAllClose(samples.sum(-1), onp.ones(10000, dtype=dtype), check_dtypes=True)
self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype), check_dtypes=True)
alpha_sum = sum(alpha)
for i, a in enumerate(alpha):
self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64]))
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
for dtype in [np.float32, np.float64]))
def testExponential(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.exponential(key, (10000,), dtype)
@ -351,9 +351,9 @@ class LaxRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_{}".format(a, dtype),
"a": a, "dtype": onp.dtype(dtype).name}
"a": a, "dtype": np.dtype(dtype).name}
for a in [0.1, 1., 10.]
for dtype in [onp.float32, onp.float64]))
for dtype in [np.float32, np.float64]))
def testGamma(self, a, dtype):
key = random.PRNGKey(0)
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
@ -367,7 +367,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def testGammaShape(self):
key = random.PRNGKey(0)
x = random.gamma(key, onp.array([0.2, 0.3]), shape=(3, 2))
x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@parameterized.named_parameters(jtu.cases_from_list(
@ -375,11 +375,11 @@ class LaxRandomTest(jtu.JaxTestCase):
for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]))
def testGammaGrad(self, alpha):
rng = random.PRNGKey(0)
alphas = onp.full((100,), alpha)
alphas = np.full((100,), alpha)
z = random.gamma(rng, alphas)
actual_grad = api.grad(lambda x: random.gamma(rng, x).sum())(alphas)
eps = 0.01 * alpha / (1.0 + onp.sqrt(alpha))
eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps)
- scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps)
pdf = scipy.stats.gamma.pdf(z, alpha)
@ -391,17 +391,17 @@ class LaxRandomTest(jtu.JaxTestCase):
def testGammaGradType(self):
# Regression test for https://github.com/google/jax/issues/2130
key = random.PRNGKey(0)
a = np.array(1., dtype=np.float32)
b = np.array(3., dtype=np.float32)
f = lambda x, y: random.gamma(key=key, a=x, dtype=np.float32) / y
a = jnp.array(1., dtype=jnp.float32)
b = jnp.array(3., dtype=jnp.float32)
f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y
# Should not crash with a type error.
api.vjp(f, a, b)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lam={}_{}".format(lam, dtype),
"lam": lam, "dtype": onp.dtype(dtype).name}
"lam": lam, "dtype": np.dtype(dtype).name}
for lam in [0.5, 3, 9, 11, 50, 500]
for dtype in [onp.int32, onp.int64]))
for dtype in [np.int32, np.int64]))
def testPoisson(self, lam, dtype):
key = random.PRNGKey(0)
rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
@ -419,19 +419,19 @@ class LaxRandomTest(jtu.JaxTestCase):
def testPoissonBatched(self):
key = random.PRNGKey(0)
lam = np.concatenate([2 * np.ones(10000), 20 * np.ones(10000)])
lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)])
samples = random.poisson(key, lam, shape=(20000,))
self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)
self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)
def testPoissonShape(self):
key = random.PRNGKey(0)
x = random.poisson(key, onp.array([2.0, 20.0]), shape=(3, 2))
x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2))
assert x.shape == (3, 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64]))
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
for dtype in [np.float32, np.float64]))
def testGumbel(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.gumbel(key, (10000,), dtype)
@ -444,8 +444,8 @@ class LaxRandomTest(jtu.JaxTestCase):
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gumbel_r().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64]))
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
for dtype in [np.float32, np.float64]))
def testLaplace(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.laplace(key, (10000,), dtype)
@ -458,8 +458,8 @@ class LaxRandomTest(jtu.JaxTestCase):
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64]))
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
for dtype in [np.float32, np.float64]))
def testLogistic(self, dtype):
key = random.PRNGKey(0)
rand = lambda key: random.logistic(key, (10000,), dtype)
@ -473,9 +473,9 @@ class LaxRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_b={}_{}".format(b, dtype),
"b": b, "dtype": onp.dtype(dtype).name}
"b": b, "dtype": np.dtype(dtype).name}
for b in [0.1, 1., 10.]
for dtype in [onp.float32, onp.float64]))
for dtype in [np.float32, np.float64]))
def testPareto(self, b, dtype):
key = random.PRNGKey(0)
rand = lambda key, b: random.pareto(key, b, (10000,), dtype)
@ -489,14 +489,14 @@ class LaxRandomTest(jtu.JaxTestCase):
def testParetoShape(self):
key = random.PRNGKey(0)
x = random.pareto(key, onp.array([0.2, 0.3]), shape=(3, 2))
x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_df={}_{}".format(df, dtype),
"df": df, "dtype": onp.dtype(dtype).name}
"df": df, "dtype": np.dtype(dtype).name}
for df in [0.1, 1., 10.]
for dtype in [onp.float32, onp.float64]))
for dtype in [np.float32, np.float64]))
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
def testT(self, df, dtype):
key = random.PRNGKey(0)
@ -510,29 +510,29 @@ class LaxRandomTest(jtu.JaxTestCase):
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}D_{}".format(dim, onp.dtype(dtype).name),
{"testcase_name": "_{}D_{}".format(dim, np.dtype(dtype).name),
"dim": dim, "dtype": dtype}
for dim in [1, 3, 5]
for dtype in [onp.float32, onp.float64]))
for dtype in [np.float32, np.float64]))
@jtu.skip_on_mac_linalg_bug()
def testMultivariateNormal(self, dim, dtype):
r = onp.random.RandomState(dim)
r = np.random.RandomState(dim)
mean = r.randn(dim)
cov_factor = r.randn(dim, dim)
cov = onp.dot(cov_factor, cov_factor.T) + dim * onp.eye(dim)
cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim)
key = random.PRNGKey(0)
rand = partial(random.multivariate_normal, mean=mean, cov=cov,
shape=(10000,))
crand = api.jit(rand)
uncompiled_samples = onp.asarray(rand(key), onp.float64)
compiled_samples = onp.asarray(crand(key), onp.float64)
uncompiled_samples = np.asarray(rand(key), np.float64)
compiled_samples = np.asarray(crand(key), np.float64)
inv_scale = scipy.linalg.lapack.dtrtri(onp.linalg.cholesky(cov), lower=True)[0]
inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov), lower=True)[0]
for samples in [uncompiled_samples, compiled_samples]:
centered = samples - mean
whitened = onp.einsum('nj,ij->ni', centered, inv_scale)
whitened = np.einsum('nj,ij->ni', centered, inv_scale)
# This is a quick-and-dirty multivariate normality check that tests that a
# uniform mixture of the marginals along the covariance matrix's
@ -543,25 +543,25 @@ class LaxRandomTest(jtu.JaxTestCase):
def testMultivariateNormalCovariance(self):
# test code based on https://github.com/google/jax/issues/1869
N = 100000
cov = np.array([[ 0.19, 0.00, -0.13, 0.00],
cov = jnp.array([[ 0.19, 0.00, -0.13, 0.00],
[ 0.00, 0.29, 0.00, -0.23],
[ -0.13, 0.00, 0.39, 0.00],
[ 0.00, -0.23, 0.00, 0.49]])
mean = np.zeros(4)
mean = jnp.zeros(4)
out_onp = onp.random.RandomState(0).multivariate_normal(mean, cov, N)
out_np = np.random.RandomState(0).multivariate_normal(mean, cov, N)
key = random.PRNGKey(0)
out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N,))
var_onp = out_onp.var(axis=0)
var_np = out_np.var(axis=0)
var_jnp = out_jnp.var(axis=0)
self.assertAllClose(var_onp, var_jnp, rtol=1e-2, atol=1e-2,
self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2,
check_dtypes=False)
var_onp = onp.cov(out_onp, rowvar=False)
var_jnp = onp.cov(out_jnp, rowvar=False)
self.assertAllClose(var_onp, var_jnp, rtol=1e-2, atol=1e-2,
var_np = np.cov(out_np, rowvar=False)
var_jnp = np.cov(out_jnp, rowvar=False)
self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2,
check_dtypes=False)
def testIssue222(self):
@ -571,7 +571,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def testFoldIn(self):
key = random.PRNGKey(0)
keys = [random.fold_in(key, i) for i in range(10)]
assert onp.unique(onp.ravel(keys)).shape == (20,)
assert np.unique(np.ravel(keys)).shape == (20,)
def testStaticShapeErrors(self):
if config.read("jax_disable_jit"):
@ -582,9 +582,9 @@ class LaxRandomTest(jtu.JaxTestCase):
key = random.PRNGKey(seed)
W = random.normal(key, (d, n)) / sigma
w = random.normal(key, (d, )) / sigma
b = 2 * np.pi * random.uniform(key, (d, ))
b = 2 * jnp.pi * random.uniform(key, (d, ))
phi = lambda x, t: np.sqrt(2.0 / d) * np.cos(np.matmul(W, x) + w*t + b)
phi = lambda x, t: jnp.sqrt(2.0 / d) * jnp.cos(jnp.matmul(W, x) + w*t + b)
return phi
self.assertRaisesRegex(TypeError, 'Shapes must be 1D.*',
@ -594,21 +594,21 @@ class LaxRandomTest(jtu.JaxTestCase):
key = random.PRNGKey(0)
w = random.normal(key, ())
if FLAGS.jax_enable_x64:
self.assertEqual(onp.result_type(w), onp.float64)
self.assertEqual(np.result_type(w), np.float64)
else:
self.assertEqual(onp.result_type(w), onp.float32)
self.assertEqual(np.result_type(w), np.float32)
def testIssue1789(self):
def f(x):
return random.gamma(random.PRNGKey(0), x)
grad(lambda x: np.sum(vmap(f)(x)))(np.ones(2))
grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))
def testNoOpByOpUnderHash(self):
def fail(*args, **kwargs): assert False
apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
try:
out = random.threefry_2x32(onp.zeros(2, onp.uint32), onp.arange(10, dtype=onp.uint32))
out = random.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32))
finally:
xla.apply_primitive = apply_primitive
@ -620,21 +620,21 @@ class LaxRandomTest(jtu.JaxTestCase):
if FLAGS.jax_enable_x64:
self.assertAllClose(
random.randint(k, (3, 3), 0, 8),
onp.array([[7, 2, 6],
np.array([[7, 2, 6],
[2, 1, 0],
[6, 7, 7]], dtype='int64'),
check_dtypes=True)
else:
self.assertAllClose(
random.randint(k, (3, 3), 0, 8),
onp.array([[2, 1, 3],
np.array([[2, 1, 3],
[6, 1, 5],
[6, 3, 4]], dtype='int32'),
check_dtypes=True)
self.assertAllClose(
random.split(k, 4),
onp.array([[2285895361, 1501764800],
np.array([[2285895361, 1501764800],
[1518642379, 4090693311],
[ 433833334, 4221794875],
[ 839183663, 3740430601]], dtype='uint32'),
@ -642,7 +642,7 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertAllClose(
random.fold_in(k, 4),
onp.array([2285895361, 433833334], dtype='uint32'),
np.array([2285895361, 433833334], dtype='uint32'),
check_dtypes=True)

View File

@ -17,9 +17,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import numpy as onp
from jax import numpy as np
from jax import numpy as jnp
from jax import test_util as jtu
from jax import random
from jax.experimental.vectorize import vectorize
@ -27,24 +25,24 @@ from jax.experimental.vectorize import vectorize
from jax.config import config
config.parse_flags_with_absl()
matmat = vectorize('(n,m),(m,k)->(n,k)')(np.dot)
matvec = vectorize('(n,m),(m)->(n)')(np.dot)
vecmat = vectorize('(m),(m,k)->(k)')(np.dot)
vecvec = vectorize('(m),(m)->()')(np.dot)
matmat = vectorize('(n,m),(m,k)->(n,k)')(jnp.dot)
matvec = vectorize('(n,m),(m)->(n)')(jnp.dot)
vecmat = vectorize('(m),(m,k)->(k)')(jnp.dot)
vecvec = vectorize('(m),(m)->()')(jnp.dot)
@vectorize('(n)->()')
def magnitude(x):
return np.dot(x, x)
return jnp.dot(x, x)
mean = vectorize('(n)->()')(np.mean)
mean = vectorize('(n)->()')(jnp.mean)
@vectorize('()->(n)')
def stack_plus_minus(x):
return np.stack([x, -x])
return jnp.stack([x, -x])
@vectorize('(n)->(),(n)')
def center(array):
bias = np.mean(array)
bias = jnp.mean(array)
debiased = array - bias
return bias, debiased
@ -60,8 +58,8 @@ class VectorizeTest(jtu.JaxTestCase):
((6, 5, 2, 3), (3, 4), (6, 5, 2, 4)),
]))
def test_matmat(self, left_shape, right_shape, result_shape):
self.assertEqual(matmat(np.zeros(left_shape),
np.zeros(right_shape)).shape, result_shape)
self.assertEqual(matmat(jnp.zeros(left_shape),
jnp.zeros(right_shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
@ -73,8 +71,8 @@ class VectorizeTest(jtu.JaxTestCase):
((5, 4, 2, 3), (1, 3), (5, 4, 2)),
]))
def test_matvec(self, left_shape, right_shape, result_shape):
self.assertEqual(matvec(np.zeros(left_shape),
np.zeros(right_shape)).shape, result_shape)
self.assertEqual(matvec(jnp.zeros(left_shape),
jnp.zeros(right_shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
@ -85,8 +83,8 @@ class VectorizeTest(jtu.JaxTestCase):
((4, 2, 3), (3,), (4, 2)),
]))
def test_vecvec(self, left_shape, right_shape, result_shape):
self.assertEqual(vecvec(np.zeros(left_shape),
np.zeros(right_shape)).shape, result_shape)
self.assertEqual(vecvec(jnp.zeros(left_shape),
jnp.zeros(right_shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(shape),
@ -100,7 +98,7 @@ class VectorizeTest(jtu.JaxTestCase):
size = 1
for x in shape:
size *= x
self.assertEqual(magnitude(np.arange(size).reshape(shape)).shape, result_shape)
self.assertEqual(magnitude(jnp.arange(size).reshape(shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(shape),
@ -111,10 +109,10 @@ class VectorizeTest(jtu.JaxTestCase):
((1, 2, 3, 4), (1, 2, 3)),
]))
def test_mean(self, shape, result_shape):
self.assertEqual(mean(np.zeros(shape)).shape, result_shape)
self.assertEqual(mean(jnp.zeros(shape)).shape, result_shape)
def test_mean_axis(self):
self.assertEqual(mean(np.zeros((2, 3)), axis=0).shape, (3,))
self.assertEqual(mean(jnp.zeros((2, 3)), axis=0).shape, (3,))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(shape),
@ -124,24 +122,24 @@ class VectorizeTest(jtu.JaxTestCase):
((3,), (3,2,)),
]))
def test_stack_plus_minus(self, shape, result_shape):
self.assertEqual(stack_plus_minus(np.zeros(shape)).shape, result_shape)
self.assertEqual(stack_plus_minus(jnp.zeros(shape)).shape, result_shape)
def test_center(self):
b, a = center(np.arange(3))
b, a = center(jnp.arange(3))
self.assertEqual(a.shape, (3,))
self.assertEqual(b.shape, ())
self.assertAllClose(1.0, b, False)
X = np.arange(12).reshape((3, 4))
X = jnp.arange(12).reshape((3, 4))
b, a = center(X, axis=1)
self.assertEqual(a.shape, (3, 4))
self.assertEqual(b.shape, (3,))
self.assertAllClose(np.mean(X, axis=1), b, True)
self.assertAllClose(jnp.mean(X, axis=1), b, True)
b, a = center(X, axis=0)
self.assertEqual(a.shape, (3, 4))
self.assertEqual(b.shape, (4,))
self.assertAllClose(np.mean(X, axis=0), b, True)
self.assertAllClose(jnp.mean(X, axis=0), b, True)
if __name__ == "__main__":