mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Replace np -> jnp, onp -> np in more places. (#2973)
* Replace np -> jnp, onp -> np in more places. Context: #2370 * Fix typo in random_test.py
This commit is contained in:
parent
d59ecddfe8
commit
b1bc841ae5
@ -21,13 +21,13 @@ To make it run faster, set env var TARGET_TOTAL_SECS to a low number (e.g. 2).
|
||||
from absl import app
|
||||
|
||||
import jax
|
||||
from jax import numpy as np
|
||||
from jax import numpy as jnp
|
||||
from jax import pmap
|
||||
from jax.config import config
|
||||
|
||||
from benchmarks import benchmark
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pmap_shard_sharded_device_array_benchmark():
|
||||
@ -38,9 +38,9 @@ def pmap_shard_sharded_device_array_benchmark():
|
||||
"""
|
||||
|
||||
def get_benchmark_fn(nargs, nshards):
|
||||
pmap_fn = pmap(lambda *args: np.sum(args))
|
||||
pmap_fn = pmap(lambda *args: jnp.sum(args))
|
||||
shape = (nshards, 4)
|
||||
args = [onp.random.random(shape) for _ in range(nargs)]
|
||||
args = [np.random.random(shape) for _ in range(nargs)]
|
||||
sharded_args = pmap(lambda x: x)(args)
|
||||
assert all(isinstance(arg, jax.pxla.ShardedDeviceArray)
|
||||
for arg in sharded_args)
|
||||
@ -68,9 +68,9 @@ def pmap_shard_device_array_benchmark():
|
||||
"""
|
||||
|
||||
def get_benchmark_fn(nargs, nshards):
|
||||
pmap_fn = pmap(lambda *args: np.sum(args))
|
||||
pmap_fn = pmap(lambda *args: jnp.sum(args))
|
||||
shape = (nshards, 4)
|
||||
args = [np.array(onp.random.random(shape)) for _ in range(nargs)]
|
||||
args = [jnp.array(np.random.random(shape)) for _ in range(nargs)]
|
||||
assert all(isinstance(arg, jax.xla.DeviceArray) for arg in args)
|
||||
def benchmark_fn():
|
||||
for _ in range(10):
|
||||
@ -96,7 +96,7 @@ def pmap_shard_outputs_benchmark():
|
||||
def get_benchmark_fn(nouts, nshards):
|
||||
pmap_fn = pmap(lambda x: [x + i for i in range(nouts)])
|
||||
shape = (nshards, 4)
|
||||
arg = onp.random.random(shape)
|
||||
arg = np.random.random(shape)
|
||||
def benchmark_fn():
|
||||
for _ in range(100):
|
||||
pmap_fn(arg)
|
||||
@ -118,7 +118,7 @@ def sharded_device_array_indexing_benchmark():
|
||||
nshards = min(8, jax.local_device_count())
|
||||
shape = (nshards, 8, 8)
|
||||
def benchmark_fn():
|
||||
arr = pmap(lambda x: x)(np.arange(np.prod(shape)).reshape(shape))
|
||||
arr = pmap(lambda x: x)(jnp.arange(jnp.prod(shape)).reshape(shape))
|
||||
indices = indices_fn()
|
||||
for idx in indices:
|
||||
arr[idx]
|
||||
|
@ -375,7 +375,7 @@ class Tracer(object):
|
||||
"You might have\n"
|
||||
" import numpy as np\n"
|
||||
"instead of\n"
|
||||
" import jax.numpy as np")
|
||||
" import jax.numpy as jnp")
|
||||
|
||||
def __init__(self, trace):
|
||||
self._trace = trace
|
||||
|
@ -107,18 +107,18 @@ class custom_jvp:
|
||||
|
||||
For example::
|
||||
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
@jax.custom_jvp
|
||||
def f(x, y):
|
||||
return np.sin(x) * y
|
||||
return jnp.sin(x) * y
|
||||
|
||||
@f.defjvp
|
||||
def f_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
x_dot, y_dot = tangents
|
||||
primal_out = f(x, y)
|
||||
tangent_out = np.cos(x) * x_dot * y - np.sin(x) * y_dot
|
||||
tangent_out = jnp.cos(x) * x_dot * y - jnp.sin(x) * y_dot
|
||||
return primal_out, tangent_out
|
||||
|
||||
For a more detailed introduction, see the tutorial_.
|
||||
@ -150,18 +150,18 @@ class custom_jvp:
|
||||
|
||||
Example::
|
||||
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
@jax.custom_jvp
|
||||
def f(x, y):
|
||||
return np.sin(x) * y
|
||||
return jnp.sin(x) * y
|
||||
|
||||
@f.defjvp
|
||||
def f_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
x_dot, y_dot = tangents
|
||||
primal_out = f(x, y)
|
||||
tangent_out = np.cos(x) * x_dot * y - np.sin(x) * y_dot
|
||||
tangent_out = jnp.cos(x) * x_dot * y - jnp.sin(x) * y_dot
|
||||
return primal_out, tangent_out
|
||||
"""
|
||||
self.jvp = jvp
|
||||
@ -184,10 +184,10 @@ class custom_jvp:
|
||||
|
||||
@jax.custom_jvp
|
||||
def f(x, y):
|
||||
return np.sin(x) * y
|
||||
return jnp.sin(x) * y
|
||||
|
||||
f.defjvps(lambda x_dot, primal_out, x, y: np.cos(x) * x_dot * y,
|
||||
lambda y_dot, primal_out, x, y: -np.sin(x) * y_dot)
|
||||
f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
|
||||
lambda y_dot, primal_out, x, y: -jnp.sin(x) * y_dot)
|
||||
"""
|
||||
if self.nondiff_argnums:
|
||||
raise TypeError("Can't use ``defjvps`` with ``nondiff_argnums``.")
|
||||
@ -370,14 +370,14 @@ class custom_vjp:
|
||||
|
||||
For example::
|
||||
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
@jax.custom_vjp
|
||||
def f(x, y):
|
||||
return np.sin(x) * y
|
||||
return jnp.sin(x) * y
|
||||
|
||||
def f_fwd(x, y):
|
||||
return f(x, y), (np.cos(x), np.sin(x), y)
|
||||
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
|
||||
|
||||
def f_bwd(res, g):
|
||||
cos_x, sin_x, y = res
|
||||
@ -424,14 +424,14 @@ class custom_vjp:
|
||||
|
||||
Example::
|
||||
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
@jax.custom_vjp
|
||||
def f(x, y):
|
||||
return np.sin(x) * y
|
||||
return jnp.sin(x) * y
|
||||
|
||||
def f_fwd(x, y):
|
||||
return f(x, y), (np.cos(x), np.sin(x), y)
|
||||
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
|
||||
|
||||
def f_bwd(res, g):
|
||||
cos_x, sin_x, y = res
|
||||
|
@ -12,10 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as onp
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax import core
|
||||
from jax.core import Trace, Tracer, new_master
|
||||
@ -76,14 +75,14 @@ def _contains_query(vals, query):
|
||||
if isinstance(query, tuple):
|
||||
return map(partial(_contains_query, vals), query)
|
||||
|
||||
if np.isnan(query):
|
||||
if np.any(np.isnan(vals)):
|
||||
if jnp.isnan(query):
|
||||
if jnp.any(jnp.isnan(vals)):
|
||||
raise FoundValue('NaN')
|
||||
elif np.isinf(query):
|
||||
if np.any(np.isinf(vals)):
|
||||
elif jnp.isinf(query):
|
||||
if jnp.any(jnp.isinf(vals)):
|
||||
raise FoundValue('Found Inf')
|
||||
elif np.isscalar(query):
|
||||
if np.any(vals == query):
|
||||
elif jnp.isscalar(query):
|
||||
if jnp.any(vals == query):
|
||||
raise FoundValue(str(query))
|
||||
else:
|
||||
raise ValueError('Malformed Query: {}'.format(query))
|
||||
|
@ -27,7 +27,7 @@ from functools import partial
|
||||
import operator as op
|
||||
|
||||
import jax
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import ops
|
||||
@ -52,12 +52,12 @@ def ravel_first_arg_(unravel, y_flat, *args):
|
||||
|
||||
def interp_fit_dopri(y0, y1, k, dt):
|
||||
# Fit a polynomial to the results of a Runge-Kutta step.
|
||||
dps_c_mid = np.array([
|
||||
dps_c_mid = jnp.array([
|
||||
6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2,
|
||||
-2691868925 / 45128329728 / 2, 187940372067 / 1594534317056 / 2,
|
||||
-1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2])
|
||||
y_mid = y0 + dt * np.dot(dps_c_mid, k)
|
||||
return np.array(fit_4th_order_polynomial(y0, y1, y_mid, k[0], k[-1], dt))
|
||||
y_mid = y0 + dt * jnp.dot(dps_c_mid, k)
|
||||
return jnp.array(fit_4th_order_polynomial(y0, y1, y_mid, k[0], k[-1], dt))
|
||||
|
||||
def fit_4th_order_polynomial(y0, y1, y_mid, dy0, dy1, dt):
|
||||
a = -2.*dt*dy0 + 2.*dt*dy1 - 8.*y0 - 8.*y1 + 16.*y_mid
|
||||
@ -71,26 +71,26 @@ def initial_step_size(fun, t0, y0, order, rtol, atol, f0):
|
||||
# Algorithm from:
|
||||
# E. Hairer, S. P. Norsett G. Wanner,
|
||||
# Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
|
||||
scale = atol + np.abs(y0) * rtol
|
||||
d0 = np.linalg.norm(y0 / scale)
|
||||
d1 = np.linalg.norm(f0 / scale)
|
||||
scale = atol + jnp.abs(y0) * rtol
|
||||
d0 = jnp.linalg.norm(y0 / scale)
|
||||
d1 = jnp.linalg.norm(f0 / scale)
|
||||
|
||||
h0 = np.where((d0 < 1e-5) | (d1 < 1e-5), 1e-6, 0.01 * d0 / d1)
|
||||
h0 = jnp.where((d0 < 1e-5) | (d1 < 1e-5), 1e-6, 0.01 * d0 / d1)
|
||||
|
||||
y1 = y0 + h0 * f0
|
||||
f1 = fun(y1, t0 + h0)
|
||||
d2 = np.linalg.norm((f1 - f0) / scale) / h0
|
||||
d2 = jnp.linalg.norm((f1 - f0) / scale) / h0
|
||||
|
||||
h1 = np.where((d1 <= 1e-15) & (d2 <= 1e-15),
|
||||
np.maximum(1e-6, h0 * 1e-3),
|
||||
(0.01 / np.max(d1 + d2)) ** (1. / (order + 1.)))
|
||||
h1 = jnp.where((d1 <= 1e-15) & (d2 <= 1e-15),
|
||||
jnp.maximum(1e-6, h0 * 1e-3),
|
||||
(0.01 / jnp.max(d1 + d2)) ** (1. / (order + 1.)))
|
||||
|
||||
return np.minimum(100. * h0, h1)
|
||||
return jnp.minimum(100. * h0, h1)
|
||||
|
||||
def runge_kutta_step(func, y0, f0, t0, dt):
|
||||
# Dopri5 Butcher tableaux
|
||||
alpha = np.array([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1., 0])
|
||||
beta = np.array([
|
||||
alpha = jnp.array([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1., 0])
|
||||
beta = jnp.array([
|
||||
[1 / 5, 0, 0, 0, 0, 0, 0],
|
||||
[3 / 40, 9 / 40, 0, 0, 0, 0, 0],
|
||||
[44 / 45, -56 / 15, 32 / 9, 0, 0, 0, 0],
|
||||
@ -98,49 +98,49 @@ def runge_kutta_step(func, y0, f0, t0, dt):
|
||||
[9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656, 0, 0],
|
||||
[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0]
|
||||
])
|
||||
c_sol = np.array([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0])
|
||||
c_error = np.array([35 / 384 - 1951 / 21600, 0, 500 / 1113 - 22642 / 50085,
|
||||
c_sol = jnp.array([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0])
|
||||
c_error = jnp.array([35 / 384 - 1951 / 21600, 0, 500 / 1113 - 22642 / 50085,
|
||||
125 / 192 - 451 / 720, -2187 / 6784 - -12231 / 42400,
|
||||
11 / 84 - 649 / 6300, -1. / 60.])
|
||||
|
||||
def body_fun(i, k):
|
||||
ti = t0 + dt * alpha[i-1]
|
||||
yi = y0 + dt * np.dot(beta[i-1, :], k)
|
||||
yi = y0 + dt * jnp.dot(beta[i-1, :], k)
|
||||
ft = func(yi, ti)
|
||||
return ops.index_update(k, jax.ops.index[i, :], ft)
|
||||
|
||||
k = ops.index_update(np.zeros((7, f0.shape[0])), ops.index[0, :], f0)
|
||||
k = ops.index_update(jnp.zeros((7, f0.shape[0])), ops.index[0, :], f0)
|
||||
k = lax.fori_loop(1, 7, body_fun, k)
|
||||
|
||||
y1 = dt * np.dot(c_sol, k) + y0
|
||||
y1_error = dt * np.dot(c_error, k)
|
||||
y1 = dt * jnp.dot(c_sol, k) + y0
|
||||
y1_error = dt * jnp.dot(c_error, k)
|
||||
f1 = k[-1]
|
||||
return y1, f1, y1_error, k
|
||||
|
||||
def error_ratio(error_estimate, rtol, atol, y0, y1):
|
||||
err_tol = atol + rtol * np.maximum(np.abs(y0), np.abs(y1))
|
||||
err_tol = atol + rtol * jnp.maximum(jnp.abs(y0), jnp.abs(y1))
|
||||
err_ratio = error_estimate / err_tol
|
||||
return np.mean(err_ratio ** 2)
|
||||
return jnp.mean(err_ratio ** 2)
|
||||
|
||||
def optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0,
|
||||
dfactor=0.2, order=5.0):
|
||||
"""Compute optimal Runge-Kutta stepsize."""
|
||||
mean_error_ratio = np.max(mean_error_ratio)
|
||||
dfactor = np.where(mean_error_ratio < 1, 1.0, dfactor)
|
||||
mean_error_ratio = jnp.max(mean_error_ratio)
|
||||
dfactor = jnp.where(mean_error_ratio < 1, 1.0, dfactor)
|
||||
|
||||
err_ratio = np.sqrt(mean_error_ratio)
|
||||
factor = np.maximum(1.0 / ifactor,
|
||||
np.minimum(err_ratio**(1.0 / order) / safety, 1.0 / dfactor))
|
||||
return np.where(mean_error_ratio == 0, last_step * ifactor, last_step / factor)
|
||||
err_ratio = jnp.sqrt(mean_error_ratio)
|
||||
factor = jnp.maximum(1.0 / ifactor,
|
||||
jnp.minimum(err_ratio**(1.0 / order) / safety, 1.0 / dfactor))
|
||||
return jnp.where(mean_error_ratio == 0, last_step * ifactor, last_step / factor)
|
||||
|
||||
def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=np.inf):
|
||||
def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
|
||||
"""Adaptive stepsize (Dormand-Prince) Runge-Kutta odeint implementation.
|
||||
|
||||
Args:
|
||||
func: function to evaluate the time derivative of the solution `y` at time
|
||||
`t` as `func(y, t, *args)`, producing the same shape/structure as `y0`.
|
||||
y0: array or pytree of arrays representing the initial value for the state.
|
||||
t: array of float times for evaluation, like `np.linspace(0., 10., 101)`,
|
||||
t: array of float times for evaluation, like `jnp.linspace(0., 10., 101)`,
|
||||
in which the values must be strictly increasing.
|
||||
*args: tuple of additional arguments for `func`, which must be arrays
|
||||
scalars, or (nested) standard Python containers (tuples, lists, dicts,
|
||||
@ -189,20 +189,20 @@ def _odeint(func, rtol, atol, mxstep, y0, ts, *args):
|
||||
|
||||
new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff]
|
||||
old = [i + 1, y, f, t, dt, last_t, interp_coeff]
|
||||
return map(partial(np.where, np.all(error_ratios <= 1.)), new, old)
|
||||
return map(partial(jnp.where, jnp.all(error_ratios <= 1.)), new, old)
|
||||
|
||||
_, *carry = lax.while_loop(cond_fun, body_fun, [0] + carry)
|
||||
_, _, t, _, last_t, interp_coeff = carry
|
||||
relative_output_time = (target_t - last_t) / (t - last_t)
|
||||
y_target = np.polyval(interp_coeff, relative_output_time)
|
||||
y_target = jnp.polyval(interp_coeff, relative_output_time)
|
||||
return carry, y_target
|
||||
|
||||
f0 = func_(y0, ts[0])
|
||||
dt = initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0)
|
||||
interp_coeff = np.array([y0] * 5)
|
||||
interp_coeff = jnp.array([y0] * 5)
|
||||
init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff]
|
||||
_, ys = lax.scan(scan_fun, init_carry, ts[1:])
|
||||
return np.concatenate((y0[None], ys))
|
||||
return jnp.concatenate((y0[None], ys))
|
||||
|
||||
def _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args):
|
||||
ys = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
|
||||
@ -226,22 +226,22 @@ def _odeint_rev(func, rtol, atol, mxstep, res, g):
|
||||
def scan_fun(carry, i):
|
||||
y_bar, t0_bar, args_bar = carry
|
||||
# Compute effect of moving measurement time
|
||||
t_bar = np.dot(func(ys[i], ts[i], *args), g[i])
|
||||
t_bar = jnp.dot(func(ys[i], ts[i], *args), g[i])
|
||||
t0_bar = t0_bar - t_bar
|
||||
# Run augmented system backwards to previous observation
|
||||
_, y_bar, t0_bar, args_bar = odeint(
|
||||
aug_dynamics, (ys[i], y_bar, t0_bar, args_bar),
|
||||
np.array([-ts[i], -ts[i - 1]]),
|
||||
jnp.array([-ts[i], -ts[i - 1]]),
|
||||
*args, rtol=rtol, atol=atol, mxstep=mxstep)
|
||||
y_bar, t0_bar, args_bar = tree_map(op.itemgetter(1), (y_bar, t0_bar, args_bar))
|
||||
# Add gradient from current output
|
||||
y_bar = y_bar + g[i - 1]
|
||||
return (y_bar, t0_bar, args_bar), t_bar
|
||||
|
||||
init_carry = (g[-1], 0., tree_map(np.zeros_like, args))
|
||||
init_carry = (g[-1], 0., tree_map(jnp.zeros_like, args))
|
||||
(y_bar, t0_bar, args_bar), rev_ts_bar = lax.scan(
|
||||
scan_fun, init_carry, np.arange(len(ts) - 1, 0, -1))
|
||||
ts_bar = np.concatenate([np.array([t0_bar]), rev_ts_bar[::-1]])
|
||||
scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1))
|
||||
ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]])
|
||||
return (y_bar, ts_bar, *args_bar)
|
||||
|
||||
_odeint.defvjp(_odeint_fwd, _odeint_rev)
|
||||
|
@ -71,7 +71,7 @@ from collections import namedtuple
|
||||
import functools
|
||||
import operator
|
||||
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax.util import partial, safe_zip, safe_map, unzip2
|
||||
from jax import tree_util
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
@ -208,7 +208,7 @@ def momentum(step_size, mass):
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
v0 = np.zeros_like(x0)
|
||||
v0 = jnp.zeros_like(x0)
|
||||
return x0, v0
|
||||
def update(i, g, state):
|
||||
x, velocity = state
|
||||
@ -235,7 +235,7 @@ def nesterov(step_size, mass):
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
v0 = np.zeros_like(x0)
|
||||
v0 = jnp.zeros_like(x0)
|
||||
return x0, v0
|
||||
def update(i, g, state):
|
||||
x, velocity = state
|
||||
@ -266,14 +266,14 @@ def adagrad(step_size, momentum=0.9):
|
||||
step_size = make_schedule(step_size)
|
||||
|
||||
def init(x0):
|
||||
g_sq = np.zeros_like(x0)
|
||||
m = np.zeros_like(x0)
|
||||
g_sq = jnp.zeros_like(x0)
|
||||
m = jnp.zeros_like(x0)
|
||||
return x0, g_sq, m
|
||||
|
||||
def update(i, g, state):
|
||||
x, g_sq, m = state
|
||||
g_sq += g**2
|
||||
g_sq_inv_sqrt = np.where(g_sq > 0, 1. / np.sqrt(g_sq), 0.0)
|
||||
g_sq_inv_sqrt = jnp.where(g_sq > 0, 1. / jnp.sqrt(g_sq), 0.0)
|
||||
m = (1. - momentum) * (g * g_sq_inv_sqrt) + momentum * m
|
||||
x = x - step_size(i) * m
|
||||
return x, g_sq, m
|
||||
@ -300,12 +300,12 @@ def rmsprop(step_size, gamma=0.9, eps=1e-8):
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
avg_sq_grad = np.zeros_like(x0)
|
||||
avg_sq_grad = jnp.zeros_like(x0)
|
||||
return x0, avg_sq_grad
|
||||
def update(i, g, state):
|
||||
x, avg_sq_grad = state
|
||||
avg_sq_grad = avg_sq_grad * gamma + g**2 * (1. - gamma)
|
||||
x = x - step_size(i) * g / np.sqrt(avg_sq_grad + eps)
|
||||
x = x - step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
|
||||
return x, avg_sq_grad
|
||||
def get_params(state):
|
||||
x, _ = state
|
||||
@ -332,13 +332,13 @@ def rmsprop_momentum(step_size, gamma=0.9, eps=1e-8, momentum=0.9):
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
avg_sq_grad = np.zeros_like(x0)
|
||||
mom = np.zeros_like(x0)
|
||||
avg_sq_grad = jnp.zeros_like(x0)
|
||||
mom = jnp.zeros_like(x0)
|
||||
return x0, avg_sq_grad, mom
|
||||
def update(i, g, state):
|
||||
x, avg_sq_grad, mom = state
|
||||
avg_sq_grad = avg_sq_grad * gamma + g**2 * (1. - gamma)
|
||||
mom = momentum * mom + step_size(i) * g / np.sqrt(avg_sq_grad + eps)
|
||||
mom = momentum * mom + step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
|
||||
x = x - mom
|
||||
return x, avg_sq_grad, mom
|
||||
def get_params(state):
|
||||
@ -366,8 +366,8 @@ def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
m0 = np.zeros_like(x0)
|
||||
v0 = np.zeros_like(x0)
|
||||
m0 = jnp.zeros_like(x0)
|
||||
v0 = jnp.zeros_like(x0)
|
||||
return x0, m0, v0
|
||||
def update(i, g, state):
|
||||
x, m, v = state
|
||||
@ -375,7 +375,7 @@ def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
||||
v = (1 - b2) * (g ** 2) + b2 * v # Second moment estimate.
|
||||
mhat = m / (1 - b1 ** (i + 1)) # Bias correction.
|
||||
vhat = v / (1 - b2 ** (i + 1))
|
||||
x = x - step_size(i) * mhat / (np.sqrt(vhat) + eps)
|
||||
x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
|
||||
return x, m, v
|
||||
def get_params(state):
|
||||
x, m, v = state
|
||||
@ -402,13 +402,13 @@ def adamax(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
m0 = np.zeros_like(x0)
|
||||
u0 = np.zeros_like(x0)
|
||||
m0 = jnp.zeros_like(x0)
|
||||
u0 = jnp.zeros_like(x0)
|
||||
return x0, m0, u0
|
||||
def update(i, g, state):
|
||||
x, m, u = state
|
||||
m = (1 - b1) * g + b1 * m # First moment estimate.
|
||||
u = np.maximum(b2 * u, np.abs(g)) # Update exponentially weighted infinity norm.
|
||||
u = jnp.maximum(b2 * u, jnp.abs(g)) # Update exponentially weighted infinity norm.
|
||||
x = x - (step_size(i) / (1 - b1 ** (i + 1))) * m / (u + eps)
|
||||
return x, m, u
|
||||
def get_params(state):
|
||||
@ -444,14 +444,14 @@ def sm3(step_size, momentum=0.9):
|
||||
return x[tuple(idx)]
|
||||
|
||||
def init(x0):
|
||||
vs = [np.zeros(sz, dtype=x0.dtype) for sz in x0.shape]
|
||||
return x0, np.zeros_like(x0), vs
|
||||
vs = [jnp.zeros(sz, dtype=x0.dtype) for sz in x0.shape]
|
||||
return x0, jnp.zeros_like(x0), vs
|
||||
|
||||
def update(i, g, state):
|
||||
x, m, vs = state
|
||||
vs = [broadcast_into(g.ndim, v, i) for i, v in enumerate(vs)]
|
||||
accum = functools.reduce(np.minimum, vs) + g ** 2
|
||||
accum_inv_sqrt = np.where(accum > 0, 1. / np.sqrt(accum), 0)
|
||||
accum = functools.reduce(jnp.minimum, vs) + g ** 2
|
||||
accum_inv_sqrt = jnp.where(accum > 0, 1. / jnp.sqrt(accum), 0)
|
||||
m = (1. - momentum) * (g * accum_inv_sqrt) + momentum * m
|
||||
x = x - step_size(i) * m
|
||||
vs = [accum.max(splice(range(x.ndim), j, [])) for j in range(x.ndim)]
|
||||
@ -479,7 +479,7 @@ def exponential_decay(step_size, decay_steps, decay_rate):
|
||||
def inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False):
|
||||
if staircase:
|
||||
def schedule(i):
|
||||
return step_size / (1 + decay_rate * np.floor(i / decay_steps))
|
||||
return step_size / (1 + decay_rate * jnp.floor(i / decay_steps))
|
||||
else:
|
||||
def schedule(i):
|
||||
return step_size / (1 + decay_rate * i / decay_steps)
|
||||
@ -487,28 +487,28 @@ def inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False):
|
||||
|
||||
def polynomial_decay(step_size, decay_steps, final_step_size, power=1.0):
|
||||
def schedule(step_num):
|
||||
step_num = np.minimum(step_num, decay_steps)
|
||||
step_num = jnp.minimum(step_num, decay_steps)
|
||||
step_mult = (1 - step_num / decay_steps) ** power
|
||||
return step_mult * (step_size - final_step_size) + final_step_size
|
||||
|
||||
return schedule
|
||||
|
||||
def piecewise_constant(boundaries, values):
|
||||
boundaries = np.array(boundaries)
|
||||
values = np.array(values)
|
||||
boundaries = jnp.array(boundaries)
|
||||
values = jnp.array(values)
|
||||
if not boundaries.ndim == values.ndim == 1:
|
||||
raise ValueError("boundaries and values must be sequences")
|
||||
if not boundaries.shape[0] == values.shape[0] - 1:
|
||||
raise ValueError("boundaries length must be one longer than values length")
|
||||
|
||||
def schedule(i):
|
||||
return values[np.sum(i > boundaries)]
|
||||
return values[jnp.sum(i > boundaries)]
|
||||
return schedule
|
||||
|
||||
def make_schedule(scalar_or_schedule):
|
||||
if callable(scalar_or_schedule):
|
||||
return scalar_or_schedule
|
||||
elif np.ndim(scalar_or_schedule) == 0:
|
||||
elif jnp.ndim(scalar_or_schedule) == 0:
|
||||
return constant(scalar_or_schedule)
|
||||
else:
|
||||
raise TypeError(type(scalar_or_schedule))
|
||||
@ -519,12 +519,12 @@ def make_schedule(scalar_or_schedule):
|
||||
def l2_norm(tree):
|
||||
"""Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
|
||||
leaves, _ = tree_flatten(tree)
|
||||
return np.sqrt(sum(np.vdot(x, x) for x in leaves))
|
||||
return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
|
||||
|
||||
def clip_grads(grad_tree, max_norm):
|
||||
"""Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
|
||||
norm = l2_norm(grad_tree)
|
||||
normalize = lambda g: np.where(norm < max_norm, g, g * (max_norm / norm))
|
||||
normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm))
|
||||
return tree_map(normalize, grad_tree)
|
||||
|
||||
|
||||
|
@ -22,11 +22,9 @@ import functools
|
||||
import itertools
|
||||
import operator as op
|
||||
|
||||
import numpy as onp
|
||||
|
||||
from jax import lax
|
||||
from jax import random
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu,
|
||||
leaky_relu, selu, gelu, normalize)
|
||||
@ -56,7 +54,7 @@ def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
|
||||
return output_shape, (W, b)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
W, b = params
|
||||
return np.dot(inputs, W) + b
|
||||
return jnp.dot(inputs, W) + b
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
@ -123,7 +121,7 @@ def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
|
||||
"""Layer construction function for a batch normalization layer."""
|
||||
_beta_init = lambda rng, shape: beta_init(rng, shape) if center else ()
|
||||
_gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else ()
|
||||
axis = (axis,) if np.isscalar(axis) else axis
|
||||
axis = (axis,) if jnp.isscalar(axis) else axis
|
||||
def init_fun(rng, input_shape):
|
||||
shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
|
||||
k1, k2 = random.split(rng)
|
||||
@ -131,9 +129,9 @@ def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
|
||||
return input_shape, (beta, gamma)
|
||||
def apply_fun(params, x, **kwargs):
|
||||
beta, gamma = params
|
||||
# TODO(phawkins): np.expand_dims should accept an axis tuple.
|
||||
# TODO(phawkins): jnp.expand_dims should accept an axis tuple.
|
||||
# (https://github.com/numpy/numpy/issues/12290)
|
||||
ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
|
||||
ed = tuple(None if i in axis else slice(None) for i in range(jnp.ndim(x)))
|
||||
z = normalize(x, axis, epsilon=epsilon)
|
||||
if center and scale: return gamma[ed] * z + beta[ed]
|
||||
if center: return z + beta[ed]
|
||||
@ -147,9 +145,9 @@ def elementwise(fun, **fun_kwargs):
|
||||
init_fun = lambda rng, input_shape: (input_shape, ())
|
||||
apply_fun = lambda params, inputs, **kwargs: fun(inputs, **fun_kwargs)
|
||||
return init_fun, apply_fun
|
||||
Tanh = elementwise(np.tanh)
|
||||
Tanh = elementwise(jnp.tanh)
|
||||
Relu = elementwise(relu)
|
||||
Exp = elementwise(np.exp)
|
||||
Exp = elementwise(jnp.exp)
|
||||
LogSoftmax = elementwise(log_softmax, axis=-1)
|
||||
Softmax = elementwise(softmax, axis=-1)
|
||||
Softplus = elementwise(softplus)
|
||||
@ -185,7 +183,7 @@ def _pooling_layer(reducer, init_val, rescaler=None):
|
||||
return rescale(out, inputs, spec) if rescale else out
|
||||
return init_fun, apply_fun
|
||||
return PoolingLayer
|
||||
MaxPool = _pooling_layer(lax.max, -np.inf)
|
||||
MaxPool = _pooling_layer(lax.max, -jnp.inf)
|
||||
SumPool = _pooling_layer(lax.add, 0.)
|
||||
|
||||
|
||||
@ -199,10 +197,10 @@ def _normalize_by_window_size(dims, strides, padding):
|
||||
spatial_shape = tuple(inputs.shape[i]
|
||||
for i in range(inputs.ndim)
|
||||
if i not in non_spatial_axes)
|
||||
one = np.ones(spatial_shape, dtype=inputs.dtype)
|
||||
one = jnp.ones(spatial_shape, dtype=inputs.dtype)
|
||||
window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides, padding)
|
||||
for i in sorted(non_spatial_axes):
|
||||
window_sizes = np.expand_dims(window_sizes, i)
|
||||
window_sizes = jnp.expand_dims(window_sizes, i)
|
||||
|
||||
return outputs / window_sizes
|
||||
return rescale
|
||||
@ -215,7 +213,7 @@ def Flatten():
|
||||
output_shape = input_shape[0], functools.reduce(op.mul, input_shape[1:], 1)
|
||||
return output_shape, ()
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
return np.reshape(inputs, (inputs.shape[0], -1))
|
||||
return jnp.reshape(inputs, (inputs.shape[0], -1))
|
||||
return init_fun, apply_fun
|
||||
Flatten = Flatten()
|
||||
|
||||
@ -251,7 +249,7 @@ def FanInConcat(axis=-1):
|
||||
out_shape = input_shape[0][:ax] + (concat_size,) + input_shape[0][ax+1:]
|
||||
return out_shape, ()
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
return np.concatenate(inputs, axis)
|
||||
return jnp.concatenate(inputs, axis)
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
@ -269,7 +267,7 @@ def Dropout(rate, mode='train'):
|
||||
raise ValueError(msg)
|
||||
if mode == 'train':
|
||||
keep = random.bernoulli(rng, rate, inputs.shape)
|
||||
return np.where(keep, inputs / rate, 0)
|
||||
return jnp.where(keep, inputs / rate, 0)
|
||||
else:
|
||||
return inputs
|
||||
return init_fun, apply_fun
|
||||
|
@ -17,7 +17,7 @@ from .tree_util import tree_flatten, tree_unflatten
|
||||
from . import linear_util as lu
|
||||
from .util import safe_zip
|
||||
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax.api import vjp
|
||||
|
||||
zip = safe_zip
|
||||
@ -30,7 +30,7 @@ def ravel_pytree(pytree):
|
||||
return flat, unravel_pytree
|
||||
|
||||
def ravel_list(*lst):
|
||||
return np.concatenate([np.ravel(elt) for elt in lst]) if lst else np.array([])
|
||||
return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([])
|
||||
|
||||
|
||||
@lu.transformation_with_aux
|
||||
|
@ -15,13 +15,13 @@
|
||||
"""Shared neural network activations and other functions."""
|
||||
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from jax import custom_jvp
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax.scipy.special import expit
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
# activations
|
||||
|
||||
@ -34,7 +34,7 @@ def relu(x):
|
||||
.. math::
|
||||
\mathrm{relu}(x) = \max(x, 0)
|
||||
"""
|
||||
return np.maximum(x, 0)
|
||||
return jnp.maximum(x, 0)
|
||||
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
|
||||
|
||||
def softplus(x):
|
||||
@ -45,7 +45,7 @@ def softplus(x):
|
||||
.. math::
|
||||
\mathrm{softplus}(x) = \log(1 + e^x)
|
||||
"""
|
||||
return np.logaddexp(x, 0)
|
||||
return jnp.logaddexp(x, 0)
|
||||
|
||||
def soft_sign(x):
|
||||
r"""Soft-sign activation function.
|
||||
@ -55,7 +55,7 @@ def soft_sign(x):
|
||||
.. math::
|
||||
\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
|
||||
"""
|
||||
return x / (np.abs(x) + 1)
|
||||
return x / (jnp.abs(x) + 1)
|
||||
|
||||
def sigmoid(x):
|
||||
r"""Sigmoid activation function.
|
||||
@ -98,8 +98,8 @@ def elu(x, alpha=1.0):
|
||||
\alpha \exp(x - 1), & x \le 0
|
||||
\end{cases}
|
||||
"""
|
||||
safe_x = np.where(x > 0, 0., x)
|
||||
return np.where(x > 0, x, alpha * np.expm1(safe_x))
|
||||
safe_x = jnp.where(x > 0, 0., x)
|
||||
return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x))
|
||||
|
||||
def leaky_relu(x, negative_slope=1e-2):
|
||||
r"""Leaky rectified linear unit activation function.
|
||||
@ -114,7 +114,7 @@ def leaky_relu(x, negative_slope=1e-2):
|
||||
|
||||
where :math:`\alpha` = :code:`negative_slope`.
|
||||
"""
|
||||
return np.where(x >= 0, x, negative_slope * x)
|
||||
return jnp.where(x >= 0, x, negative_slope * x)
|
||||
|
||||
def hard_tanh(x):
|
||||
r"""Hard :math:`\mathrm{tanh}` activation function.
|
||||
@ -128,7 +128,7 @@ def hard_tanh(x):
|
||||
1, & 1 < x
|
||||
\end{cases}
|
||||
"""
|
||||
return np.where(x > 1, 1, np.where(x < -1, -1, x))
|
||||
return jnp.where(x > 1, 1, jnp.where(x < -1, -1, x))
|
||||
|
||||
def celu(x, alpha=1.0):
|
||||
r"""Continuously-differentiable exponential linear unit activation.
|
||||
@ -144,7 +144,7 @@ def celu(x, alpha=1.0):
|
||||
For more information, see
|
||||
`Continuously Differentiable Exponential Linear Units
|
||||
<https://arxiv.org/pdf/1704.07483.pdf>`_."""
|
||||
return np.where(x > 0, x, alpha * np.expm1(x / alpha))
|
||||
return jnp.where(x > 0, x, alpha * jnp.expm1(x / alpha))
|
||||
|
||||
def selu(x):
|
||||
r"""Scaled exponential linear unit activation.
|
||||
@ -181,15 +181,15 @@ def gelu(x):
|
||||
speed. For more information, see `Gaussian Error Linear Units (GELUs)
|
||||
<https://arxiv.org/abs/1606.08415>`_, section 2.
|
||||
"""
|
||||
sqrt_2_over_pi = onp.sqrt(2 / onp.pi).astype(x.dtype)
|
||||
cdf = 0.5 * (1.0 + np.tanh(sqrt_2_over_pi * (x + 0.044715 * x**3)))
|
||||
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
|
||||
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * x**3)))
|
||||
return x * cdf
|
||||
|
||||
def glu(x, axis=-1):
|
||||
"""Gated linear unit activation function."""
|
||||
size = x.shape[axis]
|
||||
assert size % 2 == 0, "axis size must be divisible by 2"
|
||||
x1, x2 = np.split(x, 2, axis)
|
||||
x1, x2 = jnp.split(x, 2, axis)
|
||||
return x1 * sigmoid(x2)
|
||||
|
||||
# other functions
|
||||
@ -209,7 +209,7 @@ def log_softmax(x, axis=-1):
|
||||
computed. Either an integer or a tuple of integers.
|
||||
"""
|
||||
shifted = x - x.max(axis, keepdims=True)
|
||||
return shifted - np.log(np.sum(np.exp(shifted), axis, keepdims=True))
|
||||
return shifted - jnp.log(jnp.sum(jnp.exp(shifted), axis, keepdims=True))
|
||||
|
||||
def softmax(x, axis=-1):
|
||||
r"""Softmax function.
|
||||
@ -225,35 +225,35 @@ def softmax(x, axis=-1):
|
||||
softmax output summed across these dimensions should sum to :math:`1`.
|
||||
Either an integer or a tuple of integers.
|
||||
"""
|
||||
unnormalized = np.exp(x - x.max(axis, keepdims=True))
|
||||
unnormalized = jnp.exp(x - x.max(axis, keepdims=True))
|
||||
return unnormalized / unnormalized.sum(axis, keepdims=True)
|
||||
|
||||
def normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-5):
|
||||
"""Normalizes an array by subtracting mean and dividing by sqrt(var)."""
|
||||
if mean is None:
|
||||
mean = np.mean(x, axis, keepdims=True)
|
||||
mean = jnp.mean(x, axis, keepdims=True)
|
||||
if variance is None:
|
||||
# this definition is traditionally seen as less accurate than np.var's
|
||||
# this definition is traditionally seen as less accurate than jnp.var's
|
||||
# mean((x - mean(x))**2) but may be faster and even, given typical
|
||||
# activation distributions and low-precision arithmetic, more accurate
|
||||
# when used in neural network normalization layers
|
||||
variance = np.mean(x**2, axis, keepdims=True) - mean**2
|
||||
variance = jnp.mean(x**2, axis, keepdims=True) - mean**2
|
||||
return (x - mean) * lax.rsqrt(variance + epsilon)
|
||||
|
||||
def one_hot(x, num_classes, *, dtype=np.float64):
|
||||
def one_hot(x, num_classes, *, dtype=jnp.float64):
|
||||
"""One-hot encodes the given indicies.
|
||||
|
||||
Each index in the input ``x`` is encoded as a vector of zeros of length
|
||||
``num_classes`` with the element at ``index`` set to one::
|
||||
|
||||
>>> jax.nn.one_hot(np.array([0, 1, 2]), 3)
|
||||
>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
|
||||
DeviceArray([[1., 0., 0.],
|
||||
[0., 1., 0.],
|
||||
[0., 0., 1.]], dtype=float32)
|
||||
|
||||
Indicies outside the range [0, num_classes) will be encoded as zeros::
|
||||
|
||||
>>> jax.nn.one_hot(np.array([-1, 3]), 3)
|
||||
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
|
||||
DeviceArray([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)
|
||||
|
||||
@ -264,10 +264,10 @@ def one_hot(x, num_classes, *, dtype=np.float64):
|
||||
jax_enable_x64 is true, otherwise float32).
|
||||
"""
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
x = np.asarray(x)
|
||||
lhs = x[..., np.newaxis]
|
||||
rhs = lax.broadcast_to_rank(np.arange(num_classes, dtype=x.dtype), lhs.ndim)
|
||||
return np.array(lhs == rhs, dtype=dtype)
|
||||
x = jnp.asarray(x)
|
||||
lhs = x[..., jnp.newaxis]
|
||||
rhs = lax.broadcast_to_rank(jnp.arange(num_classes, dtype=x.dtype), lhs.ndim)
|
||||
return jnp.array(lhs == rhs, dtype=dtype)
|
||||
|
||||
def relu6(x):
|
||||
r"""Rectified Linear Unit 6 activation function.
|
||||
@ -277,7 +277,7 @@ def relu6(x):
|
||||
.. math::
|
||||
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
|
||||
"""
|
||||
return np.minimum(np.maximum(x, 0), 6.)
|
||||
return jnp.minimum(jnp.maximum(x, 0), 6.)
|
||||
|
||||
def hard_sigmoid(x):
|
||||
r"""Hard Sigmoid activation function.
|
||||
|
@ -20,33 +20,33 @@ used in Keras and Sonnet.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import ops
|
||||
from jax import random
|
||||
|
||||
def zeros(key, shape, dtype=np.float32): return np.zeros(shape, dtype)
|
||||
def ones(key, shape, dtype=np.float32): return np.ones(shape, dtype)
|
||||
def zeros(key, shape, dtype=jnp.float32): return jnp.zeros(shape, dtype)
|
||||
def ones(key, shape, dtype=jnp.float32): return jnp.ones(shape, dtype)
|
||||
|
||||
def uniform(scale=1e-2, dtype=np.float32):
|
||||
def uniform(scale=1e-2, dtype=jnp.float32):
|
||||
def init(key, shape, dtype=dtype):
|
||||
return random.uniform(key, shape, dtype) * scale
|
||||
return init
|
||||
|
||||
def normal(stddev=1e-2, dtype=np.float32):
|
||||
def normal(stddev=1e-2, dtype=jnp.float32):
|
||||
def init(key, shape, dtype=dtype):
|
||||
return random.normal(key, shape, dtype) * stddev
|
||||
return init
|
||||
|
||||
def _compute_fans(shape, in_axis=-2, out_axis=-1):
|
||||
receptive_field_size = onp.prod(shape) / shape[in_axis] / shape[out_axis]
|
||||
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
||||
fan_in = shape[in_axis] * receptive_field_size
|
||||
fan_out = shape[out_axis] * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
|
||||
def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=np.float32):
|
||||
def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float32):
|
||||
def init(key, shape, dtype=dtype):
|
||||
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
||||
if mode == "fan_in": denominator = fan_in
|
||||
@ -55,15 +55,15 @@ def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=n
|
||||
else:
|
||||
raise ValueError(
|
||||
"invalid mode for variance scaling initializer: {}".format(mode))
|
||||
variance = np.array(scale / denominator, dtype=dtype)
|
||||
variance = jnp.array(scale / denominator, dtype=dtype)
|
||||
if distribution == "truncated_normal":
|
||||
# constant is stddev of standard normal truncated to (-2, 2)
|
||||
stddev = np.sqrt(variance) / np.array(.87962566103423978, dtype)
|
||||
stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype)
|
||||
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
|
||||
elif distribution == "normal":
|
||||
return random.normal(key, shape, dtype) * np.sqrt(variance)
|
||||
return random.normal(key, shape, dtype) * jnp.sqrt(variance)
|
||||
elif distribution == "uniform":
|
||||
return random.uniform(key, shape, dtype, -1) * onp.sqrt(3 * variance)
|
||||
return random.uniform(key, shape, dtype, -1) * np.sqrt(3 * variance)
|
||||
else:
|
||||
raise ValueError("invalid distribution for variance scaling initializer")
|
||||
return init
|
||||
@ -75,7 +75,7 @@ lecun_normal = partial(variance_scaling, 1.0, "fan_in", "truncated_normal")
|
||||
kaiming_uniform = he_uniform = partial(variance_scaling, 2.0, "fan_in", "uniform")
|
||||
kaiming_normal = he_normal = partial(variance_scaling, 2.0, "fan_in", "truncated_normal")
|
||||
|
||||
def orthogonal(scale=1.0, column_axis=-1, dtype=np.float32):
|
||||
def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
|
||||
"""
|
||||
Construct an initializer for uniformly distributed orthogonal matrices.
|
||||
|
||||
@ -85,20 +85,20 @@ def orthogonal(scale=1.0, column_axis=-1, dtype=np.float32):
|
||||
def init(key, shape, dtype=dtype):
|
||||
if len(shape) < 2:
|
||||
raise ValueError("orthogonal initializer requires at least a 2D shape")
|
||||
n_rows, n_cols = onp.prod(shape) // shape[column_axis], shape[column_axis]
|
||||
n_rows, n_cols = np.prod(shape) // shape[column_axis], shape[column_axis]
|
||||
matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)
|
||||
A = random.normal(key, matrix_shape, dtype)
|
||||
Q, R = np.linalg.qr(A)
|
||||
diag_sign = lax.broadcast_to_rank(np.sign(np.diag(R)), rank=Q.ndim)
|
||||
Q, R = jnp.linalg.qr(A)
|
||||
diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim)
|
||||
Q *= diag_sign # needed for a uniform distribution
|
||||
if n_rows < n_cols: Q = Q.T
|
||||
Q = np.reshape(Q, tuple(onp.delete(shape, column_axis)) + (shape[column_axis],))
|
||||
Q = np.moveaxis(Q, -1, column_axis)
|
||||
Q = jnp.reshape(Q, tuple(np.delete(shape, column_axis)) + (shape[column_axis],))
|
||||
Q = jnp.moveaxis(Q, -1, column_axis)
|
||||
return scale * Q
|
||||
return init
|
||||
|
||||
|
||||
def delta_orthogonal(scale=1.0, column_axis=-1, dtype=np.float32):
|
||||
def delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
|
||||
"""
|
||||
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
|
||||
|
||||
@ -112,7 +112,7 @@ def delta_orthogonal(scale=1.0, column_axis=-1, dtype=np.float32):
|
||||
raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
|
||||
ortho_init = orthogonal(scale=scale, column_axis=column_axis, dtype=dtype)
|
||||
ortho_matrix = ortho_init(key, shape[-2:])
|
||||
W = np.zeros(shape, dtype=dtype)
|
||||
W = jnp.zeros(shape, dtype=dtype)
|
||||
if len(shape) == 3:
|
||||
k = shape[0]
|
||||
return ops.index_update(W, ops.index[(k-1)//2, ...], ortho_matrix)
|
||||
|
@ -26,7 +26,7 @@ from .. import ops as jaxops
|
||||
|
||||
def _fft_core(func_name, fft_type, a, s, axes, norm):
|
||||
# TODO(skye): implement padding/cropping based on 's'.
|
||||
full_name = "jax.np.fft." + func_name
|
||||
full_name = "jax.numpy.fft." + func_name
|
||||
if s is not None:
|
||||
raise NotImplementedError("%s only supports s=None, got %s" % (full_name, s))
|
||||
if norm is not None:
|
||||
@ -93,7 +93,7 @@ def irfftn(a, s=None, axes=None, norm=None):
|
||||
|
||||
|
||||
def _fft_core_1d(func_name, fft_type, a, s, axis, norm):
|
||||
full_name = "jax.np.fft." + func_name
|
||||
full_name = "jax.numpy.fft." + func_name
|
||||
if isinstance(axis, (list, tuple)):
|
||||
raise ValueError(
|
||||
"%s does not support multiple axes. Please use %sn. "
|
||||
@ -125,7 +125,7 @@ def irfft(a, n=None, axis=-1, norm=None):
|
||||
|
||||
|
||||
def _fft_core_2d(func_name, fft_type, a, s, axes, norm):
|
||||
full_name = "jax.np.fft." + func_name
|
||||
full_name = "jax.numpy.fft." + func_name
|
||||
if len(axes) != 2:
|
||||
raise ValueError(
|
||||
"%s only supports 2 axes. Got axes = %r."
|
||||
@ -159,12 +159,12 @@ def irfft2(a, s=None, axes=(-2,-1), norm=None):
|
||||
def fftfreq(n, d=1.0):
|
||||
if isinstance(n, list) or isinstance(n, tuple):
|
||||
raise ValueError(
|
||||
"The n argument of jax.np.fft.fftfreq only takes an int. "
|
||||
"The n argument of jax.numpy.fft.fftfreq only takes an int. "
|
||||
"Got n = %s." % list(n))
|
||||
|
||||
elif isinstance(d, list) or isinstance(d, tuple):
|
||||
raise ValueError(
|
||||
"The d argument of jax.np.fft.fftfreq only takes a single value. "
|
||||
"The d argument of jax.numpy.fft.fftfreq only takes a single value. "
|
||||
"Got d = %s." % list(d))
|
||||
|
||||
k = np.zeros(n)
|
||||
@ -191,12 +191,12 @@ def fftfreq(n, d=1.0):
|
||||
def rfftfreq(n, d=1.0):
|
||||
if isinstance(n, list) or isinstance(n, tuple):
|
||||
raise ValueError(
|
||||
"The n argument of jax.np.fft.rfftfreq only takes an int. "
|
||||
"The n argument of jax.numpy.fft.rfftfreq only takes an int. "
|
||||
"Got n = %s." % list(n))
|
||||
|
||||
elif isinstance(d, list) or isinstance(d, tuple):
|
||||
raise ValueError(
|
||||
"The d argument of jax.np.fft.rfftfreq only takes a single value. "
|
||||
"The d argument of jax.numpy.fft.rfftfreq only takes a single value. "
|
||||
"Got d = %s." % list(d))
|
||||
|
||||
if n % 2 == 0:
|
||||
|
@ -28,10 +28,10 @@ See tensorflow/compiler/xla/service/hlo_runner.h.
|
||||
Usage:
|
||||
|
||||
$ cat prog.py
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
def fn(x, y, z):
|
||||
return np.dot(x, y) / z
|
||||
return jnp.dot(x, y) / z
|
||||
|
||||
$ python jax_to_hlo.py \
|
||||
--fn prog.fn \
|
||||
@ -48,7 +48,7 @@ The order of elements in input_shapes determines the order of parameters in the
|
||||
resulting HLO program.
|
||||
|
||||
Values of `constants` which are lists are converted to Numpy arrays using
|
||||
np.asarray. In addition, you can specify constants using the flag
|
||||
jnp.asarray. In addition, you can specify constants using the flag
|
||||
--evaled_constants; values there that are strings are first evaluated using
|
||||
ast.literal_eval. --evaled_constants is primarily useful for genrules; Skylark
|
||||
doesn't support floating-point types, so genrules need to deal in strings.
|
||||
@ -71,7 +71,7 @@ import functools
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import jax.api
|
||||
import jax.numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax.lib import xla_client
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
@ -119,7 +119,7 @@ def jax_to_hlo(fn, input_shapes, constants=None):
|
||||
raise ValueError('Shape %s has a non-default layout, but only '
|
||||
'the default layout is allowed.' % str(shape))
|
||||
|
||||
args.append(np.zeros(shape.dimensions(), dtype=shape.numpy_dtype()))
|
||||
args.append(jnp.zeros(shape.dimensions(), dtype=shape.numpy_dtype()))
|
||||
|
||||
# Curry `constants` into the function.
|
||||
fn_curried = functools.partial(fn, **constants)
|
||||
@ -153,14 +153,14 @@ def main(argv):
|
||||
constants = {}
|
||||
for k, v in literal_eval(FLAGS.constants).items():
|
||||
if isinstance(v, list):
|
||||
v = np.asarray(v)
|
||||
v = jnp.asarray(v)
|
||||
constants[k] = v
|
||||
|
||||
for k, v in literal_eval(FLAGS.evaled_constants).items():
|
||||
if isinstance(v, str):
|
||||
v = literal_eval(v)
|
||||
if isinstance(v, list):
|
||||
v = np.asarray(v)
|
||||
v = jnp.asarray(v)
|
||||
if k in constants:
|
||||
raise ValueError(
|
||||
'Argument appears in both --constants and --evaled_constants: %s' % k)
|
||||
|
@ -180,7 +180,7 @@ class APITest(jtu.JaxTestCase):
|
||||
jtu.check_raises(lambda: grad(f)(np.zeros(3)), Exception,
|
||||
"Tracer can't be used with raw numpy functions. "
|
||||
"You might have\n import numpy as np\ninstead of\n"
|
||||
" import jax.numpy as np")
|
||||
" import jax.numpy as jnp")
|
||||
|
||||
def test_binop_mismatch(self):
|
||||
def f(x, y):
|
||||
|
@ -18,12 +18,12 @@ import gc
|
||||
import itertools as it
|
||||
import operator
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from jax import core
|
||||
from jax import numpy as np
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.api import jvp, linearize, vjp, jit
|
||||
from jax.lax import UnshapedArray, ShapedArray, ConcreteArray
|
||||
@ -35,24 +35,24 @@ from jax.interpreters import partial_eval as pe
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
_ = pe.PartialVal.unknown(UnshapedArray(onp.float32))
|
||||
__ = pe.PartialVal.unknown(ShapedArray((), onp.float32))
|
||||
_ = pe.PartialVal.unknown(UnshapedArray(np.float32))
|
||||
__ = pe.PartialVal.unknown(ShapedArray((), np.float32))
|
||||
|
||||
def call(f, *args):
|
||||
return jit(f)(*args)
|
||||
|
||||
def simple_fun(x, y):
|
||||
return np.sin(x * y)
|
||||
return jnp.sin(x * y)
|
||||
|
||||
def simple_fun_fanout(x, y):
|
||||
return np.sin(x * y) * x
|
||||
return jnp.sin(x * y) * x
|
||||
|
||||
def fun_with_call(x):
|
||||
return call(np.sin, x)
|
||||
return call(jnp.sin, x)
|
||||
|
||||
def fun_with_nested_calls(x):
|
||||
def f(y):
|
||||
y2 = np.sin(y) + 1.0 + (2.0 * x)
|
||||
y2 = jnp.sin(y) + 1.0 + (2.0 * x)
|
||||
|
||||
@jit
|
||||
def g(z):
|
||||
@ -73,7 +73,7 @@ def fun_with_nested_calls_2(x):
|
||||
q = call(lambda x: y, x)
|
||||
q = q + call(lambda: y)
|
||||
q = q + call(lambda y: w + y, y)
|
||||
q = call(lambda w: call(np.sin, x) * y, 1.0) + q
|
||||
q = call(lambda w: call(jnp.sin, x) * y, 1.0) + q
|
||||
return q
|
||||
p, t = jvp(baz, (x + 1.0,), (y,))
|
||||
return t + (x * p)
|
||||
@ -87,22 +87,22 @@ def fun_call_jitted(x):
|
||||
return call(g, x)
|
||||
|
||||
def fun_with_two_calls(x):
|
||||
return call(np.sin, x) + call(np.cos, x)
|
||||
return call(jnp.sin, x) + call(jnp.cos, x)
|
||||
|
||||
def fun_with_call_closure(x):
|
||||
def foo(y, z):
|
||||
return (x * x) * np.sin(y) * z
|
||||
return (x * x) * jnp.sin(y) * z
|
||||
|
||||
return call(foo, x, np.cos(x)) + x
|
||||
return call(foo, x, jnp.cos(x)) + x
|
||||
|
||||
def product_io_fun(x, y):
|
||||
xa = x['a']
|
||||
xb = x['b']
|
||||
y1, (y2, y3) = y
|
||||
return np.sin(xa + y2), [xb, (y1, y3)]
|
||||
return jnp.sin(xa + y2), [xb, (y1, y3)]
|
||||
|
||||
|
||||
R = onp.random.randn
|
||||
R = np.random.randn
|
||||
CallSpec = namedtuple('CallSpec', ['fun', 'args'])
|
||||
test_specs_base = [
|
||||
CallSpec(simple_fun, (R(3, 2), R(3, 2))),
|
||||
@ -173,12 +173,12 @@ class CoreTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters(test_specs)
|
||||
def test_jvp(self, f, args):
|
||||
jtu.check_jvp(f, partial(jvp, f), args, rtol={onp.float32: 3e-2})
|
||||
jtu.check_jvp(f, partial(jvp, f), args, rtol={np.float32: 3e-2})
|
||||
|
||||
def test_jvp_zeros(self):
|
||||
def foo(x):
|
||||
def bar(y):
|
||||
return np.sin(x * y)
|
||||
return jnp.sin(x * y)
|
||||
return jvp(bar, (3 * x,), (2 * x,))
|
||||
|
||||
jtu.check_eq(jit(foo)(0.5), foo(0.5))
|
||||
@ -186,18 +186,18 @@ class CoreTest(jtu.JaxTestCase):
|
||||
@parameterized.parameters(test_specs)
|
||||
def test_jvp_linearized(self, f, args):
|
||||
jtu.check_jvp(f, partial(jvp_unlinearized, f), args,
|
||||
rtol={onp.float32: 3e-2})
|
||||
rtol={np.float32: 3e-2})
|
||||
|
||||
@parameterized.parameters(test_specs)
|
||||
def test_vjp(self, f, args):
|
||||
jtu.check_vjp(f, partial(vjp, f), args,
|
||||
rtol={onp.float32: 3e-1, onp.float64: 1e-5},
|
||||
atol={onp.float32: 1e-2, onp.float64: 1e-5})
|
||||
rtol={np.float32: 3e-1, np.float64: 1e-5},
|
||||
atol={np.float32: 1e-2, np.float64: 1e-5})
|
||||
|
||||
def test_jvp_closure(self):
|
||||
def foo(x):
|
||||
def bar(y):
|
||||
return np.multiply(x, y)
|
||||
return jnp.multiply(x, y)
|
||||
return jvp(bar, (3.0,), (1.0,))[1]
|
||||
ans = jvp(foo, (1.0,), (2.0,))
|
||||
assert ans == (1.0, 2.0), ans
|
||||
@ -220,15 +220,15 @@ class CoreTest(jtu.JaxTestCase):
|
||||
foo2 = jit(foo)
|
||||
foo3 = jit(foo2)
|
||||
|
||||
x1, y1 = onp.array(1.0), onp.array(2.0)
|
||||
x1, y1 = np.array(1.0), np.array(2.0)
|
||||
assert foo(x1) == y1
|
||||
assert foo2(x1) == y1
|
||||
assert foo3(x1) == y1
|
||||
|
||||
x2, y2 = onp.array([1.0, 2.0]), onp.array([3.0, 4.0])
|
||||
assert onp.all(foo(x2) == y2)
|
||||
assert onp.all(foo2(x2) == y2)
|
||||
assert onp.all(foo3(x2) == y2)
|
||||
x2, y2 = np.array([1.0, 2.0]), np.array([3.0, 4.0])
|
||||
assert np.all(foo(x2) == y2)
|
||||
assert np.all(foo2(x2) == y2)
|
||||
assert np.all(foo3(x2) == y2)
|
||||
|
||||
def test_product_jit(self):
|
||||
def foo(x, tup):
|
||||
@ -247,7 +247,7 @@ class CoreTest(jtu.JaxTestCase):
|
||||
assert foo3(*args) == foo(*args)
|
||||
|
||||
def test_jvp_2(self):
|
||||
d_sin = fwd_deriv(np.sin)
|
||||
d_sin = fwd_deriv(jnp.sin)
|
||||
d2_sin = fwd_deriv(d_sin)
|
||||
d3_sin = fwd_deriv(d2_sin)
|
||||
|
||||
@ -262,7 +262,7 @@ class CoreTest(jtu.JaxTestCase):
|
||||
return x.sum()
|
||||
|
||||
fn = partial(linearize, f)
|
||||
params = np.zeros([])
|
||||
params = jnp.zeros([])
|
||||
|
||||
debug = gc.get_debug()
|
||||
try:
|
||||
|
@ -17,11 +17,9 @@
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as onp
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax import numpy as np
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -36,18 +34,18 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
config.update("jax_debug_nans", self.cfg)
|
||||
|
||||
def testSingleResultPrimitiveNoNaN(self):
|
||||
A = np.array([[1., 2.], [2., 3.]])
|
||||
B = np.tanh(A)
|
||||
A = jnp.array([[1., 2.], [2., 3.]])
|
||||
B = jnp.tanh(A)
|
||||
|
||||
def testMultipleResultPrimitiveNoNaN(self):
|
||||
A = np.array([[1., 2.], [2., 3.]])
|
||||
D, V = np.linalg.eig(A)
|
||||
A = jnp.array([[1., 2.], [2., 3.]])
|
||||
D, V = jnp.linalg.eig(A)
|
||||
|
||||
def testJitComputationNoNaN(self):
|
||||
A = np.array([[1., 2.], [2., 3.]])
|
||||
B = jax.jit(np.tanh)(A)
|
||||
A = jnp.array([[1., 2.], [2., 3.]])
|
||||
B = jax.jit(jnp.tanh)(A)
|
||||
|
||||
def testSingleResultPrimitiveNaN(self):
|
||||
A = np.array(0.)
|
||||
A = jnp.array(0.)
|
||||
with self.assertRaises(FloatingPointError):
|
||||
B = 0. / A
|
||||
|
@ -21,12 +21,12 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax import numpy as np
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.interpreters import xla
|
||||
|
||||
@ -34,41 +34,41 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
bool_dtypes = [onp.dtype('bool')]
|
||||
bool_dtypes = [np.dtype('bool')]
|
||||
|
||||
signed_dtypes = [onp.dtype('int8'), onp.dtype('int16'), onp.dtype('int32'),
|
||||
onp.dtype('int64')]
|
||||
signed_dtypes = [np.dtype('int8'), np.dtype('int16'), np.dtype('int32'),
|
||||
np.dtype('int64')]
|
||||
|
||||
unsigned_dtypes = [onp.dtype('uint8'), onp.dtype('uint16'), onp.dtype('uint32'),
|
||||
onp.dtype('uint64')]
|
||||
unsigned_dtypes = [np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'),
|
||||
np.dtype('uint64')]
|
||||
|
||||
onp_float_dtypes = [onp.dtype('float16'), onp.dtype('float32'),
|
||||
onp.dtype('float64')]
|
||||
np_float_dtypes = [np.dtype('float16'), np.dtype('float32'),
|
||||
np.dtype('float64')]
|
||||
|
||||
float_dtypes = [onp.dtype(dtypes.bfloat16)] + onp_float_dtypes
|
||||
float_dtypes = [np.dtype(dtypes.bfloat16)] + np_float_dtypes
|
||||
|
||||
complex_dtypes = [onp.dtype('complex64'), onp.dtype('complex128')]
|
||||
complex_dtypes = [np.dtype('complex64'), np.dtype('complex128')]
|
||||
|
||||
|
||||
all_dtypes = (bool_dtypes + signed_dtypes + unsigned_dtypes + float_dtypes +
|
||||
complex_dtypes)
|
||||
|
||||
scalar_types = [np.bool_, np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.bfloat16, np.float16, np.float32, np.float64,
|
||||
np.complex64, np.complex128]
|
||||
scalar_types = [jnp.bool_, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
||||
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
|
||||
jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64,
|
||||
jnp.complex64, jnp.complex128]
|
||||
|
||||
class DtypesTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_type={}".format(type.__name__), "type": type,
|
||||
"dtype": dtype}
|
||||
for type, dtype in [(bool, np.bool_), (int, np.int_), (float, np.float_),
|
||||
(complex, np.complex_)])
|
||||
for type, dtype in [(bool, jnp.bool_), (int, jnp.int_), (float, jnp.float_),
|
||||
(complex, jnp.complex_)])
|
||||
def testDefaultTypes(self, type, dtype):
|
||||
for f in [np.array, jax.jit(np.array), jax.jit(lambda x: x)]:
|
||||
for f in [jnp.array, jax.jit(jnp.array), jax.jit(lambda x: x)]:
|
||||
y = f(type(0))
|
||||
self.assertTrue(isinstance(y, np.ndarray), msg=(f, y))
|
||||
self.assertTrue(isinstance(y, jnp.ndarray), msg=(f, y))
|
||||
self.assertEqual(y.dtype, dtypes.canonicalize_dtype(dtype), msg=(f, y))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -78,51 +78,51 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu") # F16 not supported on TPU
|
||||
def testBinaryPromotion(self, swap, jit):
|
||||
testcases = [
|
||||
(np.array(1.), 0., np.float_),
|
||||
(np.array(1.), np.array(0.), np.float_),
|
||||
(np.array(1.), np.array(0., dtype=np.float16), np.float_),
|
||||
(np.array(1.), np.array(0., dtype=np.float32), np.float_),
|
||||
(np.array(1.), np.array(0., dtype=np.float64), np.float64),
|
||||
(np.array(1., dtype=np.float16), 0., np.float16),
|
||||
(np.array(1., dtype=np.float32), 0., np.float32),
|
||||
(np.array(1., dtype=np.float64), 0., np.float64),
|
||||
(np.array(1., dtype=np.float16), np.array(0., dtype=np.float16), np.float16),
|
||||
(np.array(1., dtype=np.float16), np.array(0., dtype=np.float32), np.float32),
|
||||
(np.array(1., dtype=np.float16), np.array(0., dtype=np.float64), np.float64),
|
||||
(np.array(1., dtype=np.float32), np.array(0., dtype=np.float32), np.float32),
|
||||
(np.array(1., dtype=np.float32), np.array(0., dtype=np.float64), np.float64),
|
||||
(np.array(1., dtype=np.float64), np.array(0., dtype=np.float64), np.float64),
|
||||
(np.array([1.]), 0., np.float_),
|
||||
(np.array([1.]), np.array(0.), np.float_),
|
||||
(np.array([1.]), np.array(0., dtype=np.float16), np.float_),
|
||||
(np.array([1.]), np.array(0., dtype=np.float32), np.float_),
|
||||
(np.array([1.]), np.array(0., dtype=np.float64), np.float64),
|
||||
(np.array([1.], dtype=np.float32), np.array(0., dtype=np.float16), np.float32),
|
||||
(np.array([1.], dtype=np.float16), np.array(0., dtype=np.float32), np.float32),
|
||||
(np.array([1.], dtype=np.float16), 0., np.float16),
|
||||
(jnp.array(1.), 0., jnp.float_),
|
||||
(jnp.array(1.), jnp.array(0.), jnp.float_),
|
||||
(jnp.array(1.), jnp.array(0., dtype=jnp.float16), jnp.float_),
|
||||
(jnp.array(1.), jnp.array(0., dtype=jnp.float32), jnp.float_),
|
||||
(jnp.array(1.), jnp.array(0., dtype=jnp.float64), jnp.float64),
|
||||
(jnp.array(1., dtype=jnp.float16), 0., jnp.float16),
|
||||
(jnp.array(1., dtype=jnp.float32), 0., jnp.float32),
|
||||
(jnp.array(1., dtype=jnp.float64), 0., jnp.float64),
|
||||
(jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float16), jnp.float16),
|
||||
(jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float32), jnp.float32),
|
||||
(jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float64), jnp.float64),
|
||||
(jnp.array(1., dtype=jnp.float32), jnp.array(0., dtype=jnp.float32), jnp.float32),
|
||||
(jnp.array(1., dtype=jnp.float32), jnp.array(0., dtype=jnp.float64), jnp.float64),
|
||||
(jnp.array(1., dtype=jnp.float64), jnp.array(0., dtype=jnp.float64), jnp.float64),
|
||||
(jnp.array([1.]), 0., jnp.float_),
|
||||
(jnp.array([1.]), jnp.array(0.), jnp.float_),
|
||||
(jnp.array([1.]), jnp.array(0., dtype=jnp.float16), jnp.float_),
|
||||
(jnp.array([1.]), jnp.array(0., dtype=jnp.float32), jnp.float_),
|
||||
(jnp.array([1.]), jnp.array(0., dtype=jnp.float64), jnp.float64),
|
||||
(jnp.array([1.], dtype=jnp.float32), jnp.array(0., dtype=jnp.float16), jnp.float32),
|
||||
(jnp.array([1.], dtype=jnp.float16), jnp.array(0., dtype=jnp.float32), jnp.float32),
|
||||
(jnp.array([1.], dtype=jnp.float16), 0., jnp.float16),
|
||||
]
|
||||
op = jax.jit(operator.add) if jit else operator.add
|
||||
for x, y, dtype in testcases:
|
||||
x, y = (y, x) if swap else (x, y)
|
||||
z = x + y
|
||||
self.assertTrue(isinstance(z, np.ndarray), msg=(x, y, z))
|
||||
self.assertTrue(isinstance(z, jnp.ndarray), msg=(x, y, z))
|
||||
self.assertEqual(z.dtype, dtypes.canonicalize_dtype(dtype), msg=(x, y, z))
|
||||
|
||||
def testPromoteDtypes(self):
|
||||
for t1 in all_dtypes:
|
||||
self.assertEqual(t1, dtypes.promote_types(t1, t1))
|
||||
|
||||
self.assertEqual(t1, dtypes.promote_types(t1, onp.bool_))
|
||||
self.assertEqual(onp.dtype(onp.complex128),
|
||||
dtypes.promote_types(t1, onp.complex128))
|
||||
self.assertEqual(t1, dtypes.promote_types(t1, np.bool_))
|
||||
self.assertEqual(np.dtype(np.complex128),
|
||||
dtypes.promote_types(t1, np.complex128))
|
||||
|
||||
for t2 in all_dtypes:
|
||||
# Symmetry
|
||||
self.assertEqual(dtypes.promote_types(t1, t2),
|
||||
dtypes.promote_types(t2, t1))
|
||||
|
||||
self.assertEqual(onp.dtype(onp.float32),
|
||||
dtypes.promote_types(onp.float16, dtypes.bfloat16))
|
||||
self.assertEqual(np.dtype(np.float32),
|
||||
dtypes.promote_types(np.float16, dtypes.bfloat16))
|
||||
|
||||
# Promotions of non-inexact types against inexact types always prefer
|
||||
# the inexact types.
|
||||
@ -132,47 +132,47 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
|
||||
# Promotions between exact types, or between inexact types, match NumPy.
|
||||
for groups in [bool_dtypes + signed_dtypes + unsigned_dtypes,
|
||||
onp_float_dtypes + complex_dtypes]:
|
||||
np_float_dtypes + complex_dtypes]:
|
||||
for t1, t2 in itertools.combinations(groups, 2):
|
||||
self.assertEqual(onp.promote_types(t1, t2),
|
||||
self.assertEqual(np.promote_types(t1, t2),
|
||||
dtypes.promote_types(t1, t2))
|
||||
|
||||
def testScalarInstantiation(self):
|
||||
for t in [np.bool_, np.int32, np.bfloat16, np.float32, np.complex64]:
|
||||
for t in [jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64]:
|
||||
a = t(1)
|
||||
self.assertEqual(a.dtype, np.dtype(t))
|
||||
self.assertEqual(a.dtype, jnp.dtype(t))
|
||||
self.assertIsInstance(a, xla.DeviceArray)
|
||||
self.assertEqual(0, np.ndim(a))
|
||||
self.assertEqual(0, jnp.ndim(a))
|
||||
|
||||
def testIsSubdtype(self):
|
||||
for t in scalar_types:
|
||||
self.assertTrue(dtypes.issubdtype(t, t))
|
||||
self.assertTrue(dtypes.issubdtype(onp.dtype(t).type, t))
|
||||
self.assertTrue(dtypes.issubdtype(t, onp.dtype(t).type))
|
||||
if t != np.bfloat16:
|
||||
for category in [onp.generic, np.inexact, np.integer, np.signedinteger,
|
||||
np.unsignedinteger, np.floating, np.complexfloating]:
|
||||
self.assertTrue(dtypes.issubdtype(np.dtype(t).type, t))
|
||||
self.assertTrue(dtypes.issubdtype(t, np.dtype(t).type))
|
||||
if t != jnp.bfloat16:
|
||||
for category in [np.generic, jnp.inexact, jnp.integer, jnp.signedinteger,
|
||||
jnp.unsignedinteger, jnp.floating, jnp.complexfloating]:
|
||||
self.assertEqual(dtypes.issubdtype(t, category),
|
||||
onp.issubdtype(onp.dtype(t).type, category))
|
||||
np.issubdtype(np.dtype(t).type, category))
|
||||
self.assertEqual(dtypes.issubdtype(t, category),
|
||||
onp.issubdtype(onp.dtype(t).type, category))
|
||||
np.issubdtype(np.dtype(t).type, category))
|
||||
|
||||
def testArrayCasts(self):
|
||||
for t in [np.bool_, np.int32, np.bfloat16, np.float32, np.complex64]:
|
||||
a = onp.array([1, 2.5, -3.7])
|
||||
self.assertEqual(a.astype(t).dtype, np.dtype(t))
|
||||
self.assertEqual(np.array(a).astype(t).dtype, np.dtype(t))
|
||||
for t in [jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64]:
|
||||
a = np.array([1, 2.5, -3.7])
|
||||
self.assertEqual(a.astype(t).dtype, jnp.dtype(t))
|
||||
self.assertEqual(jnp.array(a).astype(t).dtype, jnp.dtype(t))
|
||||
|
||||
def testEnumPromotion(self):
|
||||
class AnEnum(enum.IntEnum):
|
||||
A = 42
|
||||
B = 101
|
||||
onp.testing.assert_equal(onp.array(42), onp.array(AnEnum.A))
|
||||
np.testing.assert_equal(np.array(42), np.array(AnEnum.A))
|
||||
with core.skipping_checks():
|
||||
# Passing AnEnum.A to np.array fails the type check in bind
|
||||
onp.testing.assert_equal(np.array(42), np.array(AnEnum.A))
|
||||
onp.testing.assert_equal(onp.int32(101), onp.int32(AnEnum.B))
|
||||
onp.testing.assert_equal(np.int32(101), np.int32(AnEnum.B))
|
||||
# Passing AnEnum.A to jnp.array fails the type check in bind
|
||||
np.testing.assert_equal(jnp.array(42), jnp.array(AnEnum.A))
|
||||
np.testing.assert_equal(np.int32(101), np.int32(AnEnum.B))
|
||||
np.testing.assert_equal(jnp.int32(101), jnp.int32(AnEnum.B))
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
@ -16,26 +16,26 @@
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from jax import lax
|
||||
from jax import numpy as np
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
float_dtypes = [onp.float32, onp.float64]
|
||||
# TODO(b/144573940): onp.complex128 isn't supported by XLA, and the JAX
|
||||
float_dtypes = [np.float32, np.float64]
|
||||
# TODO(b/144573940): np.complex128 isn't supported by XLA, and the JAX
|
||||
# implementation casts to complex64.
|
||||
complex_dtypes = [onp.complex64]
|
||||
complex_dtypes = [np.complex64]
|
||||
inexact_dtypes = float_dtypes + complex_dtypes
|
||||
int_dtypes = [onp.int32, onp.int64]
|
||||
bool_dtypes = [onp.bool_]
|
||||
int_dtypes = [np.int32, np.int64]
|
||||
bool_dtypes = [np.bool_]
|
||||
real_dtypes = float_dtypes + int_dtypes + bool_dtypes
|
||||
all_dtypes = real_dtypes + complex_dtypes
|
||||
|
||||
@ -84,7 +84,7 @@ def _zero_for_irfft(z, axes):
|
||||
else:
|
||||
parts = [lax.slice_in_dim(z.real, 0, 1, axis=axis).real,
|
||||
lax.slice_in_dim(z.real, 1, size, axis=axis)]
|
||||
return np.concatenate(parts, axis=axis)
|
||||
return jnp.concatenate(parts, axis=axis)
|
||||
|
||||
|
||||
class FftTest(jtu.JaxTestCase):
|
||||
@ -103,19 +103,19 @@ class FftTest(jtu.JaxTestCase):
|
||||
def testFftn(self, inverse, real, shape, dtype, axes, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_op = _get_fftn_func(jnp.fft, inverse, real)
|
||||
np_op = _get_fftn_func(np.fft, inverse, real)
|
||||
onp_op = _get_fftn_func(onp.fft, inverse, real)
|
||||
np_fn = lambda a: np_op(a, axes=axes)
|
||||
onp_fn = lambda a: onp_op(a, axes=axes) if axes is None or axes else a
|
||||
jnp_fn = lambda a: jnp_op(a, axes=axes)
|
||||
np_fn = lambda a: np_op(a, axes=axes) if axes is None or axes else a
|
||||
# Numpy promotes to complex128 aggressively.
|
||||
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=False,
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True)
|
||||
# Test gradient for differentiable types.
|
||||
if dtype in (float_dtypes if real and not inverse else inexact_dtypes):
|
||||
# TODO(skye): can we be more precise?
|
||||
tol = 0.15
|
||||
jtu.check_grads(np_fn, args_maker(), order=2, atol=tol, rtol=tol)
|
||||
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inverse={}_real={}".format(inverse, real),
|
||||
@ -129,20 +129,20 @@ class FftTest(jtu.JaxTestCase):
|
||||
name = 'r' + name
|
||||
if inverse:
|
||||
name = 'i' + name
|
||||
func = _get_fftn_func(np.fft, inverse, real)
|
||||
func = _get_fftn_func(jnp.fft, inverse, real)
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"jax.np.fft.{} only supports 1D, 2D, and 3D FFTs. "
|
||||
"jax.numpy.fft.{} only supports 1D, 2D, and 3D FFTs. "
|
||||
"Got axes None with input rank 4.".format(name),
|
||||
lambda: func(rng([2, 3, 4, 5], dtype=onp.float64), axes=None))
|
||||
lambda: func(rng([2, 3, 4, 5], dtype=np.float64), axes=None))
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"jax.np.fft.{} does not support repeated axes. Got axes \\[1, 1\\].".format(name),
|
||||
lambda: func(rng([2, 3], dtype=onp.float64), axes=[1, 1]))
|
||||
"jax.numpy.fft.{} does not support repeated axes. Got axes \\[1, 1\\].".format(name),
|
||||
lambda: func(rng([2, 3], dtype=np.float64), axes=[1, 1]))
|
||||
self.assertRaises(
|
||||
ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axes=[2]))
|
||||
ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2]))
|
||||
self.assertRaises(
|
||||
ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axes=[-3]))
|
||||
ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3]))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inverse={}_real={}_shape={}_axis={}".format(
|
||||
@ -163,14 +163,14 @@ class FftTest(jtu.JaxTestCase):
|
||||
name = 'r' + name
|
||||
if inverse:
|
||||
name = 'i' + name
|
||||
jnp_op = getattr(jnp.fft, name)
|
||||
np_op = getattr(np.fft, name)
|
||||
onp_op = getattr(onp.fft, name)
|
||||
jnp_fn = lambda a: jnp_op(a, axis=axis)
|
||||
np_fn = lambda a: np_op(a, axis=axis)
|
||||
onp_fn = lambda a: onp_op(a, axis=axis)
|
||||
# Numpy promotes to complex128 aggressively.
|
||||
self._CheckAgainstNumpy(onp_op, np_op, args_maker, check_dtypes=False,
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(np_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inverse={}_real={}".format(inverse, real),
|
||||
@ -184,26 +184,26 @@ class FftTest(jtu.JaxTestCase):
|
||||
name = 'r' + name
|
||||
if inverse:
|
||||
name = 'i' + name
|
||||
func = getattr(np.fft, name)
|
||||
func = getattr(jnp.fft, name)
|
||||
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"jax.np.fft.{} does not support multiple axes. "
|
||||
"Please use jax.np.fft.{}n. "
|
||||
"jax.numpy.fft.{} does not support multiple axes. "
|
||||
"Please use jax.numpy.fft.{}n. "
|
||||
"Got axis = \\[1, 1\\].".format(name, name),
|
||||
lambda: func(rng([2, 3], dtype=onp.float64), axis=[1, 1])
|
||||
lambda: func(rng([2, 3], dtype=np.float64), axis=[1, 1])
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"jax.np.fft.{} does not support multiple axes. "
|
||||
"Please use jax.np.fft.{}n. "
|
||||
"jax.numpy.fft.{} does not support multiple axes. "
|
||||
"Please use jax.numpy.fft.{}n. "
|
||||
"Got axis = \\(1, 1\\).".format(name, name),
|
||||
lambda: func(rng([2, 3], dtype=onp.float64), axis=(1, 1))
|
||||
lambda: func(rng([2, 3], dtype=np.float64), axis=(1, 1))
|
||||
)
|
||||
self.assertRaises(
|
||||
ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axis=[2]))
|
||||
ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[2]))
|
||||
self.assertRaises(
|
||||
ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axis=[-3]))
|
||||
ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[-3]))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inverse={}_real={}_shape={}_axes={}".format(
|
||||
@ -224,12 +224,12 @@ class FftTest(jtu.JaxTestCase):
|
||||
name = 'r' + name
|
||||
if inverse:
|
||||
name = 'i' + name
|
||||
jnp_op = getattr(jnp.fft, name)
|
||||
np_op = getattr(np.fft, name)
|
||||
onp_op = getattr(onp.fft, name)
|
||||
# Numpy promotes to complex128 aggressively.
|
||||
self._CheckAgainstNumpy(onp_op, np_op, args_maker, check_dtypes=False,
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(np_op, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inverse={}_real={}".format(inverse, real),
|
||||
@ -243,24 +243,24 @@ class FftTest(jtu.JaxTestCase):
|
||||
name = 'r' + name
|
||||
if inverse:
|
||||
name = 'i' + name
|
||||
func = getattr(np.fft, name)
|
||||
func = getattr(jnp.fft, name)
|
||||
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"jax.np.fft.{} only supports 2 axes. "
|
||||
"jax.numpy.fft.{} only supports 2 axes. "
|
||||
"Got axes = \\[0\\].".format(name),
|
||||
lambda: func(rng([2, 3], dtype=onp.float64), axes=[0])
|
||||
lambda: func(rng([2, 3], dtype=np.float64), axes=[0])
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"jax.np.fft.{} only supports 2 axes. "
|
||||
"jax.numpy.fft.{} only supports 2 axes. "
|
||||
"Got axes = \\(0, 1, 2\\).".format(name),
|
||||
lambda: func(rng([2, 3, 3], dtype=onp.float64), axes=(0, 1, 2))
|
||||
lambda: func(rng([2, 3, 3], dtype=np.float64), axes=(0, 1, 2))
|
||||
)
|
||||
self.assertRaises(
|
||||
ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axes=[2, 3]))
|
||||
ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2, 3]))
|
||||
self.assertRaises(
|
||||
ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axes=[-3, -4]))
|
||||
ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3, -4]))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_size={}_d={}".format(
|
||||
@ -273,18 +273,18 @@ class FftTest(jtu.JaxTestCase):
|
||||
def testFftfreq(self, size, d, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: (rng([size], dtype),)
|
||||
jnp_op = jnp.fft.fftfreq
|
||||
np_op = np.fft.fftfreq
|
||||
onp_op = onp.fft.fftfreq
|
||||
jnp_fn = lambda a: jnp_op(size, d=d)
|
||||
np_fn = lambda a: np_op(size, d=d)
|
||||
onp_fn = lambda a: onp_op(size, d=d)
|
||||
# Numpy promotes to complex128 aggressively.
|
||||
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=False,
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True)
|
||||
# Test gradient for differentiable types.
|
||||
if dtype in inexact_dtypes:
|
||||
tol = 0.15 # TODO(skye): can we be more precise?
|
||||
jtu.check_grads(np_fn, args_maker(), order=2, atol=tol, rtol=tol)
|
||||
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_n={}".format(n),
|
||||
@ -292,16 +292,16 @@ class FftTest(jtu.JaxTestCase):
|
||||
for n in [[0,1,2]]))
|
||||
def testFftfreqErrors(self, n):
|
||||
name = 'fftfreq'
|
||||
func = np.fft.fftfreq
|
||||
func = jnp.fft.fftfreq
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The n argument of jax.np.fft.{} only takes an int. "
|
||||
"The n argument of jax.numpy.fft.{} only takes an int. "
|
||||
"Got n = \\[0, 1, 2\\].".format(name),
|
||||
lambda: func(n=n)
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The d argument of jax.np.fft.{} only takes a single value. "
|
||||
"The d argument of jax.numpy.fft.{} only takes a single value. "
|
||||
"Got d = \\[0, 1, 2\\].".format(name),
|
||||
lambda: func(n=10, d=n)
|
||||
)
|
||||
@ -317,18 +317,18 @@ class FftTest(jtu.JaxTestCase):
|
||||
def testRfftfreq(self, size, d, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: (rng([size], dtype),)
|
||||
jnp_op = jnp.fft.rfftfreq
|
||||
np_op = np.fft.rfftfreq
|
||||
onp_op = onp.fft.rfftfreq
|
||||
jnp_fn = lambda a: jnp_op(size, d=d)
|
||||
np_fn = lambda a: np_op(size, d=d)
|
||||
onp_fn = lambda a: onp_op(size, d=d)
|
||||
# Numpy promotes to complex128 aggressively.
|
||||
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=False,
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, check_dtypes=True)
|
||||
# Test gradient for differentiable types.
|
||||
if dtype in inexact_dtypes:
|
||||
tol = 0.15 # TODO(skye): can we be more precise?
|
||||
jtu.check_grads(np_fn, args_maker(), order=2, atol=tol, rtol=tol)
|
||||
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_n={}".format(n),
|
||||
@ -336,16 +336,16 @@ class FftTest(jtu.JaxTestCase):
|
||||
for n in [[0, 1, 2]]))
|
||||
def testRfftfreqErrors(self, n):
|
||||
name = 'rfftfreq'
|
||||
func = np.fft.rfftfreq
|
||||
func = jnp.fft.rfftfreq
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The n argument of jax.np.fft.{} only takes an int. "
|
||||
"The n argument of jax.numpy.fft.{} only takes an int. "
|
||||
"Got n = \\[0, 1, 2\\].".format(name),
|
||||
lambda: func(n=n)
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The d argument of jax.np.fft.{} only takes a single value. "
|
||||
"The d argument of jax.numpy.fft.{} only takes a single value. "
|
||||
"Got d = \\[0, 1, 2\\].".format(name),
|
||||
lambda: func(n=10, d=n)
|
||||
)
|
||||
@ -361,9 +361,9 @@ class FftTest(jtu.JaxTestCase):
|
||||
def testFftshift(self, shape, dtype, rng_factory, axes):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda arg: jnp.fft.fftshift(arg, axes=axes)
|
||||
np_fn = lambda arg: np.fft.fftshift(arg, axes=axes)
|
||||
onp_fn = lambda arg: onp.fft.fftshift(arg, axes=axes)
|
||||
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "dtype={}_axes={}".format(
|
||||
@ -376,9 +376,9 @@ class FftTest(jtu.JaxTestCase):
|
||||
def testIfftshift(self, shape, dtype, rng_factory, axes):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda arg: jnp.fft.ifftshift(arg, axes=axes)
|
||||
np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes)
|
||||
onp_fn = lambda arg: onp.fft.ifftshift(arg, axes=axes)
|
||||
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
@ -16,14 +16,12 @@ from functools import partial
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import numpy as onp
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as np
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
@ -41,9 +39,9 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
((6, 5, 2, 3), (3, 4), (6, 5, 2, 4)),
|
||||
]))
|
||||
def test_matmat(self, left_shape, right_shape, result_shape):
|
||||
matmat = np.vectorize(np.dot, signature='(n,m),(m,k)->(n,k)')
|
||||
self.assertEqual(matmat(np.zeros(left_shape),
|
||||
np.zeros(right_shape)).shape, result_shape)
|
||||
matmat = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k)')
|
||||
self.assertEqual(matmat(jnp.zeros(left_shape),
|
||||
jnp.zeros(right_shape)).shape, result_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
|
||||
@ -55,9 +53,9 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
((5, 4, 2, 3), (1, 3), (5, 4, 2)),
|
||||
]))
|
||||
def test_matvec(self, left_shape, right_shape, result_shape):
|
||||
matvec = np.vectorize(np.dot, signature='(n,m),(m)->(n)')
|
||||
self.assertEqual(matvec(np.zeros(left_shape),
|
||||
np.zeros(right_shape)).shape, result_shape)
|
||||
matvec = jnp.vectorize(jnp.dot, signature='(n,m),(m)->(n)')
|
||||
self.assertEqual(matvec(jnp.zeros(left_shape),
|
||||
jnp.zeros(right_shape)).shape, result_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
|
||||
@ -68,9 +66,9 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
((4, 2, 3), (3,), (4, 2)),
|
||||
]))
|
||||
def test_vecmat(self, left_shape, right_shape, result_shape):
|
||||
vecvec = np.vectorize(np.dot, signature='(m),(m)->()')
|
||||
self.assertEqual(vecvec(np.zeros(left_shape),
|
||||
np.zeros(right_shape)).shape, result_shape)
|
||||
vecvec = jnp.vectorize(jnp.dot, signature='(m),(m)->()')
|
||||
self.assertEqual(vecvec(jnp.zeros(left_shape),
|
||||
jnp.zeros(right_shape)).shape, result_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(shape),
|
||||
@ -84,11 +82,11 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
size = 1
|
||||
for x in shape:
|
||||
size *= x
|
||||
inputs = np.arange(size).reshape(shape)
|
||||
inputs = jnp.arange(size).reshape(shape)
|
||||
|
||||
@partial(np.vectorize, signature='(n)->()')
|
||||
@partial(jnp.vectorize, signature='(n)->()')
|
||||
def magnitude(x):
|
||||
return np.dot(x, x)
|
||||
return jnp.dot(x, x)
|
||||
|
||||
self.assertEqual(magnitude(inputs).shape, result_shape)
|
||||
|
||||
@ -101,8 +99,8 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
((1, 2, 3, 4), (1, 2, 3)),
|
||||
]))
|
||||
def test_mean(self, shape, result_shape):
|
||||
mean = np.vectorize(np.mean, signature='(n)->()')
|
||||
self.assertEqual(mean(np.zeros(shape)).shape, result_shape)
|
||||
mean = jnp.vectorize(jnp.mean, signature='(n)->()')
|
||||
self.assertEqual(mean(jnp.zeros(shape)).shape, result_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(shape),
|
||||
@ -113,104 +111,104 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
]))
|
||||
def test_stack_plus_minus(self, shape, result_shape):
|
||||
|
||||
@partial(np.vectorize, signature='()->(n)')
|
||||
@partial(jnp.vectorize, signature='()->(n)')
|
||||
def stack_plus_minus(x):
|
||||
return np.stack([x, -x])
|
||||
return jnp.stack([x, -x])
|
||||
|
||||
self.assertEqual(stack_plus_minus(np.zeros(shape)).shape, result_shape)
|
||||
self.assertEqual(stack_plus_minus(jnp.zeros(shape)).shape, result_shape)
|
||||
|
||||
def test_center(self):
|
||||
|
||||
@partial(np.vectorize, signature='(n)->(),(n)')
|
||||
@partial(jnp.vectorize, signature='(n)->(),(n)')
|
||||
def center(array):
|
||||
bias = np.mean(array)
|
||||
bias = jnp.mean(array)
|
||||
debiased = array - bias
|
||||
return bias, debiased
|
||||
|
||||
b, a = center(np.arange(3))
|
||||
b, a = center(jnp.arange(3))
|
||||
self.assertEqual(a.shape, (3,))
|
||||
self.assertEqual(b.shape, ())
|
||||
self.assertAllClose(1.0, b, check_dtypes=False)
|
||||
|
||||
b, a = center(np.arange(6).reshape(2, 3))
|
||||
b, a = center(jnp.arange(6).reshape(2, 3))
|
||||
self.assertEqual(a.shape, (2, 3))
|
||||
self.assertEqual(b.shape, (2,))
|
||||
self.assertAllClose(np.array([1.0, 4.0]), b, check_dtypes=False)
|
||||
self.assertAllClose(jnp.array([1.0, 4.0]), b, check_dtypes=False)
|
||||
|
||||
def test_exclude_first(self):
|
||||
|
||||
@partial(np.vectorize, excluded={0})
|
||||
@partial(jnp.vectorize, excluded={0})
|
||||
def f(x, y):
|
||||
assert x == 'foo'
|
||||
assert y.ndim == 0
|
||||
return y
|
||||
|
||||
x = np.arange(3)
|
||||
x = jnp.arange(3)
|
||||
self.assertAllClose(x, f('foo', x), check_dtypes=True)
|
||||
self.assertAllClose(x, jax.jit(f, 0)('foo', x), check_dtypes=True)
|
||||
|
||||
def test_exclude_second(self):
|
||||
|
||||
@partial(np.vectorize, excluded={1})
|
||||
@partial(jnp.vectorize, excluded={1})
|
||||
def f(x, y):
|
||||
assert x.ndim == 0
|
||||
assert y == 'foo'
|
||||
return x
|
||||
|
||||
x = np.arange(3)
|
||||
x = jnp.arange(3)
|
||||
self.assertAllClose(x, f(x, 'foo'), check_dtypes=True)
|
||||
self.assertAllClose(x, jax.jit(f, 1)(x, 'foo'), check_dtypes=True)
|
||||
|
||||
def test_exclude_errors(self):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "jax.numpy.vectorize can only exclude"):
|
||||
np.vectorize(lambda x: x, excluded={'foo'})
|
||||
jnp.vectorize(lambda x: x, excluded={'foo'})
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"excluded=\{-1\} contains negative numbers"):
|
||||
np.vectorize(lambda x: x, excluded={-1})
|
||||
jnp.vectorize(lambda x: x, excluded={-1})
|
||||
|
||||
f = np.vectorize(lambda x: x, excluded={1})
|
||||
f = jnp.vectorize(lambda x: x, excluded={1})
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"excluded=\{1\} is invalid for 1 argument\(s\)"):
|
||||
f(1.0)
|
||||
|
||||
def test_bad_inputs(self):
|
||||
matmat = np.vectorize(np.dot, signature='(n,m),(m,k)->(n,k)')
|
||||
matmat = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k)')
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "wrong number of positional arguments"):
|
||||
matmat(np.zeros((3, 2)))
|
||||
matmat(jnp.zeros((3, 2)))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"input with shape \(2,\) does not have enough dimensions"):
|
||||
matmat(np.zeros((2,)), np.zeros((2, 2)))
|
||||
matmat(jnp.zeros((2,)), jnp.zeros((2, 2)))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"inconsistent size for core dimension 'm'"):
|
||||
matmat(np.zeros((2, 3)), np.zeros((4, 5)))
|
||||
matmat(jnp.zeros((2, 3)), jnp.zeros((4, 5)))
|
||||
|
||||
def test_wrong_output_type(self):
|
||||
f = np.vectorize(np.dot, signature='(n,m),(m,k)->(n,k),()')
|
||||
f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k),()')
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "output must be a tuple"):
|
||||
f(np.zeros((2, 2)), np.zeros((2, 2)))
|
||||
f(jnp.zeros((2, 2)), jnp.zeros((2, 2)))
|
||||
|
||||
def test_wrong_num_outputs(self):
|
||||
f = np.vectorize(lambda *args: args, signature='(),()->(),(),()')
|
||||
f = jnp.vectorize(lambda *args: args, signature='(),()->(),(),()')
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "wrong number of output arguments"):
|
||||
f(1, 2)
|
||||
|
||||
def test_wrong_output_shape(self):
|
||||
f = np.vectorize(np.dot, signature='(n,m),(m,k)->(n)')
|
||||
f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n)')
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"output shape \(2, 2\) does not match"):
|
||||
f(np.zeros((2, 2)), np.zeros((2, 2)))
|
||||
f(jnp.zeros((2, 2)), jnp.zeros((2, 2)))
|
||||
|
||||
def test_inconsistent_output_size(self):
|
||||
f = np.vectorize(np.dot, signature='(n,m),(m,k)->(n,n)')
|
||||
f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,n)')
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"inconsistent size for core dimension 'n'"):
|
||||
f(np.zeros((2, 3)), np.zeros((3, 4)))
|
||||
f(jnp.zeros((2, 3)), jnp.zeros((3, 4)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -19,7 +19,7 @@ import itertools
|
||||
import unittest
|
||||
import sys
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
import scipy as osp
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -30,7 +30,7 @@ import jax.lib
|
||||
from jax import jit, grad, jvp, vmap
|
||||
from jax import lax
|
||||
from jax import lax_linalg
|
||||
from jax import numpy as np
|
||||
from jax import numpy as jnp
|
||||
from jax import scipy as jsp
|
||||
from jax import test_util as jtu
|
||||
from jax.lib import xla_bridge
|
||||
@ -40,16 +40,16 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
T = lambda x: onp.swapaxes(x, -1, -2)
|
||||
T = lambda x: np.swapaxes(x, -1, -2)
|
||||
|
||||
|
||||
float_types = [onp.float32, onp.float64]
|
||||
complex_types = [onp.complex64, onp.complex128]
|
||||
float_types = [np.float32, np.float64]
|
||||
complex_types = [np.complex64, np.complex128]
|
||||
|
||||
def _skip_if_unsupported_type(dtype):
|
||||
dtype = onp.dtype(dtype)
|
||||
dtype = np.dtype(dtype)
|
||||
if (not FLAGS.jax_enable_x64 and
|
||||
dtype in (onp.dtype('float64'), onp.dtype('complex128'))):
|
||||
dtype in (np.dtype('float64'), np.dtype('complex128'))):
|
||||
raise unittest.SkipTest("--jax_enable_x64 is not set")
|
||||
|
||||
|
||||
@ -69,25 +69,25 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
factor_shape = shape[:-1] + (2 * shape[-1],)
|
||||
a = rng(factor_shape, dtype)
|
||||
return [onp.matmul(a, np.conj(T(a)))]
|
||||
return [np.matmul(a, jnp.conj(T(a)))]
|
||||
|
||||
if (np.issubdtype(dtype, np.complexfloating) and
|
||||
if (jnp.issubdtype(dtype, jnp.complexfloating) and
|
||||
jtu.device_under_test() == "tpu"):
|
||||
self.skipTest("Unimplemented case for complex Cholesky decomposition.")
|
||||
|
||||
self._CheckAgainstNumpy(onp.linalg.cholesky, np.linalg.cholesky, args_maker,
|
||||
self._CheckAgainstNumpy(np.linalg.cholesky, jnp.linalg.cholesky, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
self._CompileAndCheck(np.linalg.cholesky, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp.linalg.cholesky, args_maker, check_dtypes=True)
|
||||
|
||||
if np.finfo(dtype).bits == 64:
|
||||
jtu.check_grads(np.linalg.cholesky, args_maker(), order=2)
|
||||
if jnp.finfo(dtype).bits == 64:
|
||||
jtu.check_grads(jnp.linalg.cholesky, args_maker(), order=2)
|
||||
|
||||
def testCholeskyGradPrecision(self):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
a = rng((3, 3), onp.float32)
|
||||
a = onp.dot(a, a.T)
|
||||
a = rng((3, 3), np.float32)
|
||||
a = np.dot(a, a.T)
|
||||
jtu.assert_dot_precision(
|
||||
lax.Precision.HIGHEST, partial(jvp, np.linalg.cholesky), (a,), (a,))
|
||||
lax.Precision.HIGHEST, partial(jvp, jnp.linalg.cholesky), (a,), (a,))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -101,14 +101,14 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng((n, n), dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(onp.linalg.det, np.linalg.det, args_maker,
|
||||
self._CheckAgainstNumpy(np.linalg.det, jnp.linalg.det, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True,
|
||||
rtol={onp.float64: 1e-13, onp.complex128: 1e-13})
|
||||
self._CompileAndCheck(jnp.linalg.det, args_maker, check_dtypes=True,
|
||||
rtol={np.float64: 1e-13, np.complex128: 1e-13})
|
||||
|
||||
def testDetOfSingularMatrix(self):
|
||||
x = np.array([[-1., 3./2], [2./3, -1.]], dtype=onp.float32)
|
||||
self.assertAllClose(onp.float32(0), jsp.linalg.det(x), check_dtypes=True)
|
||||
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
|
||||
self.assertAllClose(np.float32(0), jsp.linalg.det(x), check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -123,30 +123,30 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
a = rng(shape, dtype)
|
||||
jtu.check_grads(np.linalg.det, (a,), 2, atol=1e-1, rtol=1e-1)
|
||||
jtu.check_grads(jnp.linalg.det, (a,), 2, atol=1e-1, rtol=1e-1)
|
||||
# make sure there are no NaNs when a matrix is zero
|
||||
if len(shape) == 2:
|
||||
pass
|
||||
jtu.check_grads(
|
||||
np.linalg.det, (np.zeros_like(a),), 1, atol=1e-1, rtol=1e-1)
|
||||
jnp.linalg.det, (jnp.zeros_like(a),), 1, atol=1e-1, rtol=1e-1)
|
||||
else:
|
||||
a[0] = 0
|
||||
jtu.check_grads(np.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
|
||||
jtu.check_grads(jnp.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
|
||||
|
||||
def testDetGradOfSingularMatrixCorank1(self):
|
||||
# Rank 2 matrix with nonzero gradient
|
||||
a = np.array([[ 50, -30, 45],
|
||||
a = jnp.array([[ 50, -30, 45],
|
||||
[-30, 90, -81],
|
||||
[ 45, -81, 81]], dtype=np.float32)
|
||||
jtu.check_grads(np.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
|
||||
[ 45, -81, 81]], dtype=jnp.float32)
|
||||
jtu.check_grads(jnp.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
|
||||
|
||||
@jtu.skip_on_devices("tpu") # TODO(mattjj,pfau): nan on tpu, investigate
|
||||
def testDetGradOfSingularMatrixCorank2(self):
|
||||
# Rank 1 matrix with zero gradient
|
||||
b = np.array([[ 36, -42, 18],
|
||||
b = jnp.array([[ 36, -42, 18],
|
||||
[-42, 49, -21],
|
||||
[ 18, -21, 9]], dtype=np.float32)
|
||||
jtu.check_grads(np.linalg.det, (b,), 1, atol=1e-1, rtol=1e-1)
|
||||
[ 18, -21, 9]], dtype=jnp.float32)
|
||||
jtu.check_grads(jnp.linalg.det, (b,), 1, atol=1e-1, rtol=1e-1)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -178,16 +178,16 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
rng(b_shape + Q, dtype), # = a
|
||||
rng(b_shape, dtype)] # = b
|
||||
a, b = args_maker()
|
||||
result = np.linalg.tensorsolve(*args_maker())
|
||||
result = jnp.linalg.tensorsolve(*args_maker())
|
||||
self.assertEqual(result.shape, Q)
|
||||
|
||||
self._CheckAgainstNumpy(onp.linalg.tensorsolve,
|
||||
np.linalg.tensorsolve, args_maker,
|
||||
self._CheckAgainstNumpy(np.linalg.tensorsolve,
|
||||
jnp.linalg.tensorsolve, args_maker,
|
||||
check_dtypes=True,
|
||||
tol={onp.float32: 1e-2, onp.float64: 1e-3})
|
||||
self._CompileAndCheck(np.linalg.tensorsolve,
|
||||
tol={np.float32: 1e-2, np.float64: 1e-3})
|
||||
self._CompileAndCheck(jnp.linalg.tensorsolve,
|
||||
args_maker, check_dtypes=True,
|
||||
rtol={onp.float64: 1e-13})
|
||||
rtol={np.float64: 1e-13})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -204,9 +204,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker,
|
||||
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp.linalg.slogdet, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -221,13 +221,13 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
a = rng(shape, dtype)
|
||||
jtu.check_grads(np.linalg.slogdet, (a,), 2, atol=1e-1, rtol=1e-1)
|
||||
jtu.check_grads(jnp.linalg.slogdet, (a,), 2, atol=1e-1, rtol=1e-1)
|
||||
|
||||
def testIssue1213(self):
|
||||
for n in range(5):
|
||||
mat = np.array([onp.diag(onp.ones([5], dtype=onp.float32))*(-.01)] * 2)
|
||||
mat = jnp.array([np.diag(np.ones([5], dtype=np.float32))*(-.01)] * 2)
|
||||
args_maker = lambda: [mat]
|
||||
self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker,
|
||||
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -249,14 +249,14 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
# Norm, adjusted for dimension and type.
|
||||
def norm(x):
|
||||
norm = onp.linalg.norm(x, axis=(-2, -1))
|
||||
return norm / ((n + 1) * np.finfo(dtype).eps)
|
||||
norm = np.linalg.norm(x, axis=(-2, -1))
|
||||
return norm / ((n + 1) * jnp.finfo(dtype).eps)
|
||||
|
||||
a, = args_maker()
|
||||
w, v = np.linalg.eig(a)
|
||||
self.assertTrue(onp.all(norm(onp.matmul(a, v) - w[..., None, :] * v) < 100))
|
||||
w, v = jnp.linalg.eig(a)
|
||||
self.assertTrue(np.all(norm(np.matmul(a, v) - w[..., None, :] * v) < 100))
|
||||
|
||||
self._CompileAndCheck(partial(np.linalg.eig), args_maker,
|
||||
self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
|
||||
check_dtypes=True, rtol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -276,15 +276,15 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
n = shape[-1]
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
a, = args_maker()
|
||||
w1, _ = np.linalg.eig(a)
|
||||
w2 = np.linalg.eigvals(a)
|
||||
w1, _ = jnp.linalg.eig(a)
|
||||
w2 = jnp.linalg.eigvals(a)
|
||||
self.assertAllClose(w1, w2, check_dtypes=True)
|
||||
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testEigvalsInf(self):
|
||||
# https://github.com/google/jax/issues/2661
|
||||
x = np.array([[np.inf]], np.float64)
|
||||
self.assertTrue(np.all(np.isnan(np.linalg.eigvals(x))))
|
||||
x = jnp.array([[jnp.inf]], jnp.float64)
|
||||
self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x))))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -299,9 +299,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
shape = (10,) + shape
|
||||
args = rng(shape, dtype)
|
||||
ws, vs = vmap(np.linalg.eig)(args)
|
||||
self.assertTrue(onp.all(onp.linalg.norm(
|
||||
onp.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
|
||||
ws, vs = vmap(jnp.linalg.eig)(args)
|
||||
self.assertTrue(np.all(np.linalg.norm(
|
||||
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_n={}_lower={}".format(
|
||||
@ -316,7 +316,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
tol = 30
|
||||
if jtu.device_under_test() == "tpu":
|
||||
if np.issubdtype(dtype, onp.complexfloating):
|
||||
if jnp.issubdtype(dtype, np.complexfloating):
|
||||
raise unittest.SkipTest("No complex eigh on TPU")
|
||||
# TODO(phawkins): this tolerance is unpleasantly high.
|
||||
tol = 1500
|
||||
@ -326,17 +326,17 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
# Norm, adjusted for dimension and type.
|
||||
def norm(x):
|
||||
norm = onp.linalg.norm(x, axis=(-2, -1))
|
||||
return norm / ((n + 1) * np.finfo(dtype).eps)
|
||||
norm = np.linalg.norm(x, axis=(-2, -1))
|
||||
return norm / ((n + 1) * jnp.finfo(dtype).eps)
|
||||
|
||||
a, = args_maker()
|
||||
a = (a + onp.conj(a.T)) / 2
|
||||
w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a),
|
||||
a = (a + np.conj(a.T)) / 2
|
||||
w, v = jnp.linalg.eigh(np.tril(a) if lower else np.triu(a),
|
||||
UPLO=uplo, symmetrize_input=False)
|
||||
self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5)
|
||||
self.assertTrue(norm(onp.matmul(a, v) - w * v) < tol)
|
||||
self.assertTrue(norm(np.eye(n) - np.matmul(np.conj(T(v)), v)) < 5)
|
||||
self.assertTrue(norm(np.matmul(a, v) - w * v) < tol)
|
||||
|
||||
self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo), args_maker,
|
||||
self._CompileAndCheck(partial(jnp.linalg.eigh, UPLO=uplo), args_maker,
|
||||
check_dtypes=True, rtol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -350,14 +350,14 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if jtu.device_under_test() == "tpu":
|
||||
if np.issubdtype(dtype, np.complexfloating):
|
||||
if jnp.issubdtype(dtype, jnp.complexfloating):
|
||||
raise unittest.SkipTest("No complex eigh on TPU")
|
||||
n = shape[-1]
|
||||
def args_maker():
|
||||
a = rng((n, n), dtype)
|
||||
a = (a + onp.conj(a.T)) / 2
|
||||
a = (a + np.conj(a.T)) / 2
|
||||
return [a]
|
||||
self._CheckAgainstNumpy(onp.linalg.eigvalsh, np.linalg.eigvalsh, args_maker,
|
||||
self._CheckAgainstNumpy(np.linalg.eigvalsh, jnp.linalg.eigvalsh, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -374,16 +374,16 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self.skipTest("Test fails with numeric errors.")
|
||||
uplo = "L" if lower else "U"
|
||||
a = rng(shape, dtype)
|
||||
a = (a + onp.conj(T(a))) / 2
|
||||
ones = onp.ones((a.shape[-1], a.shape[-1]), dtype=dtype)
|
||||
a *= onp.tril(ones) if lower else onp.triu(ones)
|
||||
a = (a + np.conj(T(a))) / 2
|
||||
ones = np.ones((a.shape[-1], a.shape[-1]), dtype=dtype)
|
||||
a *= np.tril(ones) if lower else np.triu(ones)
|
||||
# Gradient checks will fail without symmetrization as the eigh jvp rule
|
||||
# is only correct for tangents in the symmetric subspace, whereas the
|
||||
# checker checks against unconstrained (co)tangents.
|
||||
if dtype not in complex_types:
|
||||
f = partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True)
|
||||
f = partial(jnp.linalg.eigh, UPLO=uplo, symmetrize_input=True)
|
||||
else: # only check eigenvalue grads for complex matrices
|
||||
f = lambda a: partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0]
|
||||
f = lambda a: partial(jnp.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0]
|
||||
jtu.check_grads(f, (a,), 2, rtol=1e-1)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -409,32 +409,32 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
# eigenvectors. You only ever want to optimize eigenvector directions, not coordinates!
|
||||
uplo = "L" if lower else "U"
|
||||
a = rng(shape, dtype)
|
||||
a = (a + onp.conj(a.T)) / 2
|
||||
a = onp.tril(a) if lower else onp.triu(a)
|
||||
a = (a + np.conj(a.T)) / 2
|
||||
a = np.tril(a) if lower else np.triu(a)
|
||||
a_dot = eps * rng(shape, dtype)
|
||||
a_dot = (a_dot + onp.conj(a_dot.T)) / 2
|
||||
a_dot = onp.tril(a_dot) if lower else onp.triu(a_dot)
|
||||
a_dot = (a_dot + np.conj(a_dot.T)) / 2
|
||||
a_dot = np.tril(a_dot) if lower else np.triu(a_dot)
|
||||
# evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix
|
||||
f = partial(np.linalg.eigh, UPLO=uplo)
|
||||
f = partial(jnp.linalg.eigh, UPLO=uplo)
|
||||
(w, v), (dw, dv) = jvp(f, primals=(a,), tangents=(a_dot,))
|
||||
new_a = a + a_dot
|
||||
new_w, new_v = f(new_a)
|
||||
new_a = (new_a + onp.conj(new_a.T)) / 2
|
||||
new_a = (new_a + np.conj(new_a.T)) / 2
|
||||
# Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
|
||||
RTOL=1e-2
|
||||
assert onp.max(
|
||||
onp.abs((onp.diag(onp.dot(onp.conj((v+dv).T), onp.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
|
||||
assert np.max(
|
||||
np.abs((np.diag(np.dot(np.conj((v+dv).T), np.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
|
||||
# Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
|
||||
assert onp.max(
|
||||
onp.linalg.norm(onp.abs(new_w*(v+dv) - onp.dot(new_a, (v+dv))), axis=0) /
|
||||
onp.linalg.norm(onp.abs(new_w*(v+dv)), axis=0)
|
||||
assert np.max(
|
||||
np.linalg.norm(np.abs(new_w*(v+dv) - np.dot(new_a, (v+dv))), axis=0) /
|
||||
np.linalg.norm(np.abs(new_w*(v+dv)), axis=0)
|
||||
) < RTOL
|
||||
|
||||
def testEighGradPrecision(self):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
a = rng((3, 3), onp.float32)
|
||||
a = rng((3, 3), np.float32)
|
||||
jtu.assert_dot_precision(
|
||||
lax.Precision.HIGHEST, partial(jvp, np.linalg.eigh), (a,), (a,))
|
||||
lax.Precision.HIGHEST, partial(jvp, jnp.linalg.eigh), (a,), (a,))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -447,14 +447,14 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if (jtu.device_under_test() == "tpu" and
|
||||
np.issubdtype(dtype, onp.complexfloating)):
|
||||
jnp.issubdtype(dtype, np.complexfloating)):
|
||||
raise unittest.SkipTest("No complex eigh on TPU")
|
||||
shape = (10,) + shape
|
||||
args = rng(shape, dtype)
|
||||
args = (args + onp.conj(T(args))) / 2
|
||||
args = (args + np.conj(T(args))) / 2
|
||||
ws, vs = vmap(jsp.linalg.eigh)(args)
|
||||
self.assertTrue(onp.all(onp.linalg.norm(
|
||||
onp.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
|
||||
self.assertTrue(np.all(np.linalg.norm(
|
||||
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_ord={}_axis={}_keepdims={}".format(
|
||||
@ -469,11 +469,11 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for keepdims in [False, True]
|
||||
for ord in (
|
||||
[None] if axis is None and len(shape) > 2
|
||||
else [None, 0, 1, 2, 3, -1, -2, -3, np.inf, -np.inf]
|
||||
else [None, 0, 1, 2, 3, -1, -2, -3, jnp.inf, -jnp.inf]
|
||||
if (axis is None and len(shape) == 1) or
|
||||
isinstance(axis, int) or
|
||||
(isinstance(axis, tuple) and len(axis) == 1)
|
||||
else [None, 'fro', 1, 2, -1, -2, np.inf, -np.inf, 'nuc'])
|
||||
else [None, 'fro', 1, 2, -1, -2, jnp.inf, -jnp.inf, 'nuc'])
|
||||
for dtype in float_types + complex_types
|
||||
for rng_factory in [jtu.rand_default])) # type: ignore
|
||||
def testNorm(self, shape, dtype, ord, axis, keepdims, rng_factory):
|
||||
@ -485,9 +485,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
raise unittest.SkipTest("No adequate SVD implementation available")
|
||||
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
onp_fn = partial(onp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
||||
np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
||||
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker,
|
||||
np_fn = partial(jnp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
||||
self._CheckAgainstNumpy(np_fn, np_fn, args_maker,
|
||||
check_dtypes=False, tol=1e-3)
|
||||
self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
|
||||
|
||||
@ -512,39 +512,39 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
# Norm, adjusted for dimension and type.
|
||||
def norm(x):
|
||||
norm = onp.linalg.norm(x, axis=(-2, -1))
|
||||
return norm / (max(m, n) * np.finfo(dtype).eps)
|
||||
norm = np.linalg.norm(x, axis=(-2, -1))
|
||||
return norm / (max(m, n) * jnp.finfo(dtype).eps)
|
||||
|
||||
a, = args_maker()
|
||||
out = np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
out = jnp.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
if compute_uv:
|
||||
# Check the reconstructed matrices
|
||||
if full_matrices:
|
||||
k = min(m, n)
|
||||
if m < n:
|
||||
self.assertTrue(onp.all(
|
||||
norm(a - onp.matmul(out[1][..., None, :] * out[0], out[2][..., :k, :])) < 50))
|
||||
self.assertTrue(np.all(
|
||||
norm(a - np.matmul(out[1][..., None, :] * out[0], out[2][..., :k, :])) < 50))
|
||||
else:
|
||||
self.assertTrue(onp.all(
|
||||
norm(a - onp.matmul(out[1][..., None, :] * out[0][..., :, :k], out[2])) < 350))
|
||||
self.assertTrue(np.all(
|
||||
norm(a - np.matmul(out[1][..., None, :] * out[0][..., :, :k], out[2])) < 350))
|
||||
else:
|
||||
self.assertTrue(onp.all(
|
||||
norm(a - onp.matmul(out[1][..., None, :] * out[0], out[2])) < 350))
|
||||
self.assertTrue(np.all(
|
||||
norm(a - np.matmul(out[1][..., None, :] * out[0], out[2])) < 350))
|
||||
|
||||
# Check the unitary properties of the singular vector matrices.
|
||||
self.assertTrue(onp.all(norm(onp.eye(out[0].shape[-1]) - onp.matmul(onp.conj(T(out[0])), out[0])) < 15))
|
||||
self.assertTrue(np.all(norm(np.eye(out[0].shape[-1]) - np.matmul(np.conj(T(out[0])), out[0])) < 15))
|
||||
if m >= n:
|
||||
self.assertTrue(onp.all(norm(onp.eye(out[2].shape[-1]) - onp.matmul(onp.conj(T(out[2])), out[2])) < 10))
|
||||
self.assertTrue(np.all(norm(np.eye(out[2].shape[-1]) - np.matmul(np.conj(T(out[2])), out[2])) < 10))
|
||||
else:
|
||||
self.assertTrue(onp.all(norm(onp.eye(out[2].shape[-2]) - onp.matmul(out[2], onp.conj(T(out[2])))) < 20))
|
||||
self.assertTrue(np.all(norm(np.eye(out[2].shape[-2]) - np.matmul(out[2], np.conj(T(out[2])))) < 20))
|
||||
|
||||
else:
|
||||
self.assertTrue(onp.allclose(onp.linalg.svd(a, compute_uv=False), onp.asarray(out), atol=1e-4, rtol=1e-4))
|
||||
self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False), np.asarray(out), atol=1e-4, rtol=1e-4))
|
||||
|
||||
self._CompileAndCheck(partial(np.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv),
|
||||
self._CompileAndCheck(partial(jnp.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv),
|
||||
args_maker, check_dtypes=True)
|
||||
if not (compute_uv and full_matrices):
|
||||
svd = partial(np.linalg.svd, full_matrices=full_matrices,
|
||||
svd = partial(jnp.linalg.svd, full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
# TODO(phawkins): these tolerances seem very loose.
|
||||
jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=5e-2, atol=2e-1)
|
||||
@ -561,7 +561,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
def testQr(self, shape, dtype, full_matrices, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if (np.issubdtype(dtype, onp.complexfloating) and
|
||||
if (jnp.issubdtype(dtype, np.complexfloating) and
|
||||
jtu.device_under_test() == "tpu"):
|
||||
raise unittest.SkipTest("No complex QR implementation")
|
||||
m, n = shape[-2:]
|
||||
@ -572,42 +572,42 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
mode, k = "reduced", min(m, n)
|
||||
|
||||
a = rng(shape, dtype)
|
||||
lq, lr = np.linalg.qr(a, mode=mode)
|
||||
lq, lr = jnp.linalg.qr(a, mode=mode)
|
||||
|
||||
# onp.linalg.qr doesn't support batch dimensions. But it seems like an
|
||||
# np.linalg.qr doesn't support batch dimensions. But it seems like an
|
||||
# inevitable extension so we support it in our version.
|
||||
nq = onp.zeros(shape[:-2] + (m, k), dtype)
|
||||
nr = onp.zeros(shape[:-2] + (k, n), dtype)
|
||||
for index in onp.ndindex(*shape[:-2]):
|
||||
nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode)
|
||||
nq = np.zeros(shape[:-2] + (m, k), dtype)
|
||||
nr = np.zeros(shape[:-2] + (k, n), dtype)
|
||||
for index in np.ndindex(*shape[:-2]):
|
||||
nq[index], nr[index] = np.linalg.qr(a[index], mode=mode)
|
||||
|
||||
max_rank = max(m, n)
|
||||
|
||||
# Norm, adjusted for dimension and type.
|
||||
def norm(x):
|
||||
n = onp.linalg.norm(x, axis=(-2, -1))
|
||||
return n / (max_rank * np.finfo(dtype).eps)
|
||||
n = np.linalg.norm(x, axis=(-2, -1))
|
||||
return n / (max_rank * jnp.finfo(dtype).eps)
|
||||
|
||||
def compare_orthogonal(q1, q2):
|
||||
# Q is unique up to sign, so normalize the sign first.
|
||||
sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True)
|
||||
phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios))
|
||||
sum_of_ratios = np.sum(np.divide(q1, q2), axis=-2, keepdims=True)
|
||||
phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
|
||||
q1 *= phases
|
||||
self.assertTrue(onp.all(norm(q1 - q2) < 30))
|
||||
self.assertTrue(np.all(norm(q1 - q2) < 30))
|
||||
|
||||
# Check a ~= qr
|
||||
self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30))
|
||||
self.assertTrue(np.all(norm(a - np.matmul(lq, lr)) < 30))
|
||||
|
||||
# Compare the first 'k' vectors of Q; the remainder form an arbitrary
|
||||
# orthonormal basis for the null space.
|
||||
compare_orthogonal(nq[..., :k], lq[..., :k])
|
||||
|
||||
# Check that q is close to unitary.
|
||||
self.assertTrue(onp.all(
|
||||
norm(onp.eye(k) -onp.matmul(onp.conj(T(lq)), lq)) < 5))
|
||||
self.assertTrue(np.all(
|
||||
norm(np.eye(k) -np.matmul(np.conj(T(lq)), lq)) < 5))
|
||||
|
||||
if not full_matrices and m >= n:
|
||||
jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a,), atol=3e-3)
|
||||
jtu.check_jvp(jnp.linalg.qr, partial(jvp, jnp.linalg.qr), (a,), atol=3e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(
|
||||
@ -619,16 +619,16 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testQrBatching(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args = rng(shape, np.float32)
|
||||
args = rng(shape, jnp.float32)
|
||||
qs, rs = vmap(jsp.linalg.qr)(args)
|
||||
self.assertTrue(onp.all(onp.linalg.norm(args - onp.matmul(qs, rs)) < 1e-3))
|
||||
self.assertTrue(np.all(np.linalg.norm(args - np.matmul(qs, rs)) < 1e-3))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_pnorm={}".format(jtu.format_shape_dtype_string(shape, dtype), pnorm),
|
||||
"shape": shape, "pnorm": pnorm, "dtype": dtype}
|
||||
for shape in [(1, 1), (4, 4), (2, 3, 5), (5, 5, 5), (20, 20), (5, 10)]
|
||||
for pnorm in [np.inf, -np.inf, 1, -1, 2, -2, 'fro']
|
||||
for pnorm in [jnp.inf, -jnp.inf, 1, -1, 2, -2, 'fro']
|
||||
for dtype in float_types + complex_types))
|
||||
@jtu.skip_on_devices("tpu") # SVD is not implemented on the TPU backend
|
||||
@jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors
|
||||
@ -648,12 +648,12 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
args_maker = args_gen(pnorm)
|
||||
if pnorm not in [2, -2] and len(set(shape[-2:])) != 1:
|
||||
with self.assertRaises(onp.linalg.LinAlgError):
|
||||
np.linalg.cond(*args_maker())
|
||||
with self.assertRaises(np.linalg.LinAlgError):
|
||||
jnp.linalg.cond(*args_maker())
|
||||
else:
|
||||
self._CheckAgainstNumpy(onp.linalg.cond, np.linalg.cond, args_maker,
|
||||
self._CheckAgainstNumpy(np.linalg.cond, jnp.linalg.cond, args_maker,
|
||||
check_dtypes=False, tol=1e-3)
|
||||
partial_norm = partial(np.linalg.cond, p=pnorm)
|
||||
partial_norm = partial(jnp.linalg.cond, p=pnorm)
|
||||
self._CompileAndCheck(partial_norm, lambda: [gen_mat()],
|
||||
check_dtypes=False, rtol=1e-03, atol=1e-03)
|
||||
|
||||
@ -674,16 +674,16 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
while not invertible:
|
||||
a = rng(shape, dtype)
|
||||
try:
|
||||
onp.linalg.inv(a)
|
||||
np.linalg.inv(a)
|
||||
invertible = True
|
||||
except onp.linalg.LinAlgError:
|
||||
except np.linalg.LinAlgError:
|
||||
pass
|
||||
return a
|
||||
|
||||
args_maker = lambda: [tensor_maker(), int(onp.floor(len(shape) / 2))]
|
||||
self._CheckAgainstNumpy(onp.linalg.tensorinv, np.linalg.tensorinv, args_maker,
|
||||
args_maker = lambda: [tensor_maker(), int(np.floor(len(shape) / 2))]
|
||||
self._CheckAgainstNumpy(np.linalg.tensorinv, jnp.linalg.tensorinv, args_maker,
|
||||
check_dtypes=False, tol=1e-3)
|
||||
partial_inv = partial(np.linalg.tensorinv, ind=int(onp.floor(len(shape) / 2)))
|
||||
partial_inv = partial(jnp.linalg.tensorinv, ind=int(np.floor(len(shape) / 2)))
|
||||
self._CompileAndCheck(partial_inv, lambda: [tensor_maker()], check_dtypes=False, rtol=1e-03, atol=1e-03)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -707,9 +707,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(onp.linalg.solve, np.linalg.solve, args_maker,
|
||||
self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
self._CompileAndCheck(np.linalg.solve, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp.linalg.solve, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -730,15 +730,15 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
while not invertible:
|
||||
a = rng(shape, dtype)
|
||||
try:
|
||||
onp.linalg.inv(a)
|
||||
np.linalg.inv(a)
|
||||
invertible = True
|
||||
except onp.linalg.LinAlgError:
|
||||
except np.linalg.LinAlgError:
|
||||
pass
|
||||
return [a]
|
||||
|
||||
self._CheckAgainstNumpy(onp.linalg.inv, np.linalg.inv, args_maker,
|
||||
self._CheckAgainstNumpy(np.linalg.inv, jnp.linalg.inv, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp.linalg.inv, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -754,26 +754,26 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(onp.linalg.pinv, np.linalg.pinv, args_maker,
|
||||
self._CheckAgainstNumpy(np.linalg.pinv, jnp.linalg.pinv, args_maker,
|
||||
check_dtypes=True, tol=1e-2)
|
||||
self._CompileAndCheck(np.linalg.pinv, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp.linalg.pinv, args_maker, check_dtypes=True)
|
||||
# TODO(phawkins): 1e-1 seems like a very loose tolerance.
|
||||
jtu.check_grads(np.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1)
|
||||
jtu.check_grads(jnp.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1)
|
||||
|
||||
@jtu.skip_on_devices("tpu") # SVD is not implemented on the TPU backend
|
||||
def testPinvGradIssue2792(self):
|
||||
def f(p):
|
||||
a = np.array([[0., 0.],[-p, 1.]], np.float32) * 1 / (1 + p**2)
|
||||
return np.linalg.pinv(a)
|
||||
j = jax.jacobian(f)(np.float32(2.))
|
||||
self.assertAllClose(np.array([[0., -1.], [ 0., 0.]], np.float32), j,
|
||||
a = jnp.array([[0., 0.],[-p, 1.]], jnp.float32) * 1 / (1 + p**2)
|
||||
return jnp.linalg.pinv(a)
|
||||
j = jax.jacobian(f)(jnp.float32(2.))
|
||||
self.assertAllClose(jnp.array([[0., -1.], [ 0., 0.]], jnp.float32), j,
|
||||
check_dtypes=True)
|
||||
|
||||
expected = np.array([[[[-1., 0.], [ 0., 0.]], [[0., -1.], [0., 0.]]],
|
||||
expected = jnp.array([[[[-1., 0.], [ 0., 0.]], [[0., -1.], [0., 0.]]],
|
||||
[[[0., 0.], [-1., 0.]], [[0., 0.], [0., -1.]]]],
|
||||
dtype=np.float32)
|
||||
dtype=jnp.float32)
|
||||
self.assertAllClose(
|
||||
expected, jax.jacobian(np.linalg.pinv)(np.eye(2, dtype=np.float32)),
|
||||
expected, jax.jacobian(jnp.linalg.pinv)(jnp.eye(2, dtype=jnp.float32)),
|
||||
check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -791,10 +791,10 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
|
||||
self._CheckAgainstNumpy(partial(onp.linalg.matrix_power, n=n),
|
||||
partial(np.linalg.matrix_power, n=n),
|
||||
self._CheckAgainstNumpy(partial(np.linalg.matrix_power, n=n),
|
||||
partial(jnp.linalg.matrix_power, n=n),
|
||||
args_maker, check_dtypes=True, tol=tol)
|
||||
self._CompileAndCheck(partial(np.linalg.matrix_power, n=n), args_maker,
|
||||
self._CompileAndCheck(partial(jnp.linalg.matrix_power, n=n), args_maker,
|
||||
check_dtypes=True, rtol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -811,9 +811,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
n = shape[-1]
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
a, = args_maker()
|
||||
self._CheckAgainstNumpy(onp.linalg.matrix_rank, np.linalg.matrix_rank,
|
||||
self._CheckAgainstNumpy(np.linalg.matrix_rank, jnp.linalg.matrix_rank,
|
||||
args_maker, check_dtypes=False, tol=1e-3)
|
||||
self._CompileAndCheck(np.linalg.matrix_rank, args_maker,
|
||||
self._CompileAndCheck(jnp.linalg.matrix_rank, args_maker,
|
||||
check_dtypes=False, rtol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -832,12 +832,12 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [[rng(shape, dtype) for shape in shapes]]
|
||||
|
||||
onp_fun = onp.linalg.multi_dot
|
||||
jnp_fun = partial(np.linalg.multi_dot, precision=lax.Precision.HIGHEST)
|
||||
tol = {onp.float32: 1e-4, onp.float64: 1e-10,
|
||||
onp.complex64: 1e-4, onp.complex128: 1e-10}
|
||||
np_fun = np.linalg.multi_dot
|
||||
jnp_fun = partial(jnp.linalg.multi_dot, precision=lax.Precision.HIGHEST)
|
||||
tol = {np.float32: 1e-4, np.float64: 1e-10,
|
||||
np.complex64: 1e-4, np.complex128: 1e-10}
|
||||
|
||||
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
@ -846,39 +846,39 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu") # TODO(phawkins): No complex eigh implementation on TPU.
|
||||
def testIssue669(self):
|
||||
def test(x):
|
||||
val, vec = np.linalg.eigh(x)
|
||||
return np.real(np.sum(val))
|
||||
val, vec = jnp.linalg.eigh(x)
|
||||
return jnp.real(jnp.sum(val))
|
||||
|
||||
grad_test_jc = jit(grad(jit(test)))
|
||||
xc = onp.eye(3, dtype=onp.complex)
|
||||
xc = np.eye(3, dtype=np.complex)
|
||||
self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True)
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testIssue1151(self):
|
||||
A = np.array(onp.random.randn(100, 3, 3), dtype=np.float32)
|
||||
b = np.array(onp.random.randn(100, 3), dtype=np.float32)
|
||||
x = np.linalg.solve(A, b)
|
||||
self.assertAllClose(vmap(np.dot)(A, x), b, atol=1e-3, rtol=1e-2,
|
||||
A = jnp.array(np.random.randn(100, 3, 3), dtype=jnp.float32)
|
||||
b = jnp.array(np.random.randn(100, 3), dtype=jnp.float32)
|
||||
x = jnp.linalg.solve(A, b)
|
||||
self.assertAllClose(vmap(jnp.dot)(A, x), b, atol=1e-3, rtol=1e-2,
|
||||
check_dtypes=True)
|
||||
jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A, b)
|
||||
jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A, b)
|
||||
jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A[0], b[0])
|
||||
jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A[0], b[0])
|
||||
jac0 = jax.jacobian(jnp.linalg.solve, argnums=0)(A, b)
|
||||
jac1 = jax.jacobian(jnp.linalg.solve, argnums=1)(A, b)
|
||||
jac0 = jax.jacobian(jnp.linalg.solve, argnums=0)(A[0], b[0])
|
||||
jac1 = jax.jacobian(jnp.linalg.solve, argnums=1)(A[0], b[0])
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testIssue1383(self):
|
||||
seed = jax.random.PRNGKey(0)
|
||||
tmp = jax.random.uniform(seed, (2,2))
|
||||
a = np.dot(tmp, tmp.T)
|
||||
a = jnp.dot(tmp, tmp.T)
|
||||
|
||||
def f(inp):
|
||||
val, vec = np.linalg.eigh(inp)
|
||||
return np.dot(np.dot(vec, inp), vec.T)
|
||||
val, vec = jnp.linalg.eigh(inp)
|
||||
return jnp.dot(jnp.dot(vec, inp), vec.T)
|
||||
|
||||
grad_func = jax.jacfwd(f)
|
||||
hess_func = jax.jacfwd(grad_func)
|
||||
cube_func = jax.jacfwd(hess_func)
|
||||
self.assertFalse(onp.any(onp.isnan(cube_func(a))))
|
||||
self.assertFalse(np.any(np.isnan(cube_func(a))))
|
||||
|
||||
|
||||
class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
@ -890,8 +890,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
(1,),
|
||||
(7, -2),
|
||||
(3, 4, 5),
|
||||
(onp.ones((3, 4), dtype=np.float_), 5,
|
||||
onp.random.randn(5, 2).astype(np.float_)),
|
||||
(np.ones((3, 4), dtype=jnp.float_), 5,
|
||||
np.random.randn(5, 2).astype(jnp.float_)),
|
||||
])))
|
||||
def testBlockDiag(self, args):
|
||||
args_maker = lambda: args
|
||||
@ -914,15 +914,15 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
x, = args_maker()
|
||||
p, l, u = jsp.linalg.lu(x)
|
||||
self.assertAllClose(x, onp.matmul(p, onp.matmul(l, u)), check_dtypes=True,
|
||||
rtol={onp.float32: 1e-3, onp.float64: 1e-12,
|
||||
onp.complex64: 1e-3, onp.complex128: 1e-12})
|
||||
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)), check_dtypes=True,
|
||||
rtol={np.float32: 1e-3, np.float64: 1e-12,
|
||||
np.complex64: 1e-3, np.complex128: 1e-12})
|
||||
self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True)
|
||||
|
||||
def testLuOfSingularMatrix(self):
|
||||
x = np.array([[-1., 3./2], [2./3, -1.]], dtype=onp.float32)
|
||||
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
|
||||
p, l, u = jsp.linalg.lu(x)
|
||||
self.assertAllClose(x, onp.matmul(p, onp.matmul(l, u)), check_dtypes=True)
|
||||
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)), check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -945,18 +945,18 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype, "rng_factory": rng_factory}
|
||||
for shape in [(4, 5), (6, 5)]
|
||||
for dtype in [np.float32]
|
||||
for dtype in [jnp.float32]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testLuBatching(self, shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
args = [rng(shape, np.float32) for _ in range(10)]
|
||||
args = [rng(shape, jnp.float32) for _ in range(10)]
|
||||
expected = list(osp.linalg.lu(x) for x in args)
|
||||
ps = onp.stack([out[0] for out in expected])
|
||||
ls = onp.stack([out[1] for out in expected])
|
||||
us = onp.stack([out[2] for out in expected])
|
||||
ps = np.stack([out[0] for out in expected])
|
||||
ls = np.stack([out[1] for out in expected])
|
||||
us = np.stack([out[2] for out in expected])
|
||||
|
||||
actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(np.stack(args))
|
||||
actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(jnp.stack(args))
|
||||
self.assertAllClose(ps, actual_ps, check_dtypes=True)
|
||||
self.assertAllClose(ls, actual_ls, check_dtypes=True)
|
||||
self.assertAllClose(us, actual_us, check_dtypes=True)
|
||||
@ -976,11 +976,11 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
x, = args_maker()
|
||||
lu, piv = jsp.linalg.lu_factor(x)
|
||||
l = onp.tril(lu, -1) + onp.eye(n, dtype=dtype)
|
||||
u = onp.triu(lu)
|
||||
l = np.tril(lu, -1) + np.eye(n, dtype=dtype)
|
||||
u = np.triu(lu)
|
||||
for i in range(n):
|
||||
x[[i, piv[i]],] = x[[piv[i], i],]
|
||||
self.assertAllClose(x, onp.matmul(l, u), check_dtypes=True, rtol=1e-3,
|
||||
self.assertAllClose(x, np.matmul(l, u), check_dtypes=True, rtol=1e-3,
|
||||
atol=1e-3)
|
||||
self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True)
|
||||
|
||||
@ -1039,7 +1039,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
if (sym_pos and np.issubdtype(dtype, onp.complexfloating) and
|
||||
if (sym_pos and jnp.issubdtype(dtype, np.complexfloating) and
|
||||
jtu.device_under_test() == "tpu"):
|
||||
raise unittest.SkipTest(
|
||||
"Complex Cholesky decomposition not implemented on TPU")
|
||||
@ -1049,8 +1049,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
a = rng(lhs_shape, dtype)
|
||||
if sym_pos:
|
||||
a = onp.matmul(a, onp.conj(T(a)))
|
||||
a = onp.tril(a) if lower else onp.triu(a)
|
||||
a = np.matmul(a, np.conj(T(a)))
|
||||
a = np.tril(a) if lower else np.triu(a)
|
||||
return [a, rng(rhs_shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
|
||||
@ -1081,22 +1081,22 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
rng = rng_factory(self.rng())
|
||||
k = rng(lhs_shape, dtype)
|
||||
l = onp.linalg.cholesky(onp.matmul(k, T(k))
|
||||
+ lhs_shape[-1] * onp.eye(lhs_shape[-1]))
|
||||
l = np.linalg.cholesky(np.matmul(k, T(k))
|
||||
+ lhs_shape[-1] * np.eye(lhs_shape[-1]))
|
||||
l = l.astype(k.dtype)
|
||||
b = rng(rhs_shape, dtype)
|
||||
|
||||
if unit_diagonal:
|
||||
a = onp.tril(l, -1) + onp.eye(lhs_shape[-1], dtype=dtype)
|
||||
a = np.tril(l, -1) + np.eye(lhs_shape[-1], dtype=dtype)
|
||||
else:
|
||||
a = l
|
||||
a = a if lower else T(a)
|
||||
|
||||
inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype)
|
||||
inv = np.linalg.inv(T(a) if transpose_a else a).astype(a.dtype)
|
||||
if len(lhs_shape) == len(rhs_shape):
|
||||
onp_ans = onp.matmul(inv, b)
|
||||
np_ans = np.matmul(inv, b)
|
||||
else:
|
||||
onp_ans = onp.einsum("...ij,...j->...i", inv, b)
|
||||
np_ans = np.einsum("...ij,...j->...i", inv, b)
|
||||
|
||||
# The standard scipy.linalg.solve_triangular doesn't support broadcasting.
|
||||
# But it seems like an inevitable extension so we support it.
|
||||
@ -1104,8 +1104,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower,
|
||||
unit_diagonal=unit_diagonal)
|
||||
|
||||
self.assertAllClose(onp_ans, ans, check_dtypes=True,
|
||||
rtol={onp.float32: 1e-4, onp.float64: 1e-11})
|
||||
self.assertAllClose(np_ans, ans, check_dtypes=True,
|
||||
rtol={np.float32: 1e-4, np.float64: 1e-11})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -1122,7 +1122,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
for dtype in float_types + complex_types
|
||||
for transpose_a in [False, True]
|
||||
for conjugate_a in (
|
||||
[False] if np.issubdtype(dtype, np.floating) else [False, True])
|
||||
[False] if jnp.issubdtype(dtype, jnp.floating) else [False, True])
|
||||
for left_side, a_shape, b_shape in [
|
||||
(False, (4, 4), (4,)),
|
||||
(False, (4, 4), (1, 4,)),
|
||||
@ -1141,7 +1141,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
# Test lax_linalg.triangular_solve instead of scipy.linalg.solve_triangular
|
||||
# because it exposes more options.
|
||||
A = np.tril(rng(a_shape, dtype) + 5 * onp.eye(a_shape[-1], dtype=dtype))
|
||||
A = jnp.tril(rng(a_shape, dtype) + 5 * np.eye(a_shape[-1], dtype=dtype))
|
||||
A = A if lower else T(A)
|
||||
B = rng(b_shape, dtype)
|
||||
f = partial(lax_linalg.triangular_solve, lower=lower,
|
||||
@ -1165,21 +1165,21 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
]))
|
||||
def testTriangularSolveBatching(self, left_side, a_shape, b_shape, bdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
A = np.tril(rng(a_shape, onp.float32)
|
||||
+ 5 * onp.eye(a_shape[-1], dtype=onp.float32))
|
||||
B = rng(b_shape, onp.float32)
|
||||
A = jnp.tril(rng(a_shape, np.float32)
|
||||
+ 5 * np.eye(a_shape[-1], dtype=np.float32))
|
||||
B = rng(b_shape, np.float32)
|
||||
solve = partial(lax_linalg.triangular_solve, lower=True,
|
||||
transpose_a=False, conjugate_a=False,
|
||||
unit_diagonal=False, left_side=left_side)
|
||||
X = vmap(solve, bdims)(A, B)
|
||||
matmul = partial(np.matmul, precision=lax.Precision.HIGHEST)
|
||||
matmul = partial(jnp.matmul, precision=lax.Precision.HIGHEST)
|
||||
Y = matmul(A, X) if left_side else matmul(X, A)
|
||||
onp.testing.assert_allclose(Y - B, 0, atol=1e-4)
|
||||
np.testing.assert_allclose(Y - B, 0, atol=1e-4)
|
||||
|
||||
def testTriangularSolveGradPrecision(self):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
a = np.tril(rng((3, 3), onp.float32))
|
||||
b = rng((1, 3), onp.float32)
|
||||
a = jnp.tril(rng((3, 3), np.float32))
|
||||
b = rng((1, 3), np.float32)
|
||||
jtu.assert_dot_precision(
|
||||
lax.Precision.HIGHEST,
|
||||
partial(jvp, lax_linalg.triangular_solve),
|
||||
@ -1205,7 +1205,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
check_dtypes=True)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
args_maker_triu = lambda: [onp.triu(rng((n, n), dtype))]
|
||||
args_maker_triu = lambda: [np.triu(rng((n, n), dtype))]
|
||||
jsp_fun_triu = lambda a: jsp.linalg.expm(a,upper_triangular=True)
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu,
|
||||
check_dtypes=True)
|
||||
@ -1220,7 +1220,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
))
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testIssue2131(self, n, dtype):
|
||||
args_maker_zeros = lambda: [onp.zeros((n, n), dtype)]
|
||||
args_maker_zeros = lambda: [np.zeros((n, n), dtype)]
|
||||
osp_fun = lambda a: osp.linalg.expm(a)
|
||||
jsp_fun = lambda a: jsp.linalg.expm(a)
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker_zeros,
|
||||
@ -1248,10 +1248,10 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
b = rng(rhs_shape, dtype)
|
||||
if lower:
|
||||
L = onp.tril(rng(lhs_shape, dtype))
|
||||
L = np.tril(rng(lhs_shape, dtype))
|
||||
return [(L, lower), b]
|
||||
else:
|
||||
U = onp.triu(rng(lhs_shape, dtype))
|
||||
U = np.triu(rng(lhs_shape, dtype))
|
||||
return [(U, lower), b]
|
||||
self._CheckAgainstNumpy(osp.linalg.cho_solve, jsp.linalg.cho_solve,
|
||||
args_maker, check_dtypes=True, tol=1e-3)
|
||||
|
@ -16,12 +16,12 @@
|
||||
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
from jax import api, lax, ops
|
||||
from jax import core
|
||||
from jax import numpy as np
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.experimental import loops
|
||||
|
||||
@ -61,9 +61,9 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True)
|
||||
self.assertAllClose(5., api.grad(f_op)(2.), check_dtypes=True)
|
||||
self.assertAllClose(5., api.grad(f_op)(2.), check_dtypes=True)
|
||||
inc_batch = onp.arange(5, dtype=np.float_)
|
||||
self.assertAllClose(np.array([f_expected(inc) for inc in inc_batch],
|
||||
dtype=np.float_),
|
||||
inc_batch = np.arange(5, dtype=jnp.float_)
|
||||
self.assertAllClose(jnp.array([f_expected(inc) for inc in inc_batch],
|
||||
dtype=jnp.float_),
|
||||
api.vmap(f_op)(inc_batch), check_dtypes=True)
|
||||
|
||||
|
||||
@ -86,14 +86,14 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
with loops.Scope() as s:
|
||||
n = x.shape[0]
|
||||
assert n == y.shape[0]
|
||||
s.out = np.zeros(shape=[n], dtype=np.float32)
|
||||
s.out = jnp.zeros(shape=[n], dtype=jnp.float32)
|
||||
for i in s.range(n):
|
||||
s.out = ops.index_add(s.out, i, x[i] + y[i])
|
||||
return s.out
|
||||
|
||||
x = np.array([1., 2., 3.], dtype=np.float32)
|
||||
y = np.array([4., 5., 6.], dtype=np.float32)
|
||||
self.assertAllClose(np.add(x, y), add_vec(x, y), check_dtypes=True)
|
||||
x = jnp.array([1., 2., 3.], dtype=jnp.float32)
|
||||
y = jnp.array([4., 5., 6.], dtype=jnp.float32)
|
||||
self.assertAllClose(jnp.add(x, y), add_vec(x, y), check_dtypes=True)
|
||||
|
||||
def test_matmul(self):
|
||||
def matmul(x, y):
|
||||
@ -101,16 +101,16 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
n, m = x.shape
|
||||
m1, p = y.shape
|
||||
assert m == m1
|
||||
s.out = np.zeros(shape=[n, p], dtype=np.float32)
|
||||
s.out = jnp.zeros(shape=[n, p], dtype=jnp.float32)
|
||||
for i in s.range(n):
|
||||
for j in s.range(p):
|
||||
for k in s.range(m):
|
||||
s.out = ops.index_add(s.out, (i, j), x[i, k] * y[k, j])
|
||||
return s.out
|
||||
|
||||
x = np.array([[1., 2., 3.]], dtype=np.float32) # 1x3
|
||||
y = np.array([[4.], [5.], [6.]], dtype=np.float32) # 3x1
|
||||
self.assertAllClose(np.matmul(x, y), matmul(x, y), check_dtypes=True)
|
||||
x = jnp.array([[1., 2., 3.]], dtype=jnp.float32) # 1x3
|
||||
y = jnp.array([[4.], [5.], [6.]], dtype=jnp.float32) # 3x1
|
||||
self.assertAllClose(jnp.matmul(x, y), matmul(x, y), check_dtypes=True)
|
||||
|
||||
def test_reuse_range(self):
|
||||
"""Ranges can be reused, as long as not nested in each other."""
|
||||
@ -142,7 +142,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
def test_example_doc(self):
|
||||
"The example from the module docstring."
|
||||
def f_expected():
|
||||
arr = onp.zeros(5, dtype=np.float_)
|
||||
arr = np.zeros(5, dtype=jnp.float_)
|
||||
for i in range(arr.shape[0]):
|
||||
arr[i] += 2.
|
||||
if i % 2 == 0:
|
||||
@ -150,7 +150,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
return arr
|
||||
|
||||
def f_op_jax():
|
||||
arr = np.zeros(5)
|
||||
arr = jnp.zeros(5)
|
||||
def loop_body(i, acc_arr):
|
||||
arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
|
||||
return lax.cond(i % 2 == 0,
|
||||
@ -163,7 +163,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
|
||||
def f_op_loops():
|
||||
with loops.Scope() as s:
|
||||
s.arr = np.zeros(5) # Must create the mutable state of the loop as `scope` fields.
|
||||
s.arr = jnp.zeros(5) # Must create the mutable state of the loop as `scope` fields.
|
||||
for i in s.range(s.arr.shape[0]):
|
||||
s.arr = ops.index_update(s.arr, i, s.arr[i] + 2.)
|
||||
for _ in s.cond_range(i % 2 == 0): # Conditionals are also sugared as loops with 0 or 1 iterations
|
||||
@ -288,7 +288,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, "Abstract tracer value encountered where concrete value is expected"):
|
||||
self.assertAllClose(16., api.jit(f_op)(0, 4, 4.), check_dtypes=True)
|
||||
with self.assertRaisesRegex(TypeError, "Abstract tracer value encountered where concrete value is expected"):
|
||||
self.assertAllClose(16., api.vmap(f_op)(np.zeros(10), np.ones(10), np.array([4.] * 10)), check_dtypes=True)
|
||||
self.assertAllClose(16., api.vmap(f_op)(jnp.zeros(10), jnp.ones(10), jnp.array([4.] * 10)), check_dtypes=True)
|
||||
|
||||
def test_cond(self):
|
||||
def f_op(inc):
|
||||
@ -377,9 +377,9 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True)
|
||||
self.assertAllClose(f_expected(1.), f_op(1.), check_dtypes=True)
|
||||
init_batch = onp.array([1., 2., 3.], dtype=onp.float32)
|
||||
self.assertAllClose(onp.array([f_expected(init) for init in init_batch],
|
||||
dtype=onp.float32),
|
||||
init_batch = np.array([1., 2., 3.], dtype=np.float32)
|
||||
self.assertAllClose(np.array([f_expected(init) for init in init_batch],
|
||||
dtype=np.float32),
|
||||
api.vmap(f_op)(init_batch), check_dtypes=True)
|
||||
|
||||
def test_error_while_cond_mutation(self):
|
||||
|
@ -17,11 +17,11 @@ from functools import partial
|
||||
import itertools as it
|
||||
from unittest import SkipTest
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
from absl.testing import absltest, parameterized
|
||||
from jax.interpreters.masking import shape_as_value, ShapeError, \
|
||||
parse_spec, Poly, Mon
|
||||
from jax import numpy as np, test_util as jtu, mask, vmap, jit, grad, lax, \
|
||||
from jax import numpy as jnp, test_util as jtu, mask, vmap, jit, grad, lax, \
|
||||
shapecheck, api
|
||||
from jax.config import config
|
||||
from jax.numpy.lax_numpy import _polymorphic_slice_indices
|
||||
@ -60,9 +60,9 @@ class ShapesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_Poly_equal(self):
|
||||
assert constant_poly(3) == 3
|
||||
assert onp.array(3, onp.int64) == constant_poly(3)
|
||||
assert onp.array(3, onp.int64)[()] == constant_poly(3)
|
||||
assert not onp.array(3, onp.int64) != constant_poly(3)
|
||||
assert np.array(3, np.int64) == constant_poly(3)
|
||||
assert np.array(3, np.int64)[()] == constant_poly(3)
|
||||
assert not np.array(3, np.int64) != constant_poly(3)
|
||||
assert constant_poly(4) != 3
|
||||
assert 3 == constant_poly(3)
|
||||
assert 4 != constant_poly(3)
|
||||
@ -109,27 +109,27 @@ class ShapesTest(jtu.JaxTestCase):
|
||||
def test_sum(self):
|
||||
@shapecheck(['(m, n)'], '')
|
||||
def sum(x):
|
||||
return np.sum(x)
|
||||
return jnp.sum(x)
|
||||
|
||||
def test_prod(self):
|
||||
@shapecheck(['(m, n)'], '')
|
||||
def prod(x):
|
||||
return np.prod(x)
|
||||
return jnp.prod(x)
|
||||
|
||||
def test_max(self):
|
||||
@shapecheck(['(m, n)'], '')
|
||||
def prod(x):
|
||||
return np.max(x)
|
||||
return jnp.max(x)
|
||||
|
||||
def test_min(self):
|
||||
@shapecheck(['(m, n)'], '')
|
||||
def prod(x):
|
||||
return np.min(x)
|
||||
return jnp.min(x)
|
||||
|
||||
def test_dot(self):
|
||||
@shapecheck(['(m, n)', 'n'], 'm')
|
||||
def matvec(A, b):
|
||||
return np.dot(A, b)
|
||||
return jnp.dot(A, b)
|
||||
|
||||
def thunk():
|
||||
@shapecheck(['(m, n)', 'n'], 'm')
|
||||
@ -159,12 +159,12 @@ class ShapesTest(jtu.JaxTestCase):
|
||||
return api.device_put(x)
|
||||
|
||||
def test_broadcast_in_dim(self):
|
||||
x = np.zeros(7)
|
||||
x = jnp.zeros(7)
|
||||
|
||||
@shapecheck(['(n,)'], '(3, n, 4)')
|
||||
def broadcast_in_dim(x):
|
||||
return lax.broadcast_in_dim(x, shape=(3, x.shape[0], 4), broadcast_dimensions=(1,))
|
||||
x = np.zeros((7, 1))
|
||||
x = jnp.zeros((7, 1))
|
||||
|
||||
@shapecheck(['(n, 1)'], '(3, n, 4, 1)')
|
||||
def broadcast_in_dim(x):
|
||||
@ -181,17 +181,17 @@ class ShapesTest(jtu.JaxTestCase):
|
||||
# @jit
|
||||
# @grad
|
||||
# def sum_square(x):
|
||||
# return np.sum(x ** 2)
|
||||
# return jnp.sum(x ** 2)
|
||||
|
||||
def test_pad(self):
|
||||
@shapecheck(['n'], '2*n+1')
|
||||
def p(x):
|
||||
return lax.pad(x, np.array(0., x.dtype), [(1, 1, 1)])
|
||||
return lax.pad(x, jnp.array(0., x.dtype), [(1, 1, 1)])
|
||||
|
||||
def test_numpy_pad(self):
|
||||
@shapecheck(['n'], 'n+1')
|
||||
def p(x):
|
||||
return np.pad(x, (0, 1))
|
||||
return jnp.pad(x, (0, 1))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{
|
||||
@ -217,9 +217,9 @@ class ShapesTest(jtu.JaxTestCase):
|
||||
dimension_numbers, lhs_perm, rhs_perm, out_perm):
|
||||
valid = padding == 'VALID'
|
||||
is_strided = strides[0] != 1
|
||||
lhs_shape = '({}, {}, {}, {})'.format(*onp.take(['n', 'i', '2*h' if is_strided else 'h', 'w'], lhs_perm))
|
||||
rhs_shape = '({}, {}, {}, {})'.format(*onp.take(['o', 'i', '2', '3'], rhs_perm))
|
||||
out_shape = '({}, {}, {}, {})'.format(*onp.take([
|
||||
lhs_shape = '({}, {}, {}, {})'.format(*np.take(['n', 'i', '2*h' if is_strided else 'h', 'w'], lhs_perm))
|
||||
rhs_shape = '({}, {}, {}, {})'.format(*np.take(['o', 'i', '2', '3'], rhs_perm))
|
||||
out_shape = '({}, {}, {}, {})'.format(*np.take([
|
||||
'n', 'o', 'h+-1' if valid and not is_strided else 'h',
|
||||
('w+-2' if valid else 'w') if lhs_dilation is None else '2*w+-1'], out_perm))
|
||||
|
||||
@ -275,14 +275,14 @@ class ShapesTest(jtu.JaxTestCase):
|
||||
# https://travis-ci.org/github/google/jax/jobs/682086351
|
||||
@shapecheck(['n'], 'n')
|
||||
def range_like(x):
|
||||
return lax.iota(np.int32, x.shape[0])
|
||||
return lax.iota(jnp.int32, x.shape[0])
|
||||
|
||||
def test_arange(self):
|
||||
raise SkipTest("not yet implemented")
|
||||
# https://travis-ci.org/github/google/jax/jobs/682086351
|
||||
@shapecheck(['n'], 'n')
|
||||
def arange_like(x):
|
||||
return np.arange(x.shape[0], dtype=np.int32)
|
||||
return jnp.arange(x.shape[0], dtype=jnp.int32)
|
||||
|
||||
def test_expit(self):
|
||||
@shapecheck(['n'], 'n')
|
||||
@ -292,10 +292,10 @@ class ShapesTest(jtu.JaxTestCase):
|
||||
def test_reshape(self):
|
||||
@shapecheck(['n, a, b'], 'n, a*b')
|
||||
def flatten(x):
|
||||
return np.reshape(x, (x.shape[0], x.shape[1] * x.shape[2]))
|
||||
return jnp.reshape(x, (x.shape[0], x.shape[1] * x.shape[2]))
|
||||
|
||||
def test_ravel(self):
|
||||
a = np.array(1)
|
||||
a = jnp.array(1)
|
||||
|
||||
@shapecheck(['n'], '')
|
||||
def thunk(n):
|
||||
@ -306,23 +306,23 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
def test_sum(self):
|
||||
@partial(mask, in_shapes=['n'], out_shape='')
|
||||
def padded_sum(x):
|
||||
return np.sum(x)
|
||||
return jnp.sum(x)
|
||||
|
||||
ans = padded_sum([np.array([3, 1, 4, 1, 5])], dict(n=3))
|
||||
ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=3))
|
||||
expected = 8
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = padded_sum([np.array([3, 1, 4, 1, 5])], dict(n=4))
|
||||
ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=4))
|
||||
expected = 9
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_sum_vmap(self):
|
||||
@partial(mask, in_shapes=['n'], out_shape='')
|
||||
def padded_sum(x):
|
||||
return np.sum(x)
|
||||
return jnp.sum(x)
|
||||
|
||||
ans = vmap(padded_sum)([np.ones((5, 10))], dict(n=np.arange(5)))
|
||||
expected = onp.array([0, 1, 2, 3, 4])
|
||||
ans = vmap(padded_sum)([jnp.ones((5, 10))], dict(n=jnp.arange(5)))
|
||||
expected = np.array([0, 1, 2, 3, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_add(self):
|
||||
@ -330,13 +330,13 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
def addvecs(x, y):
|
||||
return x + y
|
||||
|
||||
x = np.array([3, 1, 4, 1, 5, 9])
|
||||
y = np.array([2, 6, 5, 3, 5, 8])
|
||||
x = jnp.array([3, 1, 4, 1, 5, 9])
|
||||
y = jnp.array([2, 6, 5, 3, 5, 8])
|
||||
ans = addvecs([x, y], dict(n=3))
|
||||
expected = onp.array([5, 7, 9])
|
||||
expected = np.array([5, 7, 9])
|
||||
self.assertAllClose(ans[:3], expected, check_dtypes=False)
|
||||
|
||||
thunk = lambda: addvecs([np.arange(5), np.arange(6)], dict(n=3))
|
||||
thunk = lambda: addvecs([jnp.arange(5), jnp.arange(6)], dict(n=3))
|
||||
self.assertRaisesRegex(ShapeError, "", thunk)
|
||||
|
||||
def test_scan(self):
|
||||
@ -345,7 +345,7 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
|
||||
return out
|
||||
|
||||
ans = cumsum([np.array([5, 2, 9, 1, 4])], dict(n=3))
|
||||
ans = cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=3))
|
||||
expected = 16
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ -355,8 +355,8 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
|
||||
return out
|
||||
|
||||
ans = vmap(cumsum)([np.arange(6).reshape(2, 3)], dict(n=np.array([1, 2])))
|
||||
expected = onp.array([0, 7])
|
||||
ans = vmap(cumsum)([jnp.arange(6).reshape(2, 3)], dict(n=jnp.array([1, 2])))
|
||||
expected = np.array([0, 7])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_scan_jit(self):
|
||||
@ -371,17 +371,17 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
return cumsum(args, shape_env)
|
||||
|
||||
python_should_be_executing = True
|
||||
ans = jit_cumsum([np.array([5, 2, 9, 1, 4])], dict(n=3))
|
||||
ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=3))
|
||||
expected = 16
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
python_should_be_executing = False
|
||||
ans = jit_cumsum([np.array([5, 2, 9, 1, 4])], dict(n=4))
|
||||
ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=4))
|
||||
expected = 17
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
python_should_be_executing = False
|
||||
ans = jit_cumsum([np.array([5, 2, 9, 1, 4])], dict(n=1))
|
||||
ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=1))
|
||||
expected = 5
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ -390,9 +390,9 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
def cat(x, y, z):
|
||||
return lax.concatenate([x, y, z], 0)
|
||||
|
||||
ans = cat([np.array([1, 9]), np.array([2, 4, 9]), np.array([3, 9])],
|
||||
ans = cat([jnp.array([1, 9]), jnp.array([2, 4, 9]), jnp.array([3, 9])],
|
||||
dict(n=1, m=2))
|
||||
expected = onp.array([1, 2, 4, 3])
|
||||
expected = np.array([1, 2, 4, 3])
|
||||
self.assertAllClose(ans[:4], expected, check_dtypes=False)
|
||||
|
||||
def test_dot(self):
|
||||
@ -400,46 +400,46 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
def dot(x, y):
|
||||
return lax.dot(x, y)
|
||||
|
||||
x = onp.arange(6, dtype=onp.float32).reshape((2, 3))
|
||||
y = onp.arange(12, dtype=onp.float32).reshape((3, 4))
|
||||
x = np.arange(6, dtype=np.float32).reshape((2, 3))
|
||||
y = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
ans = dot([x, y], dict(m=2, k=2, n=2))
|
||||
expected = onp.dot(x[:2, :2], y[:2, :2])
|
||||
expected = np.dot(x[:2, :2], y[:2, :2])
|
||||
self.assertAllClose(ans[:2, :2], expected, check_dtypes=False)
|
||||
|
||||
def test_mean(self):
|
||||
@partial(mask, in_shapes=['n'], out_shape='')
|
||||
def padded_sum(x):
|
||||
return np.sum(x) / shape_as_value(x.shape)[0]
|
||||
return jnp.sum(x) / shape_as_value(x.shape)[0]
|
||||
|
||||
ans = padded_sum([np.array([3, 1, 4, 1, 5])], dict(n=3))
|
||||
ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=3))
|
||||
expected = 8 / 3
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_monomorphic(self):
|
||||
@partial(mask, in_shapes=['(_, n)'], out_shape='')
|
||||
def padded_sum(x):
|
||||
return np.sum(x)
|
||||
return jnp.sum(x)
|
||||
|
||||
ans = padded_sum([np.array([[3, 4], [5, 6]])], dict(n=1))
|
||||
ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=1))
|
||||
expected = 8
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_monomorphic2(self):
|
||||
@partial(mask, in_shapes=['(_, n)'], out_shape='n')
|
||||
def padded_sum(x):
|
||||
return np.sum(x, axis=0)
|
||||
return jnp.sum(x, axis=0)
|
||||
|
||||
ans = padded_sum([np.array([[3, 4], [5, 6]])], dict(n=2))
|
||||
expected = np.array([8, 10])
|
||||
ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=2))
|
||||
expected = jnp.array([8, 10])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_monomorphic3(self):
|
||||
@partial(mask, in_shapes=['(_, n)'], out_shape='_')
|
||||
def padded_sum(x):
|
||||
return np.sum(x, axis=1)
|
||||
return jnp.sum(x, axis=1)
|
||||
|
||||
ans = padded_sum([np.array([[3, 4], [5, 6]])], dict(n=1))
|
||||
expected = np.array([3, 5])
|
||||
ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=1))
|
||||
expected = jnp.array([3, 5])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_rnn(self):
|
||||
@ -448,14 +448,14 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
@partial(mask, in_shapes=['(_, _)', '(t, _)'], out_shape='_')
|
||||
def rnn(W, xs):
|
||||
def step(h, x):
|
||||
new_h = np.dot(W, h) + np.dot(W, x)
|
||||
new_h = jnp.dot(W, h) + jnp.dot(W, x)
|
||||
return new_h, ()
|
||||
predicted, _ = lax.scan(step, np.zeros(n), xs)
|
||||
predicted, _ = lax.scan(step, jnp.zeros(n), xs)
|
||||
return predicted
|
||||
|
||||
rng = onp.random.RandomState(0)
|
||||
W = np.eye(n)
|
||||
xs = rng.randn(10, n).astype(np.float_)
|
||||
rng = np.random.RandomState(0)
|
||||
W = jnp.eye(n)
|
||||
xs = rng.randn(10, n).astype(jnp.float_)
|
||||
ans = rnn([W, xs], dict(t=4))
|
||||
expected = xs[:4].sum(0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
@ -466,24 +466,24 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
@partial(mask, in_shapes=['(_, _)', '(t, _)', '_'], out_shape='')
|
||||
def rnn(W, xs, target):
|
||||
def step(h, x):
|
||||
new_h = np.tanh(np.dot(W, h) + np.dot(W, x))
|
||||
new_h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
|
||||
return new_h, ()
|
||||
predicted, _ = lax.scan(step, np.zeros(n), xs)
|
||||
return np.sum((predicted - target)**2)
|
||||
predicted, _ = lax.scan(step, jnp.zeros(n), xs)
|
||||
return jnp.sum((predicted - target)**2)
|
||||
|
||||
rng = onp.random.RandomState(0)
|
||||
W = rng.randn(n, n).astype(np.float_)
|
||||
xs = rng.randn(10, n).astype(np.float_)
|
||||
y = rng.randn(n).astype(np.float_)
|
||||
rng = np.random.RandomState(0)
|
||||
W = rng.randn(n, n).astype(jnp.float_)
|
||||
xs = rng.randn(10, n).astype(jnp.float_)
|
||||
y = rng.randn(n).astype(jnp.float_)
|
||||
|
||||
ans = grad(lambda W: rnn([W, xs, y], dict(t=4)))(W)
|
||||
|
||||
def rnn_reference(W, xs, target):
|
||||
h = np.zeros(n)
|
||||
h = jnp.zeros(n)
|
||||
for x in xs:
|
||||
h = np.tanh(np.dot(W, h) + np.dot(W, x))
|
||||
h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
|
||||
predicted = h
|
||||
return np.sum((predicted - target)**2)
|
||||
return jnp.sum((predicted - target)**2)
|
||||
|
||||
expected = grad(lambda W: rnn_reference(W, xs[:4], y))(W)
|
||||
|
||||
@ -495,27 +495,27 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
@partial(mask, in_shapes=('(_, _)', '(t, _)', '_'), out_shape='')
|
||||
def rnn(W, xs, target):
|
||||
def step(h, x):
|
||||
new_h = np.tanh(np.dot(W, h) + np.dot(W, x))
|
||||
new_h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
|
||||
return new_h, ()
|
||||
predicted, _ = lax.scan(step, np.zeros(n), xs)
|
||||
return np.sum((predicted - target)**2)
|
||||
predicted, _ = lax.scan(step, jnp.zeros(n), xs)
|
||||
return jnp.sum((predicted - target)**2)
|
||||
|
||||
rng = onp.random.RandomState(0)
|
||||
W = rng.randn(n, n).astype(np.float_)
|
||||
seqs = rng.randn(3, 10, n).astype(np.float_)
|
||||
ts = np.array([2, 5, 4])
|
||||
rng = np.random.RandomState(0)
|
||||
W = rng.randn(n, n).astype(jnp.float_)
|
||||
seqs = rng.randn(3, 10, n).astype(jnp.float_)
|
||||
ts = jnp.array([2, 5, 4])
|
||||
ys = rng.randn(3, n)
|
||||
|
||||
ans = grad(lambda W: vmap(rnn, ((None, 0, 0), 0))((W, seqs, ys), dict(t=ts)).sum())(W)
|
||||
|
||||
def rnn_reference(W, seqs, targets):
|
||||
total_loss = np.array(0, np.float_)
|
||||
total_loss = jnp.array(0, jnp.float_)
|
||||
for xs, target in zip(seqs, targets):
|
||||
h = np.zeros(n)
|
||||
h = jnp.zeros(n)
|
||||
for x in xs:
|
||||
h = np.tanh(np.dot(W, h) + np.dot(W, x))
|
||||
h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
|
||||
predicted = h
|
||||
total_loss = total_loss + np.sum((predicted - target)**2)
|
||||
total_loss = total_loss + jnp.sum((predicted - target)**2)
|
||||
return total_loss
|
||||
|
||||
seqs_ = [xs[:t] for xs, t in zip(seqs, ts)]
|
||||
@ -530,7 +530,7 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
|
||||
@partial(mask, in_shapes=['n'], out_shape='')
|
||||
def padded_sum(x):
|
||||
return np.sum(x)
|
||||
return jnp.sum(x)
|
||||
|
||||
batched_sum = vmap(padded_sum)
|
||||
|
||||
@ -538,10 +538,10 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
def fun(x, ns):
|
||||
return batched_sum([x], dict(n=ns)).sum()
|
||||
|
||||
x = np.array([[3, 1, 4, 1],
|
||||
x = jnp.array([[3, 1, 4, 1],
|
||||
[5, 9, 2, 6],
|
||||
[5, 3, 5, 8]])
|
||||
ns = np.array([2, 3, 2])
|
||||
ns = jnp.array([2, 3, 2])
|
||||
ans = fun([x, ns], dict(m=2))
|
||||
expected = 3+1 + 5+9+2
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
@ -553,8 +553,8 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
def padded_add(x):
|
||||
return x + lax.iota(x.shape[0])
|
||||
|
||||
ans = padded_add([np.array([3, 1, 4, 1, 5])], dict(n=3))
|
||||
expected = onp.array([3, 2, 6])
|
||||
ans = padded_add([jnp.array([3, 1, 4, 1, 5])], dict(n=3))
|
||||
expected = np.array([3, 2, 6])
|
||||
self.assertAllClose(ans[:3], expected, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -577,8 +577,8 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_slice_oob_indexing(self):
|
||||
# https://github.com/google/jax/issues/2245
|
||||
self.assertAllClose(np.ones(5), np.ones(5)[:10], check_dtypes=True)
|
||||
self.assertAllClose(np.ones(5), np.ones(5)[-10:], check_dtypes=True)
|
||||
self.assertAllClose(jnp.ones(5), jnp.ones(5)[:10], check_dtypes=True)
|
||||
self.assertAllClose(jnp.ones(5), jnp.ones(5)[-10:], check_dtypes=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
@ -18,13 +18,13 @@ from functools import partial
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
import numpy.random as npr
|
||||
from unittest import SkipTest
|
||||
|
||||
from jax import api
|
||||
from jax import test_util as jtu
|
||||
from jax import numpy as np
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -46,10 +46,10 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
raise SkipTest("Backend is not CPU or the device under test")
|
||||
@partial(api.jit, backend=backend)
|
||||
def fun(x, y):
|
||||
return np.matmul(x, y)
|
||||
return jnp.matmul(x, y)
|
||||
x = npr.uniform(size=(10,10))
|
||||
y = npr.uniform(size=(10,10))
|
||||
z_host = onp.matmul(x, y)
|
||||
z_host = np.matmul(x, y)
|
||||
z = fun(x, y)
|
||||
self.assertAllClose(z, z_host, check_dtypes=True, rtol=1e-2)
|
||||
correct_platform = backend if backend else jtu.device_under_test()
|
||||
@ -67,11 +67,11 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
def fun(x, y):
|
||||
@partial(api.jit, backend=inner)
|
||||
def infun(x, y):
|
||||
return np.matmul(x, y)
|
||||
return infun(x, y) + np.ones_like(x)
|
||||
return jnp.matmul(x, y)
|
||||
return infun(x, y) + jnp.ones_like(x)
|
||||
x = npr.uniform(size=(10,10))
|
||||
y = npr.uniform(size=(10,10))
|
||||
z_host = onp.matmul(x, y) + onp.ones_like(x)
|
||||
z_host = np.matmul(x, y) + np.ones_like(x)
|
||||
z = fun(x, y)
|
||||
self.assertAllClose(z, z_host, check_dtypes=True, rtol=1e-2)
|
||||
correct_platform = outer if outer else jtu.device_under_test()
|
||||
@ -95,8 +95,8 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
def fun(x, y):
|
||||
@partial(api.jit, backend=inner)
|
||||
def infun(x, y):
|
||||
return np.matmul(x, y)
|
||||
return infun(x, y) + np.ones_like(x)
|
||||
return jnp.matmul(x, y)
|
||||
return infun(x, y) + jnp.ones_like(x)
|
||||
x = npr.uniform(size=(10,10))
|
||||
y = npr.uniform(size=(10,10))
|
||||
self.assertRaises(ValueError, lambda: fun(x, y))
|
||||
@ -111,11 +111,11 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
raise SkipTest("Backend is not CPU or the device under test")
|
||||
@partial(api.jit, backend=backend)
|
||||
def fun(x, y):
|
||||
return np.matmul(x, y)
|
||||
return jnp.matmul(x, y)
|
||||
x = npr.uniform(size=(10,10))
|
||||
y = npr.uniform(size=(10,10))
|
||||
z = fun(x, y)
|
||||
w = np.sin(z)
|
||||
w = jnp.sin(z)
|
||||
self.assertEqual(z.device_buffer.platform(), backend)
|
||||
self.assertEqual(w.device_buffer.platform(), backend)
|
||||
|
||||
@ -123,13 +123,13 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
def testJitCpu(self):
|
||||
@partial(api.jit, backend='cpu')
|
||||
def get_arr(scale):
|
||||
return scale + np.ones((2, 2))
|
||||
return scale + jnp.ones((2, 2))
|
||||
|
||||
x = get_arr(0.1)
|
||||
|
||||
a = x / x.shape[0]
|
||||
b = x + np.ones_like(x)
|
||||
c = x + np.eye(2)
|
||||
b = x + jnp.ones_like(x)
|
||||
c = x + jnp.eye(2)
|
||||
|
||||
self.assertEqual(a.device_buffer.device(), api.devices('cpu')[0])
|
||||
self.assertEqual(b.device_buffer.device(), api.devices('cpu')[0])
|
||||
@ -138,7 +138,7 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends
|
||||
def test_closed_over_values_device_placement(self):
|
||||
# see https://github.com/google/jax/issues/1431
|
||||
def f(): return np.add(3., 4.)
|
||||
def f(): return jnp.add(3., 4.)
|
||||
self.assertNotEqual(api.jit(f)().device_buffer.device(),
|
||||
api.devices('cpu')[0])
|
||||
self.assertEqual(api.jit(f, backend='cpu')().device_buffer.device(),
|
||||
@ -156,7 +156,7 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
data_on_cpu = api.device_put(1, device=cpus[0])
|
||||
self.assertEqual(data_on_cpu.device_buffer.device(), cpus[0])
|
||||
|
||||
def my_sin(x): return np.sin(x)
|
||||
def my_sin(x): return jnp.sin(x)
|
||||
# jit without any device spec follows the data
|
||||
result1 = api.jit(my_sin)(2)
|
||||
self.assertEqual(result1.device_buffer.device(), default_dev)
|
||||
@ -176,7 +176,7 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
# https://github.com/google/jax/issues/2905
|
||||
cpus = api.devices("cpu")
|
||||
|
||||
x = api.device_put(onp.ones(2), cpus[0])
|
||||
x = api.device_put(np.ones(2), cpus[0])
|
||||
y = x[0]
|
||||
self.assertEqual(y.device_buffer.device(), cpus[0])
|
||||
|
||||
@ -185,7 +185,7 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
# https://github.com/google/jax/issues/2905
|
||||
cpus = api.devices("cpu")
|
||||
|
||||
x = api.device_put(onp.ones(2), cpus[0])
|
||||
x = api.device_put(np.ones(2), cpus[0])
|
||||
y = x.sum()
|
||||
self.assertEqual(y.device_buffer.device(), cpus[0])
|
||||
|
||||
|
@ -13,23 +13,23 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from jax import numpy as np
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu, jit, partial
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
float_dtypes = [onp.float32, onp.float64]
|
||||
float_dtypes = [np.float32, np.float64]
|
||||
# implementation casts to complex64.
|
||||
complex_dtypes = [onp.complex64]
|
||||
complex_dtypes = [np.complex64]
|
||||
inexact_dtypes = float_dtypes + complex_dtypes
|
||||
int_dtypes = [onp.int32, onp.int64]
|
||||
int_dtypes = [np.int32, np.int64]
|
||||
real_dtypes = float_dtypes + int_dtypes
|
||||
all_dtypes = real_dtypes + complex_dtypes
|
||||
|
||||
@ -51,17 +51,17 @@ class TestPolynomial(jtu.JaxTestCase):
|
||||
for leading in [0, 1, 2, 3, 5, 7, 10]
|
||||
for trailing in [0, 1, 2, 3, 5, 7, 10]))
|
||||
def testRoots(self, dtype, rng_factory, length, leading, trailing):
|
||||
rng = rng_factory(onp.random.RandomState(0))
|
||||
rng = rng_factory(np.random.RandomState(0))
|
||||
|
||||
def args_maker():
|
||||
p = rng((length,), dtype)
|
||||
return np.concatenate(
|
||||
[np.zeros(leading, p.dtype), p, np.zeros(trailing, p.dtype)]),
|
||||
return jnp.concatenate(
|
||||
[jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)]),
|
||||
|
||||
# order may differ (np.sort doesn't deal with complex numbers)
|
||||
np_fn = lambda arg: onp.sort(np.roots(arg))
|
||||
onp_fn = lambda arg: onp.sort(onp.roots(arg))
|
||||
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=False,
|
||||
# order may differ (jnp.sort doesn't deal with complex numbers)
|
||||
np_fn = lambda arg: np.sort(jnp.roots(arg))
|
||||
np_fn = lambda arg: np.sort(np.roots(arg))
|
||||
self._CheckAgainstNumpy(np_fn, np_fn, args_maker, check_dtypes=False,
|
||||
tol=3e-6)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -74,20 +74,20 @@ class TestPolynomial(jtu.JaxTestCase):
|
||||
for length in [0, 1, 3, 10]
|
||||
for trailing in [0, 1, 3, 7]))
|
||||
def testRootsNostrip(self, length, dtype, rng_factory, trailing):
|
||||
rng = rng_factory(onp.random.RandomState(0))
|
||||
rng = rng_factory(np.random.RandomState(0))
|
||||
|
||||
def args_maker():
|
||||
p = rng((length,), dtype)
|
||||
if length != 0:
|
||||
return np.concatenate([p, np.zeros(trailing, p.dtype)]),
|
||||
return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]),
|
||||
else:
|
||||
# adding trailing would make input invalid (start with zeros)
|
||||
return p,
|
||||
|
||||
# order may differ (np.sort doesn't deal with complex numbers)
|
||||
np_fn = lambda arg: onp.sort(np.roots(arg, strip_zeros=False))
|
||||
onp_fn = lambda arg: onp.sort(onp.roots(arg))
|
||||
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker,
|
||||
# order may differ (jnp.sort doesn't deal with complex numbers)
|
||||
np_fn = lambda arg: np.sort(jnp.roots(arg, strip_zeros=False))
|
||||
np_fn = lambda arg: np.sort(np.roots(arg))
|
||||
self._CheckAgainstNumpy(np_fn, np_fn, args_maker,
|
||||
check_dtypes=False, tol=1e-6)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -103,23 +103,23 @@ class TestPolynomial(jtu.JaxTestCase):
|
||||
# for GPU/TPU.
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testRootsJit(self, length, dtype, rng_factory, trailing):
|
||||
rng = rng_factory(onp.random.RandomState(0))
|
||||
rng = rng_factory(np.random.RandomState(0))
|
||||
|
||||
def args_maker():
|
||||
p = rng((length,), dtype)
|
||||
if length != 0:
|
||||
return np.concatenate([p, np.zeros(trailing, p.dtype)]),
|
||||
return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]),
|
||||
else:
|
||||
# adding trailing would make input invalid (start with zeros)
|
||||
return p,
|
||||
|
||||
# order may differ (np.sort doesn't deal with complex numbers)
|
||||
roots_compiled = jit(partial(np.roots, strip_zeros=False))
|
||||
np_fn = lambda arg: onp.sort(roots_compiled(arg))
|
||||
onp_fn = lambda arg: onp.sort(onp.roots(arg))
|
||||
# order may differ (jnp.sort doesn't deal with complex numbers)
|
||||
roots_compiled = jit(partial(jnp.roots, strip_zeros=False))
|
||||
np_fn = lambda arg: np.sort(roots_compiled(arg))
|
||||
np_fn = lambda arg: np.sort(np.roots(arg))
|
||||
# Using strip_zeros=False makes the algorithm less efficient
|
||||
# and leads to slightly different values compared ot numpy
|
||||
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker,
|
||||
self._CheckAgainstNumpy(np_fn, np_fn, args_maker,
|
||||
check_dtypes=False, tol=1e-6)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -133,19 +133,19 @@ class TestPolynomial(jtu.JaxTestCase):
|
||||
for zeros in [1, 2, 5]
|
||||
for nonzeros in [0, 3]))
|
||||
def testRootsInvalid(self, zeros, nonzeros, dtype, rng_factory):
|
||||
rng = rng_factory(onp.random.RandomState(0))
|
||||
rng = rng_factory(np.random.RandomState(0))
|
||||
|
||||
# The polynomial coefficients here start with zero and would have to
|
||||
# be stripped before computing eigenvalues of the companion matrix.
|
||||
# Setting strip_zeros=False skips this check,
|
||||
# allowing jit transformation but yielding nan's for these inputs.
|
||||
p = np.concatenate([np.zeros(zeros, dtype), rng((nonzeros,), dtype)])
|
||||
p = jnp.concatenate([jnp.zeros(zeros, dtype), rng((nonzeros,), dtype)])
|
||||
|
||||
if p.size == 1:
|
||||
# polynomial = const has no roots
|
||||
self.assertTrue(np.roots(p, strip_zeros=False).size == 0)
|
||||
self.assertTrue(jnp.roots(p, strip_zeros=False).size == 0)
|
||||
else:
|
||||
self.assertTrue(np.any(np.isnan(np.roots(p, strip_zeros=False))))
|
||||
self.assertTrue(jnp.any(jnp.isnan(jnp.roots(p, strip_zeros=False))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -19,7 +19,7 @@ from unittest import SkipTest
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
import scipy.special
|
||||
import scipy.stats
|
||||
@ -28,7 +28,7 @@ from jax import api
|
||||
from jax import core
|
||||
from jax import grad
|
||||
from jax import lax
|
||||
from jax import numpy as np
|
||||
from jax import numpy as jnp
|
||||
from jax import random
|
||||
from jax import test_util as jtu
|
||||
from jax import vmap
|
||||
@ -45,9 +45,9 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
nitems = len(samples)
|
||||
nbins = 2 ** nbits
|
||||
nexpected = nbins * (1 - ((nbins - 1) / nbins) ** nitems)
|
||||
ncollisions = len(onp.unique(samples))
|
||||
ncollisions = len(np.unique(samples))
|
||||
sq_percent_deviation = ((ncollisions - nexpected) / nexpected) ** 2
|
||||
self.assertLess(sq_percent_deviation, 1 / onp.sqrt(nexpected * fail_prob))
|
||||
self.assertLess(sq_percent_deviation, 1 / np.sqrt(nexpected * fail_prob))
|
||||
|
||||
def _CheckKolmogorovSmirnovCDF(self, samples, cdf):
|
||||
fail_prob = 0.01 # conservative bound on statistical fail prob by Kolmo CDF
|
||||
@ -55,7 +55,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
def _CheckChiSquared(self, samples, pmf):
|
||||
alpha = 0.01 # significance level, threshold for p-value
|
||||
values, actual_freq = onp.unique(samples, return_counts=True)
|
||||
values, actual_freq = np.unique(samples, return_counts=True)
|
||||
expected_freq = pmf(values) * samples.size
|
||||
# per scipy: "A typical rule is that all of the observed and expected
|
||||
# frequencies should be at least 5."
|
||||
@ -71,16 +71,16 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
f'{expected_freq[valid]}\n{actual_freq[valid]}')
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
|
||||
for dtype in [np.float32, np.float64]))
|
||||
def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
|
||||
if not FLAGS.jax_enable_x64 and np.issubdtype(dtype, onp.float64):
|
||||
if not FLAGS.jax_enable_x64 and jnp.issubdtype(dtype, np.float64):
|
||||
raise SkipTest("can't test float64 agreement")
|
||||
|
||||
bits_dtype = onp.uint32 if np.finfo(dtype).bits == 32 else onp.uint64
|
||||
numpy_bits = onp.array(1., dtype).view(bits_dtype)
|
||||
bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
|
||||
numpy_bits = np.array(1., dtype).view(bits_dtype)
|
||||
xla_bits = api.jit(
|
||||
lambda: lax.bitcast_convert_type(onp.array(1., dtype), bits_dtype))()
|
||||
lambda: lax.bitcast_convert_type(np.array(1., dtype), bits_dtype))()
|
||||
self.assertEqual(numpy_bits, xla_bits)
|
||||
|
||||
def testThreefry2x32(self):
|
||||
@ -91,34 +91,34 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
return tuple([hex(x.copy()).rstrip("L") for x in result])
|
||||
|
||||
expected = ("0x6b200159", "0x99ba4efe")
|
||||
result = random.threefry_2x32(onp.uint32([0, 0]), onp.uint32([0, 0]))
|
||||
result = random.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0]))
|
||||
|
||||
self.assertEqual(expected, result_to_hex(result))
|
||||
|
||||
expected = ("0x1cb996fc", "0xbb002be7")
|
||||
result = random.threefry_2x32(onp.uint32([-1, -1]), onp.uint32([-1, -1]))
|
||||
result = random.threefry_2x32(np.uint32([-1, -1]), np.uint32([-1, -1]))
|
||||
self.assertEqual(expected, result_to_hex(result))
|
||||
|
||||
expected = ("0xc4923a9c", "0x483df7a0")
|
||||
result = random.threefry_2x32(
|
||||
onp.uint32([0x13198a2e, 0x03707344]),
|
||||
onp.uint32([0x243f6a88, 0x85a308d3]))
|
||||
np.uint32([0x13198a2e, 0x03707344]),
|
||||
np.uint32([0x243f6a88, 0x85a308d3]))
|
||||
self.assertEqual(expected, result_to_hex(result))
|
||||
|
||||
def testThreefry2x32Large(self):
|
||||
n = 10000000
|
||||
result = random.threefry_2x32(
|
||||
(onp.uint32(0x13198a2e), onp.uint32(0x03707344)),
|
||||
np.concatenate([
|
||||
np.full((n,), 0x243f6a88, np.uint32),
|
||||
np.full((n,), 0x85a308d3, np.uint32)
|
||||
(np.uint32(0x13198a2e), np.uint32(0x03707344)),
|
||||
jnp.concatenate([
|
||||
jnp.full((n,), 0x243f6a88, jnp.uint32),
|
||||
jnp.full((n,), 0x85a308d3, jnp.uint32)
|
||||
]))
|
||||
onp.testing.assert_equal(result[:n], onp.full((n,), 0xc4923a9c, dtype=onp.uint32))
|
||||
onp.testing.assert_equal(result[n:], onp.full((n,), 0x483df7a0, dtype=onp.uint32))
|
||||
np.testing.assert_equal(result[:n], np.full((n,), 0xc4923a9c, dtype=np.uint32))
|
||||
np.testing.assert_equal(result[n:], np.full((n,), 0x483df7a0, dtype=np.uint32))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
|
||||
for dtype in [np.float32, np.float64]))
|
||||
def testRngUniform(self, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key: random.uniform(key, (10000,), dtype)
|
||||
@ -128,12 +128,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
compiled_samples = crand(key)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckCollisions(samples, np.finfo(dtype).nmant)
|
||||
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.int32, onp.int64]))
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
|
||||
for dtype in [np.int32, np.int64]))
|
||||
def testRngRandint(self, dtype):
|
||||
lo = 5
|
||||
hi = 10
|
||||
@ -146,12 +146,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
compiled_samples = crand(key)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self.assertTrue(onp.all(lo <= samples))
|
||||
self.assertTrue(onp.all(samples < hi))
|
||||
self.assertTrue(np.all(lo <= samples))
|
||||
self.assertTrue(np.all(samples < hi))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
|
||||
for dtype in [np.float32, np.float64]))
|
||||
def testNormal(self, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key: random.normal(key, (10000,), dtype)
|
||||
@ -164,11 +164,11 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.float32, onp.float64, onp.int32, onp.int64]))
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
|
||||
for dtype in [np.float32, np.float64, np.int32, np.int64]))
|
||||
def testShuffle(self, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
x = onp.arange(100).astype(dtype)
|
||||
x = np.arange(100).astype(dtype)
|
||||
rand = lambda key: random.shuffle(key, x)
|
||||
crand = api.jit(rand)
|
||||
|
||||
@ -178,17 +178,17 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
perm2 = crand(key)
|
||||
|
||||
self.assertAllClose(perm1, perm2, check_dtypes=True)
|
||||
self.assertFalse(onp.all(perm1 == x)) # seems unlikely!
|
||||
self.assertAllClose(onp.sort(perm1), x, check_dtypes=False)
|
||||
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
|
||||
self.assertAllClose(np.sort(perm1), x, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"dtype": onp.dtype(dtype).name, "shape": shape}
|
||||
for dtype in [onp.float32, onp.float64, onp.int32, onp.int64]
|
||||
"dtype": np.dtype(dtype).name, "shape": shape}
|
||||
for dtype in [np.float32, np.float64, np.int32, np.int64]
|
||||
for shape in [100, (10, 10), (10, 5, 2)]))
|
||||
def testPermutationArray(self, dtype, shape):
|
||||
key = random.PRNGKey(0)
|
||||
x = np.arange(np.prod(shape)).reshape(shape).astype(dtype)
|
||||
x = jnp.arange(jnp.prod(shape)).reshape(shape).astype(dtype)
|
||||
rand = lambda key: random.permutation(key, x)
|
||||
crand = api.jit(rand)
|
||||
|
||||
@ -196,10 +196,10 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
perm2 = crand(key)
|
||||
|
||||
self.assertAllClose(perm1, perm2, check_dtypes=True)
|
||||
self.assertFalse(onp.all(perm1 == x)) # seems unlikely!
|
||||
self.assertAllClose(onp.sort(perm1.ravel()), x.ravel(), check_dtypes=False)
|
||||
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
|
||||
self.assertAllClose(np.sort(perm1.ravel()), x.ravel(), check_dtypes=False)
|
||||
self.assertArraysAllClose(
|
||||
x, np.arange(np.prod(shape)).reshape(shape).astype(dtype),
|
||||
x, jnp.arange(jnp.prod(shape)).reshape(shape).astype(dtype),
|
||||
check_dtypes=True)
|
||||
|
||||
def testPermutationInteger(self):
|
||||
@ -213,8 +213,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(perm1, perm2, check_dtypes=True)
|
||||
self.assertEqual(perm1.dtype, perm2.dtype)
|
||||
self.assertFalse(onp.all(perm1 == onp.arange(100))) # seems unlikely!
|
||||
self.assertAllClose(onp.sort(perm1), onp.arange(100), check_dtypes=False)
|
||||
self.assertFalse(np.all(perm1 == np.arange(100))) # seems unlikely!
|
||||
self.assertAllClose(np.sort(perm1), np.arange(100), check_dtypes=False)
|
||||
|
||||
def testPermutationErrors(self):
|
||||
key = random.PRNGKey(0)
|
||||
@ -225,12 +225,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_p={}_{}".format(p, dtype),
|
||||
"p": p, "dtype": onp.dtype(dtype).name}
|
||||
"p": p, "dtype": np.dtype(dtype).name}
|
||||
for p in [0.1, 0.5, 0.9]
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
for dtype in [np.float32, np.float64]))
|
||||
def testBernoulli(self, p, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
p = onp.array(p, dtype=dtype)
|
||||
p = np.array(p, dtype=dtype)
|
||||
rand = lambda key, p: random.bernoulli(key, p, (10000,))
|
||||
crand = api.jit(rand)
|
||||
|
||||
@ -242,7 +242,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_p={}_{}_{}".format(p, dtype, sample_shape),
|
||||
"p": p, "axis": axis, "dtype": onp.dtype(dtype).name, 'sample_shape': sample_shape}
|
||||
"p": p, "axis": axis, "dtype": np.dtype(dtype).name, 'sample_shape': sample_shape}
|
||||
for (p, axis) in [
|
||||
([.25] * 4, -1),
|
||||
([.1, .2, .3, .4], -1),
|
||||
@ -250,12 +250,12 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
([[.5, .1], [.5, .9]], 0),
|
||||
]
|
||||
for sample_shape in [(10000,), (5000, 2)]
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
for dtype in [np.float32, np.float64]))
|
||||
def testCategorical(self, p, axis, dtype, sample_shape):
|
||||
key = random.PRNGKey(0)
|
||||
p = onp.array(p, dtype=dtype)
|
||||
logits = onp.log(p) - 42 # test unnormalized
|
||||
out_shape = tuple(onp.delete(logits.shape, axis))
|
||||
p = np.array(p, dtype=dtype)
|
||||
logits = np.log(p) - 42 # test unnormalized
|
||||
out_shape = tuple(np.delete(logits.shape, axis))
|
||||
shape = sample_shape + out_shape
|
||||
rand = lambda key, p: random.categorical(key, logits, shape=shape, axis=axis)
|
||||
crand = api.jit(rand)
|
||||
@ -268,9 +268,9 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
assert samples.shape == shape
|
||||
samples = np.reshape(samples, (10000,) + out_shape)
|
||||
samples = jnp.reshape(samples, (10000,) + out_shape)
|
||||
if len(p.shape[:-1]) > 0:
|
||||
ps = onp.transpose(p, (1, 0)) if axis == 0 else p
|
||||
ps = np.transpose(p, (1, 0)) if axis == 0 else p
|
||||
for cat_samples, cat_p in zip(samples.transpose(), ps):
|
||||
self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x])
|
||||
else:
|
||||
@ -278,15 +278,15 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
def testBernoulliShape(self):
|
||||
key = random.PRNGKey(0)
|
||||
x = random.bernoulli(key, onp.array([0.2, 0.3]), shape=(3, 2))
|
||||
x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
|
||||
assert x.shape == (3, 2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_a={}_b={}_{}".format(a, b, dtype),
|
||||
"a": a, "b": b, "dtype": onp.dtype(dtype).name}
|
||||
"a": a, "b": b, "dtype": np.dtype(dtype).name}
|
||||
for a in [0.2, 5.]
|
||||
for b in [0.2, 5.]
|
||||
for dtype in [onp.float64])) # NOTE: KS test fails with float32
|
||||
for dtype in [np.float64])) # NOTE: KS test fails with float32
|
||||
def testBeta(self, a, b, dtype):
|
||||
if not FLAGS.jax_enable_x64:
|
||||
raise SkipTest("skip test except on X64")
|
||||
@ -301,8 +301,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
|
||||
for dtype in [np.float32, np.float64]))
|
||||
def testCauchy(self, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key: random.cauchy(key, (10000,), dtype)
|
||||
@ -316,11 +316,11 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_alpha={}_{}".format(alpha, dtype),
|
||||
"alpha": alpha, "dtype": onp.dtype(dtype).name}
|
||||
"alpha": alpha, "dtype": np.dtype(dtype).name}
|
||||
for alpha in [
|
||||
onp.array([0.2, 1., 5.]),
|
||||
np.array([0.2, 1., 5.]),
|
||||
]
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
for dtype in [np.float32, np.float64]))
|
||||
def testDirichlet(self, alpha, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype)
|
||||
@ -330,14 +330,14 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
compiled_samples = crand(key, alpha)
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self.assertAllClose(samples.sum(-1), onp.ones(10000, dtype=dtype), check_dtypes=True)
|
||||
self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype), check_dtypes=True)
|
||||
alpha_sum = sum(alpha)
|
||||
for i, a in enumerate(alpha):
|
||||
self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
|
||||
for dtype in [np.float32, np.float64]))
|
||||
def testExponential(self, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key: random.exponential(key, (10000,), dtype)
|
||||
@ -351,9 +351,9 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_a={}_{}".format(a, dtype),
|
||||
"a": a, "dtype": onp.dtype(dtype).name}
|
||||
"a": a, "dtype": np.dtype(dtype).name}
|
||||
for a in [0.1, 1., 10.]
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
for dtype in [np.float32, np.float64]))
|
||||
def testGamma(self, a, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
|
||||
@ -367,7 +367,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
def testGammaShape(self):
|
||||
key = random.PRNGKey(0)
|
||||
x = random.gamma(key, onp.array([0.2, 0.3]), shape=(3, 2))
|
||||
x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2))
|
||||
assert x.shape == (3, 2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -375,11 +375,11 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]))
|
||||
def testGammaGrad(self, alpha):
|
||||
rng = random.PRNGKey(0)
|
||||
alphas = onp.full((100,), alpha)
|
||||
alphas = np.full((100,), alpha)
|
||||
z = random.gamma(rng, alphas)
|
||||
actual_grad = api.grad(lambda x: random.gamma(rng, x).sum())(alphas)
|
||||
|
||||
eps = 0.01 * alpha / (1.0 + onp.sqrt(alpha))
|
||||
eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
|
||||
cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps)
|
||||
- scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps)
|
||||
pdf = scipy.stats.gamma.pdf(z, alpha)
|
||||
@ -391,17 +391,17 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def testGammaGradType(self):
|
||||
# Regression test for https://github.com/google/jax/issues/2130
|
||||
key = random.PRNGKey(0)
|
||||
a = np.array(1., dtype=np.float32)
|
||||
b = np.array(3., dtype=np.float32)
|
||||
f = lambda x, y: random.gamma(key=key, a=x, dtype=np.float32) / y
|
||||
a = jnp.array(1., dtype=jnp.float32)
|
||||
b = jnp.array(3., dtype=jnp.float32)
|
||||
f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y
|
||||
# Should not crash with a type error.
|
||||
api.vjp(f, a, b)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_lam={}_{}".format(lam, dtype),
|
||||
"lam": lam, "dtype": onp.dtype(dtype).name}
|
||||
"lam": lam, "dtype": np.dtype(dtype).name}
|
||||
for lam in [0.5, 3, 9, 11, 50, 500]
|
||||
for dtype in [onp.int32, onp.int64]))
|
||||
for dtype in [np.int32, np.int64]))
|
||||
def testPoisson(self, lam, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
|
||||
@ -419,19 +419,19 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
def testPoissonBatched(self):
|
||||
key = random.PRNGKey(0)
|
||||
lam = np.concatenate([2 * np.ones(10000), 20 * np.ones(10000)])
|
||||
lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)])
|
||||
samples = random.poisson(key, lam, shape=(20000,))
|
||||
self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)
|
||||
self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)
|
||||
|
||||
def testPoissonShape(self):
|
||||
key = random.PRNGKey(0)
|
||||
x = random.poisson(key, onp.array([2.0, 20.0]), shape=(3, 2))
|
||||
x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2))
|
||||
assert x.shape == (3, 2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
|
||||
for dtype in [np.float32, np.float64]))
|
||||
def testGumbel(self, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key: random.gumbel(key, (10000,), dtype)
|
||||
@ -444,8 +444,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gumbel_r().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
|
||||
for dtype in [np.float32, np.float64]))
|
||||
def testLaplace(self, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key: random.laplace(key, (10000,), dtype)
|
||||
@ -458,8 +458,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
|
||||
for dtype in [np.float32, np.float64]))
|
||||
def testLogistic(self, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key: random.logistic(key, (10000,), dtype)
|
||||
@ -473,9 +473,9 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_b={}_{}".format(b, dtype),
|
||||
"b": b, "dtype": onp.dtype(dtype).name}
|
||||
"b": b, "dtype": np.dtype(dtype).name}
|
||||
for b in [0.1, 1., 10.]
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
for dtype in [np.float32, np.float64]))
|
||||
def testPareto(self, b, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
rand = lambda key, b: random.pareto(key, b, (10000,), dtype)
|
||||
@ -489,14 +489,14 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
def testParetoShape(self):
|
||||
key = random.PRNGKey(0)
|
||||
x = random.pareto(key, onp.array([0.2, 0.3]), shape=(3, 2))
|
||||
x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2))
|
||||
assert x.shape == (3, 2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_df={}_{}".format(df, dtype),
|
||||
"df": df, "dtype": onp.dtype(dtype).name}
|
||||
"df": df, "dtype": np.dtype(dtype).name}
|
||||
for df in [0.1, 1., 10.]
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
for dtype in [np.float32, np.float64]))
|
||||
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
|
||||
def testT(self, df, dtype):
|
||||
key = random.PRNGKey(0)
|
||||
@ -510,29 +510,29 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}D_{}".format(dim, onp.dtype(dtype).name),
|
||||
{"testcase_name": "_{}D_{}".format(dim, np.dtype(dtype).name),
|
||||
"dim": dim, "dtype": dtype}
|
||||
for dim in [1, 3, 5]
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
for dtype in [np.float32, np.float64]))
|
||||
@jtu.skip_on_mac_linalg_bug()
|
||||
def testMultivariateNormal(self, dim, dtype):
|
||||
r = onp.random.RandomState(dim)
|
||||
r = np.random.RandomState(dim)
|
||||
mean = r.randn(dim)
|
||||
cov_factor = r.randn(dim, dim)
|
||||
cov = onp.dot(cov_factor, cov_factor.T) + dim * onp.eye(dim)
|
||||
cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim)
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
rand = partial(random.multivariate_normal, mean=mean, cov=cov,
|
||||
shape=(10000,))
|
||||
crand = api.jit(rand)
|
||||
|
||||
uncompiled_samples = onp.asarray(rand(key), onp.float64)
|
||||
compiled_samples = onp.asarray(crand(key), onp.float64)
|
||||
uncompiled_samples = np.asarray(rand(key), np.float64)
|
||||
compiled_samples = np.asarray(crand(key), np.float64)
|
||||
|
||||
inv_scale = scipy.linalg.lapack.dtrtri(onp.linalg.cholesky(cov), lower=True)[0]
|
||||
inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov), lower=True)[0]
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
centered = samples - mean
|
||||
whitened = onp.einsum('nj,ij->ni', centered, inv_scale)
|
||||
whitened = np.einsum('nj,ij->ni', centered, inv_scale)
|
||||
|
||||
# This is a quick-and-dirty multivariate normality check that tests that a
|
||||
# uniform mixture of the marginals along the covariance matrix's
|
||||
@ -543,25 +543,25 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def testMultivariateNormalCovariance(self):
|
||||
# test code based on https://github.com/google/jax/issues/1869
|
||||
N = 100000
|
||||
cov = np.array([[ 0.19, 0.00, -0.13, 0.00],
|
||||
cov = jnp.array([[ 0.19, 0.00, -0.13, 0.00],
|
||||
[ 0.00, 0.29, 0.00, -0.23],
|
||||
[ -0.13, 0.00, 0.39, 0.00],
|
||||
[ 0.00, -0.23, 0.00, 0.49]])
|
||||
mean = np.zeros(4)
|
||||
mean = jnp.zeros(4)
|
||||
|
||||
out_onp = onp.random.RandomState(0).multivariate_normal(mean, cov, N)
|
||||
out_np = np.random.RandomState(0).multivariate_normal(mean, cov, N)
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N,))
|
||||
|
||||
var_onp = out_onp.var(axis=0)
|
||||
var_np = out_np.var(axis=0)
|
||||
var_jnp = out_jnp.var(axis=0)
|
||||
self.assertAllClose(var_onp, var_jnp, rtol=1e-2, atol=1e-2,
|
||||
self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2,
|
||||
check_dtypes=False)
|
||||
|
||||
var_onp = onp.cov(out_onp, rowvar=False)
|
||||
var_jnp = onp.cov(out_jnp, rowvar=False)
|
||||
self.assertAllClose(var_onp, var_jnp, rtol=1e-2, atol=1e-2,
|
||||
var_np = np.cov(out_np, rowvar=False)
|
||||
var_jnp = np.cov(out_jnp, rowvar=False)
|
||||
self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2,
|
||||
check_dtypes=False)
|
||||
|
||||
def testIssue222(self):
|
||||
@ -571,7 +571,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def testFoldIn(self):
|
||||
key = random.PRNGKey(0)
|
||||
keys = [random.fold_in(key, i) for i in range(10)]
|
||||
assert onp.unique(onp.ravel(keys)).shape == (20,)
|
||||
assert np.unique(np.ravel(keys)).shape == (20,)
|
||||
|
||||
def testStaticShapeErrors(self):
|
||||
if config.read("jax_disable_jit"):
|
||||
@ -582,9 +582,9 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
key = random.PRNGKey(seed)
|
||||
W = random.normal(key, (d, n)) / sigma
|
||||
w = random.normal(key, (d, )) / sigma
|
||||
b = 2 * np.pi * random.uniform(key, (d, ))
|
||||
b = 2 * jnp.pi * random.uniform(key, (d, ))
|
||||
|
||||
phi = lambda x, t: np.sqrt(2.0 / d) * np.cos(np.matmul(W, x) + w*t + b)
|
||||
phi = lambda x, t: jnp.sqrt(2.0 / d) * jnp.cos(jnp.matmul(W, x) + w*t + b)
|
||||
return phi
|
||||
|
||||
self.assertRaisesRegex(TypeError, 'Shapes must be 1D.*',
|
||||
@ -594,21 +594,21 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
key = random.PRNGKey(0)
|
||||
w = random.normal(key, ())
|
||||
if FLAGS.jax_enable_x64:
|
||||
self.assertEqual(onp.result_type(w), onp.float64)
|
||||
self.assertEqual(np.result_type(w), np.float64)
|
||||
else:
|
||||
self.assertEqual(onp.result_type(w), onp.float32)
|
||||
self.assertEqual(np.result_type(w), np.float32)
|
||||
|
||||
def testIssue1789(self):
|
||||
def f(x):
|
||||
return random.gamma(random.PRNGKey(0), x)
|
||||
|
||||
grad(lambda x: np.sum(vmap(f)(x)))(np.ones(2))
|
||||
grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))
|
||||
|
||||
def testNoOpByOpUnderHash(self):
|
||||
def fail(*args, **kwargs): assert False
|
||||
apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
|
||||
try:
|
||||
out = random.threefry_2x32(onp.zeros(2, onp.uint32), onp.arange(10, dtype=onp.uint32))
|
||||
out = random.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32))
|
||||
finally:
|
||||
xla.apply_primitive = apply_primitive
|
||||
|
||||
@ -620,21 +620,21 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
if FLAGS.jax_enable_x64:
|
||||
self.assertAllClose(
|
||||
random.randint(k, (3, 3), 0, 8),
|
||||
onp.array([[7, 2, 6],
|
||||
np.array([[7, 2, 6],
|
||||
[2, 1, 0],
|
||||
[6, 7, 7]], dtype='int64'),
|
||||
check_dtypes=True)
|
||||
else:
|
||||
self.assertAllClose(
|
||||
random.randint(k, (3, 3), 0, 8),
|
||||
onp.array([[2, 1, 3],
|
||||
np.array([[2, 1, 3],
|
||||
[6, 1, 5],
|
||||
[6, 3, 4]], dtype='int32'),
|
||||
check_dtypes=True)
|
||||
|
||||
self.assertAllClose(
|
||||
random.split(k, 4),
|
||||
onp.array([[2285895361, 1501764800],
|
||||
np.array([[2285895361, 1501764800],
|
||||
[1518642379, 4090693311],
|
||||
[ 433833334, 4221794875],
|
||||
[ 839183663, 3740430601]], dtype='uint32'),
|
||||
@ -642,7 +642,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(
|
||||
random.fold_in(k, 4),
|
||||
onp.array([2285895361, 433833334], dtype='uint32'),
|
||||
np.array([2285895361, 433833334], dtype='uint32'),
|
||||
check_dtypes=True)
|
||||
|
||||
|
||||
|
@ -17,9 +17,7 @@
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as onp
|
||||
|
||||
from jax import numpy as np
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax import random
|
||||
from jax.experimental.vectorize import vectorize
|
||||
@ -27,24 +25,24 @@ from jax.experimental.vectorize import vectorize
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
matmat = vectorize('(n,m),(m,k)->(n,k)')(np.dot)
|
||||
matvec = vectorize('(n,m),(m)->(n)')(np.dot)
|
||||
vecmat = vectorize('(m),(m,k)->(k)')(np.dot)
|
||||
vecvec = vectorize('(m),(m)->()')(np.dot)
|
||||
matmat = vectorize('(n,m),(m,k)->(n,k)')(jnp.dot)
|
||||
matvec = vectorize('(n,m),(m)->(n)')(jnp.dot)
|
||||
vecmat = vectorize('(m),(m,k)->(k)')(jnp.dot)
|
||||
vecvec = vectorize('(m),(m)->()')(jnp.dot)
|
||||
|
||||
@vectorize('(n)->()')
|
||||
def magnitude(x):
|
||||
return np.dot(x, x)
|
||||
return jnp.dot(x, x)
|
||||
|
||||
mean = vectorize('(n)->()')(np.mean)
|
||||
mean = vectorize('(n)->()')(jnp.mean)
|
||||
|
||||
@vectorize('()->(n)')
|
||||
def stack_plus_minus(x):
|
||||
return np.stack([x, -x])
|
||||
return jnp.stack([x, -x])
|
||||
|
||||
@vectorize('(n)->(),(n)')
|
||||
def center(array):
|
||||
bias = np.mean(array)
|
||||
bias = jnp.mean(array)
|
||||
debiased = array - bias
|
||||
return bias, debiased
|
||||
|
||||
@ -60,8 +58,8 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
((6, 5, 2, 3), (3, 4), (6, 5, 2, 4)),
|
||||
]))
|
||||
def test_matmat(self, left_shape, right_shape, result_shape):
|
||||
self.assertEqual(matmat(np.zeros(left_shape),
|
||||
np.zeros(right_shape)).shape, result_shape)
|
||||
self.assertEqual(matmat(jnp.zeros(left_shape),
|
||||
jnp.zeros(right_shape)).shape, result_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
|
||||
@ -73,8 +71,8 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
((5, 4, 2, 3), (1, 3), (5, 4, 2)),
|
||||
]))
|
||||
def test_matvec(self, left_shape, right_shape, result_shape):
|
||||
self.assertEqual(matvec(np.zeros(left_shape),
|
||||
np.zeros(right_shape)).shape, result_shape)
|
||||
self.assertEqual(matvec(jnp.zeros(left_shape),
|
||||
jnp.zeros(right_shape)).shape, result_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
|
||||
@ -85,8 +83,8 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
((4, 2, 3), (3,), (4, 2)),
|
||||
]))
|
||||
def test_vecvec(self, left_shape, right_shape, result_shape):
|
||||
self.assertEqual(vecvec(np.zeros(left_shape),
|
||||
np.zeros(right_shape)).shape, result_shape)
|
||||
self.assertEqual(vecvec(jnp.zeros(left_shape),
|
||||
jnp.zeros(right_shape)).shape, result_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(shape),
|
||||
@ -100,7 +98,7 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
size = 1
|
||||
for x in shape:
|
||||
size *= x
|
||||
self.assertEqual(magnitude(np.arange(size).reshape(shape)).shape, result_shape)
|
||||
self.assertEqual(magnitude(jnp.arange(size).reshape(shape)).shape, result_shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(shape),
|
||||
@ -111,10 +109,10 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
((1, 2, 3, 4), (1, 2, 3)),
|
||||
]))
|
||||
def test_mean(self, shape, result_shape):
|
||||
self.assertEqual(mean(np.zeros(shape)).shape, result_shape)
|
||||
self.assertEqual(mean(jnp.zeros(shape)).shape, result_shape)
|
||||
|
||||
def test_mean_axis(self):
|
||||
self.assertEqual(mean(np.zeros((2, 3)), axis=0).shape, (3,))
|
||||
self.assertEqual(mean(jnp.zeros((2, 3)), axis=0).shape, (3,))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(shape),
|
||||
@ -124,24 +122,24 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
((3,), (3,2,)),
|
||||
]))
|
||||
def test_stack_plus_minus(self, shape, result_shape):
|
||||
self.assertEqual(stack_plus_minus(np.zeros(shape)).shape, result_shape)
|
||||
self.assertEqual(stack_plus_minus(jnp.zeros(shape)).shape, result_shape)
|
||||
|
||||
def test_center(self):
|
||||
b, a = center(np.arange(3))
|
||||
b, a = center(jnp.arange(3))
|
||||
self.assertEqual(a.shape, (3,))
|
||||
self.assertEqual(b.shape, ())
|
||||
self.assertAllClose(1.0, b, False)
|
||||
|
||||
X = np.arange(12).reshape((3, 4))
|
||||
X = jnp.arange(12).reshape((3, 4))
|
||||
b, a = center(X, axis=1)
|
||||
self.assertEqual(a.shape, (3, 4))
|
||||
self.assertEqual(b.shape, (3,))
|
||||
self.assertAllClose(np.mean(X, axis=1), b, True)
|
||||
self.assertAllClose(jnp.mean(X, axis=1), b, True)
|
||||
|
||||
b, a = center(X, axis=0)
|
||||
self.assertEqual(a.shape, (3, 4))
|
||||
self.assertEqual(b.shape, (4,))
|
||||
self.assertAllClose(np.mean(X, axis=0), b, True)
|
||||
self.assertAllClose(jnp.mean(X, axis=0), b, True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user