mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Cleanup: convert uses of 'import numpy as onp' in tests (#3756)
This commit is contained in:
parent
58aba9bcf1
commit
512ed18d5a
@ -17,12 +17,12 @@ import threading
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import lax, numpy as np
|
||||
from jax import lax, numpy as jnp
|
||||
from jax.config import config
|
||||
from jax.experimental import host_callback as hcb
|
||||
from jax.lib import xla_client
|
||||
import jax.test_util as jtu
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
@ -34,14 +34,14 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
token = lax.create_token(x)
|
||||
(y,), token = lax.infeed(
|
||||
token, shape=(jax.ShapedArray((3, 4), np.float32),))
|
||||
token, shape=(jax.ShapedArray((3, 4), jnp.float32),))
|
||||
(z,), _ = lax.infeed(
|
||||
token, shape=(jax.ShapedArray((3, 1, 1), np.float32),))
|
||||
token, shape=(jax.ShapedArray((3, 1, 1), jnp.float32),))
|
||||
return x + y + z
|
||||
|
||||
x = onp.float32(1.5)
|
||||
y = onp.reshape(onp.arange(12, dtype=onp.float32), (3, 4)) # onp.random.randn(3, 4).astype(onp.float32)
|
||||
z = onp.random.randn(3, 1, 1).astype(onp.float32)
|
||||
x = np.float32(1.5)
|
||||
y = np.reshape(np.arange(12, dtype=np.float32), (3, 4)) # np.random.randn(3, 4).astype(np.float32)
|
||||
z = np.random.randn(3, 1, 1).astype(np.float32)
|
||||
device = jax.local_devices()[0]
|
||||
device.transfer_to_infeed((y,))
|
||||
device.transfer_to_infeed((z,))
|
||||
@ -53,12 +53,12 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
token = lax.create_token(x)
|
||||
y, token = lax.infeed(
|
||||
token, shape=jax.ShapedArray((3, 4), np.float32))
|
||||
token = lax.outfeed(token, y + onp.float32(1))
|
||||
token, shape=jax.ShapedArray((3, 4), jnp.float32))
|
||||
token = lax.outfeed(token, y + np.float32(1))
|
||||
return lax.tie_in(token, x - 1)
|
||||
|
||||
x = onp.float32(7.5)
|
||||
y = onp.random.randn(3, 4).astype(onp.float32)
|
||||
x = np.float32(7.5)
|
||||
y = np.random.randn(3, 4).astype(np.float32)
|
||||
execution = threading.Thread(target=lambda: f(x))
|
||||
execution.start()
|
||||
device = jax.local_devices()[0]
|
||||
@ -66,14 +66,14 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
out, = device.transfer_from_outfeed(
|
||||
xla_client.shape_from_pyval((y,)).with_major_to_minor_layout_if_absent())
|
||||
execution.join()
|
||||
self.assertAllClose(out, y + onp.float32(1))
|
||||
self.assertAllClose(out, y + np.float32(1))
|
||||
|
||||
def testInfeedThenOutfeedInALoop(self):
|
||||
hcb.stop_outfeed_receiver()
|
||||
def doubler(_, token):
|
||||
y, token = lax.infeed(
|
||||
token, shape=jax.ShapedArray((3, 4), np.float32))
|
||||
return lax.outfeed(token, y * onp.float32(2))
|
||||
token, shape=jax.ShapedArray((3, 4), jnp.float32))
|
||||
return lax.outfeed(token, y * np.float32(2))
|
||||
|
||||
@jax.jit
|
||||
def f(n):
|
||||
@ -86,11 +86,11 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
execution = threading.Thread(target=lambda: f(n))
|
||||
execution.start()
|
||||
for _ in range(n):
|
||||
x = onp.random.randn(3, 4).astype(onp.float32)
|
||||
x = np.random.randn(3, 4).astype(np.float32)
|
||||
device.transfer_to_infeed((x,))
|
||||
y, = device.transfer_from_outfeed(xla_client.shape_from_pyval((x,))
|
||||
.with_major_to_minor_layout_if_absent())
|
||||
self.assertAllClose(y, x * onp.float32(2))
|
||||
self.assertAllClose(y, x * np.float32(2))
|
||||
execution.join()
|
||||
|
||||
|
||||
|
@ -21,7 +21,7 @@ from unittest import SkipTest
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import api
|
||||
@ -74,13 +74,13 @@ LAX_GRAD_OPS = [
|
||||
grad_test_spec(lax.log1p, nargs=1, order=2, rng_factory=jtu.rand_positive,
|
||||
dtypes=grad_inexact_dtypes),
|
||||
grad_test_spec(lax.sinh, nargs=1, order=2, rng_factory=jtu.rand_default,
|
||||
dtypes=grad_float_dtypes + [onp.complex64], tol=1e-5),
|
||||
dtypes=grad_float_dtypes + [np.complex64], tol=1e-5),
|
||||
grad_test_spec(lax.cosh, nargs=1, order=2, rng_factory=jtu.rand_default,
|
||||
dtypes=grad_inexact_dtypes, tol=1e-5),
|
||||
grad_test_spec(lax.tanh, nargs=1, order=2, rng_factory=jtu.rand_default,
|
||||
dtypes=grad_inexact_dtypes, tol=1e-5),
|
||||
grad_test_spec(lax.sin, nargs=1, order=2, rng_factory=jtu.rand_default,
|
||||
dtypes=grad_inexact_dtypes, tol={onp.float32: 5e-1}),
|
||||
dtypes=grad_inexact_dtypes, tol={np.float32: 5e-1}),
|
||||
grad_test_spec(lax.cos, nargs=1, order=2, rng_factory=jtu.rand_default,
|
||||
dtypes=grad_inexact_dtypes),
|
||||
grad_test_spec(lax.tan, nargs=1, order=2,
|
||||
@ -122,7 +122,7 @@ LAX_GRAD_OPS = [
|
||||
grad_test_spec(lax.abs, nargs=1, order=2, rng_factory=jtu.rand_positive,
|
||||
dtypes=grad_inexact_dtypes),
|
||||
grad_test_spec(lax.pow, nargs=2, order=2, rng_factory=jtu.rand_positive,
|
||||
dtypes=grad_inexact_dtypes, tol={onp.float32: 3e-1}),
|
||||
dtypes=grad_inexact_dtypes, tol={np.float32: 3e-1}),
|
||||
|
||||
grad_test_spec(lax.add, nargs=2, order=2, rng_factory=jtu.rand_default,
|
||||
dtypes=grad_inexact_dtypes),
|
||||
@ -152,13 +152,13 @@ def grad_special_values_test_spec(op, values, tol=None):
|
||||
LAX_GRAD_SPECIAL_VALUE_TESTS = [
|
||||
grad_special_values_test_spec(
|
||||
lax.sinh, [0.],
|
||||
tol={onp.float32: 1e-2} if jtu.device_under_test() == "tpu" else None),
|
||||
tol={np.float32: 1e-2} if jtu.device_under_test() == "tpu" else None),
|
||||
grad_special_values_test_spec(
|
||||
lax.cosh, [0.],
|
||||
tol={onp.float32: 1e-2} if jtu.device_under_test() == "tpu" else None),
|
||||
tol={np.float32: 1e-2} if jtu.device_under_test() == "tpu" else None),
|
||||
grad_special_values_test_spec(lax.tanh, [0., 1000.]),
|
||||
grad_special_values_test_spec(lax.sin, [0., onp.pi, onp.pi/2., onp.pi/4.]),
|
||||
grad_special_values_test_spec(lax.cos, [0., onp.pi, onp.pi/2., onp.pi/4.]),
|
||||
grad_special_values_test_spec(lax.sin, [0., np.pi, np.pi/2., np.pi/4.]),
|
||||
grad_special_values_test_spec(lax.cos, [0., np.pi, np.pi/2., np.pi/4.]),
|
||||
grad_special_values_test_spec(lax.tan, [0.]),
|
||||
grad_special_values_test_spec(lax.asin, [0.]),
|
||||
grad_special_values_test_spec(lax.acos, [0.]),
|
||||
@ -219,7 +219,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
|
||||
args = (rng((2, 3), from_dtype),)
|
||||
convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
|
||||
convert_element_type = jtu.ignore_warning(category=onp.ComplexWarning)(
|
||||
convert_element_type = jtu.ignore_warning(category=np.ComplexWarning)(
|
||||
convert_element_type)
|
||||
check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)
|
||||
|
||||
@ -243,7 +243,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
|
||||
dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name,
|
||||
dim, ",".join(str(d) for d in base_shape), np.dtype(dtype).name,
|
||||
num_arrs),
|
||||
"dim": dim, "base_shape": base_shape, "dtype": dtype,
|
||||
"num_arrs": num_arrs, "rng_factory": rng_factory}
|
||||
@ -359,16 +359,16 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
padding, lhs_dil, rhs_dil, dimension_numbers,
|
||||
perms, feature_group_count, batch_group_count,
|
||||
rng_factory):
|
||||
if dtype == onp.float16:
|
||||
if dtype == np.float16:
|
||||
raise SkipTest("float16 numerical issues") # TODO(mattjj): resolve
|
||||
|
||||
rng = rng_factory(self.rng())
|
||||
tol = {dtypes.bfloat16: 1e-0, onp.float16: 5e-1, onp.float32: 2e-4}
|
||||
tol = {dtypes.bfloat16: 1e-0, np.float16: 5e-1, np.float32: 2e-4}
|
||||
|
||||
# permute shapes to match dim_spec, scale by feature_group_count
|
||||
lhs_perm, rhs_perm = perms
|
||||
lhs_shape = list(onp.take(lhs_shape, lhs_perm))
|
||||
rhs_shape = list(onp.take(rhs_shape, rhs_perm))
|
||||
lhs_shape = list(np.take(lhs_shape, lhs_perm))
|
||||
rhs_shape = list(np.take(rhs_shape, rhs_perm))
|
||||
|
||||
lhs = rng(lhs_shape, dtype)
|
||||
rhs = rng(rhs_shape, dtype)
|
||||
@ -391,7 +391,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
for dtype in float_dtypes))
|
||||
def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
tol = {onp.float16: 1e-1, onp.float32: 1e-4}
|
||||
tol = {np.float16: 1e-1, np.float32: 1e-4}
|
||||
lhs = rng(lhs_shape, dtype)
|
||||
rhs = rng(rhs_shape, dtype)
|
||||
dot = partial(lax.dot, precision=lax.Precision.HIGHEST)
|
||||
@ -434,7 +434,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
|
||||
shape, onp.dtype(dtype).name, broadcast_sizes),
|
||||
shape, np.dtype(dtype).name, broadcast_sizes),
|
||||
"shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
|
||||
"rng_factory": rng_factory}
|
||||
for shape in [(), (2, 3)]
|
||||
@ -503,11 +503,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
def testPadGrad(self, shape, dtype, pads, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
operand = rng(shape, dtype)
|
||||
pad = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
|
||||
pad = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
|
||||
check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
operand = rng(shape, dtype)
|
||||
padding_value = onp.array(0., dtype)
|
||||
padding_value = np.array(0., dtype)
|
||||
pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads)
|
||||
check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
@ -515,15 +515,15 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
rev = lambda operand: lax.rev(operand, dimensions)
|
||||
|
||||
dimensions = [0]
|
||||
check_grads(rev, (onp.array([3., 2., 1.]),), 2)
|
||||
check_grads(rev, (np.array([3., 2., 1.]),), 2)
|
||||
|
||||
dimensions = [0, 1]
|
||||
check_grads(rev, (onp.array([[6., 5., 4.], [3., 2., 1.]]),), 2,
|
||||
rtol={onp.float32: 3e-3})
|
||||
check_grads(rev, (np.array([[6., 5., 4.], [3., 2., 1.]]),), 2,
|
||||
rtol={np.float32: 3e-3})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_predshape={}_argshapes={}".format(
|
||||
jtu.format_shape_dtype_string(pred_shape, onp.bool_),
|
||||
jtu.format_shape_dtype_string(pred_shape, np.bool_),
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype)),
|
||||
"pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype,
|
||||
"rng_factory": rng_factory}
|
||||
@ -533,7 +533,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
pred = rng(pred_shape, onp.bool_)
|
||||
pred = rng(pred_shape, np.bool_)
|
||||
on_true = rng(arg_shape, dtype)
|
||||
on_false = rng(arg_shape, dtype)
|
||||
select = lambda on_true, on_false: lax.select(pred, on_true, on_false)
|
||||
@ -603,7 +603,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
operand = rng(shape, dtype)
|
||||
update = rng(update_shape, dtype)
|
||||
start_indices = onp.array(start_indices)
|
||||
start_indices = np.array(start_indices)
|
||||
|
||||
dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices)
|
||||
check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.)
|
||||
@ -639,8 +639,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
"dims": dims, "rng_factory": rng_factory}
|
||||
for init_val, op, dtypes, rng_factory in [
|
||||
(0, lax.add, float_dtypes + jtu.dtypes.complex, jtu.rand_default),
|
||||
(-onp.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int),
|
||||
(onp.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int),
|
||||
(-np.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int),
|
||||
(np.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int),
|
||||
(1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)),
|
||||
]
|
||||
for dtype in dtypes
|
||||
@ -657,10 +657,10 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
if jtu.device_under_test() == "tpu" and op is lax.mul:
|
||||
raise SkipTest("unimplemented case")
|
||||
tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 1e-1,
|
||||
onp.float64: 1e-3, onp.complex64: 1e-1}
|
||||
tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-1, np.float32: 1e-1,
|
||||
np.float64: 1e-3, np.complex64: 1e-1}
|
||||
operand = rng(shape, dtype)
|
||||
init_val = onp.asarray(init_val, dtype=dtype)
|
||||
init_val = np.asarray(init_val, dtype=dtype)
|
||||
reduce = lambda operand: lax.reduce(operand, init_val, op, dims)
|
||||
eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else
|
||||
1e-1 if dtype == dtypes.bfloat16 else
|
||||
@ -670,13 +670,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_dtype={}_padding={}"
|
||||
.format(op.__name__, onp.dtype(dtype).name, padding),
|
||||
.format(op.__name__, np.dtype(dtype).name, padding),
|
||||
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
|
||||
"rng_factory": rng_factory}
|
||||
for init_val, op, dtypes, rng_factory in [
|
||||
(0, lax.add, grad_float_dtypes, jtu.rand_small),
|
||||
(-onp.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int),
|
||||
(onp.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int),
|
||||
(-np.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int),
|
||||
(np.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int),
|
||||
]
|
||||
for dtype in dtypes
|
||||
for padding in ["VALID", "SAME"]))
|
||||
@ -685,7 +685,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
message="Using reduced precision for gradient.*")
|
||||
def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
init_val = onp.asarray(init_val, dtype=dtype)
|
||||
init_val = np.asarray(init_val, dtype=dtype)
|
||||
|
||||
# We need this conditional and the corresponding loop logic to be in the
|
||||
# test method, rather than at the parameterized test level, because it
|
||||
@ -720,10 +720,10 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
tol = None
|
||||
else:
|
||||
# this test can fail if there are duplicates in operand
|
||||
self.assertEqual(onp.unique(operand).size, operand.size,
|
||||
self.assertEqual(np.unique(operand).size, operand.size,
|
||||
msg="test requires operand elements to be unique.")
|
||||
eps = 1e-2
|
||||
tol = {onp.float16: 1e-1, onp.float32: 6e-2, onp.float64: 6e-2}
|
||||
tol = {np.float16: 1e-1, np.float32: 6e-2, np.float64: 6e-2}
|
||||
check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol,
|
||||
eps)
|
||||
|
||||
@ -733,14 +733,14 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
"op": op, "shape": shape, "dtype": dtype,
|
||||
"axis": axis, "rng_factory": rng_factory}
|
||||
for op, types in [
|
||||
(lax.cumsum, [onp.float32, onp.float64]),
|
||||
(lax.cumprod, [onp.float32, onp.float64]),
|
||||
(lax.cumsum, [np.float32, np.float64]),
|
||||
(lax.cumprod, [np.float32, np.float64]),
|
||||
]
|
||||
for dtype in types
|
||||
for shape in [[10], [3, 4, 5]]
|
||||
for axis in range(len(shape))
|
||||
for rng_factory in [
|
||||
jtu.rand_default if dtypes.issubdtype(dtype, onp.integer)
|
||||
jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
||||
else jtu.rand_small]))
|
||||
def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
@ -753,7 +753,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, is_stable),
|
||||
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis,
|
||||
"is_stable": is_stable}
|
||||
for dtype in [onp.float32]
|
||||
for dtype in [np.float32]
|
||||
for shape in [(5,), (5, 7)]
|
||||
for axis in [len(shape) - 1]
|
||||
for is_stable in [False, True]
|
||||
@ -773,8 +773,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
"rng_factory": rng_factory, "shape": shape,
|
||||
"key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis,
|
||||
"is_stable": is_stable}
|
||||
for key_dtype in [onp.float32]
|
||||
for val_dtype in [onp.float32]
|
||||
for key_dtype in [np.float32]
|
||||
for val_dtype in [np.float32]
|
||||
for shape in [(3,), (5, 3)]
|
||||
for axis in [len(shape) - 1]
|
||||
for is_stable in [False, True]
|
||||
@ -786,7 +786,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
# too, since we don't guarantee the same ordering of values with equal keys.
|
||||
# To avoid that case, we generate unique keys (globally in the key array).
|
||||
def args_maker():
|
||||
flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype)
|
||||
flat_keys = np.arange(np.prod(shape, dtype=int), dtype=key_dtype)
|
||||
keys = self.rng().permutation(flat_keys).reshape(shape)
|
||||
values = rng(shape, val_dtype)
|
||||
return keys, values
|
||||
@ -799,12 +799,12 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_shape={}_k={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), k),
|
||||
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k}
|
||||
for dtype in [onp.float32,]
|
||||
for dtype in [np.float32,]
|
||||
for shape in [(4,), (5, 5), (2, 1, 4)]
|
||||
for k in [1, 3]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testTopKGrad(self, shape, dtype, k, rng_factory):
|
||||
flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype)
|
||||
flat_values = np.arange(np.prod(shape, dtype=int), dtype=dtype)
|
||||
values = self.rng().permutation(flat_values).reshape(shape)
|
||||
fun = lambda vs: lax.top_k(vs, k=k)[0]
|
||||
check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)
|
||||
@ -816,10 +816,10 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
"rng_factory": rng_factory}
|
||||
for dtype in float_dtypes
|
||||
for shape, idxs, axes in [
|
||||
[(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
|
||||
[(3, 4, 5), (onp.array([-1, -2]),), (0,)],
|
||||
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
|
||||
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
|
||||
[(3, 4, 5), (np.array([0, 2, 1]),), (0,)],
|
||||
[(3, 4, 5), (np.array([-1, -2]),), (0,)],
|
||||
[(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 1)],
|
||||
[(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 2)],
|
||||
]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory):
|
||||
@ -837,13 +837,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
"rng_idx_factory": rng_idx_factory}
|
||||
for dtype in grad_float_dtypes
|
||||
for shape, idxs, dnums, slice_sizes, max_idx in [
|
||||
((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
|
||||
((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1,), 5),
|
||||
((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
||||
((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
||||
(2,), 9),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
||||
((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1, 3), 3),
|
||||
]
|
||||
@ -868,13 +868,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
"rng_idx_factory": rng_idx_factory}
|
||||
for dtype in grad_float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums, max_idx in [
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,)), 4),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,)), 9),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,)), 3),
|
||||
]
|
||||
@ -900,13 +900,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
"rng_idx_factory": rng_idx_factory}
|
||||
for dtype in grad_float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums, max_idx in [
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,)), 4),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,)), 9),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,)), 3),
|
||||
]
|
||||
@ -928,10 +928,10 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
# https://github.com/google/jax/issues/1901
|
||||
def f(x):
|
||||
n = x.shape[0]
|
||||
y = onp.arange(n, dtype=x.dtype)
|
||||
return jax.ops.index_update(x, onp.diag_indices(n), y)
|
||||
y = np.arange(n, dtype=x.dtype)
|
||||
return jax.ops.index_update(x, np.diag_indices(n), y)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
check_grads(f, (rng((5, 5), onp.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2,
|
||||
check_grads(f, (rng((5, 5), np.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2,
|
||||
1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -943,13 +943,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
||||
for dtype in grad_float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
@ -974,13 +974,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
||||
for dtype in grad_float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
@ -1013,7 +1013,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
|
||||
expected = onp.array(0.0)
|
||||
expected = np.array(0.0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
with core.skipping_checks():
|
||||
@ -1022,18 +1022,18 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
|
||||
# TODO(mattjj): make this a more systematic test
|
||||
def testRemainder(self):
|
||||
rng = onp.random.RandomState(0)
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.uniform(-0.9, 9, size=(3, 4))
|
||||
y = rng.uniform(0.7, 1.9, size=(3, 1))
|
||||
assert not set(onp.unique(x)) & set(onp.unique(y))
|
||||
tol = 1e-1 if jtu.num_float_bits(onp.float64) == 32 else 1e-3
|
||||
assert not set(np.unique(x)) & set(np.unique(y))
|
||||
tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3
|
||||
check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)
|
||||
|
||||
rng = onp.random.RandomState(0)
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.uniform(-0.9, 9, size=(1, 4))
|
||||
y = rng.uniform(0.7, 1.9, size=(3, 4))
|
||||
assert not set(onp.unique(x)) & set(onp.unique(y))
|
||||
tol = 1e-1 if jtu.num_float_bits(onp.float64) == 32 else 1e-3
|
||||
assert not set(np.unique(x)) & set(np.unique(y))
|
||||
tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3
|
||||
check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)
|
||||
|
||||
def testHigherOrderGradientOfReciprocal(self):
|
||||
@ -1042,7 +1042,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
# N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x)
|
||||
return 1 / x
|
||||
grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv))))))
|
||||
self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.)))
|
||||
self.assertAllClose(np.float32(0.0439453125), grad_fn(np.float32(4.)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -22,7 +22,7 @@ import warnings
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from jax import api
|
||||
from jax import numpy as jnp
|
||||
@ -218,31 +218,31 @@ STATIC_INDEXING_GRAD_TESTS = [
|
||||
|
||||
ADVANCED_INDEXING_TESTS = [
|
||||
("One1DIntArrayIndex",
|
||||
[IndexSpec(shape=(3,), indexer=onp.array([0, 1])),
|
||||
IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])),
|
||||
IndexSpec(shape=(3,), indexer=onp.array([-1, 1])),
|
||||
IndexSpec(shape=(3,), indexer=onp.array([-2, -1])),
|
||||
IndexSpec(shape=(0,), indexer=onp.array([], dtype=onp.int32)),
|
||||
[IndexSpec(shape=(3,), indexer=np.array([0, 1])),
|
||||
IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 1])),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 0, 1])),
|
||||
IndexSpec(shape=(3,), indexer=np.array([-1, 1])),
|
||||
IndexSpec(shape=(3,), indexer=np.array([-2, -1])),
|
||||
IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32)),
|
||||
]),
|
||||
("One2DIntArrayIndex",
|
||||
[IndexSpec(shape=(3,), indexer=onp.array([[0, 0]])),
|
||||
IndexSpec(shape=(3, 3), indexer=onp.array([[1, 2, 1],
|
||||
[IndexSpec(shape=(3,), indexer=np.array([[0, 0]])),
|
||||
IndexSpec(shape=(3, 3), indexer=np.array([[1, 2, 1],
|
||||
[0, 1, -1]])),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=onp.array([[0, 2, 0, 1],
|
||||
IndexSpec(shape=(3, 4, 5), indexer=np.array([[0, 2, 0, 1],
|
||||
[-1, -2, 1, 0]])),
|
||||
]),
|
||||
("Two1DIntArrayIndicesNoBroadcasting",
|
||||
[IndexSpec(shape=(3, 3), indexer=[onp.array([0, 1]),
|
||||
onp.array([1, 2])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[onp.array([0, 2, 0, 1]),
|
||||
onp.array([-1, 0, -1, 2])]),
|
||||
[IndexSpec(shape=(3, 3), indexer=[np.array([0, 1]),
|
||||
np.array([1, 2])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[np.array([0, 2, 0, 1]),
|
||||
np.array([-1, 0, -1, 2])]),
|
||||
]),
|
||||
("Two1DIntArrayIndicesWithBroadcasting",
|
||||
[IndexSpec(shape=(3, 3), indexer=[onp.array([[0, 1]]),
|
||||
onp.array([1, 2])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[onp.array([[0, 2, 0, 1]]),
|
||||
onp.array([-1, 0, -1, 2])]),
|
||||
[IndexSpec(shape=(3, 3), indexer=[np.array([[0, 1]]),
|
||||
np.array([1, 2])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[np.array([[0, 2, 0, 1]]),
|
||||
np.array([-1, 0, -1, 2])]),
|
||||
]),
|
||||
("ListOfPythonInts",
|
||||
[IndexSpec(shape=(3,), indexer=[0, 1, 0]),
|
||||
@ -257,42 +257,42 @@ ADVANCED_INDEXING_TESTS = [
|
||||
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0, 3]])),
|
||||
]),
|
||||
("ListOfPythonIntsAndIntArrays",
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[0, onp.array([0, 1])]),
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[0, np.array([0, 1])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[0, 1,
|
||||
onp.array([[2, 3, 0, 3]])]),
|
||||
np.array([[2, 3, 0, 3]])]),
|
||||
]),
|
||||
("ListOfListsOfPythonIntsAndIntArrays",
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], onp.array([0])]),
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], np.array([0])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]],
|
||||
onp.array([[2, 3, 0, 3]])]),
|
||||
np.array([[2, 3, 0, 3]])]),
|
||||
]),
|
||||
]
|
||||
|
||||
ADVANCED_INDEXING_TESTS_NO_REPEATS = [
|
||||
("One1DIntArrayIndex",
|
||||
[IndexSpec(shape=(3,), indexer=onp.array([0, 1])),
|
||||
IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 0])),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 1])),
|
||||
IndexSpec(shape=(3,), indexer=onp.array([-1, 1])),
|
||||
IndexSpec(shape=(3,), indexer=onp.array([-2, -1])),
|
||||
IndexSpec(shape=(0,), indexer=onp.array([], dtype=onp.int32)),
|
||||
[IndexSpec(shape=(3,), indexer=np.array([0, 1])),
|
||||
IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 0])),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 1])),
|
||||
IndexSpec(shape=(3,), indexer=np.array([-1, 1])),
|
||||
IndexSpec(shape=(3,), indexer=np.array([-2, -1])),
|
||||
IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32)),
|
||||
]),
|
||||
("One2DIntArrayIndex",
|
||||
[IndexSpec(shape=(3,), indexer=onp.array([[0, 1]])),
|
||||
IndexSpec(shape=(6, 6), indexer=onp.array([[1, 2, 0],
|
||||
[IndexSpec(shape=(3,), indexer=np.array([[0, 1]])),
|
||||
IndexSpec(shape=(6, 6), indexer=np.array([[1, 2, 0],
|
||||
[3, 4, -1]])),
|
||||
]),
|
||||
("Two1DIntArrayIndicesNoBroadcasting",
|
||||
[IndexSpec(shape=(3, 3), indexer=[onp.array([0, 1]),
|
||||
onp.array([1, 2])]),
|
||||
IndexSpec(shape=(4, 5, 6), indexer=[onp.array([0, 2, 1, 3]),
|
||||
onp.array([-1, 0, -2, 1])]),
|
||||
[IndexSpec(shape=(3, 3), indexer=[np.array([0, 1]),
|
||||
np.array([1, 2])]),
|
||||
IndexSpec(shape=(4, 5, 6), indexer=[np.array([0, 2, 1, 3]),
|
||||
np.array([-1, 0, -2, 1])]),
|
||||
]),
|
||||
("Two1DIntArrayIndicesWithBroadcasting",
|
||||
[IndexSpec(shape=(3, 3), indexer=[onp.array([[0, 1]]),
|
||||
onp.array([1, 2])]),
|
||||
IndexSpec(shape=(4, 5, 6), indexer=[onp.array([[0, 2, -1, 1]]),
|
||||
onp.array([-1, 0, -2, 2])]),
|
||||
[IndexSpec(shape=(3, 3), indexer=[np.array([[0, 1]]),
|
||||
np.array([1, 2])]),
|
||||
IndexSpec(shape=(4, 5, 6), indexer=[np.array([[0, 2, -1, 1]]),
|
||||
np.array([-1, 0, -2, 2])]),
|
||||
]),
|
||||
("ListOfPythonInts",
|
||||
[IndexSpec(shape=(3,), indexer=[0, 2, 1]),
|
||||
@ -307,65 +307,65 @@ ADVANCED_INDEXING_TESTS_NO_REPEATS = [
|
||||
IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0]])),
|
||||
]),
|
||||
("ListOfPythonIntsAndIntArrays",
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[0, onp.array([0, 1])]),
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[0, np.array([0, 1])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[0, 1,
|
||||
onp.array([[2, 3, 0]])]),
|
||||
np.array([[2, 3, 0]])]),
|
||||
]),
|
||||
("ListOfListsOfPythonIntsAndIntArrays",
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], onp.array([0])]),
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], np.array([0])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]],
|
||||
onp.array([[2, 3, 0]])]),
|
||||
np.array([[2, 3, 0]])]),
|
||||
]),
|
||||
]
|
||||
|
||||
MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS = [
|
||||
("SlicesAndOneIntArrayIndex",
|
||||
[IndexSpec(shape=(2, 3), indexer=(onp.array([0, 1]), slice(1, 2))),
|
||||
[IndexSpec(shape=(2, 3), indexer=(np.array([0, 1]), slice(1, 2))),
|
||||
IndexSpec(shape=(2, 3), indexer=(slice(0, 2),
|
||||
onp.array([0, 2]))),
|
||||
np.array([0, 2]))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis,
|
||||
onp.array([0, 2]),
|
||||
np.array([0, 2]),
|
||||
slice(None))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis,
|
||||
onp.array([[0, 2], [1, 3]]),
|
||||
np.array([[0, 2], [1, 3]]),
|
||||
slice(None))),
|
||||
]),
|
||||
("SlicesAndTwoIntArrayIndices",
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis,
|
||||
onp.array([0, 2]),
|
||||
onp.array([-1, 2]))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]),
|
||||
np.array([0, 2]),
|
||||
np.array([-1, 2]))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2]),
|
||||
Ellipsis,
|
||||
onp.array([-1, 2]))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]),
|
||||
onp.array([-1, 2]),
|
||||
np.array([-1, 2]))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2]),
|
||||
np.array([-1, 2]),
|
||||
Ellipsis)),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]),
|
||||
onp.array([-1, 2]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2]),
|
||||
np.array([-1, 2]),
|
||||
slice(1, 3))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2]),
|
||||
slice(1, 3),
|
||||
onp.array([-1, 2]))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2, -2]),
|
||||
np.array([-1, 2]))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2, -2]),
|
||||
slice(None, None, 2),
|
||||
onp.array([-1, 2, 1]))),
|
||||
np.array([-1, 2, 1]))),
|
||||
]),
|
||||
("NonesAndIntArrayIndices",
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[onp.array([0, 2]),
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[np.array([0, 2]),
|
||||
None,
|
||||
onp.array([-1, 2])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]),
|
||||
np.array([-1, 2])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2]),
|
||||
None,
|
||||
None,
|
||||
onp.array([-1, 2]))),
|
||||
np.array([-1, 2]))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis,
|
||||
onp.array([0, 2]),
|
||||
np.array([0, 2]),
|
||||
None,
|
||||
None,
|
||||
onp.array([-1, 2]))),
|
||||
np.array([-1, 2]))),
|
||||
]),
|
||||
("IntArrayWithInt32Type",
|
||||
[IndexSpec(shape=(3, 4), indexer=(Ellipsis, onp.array(1, dtype=onp.int32)))
|
||||
[IndexSpec(shape=(3, 4), indexer=(Ellipsis, np.array(1, dtype=np.int32)))
|
||||
]),
|
||||
]
|
||||
|
||||
@ -373,16 +373,16 @@ MIXED_ADVANCED_INDEXING_TESTS = MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + [
|
||||
("SlicesAndOneIntArrayIndex",
|
||||
[
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis,
|
||||
onp.array([[0, 2], [1, 1]]),
|
||||
np.array([[0, 2], [1, 1]]),
|
||||
slice(None))),
|
||||
]),
|
||||
("SlicesAndTwoIntArrayIndices",
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2, -2]),
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2, -2]),
|
||||
slice(None, None, 2),
|
||||
onp.array([-1, 2, -1]))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(onp.array([[0, 2], [2, 0]]),
|
||||
np.array([-1, 2, -1]))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(np.array([[0, 2], [2, 0]]),
|
||||
Ellipsis,
|
||||
onp.array([[1, 0], [1, 0]]))),
|
||||
np.array([[1, 0], [1, 0]]))),
|
||||
]),]
|
||||
|
||||
class IndexingTest(jtu.JaxTestCase):
|
||||
@ -559,30 +559,30 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
"shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer}
|
||||
for name, index_specs in [
|
||||
("One1DIntArrayIndex",
|
||||
[IndexSpec(shape=(3,), indexer=onp.array([0, 1])),
|
||||
IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])),
|
||||
IndexSpec(shape=(3,), indexer=onp.array([-1, 1])),
|
||||
IndexSpec(shape=(3,), indexer=onp.array([-2, -1])),
|
||||
[IndexSpec(shape=(3,), indexer=np.array([0, 1])),
|
||||
IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 1])),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 0, 1])),
|
||||
IndexSpec(shape=(3,), indexer=np.array([-1, 1])),
|
||||
IndexSpec(shape=(3,), indexer=np.array([-2, -1])),
|
||||
]),
|
||||
("One2DIntArrayIndex",
|
||||
[IndexSpec(shape=(3,), indexer=onp.array([[0, 0]])),
|
||||
IndexSpec(shape=(3, 3), indexer=onp.array([[1, 2, 1],
|
||||
[IndexSpec(shape=(3,), indexer=np.array([[0, 0]])),
|
||||
IndexSpec(shape=(3, 3), indexer=np.array([[1, 2, 1],
|
||||
[0, 1, -1]])),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=onp.array([[0, 2, 0, 1],
|
||||
IndexSpec(shape=(3, 4, 5), indexer=np.array([[0, 2, 0, 1],
|
||||
[-1, -2, 1, 0]])),
|
||||
]),
|
||||
("Two1DIntArrayIndicesNoBroadcasting",
|
||||
[IndexSpec(shape=(3, 3), indexer=[onp.array([0, 1]),
|
||||
onp.array([1, 2])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[onp.array([0, 2, 0, 1]),
|
||||
onp.array([-1, 0, -1, 2])]),
|
||||
[IndexSpec(shape=(3, 3), indexer=[np.array([0, 1]),
|
||||
np.array([1, 2])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[np.array([0, 2, 0, 1]),
|
||||
np.array([-1, 0, -1, 2])]),
|
||||
]),
|
||||
("Two1DIntArrayIndicesWithBroadcasting",
|
||||
[IndexSpec(shape=(3, 3), indexer=[onp.array([[0, 1]]),
|
||||
onp.array([1, 2])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[onp.array([[0, 2, 0, 1]]),
|
||||
onp.array([-1, 0, -1, 2])]),
|
||||
[IndexSpec(shape=(3, 3), indexer=[np.array([[0, 1]]),
|
||||
np.array([1, 2])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[np.array([[0, 2, 0, 1]]),
|
||||
np.array([-1, 0, -1, 2])]),
|
||||
]),
|
||||
("ListOfPythonInts",
|
||||
[IndexSpec(shape=(3,), indexer=[0, 1, 0]),
|
||||
@ -593,14 +593,14 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], [[2, 3, 0, 3]]]),
|
||||
]),
|
||||
("ListOfPythonIntsAndIntArrays",
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[0, onp.array([0, 1])]),
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[0, np.array([0, 1])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[0, 1,
|
||||
onp.array([[2, 3, 0, 3]])]),
|
||||
np.array([[2, 3, 0, 3]])]),
|
||||
]),
|
||||
("ListOfListsOfPythonIntsAndIntArrays",
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], onp.array([0])]),
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], np.array([0])]),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]],
|
||||
onp.array([[2, 3, 0, 3]])]),
|
||||
np.array([[2, 3, 0, 3]])]),
|
||||
]),
|
||||
]
|
||||
for shape, indexer in index_specs
|
||||
@ -623,10 +623,10 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default])
|
||||
def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer):
|
||||
rng = rng_factory(self.rng())
|
||||
indexer_with_dummies = [e if isinstance(e, onp.ndarray) else ()
|
||||
indexer_with_dummies = [e if isinstance(e, np.ndarray) else ()
|
||||
for e in indexer]
|
||||
substitutes = [(i, e) for i, e in enumerate(indexer)
|
||||
if not isinstance(e, onp.ndarray)]
|
||||
if not isinstance(e, np.ndarray)]
|
||||
args_maker = lambda: [rng(shape, dtype), indexer_with_dummies]
|
||||
|
||||
def fun(x, indexer_with_dummies):
|
||||
@ -636,8 +636,8 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
def testAdvancedIndexingManually(self):
|
||||
x = onp.random.RandomState(0).randn(3, 4, 5)
|
||||
index_array = onp.array([0, 2, -1, 0])
|
||||
x = np.random.RandomState(0).randn(3, 4, 5)
|
||||
index_array = np.array([0, 2, -1, 0])
|
||||
|
||||
op = lambda x, index_array: x[..., index_array, :]
|
||||
cop = api.jit(op)
|
||||
@ -671,59 +671,59 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
|
||||
cfoo = api.jit(foo)
|
||||
|
||||
a1 = foo(onp.arange(3))
|
||||
a2 = cfoo(onp.arange(3))
|
||||
a1 = foo(np.arange(3))
|
||||
a2 = cfoo(np.arange(3))
|
||||
|
||||
self.assertAllClose(a1, a2)
|
||||
|
||||
def testBooleanIndexingArray1D(self):
|
||||
idx = onp.array([True, True, False])
|
||||
x = api.device_put(onp.arange(3))
|
||||
idx = np.array([True, True, False])
|
||||
x = api.device_put(np.arange(3))
|
||||
ans = x[idx]
|
||||
expected = onp.arange(3)[idx]
|
||||
expected = np.arange(3)[idx]
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testBooleanIndexingList1D(self):
|
||||
idx = [True, True, False]
|
||||
x = api.device_put(onp.arange(3))
|
||||
x = api.device_put(np.arange(3))
|
||||
ans = x[idx]
|
||||
expected = onp.arange(3)[idx]
|
||||
expected = np.arange(3)[idx]
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testBooleanIndexingArray2DBroadcast(self):
|
||||
idx = onp.array([True, True, False, True])
|
||||
x = onp.arange(8).reshape(4, 2)
|
||||
idx = np.array([True, True, False, True])
|
||||
x = np.arange(8).reshape(4, 2)
|
||||
ans = api.device_put(x)[idx]
|
||||
expected = x[idx]
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testBooleanIndexingList2DBroadcast(self):
|
||||
idx = [True, True, False, True]
|
||||
x = onp.arange(8).reshape(4, 2)
|
||||
x = np.arange(8).reshape(4, 2)
|
||||
ans = api.device_put(x)[idx]
|
||||
expected = x[idx]
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testBooleanIndexingArray2D(self):
|
||||
idx = onp.array([[True, False],
|
||||
idx = np.array([[True, False],
|
||||
[False, True],
|
||||
[False, False],
|
||||
[True, True]])
|
||||
x = onp.arange(8).reshape(4, 2)
|
||||
x = np.arange(8).reshape(4, 2)
|
||||
ans = api.device_put(x)[idx]
|
||||
expected = x[idx]
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testBooleanIndexingDynamicShapeError(self):
|
||||
x = onp.zeros(3)
|
||||
i = onp.array([True, True, False])
|
||||
x = np.zeros(3)
|
||||
i = np.array([True, True, False])
|
||||
self.assertRaises(IndexError, lambda: api.jit(lambda x, i: x[i])(x, i))
|
||||
|
||||
def testIssue187(self):
|
||||
x = jnp.ones((5, 5))
|
||||
x[[0, 2, 4], [0, 2, 4]] # doesn't crash
|
||||
|
||||
x = onp.arange(25).reshape((5, 5))
|
||||
x = np.arange(25).reshape((5, 5))
|
||||
ans = api.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x)
|
||||
expected = x[[0, 2, 4], [0, 2, 4]]
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
@ -734,15 +734,15 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
x = jnp.ones((3, 4), jnp.float32)
|
||||
i = jnp.ones((3,), jnp.int32)
|
||||
f = lambda x, i: jnp.sum(x[i])
|
||||
primals, tangents = api.jvp(api.grad(f), (x, i), (x, onp.zeros_like(i)))
|
||||
expected = onp.broadcast_to(
|
||||
onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4))
|
||||
primals, tangents = api.jvp(api.grad(f), (x, i), (x, np.zeros_like(i)))
|
||||
expected = np.broadcast_to(
|
||||
np.array([0, 3, 0], dtype=np.float32)[:, None], (3, 4))
|
||||
self.assertAllClose(expected, primals)
|
||||
self.assertAllClose(onp.zeros_like(x), tangents)
|
||||
self.assertAllClose(np.zeros_like(x), tangents)
|
||||
|
||||
def testTrivialGatherIsntGenerated(self):
|
||||
# https://github.com/google/jax/issues/1621
|
||||
jaxpr = api.make_jaxpr(lambda x: x[:, None])(onp.arange(4))
|
||||
jaxpr = api.make_jaxpr(lambda x: x[:, None])(np.arange(4))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
||||
self.assertNotIn('gather', str(jaxpr))
|
||||
|
||||
@ -754,7 +754,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(IndexError,
|
||||
"index .* is out of bounds for axis .* with size 0"):
|
||||
_ = onp.ones((2, 0))[0, 0] # The numpy error
|
||||
_ = np.ones((2, 0))[0, 0] # The numpy error
|
||||
with self.assertRaisesRegex(IndexError,
|
||||
"index is out of bounds for axis .* with size 0"):
|
||||
_ = x[0, 0] # JAX indexing
|
||||
@ -768,7 +768,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
mask = jnp.array([False])
|
||||
ans = x[mask] # doesn't crash
|
||||
|
||||
expected = onp.array([-1])[onp.array([False])]
|
||||
expected = np.array([-1])[np.array([False])]
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testFloatIndexingError(self):
|
||||
@ -807,7 +807,7 @@ def _broadcastable_shapes(shape):
|
||||
|
||||
@suppress_deprecated_indexing_warnings()
|
||||
def _update_shape(shape, indexer):
|
||||
return onp.zeros(shape)[indexer].shape
|
||||
return np.zeros(shape)[indexer].shape
|
||||
|
||||
|
||||
class UpdateOps(enum.Enum):
|
||||
@ -818,14 +818,14 @@ class UpdateOps(enum.Enum):
|
||||
MAX = 4
|
||||
|
||||
@suppress_deprecated_indexing_warnings()
|
||||
def onp_fn(op, indexer, x, y):
|
||||
def np_fn(op, indexer, x, y):
|
||||
x = x.copy()
|
||||
x[indexer] = {
|
||||
UpdateOps.UPDATE: lambda: y,
|
||||
UpdateOps.ADD: lambda: x[indexer] + y,
|
||||
UpdateOps.MUL: lambda: x[indexer] * y,
|
||||
UpdateOps.MIN: lambda: onp.minimum(x[indexer], y),
|
||||
UpdateOps.MAX: lambda: onp.maximum(x[indexer], y),
|
||||
UpdateOps.MIN: lambda: np.minimum(x[indexer], y),
|
||||
UpdateOps.MAX: lambda: np.maximum(x[indexer], y),
|
||||
}[op]()
|
||||
return x
|
||||
|
||||
@ -870,12 +870,12 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
rng_factory, indexer, sugared, op):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
||||
onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
|
||||
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
|
||||
if sugared:
|
||||
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
|
||||
else:
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
|
||||
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker)
|
||||
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker)
|
||||
self._CompileAndCheck(jax_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list({
|
||||
@ -897,12 +897,12 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
rng_factory, indexer, sugared, op):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
||||
onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
|
||||
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
|
||||
if sugared:
|
||||
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
|
||||
else:
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
|
||||
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker)
|
||||
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker)
|
||||
self._CompileAndCheck(jax_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list({
|
||||
@ -924,12 +924,12 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
rng_factory, indexer, sugared, op):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
||||
onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
|
||||
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
|
||||
if sugared:
|
||||
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
|
||||
else:
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
|
||||
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker)
|
||||
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker)
|
||||
self._CompileAndCheck(jax_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list({
|
||||
@ -959,25 +959,25 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
# testAdvancedIndexing compares against NumPy, and as a result doesn't check
|
||||
# repeated indices. This test is just a simple manual check, based on
|
||||
# https://www.tensorflow.org/api_docs/python/tf/math/segment_sum
|
||||
data = onp.array([5, 1, 7, 2, 3, 4, 1, 3])
|
||||
segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3])
|
||||
data = np.array([5, 1, 7, 2, 3, 4, 1, 3])
|
||||
segment_ids = np.array([0, 0, 0, 1, 2, 2, 3, 3])
|
||||
|
||||
ans = ops.index_add(onp.zeros(onp.max(segment_ids) + 1), segment_ids, data)
|
||||
expected = onp.array([13, 2, 7, 4])
|
||||
ans = ops.index_add(np.zeros(np.max(segment_ids) + 1), segment_ids, data)
|
||||
expected = np.array([13, 2, 7, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testSegmentSum(self):
|
||||
data = onp.array([5, 1, 7, 2, 3, 4, 1, 3])
|
||||
segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3])
|
||||
data = np.array([5, 1, 7, 2, 3, 4, 1, 3])
|
||||
segment_ids = np.array([0, 0, 0, 1, 2, 2, 3, 3])
|
||||
|
||||
# test with explicit num_segments
|
||||
ans = ops.segment_sum(data, segment_ids, num_segments=4)
|
||||
expected = onp.array([13, 2, 7, 4])
|
||||
expected = np.array([13, 2, 7, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# test without explicit num_segments
|
||||
ans = ops.segment_sum(data, segment_ids)
|
||||
expected = onp.array([13, 2, 7, 4])
|
||||
expected = np.array([13, 2, 7, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testIndexDtypeError(self):
|
||||
|
@ -21,7 +21,7 @@ import itertools
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
import scipy.special as osp_special
|
||||
|
||||
from jax import api
|
||||
@ -188,12 +188,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
|
||||
tol={onp.float32: 1e-3, onp.float64: 1e-14})
|
||||
tol={np.float32: 1e-3, np.float64: 1e-14})
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
def testIssue980(self):
|
||||
x = onp.full((4,), -1e20, dtype=onp.float32)
|
||||
self.assertAllClose(onp.zeros((4,), dtype=onp.float32),
|
||||
x = np.full((4,), -1e20, dtype=np.float32)
|
||||
self.assertAllClose(np.zeros((4,), dtype=np.float32),
|
||||
lsp_special.expit(x))
|
||||
|
||||
def testXlogyShouldReturnZero(self):
|
||||
|
@ -21,7 +21,7 @@ from unittest import SkipTest
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import api
|
||||
@ -77,14 +77,14 @@ LAX_OPS = [
|
||||
# TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32
|
||||
# precision.
|
||||
op_record("expm1", 1, float_dtypes + complex_dtypes, jtu.rand_small,
|
||||
{onp.float64: 1e-8}),
|
||||
{np.float64: 1e-8}),
|
||||
op_record("log", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
||||
op_record("log1p", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
||||
# TODO(b/142975473): on CPU, tanh for complex128 is only accurate to
|
||||
# ~float32 precision.
|
||||
# TODO(b/143135720): on GPU, tanh has only ~float32 precision.
|
||||
op_record("tanh", 1, float_dtypes + complex_dtypes, jtu.rand_small,
|
||||
{onp.float64: 1e-9, onp.complex128: 1e-7}),
|
||||
{np.float64: 1e-9, np.complex128: 1e-7}),
|
||||
op_record("sin", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
||||
op_record("cos", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
||||
op_record("atan2", 2, float_dtypes, jtu.rand_default),
|
||||
@ -93,35 +93,35 @@ LAX_OPS = [
|
||||
op_record("rsqrt", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
||||
op_record("square", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
||||
op_record("reciprocal", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
||||
op_record("tan", 1, float_dtypes, jtu.rand_default, {onp.float32: 3e-5}),
|
||||
op_record("tan", 1, float_dtypes, jtu.rand_default, {np.float32: 3e-5}),
|
||||
op_record("asin", 1, float_dtypes, jtu.rand_small),
|
||||
op_record("acos", 1, float_dtypes, jtu.rand_small),
|
||||
op_record("atan", 1, float_dtypes, jtu.rand_small),
|
||||
op_record("asinh", 1, float_dtypes, jtu.rand_default),
|
||||
op_record("acosh", 1, float_dtypes, jtu.rand_positive),
|
||||
# TODO(b/155331781): atanh has only ~float precision
|
||||
op_record("atanh", 1, float_dtypes, jtu.rand_small, {onp.float64: 1e-9}),
|
||||
op_record("atanh", 1, float_dtypes, jtu.rand_small, {np.float64: 1e-9}),
|
||||
op_record("sinh", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
||||
op_record("cosh", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
||||
op_record("lgamma", 1, float_dtypes, jtu.rand_positive,
|
||||
{onp.float32: 1e-3 if jtu.device_under_test() == "tpu" else 1e-5,
|
||||
onp.float64: 1e-14}),
|
||||
{np.float32: 1e-3 if jtu.device_under_test() == "tpu" else 1e-5,
|
||||
np.float64: 1e-14}),
|
||||
op_record("digamma", 1, float_dtypes, jtu.rand_positive,
|
||||
{onp.float64: 1e-14}),
|
||||
{np.float64: 1e-14}),
|
||||
op_record("betainc", 3, float_dtypes, jtu.rand_positive,
|
||||
{onp.float64: 1e-14}),
|
||||
{np.float64: 1e-14}),
|
||||
op_record("igamma", 2,
|
||||
[f for f in float_dtypes if f not in [dtypes.bfloat16, onp.float16]],
|
||||
jtu.rand_positive, {onp.float64: 1e-14}),
|
||||
[f for f in float_dtypes if f not in [dtypes.bfloat16, np.float16]],
|
||||
jtu.rand_positive, {np.float64: 1e-14}),
|
||||
op_record("igammac", 2,
|
||||
[f for f in float_dtypes if f not in [dtypes.bfloat16, onp.float16]],
|
||||
jtu.rand_positive, {onp.float64: 1e-14}),
|
||||
[f for f in float_dtypes if f not in [dtypes.bfloat16, np.float16]],
|
||||
jtu.rand_positive, {np.float64: 1e-14}),
|
||||
op_record("erf", 1, float_dtypes, jtu.rand_small),
|
||||
op_record("erfc", 1, float_dtypes, jtu.rand_small),
|
||||
# TODO(b/142976030): the approximation of erfinf used by XLA is only
|
||||
# accurate to float32 precision.
|
||||
op_record("erf_inv", 1, float_dtypes, jtu.rand_small,
|
||||
{onp.float64: 1e-9}),
|
||||
{np.float64: 1e-9}),
|
||||
op_record("bessel_i0e", 1, float_dtypes, jtu.rand_default),
|
||||
op_record("bessel_i1e", 1, float_dtypes, jtu.rand_default),
|
||||
|
||||
@ -189,7 +189,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for rec in LAX_OPS))
|
||||
def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol):
|
||||
if (not FLAGS.jax_enable_x64 and op_name == "nextafter"
|
||||
and dtype == onp.float64):
|
||||
and dtype == np.float64):
|
||||
raise SkipTest("64-bit mode disabled")
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
||||
@ -204,7 +204,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
from_dtype, to_dtype),
|
||||
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
|
||||
for from_dtype, to_dtype in itertools.product(
|
||||
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
|
||||
[np.float32, np.int32, "float32", "int32"], repeat=2)
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testConvertElementType(self, from_dtype, to_dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
@ -217,7 +217,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
.format(from_dtype, to_dtype),
|
||||
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
|
||||
for from_dtype, to_dtype in itertools.product(
|
||||
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
|
||||
[np.float32, np.int32, "float32", "int32"], repeat=2)
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testConvertElementTypeAgainstNumpy(self, from_dtype, to_dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
@ -231,7 +231,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
.format(from_dtype, to_dtype),
|
||||
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
|
||||
for from_dtype, to_dtype in itertools.product(
|
||||
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
|
||||
[np.float32, np.int32, "float32", "int32"], repeat=2)
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testBitcastConvertType(self, from_dtype, to_dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
@ -244,7 +244,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
.format(from_dtype, to_dtype),
|
||||
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
|
||||
for from_dtype, to_dtype in itertools.product(
|
||||
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
|
||||
[np.float32, np.int32, "float32", "int32"], repeat=2)
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
@ -298,7 +298,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
|
||||
dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name,
|
||||
dim, ",".join(str(d) for d in base_shape), np.dtype(dtype).name,
|
||||
num_arrs),
|
||||
"dim": dim, "base_shape": base_shape, "dtype": dtype,
|
||||
"num_arrs": num_arrs, "rng_factory": rng_factory}
|
||||
@ -317,7 +317,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
|
||||
dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name,
|
||||
dim, ",".join(str(d) for d in base_shape), np.dtype(dtype).name,
|
||||
num_arrs),
|
||||
"dim": dim, "base_shape": base_shape, "dtype": dtype,
|
||||
"num_arrs": num_arrs, "rng_factory": rng_factory}
|
||||
@ -421,7 +421,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for lhs_shape, rhs_shape in [
|
||||
((b, i, 9, 10), (j, i, 4, 5))
|
||||
for b, i, j in itertools.product([1, 2, 3], repeat=3)]
|
||||
for dtype in [onp.float32] for strides in [(1, 1), (1, 2), (2, 1)]
|
||||
for dtype in [np.float32] for strides in [(1, 1), (1, 2), (2, 1)]
|
||||
for padding in [((0, 0), (0, 0)), ((1, 2), (2, 0))]
|
||||
for lhs_dilation, rhs_dilation in itertools.product(
|
||||
[(1, 1), (1, 2), (2, 2)], repeat=2)
|
||||
@ -497,11 +497,11 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def testConv0DIsDot(self):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
def args_maker():
|
||||
return [rng((10, 5), onp.float32), rng((5, 7), onp.float32)]
|
||||
return [rng((10, 5), np.float32), rng((5, 7), np.float32)]
|
||||
jnp_fun = partial(lax.conv_general_dilated, window_strides=(),
|
||||
padding='VALID', dimension_numbers=('NC', 'IO', 'NC'))
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
self._CheckAgainstNumpy(jnp_fun, onp.dot, args_maker, tol=.1)
|
||||
self._CheckAgainstNumpy(jnp_fun, np.dot, args_maker, tol=.1)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@ -514,9 +514,9 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rhs_dilation = rhs_dilation or one
|
||||
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
|
||||
dimension_numbers)
|
||||
in_shape = onp.take(data.shape, dn.lhs_spec)
|
||||
in_shape = np.take(data.shape, dn.lhs_spec)
|
||||
in_sdims = in_shape[2:]
|
||||
k_shape = onp.take(kernel.shape, dn.rhs_spec)
|
||||
k_shape = np.take(kernel.shape, dn.rhs_spec)
|
||||
k_sdims = k_shape[2:]
|
||||
e_k_sdims = [(k-1) * r + 1 for k, r in zip(k_sdims, rhs_dilation)]
|
||||
if padding == 'VALID':
|
||||
@ -527,8 +527,8 @@ class LaxTest(jtu.JaxTestCase):
|
||||
o_shape = [in_shape[0], k_shape[1]] + o_sdims
|
||||
out_spec_inv = [x[0] for x in
|
||||
sorted(enumerate(dn.out_spec), key=lambda x: x[1])]
|
||||
o_layout = onp.take(onp.array(o_shape), out_spec_inv)
|
||||
placeholder = onp.ones(o_layout, data.dtype)
|
||||
o_layout = np.take(np.array(o_shape), out_spec_inv)
|
||||
placeholder = np.ones(o_layout, data.dtype)
|
||||
conv = lambda x: lax.conv_general_dilated(x, kernel, strides, padding,
|
||||
one, rhs_dilation, dn)
|
||||
_, g = api.vjp(conv, placeholder)
|
||||
@ -538,10 +538,10 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def _transpose_conv_kernel(data, kernel, dimension_numbers):
|
||||
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
|
||||
dimension_numbers)
|
||||
spatial_axes = onp.array(dn.rhs_spec)[2:]
|
||||
spatial_axes = np.array(dn.rhs_spec)[2:]
|
||||
for axis in spatial_axes:
|
||||
kernel = onp.flip(kernel, axis)
|
||||
kernel = onp.swapaxes(kernel, dn.rhs_spec[0], dn.rhs_spec[1])
|
||||
kernel = np.flip(kernel, axis)
|
||||
kernel = np.swapaxes(kernel, dn.rhs_spec[0], dn.rhs_spec[1])
|
||||
return kernel
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -725,9 +725,9 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
tol = {
|
||||
onp.float16: 1e-2,
|
||||
onp.float64: max(jtu.default_tolerance()[onp.dtype(onp.float64)], 1e-14),
|
||||
onp.complex128: max(jtu.default_tolerance()[onp.dtype(onp.complex128)],
|
||||
np.float16: 1e-2,
|
||||
np.float64: max(jtu.default_tolerance()[np.dtype(np.float64)], 1e-14),
|
||||
np.complex128: max(jtu.default_tolerance()[np.dtype(np.complex128)],
|
||||
1e-14)
|
||||
}
|
||||
lax_op = partial(lax.dot, precision=lax.Precision.HIGHEST)
|
||||
@ -811,7 +811,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
|
||||
shape, onp.dtype(dtype).name, broadcast_sizes),
|
||||
shape, np.dtype(dtype).name, broadcast_sizes),
|
||||
"shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
|
||||
"rng_factory": rng_factory}
|
||||
for shape in [(), (2, 3)]
|
||||
@ -863,7 +863,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
|
||||
jtu.format_shape_dtype_string(inshape, onp.float32),
|
||||
jtu.format_shape_dtype_string(inshape, np.float32),
|
||||
outshape, broadcast_dimensions),
|
||||
"inshape": inshape, "outshape": outshape,
|
||||
"broadcast_dimensions": broadcast_dimensions, "err_msg": err_msg}
|
||||
@ -880,7 +880,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
]))
|
||||
def testBroadcastInDimShapeCheck(self, inshape, outshape, broadcast_dimensions, err_msg):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(inshape, onp.float32)
|
||||
x = rng(inshape, np.float32)
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
lax.broadcast_in_dim(x, shape=outshape, broadcast_dimensions=broadcast_dimensions)
|
||||
|
||||
@ -910,7 +910,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_dimensions={}".format(
|
||||
jtu.format_shape_dtype_string(inshape, onp.float32), dimensions),
|
||||
jtu.format_shape_dtype_string(inshape, np.float32), dimensions),
|
||||
"inshape": inshape, "dimensions": dimensions, "error_type": error_type,
|
||||
"err_msg": err_msg}
|
||||
for inshape, dimensions, error_type, err_msg in [
|
||||
@ -922,13 +922,13 @@ class LaxTest(jtu.JaxTestCase):
|
||||
]))
|
||||
def testSqueezeShapeCheck(self, inshape, dimensions, error_type, err_msg):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(inshape, onp.float32)
|
||||
x = rng(inshape, np.float32)
|
||||
with self.assertRaisesRegex(error_type, err_msg):
|
||||
lax.squeeze(x, dimensions=dimensions)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_dimensions={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, onp.float32), dimensions),
|
||||
jtu.format_shape_dtype_string(arg_shape, np.float32), dimensions),
|
||||
"arg_shape": arg_shape, "dimensions": dimensions,
|
||||
"rng_factory": rng_factory}
|
||||
for arg_shape, dimensions in [
|
||||
@ -942,7 +942,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testSqueeze(self, arg_shape, dimensions, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(arg_shape, onp.float32)]
|
||||
args_maker = lambda: [rng(arg_shape, np.float32)]
|
||||
op = lambda x: lax.squeeze(x, dimensions)
|
||||
numpy_op = lambda x: lax_reference.squeeze(x, dimensions)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
@ -994,7 +994,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def testPad(self, shape, dtype, pads, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
fun = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
|
||||
fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1014,29 +1014,29 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def testPadAgainstNumpy(self, shape, dtype, pads, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
op = lambda x: lax.pad(x, onp.array(0, dtype), pads)
|
||||
numpy_op = lambda x: lax_reference.pad(x, onp.array(0, dtype), pads)
|
||||
op = lambda x: lax.pad(x, np.array(0, dtype), pads)
|
||||
numpy_op = lambda x: lax_reference.pad(x, np.array(0, dtype), pads)
|
||||
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
||||
|
||||
def testReverse(self):
|
||||
rev = api.jit(lambda operand: lax.rev(operand, dimensions))
|
||||
|
||||
dimensions = []
|
||||
self.assertAllClose(onp.array([0, 1, 2, 3]), rev(onp.array([0, 1, 2, 3])),
|
||||
self.assertAllClose(np.array([0, 1, 2, 3]), rev(np.array([0, 1, 2, 3])),
|
||||
check_dtypes=False)
|
||||
|
||||
dimensions = [0]
|
||||
self.assertAllClose(onp.array([3, 2, 1]), rev(onp.array([1, 2, 3])),
|
||||
self.assertAllClose(np.array([3, 2, 1]), rev(np.array([1, 2, 3])),
|
||||
check_dtypes=False)
|
||||
|
||||
dimensions = [0, 1]
|
||||
self.assertAllClose(onp.array([[6, 5, 4], [3, 2, 1]]),
|
||||
rev(onp.array([[1, 2, 3], [4, 5, 6]])),
|
||||
self.assertAllClose(np.array([[6, 5, 4], [3, 2, 1]]),
|
||||
rev(np.array([[1, 2, 3], [4, 5, 6]])),
|
||||
check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_predshape={}_argshapes={}".format(
|
||||
jtu.format_shape_dtype_string(pred_shape, onp.bool_),
|
||||
jtu.format_shape_dtype_string(pred_shape, np.bool_),
|
||||
jtu.format_shape_dtype_string(arg_shape, arg_dtype)),
|
||||
"pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype,
|
||||
"rng_factory": rng_factory}
|
||||
@ -1046,14 +1046,14 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testSelect(self, pred_shape, arg_shape, arg_dtype, rng_factory):
|
||||
def args_maker():
|
||||
return [rng(pred_shape, onp.bool_), rng(arg_shape, arg_dtype),
|
||||
return [rng(pred_shape, np.bool_), rng(arg_shape, arg_dtype),
|
||||
rng(arg_shape, arg_dtype)]
|
||||
rng = rng_factory(self.rng())
|
||||
return self._CompileAndCheck(lax.select, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_predshape={}_argshapes={}".format(
|
||||
jtu.format_shape_dtype_string(pred_shape, onp.bool_),
|
||||
jtu.format_shape_dtype_string(pred_shape, np.bool_),
|
||||
jtu.format_shape_dtype_string(arg_shape, arg_dtype)),
|
||||
"pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype,
|
||||
"rng_factory": rng_factory}
|
||||
@ -1063,7 +1063,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testSelectAgainstNumpy(self, pred_shape, arg_shape, arg_dtype, rng_factory):
|
||||
def args_maker():
|
||||
return [rng(pred_shape, onp.bool_), rng(arg_shape, arg_dtype),
|
||||
return [rng(pred_shape, np.bool_), rng(arg_shape, arg_dtype),
|
||||
rng(arg_shape, arg_dtype)]
|
||||
rng = rng_factory(self.rng())
|
||||
return self._CheckAgainstNumpy(lax.select, lax_reference.select, args_maker)
|
||||
@ -1129,16 +1129,16 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"shape": shape, "dtype": dtype, "start_indices": start_indices,
|
||||
"size_indices": size_indices, "rng_factory": rng_factory}
|
||||
for shape, start_indices, size_indices in [
|
||||
[(3,), onp.array((1,)), (1,)],
|
||||
[(3,), np.array((1,)), (1,)],
|
||||
[(5, 3), (1, 1), (3, 1)],
|
||||
[(5, 3), onp.array((1, 1)), (3, 1)],
|
||||
[(7, 5, 3), onp.array((4, 1, 0)), (2, 0, 1)],
|
||||
[(5, 3), np.array((1, 1)), (3, 1)],
|
||||
[(7, 5, 3), np.array((4, 1, 0)), (2, 0, 1)],
|
||||
]
|
||||
for dtype in default_dtypes
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testDynamicSlice(self, shape, dtype, start_indices, size_indices, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)]
|
||||
args_maker = lambda: [rng(shape, dtype), np.array(start_indices)]
|
||||
op = lambda x, starts: lax.dynamic_slice(x, starts, size_indices)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
|
||||
@ -1158,7 +1158,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def testDynamicSliceAgainstNumpy(self, shape, dtype, start_indices,
|
||||
size_indices, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)]
|
||||
args_maker = lambda: [rng(shape, dtype), np.array(start_indices)]
|
||||
op = lambda x, s: lax.dynamic_slice(x, s, size_indices)
|
||||
numpy_op = lambda x, s: lax_reference.dynamic_slice(x, s, size_indices)
|
||||
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
||||
@ -1166,8 +1166,8 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def testDynamicSliceInDim(self):
|
||||
# Regression test for mixed type problem in dynamic_slice_in_dim.
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng((6, 7), onp.int32)
|
||||
onp.testing.assert_equal(lax.dynamic_slice_in_dim(x, 2, 3), x[2:5])
|
||||
x = rng((6, 7), np.int32)
|
||||
np.testing.assert_equal(lax.dynamic_slice_in_dim(x, 2, 3), x[2:5])
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
|
||||
@ -1188,7 +1188,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
def args_maker():
|
||||
return [rng(shape, dtype), rng(update_shape, dtype),
|
||||
onp.array(start_indices)]
|
||||
np.array(start_indices)]
|
||||
|
||||
self._CompileAndCheck(lax.dynamic_update_slice, args_maker)
|
||||
|
||||
@ -1211,7 +1211,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
def args_maker():
|
||||
return [rng(shape, dtype), rng(update_shape, dtype),
|
||||
onp.array(start_indices)]
|
||||
np.array(start_indices)]
|
||||
|
||||
self._CheckAgainstNumpy(lax.dynamic_update_slice,
|
||||
lax_reference.dynamic_update_slice, args_maker)
|
||||
@ -1263,16 +1263,16 @@ class LaxTest(jtu.JaxTestCase):
|
||||
(0, lax.add, default_dtypes),
|
||||
(1, lax.mul, default_dtypes),
|
||||
(0, lax.max, all_dtypes), # non-monoidal
|
||||
(-onp.inf, lax.max, float_dtypes),
|
||||
(dtypes.iinfo(onp.int32).min, lax.max, [onp.int32]),
|
||||
# (dtypes.iinfo(onp.int64).min, lax.max, [onp.int64]), # TODO fails
|
||||
(dtypes.iinfo(onp.uint32).min, lax.max, [onp.uint32]),
|
||||
(dtypes.iinfo(onp.uint64).min, lax.max, [onp.uint64]),
|
||||
(onp.inf, lax.min, float_dtypes),
|
||||
(dtypes.iinfo(onp.int32).max, lax.min, [onp.int32]),
|
||||
# (dtypes.iinfo(onp.int64).max, lax.min, [onp.int64]), # TODO fails
|
||||
(dtypes.iinfo(onp.uint32).max, lax.min, [onp.uint32]),
|
||||
(dtypes.iinfo(onp.uint64).max, lax.min, [onp.uint64]),
|
||||
(-np.inf, lax.max, float_dtypes),
|
||||
(dtypes.iinfo(np.int32).min, lax.max, [np.int32]),
|
||||
# (dtypes.iinfo(np.int64).min, lax.max, [np.int64]), # TODO fails
|
||||
(dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]),
|
||||
(dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]),
|
||||
(np.inf, lax.min, float_dtypes),
|
||||
(dtypes.iinfo(np.int32).max, lax.min, [np.int32]),
|
||||
# (dtypes.iinfo(np.int64).max, lax.min, [np.int64]), # TODO fails
|
||||
(dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]),
|
||||
(dtypes.iinfo(np.uint64).max, lax.min, [np.uint64]),
|
||||
]
|
||||
for dtype in types
|
||||
for shape, dims in [
|
||||
@ -1280,11 +1280,11 @@ class LaxTest(jtu.JaxTestCase):
|
||||
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
|
||||
]
|
||||
for rng_factory in [
|
||||
jtu.rand_default if dtypes.issubdtype(dtype, onp.integer)
|
||||
jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
||||
else jtu.rand_small]))
|
||||
def testReduce(self, op, init_val, shape, dtype, dims, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
init_val = onp.asarray(init_val, dtype=dtype)
|
||||
init_val = np.asarray(init_val, dtype=dtype)
|
||||
fun = lambda operand, init_val: lax.reduce(operand, init_val, op, dims)
|
||||
args_maker = lambda: [rng(shape, dtype), init_val]
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
@ -1297,20 +1297,20 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_dtype={}_padding={}"
|
||||
.format(op.__name__, onp.dtype(dtype).name, padding),
|
||||
.format(op.__name__, np.dtype(dtype).name, padding),
|
||||
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
|
||||
"rng_factory": rng_factory}
|
||||
for init_val, op, dtypes in [
|
||||
(0, lax.add, [onp.float32]),
|
||||
(-onp.inf, lax.max, [onp.float32]),
|
||||
(onp.inf, lax.min, [onp.float32]),
|
||||
(0, lax.add, [np.float32]),
|
||||
(-np.inf, lax.max, [np.float32]),
|
||||
(np.inf, lax.min, [np.float32]),
|
||||
]
|
||||
for dtype in dtypes
|
||||
for padding in ["VALID", "SAME"]
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
def testReduceWindow(self, op, init_val, dtype, padding, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
init_val = onp.asarray(init_val, dtype=dtype)
|
||||
init_val = np.asarray(init_val, dtype=dtype)
|
||||
|
||||
all_configs = itertools.chain(
|
||||
itertools.product(
|
||||
@ -1344,27 +1344,27 @@ class LaxTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_shape={}_axis={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis),
|
||||
"op": op, "onp_op": onp_op, "shape": shape, "dtype": dtype,
|
||||
"op": op, "np_op": np_op, "shape": shape, "dtype": dtype,
|
||||
"axis": axis, "rng_factory": rng_factory}
|
||||
for op, onp_op, types in [
|
||||
(lax.cumsum, onp.cumsum, default_dtypes),
|
||||
(lax.cumprod, onp.cumprod, default_dtypes),
|
||||
(lax.cummax, onp.maximum.accumulate, default_dtypes),
|
||||
(lax.cummin, onp.minimum.accumulate, default_dtypes),
|
||||
for op, np_op, types in [
|
||||
(lax.cumsum, np.cumsum, default_dtypes),
|
||||
(lax.cumprod, np.cumprod, default_dtypes),
|
||||
(lax.cummax, np.maximum.accumulate, default_dtypes),
|
||||
(lax.cummin, np.minimum.accumulate, default_dtypes),
|
||||
]
|
||||
for dtype in types
|
||||
for shape in [[10], [3, 4, 5]]
|
||||
for axis in range(len(shape))
|
||||
for rng_factory in [
|
||||
jtu.rand_default if dtypes.issubdtype(dtype, onp.integer)
|
||||
jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
||||
else jtu.rand_small]))
|
||||
def testCumulativeReduce(self, op, onp_op, shape, dtype, axis, rng_factory):
|
||||
def testCumulativeReduce(self, op, np_op, shape, dtype, axis, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
fun = partial(op, axis=axis)
|
||||
onp_fun = partial(onp_op, axis=axis, dtype=dtype)
|
||||
np_fun = partial(np_op, axis=axis, dtype=dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
self._CheckAgainstNumpy(fun, onp_fun, args_maker)
|
||||
self._CheckAgainstNumpy(fun, np_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_axis={}_isstable={}".format(
|
||||
@ -1376,7 +1376,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for is_stable in [False, True]))
|
||||
def testSort(self, shape, dtype, axis, is_stable):
|
||||
# TODO(b/141131288): enable complex-valued sorts on TPU.
|
||||
if (onp.issubdtype(dtype, onp.complexfloating) and (
|
||||
if (np.issubdtype(dtype, np.complexfloating) and (
|
||||
(jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 47)) or
|
||||
jtu.device_under_test() == "tpu")):
|
||||
raise SkipTest("Complex-valued sort not implemented")
|
||||
@ -1395,7 +1395,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for is_stable in [False, True]))
|
||||
def testSortAgainstNumpy(self, shape, dtype, axis, is_stable):
|
||||
# TODO(b/141131288): enable complex-valued sorts on TPU.
|
||||
if (onp.issubdtype(dtype, onp.complexfloating) and (
|
||||
if (np.issubdtype(dtype, np.complexfloating) and (
|
||||
(jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 47)) or
|
||||
jtu.device_under_test() == "tpu")):
|
||||
raise SkipTest("Complex-valued sort not implemented")
|
||||
@ -1417,13 +1417,13 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"shape": shape, "key_dtype": key_dtype, "val_dtype": val_dtype,
|
||||
"axis": axis, "is_stable": is_stable}
|
||||
for key_dtype in float_dtypes + complex_dtypes + int_dtypes + uint_dtypes
|
||||
for val_dtype in [onp.float32, onp.int32, onp.uint32]
|
||||
for val_dtype in [np.float32, np.int32, np.uint32]
|
||||
for shape in [(3,), (5, 3)]
|
||||
for axis in [-1, len(shape) - 1]
|
||||
for is_stable in [False, True]))
|
||||
def testSortKeyVal(self, shape, key_dtype, val_dtype, axis, is_stable):
|
||||
# TODO(b/141131288): enable complex-valued sorts on TPU.
|
||||
if (onp.issubdtype(key_dtype, onp.complexfloating) and (
|
||||
if (np.issubdtype(key_dtype, np.complexfloating) and (
|
||||
(jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 47)) or
|
||||
jtu.device_under_test() == "tpu")):
|
||||
raise SkipTest("Complex-valued sort not implemented")
|
||||
@ -1432,7 +1432,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
# too, since we don't guarantee the same ordering of values with equal keys.
|
||||
# To avoid that case, we generate unique keys (globally in the key array).
|
||||
def args_maker():
|
||||
flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype)
|
||||
flat_keys = np.arange(np.prod(shape, dtype=int), dtype=key_dtype)
|
||||
keys = self.rng().permutation(flat_keys).reshape(shape)
|
||||
values = rng(shape, val_dtype)
|
||||
return keys, values
|
||||
@ -1449,14 +1449,14 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for num_keys in range(1, shape[0] + 1)))
|
||||
def testSortNumKeys(self, shape, dtype, num_keys):
|
||||
# TODO(b/141131288): enable complex-valued sorts on TPU.
|
||||
if (onp.issubdtype(dtype, onp.complexfloating) and (
|
||||
if (np.issubdtype(dtype, np.complexfloating) and (
|
||||
(jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 47)) or
|
||||
jtu.device_under_test() == "tpu")):
|
||||
raise SkipTest("Complex-valued sort not implemented")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
lax_fun = lambda x: lax.sort(tuple(x), num_keys=num_keys)
|
||||
numpy_fun = lambda x: tuple(x[:, onp.lexsort(x[:num_keys][::-1])])
|
||||
numpy_fun = lambda x: tuple(x[:, np.lexsort(x[:num_keys][::-1])])
|
||||
# self._CompileAndCheck(lax_fun, args_maker)
|
||||
self._CheckAgainstNumpy(lax_fun, numpy_fun, args_maker)
|
||||
|
||||
@ -1468,12 +1468,12 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"shape": shape, "key_dtype": key_dtype, "val_dtype": val_dtype,
|
||||
"axis": axis}
|
||||
for key_dtype in float_dtypes + complex_dtypes + int_dtypes + uint_dtypes
|
||||
for val_dtype in [onp.float32, onp.int32, onp.uint32]
|
||||
for val_dtype in [np.float32, np.int32, np.uint32]
|
||||
for shape in [(3,), (5, 3)]
|
||||
for axis in [-1, len(shape) - 1]))
|
||||
def testSortKeyValAgainstNumpy(self, shape, key_dtype, val_dtype, axis):
|
||||
# TODO(b/141131288): enable complex-valued sorts on TPU.
|
||||
if (onp.issubdtype(key_dtype, onp.complexfloating) and (
|
||||
if (np.issubdtype(key_dtype, np.complexfloating) and (
|
||||
(jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 47)) or
|
||||
jtu.device_under_test() == "tpu")):
|
||||
raise SkipTest("Complex-valued sort not implemented")
|
||||
@ -1482,7 +1482,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
# too, since we don't guarantee the same ordering of values with equal keys.
|
||||
# To avoid that case, we generate unique keys (globally in the key array).
|
||||
def args_maker():
|
||||
flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype)
|
||||
flat_keys = np.arange(np.prod(shape, dtype=int), dtype=key_dtype)
|
||||
keys = self.rng().permutation(flat_keys).reshape(shape)
|
||||
values = rng(shape, val_dtype)
|
||||
return keys, values
|
||||
@ -1495,17 +1495,17 @@ class LaxTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_shape={}_k={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), k),
|
||||
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k}
|
||||
for dtype in [onp.float32, onp.int32, onp.uint32]
|
||||
for dtype in [np.float32, np.int32, np.uint32]
|
||||
for shape in [(3,), (5, 3)]
|
||||
for k in [1, 3]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testTopK(self, shape, dtype, k, rng_factory):
|
||||
def args_maker():
|
||||
flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype)
|
||||
flat_values = np.arange(np.prod(shape, dtype=int), dtype=dtype)
|
||||
values = self.rng().permutation(flat_values).reshape(shape)
|
||||
return [values]
|
||||
def reference_top_k(x):
|
||||
bcast_idxs = onp.broadcast_to(onp.arange(shape[-1], dtype=onp.int32), shape)
|
||||
bcast_idxs = np.broadcast_to(np.arange(shape[-1], dtype=np.int32), shape)
|
||||
sorted_vals, sorted_idxs = lax_reference.sort_key_val(x, bcast_idxs)
|
||||
return sorted_vals[..., :-k-1:-1], sorted_idxs[..., :-k-1:-1]
|
||||
op = lambda vs: lax.top_k(vs, k=k)
|
||||
@ -1534,10 +1534,10 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def collapse_first_two(x):
|
||||
return lax.collapse(x, 0, 2)
|
||||
|
||||
self.assertEqual((6,), collapse_first_two(onp.zeros((2, 3))).shape)
|
||||
self.assertEqual((6, 4), collapse_first_two(onp.zeros((2, 3, 4))).shape)
|
||||
self.assertEqual((6,), collapse_first_two(np.zeros((2, 3))).shape)
|
||||
self.assertEqual((6, 4), collapse_first_two(np.zeros((2, 3, 4))).shape)
|
||||
self.assertEqual((2, 3, 4),
|
||||
collapse_first_two(onp.zeros((1, 2, 3, 4))).shape)
|
||||
collapse_first_two(np.zeros((1, 2, 3, 4))).shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_axes={}".format(
|
||||
@ -1545,10 +1545,10 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes, "rng_factory": rng_factory}
|
||||
for dtype in all_dtypes
|
||||
for shape, idxs, axes in [
|
||||
[(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
|
||||
[(3, 4, 5), (onp.array([-1, -2]),), (0,)],
|
||||
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
|
||||
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
|
||||
[(3, 4, 5), (np.array([0, 2, 1]),), (0,)],
|
||||
[(3, 4, 5), (np.array([-1, -2]),), (0,)],
|
||||
[(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 1)],
|
||||
[(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 2)],
|
||||
]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testIndexTake(self, shape, dtype, idxs, axes, rng_factory):
|
||||
@ -1567,16 +1567,16 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"rng_idx_factory": rng_idx_factory}
|
||||
for dtype in all_dtypes
|
||||
for shape, idxs, dnums, slice_sizes in [
|
||||
((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
|
||||
((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1,)),
|
||||
((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
||||
((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
||||
(2,)),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
||||
((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1, 3)),
|
||||
((10, 5), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
|
||||
((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
|
||||
(1, 3)),
|
||||
]
|
||||
@ -1600,13 +1600,13 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
||||
for dtype in float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
@ -1631,13 +1631,13 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
||||
for dtype in float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
@ -1662,13 +1662,13 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
||||
for dtype in float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
@ -1693,13 +1693,13 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
||||
for dtype in float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
@ -1724,26 +1724,26 @@ class LaxTest(jtu.JaxTestCase):
|
||||
api.jit(f)(1.) # doesn't crash
|
||||
|
||||
def testReshapeWithUnusualShapes(self):
|
||||
ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1))
|
||||
self.assertAllClose(ans, onp.ones((3, 1), onp.float32))
|
||||
ans = lax.reshape(np.ones((3,), np.float32), (lax.add(1, 2), 1))
|
||||
self.assertAllClose(ans, np.ones((3, 1), np.float32))
|
||||
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type.*",
|
||||
lambda: lax.reshape(onp.ones(3,), (onp.array([3, 1]),)))
|
||||
lambda: lax.reshape(np.ones(3,), (np.array([3, 1]),)))
|
||||
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type.*",
|
||||
lambda: lax.reshape(onp.ones(3,), (1.5, 2.0)))
|
||||
lambda: lax.reshape(np.ones(3,), (1.5, 2.0)))
|
||||
|
||||
@jtu.skip_on_devices("tpu") # S16 not supported on TPU
|
||||
def testDynamicSliceTypeErrors(self):
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"index arguments to dynamic_slice must be integers of the same type",
|
||||
lambda: lax.dynamic_slice(onp.ones((3, 4), dtype=onp.float32),
|
||||
(onp.int32(1), onp.int16(2)), (2, 2)))
|
||||
lambda: lax.dynamic_slice(np.ones((3, 4), dtype=np.float32),
|
||||
(np.int32(1), np.int16(2)), (2, 2)))
|
||||
|
||||
@jtu.skip_on_devices("tpu") # S16 not supported on TPU
|
||||
def testDynamicUpdateSliceTypeErrors(self):
|
||||
@ -1751,9 +1751,9 @@ class LaxTest(jtu.JaxTestCase):
|
||||
TypeError,
|
||||
"index arguments to dynamic_update_slice must be integers of the same "
|
||||
"type",
|
||||
lambda: lax.dynamic_update_slice(onp.ones((3, 4), dtype=onp.float32),
|
||||
onp.zeros((2, 2), dtype=onp.float32),
|
||||
(onp.int32(1), onp.int16(2))))
|
||||
lambda: lax.dynamic_update_slice(np.ones((3, 4), dtype=np.float32),
|
||||
np.zeros((2, 2), dtype=np.float32),
|
||||
(np.int32(1), np.int16(2))))
|
||||
|
||||
def test_tie_in_error(self):
|
||||
with core.skipping_checks():
|
||||
@ -1769,16 +1769,16 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
def test_reduction_with_repeated_axes_error(self):
|
||||
with self.assertRaisesRegex(ValueError, "duplicate value in 'axes' .*"):
|
||||
lax.reduce(onp.arange(3), 0, lax.add, (0, 0))
|
||||
lax.reduce(np.arange(3), 0, lax.add, (0, 0))
|
||||
|
||||
|
||||
class LazyConstantTest(jtu.JaxTestCase):
|
||||
def _Check(self, make_const, expected):
|
||||
# check casting to ndarray works
|
||||
asarray_result = onp.asarray(make_const())
|
||||
asarray_result = np.asarray(make_const())
|
||||
|
||||
# check passing as an argument works (should hit constant handler)
|
||||
zero = onp.array(0, expected.dtype)
|
||||
zero = np.array(0, expected.dtype)
|
||||
argument_result = lax.add(zero, make_const())
|
||||
|
||||
# check looping into a compiled computation works
|
||||
@ -1799,10 +1799,10 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
"shape": shape, "dtype": dtype, "fill_value": fill_value}
|
||||
for dtype in itertools.chain(default_dtypes, [None])
|
||||
for shape in [(), (3,), (2, 3), (2, 3, 4), (1001, 1001)]
|
||||
for fill_value in [0, 1, onp.pi]))
|
||||
for fill_value in [0, 1, np.pi]))
|
||||
def testFilledConstant(self, shape, fill_value, dtype):
|
||||
make_const = lambda: lax.full(shape, fill_value, dtype)
|
||||
expected = onp.full(shape, fill_value,
|
||||
expected = np.full(shape, fill_value,
|
||||
dtype or dtypes.result_type(fill_value))
|
||||
self._Check(make_const, expected)
|
||||
|
||||
@ -1819,10 +1819,10 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
def testIotaConstant(self, dtype, shape, dimension):
|
||||
make_const = lambda: lax.broadcasted_iota(dtype, shape, dimension)
|
||||
|
||||
arr = onp.arange(shape[dimension], dtype=dtypes.canonicalize_dtype(dtype))
|
||||
arr = np.arange(shape[dimension], dtype=dtypes.canonicalize_dtype(dtype))
|
||||
singleton_shape = [1] * len(shape)
|
||||
singleton_shape[dimension] = shape[dimension]
|
||||
expected = onp.broadcast_to(arr.reshape(singleton_shape), shape)
|
||||
expected = np.broadcast_to(arr.reshape(singleton_shape), shape)
|
||||
|
||||
self._Check(make_const, expected)
|
||||
|
||||
@ -1845,13 +1845,13 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
def testDeltaConstant(self, dtype, shape, axes):
|
||||
make_const = lambda: lax._delta(dtype, shape, axes)
|
||||
# don't check the asarray case, just assume it's right
|
||||
expected = onp.asarray(make_const())
|
||||
expected = np.asarray(make_const())
|
||||
self._Check(make_const, expected)
|
||||
|
||||
def testBroadcastInDim(self):
|
||||
arr = lax.full((2, 1), 1.) + 1.
|
||||
arr_onp = onp.full((2, 1), 1.) + 1.
|
||||
expected = lax_reference.broadcast_in_dim(arr_onp, (2, 1, 3), (0, 2))
|
||||
arr_np = np.full((2, 1), 1.) + 1.
|
||||
expected = lax_reference.broadcast_in_dim(arr_np, (2, 1, 3), (0, 2))
|
||||
make_const = lambda: lax.broadcast_in_dim(arr, (2, 1, 3), (0, 2))
|
||||
self._Check(make_const, expected)
|
||||
|
||||
|
@ -21,7 +21,7 @@ from unittest import SkipTest
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from jax import api
|
||||
from jax import dtypes
|
||||
@ -75,7 +75,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)]
|
||||
args_slice = args_slicer(args, bdims)
|
||||
ans = api.vmap(op, bdims)(*args)
|
||||
expected = onp.stack([op(*args_slice(i)) for i in range(bdim_size)])
|
||||
expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)])
|
||||
self.assertAllClose(ans, expected, rtol=rtol, atol=atol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
@ -124,7 +124,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for strides in all_strides
|
||||
for rhs_dil in rhs_dils
|
||||
for lhs_dil in lhs_dils
|
||||
for dtype in [onp.float32]
|
||||
for dtype in [np.float32]
|
||||
for padding in all_pads
|
||||
for dim_nums, perms in [
|
||||
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
|
||||
@ -146,8 +146,8 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
|
||||
# permute shapes to match dim_spec, scale by feature_group_count
|
||||
lhs_perm, rhs_perm = perms
|
||||
lhs_shape = list(onp.take(lhs_shape, lhs_perm))
|
||||
rhs_shape = list(onp.take(rhs_shape, rhs_perm))
|
||||
lhs_shape = list(np.take(lhs_shape, lhs_perm))
|
||||
rhs_shape = list(np.take(rhs_shape, rhs_perm))
|
||||
|
||||
conv = partial(lax.conv_general_dilated, window_strides=strides,
|
||||
padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
|
||||
@ -164,7 +164,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
"shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype,
|
||||
"bdims": bdims, "rng_factory": rng_factory}
|
||||
for from_dtype, to_dtype in itertools.product(
|
||||
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
|
||||
[np.float32, np.int32, "float32", "int32"], repeat=2)
|
||||
for shape in [(2, 3)]
|
||||
for bdims in all_bdims(shape)
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@ -179,7 +179,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
"shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype,
|
||||
"bdims": bdims, "rng_factory": rng_factory}
|
||||
for from_dtype, to_dtype in itertools.product(
|
||||
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
|
||||
[np.float32, np.int32, "float32", "int32"], repeat=2)
|
||||
for shape in [(2, 3)]
|
||||
for bdims in all_bdims(shape)
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@ -226,7 +226,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
op = partial(lax.dot, precision=lax.Precision.HIGHEST)
|
||||
self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype),
|
||||
rng, rtol={onp.float16: 5e-2, onp.float64: 5e-14})
|
||||
rng, rtol={np.float16: 5e-2, np.float64: 5e-14})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -280,7 +280,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format(
|
||||
shape, onp.dtype(dtype).name, broadcast_sizes, bdims),
|
||||
shape, np.dtype(dtype).name, broadcast_sizes, bdims),
|
||||
"shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
|
||||
"bdims": bdims, "rng_factory": rng_factory}
|
||||
for shape in [(), (2, 3)]
|
||||
@ -317,7 +317,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_dimensions={}_bdims={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, onp.float32),
|
||||
jtu.format_shape_dtype_string(arg_shape, np.float32),
|
||||
dimensions, bdims),
|
||||
"arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims,
|
||||
"rng_factory": rng_factory}
|
||||
@ -334,7 +334,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for bdims in all_bdims(arg_shape)
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testSqueeze(self, arg_shape, dimensions, bdims, rng_factory):
|
||||
dtype = onp.float32
|
||||
dtype = np.float32
|
||||
rng = rng_factory(self.rng())
|
||||
op = lambda x: lax.squeeze(x, dimensions)
|
||||
self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng)
|
||||
@ -373,12 +373,12 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for pads in [[(1, 2, 1), (0, 1, 0)]]))
|
||||
def testPad(self, shape, dtype, pads, bdims, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
fun = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
|
||||
fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
|
||||
self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_predshape={}_argshapes={}_bdims={}".format(
|
||||
jtu.format_shape_dtype_string(pred_shape, onp.bool_),
|
||||
jtu.format_shape_dtype_string(pred_shape, np.bool_),
|
||||
jtu.format_shape_dtype_string(arg_shape, arg_dtype),
|
||||
bdims),
|
||||
"pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype,
|
||||
@ -392,7 +392,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
rng = rng_factory(self.rng())
|
||||
op = lambda c, x, y: lax.select(c < 0, x, y)
|
||||
self._CheckBatching(op, 5, bdims, (pred_shape, arg_shape, arg_shape,),
|
||||
(onp.bool_, arg_dtype, arg_dtype), rng)
|
||||
(np.bool_, arg_dtype, arg_dtype), rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -448,16 +448,16 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
(0, lax.add, default_dtypes),
|
||||
(1, lax.mul, default_dtypes),
|
||||
(0, lax.max, all_dtypes), # non-monoidal
|
||||
(-onp.inf, lax.max, float_dtypes),
|
||||
(dtypes.iinfo(onp.int32).min, lax.max, [onp.int32]),
|
||||
(dtypes.iinfo(onp.int64).min, lax.max, [onp.int64]),
|
||||
(dtypes.iinfo(onp.uint32).min, lax.max, [onp.uint32]),
|
||||
(dtypes.iinfo(onp.uint64).min, lax.max, [onp.uint64]),
|
||||
(onp.inf, lax.min, float_dtypes),
|
||||
(dtypes.iinfo(onp.int32).max, lax.min, [onp.int32]),
|
||||
(dtypes.iinfo(onp.int64).max, lax.min, [onp.int64]),
|
||||
(dtypes.iinfo(onp.uint32).max, lax.min, [onp.uint32]),
|
||||
(dtypes.iinfo(onp.uint64).max, lax.min, [onp.uint64]),
|
||||
(-np.inf, lax.max, float_dtypes),
|
||||
(dtypes.iinfo(np.int32).min, lax.max, [np.int32]),
|
||||
(dtypes.iinfo(np.int64).min, lax.max, [np.int64]),
|
||||
(dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]),
|
||||
(dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]),
|
||||
(np.inf, lax.min, float_dtypes),
|
||||
(dtypes.iinfo(np.int32).max, lax.min, [np.int32]),
|
||||
(dtypes.iinfo(np.int64).max, lax.min, [np.int64]),
|
||||
(dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]),
|
||||
(dtypes.iinfo(np.uint64).max, lax.min, [np.uint64]),
|
||||
]
|
||||
for dtype in dtypes
|
||||
for shape, dims in [
|
||||
@ -468,7 +468,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
def testReduce(self, op, init_val, shape, dtype, dims, bdims, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
init_val = onp.asarray(init_val, dtype=dtype)
|
||||
init_val = np.asarray(init_val, dtype=dtype)
|
||||
fun = lambda operand: lax.reduce(operand, init_val, op, dims)
|
||||
self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
@ -485,25 +485,25 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for bdims in all_bdims(shape)))
|
||||
def testArgminmax(self, op, shape, dtype, dim, bdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
fun = lambda operand: op(operand, dim, onp.int32)
|
||||
fun = lambda operand: op(operand, dim, np.int32)
|
||||
self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_dtype={}_padding={}"
|
||||
.format(op.__name__, onp.dtype(dtype).name, padding),
|
||||
.format(op.__name__, np.dtype(dtype).name, padding),
|
||||
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
|
||||
"rng_factory": rng_factory}
|
||||
for init_val, op, dtypes in [
|
||||
(0, lax.add, [onp.float32]),
|
||||
(-onp.inf, lax.max, [onp.float32]),
|
||||
(onp.inf, lax.min, [onp.float32]),
|
||||
(0, lax.add, [np.float32]),
|
||||
(-np.inf, lax.max, [np.float32]),
|
||||
(np.inf, lax.min, [np.float32]),
|
||||
]
|
||||
for dtype in dtypes
|
||||
for padding in ["VALID", "SAME"]
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
def testReduceWindow(self, op, init_val, dtype, padding, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
init_val = onp.asarray(init_val, dtype=dtype)
|
||||
init_val = np.asarray(init_val, dtype=dtype)
|
||||
|
||||
all_configs = itertools.chain(
|
||||
itertools.product(
|
||||
@ -528,15 +528,15 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
"op": op, "shape": shape, "dtype": dtype, "bdims": bdims,
|
||||
"axis": axis, "rng_factory": rng_factory}
|
||||
for op, types in [
|
||||
(lax.cumsum, [onp.float32, onp.float64]),
|
||||
(lax.cumprod, [onp.float32, onp.float64]),
|
||||
(lax.cumsum, [np.float32, np.float64]),
|
||||
(lax.cumprod, [np.float32, np.float64]),
|
||||
]
|
||||
for dtype in types
|
||||
for shape in [[10], [3, 4, 5]]
|
||||
for axis in range(len(shape))
|
||||
for bdims in all_bdims(shape)
|
||||
for rng_factory in [
|
||||
jtu.rand_default if dtypes.issubdtype(dtype, onp.integer)
|
||||
jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
||||
else jtu.rand_small]))
|
||||
def testCumulativeReduce(self, op, shape, dtype, axis, bdims, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
@ -544,7 +544,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}_padding={}".format(onp.dtype(dtype).name,
|
||||
{"testcase_name": "_dtype={}_padding={}".format(np.dtype(dtype).name,
|
||||
padding),
|
||||
"dtype": dtype, "padding": padding, "rng_factory": rng_factory}
|
||||
for dtype in float_dtypes
|
||||
@ -589,7 +589,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
axes = range(ndims - fft_ndims, ndims)
|
||||
fft_lengths = [shape[axis] for axis in axes]
|
||||
op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths)
|
||||
self._CheckBatching(op, 5, bdims, [shape], [onp.complex64], rng)
|
||||
self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}"
|
||||
@ -599,16 +599,16 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
"slice_sizes": slice_sizes, "bdims": bdims}
|
||||
for dtype in all_dtypes
|
||||
for shape, idxs, dnums, slice_sizes in [
|
||||
((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
|
||||
((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1,)),
|
||||
((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
||||
((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
||||
(2,)),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
||||
((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1, 3)),
|
||||
((10, 5), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
|
||||
((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
|
||||
(1, 3)),
|
||||
]
|
||||
@ -626,13 +626,13 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
"update_shape": update_shape, "dnums": dnums, "bdims": bdims}
|
||||
for dtype in float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
@ -641,10 +641,10 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
fun = partial(lax.scatter_add, dimension_numbers=dnums)
|
||||
self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape],
|
||||
[dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()),
|
||||
rtol={onp.float16: 5e-3})
|
||||
rtol={np.float16: 5e-3})
|
||||
|
||||
def testShapeUsesBuiltinInt(self):
|
||||
x = lax.iota(onp.int32, 3) + 1
|
||||
x = lax.iota(np.int32, 3) + 1
|
||||
self.assertIsInstance(x.shape[0], int) # not np.int64
|
||||
|
||||
def testBroadcastShapesReturnsPythonInts(self):
|
||||
@ -677,7 +677,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dimension={}_arity={}_bdims={}_isstable={}"
|
||||
.format(jtu.format_shape_dtype_string(shape, onp.float32), dimension,
|
||||
.format(jtu.format_shape_dtype_string(shape, np.float32), dimension,
|
||||
arity, bdims, is_stable),
|
||||
"shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims,
|
||||
"is_stable": is_stable}
|
||||
@ -690,7 +690,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
if arity == 1:
|
||||
fun = partial(lax.sort, dimension=dimension)
|
||||
self._CheckBatching(fun, 5, bdims, (shape,) * arity, (onp.float32,) * arity,
|
||||
self._CheckBatching(fun, 5, bdims, (shape,) * arity, (np.float32,) * arity,
|
||||
rng)
|
||||
else:
|
||||
for i in range(arity):
|
||||
@ -698,7 +698,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
dimension=dimension,
|
||||
is_stable=is_stable)[i]
|
||||
self._CheckBatching(fun, 5, bdims, (shape,) * arity,
|
||||
(onp.float32,) * arity, rng)
|
||||
(np.float32,) * arity, rng)
|
||||
|
||||
|
||||
# TODO Concatenate
|
||||
|
@ -862,7 +862,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
def testLstsq(self, lhs_shape, rhs_shape, dtype, lowrank, rcond, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
onp_fun = partial(np.linalg.lstsq, rcond=rcond)
|
||||
np_fun = partial(np.linalg.lstsq, rcond=rcond)
|
||||
jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond)
|
||||
jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True)
|
||||
tol = {np.float32: 1e-6, np.float64: 1e-12,
|
||||
@ -873,7 +873,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
lhs[:, -1] = lhs[:, :-1].mean(1)
|
||||
return [lhs, rng(rhs_shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(onp_fun, jnp_fun_numpy_resid, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun_numpy_resid, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
# Disabled because grad is flaky for low-rank inputs.
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -40,17 +40,17 @@ def _fixed_ref_map_coordinates(input, coordinates, order, mode, cval=0.0):
|
||||
# the bounds of the original array.
|
||||
# https://github.com/scipy/scipy/issues/2640
|
||||
assert order <= 1
|
||||
padding = [(max(-onp.floor(c.min()).astype(int) + 1, 0),
|
||||
max(onp.ceil(c.max()).astype(int) + 1 - size, 0))
|
||||
padding = [(max(-np.floor(c.min()).astype(int) + 1, 0),
|
||||
max(np.ceil(c.max()).astype(int) + 1 - size, 0))
|
||||
for c, size in zip(coordinates, input.shape)]
|
||||
shifted_coords = [c + p[0] for p, c in zip(padding, coordinates)]
|
||||
pad_mode = {
|
||||
'nearest': 'edge', 'mirror': 'reflect', 'reflect': 'symmetric'
|
||||
}.get(mode, mode)
|
||||
if mode == 'constant':
|
||||
padded = onp.pad(input, padding, mode=pad_mode, constant_values=cval)
|
||||
padded = np.pad(input, padding, mode=pad_mode, constant_values=cval)
|
||||
else:
|
||||
padded = onp.pad(input, padding, mode=pad_mode)
|
||||
padded = np.pad(input, padding, mode=pad_mode)
|
||||
result = osp_ndimage.map_coordinates(
|
||||
padded, shifted_coords, order=order, mode=mode, cval=cval)
|
||||
return result
|
||||
@ -83,7 +83,7 @@ class NdimageTest(jtu.JaxTestCase):
|
||||
mode, cval, impl, round_, rng_factory):
|
||||
|
||||
def args_maker():
|
||||
x = onp.arange(onp.prod(shape), dtype=dtype).reshape(shape)
|
||||
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
||||
coords = [(size - 1) * rng(coords_shape, coords_dtype) for size in shape]
|
||||
if round_:
|
||||
coords = [c.round().astype(int) for c in coords]
|
||||
@ -103,8 +103,8 @@ class NdimageTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=0)
|
||||
|
||||
def testMapCoordinatesErrors(self):
|
||||
x = onp.arange(5.0)
|
||||
c = [onp.linspace(0, 5, num=3)]
|
||||
x = np.arange(5.0)
|
||||
c = [np.linspace(0, 5, num=3)]
|
||||
with self.assertRaisesRegex(NotImplementedError, 'requires order<=1'):
|
||||
lsp_ndimage.map_coordinates(x, c, order=2)
|
||||
with self.assertRaisesRegex(
|
||||
@ -118,13 +118,13 @@ class NdimageTest(jtu.JaxTestCase):
|
||||
lsp_ndimage.map_coordinates.__doc__)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_order={}".format(onp.dtype(dtype), order),
|
||||
{"testcase_name": "_{}_order={}".format(np.dtype(dtype), order),
|
||||
"dtype": dtype, "order": order}
|
||||
for dtype in float_dtypes + int_dtypes
|
||||
for order in [0, 1]))
|
||||
def testMapCoordinatesRoundHalf(self, dtype, order):
|
||||
x = onp.arange(-3, 3, dtype=dtype)
|
||||
c = onp.array([[.5, 1.5, 2.5, 3.5]])
|
||||
x = np.arange(-3, 3, dtype=dtype)
|
||||
c = np.array([[.5, 1.5, 2.5, 3.5]])
|
||||
def args_maker():
|
||||
return x, c
|
||||
|
||||
@ -136,9 +136,9 @@ class NdimageTest(jtu.JaxTestCase):
|
||||
# regression test for https://github.com/google/jax/issues/3024
|
||||
|
||||
def loss(delta):
|
||||
x = onp.arange(100.0)
|
||||
x = np.arange(100.0)
|
||||
border = 10
|
||||
indices = onp.arange(x.size) + delta
|
||||
indices = np.arange(x.size) + delta
|
||||
# linear interpolation of the linear function y=x should be exact
|
||||
shifted = lsp_ndimage.map_coordinates(x, [indices], order=1)
|
||||
return ((x - shifted) ** 2)[border:-border].mean()
|
||||
|
@ -17,7 +17,7 @@ from functools import partial
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
@ -56,7 +56,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
|
||||
osp_fun = partial(osp_op, mode=mode)
|
||||
jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
|
||||
tol = {onp.float16: 1e-2, onp.float32: 1e-2, onp.float64: 1e-8}
|
||||
tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-8}
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker)
|
||||
|
||||
@ -79,7 +79,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
|
||||
osp_fun = partial(osp_op, mode=mode)
|
||||
jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
|
||||
tol = {onp.float16: 1e-2, onp.float32: 1e-2, onp.float64: 1e-14}
|
||||
tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-14}
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker)
|
||||
|
||||
@ -97,7 +97,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
osp_fun = partial(osp_signal.detrend, axis=axis, type=type, bp=bp)
|
||||
jsp_fun = partial(jsp_signal.detrend, axis=axis, type=type, bp=bp)
|
||||
tol = {onp.float32: 1e-5, onp.float64: 1e-12}
|
||||
tol = {np.float32: 1e-5, np.float64: 1e-12}
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
|
@ -17,7 +17,7 @@ import itertools
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
import numpy as onp
|
||||
import numpy as np
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import test_util as jtu
|
||||
@ -50,15 +50,15 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
def args_maker():
|
||||
k, mu, loc = map(rng, shapes, dtypes)
|
||||
k = onp.floor(k)
|
||||
k = np.floor(k)
|
||||
# clipping to ensure that rate parameter is strictly positive
|
||||
mu = onp.clip(onp.abs(mu), a_min=0.1, a_max=None)
|
||||
loc = onp.floor(loc)
|
||||
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
|
||||
loc = np.floor(loc)
|
||||
return [k, mu, loc]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol={onp.float64: 1e-14})
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_default)
|
||||
def testPoissonPmf(self, rng_factory, shapes, dtypes):
|
||||
@ -68,10 +68,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
def args_maker():
|
||||
k, mu, loc = map(rng, shapes, dtypes)
|
||||
k = onp.floor(k)
|
||||
k = np.floor(k)
|
||||
# clipping to ensure that rate parameter is strictly positive
|
||||
mu = onp.clip(onp.abs(mu), a_min=0.1, a_max=None)
|
||||
loc = onp.floor(loc)
|
||||
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
|
||||
loc = np.floor(loc)
|
||||
return [k, mu, loc]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
@ -86,9 +86,9 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
def args_maker():
|
||||
x, logit, loc = map(rng, shapes, dtypes)
|
||||
x = onp.floor(x)
|
||||
x = np.floor(x)
|
||||
p = expit(logit)
|
||||
loc = onp.floor(loc)
|
||||
loc = np.floor(loc)
|
||||
return [x, p, loc]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
@ -103,9 +103,9 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
def args_maker():
|
||||
x, logit, loc = map(rng, shapes, dtypes)
|
||||
x = onp.floor(x)
|
||||
x = np.floor(x)
|
||||
p = expit(logit)
|
||||
loc = onp.floor(loc)
|
||||
loc = np.floor(loc)
|
||||
return [x, p, loc]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
@ -125,7 +125,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={onp.float32: 2e-3, onp.float64: 1e-4})
|
||||
rtol={np.float32: 2e-3, np.float64: 1e-4})
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_default)
|
||||
def testCauchyLogPdf(self, rng_factory, shapes, dtypes):
|
||||
@ -136,7 +136,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
@ -153,7 +153,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
def args_maker():
|
||||
x, alpha = map(rng, shapes, dtypes)
|
||||
x = x / onp.sum(x, axis=-1, keepdims=True)
|
||||
x = x / np.sum(x, axis=-1, keepdims=True)
|
||||
return [x, alpha]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
@ -197,7 +197,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = onp.clip(scale, a_min=0.1, a_max=None)
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
@ -213,11 +213,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure that scale is not too low
|
||||
scale = onp.clip(scale, a_min=0.1, a_max=None)
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol={onp.float32: 1e-5, onp.float64: 1e-6})
|
||||
tol={np.float32: 1e-5, np.float64: 1e-6})
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(1, jtu.rand_default)
|
||||
@ -281,7 +281,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
@ -298,7 +298,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
@ -315,7 +315,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
@ -332,9 +332,9 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
q, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure probability is between 0 and 1:
|
||||
q = onp.clip(onp.abs(q / 3), a_min=None, a_max=1)
|
||||
q = np.clip(np.abs(q / 3), a_min=None, a_max=1)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
|
||||
return [q, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
|
||||
@ -365,13 +365,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
x, df, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
|
||||
return [x, df, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={onp.float64: 1e-14}, atol={onp.float64: 1e-14})
|
||||
rtol={np.float64: 1e-14}, atol={np.float64: 1e-14})
|
||||
|
||||
|
||||
@genNamedParametersNArgs(3, jtu.rand_default)
|
||||
@ -382,7 +382,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, loc, onp.abs(scale)]
|
||||
return [x, loc, np.abs(scale)]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
@ -390,8 +390,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
def testIssue972(self):
|
||||
self.assertAllClose(
|
||||
onp.ones((4,), onp.float32),
|
||||
lsp_stats.norm.cdf(onp.full((4,), onp.inf, onp.float32)),
|
||||
np.ones((4,), np.float32),
|
||||
lsp_stats.norm.cdf(np.full((4,), np.inf, np.float32)),
|
||||
check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -430,8 +430,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
# [(5, 3, 2), (3, 2,), (2, 2)],
|
||||
]
|
||||
for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3)
|
||||
if (mean_shape is not None or mean_dtype == onp.float32)
|
||||
and (cov_shape is not None or cov_dtype == onp.float32)
|
||||
if (mean_shape is not None or mean_dtype == np.float32)
|
||||
and (cov_shape is not None or cov_dtype == np.float32)
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape,
|
||||
mean_dtype, cov_shape, cov_dtype, rng_factory):
|
||||
@ -446,7 +446,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
else:
|
||||
factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
|
||||
factor = rng(factor_shape, cov_dtype)
|
||||
args.append(onp.matmul(factor, onp.swapaxes(factor, -1, -2)))
|
||||
args.append(np.matmul(factor, np.swapaxes(factor, -1, -2)))
|
||||
return args
|
||||
|
||||
self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf,
|
||||
|
Loading…
x
Reference in New Issue
Block a user