rocm_jax/tests/lax_numpy_test.py

6641 lines
256 KiB
Python
Raw Normal View History

# Copyright 2018 The JAX Authors.
2018-11-17 18:03:33 -08:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
2018-11-17 18:03:33 -08:00
from array import array as make_python_array
2018-11-17 18:03:33 -08:00
import collections
from collections.abc import Iterator
import copy
from functools import partial, wraps
import inspect
2021-11-09 09:43:46 -08:00
import io
2018-11-17 18:03:33 -08:00
import itertools
import math
import platform
from typing import Union, cast
import unittest
from unittest import SkipTest
2018-11-17 18:03:33 -08:00
from absl.testing import absltest
from absl.testing import parameterized
2020-05-20 01:43:48 -03:00
import numpy as np
Add experimental __array_module__ method (#4076) * Add experimental __array_module__ method xref https://github.com/google/jax/issues/1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence](https://github.com/numpy/numpy/pull/16935#issuecomment-673951287) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time. * don't explicitly list cython * remove UnshpaedArray from _JAX_ARRAY_TYPES * Remove incorrect note about metaclasses * remove unnecessary numpy_dispatch.ensure_dispatching()
2020-08-18 09:40:57 -07:00
try:
import numpy_dispatch
except ImportError:
Add experimental __array_module__ method (#4076) * Add experimental __array_module__ method xref https://github.com/google/jax/issues/1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence](https://github.com/numpy/numpy/pull/16935#issuecomment-673951287) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time. * don't explicitly list cython * remove UnshpaedArray from _JAX_ARRAY_TYPES * Remove incorrect note about metaclasses * remove unnecessary numpy_dispatch.ensure_dispatching()
2020-08-18 09:40:57 -07:00
numpy_dispatch = None
2018-11-17 18:03:33 -08:00
import jax
import jax.ops
2018-12-17 14:42:32 -08:00
from jax import lax
from jax import numpy as jnp
from jax.sharding import SingleDeviceSharding
from jax.test_util import check_grads
from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
2024-11-14 15:23:26 -05:00
from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace
2018-11-17 18:03:33 -08:00
config.parse_flags_with_absl()
2018-11-17 18:03:33 -08:00
numpy_version = jtu.numpy_version()
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
one_dim_array_shapes = [(1,), (6,), (12,)]
empty_array_shapes = [(0,), (0, 4), (3, 0),]
2024-01-09 13:23:57 -08:00
broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE]
array_shapes = nonempty_array_shapes + empty_array_shapes
nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes
nonempty_shapes = scalar_shapes + nonempty_array_shapes
all_shapes = scalar_shapes + array_shapes
2018-11-17 18:03:33 -08:00
float_dtypes = jtu.dtypes.all_floating
complex_dtypes = jtu.dtypes.complex
int_dtypes = jtu.dtypes.all_integer
unsigned_dtypes = jtu.dtypes.all_unsigned
bool_dtypes = jtu.dtypes.boolean
2018-11-17 18:03:33 -08:00
default_dtypes = float_dtypes + int_dtypes
inexact_dtypes = float_dtypes + complex_dtypes
2022-02-07 08:59:44 -08:00
number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes
all_dtypes = number_dtypes + bool_dtypes
2018-11-17 18:03:33 -08:00
NO_VALUE = object()
python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
# uint64 is problematic because with any uint type it promotes to float:
int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64]
2024-12-20 13:57:40 -08:00
def _bitcast_uint4_to_uint8(operand):
# Note: assumes little-endian byte order.
assert operand.dtype == 'uint4'
operand = operand.astype('uint8')
return operand[..., ::2] + (operand[..., 1::2] << 4)
def _bitcast_uint8_to_uint4(operand):
# Note: assumes little-endian byte order.
assert operand.dtype == 'uint8'
result = np.zeros((*operand.shape[:-1], operand.shape[-1] * 2), dtype='uint4')
result[..., ::2] = (operand & 0b00001111).astype('uint4')
result[..., 1::2] = ((operand & 0b11110000) >> 4).astype('uint4')
return result
def np_view(arr, dtype):
# Implementation of np.ndarray.view() that works for int4/uint4
dtype = np.dtype(dtype)
nbits_in = dtypes.bit_width(arr.dtype)
nbits_out = dtypes.bit_width(dtype)
if nbits_in == 4:
arr = _bitcast_uint4_to_uint8(arr.view('uint4'))
if nbits_out == 4:
arr = _bitcast_uint8_to_uint4(arr.view('uint8'))
return arr.view(dtype)
def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False,
axis=None, **kwds):
# Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0
result = np.unique(ar, return_index=return_index, return_inverse=return_inverse,
return_counts=return_counts, axis=axis, **kwds)
if jtu.numpy_version() >= (2, 0, 1) or np.ndim(ar) == 1 or not return_inverse:
return result
idx = 2 if return_index else 1
inverse_indices = result[idx]
if axis is None:
inverse_indices = inverse_indices.reshape(np.shape(ar))
elif jtu.numpy_version() == (2, 0, 0):
inverse_indices = inverse_indices.reshape(-1)
return (*result[:idx], inverse_indices, *result[idx + 1:])
def _indexer_with_default_outputs(indexer, use_defaults=True):
"""Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs"""
class Indexer:
@partial(jtu.with_jax_dtype_defaults, use_defaults=use_defaults)
def __getitem__(self, *args):
return indexer.__getitem__(*args)
return Indexer()
def _valid_dtypes_for_shape(shape, dtypes):
# Not all (shape, dtype) pairs are valid. In particular, Python scalars only
# have one type in each category (float, bool, etc.)
if shape is jtu.PYTHON_SCALAR_SHAPE:
return [t for t in dtypes if t in python_scalar_dtypes]
return dtypes
def _shape_and_dtypes(shapes, dtypes):
for shape in shapes:
for dtype in _valid_dtypes_for_shape(shape, dtypes):
yield (shape, dtype)
def _compatible_shapes(shape):
if np.ndim(shape) == 0 or shape in scalar_shapes:
return [shape]
return (shape[n:] for n in range(len(shape) + 1))
OpRecord = collections.namedtuple(
"OpRecord",
["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes",
2022-01-20 12:03:49 -08:00
"test_name", "check_dtypes", "tolerance", "inexact", "kwargs"])
def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name=None, check_dtypes=True,
2022-01-20 12:03:49 -08:00
tolerance=None, inexact=False, kwargs=None):
2018-11-17 18:03:33 -08:00
test_name = test_name or name
return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes,
2022-01-20 12:03:49 -08:00
test_name, check_dtypes, tolerance, inexact, kwargs)
2018-11-17 18:03:33 -08:00
2018-11-17 18:03:33 -08:00
JAX_ARGMINMAX_RECORDS = [
op_record("argmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
op_record("argmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
op_record("nanargmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
op_record("nanargmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
2018-11-17 18:03:33 -08:00
]
def _shapes_are_broadcast_compatible(shapes):
try:
lax.broadcast_shapes(*(() if s in scalar_shapes else s for s in shapes))
except ValueError:
return False
else:
return True
def _shapes_are_equal_length(shapes):
return all(len(shape) == len(shapes[0]) for shape in shapes[1:])
def arrays_with_overlapping_values(rng, shapes, dtypes, unique=False, overlap=0.5) -> list[jax.Array]:
"""Generate multiple arrays with some overlapping values.
This is useful for tests of set-like operations.
"""
assert 0 <= overlap <= 1
sizes = [math.prod(jtu._dims_of_shape(shape)) for shape in shapes]
total_size = int(sum(sizes) * (1 - overlap)) + max(sizes) # non-strict upper-bound.
if unique:
vals = jtu.rand_unique_int(rng)((total_size,), 'int32')
else:
vals = jtu.rand_default(rng)((total_size,), 'int32')
offsets = [int(sum(sizes[:i]) * (1 - overlap)) for i in range(len(sizes))]
return [rng.permutation(vals[offset: offset + size]).reshape(shape).astype(dtype)
for (offset, size, shape, dtype) in zip(offsets, sizes, shapes, dtypes)]
def with_size_argument(fun):
@wraps(fun)
def wrapped(*args, size=None, fill_value=None, **kwargs):
result = fun(*args, **kwargs)
if size is None or size == len(result):
return result
elif size < len(result):
return result[:size]
else:
if fill_value is None:
fill_value = result.min() if result.size else 0
return np.pad(result, (0, size - len(result)), constant_values=fill_value)
return wrapped
2018-11-17 18:03:33 -08:00
class LaxBackedNumpyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Numpy implementation."""
2020-05-20 01:43:48 -03:00
def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True):
def f():
out = [rng(shape, dtype or jnp.float_)
for shape, dtype in zip(shapes, dtypes)]
2020-05-20 01:43:48 -03:00
if np_arrays:
return out
2020-05-20 01:43:48 -03:00
return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a
for a in out]
return f
2018-11-17 18:03:33 -08:00
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in list(range(-len(shape), len(shape)))],
dtype=all_dtypes,
)
def testUnstack(self, shape, axis, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
if jnp.asarray(x).ndim == 0:
with self.assertRaisesRegex(ValueError, "Unstack requires arrays with"):
jnp.unstack(x, axis=axis)
return
y = jnp.unstack(x, axis=axis)
if shape[axis] == 0:
self.assertEqual(y, ())
else:
self.assertArraysEqual(jnp.moveaxis(jnp.array(y), 0, axis), x)
@parameterized.parameters(
[dtype for dtype in [
jnp.bool,
jnp.uint4, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
jnp.int4, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64,
jnp.complex64, jnp.complex128]
if dtype == dtypes.canonicalize_dtype(dtype)])
def testDtypeWrappers(self, dtype):
arr = dtype(0)
self.assertIsInstance(arr, jax.Array)
self.assertEqual(arr.dtype, np.dtype(dtype))
self.assertArraysEqual(arr, 0, check_dtypes=False)
# No copy primitive is generated
jaxpr = jax.make_jaxpr(dtype)(0)
prims = [eqn.primitive for eqn in jaxpr.eqns]
self.assertEqual(prims, [lax.convert_element_type_p]) # No copy generated.
2024-01-17 14:34:27 -08:00
def testBoolDtypeAlias(self):
self.assertIs(jnp.bool, jnp.bool_)
@jtu.sample_product(
dtype=float_dtypes + [object],
allow_pickle=[True, False],
)
def testLoad(self, dtype, allow_pickle):
if dtype == object and not allow_pickle:
self.skipTest("dtype=object requires allow_pickle=True")
2021-11-09 09:43:46 -08:00
rng = jtu.rand_default(self.rng())
arr = rng((10), dtype)
with io.BytesIO() as f:
jnp.save(f, arr)
f.seek(0)
arr_out = jnp.load(f, allow_pickle=allow_pickle)
self.assertArraysEqual(arr, arr_out, allow_object_dtype=True)
2021-11-09 09:43:46 -08:00
2020-09-21 10:37:55 -07:00
def testArrayEqualExamples(self):
# examples from the array_equal() docstring.
self.assertTrue(jnp.array_equal([1, 2], [1, 2]))
self.assertTrue(jnp.array_equal(np.array([1, 2]), np.array([1, 2])))
self.assertFalse(jnp.array_equal([1, 2], [1, 2, 3]))
self.assertFalse(jnp.array_equal([1, 2], [1, 4]))
a = np.array([1, np.nan])
self.assertFalse(jnp.array_equal(a, a))
self.assertTrue(jnp.array_equal(a, a, equal_nan=True))
a = np.array([1 + 1j])
b = a.copy()
a.real = np.nan
b.imag = np.nan
self.assertTrue(jnp.array_equal(a, b, equal_nan=True))
def testArrayEquivExamples(self):
# examples from the array_equiv() docstring.
self.assertTrue(jnp.array_equiv([1, 2], [1, 2]))
self.assertFalse(jnp.array_equiv([1, 2], [1, 3]))
with jax.numpy_rank_promotion('allow'):
self.assertTrue(jnp.array_equiv([1, 2], [[1, 2], [1, 2]]))
self.assertFalse(jnp.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]]))
self.assertFalse(jnp.array_equiv([1, 2], [[1, 2], [1, 3]]))
2020-09-21 10:37:55 -07:00
Add experimental __array_module__ method (#4076) * Add experimental __array_module__ method xref https://github.com/google/jax/issues/1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence](https://github.com/numpy/numpy/pull/16935#issuecomment-673951287) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time. * don't explicitly list cython * remove UnshpaedArray from _JAX_ARRAY_TYPES * Remove incorrect note about metaclasses * remove unnecessary numpy_dispatch.ensure_dispatching()
2020-08-18 09:40:57 -07:00
def testArrayModule(self):
if numpy_dispatch is None:
raise SkipTest('requires https://github.com/seberg/numpy-dispatch')
jnp_array = jnp.array(1.0)
np_array = np.array(1.0)
module = numpy_dispatch.get_array_module(jnp_array)
self.assertIs(module, jnp)
module = numpy_dispatch.get_array_module(jnp_array, np_array)
self.assertIs(module, jnp)
def f(x):
module = numpy_dispatch.get_array_module(x)
self.assertIs(module, jnp)
return x
jax.jit(f)(jnp_array)
jax.grad(f)(jnp_array)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in list(range(-len(shape), len(shape)))],
discont=[None, "pi", 2],
period=["2pi", "pi"],
dtype=default_dtypes,
)
def testUnwrap(self, shape, dtype, axis, discont, period):
special_vals = {"pi": np.pi, "2pi": 2 * np.pi}
period = special_vals.get(period, period)
discont = special_vals.get(discont, discont)
rng = jtu.rand_default(self.rng())
def np_fun(x):
dtype = None
if x.dtype == dtypes.bfloat16:
dtype = x.dtype
x = x.astype(np.float32)
out = np.unwrap(x, axis=axis, discont=discont, period=period)
return out if dtype is None else out.astype(dtype)
jnp_fun = partial(jnp.unwrap, axis=axis, discont=discont, period=period)
if not dtypes.issubdtype(dtype, np.inexact):
# This case requires implicit dtype promotion
jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
atol={dtypes.bfloat16: 1e-1, np.float16: 1e-2})
self._CompileAndCheck(jnp_fun, args_maker, atol={dtypes.bfloat16: 1e-1})
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in list(range(-len(shape), len(shape))) + [None]],
dtype=all_dtypes,
)
def testCountNonzero(self, shape, dtype, axis):
rng = jtu.rand_some_zero(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda x: np.count_nonzero(x, axis)
jnp_fun = lambda x: jnp.count_nonzero(x, axis)
args_maker = lambda: [rng(shape, dtype)]
2020-05-20 01:43:48 -03:00
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
def testNonzero(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)
@jtu.sample_product(
[dict(shape=shape, fill_value=fill_value)
for shape in nonempty_nonscalar_array_shapes
for fill_value in [None, -1, shape or (1,)]
],
dtype=all_dtypes,
size=[1, 5, 10],
)
def testNonzeroSize(self, shape, dtype, size, fill_value):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
def np_fun(x):
result = np.nonzero(x)
if size <= len(result[0]):
return tuple(arg[:size] for arg in result)
else:
fillvals = fill_value if np.ndim(fill_value) else len(result) * [fill_value or 0]
return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
for fval, arg in safe_zip(fillvals, result))
jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
def testFlatNonzero(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = np.flatnonzero
jnp_fun = jnp.flatnonzero
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
# JIT compilation requires specifying the size statically:
jnp_fun = lambda x: jnp.flatnonzero(x, size=np.size(x) // 2)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
shape=nonempty_nonscalar_array_shapes,
dtype=all_dtypes,
fill_value=[None, -1, 10, (-1,), (10,)],
size=[1, 5, 10],
)
def testFlatNonzeroSize(self, shape, dtype, size, fill_value):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
def np_fun(x):
result = np.flatnonzero(x)
if size <= len(result):
return result[:size]
else:
fill_val = fill_value or 0
return np.concatenate([result, np.full(size - len(result), fill_val, result.dtype)])
jnp_fun = lambda x: jnp.flatnonzero(x, size=size, fill_value=fill_value)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
def testArgWhere(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)
# JIT compilation requires specifying a size statically. Full test of this
# behavior is in testNonzeroSize().
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, fill_value=fill_value)
for shape in nonempty_nonscalar_array_shapes
for fill_value in [None, -1, shape or (1,)]
],
dtype=all_dtypes,
size=[1, 5, 10],
)
def testArgWhereSize(self, shape, dtype, size, fill_value):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
def np_fun(x):
result = np.argwhere(x)
if size <= len(result):
return result[:size]
else:
2021-12-14 10:01:32 +00:00
fillvals = fill_value if np.ndim(fill_value) else result.shape[-1] * [fill_value or 0]
2021-12-15 08:32:42 +00:00
return np.empty((size, 0), dtype=int) if np.ndim(x) == 0 else np.stack([np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
for fval, arg in safe_zip(fillvals, result.T)]).T
jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name),
shape=shape, dtype=dtype, axis=axis, rng_factory=rec.rng_factory)
2018-11-17 18:03:33 -08:00
for rec in JAX_ARGMINMAX_RECORDS
for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes)
for axis in range(-len(shape), len(shape))],
keepdims=[False, True],
)
def testArgMinMax(self, np_op, jnp_op, rng_factory, shape, dtype, axis, keepdims):
rng = rng_factory(self.rng())
if dtype == np.complex128 and jtu.test_device_matches(["gpu"]):
raise unittest.SkipTest("complex128 reductions not supported on GPU")
2020-05-20 01:43:48 -03:00
if "nan" in np_op.__name__ and dtype == jnp.bfloat16:
raise unittest.SkipTest("NumPy doesn't correctly handle bfloat16 arrays")
kwds = {"keepdims": True} if keepdims else {}
2018-11-17 18:03:33 -08:00
np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=axis, **kwds))
jnp_fun = partial(jnp_op, axis=axis, **kwds)
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng(shape, dtype)]
try:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
except ValueError as e:
if str(e) == "All-NaN slice encountered":
self.skipTest("JAX doesn't support checking for all-NaN slices")
else:
raise
self._CompileAndCheck(jnp_fun, args_maker)
2019-03-25 17:42:08 -05:00
@jtu.sample_product(
[dict(name=rec.name, np_op=getattr(np, rec.name),
jnp_op=getattr(jnp, rec.name))
for rec in JAX_ARGMINMAX_RECORDS],
)
2020-05-20 01:43:48 -03:00
def testArgMinMaxEmpty(self, name, np_op, jnp_op):
name = name[3:] if name.startswith("nan") else name
msg = f"attempt to get {name} of an empty sequence"
with self.assertRaisesRegex(ValueError, msg):
2020-05-20 01:43:48 -03:00
jnp_op(np.array([]))
with self.assertRaisesRegex(ValueError, msg):
2020-05-20 01:43:48 -03:00
jnp_op(np.zeros((2, 0)), axis=1)
np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=0))
jnp_fun = partial(jnp_op, axis=0)
2020-05-20 01:43:48 -03:00
args_maker = lambda: [np.zeros((2, 0))]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes)
2019-03-25 17:42:08 -05:00
for lhs_shape, rhs_shape, axes in [
[(2,), (2,), (-1, -1, -1, None)], # scalar output
[(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors
[(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors
[(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting
[(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes
[(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting
[(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors
[(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting
[(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing
[(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)] # same as before
]],
lhs_dtype=number_dtypes,
rhs_dtype=number_dtypes,
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes):
rng = jtu.rand_default(self.rng())
2019-03-25 17:42:08 -05:00
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
axisa, axisb, axisc, axis = axes
jnp_fun = lambda a, b: jnp.cross(a, b, axisa, axisb, axisc, axis)
# Note: 2D inputs to jnp.cross are deprecated in numpy 2.0.
@jtu.ignore_warning(category=DeprecationWarning,
message="Arrays of 2-dimensional vectors are deprecated.")
2020-05-20 01:43:48 -03:00
def np_fun(a, b):
a = a.astype(np.float32) if lhs_dtype == jnp.bfloat16 else a
b = b.astype(np.float32) if rhs_dtype == jnp.bfloat16 else b
out = np.cross(a, b, axisa, axisb, axisc, axis)
return out.astype(jnp.promote_types(lhs_dtype, rhs_dtype))
2020-05-20 01:43:48 -03:00
tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15}
tol = max(jtu.tolerance(lhs_dtype, tol_spec),
jtu.tolerance(rhs_dtype, tol_spec))
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol)
2018-11-17 18:03:33 -08:00
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
for lhs_shape, rhs_shape in [
((3, 3), ()),
((), (3, 3)),
((4, 5), (5,)),
((6,), (6, 4)),
((3, 4), (4, 5)),
((4, 3, 2), (2,)),
((2,), (3, 2, 4)),
((4, 3, 2), (2, 5)),
((5, 2), (3, 2, 4)),
((2, 3, 4), (5, 4, 1))]],
lhs_dtype=number_dtypes,
rhs_dtype=number_dtypes,
)
@jax.default_matmul_precision("float32")
def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
rng = jtu.rand_default(self.rng())
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
tol = {np.float16: 1e-2, np.float32: 2e-5, np.float64: 1e-14,
2020-05-20 01:43:48 -03:00
np.complex128: 1e-14}
if (lhs_dtype in [np.float16, jnp.bfloat16] and
rhs_dtype in [np.float16, jnp.bfloat16]):
tol = 1e-2
2020-05-20 01:43:48 -03:00
def np_dot(x, y):
x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x
y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y
return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype))
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, tol=tol)
self._CompileAndCheck(jnp.dot, args_maker, atol=tol, rtol=tol)
2018-11-17 18:03:33 -08:00
@jtu.sample_product(
lhs_dtype=number_dtypes,
rhs_dtype=number_dtypes,
)
@jax.numpy_dtype_promotion('standard')
def testMixedPrecisionDot(self, lhs_dtype, rhs_dtype):
# This test confirms that jnp.dot lowers to a single dot_general call,
# avoiding explicit type casting of inputs and outputs.
lhs = jax.ShapeDtypeStruct((5,), lhs_dtype)
rhs = jax.ShapeDtypeStruct((5,), rhs_dtype)
jaxpr = jax.make_jaxpr(jnp.dot)(lhs, rhs)
prims = [eqn.primitive for eqn in jaxpr.eqns]
self.assertIn(prims, [
[lax.dot_general_p],
[lax.dot_general_p, lax.convert_element_type_p]
])
@jtu.sample_product(
[dict(name=name, lhs_shape=lhs_shape, rhs_shape=rhs_shape)
2018-11-17 18:03:33 -08:00
for name, lhs_shape, rhs_shape in [
("vector-vector", (3,), (3,)),
("matrix-vector", (3, 3), (3,)),
("vector-matrix", (3,), (3, 3)),
("matrix-matrix", (3, 3), (3, 3)),
("vector-tensor", (3,), (5, 3, 2)),
("tensor-vector", (5, 3, 2), (2,)),
("matrix-tensor", (5, 2), (3, 2, 4)),
("tensor-matrix", (5, 2, 3), (3, 2)),
("tensor-tensor", (5, 3, 4), (5, 4, 1)),
("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]],
lhs_dtype=number_dtypes,
rhs_dtype=number_dtypes,
)
@jax.default_matmul_precision("float32")
def testMatmul(self, name, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
def np_fun(x, y):
dtype = jnp.promote_types(lhs_dtype, rhs_dtype)
2020-05-20 01:43:48 -03:00
return np.matmul(x, y).astype(dtype)
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
2020-05-20 01:43:48 -03:00
tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12,
np.complex128: 1e-12, jnp.bfloat16: 1e-1}
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol)
self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol)
2018-11-17 18:03:33 -08:00
2024-01-09 13:23:57 -08:00
@jtu.sample_product(
lhs_batch=broadcast_compatible_shapes,
rhs_batch=broadcast_compatible_shapes,
axis_size=[2, 4],
axis=range(-2, 2),
dtype=number_dtypes,
)
@jax.default_matmul_precision("float32")
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
2024-01-09 13:23:57 -08:00
def testVecdot(self, lhs_batch, rhs_batch, axis_size, axis, dtype):
# Construct vecdot-compatible shapes.
size = min(len(lhs_batch), len(rhs_batch))
axis = int(np.clip(axis, -size - 1, size))
if axis >= 0:
lhs_shape = (*lhs_batch[:axis], axis_size, *lhs_batch[axis:])
rhs_shape = (*rhs_batch[:axis], axis_size, *rhs_batch[axis:])
else:
laxis = axis + len(lhs_batch) + 1
lhs_shape = (*lhs_batch[:laxis], axis_size, *lhs_batch[laxis:])
raxis = axis + len(rhs_batch) + 1
rhs_shape = (*rhs_batch[:raxis], axis_size, *rhs_batch[raxis:])
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
@jtu.promote_like_jnp
def np_fn(x, y, axis=axis):
f = jtu.numpy_vecdot if jtu.numpy_version() < (2, 0, 0) else np.vecdot
return f(x, y, axis=axis).astype(x.dtype)
jnp_fn = partial(jnp.vecdot, axis=axis)
tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12,
np.complex64: 1E-3, np.complex128: 1e-12, jnp.bfloat16: 1e-1}
2024-01-09 13:23:57 -08:00
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
2024-12-10 16:03:19 -08:00
@jtu.sample_product(
lhs_batch=broadcast_compatible_shapes,
rhs_batch=broadcast_compatible_shapes,
mat_size=[1, 2, 3],
vec_size=[2, 3, 4],
dtype=number_dtypes,
)
@jax.default_matmul_precision("float32")
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testMatvec(self, lhs_batch, rhs_batch, mat_size, vec_size, dtype):
rng = jtu.rand_default(self.rng())
lhs_shape = (*lhs_batch, mat_size, vec_size)
rhs_shape = (*rhs_batch, vec_size)
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
jnp_fn = jnp.matvec
@jtu.promote_like_jnp
def np_fn(x, y):
f = (np.vectorize(np.matmul, signature="(m,n),(n)->(m)")
if jtu.numpy_version() < (2, 2, 0) else np.matvec)
return f(x, y).astype(x.dtype)
tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12,
np.complex64: 1E-3, np.complex128: 1e-12, jnp.bfloat16: 1e-1}
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
@jtu.sample_product(
lhs_batch=broadcast_compatible_shapes,
rhs_batch=broadcast_compatible_shapes,
mat_size=[1, 2, 3],
vec_size=[2, 3, 4],
dtype=number_dtypes,
)
@jax.default_matmul_precision("float32")
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testVecmat(self, lhs_batch, rhs_batch, mat_size, vec_size, dtype):
rng = jtu.rand_default(self.rng())
lhs_shape = (*lhs_batch, vec_size)
rhs_shape = (*rhs_batch, vec_size, mat_size)
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
jnp_fn = jnp.vecmat
@jtu.promote_like_jnp
def np_fn(x, y):
f = (np.vectorize(lambda x, y: np.matmul(np.conj(x), y),
signature="(m),(m,n)->(n)")
if jtu.numpy_version() < (2, 2, 0) else np.vecmat)
return f(x, y).astype(x.dtype)
tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12,
np.complex64: 1E-3, np.complex128: 1e-12, jnp.bfloat16: 1e-1}
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes)
2018-12-15 21:59:18 -08:00
for lhs_shape, rhs_shape, axes in [
[(3,), (), 0],
[(2, 3, 4), (5, 6, 7), 0], # from issue #740
2018-12-15 21:59:18 -08:00
[(2, 3, 4), (3, 4, 5, 6), 2],
[(2, 3, 4), (5, 4, 3, 6), [1, 2]],
[(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]],
[(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]],
]],
lhs_dtype=number_dtypes,
rhs_dtype=number_dtypes,
)
@jax.default_matmul_precision("float32")
def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes):
rng = jtu.rand_default(self.rng())
2018-12-15 21:59:18 -08:00
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
jnp_fun = lambda a, b: jnp.tensordot(a, b, axes)
2020-05-20 01:43:48 -03:00
def np_fun(a, b):
a = a if lhs_dtype != jnp.bfloat16 else a.astype(np.float32)
b = b if rhs_dtype != jnp.bfloat16 else b.astype(np.float32)
dtype = jnp.promote_types(lhs_dtype, rhs_dtype)
2020-05-20 01:43:48 -03:00
return np.tensordot(a, b, axes).astype(dtype)
tol = {np.float16: 1e-1, np.float32: 1e-3, np.float64: 1e-12,
np.complex64: 1e-3, np.complex128: 1e-12}
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, tol=tol)
2018-12-15 21:59:18 -08:00
2020-01-29 10:14:36 -05:00
def testTensordotErrors(self):
a = self.rng().random((3, 2, 2))
b = self.rng().random((2,))
2020-01-29 10:14:36 -05:00
self.assertRaisesRegex(
TypeError, "Number of tensordot axes.*exceeds input ranks.*",
lambda: jnp.tensordot(a, b, axes=2))
2020-01-29 10:14:36 -05:00
self.assertRaisesRegex(
TypeError, "tensordot requires axes lists to have equal length.*",
lambda: jnp.tensordot(a, b, axes=([0], [0, 1])))
self.assertRaisesRegex(
TypeError, "tensordot requires both axes lists to be either ints, tuples or lists.*",
lambda: jnp.tensordot(a, b, axes=('bad', 'axes')))
self.assertRaisesRegex(
TypeError, "tensordot axes argument must be an int, a pair of ints, or a pair of lists.*",
lambda: jnp.tensordot(a, b, axes='badaxes'))
@jtu.sample_product(
element_shape=all_shapes,
test_shape=all_shapes,
dtype=default_dtypes,
invert=[False, True],
2024-08-13 19:04:14 -07:00
method=['auto', 'compare_all', 'binary_search', 'sort']
)
2024-08-13 19:04:14 -07:00
def testIsin(self, element_shape, test_shape, dtype, invert, method):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)]
2024-08-13 19:04:14 -07:00
jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert, method=method)
2020-05-20 01:43:48 -03:00
np_fun = lambda e, t: np.isin(e, t, invert=invert)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
dtype1=[s for s in default_dtypes if s != jnp.bfloat16],
dtype2=[s for s in default_dtypes if s != jnp.bfloat16],
shape1=all_shapes,
shape2=all_shapes,
overlap=[0.1, 0.5, 0.9],
)
def testSetdiff1d(self, shape1, shape2, dtype1, dtype2, overlap):
args_maker = partial(arrays_with_overlapping_values, self.rng(),
shapes=[shape1, shape2], dtypes=[dtype1, dtype2],
overlap=overlap)
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
self._CheckAgainstNumpy(np.setdiff1d, jnp.setdiff1d, args_maker)
@jtu.sample_product(
dtype1=[s for s in default_dtypes if s != jnp.bfloat16],
dtype2=[s for s in default_dtypes if s != jnp.bfloat16],
shape1=all_shapes,
shape2=all_shapes,
size=[1, 5, 10],
fill_value=[None, -1],
overlap=[0.1, 0.5, 0.9],
)
def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value, overlap):
args_maker = partial(arrays_with_overlapping_values, self.rng(),
shapes=[shape1, shape2], dtypes=[dtype1, dtype2],
overlap=overlap)
def np_fun(arg1, arg2):
result = np.setdiff1d(arg1, arg2)
if size <= len(result):
return result[:size]
else:
return np.pad(result, (0, size-len(result)), constant_values=fill_value or 0)
def jnp_fun(arg1, arg2):
return jnp.setdiff1d(arg1, arg2, size=size, fill_value=fill_value)
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
dtype1=[s for s in default_dtypes if s != jnp.bfloat16],
dtype2=[s for s in default_dtypes if s != jnp.bfloat16],
shape1=all_shapes,
shape2=all_shapes,
overlap=[0.1, 0.5, 0.9],
)
def testUnion1d(self, shape1, shape2, dtype1, dtype2, overlap):
args_maker = partial(arrays_with_overlapping_values, self.rng(),
shapes=[shape1, shape2], dtypes=[dtype1, dtype2],
overlap=overlap)
def np_fun(arg1, arg2):
dtype = jnp.promote_types(arg1.dtype, arg2.dtype)
return np.union1d(arg1, arg2).astype(dtype)
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
self._CheckAgainstNumpy(np_fun, jnp.union1d, args_maker)
@jtu.sample_product(
dtype1=[s for s in default_dtypes if s != jnp.bfloat16],
dtype2=[s for s in default_dtypes if s != jnp.bfloat16],
shape1=nonempty_shapes,
shape2=nonempty_shapes,
size=[1, 5, 10],
fill_value=[None, -1],
overlap=[0.1, 0.5, 0.9],
)
def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value, overlap):
args_maker = partial(arrays_with_overlapping_values, self.rng(),
shapes=[shape1, shape2], dtypes=[dtype1, dtype2],
overlap=overlap)
def np_fun(arg1, arg2):
dtype = jnp.promote_types(arg1.dtype, arg2.dtype)
result = np.union1d(arg1, arg2).astype(dtype)
fv = result.min() if fill_value is None else fill_value
if size <= len(result):
return result[:size]
else:
return np.concatenate([result, np.full(size - len(result), fv, result.dtype)])
def jnp_fun(arg1, arg2):
return jnp.union1d(arg1, arg2, size=size, fill_value=fill_value)
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
dtype1=[s for s in default_dtypes if s != jnp.bfloat16],
dtype2=[s for s in default_dtypes if s != jnp.bfloat16],
shape1=all_shapes,
shape2=all_shapes,
assume_unique=[False, True],
size=[None, 2, 5],
fill_value=[None, 99],
overlap=[0.1, 0.5, 0.9],
)
def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique, size, fill_value, overlap):
args_maker = partial(arrays_with_overlapping_values, self.rng(),
shapes=[shape1, shape2], dtypes=[dtype1, dtype2],
overlap=overlap)
jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique,
size=size, fill_value=fill_value)
2021-02-07 00:08:07 +05:30
def np_fun(ar1, ar2):
if assume_unique:
# numpy requires 1D inputs when assume_unique is True.
2021-02-07 00:08:07 +05:30
ar1 = np.ravel(ar1)
ar2 = np.ravel(ar2)
return with_size_argument(np.setxor1d)(ar1, ar2, assume_unique, size=size, fill_value=fill_value)
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
2021-02-07 00:08:07 +05:30
@jtu.sample_product(
dtype1=[s for s in default_dtypes if s != jnp.bfloat16],
dtype2=[s for s in default_dtypes if s != jnp.bfloat16],
shape1=nonempty_shapes,
shape2=nonempty_shapes,
assume_unique=[False, True],
return_indices=[False, True],
size=[None, 3, 5],
fill_value=[None, -1],
overlap=[0.1, 0.5, 0.9],
)
def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique,
return_indices, size, fill_value, overlap):
args_maker = partial(arrays_with_overlapping_values, self.rng(),
shapes=[shape1, shape2], dtypes=[dtype1, dtype2],
unique=assume_unique, overlap=overlap)
def jnp_fun(ar1, ar2):
return jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices,
size=size, fill_value=fill_value)
def np_fun(ar1, ar2):
result = np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices)
def correct_size(x, fill_value):
if size is None or size == len(x):
return x
elif size < len(x):
return x[:size]
else:
if fill_value is None:
fill_value = x.min()
return np.pad(x, (0, size - len(x)), constant_values=fill_value)
if return_indices:
return tuple(correct_size(r, f) for r, f in zip(result, [fill_value, ar1.size, ar2.size]))
else:
return correct_size(result, fill_value)
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, lhs_dtype=lhs_dtype,
rhs_shape=rhs_shape, rhs_dtype=rhs_dtype)
# TODO(phawkins): support integer dtypes too.
for lhs_shape, lhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes)
for rhs_shape, rhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes)
if len(jtu._dims_of_shape(lhs_shape)) == 0
or len(jtu._dims_of_shape(rhs_shape)) == 0
or lhs_shape[-1] == rhs_shape[-1]],
)
@jax.default_matmul_precision("float32")
def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
rng = jtu.rand_default(self.rng())
2018-12-19 08:57:18 -05:00
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
2020-05-20 01:43:48 -03:00
def np_fun(lhs, rhs):
lhs = lhs if lhs_dtype != jnp.bfloat16 else lhs.astype(np.float32)
rhs = rhs if rhs_dtype != jnp.bfloat16 else rhs.astype(np.float32)
dtype = jnp.promote_types(lhs_dtype, rhs_dtype)
2020-05-20 01:43:48 -03:00
return np.inner(lhs, rhs).astype(dtype)
jnp_fun = lambda lhs, rhs: jnp.inner(lhs, rhs)
2020-05-20 01:43:48 -03:00
tol_spec = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-13,
np.complex64: 1e-5}
tol = max(jtu.tolerance(lhs_dtype, tol_spec),
jtu.tolerance(rhs_dtype, tol_spec))
# TODO(phawkins): there are float32/float64 disagreements for some inputs.
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol)
2021-08-29 16:48:45 +05:30
@jtu.sample_product(
dtype=[dt for dt in float_dtypes if dt not in [jnp.float16, jnp.bfloat16]],
shape=[shape for shape in one_dim_array_shapes if shape != (1,)],
deg=[1, 2, 3],
rcond=[None, -1, 10e-3, 10e-5, 10e-10],
full=[False, True],
w=[False, True],
cov=[False, True, "unscaled"],
)
@jax.default_matmul_precision("float32")
2021-08-29 16:48:45 +05:30
def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov):
rng = jtu.rand_default(self.rng())
tol_spec = {np.float32: 1e-3, np.float64: 1e-13, np.complex64: 1e-5}
2021-08-29 16:48:45 +05:30
tol = jtu.tolerance(dtype, tol_spec)
_w = lambda a: abs(a) if w else None
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)]
jnp_fun = lambda x, y, a: jnp.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov)
np_fun = jtu.ignore_warning(
message="Polyfit may be poorly conditioned*")(lambda x, y, a: np.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov))
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol)
2018-12-19 08:57:18 -05:00
args = args_maker()
if not full:
args = args_maker()
try:
np_out = np_fun(*args)
except ValueError:
return # https://github.com/numpy/numpy/issues/22380
jnp_out = jnp_fun(*args)
self.assertAllClose(np_out, jnp_out, atol=tol, rtol=tol,
check_dtypes=False)
else:
# Don't compare the residuals because jnp.linalg.lstsq acts slightly
# differently to remain `jit`-compatible.
np_p, _, nrank, nsingular_values, nrcond = np_fun(*args)
jp_p, _, jrank, jsingular_values, jrcond = jnp_fun(*args)
self.assertAllClose(
(np_p, nrank, nsingular_values, nrcond),
(jp_p, jrank, jsingular_values, jrcond),
atol=tol, rtol=tol, check_dtypes=False)
@jtu.sample_product(
[dict(a_min=a_min, a_max=a_max)
2020-08-11 11:12:26 -07:00
for a_min, a_max in [(-1, None), (None, 1), (-0.9, 1),
2020-05-20 01:43:48 -03:00
(-np.ones(1), None),
(None, np.ones(1)),
(np.full(1, -0.9), np.ones(1))]
],
shape=all_shapes,
dtype=float_dtypes + int_dtypes + unsigned_dtypes,
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testClipStaticBounds(self, shape, dtype, a_min, a_max):
2022-02-07 08:59:44 -08:00
if np.issubdtype(dtype, np.unsignedinteger):
a_min = None if a_min is None else abs(a_min)
a_max = None if a_max is None else abs(a_max)
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max)
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng(shape, dtype)]
2020-05-20 01:43:48 -03:00
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
2018-11-17 18:03:33 -08:00
@jtu.sample_product(
shape=all_shapes,
dtype=default_dtypes + unsigned_dtypes,
)
def testClipNone(self, shape, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
self.assertArraysEqual(jnp.clip(x), x)
def testClipComplexInputError(self):
rng = jtu.rand_default(self.rng())
x = rng((5,), dtype=jnp.complex64)
msg = ".*Complex values have no ordering and cannot be clipped.*"
# jit is disabled so we don't miss warnings due to caching.
with jax.disable_jit():
with self.assertRaisesRegex(ValueError, msg):
jnp.clip(x)
with self.assertRaisesRegex(ValueError, msg):
jnp.clip(x, max=x)
x = rng((5,), dtype=jnp.int32)
with self.assertRaisesRegex(ValueError, msg):
jnp.clip(x, min=-1+5j)
with self.assertRaisesRegex(ValueError, msg):
jnp.clip(x, max=jnp.array([-1+5j]))
def testClipDeprecatedArgs(self):
with self.assertDeprecationWarnsOrRaises("jax-numpy-clip-args",
"Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated"):
jnp.clip(jnp.arange(4), a_min=2, a_max=3)
def testHypotComplexInputError(self):
rng = jtu.rand_default(self.rng())
x = rng((5,), dtype=jnp.complex64)
msg = "jnp.hypot is not well defined for complex-valued inputs.*"
# jit is disabled so we don't miss warnings due to caching.
with jax.disable_jit():
with self.assertRaisesRegex(ValueError, msg):
jnp.hypot(x, x)
y = jnp.ones_like(x)
with self.assertRaisesRegex(ValueError, msg):
jnp.hypot(x, y)
@jtu.sample_product(
[dict(shape=shape, dtype=dtype)
for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)],
decimals=[0, 1, -2],
)
def testRoundStaticDecimals(self, shape, dtype, decimals):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
if jnp.issubdtype(dtype, np.integer) and decimals < 0:
self.skipTest("Integer rounding with decimals < 0 not implemented")
2020-05-20 01:43:48 -03:00
np_fun = lambda x: np.round(x, decimals=decimals)
jnp_fun = lambda x: jnp.round(x, decimals=decimals)
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng(shape, dtype)]
2020-05-20 01:43:48 -03:00
tol = {jnp.bfloat16: 5e-2, np.float16: 1e-2}
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE
2020-05-20 01:43:48 -03:00
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=check_dtypes, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes,
atol=tol, rtol=tol)
2018-11-17 18:03:33 -08:00
@jtu.sample_product(jit=[False, True])
def testOperatorRound(self, jit):
jround = jax.jit(round, static_argnums=1) if jit else round
2020-05-20 01:43:48 -03:00
self.assertAllClose(round(np.float32(7.532), 1),
jround(jnp.float32(7.5), 1))
2020-05-20 01:43:48 -03:00
self.assertAllClose(round(np.float32(1.234), 2),
jround(jnp.float32(1.234), 2))
2020-05-20 01:43:48 -03:00
self.assertAllClose(round(np.float32(1.234)),
jround(jnp.float32(1.234)), check_dtypes=False)
2020-05-20 01:43:48 -03:00
self.assertAllClose(round(np.float32(7.532), 1),
jround(jnp.array(7.5, jnp.float32), 1))
2020-05-20 01:43:48 -03:00
self.assertAllClose(round(np.float32(1.234), 2),
jround(jnp.array(1.234, jnp.float32), 2))
2020-05-20 01:43:48 -03:00
self.assertAllClose(round(np.float32(1.234)),
jround(jnp.array(1.234, jnp.float32)),
check_dtypes=False)
2023-03-23 20:16:23 -07:00
def testRoundMethod(self):
# https://github.com/jax-ml/jax/issues/15190
2023-03-23 20:16:23 -07:00
(jnp.arange(3.) / 5.).round() # doesn't crash
@jtu.sample_product(shape=[(5,), (5, 2)])
def testOperatorReversed(self, shape):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, 'float32')]
np_fun = lambda x: np.array(list(reversed(x)))
jnp_fun = lambda x: jnp.array(list(reversed(x)))
self._CompileAndCheck(jnp_fun, args_maker)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@jtu.sample_product(
[dict(mode=mode, shape=shape, dtype=dtype,
pad_width=pad_width, constant_values=constant_values)
for mode, shapes in [
('constant', all_shapes),
('wrap', nonempty_shapes),
('edge', nonempty_shapes),
]
for shape, dtype in _shape_and_dtypes(shapes, all_dtypes)
for constant_values in [
# None is used for modes other than 'constant'
None,
# constant
0, 1,
# (constant,)
(0,), (2.718,),
# ((before_const, after_const),)
(0, 2),
(-1, 3.14),
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i / 2, -3.14 * i) for i in range(len(shape))),
]
for pad_width in [
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
# ((before, after),)
((1, 2),), ((2, 0),),
# (before, after) (not in the docstring but works in numpy)
(2, 0), (0, 0),
# (pad,)
(1,), (2,),
# pad
0, 1,
]
if (pad_width != () and constant_values != () and
((mode == 'constant' and constant_values is not None) or
(mode != 'constant' and constant_values is None)))],
)
def testPad(self, shape, dtype, mode, pad_width, constant_values):
2022-02-07 08:59:44 -08:00
if np.issubdtype(dtype, np.unsignedinteger):
constant_values = jax.tree.map(abs, constant_values)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if constant_values is None:
np_fun = partial(np.pad, pad_width=pad_width, mode=mode)
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode)
else:
np_fun = partial(np.pad, pad_width=pad_width, mode=mode,
constant_values=constant_values)
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode,
constant_values=constant_values)
2020-05-20 01:43:48 -03:00
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(mode=mode, shape=shape, dtype=dtype,
pad_width=pad_width, stat_length=stat_length)
for mode in ['maximum', 'minimum', 'mean', 'median']
for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes)
for pad_width in [
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
# ((before, after),)
((1, 2),), ((2, 0),),
# (before, after) (not in the docstring but works in numpy)
(2, 0), (0, 0),
# (pad,)
(1,), (2,),
# pad
0, 1,
]
for stat_length in [
None,
# ((before_1, after_1), ..., (before_N, after_N))
tuple(((i % 3 + 1), ((i + 1) % 3) + 1) for i in range(len(shape))),
# ((before, after),)
((1, 2),), ((2, 2),),
# (before, after) (not in the docstring but works in numpy)
(1, 1), (3, 4),
# (pad,)
(1,), (2,),
# pad
1, 2
]
if (pad_width != () and stat_length != () and
not (dtype in bool_dtypes and mode == 'mean'))],
)
def testPadStatValues(self, shape, dtype, mode, pad_width, stat_length):
if mode == 'median' and np.issubdtype(dtype, np.complexfloating):
self.skipTest("median statistic is not supported for dtype=complex.")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = partial(np.pad, pad_width=pad_width, mode=mode, stat_length=stat_length)
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, stat_length=stat_length)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, dtype=dtype,
pad_width=pad_width, reflect_type=reflect_type)
for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes)
for pad_width in [
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
# ((before, after),)
((1, 2),), ((2, 3),),
# (before, after) (not in the docstring but works in numpy)
(2, 1), (1, 2),
# (pad,)
(1,), (2,), (3,),
# pad
0, 5, 7, 10
]
for reflect_type in ['even', 'odd']
if (pad_width != () and
# following types lack precision when calculating odd values
(reflect_type != 'odd' or dtype not in [np.bool_, np.float16, jnp.bfloat16]))],
mode=['symmetric', 'reflect']
)
def testPadSymmetricAndReflect(self, shape, dtype, mode, pad_width, reflect_type):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = partial(np.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type)
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE,
tol={np.float32: 1e-3, np.complex64: 1e-3})
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, dtype=dtype, pad_width=pad_width, end_values=end_values)
2022-02-07 08:59:44 -08:00
for shape, dtype in _shape_and_dtypes(nonempty_shapes, default_dtypes + complex_dtypes)
for pad_width in [
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
# ((before, after),)
((1, 2),), ((2, 0),),
# (before, after) (not in the docstring but works in numpy)
(2, 0), (0, 0),
# (pad,)
(1,), (2,),
# pad
0, 1,
]
for end_values in [
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
# ((before, after),)
((1, 2),), ((2.0, 3.14),),
# (before, after) (not in the docstring but works in numpy)
(0, 0), (-8.0, 2.0),
# (end_values,)
(1,), (2,),
# end_values
0, 1, 100, 10.0, 3.5, 4.2, -5, -3
]
if (pad_width != () and end_values != () and
# following types lack precision
dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16])],
)
def testPadLinearRamp(self, shape, dtype, pad_width, end_values):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = partial(np.pad, pad_width=pad_width, mode="linear_ramp",
end_values=end_values)
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode="linear_ramp",
end_values=end_values)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(jnp_fun, args_maker)
2020-12-16 16:06:57 +09:00
def testPadEmpty(self):
arr = np.arange(6).reshape(2, 3)
pad_width = ((2, 3), (3, 1))
np_res = np.pad(arr, pad_width=pad_width, mode="empty")
jnp_res = jnp.pad(arr, pad_width=pad_width, mode="empty")
np.testing.assert_equal(np_res.shape, jnp_res.shape)
np.testing.assert_equal(arr, np_res[2:-3, 3:-1])
np.testing.assert_equal(arr, jnp_res[2:-3, 3:-1])
np.testing.assert_equal(np_res[2:-3, 3:-1], jnp_res[2:-3, 3:-1])
def testPadKwargs(self):
modes = {
'constant': {'constant_values': 0},
'edge': {},
'linear_ramp': {'end_values': 0},
'maximum': {'stat_length': None},
'mean': {'stat_length': None},
'median': {'stat_length': None},
'minimum': {'stat_length': None},
'reflect': {'reflect_type': 'even'},
'symmetric': {'reflect_type': 'even'},
'wrap': {},
'empty': {}
}
arr = jnp.array([1, 2, 3])
pad_width = 1
for mode in modes.keys():
allowed = modes[mode]
not_allowed = {}
for kwargs in modes.values():
if kwargs != allowed:
not_allowed.update(kwargs)
# Test if allowed keyword arguments pass
jnp.pad(arr, pad_width, mode, **allowed)
# Test if prohibited keyword arguments of other modes raise an error
match = f"unsupported keyword arguments for mode '{mode}'"
for key, value in not_allowed.items():
with self.assertRaisesRegex(ValueError, match):
jnp.pad(arr, pad_width, mode, **{key: value})
2020-12-21 19:08:57 +09:00
# Test if unsupported mode raise error.
unsupported_modes = [1, None, "foo"]
for mode in unsupported_modes:
match = f"Unimplemented padding mode '{mode}' for np.pad."
2020-12-21 19:08:57 +09:00
with self.assertRaisesRegex(NotImplementedError, match):
jnp.pad(arr, pad_width, mode)
def testPadFunction(self):
def np_pad_with(vector, pad_width, iaxis, kwargs):
pad_value = kwargs.get('padder', 10)
vector[:pad_width[0]] = pad_value
vector[-pad_width[1]:] = pad_value
def jnp_pad_with(vector, pad_width, iaxis, kwargs):
pad_value = kwargs.get('padder', 10)
vector = vector.at[:pad_width[0]].set(pad_value)
vector = vector.at[-pad_width[1]:].set(pad_value)
return vector
arr = np.arange(6).reshape(2, 3)
np_res = np.pad(arr, 2, np_pad_with)
jnp_res = jnp.pad(arr, 2, jnp_pad_with)
np.testing.assert_equal(np_res, jnp_res)
arr = np.arange(24).reshape(2, 3, 4)
np_res = np.pad(arr, 1, np_pad_with, padder=100)
jnp_res = jnp.pad(arr, 1, jnp_pad_with, padder=100)
np.testing.assert_equal(np_res, jnp_res)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(arr.shape, arr.dtype)]
jnp_fun = partial(jnp.pad, pad_width=1, mode=jnp_pad_with)
self._CompileAndCheck(jnp_fun, args_maker)
def testPadWithNumpyPadWidth(self):
a = jnp.array([1, 2, 3, 4, 5])
f = jax.jit(
partial(
jnp.pad,
pad_width=np.asarray((2, 3)),
mode="constant",
constant_values=(4, 6)))
np.testing.assert_array_equal(
f(a),
np.pad(
a,
pad_width=np.asarray((2, 3)),
mode="constant",
constant_values=(4, 6)))
def testPadWeakType(self):
x = jnp.array(1.0)[None]
for mode in ['constant', 'edge', 'linear_ramp', 'maximum', 'mean', 'median',
'minimum', 'reflect', 'symmetric', 'wrap', 'empty']:
y = jnp.pad(x, 0, mode=mode)
self.assertTrue(dtypes.is_weakly_typed(y))
@jtu.sample_product(
[dict(shape=shape, dtype=dtype)
for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes)],
reps=[(), (2,), (3, 4), (2, 3, 4), (1, 0, 2)],
)
def testTile(self, shape, dtype, reps):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda arg: np.tile(arg, reps)
jnp_fun = lambda arg: jnp.tile(arg, reps)
2019-04-30 12:56:48 -07:00
args_maker = lambda: [rng(shape, dtype)]
2020-05-20 01:43:48 -03:00
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
2020-05-28 11:04:15 -07:00
def testExtract(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)]
self._CheckAgainstNumpy(np.extract, jnp.extract, args_maker)
2020-05-28 11:04:15 -07:00
@jtu.sample_product(shape=nonempty_array_shapes, dtype=all_dtypes)
def testExtractSize(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)]
def jnp_fun(condition, arr):
return jnp.extract(condition, arr, size=jnp.size(arr) - 1)
def np_fun(condition, arr):
size = jnp.size(arr) - 1
out = np.extract(condition, arr)
result = np.zeros(np.size(arr) - 1, dtype=dtype)
size = min(len(out), size)
result[:size] = out[:size]
return result
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(ncond=ncond, nfunc=nfunc)
for ncond in [1, 2, 3]
for nfunc in [ncond, ncond + 1]
],
shape=all_shapes,
dtype=all_dtypes)
def testPiecewise(self, shape, dtype, ncond, nfunc):
rng = jtu.rand_default(self.rng())
rng_bool = jtu.rand_int(self.rng(), 0, 2)
funclist = [lambda x: x - 1, 1, lambda x: x, 0][:nfunc]
2021-03-09 13:25:38 -08:00
args_maker = lambda: (rng(shape, dtype), [rng_bool(shape, bool) for i in range(ncond)])
np_fun = partial(np.piecewise, funclist=funclist)
jnp_fun = partial(jnp.piecewise, funclist=funclist)
if dtype == np.bool_:
# The `x - 1` above uses type promotion.
jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
2021-03-09 13:25:38 -08:00
# This is a higher-order function, so the cache miss check will fail.
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, check_cache_misses=False)
def testPiecewiseRecompile(self):
def g(x):
g.num_traces += 1
return x
g.num_traces = 0
x = jnp.arange(10.0)
for i in range(5):
jnp.piecewise(x, [x < 0], [g, 0.])
self.assertEqual(g.num_traces, 1)
@jtu.sample_product(
[dict(shape=shape, perm=perm)
for shape in array_shapes
for perm in [
None,
tuple(np.random.RandomState(0).permutation(np.zeros(shape).ndim)),
tuple(np.random.RandomState(0).permutation(
np.zeros(shape).ndim) - np.zeros(shape).ndim)
]
],
dtype=default_dtypes,
arg_type=["splat", "value"],
)
def testTransposeTuple(self, shape, dtype, perm, arg_type):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if arg_type == "value":
np_fun = lambda x: x.transpose(perm)
jnp_fun = lambda x: jnp.array(x).transpose(perm)
else:
np_fun = lambda x: x.transpose(*(perm or ()))
jnp_fun = lambda x: jnp.array(x).transpose(*(perm or ()))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@jtu.sample_product(
shape=array_shapes,
dtype=default_dtypes,
)
def testPermuteDims(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
axes = self.rng().permutation(len(shape))
np_fun = partial(getattr(np, "permute_dims", np.transpose), axes=axes)
jnp_fun = partial(jnp.permute_dims, axes=axes)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@jtu.sample_product(
shape=[s for s in array_shapes if len(s) >= 2],
dtype=default_dtypes,
use_property=[True, False]
)
def testMatrixTranspose(self, shape, dtype, use_property):
if use_property:
jnp_fun = lambda x: jnp.asarray(x).mT
else:
jnp_fun = jnp.matrix_transpose
if hasattr(np, 'matrix_transpose'):
np_fun = np.matrix_transpose
else:
np_fun = lambda x: np.swapaxes(x, -1, -2)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
dtype=default_dtypes,
a_shape=one_dim_array_shapes,
trim=["f", "b", "fb"],
)
def testTrimZeros(self, a_shape, dtype, trim):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(a_shape, dtype)]
np_fun = lambda arg1: np.trim_zeros(arg1, trim)
jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
def testTrimZerosNotOneDArray(self):
# TODO: make this an error after the deprecation period.
with self.assertWarnsRegex(DeprecationWarning,
r"Passing arrays with ndim != 1 to jnp.trim_zeros\(\)"):
jnp.trim_zeros(jnp.array([[0.0, 1.0, 0.0],[2.0, 4.5, 0.0]]))
@jtu.sample_product(
rank=(1, 2),
dtype=default_dtypes,
a_shape=one_dim_array_shapes,
)
@jax.default_matmul_precision("float32")
def testPoly(self, a_shape, dtype, rank):
if dtype in (np.float16, jnp.bfloat16, np.int16):
self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.")
elif rank == 2 and not jtu.test_device_matches(["cpu", "gpu"]):
self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU and GPU backends.")
rng = jtu.rand_default(self.rng())
tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 }
if jtu.test_device_matches(["tpu"]):
tol[np.int32] = tol[np.float32] = 1e-1
tol = jtu.tolerance(dtype, tol)
args_maker = lambda: [rng(a_shape * rank, dtype)]
self._CheckAgainstNumpy(np.poly, jnp.poly, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp.poly, args_maker, check_dtypes=True, rtol=tol, atol=tol)
@jtu.sample_product(
dtype=default_dtypes,
a_shape=one_dim_array_shapes,
b_shape=one_dim_array_shapes,
)
2020-06-09 03:06:20 +10:00
def testPolyAdd(self, a_shape, b_shape, dtype):
2020-06-04 03:26:35 +10:00
rng = jtu.rand_default(self.rng())
np_fun = lambda arg1, arg2: np.polyadd(arg1, arg2)
jnp_fun = lambda arg1, arg2: jnp.polyadd(arg1, arg2)
2020-06-09 03:06:20 +10:00
args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)]
2020-06-04 03:26:35 +10:00
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@jtu.sample_product(
dtype=default_dtypes,
a_shape=one_dim_array_shapes,
b_shape=one_dim_array_shapes,
)
2020-06-06 02:44:10 +10:00
def testPolySub(self, a_shape, b_shape, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg1, arg2: np.polysub(arg1, arg2)
jnp_fun = lambda arg1, arg2: jnp.polysub(arg1, arg2)
args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@jtu.sample_product(
[dict(order=order, k=k, dtype=dtype)
2021-02-18 11:08:41 -05:00
for dtype in default_dtypes
for order in range(5)
for k in [np.arange(order, dtype=dtype), np.ones(1, dtype), None]],
a_shape=one_dim_array_shapes,
)
2021-02-18 11:08:41 -05:00
def testPolyInt(self, a_shape, order, k, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg1: np.polyint(arg1, m=order, k=k)
jnp_fun = lambda arg1: jnp.polyint(arg1, m=order, k=k)
args_maker = lambda: [rng(a_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@jtu.sample_product(
dtype=default_dtypes,
a_shape=one_dim_array_shapes,
order=list(range(5)),
)
2020-06-18 02:43:50 +10:00
def testPolyDer(self, a_shape, order, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg1: np.polyder(arg1, m=order)
jnp_fun = lambda arg1: jnp.polyder(arg1, m=order)
args_maker = lambda: [rng(a_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@parameterized.parameters(['int', 'np.int', 'jnp.int'])
def testIntegerPower(self, ptype):
p = {'int': 2, 'np.int': np.int32(2), 'jnp.int': jnp.int32(2)}[ptype]
jaxpr = jax.make_jaxpr(lambda x1: jnp.power(x1, p))(1)
eqns = jaxpr.jaxpr.eqns
self.assertLen(eqns, 1)
self.assertEqual(eqns[0].primitive, lax.integer_pow_p)
2020-06-18 02:43:50 +10:00
@jtu.sample_product(
x=[-1, 0, 1],
y=[0, 32, 64, 128],
)
2021-03-09 09:36:41 -08:00
def testIntegerPowerOverflow(self, x, y):
# Regression test for https://github.com/jax-ml/jax/issues/5987
2021-03-09 09:36:41 -08:00
args_maker = lambda: [x, y]
self._CheckAgainstNumpy(np.power, jnp.power, args_maker)
self._CompileAndCheck(jnp.power, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in [None] + list(range(len(shape)))
],
dtype=all_dtypes,
)
def testCompress(self, shape, dtype, axis):
rng = jtu.rand_some_zero(self.rng())
if shape in scalar_shapes or len(shape) == 0:
cond_shape = (0,)
elif axis is None:
cond_shape = (math.prod(shape),)
else:
cond_shape = (shape[axis],)
args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)]
np_fun = partial(np.compress, axis=axis)
jnp_fun = partial(jnp.compress, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in list(range(len(shape)))
],
dtype=all_dtypes,
)
def testCompressSize(self, shape, dtype, axis):
rng = jtu.rand_default(self.rng())
if shape in scalar_shapes or len(shape) == 0:
cond_shape = (0,)
elif axis is None:
cond_shape = (math.prod(shape),)
else:
cond_shape = (shape[axis],)
args_maker = lambda: [rng(cond_shape, bool), rng(shape, dtype)]
def np_fun(condition, a, axis=axis, fill_value=1):
# assuming size = a.shape[axis]
out = np.compress(condition, a, axis=axis)
result = np.full_like(a, fill_value)
result[tuple(slice(s) for s in out.shape)] = out
return result
def jnp_fun(condition, a, axis=axis, fill_value=1):
return jnp.compress(condition, a, axis=axis,
size=a.shape[axis], fill_value=fill_value)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
shape=[(2, 3)],
dtype=int_dtypes,
# condition entries beyond axis size must be zero.
condition=[[1], [1, 0, 0, 0, 0, 0, 0]],
axis=[None, 0, 1],
)
def testCompressMismatchedShapes(self, shape, dtype, condition, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [np.array(condition), rng(shape, dtype)]
np_fun = partial(np.compress, axis=axis)
jnp_fun = partial(jnp.compress, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in array_shapes
for axis in [None] + list(range(len(shape)))
],
dtype=all_dtypes,
)
def testCompressMethod(self, shape, dtype, axis):
rng = jtu.rand_some_zero(self.rng())
if shape in scalar_shapes or len(shape) == 0:
cond_shape = (0,)
elif axis is None:
cond_shape = (math.prod(shape),)
else:
cond_shape = (shape[axis],)
args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)]
np_fun = lambda condition, x: np.asarray(x).compress(condition, axis=axis)
jnp_fun = lambda condition, x: jnp.asarray(x).compress(condition, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@jtu.sample_product(
[dict(base_shape=base_shape, axis=axis)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
2024-01-11 13:33:18 -08:00
for axis in (None, *range(-len(base_shape)+1, len(base_shape)))
],
arg_dtypes=[
arg_dtypes
2018-11-17 18:03:33 -08:00
for num_arrs in [3]
for arg_dtypes in itertools.combinations_with_replacement(default_dtypes, num_arrs)
],
dtype=[None] + default_dtypes,
)
def testConcatenate(self, axis, dtype, base_shape, arg_dtypes):
rng = jtu.rand_default(self.rng())
2024-01-11 13:33:18 -08:00
wrapped_axis = 0 if axis is None else axis % len(base_shape)
2018-11-17 18:03:33 -08:00
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)]
2022-10-06 10:20:26 -07:00
@jtu.promote_like_jnp
2022-10-05 15:29:15 -07:00
def np_fun(*args, dtype=dtype):
dtype = dtype or args[0].dtype
2020-05-20 01:43:48 -03:00
args = [x if x.dtype != jnp.bfloat16 else x.astype(np.float32)
for x in args]
return np.concatenate(args, axis=axis, dtype=dtype, casting='unsafe')
jnp_fun = lambda *args: jnp.concatenate(args, axis=axis, dtype=dtype)
2018-11-17 18:03:33 -08:00
def args_maker():
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
2018-11-17 18:03:33 -08:00
with jtu.strict_promotion_if_dtypes_match(arg_dtypes):
2022-10-05 15:29:15 -07:00
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2018-11-17 18:03:33 -08:00
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in [(4, 1), (4, 3), (4, 5, 6)]
for axis in [None] + list(range(1 - len(shape), len(shape) - 1))
],
dtype=all_dtypes,
)
def testConcatenateArray(self, shape, dtype, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda x: np.concatenate(x, axis=axis)
jnp_fun = lambda x: jnp.concatenate(x, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testConcatenateAxisNone(self):
# https://github.com/jax-ml/jax/issues/3419
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5]])
jnp.concatenate((a, b), axis=None)
2024-01-11 13:33:18 -08:00
def testConcatenateScalarAxisNone(self):
arrays = [np.int32(0), np.int32(1)]
self.assertArraysEqual(jnp.concatenate(arrays, axis=None),
np.concatenate(arrays, axis=None))
@jtu.sample_product(
[dict(base_shape=base_shape, axis=axis)
for base_shape in [(), (4,), (3, 4), (2, 3, 4)]
for axis in (None, *range(-len(base_shape)+1, len(base_shape)))
],
dtype=default_dtypes,
)
def testConcat(self, axis, base_shape, dtype):
rng = jtu.rand_default(self.rng())
wrapped_axis = 0 if axis is None else axis % len(base_shape)
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
for size in [3, 1, 4]]
@jtu.promote_like_jnp
def np_fun(*args):
if jtu.numpy_version() >= (2, 0, 0):
return np.concat(args, axis=axis)
else:
return np.concatenate(args, axis=axis)
jnp_fun = lambda *args: jnp.concat(args, axis=axis)
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(base_shape=base_shape, axis=axis)
2018-12-30 17:49:11 -08:00
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for axis in range(-len(base_shape)+1, len(base_shape))],
arg_dtypes=itertools.combinations_with_replacement(default_dtypes, 2)
)
def testAppend(self, axis, base_shape, arg_dtypes):
rng = jtu.rand_default(self.rng())
2018-12-30 17:49:11 -08:00
wrapped_axis = axis % len(base_shape)
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)]
2020-05-20 01:43:48 -03:00
def np_fun(arr, values):
arr = arr.astype(np.float32) if arr.dtype == jnp.bfloat16 else arr
values = (values.astype(np.float32) if values.dtype == jnp.bfloat16
else values)
2020-05-20 01:43:48 -03:00
out = np.append(arr, values, axis=axis)
return out.astype(jnp.promote_types(*arg_dtypes))
jnp_fun = lambda arr, values: jnp.append(arr, values, axis=axis)
2018-12-30 17:49:11 -08:00
def args_maker():
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
2018-12-30 17:49:11 -08:00
with jtu.strict_promotion_if_dtypes_match(arg_dtypes):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2018-12-30 17:49:11 -08:00
@jtu.sample_product(
[dict(shape=shape, axis=axis, idx=idx)
2021-03-16 17:05:23 -07:00
for shape in nonempty_nonscalar_array_shapes
for axis in [None] + list(range(-len(shape), len(shape)))
for idx in (range(-math.prod(shape), math.prod(shape))
2021-03-16 17:05:23 -07:00
if axis is None else
range(-shape[axis], shape[axis]))],
dtype=all_dtypes,
)
2021-03-16 17:05:23 -07:00
def testDeleteInteger(self, shape, dtype, idx, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, idx, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
2021-03-16 17:05:23 -07:00
for shape in nonempty_nonscalar_array_shapes
for axis in [None] + list(range(-len(shape), len(shape)))
],
dtype=all_dtypes,
slc=[slice(None), slice(1, 3), slice(1, 5, 2)],
)
2021-03-16 17:05:23 -07:00
def testDeleteSlice(self, shape, dtype, axis, slc):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, slc, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, slc, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
2021-03-16 17:05:23 -07:00
for shape in nonempty_nonscalar_array_shapes
for axis in [None] + list(range(-len(shape), len(shape)))
],
dtype=all_dtypes,
idx_shape=all_shapes,
)
2021-03-16 17:05:23 -07:00
def testDeleteIndexArray(self, shape, dtype, axis, idx_shape):
rng = jtu.rand_default(self.rng())
max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
idx = jtu.rand_int(self.rng(), low=-max_idx, high=max_idx)(idx_shape, int)
2021-03-16 17:05:23 -07:00
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, idx, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in nonempty_nonscalar_array_shapes
for axis in [None] + list(range(-len(shape), len(shape)))
],
dtype=all_dtypes,
idx_shape=all_shapes,
)
def testDeleteUniqueIndices(self, shape, dtype, axis, idx_shape):
rng = jtu.rand_default(self.rng())
max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
idx_size = np.zeros(idx_shape).size
if idx_size > max_idx:
self.skipTest("Too many indices to be unique")
def args_maker():
x = rng(shape, dtype)
idx = self.rng().choice(max_idx, idx_shape, replace=False)
return x, idx
np_fun = partial(np.delete, axis=axis)
jnp_fun = partial(jnp.delete, axis=axis, assume_unique_indices=True)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
2021-03-16 17:05:23 -07:00
for shape in nonempty_nonscalar_array_shapes
for axis in [None] + list(range(-len(shape), len(shape)))
],
dtype=all_dtypes,
)
2021-03-16 17:05:23 -07:00
def testDeleteMaskArray(self, shape, dtype, axis):
rng = jtu.rand_default(self.rng())
mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool)
if numpy_version == (1, 23, 0) and mask.shape == (1,):
# https://github.com/numpy/numpy/issues/21840
self.skipTest("test fails for numpy v1.23.0")
2021-03-16 17:05:23 -07:00
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, mask, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
2021-09-16 12:18:32 -07:00
for shape in nonempty_nonscalar_array_shapes
for axis in [None] + list(range(-len(shape), len(shape)))
],
dtype=all_dtypes,
)
2021-09-16 12:18:32 -07:00
def testInsertInteger(self, shape, dtype, axis):
x = jnp.empty(shape)
max_ind = x.size if axis is None else x.shape[axis]
rng = jtu.rand_default(self.rng())
i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind)
args_maker = lambda: [rng(shape, dtype), i_rng((), np.int32), rng((), dtype)]
np_fun = lambda *args: np.insert(*args, axis=axis)
jnp_fun = lambda *args: jnp.insert(*args, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
2021-09-16 12:18:32 -07:00
for shape in nonempty_nonscalar_array_shapes
for axis in [None] + list(range(-len(shape), len(shape)))
],
dtype=all_dtypes,
)
2021-09-16 12:18:32 -07:00
def testInsertSlice(self, shape, dtype, axis):
x = jnp.empty(shape)
max_ind = x.size if axis is None else x.shape[axis]
rng = jtu.rand_default(self.rng())
i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind)
slc = slice(i_rng((), jnp.int32).item(), i_rng((), jnp.int32).item())
args_maker = lambda: [rng(shape, dtype), rng((), dtype)]
np_fun = lambda x, val: np.insert(x, slc, val, axis=axis)
jnp_fun = lambda x, val: jnp.insert(x, slc, val, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.parameters([
[[[1, 1], [2, 2], [3, 3]], 1, 5, None],
[[[1, 1], [2, 2], [3, 3]], 1, 5, 1],
[[[1, 1], [2, 2], [3, 3]], 1, [1, 2, 3], 1],
[[[1, 1], [2, 2], [3, 3]], [1], [[1],[2],[3]], 1],
[[1, 1, 2, 2, 3, 3], [2, 2], [5, 6], None],
[[1, 1, 2, 2, 3, 3], slice(2, 4), [5, 6], None],
[[1, 1, 2, 2, 3, 3], [2, 2], [7.13, False], None],
[[[0, 1, 2, 3], [4, 5, 6, 7]], (1, 3), 999, 1]
])
def testInsertExamples(self, arr, index, values, axis):
# Test examples from the np.insert docstring
args_maker = lambda: (
np.asarray(arr), index if isinstance(index, slice) else np.array(index),
np.asarray(values), axis)
self._CheckAgainstNumpy(np.insert, jnp.insert, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
2020-09-11 08:47:05 -07:00
for shape in nonempty_array_shapes
for axis in range(-len(shape), len(shape))
],
dtype=default_dtypes,
out_dims=[0, 1, 2],
)
2020-09-11 08:47:05 -07:00
def testApplyAlongAxis(self, shape, dtype, axis, out_dims):
def func(x, out_dims):
if out_dims == 0:
return x.sum(dtype=x.dtype)
2020-09-11 08:47:05 -07:00
elif out_dims == 1:
return x * x[0]
elif out_dims == 2:
return x[:, None] + x[None, :]
2020-09-11 08:47:05 -07:00
else:
2022-12-01 09:12:01 -08:00
raise NotImplementedError(f"{out_dims=}")
2020-09-11 08:47:05 -07:00
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arr: np.apply_along_axis(func, axis, arr, out_dims=out_dims)
jnp_fun = lambda arr: jnp.apply_along_axis(func, axis, arr, out_dims=out_dims)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
atol={dtypes.bfloat16: 2e-2})
2020-09-11 08:47:05 -07:00
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axes=axes)
2020-09-16 13:30:08 -07:00
for shape in nonempty_shapes
for axes in itertools.combinations(range(len(shape)), 2)
],
func=["sum"],
keepdims=[True, False],
# Avoid low-precision types in sum()
dtype=[dtype for dtype in default_dtypes
if dtype not in [np.float16, jnp.bfloat16]],
)
2020-09-16 13:30:08 -07:00
def testApplyOverAxes(self, shape, dtype, func, keepdims, axes):
f = lambda x, axis: getattr(x, func)(axis=axis, keepdims=keepdims, dtype=dtype)
2020-09-16 13:30:08 -07:00
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
np_fun = lambda a: np.apply_over_axes(f, a, axes)
jnp_fun = lambda a: jnp.apply_over_axes(f, a, axes)
self._CompileAndCheck(jnp_fun, args_maker)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, dtype=dtype, axis=axis)
for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes)
for axis in [None] + list(range(-len(shape), max(1, len(shape))))
],
repeats=[0, 1, 2],
fixed_size=[False, True],
)
def testRepeat(self, axis, shape, dtype, repeats, fixed_size):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda arg: np.repeat(arg, repeats=repeats, axis=axis)
2022-10-06 10:20:26 -07:00
np_fun = jtu.promote_like_jnp(np_fun)
if fixed_size:
total_repeat_length = np.repeat(np.zeros(shape), repeats, axis).shape[axis or 0]
jnp_fun = lambda arg, rep: jnp.repeat(arg, repeats=rep, axis=axis,
total_repeat_length=total_repeat_length)
jnp_args_maker = lambda: [rng(shape, dtype), repeats]
clo_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis,
total_repeat_length=total_repeat_length)
clo_fun_args_maker = lambda: [rng(shape, dtype)]
self._CompileAndCheck(jnp_fun, jnp_args_maker)
self._CheckAgainstNumpy(np_fun, clo_fun, clo_fun_args_maker)
else:
# Now repeats is in a closure, so a constant.
jnp_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testRepeatScalarFastPath(self):
a = jnp.array([1,2,3,4])
f = lambda a: jnp.repeat(a, repeats=2)
jaxpr = jax.make_jaxpr(f)(a)
self.assertLessEqual(len(jaxpr.jaxpr.eqns), 6)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in [None] + list(range(len(shape)))],
dtype=number_dtypes,
return_index=[False, True],
return_inverse=[False, True],
return_counts=[False, True],
)
def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
extra_args = (return_index, return_inverse, return_counts)
use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False
np_fun = jtu.with_jax_dtype_defaults(lambda x: np_unique_backport(x, *extra_args, axis=axis), use_defaults)
jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
2023-12-21 15:49:26 -08:00
@jtu.sample_product(shape=all_shapes, dtype=number_dtypes)
def testUniqueAll(self, shape, dtype):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
np_fun = partial(np_unique_backport, return_index=True, return_inverse=True, return_counts=True)
2023-12-21 15:49:26 -08:00
else:
np_fun = np.unique_all
self._CheckAgainstNumpy(jnp.unique_all, np_fun, args_maker)
@jtu.sample_product(shape=all_shapes, dtype=number_dtypes)
def testUniqueCounts(self, shape, dtype):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
np_fun = lambda x: np.unique(x, return_counts=True)
else:
np_fun = np.unique_counts
self._CheckAgainstNumpy(jnp.unique_counts, np_fun, args_maker)
@jtu.sample_product(shape=all_shapes, dtype=number_dtypes)
def testUniqueInverse(self, shape, dtype):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
np_fun = partial(np_unique_backport, return_inverse=True)
2023-12-21 15:49:26 -08:00
else:
np_fun = np.unique_inverse
self._CheckAgainstNumpy(jnp.unique_inverse, np_fun, args_maker)
@jtu.sample_product(shape=all_shapes, dtype=number_dtypes)
def testUniqueValues(self, shape, dtype):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
np_fun = np.unique
else:
np_fun = lambda *args: np.sort(np.unique_values(*args))
2023-12-21 15:49:26 -08:00
self._CheckAgainstNumpy(jnp.unique_values, np_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in nonempty_array_shapes
for axis in [None] + list(range(len(shape)))],
dtype=number_dtypes,
size=[1, 5, 10],
2024-02-09 15:26:45 -08:00
fill_value=[None, 0, "slice"],
)
def testUniqueSize(self, shape, dtype, axis, size, fill_value):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
kwds = dict(axis=axis, return_index=True, return_inverse=True, return_counts=True)
if fill_value == "slice":
if axis is None:
fill_value = rng((), dtype)
else:
fill_value = rng(shape[:axis] + shape[axis + 1:], dtype)
elif fill_value is not None:
fill_value = np.array(fill_value).astype(dtype)
@partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True))
def np_fun(x, fill_value=fill_value):
u, ind, inv, counts = np_unique_backport(x, **kwds)
axis = kwds['axis']
if axis is None:
x = x.ravel()
axis = 0
n_unique = u.shape[axis]
if size <= u.shape[axis]:
slc = (slice(None),) * axis + (slice(size),)
u, ind, counts = u[slc], ind[:size], counts[:size]
else:
extra = (0, size - n_unique)
pads = [(0, 0)] * u.ndim
pads[axis] = extra
u = np.pad(u, pads, constant_values=0)
slices = [slice(None)] * u.ndim
slices[axis] = slice(1)
if fill_value is None:
fill_value = u[tuple(slices)]
elif np.ndim(fill_value):
fill_value = lax.expand_dims(fill_value, (axis,))
slices[axis] = slice(n_unique, None)
u[tuple(slices)] = fill_value
ind = np.pad(ind, extra, constant_values=ind[0])
counts = np.pad(counts, extra, constant_values=0)
return u, ind, inv, counts
jnp_fun = lambda x: jnp.unique(x, size=size, fill_value=fill_value, **kwds)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(dtype=inexact_dtypes)
2022-01-13 15:54:07 -08:00
def testUniqueNans(self, dtype):
def args_maker():
x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan]
if np.issubdtype(dtype, np.complexfloating):
x = [complex(i, j) for i, j in itertools.product(x, repeat=2)]
return [np.array(x, dtype=dtype)]
kwds = dict(return_index=True, return_inverse=True, return_counts=True)
jnp_fun = partial(jnp.unique, **kwds)
def np_fun(x):
dtype = x.dtype
# numpy unique fails for bfloat16 NaNs, so we cast to float64
if x.dtype == jnp.bfloat16:
x = x.astype('float64')
u, *rest = np.unique(x, **kwds)
return (u.astype(dtype), *rest)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False])
@jtu.ignore_warning(
category=RuntimeWarning, message='invalid value encountered in cast'
)
def testUniqueEqualNan(self, dtype, equal_nan):
shape = (20,)
rng = jtu.rand_some_nan(self.rng())
args_maker = lambda: [rng(shape, dtype)]
def np_fun(x):
dtype = x.dtype
# numpy unique fails for bfloat16 NaNs, so we cast to float64
if x.dtype == jnp.bfloat16:
x = x.astype('float64')
return np.unique(x, equal_nan=equal_nan).astype(dtype)
jnp_fun = partial(jnp.unique, equal_nan=equal_nan)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@jtu.sample_product(fixed_size=[False, True])
def testNonScalarRepeats(self, fixed_size):
2019-09-14 21:24:28 +01:00
'''
2020-04-20 23:47:49 -07:00
Following numpy test suite from `test_repeat` at
https://github.com/numpy/numpy/blob/main/numpy/core/tests/test_multiarray.py
2019-09-14 21:24:28 +01:00
'''
tol = 1e-5
2019-09-14 22:07:54 +01:00
def test_single(m, args_maker, repeats, axis):
lax_ans = jnp.repeat(m, repeats, axis)
2020-05-20 01:43:48 -03:00
numpy_ans = np.repeat(m, repeats, axis)
2019-09-14 22:07:54 +01:00
self.assertAllClose(lax_ans, numpy_ans, rtol=tol, atol=tol)
if fixed_size:
2019-09-14 22:07:54 +01:00
# Calculate expected size of the repeated axis.
rep_length = np.repeat(np.zeros_like(m), repeats, axis).shape[axis or 0]
jnp_fun = lambda arg, rep: jnp.repeat(
arg, repeats=rep, axis=axis, total_repeat_length=rep_length)
else:
jnp_fun = lambda arg: jnp.repeat(arg, repeats = repeats, axis=axis)
self._CompileAndCheck(jnp_fun, args_maker)
2019-09-14 21:24:28 +01:00
m = jnp.array([1,2,3,4,5,6])
if fixed_size:
args_maker = lambda: [m, repeats]
else:
args_maker = lambda: [m]
2019-09-14 21:24:28 +01:00
for repeats in [2, jnp.array([1,3,0,1,1,2]), jnp.array([1,3,2,1,1,2]), jnp.array([2])]:
2020-07-15 19:03:58 +01:00
test_single(m, args_maker, repeats, axis=None)
test_single(m, args_maker, repeats, axis=0)
2019-09-14 21:57:46 +01:00
2019-09-14 22:01:36 +01:00
m_rect = m.reshape((2,3))
if fixed_size:
args_maker = lambda: [m_rect, repeats]
else:
args_maker = lambda: [m_rect]
2019-09-14 21:24:28 +01:00
for repeats in [2, jnp.array([2,1]), jnp.array([2])]:
2019-09-14 22:07:54 +01:00
test_single(m_rect, args_maker, repeats, axis=0)
2019-09-14 21:24:28 +01:00
for repeats in [2, jnp.array([1,3,2]), jnp.array([2])]:
2019-09-14 22:07:54 +01:00
test_single(m_rect, args_maker, repeats, axis=1)
2019-09-14 21:24:28 +01:00
def testIssue2330(self):
'''
Make sure return value of jnp.concatenate is a jax.ndarray and is side-effect save
'''
def attempt_sideeffect(x):
x = [x]
x = jnp.concatenate(x)
x -= 1.
return x
np_input = np.ones(1)
jnp_input = jnp.ones(1)
expected_np_input_after_call = np.ones(1)
expected_jnp_input_after_call = jnp.ones(1)
2020-04-13 17:48:49 -07:00
out = jnp.concatenate([np_input])
self.assertIs(type(out), array.ArrayImpl)
2020-04-13 17:48:49 -07:00
2020-05-20 01:43:48 -03:00
attempt_sideeffect(np_input)
attempt_sideeffect(jnp_input)
self.assertAllClose(np_input, expected_np_input_after_call)
self.assertAllClose(jnp_input, expected_jnp_input_after_call)
@jtu.sample_product(
mode=['full', 'same', 'valid'],
op=['convolve', 'correlate'],
dtype=number_dtypes,
xshape=one_dim_array_shapes,
yshape=one_dim_array_shapes,
)
def testConvolutions(self, xshape, yshape, dtype, mode, op):
jnp_op = getattr(jnp, op)
np_op = getattr(np, op)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
precision = lax.Precision.HIGHEST if jtu.test_device_matches(["tpu"]) else None
jnp_fun = partial(jnp_op, mode=mode, precision=precision)
def np_fun(x, y):
return np_op(x, y, mode=mode).astype(dtypes.to_inexact_dtype(dtype))
tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14,
np.complex128: 1e-14}
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
mode=['full', 'same', 'valid'],
op=['convolve', 'correlate'],
dtype=number_dtypes,
xshape=one_dim_array_shapes,
yshape=one_dim_array_shapes,
)
@jtu.skip_on_devices("cuda", "rocm") # backends don't support all dtypes.
def testConvolutionsPreferredElementType(self, xshape, yshape, dtype, mode, op):
jnp_op = getattr(jnp, op)
np_op = getattr(np, op)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
precision = lax.Precision.HIGHEST if jtu.test_device_matches(["tpu"]) else None
jnp_fun = partial(jnp_op, mode=mode, precision=precision,
preferred_element_type=dtype)
def np_fun(x, y):
return np_op(x, y, mode=mode).astype(dtype)
tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14,
np.complex128: 1e-14}
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in [None] + list(range(-len(shape), len(shape)))],
op=["cumsum", "cumprod"],
dtype=all_dtypes,
out_dtype=[dtype for dtype in default_dtypes if dtype != np.float16],
)
def testCumSumProd(self, axis, shape, dtype, out_dtype, op):
jnp_op = getattr(jnp, op)
np_op = getattr(np, op)
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype)
2023-08-07 19:08:41 +02:00
np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun)
np_fun = jtu.ignore_warning(category=RuntimeWarning,
message="overflow encountered.*")(np_fun)
jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol_thresholds = {dtypes.bfloat16: 4e-2}
tol = max(jtu.tolerance(dtype, tol_thresholds),
jtu.tolerance(out_dtype, tol_thresholds))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in [None] + list(range(-len(shape), len(shape)))],
op=["nancumsum", "nancumprod"],
dtype=all_dtypes,
out_dtype=default_dtypes,
)
def testNanCumSumProd(self, axis, shape, dtype, out_dtype, op):
jnp_op = getattr(jnp, op)
np_op = getattr(np, op)
rng = jtu.rand_some_nan(self.rng())
np_fun = partial(np_op, axis=axis, dtype=out_dtype)
2023-08-07 19:08:41 +02:00
np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun)
np_fun = jtu.ignore_warning(category=RuntimeWarning,
message="overflow encountered.*")(np_fun)
jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol_thresholds = {dtypes.bfloat16: 4e-2, np.float16: 3e-3}
tol = max(jtu.tolerance(dtype, tol_thresholds),
jtu.tolerance(out_dtype, tol_thresholds))
if dtype != jnp.bfloat16:
# numpy functions do not properly handle bfloat16
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@jtu.sample_product(
dtype=default_dtypes,
n=[0, 4],
m=[None, 0, 1, 3, 4],
k=[*range(-4, 4), -2**100, 2**100],
)
def testEye(self, n, m, k, dtype):
np_fun = lambda: np.eye(n, M=m, k=k, dtype=dtype)
jnp_fun = lambda: jnp.eye(n, M=m, k=k, dtype=dtype)
args_maker = lambda: []
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2024-05-14 13:32:54 -07:00
@jtu.sample_product(
dtype=default_dtypes,
n=[0, 4],
m=[None, 0, 1, 3, 4],
k=range(-4, 4),
)
def testEyeDynamicK(self, n, m, k, dtype):
np_fun = lambda k: np.eye(n, M=m, k=k, dtype=dtype)
jnp_fun = lambda k: jnp.eye(n, M=m, k=k, dtype=dtype)
args_maker = lambda: [k]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
dtype=default_dtypes,
n=[0, 4],
m=[None, 0, 1, 3, 4],
k=range(-4, 4),
)
def testTri(self, m, n, k, dtype):
2020-05-20 01:43:48 -03:00
np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype)
jnp_fun = lambda: jnp.tri(n, M=m, k=k, dtype=dtype)
args_maker = lambda: []
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def test_tri_bug_22751(self):
with self.assertRaisesRegex(core.ConcretizationTypeError, "jax.numpy.tri"):
jax.jit(jnp.tri)(3, M=3, k=0)
@jtu.sample_product(
dtype=default_dtypes,
shape=[shape for shape in all_shapes if len(shape) >= 2],
op=["tril", "triu"],
k=list(range(-3, 3)),
)
def testTriLU(self, dtype, shape, op, k):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda arg: getattr(np, op)(arg, k=k)
jnp_fun = lambda arg: getattr(jnp, op)(arg, k=k)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
n=range(5),
k=range(-3, 3),
m=[None, *range(5)],
)
def testTrilIndices(self, n, k, m):
np_fun = lambda n, k, m: np.tril_indices(n, k=k, m=m)
jnp_fun = lambda n, k, m: jnp.tril_indices(n, k=k, m=m)
args_maker = lambda: [n, k, m]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@jtu.sample_product(
n=range(5),
k=range(-3, 3),
m=[None, *range(5)],
)
def testTriuIndices(self, n, k, m):
np_fun = lambda n, k, m: np.triu_indices(n, k=k, m=m)
jnp_fun = lambda n, k, m: jnp.triu_indices(n, k=k, m=m)
args_maker = lambda: [n, k, m]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@jtu.sample_product(
dtype=default_dtypes,
shape=[(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)],
k=[-1, 0, 1],
)
def testTriuIndicesFrom(self, shape, dtype, k):
rng = jtu.rand_default(self.rng())
np_fun = lambda arr, k: np.triu_indices_from(arr, k=k)
jnp_fun = lambda arr, k: jnp.triu_indices_from(arr, k=k)
args_maker = lambda: [rng(shape, dtype), k]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@jtu.sample_product(
dtype=default_dtypes,
shape=[(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)],
k=[-1, 0, 1],
)
def testTrilIndicesFrom(self, shape, dtype, k):
rng = jtu.rand_default(self.rng())
np_fun = lambda arr, k: np.tril_indices_from(arr, k=k)
jnp_fun = lambda arr, k: jnp.tril_indices_from(arr, k=k)
args_maker = lambda: [rng(shape, dtype), k]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
2024-09-27 13:57:40 -07:00
@jtu.sample_product(
n = [2, 3, 4],
k = [None, -1, 0, 1],
funcname = ['triu', 'tril']
)
def testMaskIndices(self, n, k, funcname):
kwds = {} if k is None else {'k': k}
jnp_result = jnp.mask_indices(n, getattr(jnp, funcname), **kwds)
np_result = np.mask_indices(n, getattr(np, funcname), **kwds)
self.assertArraysEqual(jnp_result, np_result, check_dtypes=False)
2023-10-20 16:47:46 -04:00
@jtu.sample_product(
dtype=default_dtypes,
a_shape=[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2), (1, 2), (0, 2), (2, 3), (2, 2, 2), (2, 2, 2, 2)],
val_shape=[(), (1,), (2,), (1, 2), (3, 2)],
)
def testFillDiagonal(self, dtype, a_shape, val_shape):
rng = jtu.rand_default(self.rng())
def np_fun(a, val):
a_copy = a.copy()
np.fill_diagonal(a_copy, val)
return a_copy
jnp_fun = partial(jnp.fill_diagonal, inplace=False)
args_maker = lambda : [rng(a_shape, dtype), rng(val_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
ndim=[0, 1, 4],
n=[0, 1, 7],
)
def testDiagIndices(self, ndim, n):
np.testing.assert_equal(jtu.with_jax_dtype_defaults(np.diag_indices)(n, ndim),
jnp.diag_indices(n, ndim))
@jtu.sample_product(
dtype=default_dtypes,
shape=[(1,1), (2,2), (3,3), (4,4), (5,5)],
)
2020-06-24 22:01:16 +05:30
def testDiagIndicesFrom(self, dtype, shape):
rng = jtu.rand_default(self.rng())
np_fun = jtu.with_jax_dtype_defaults(np.diag_indices_from)
2020-06-24 22:01:16 +05:30
jnp_fun = jnp.diag_indices_from
args_maker = lambda : [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
dtype=default_dtypes,
shape=[shape for shape in all_shapes if len(shape) in (1, 2)],
k=list(range(-4, 4)),
)
def testDiag(self, shape, dtype, k):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda arg: np.diag(arg, k)
jnp_fun = lambda arg: jnp.diag(arg, k)
2018-12-12 17:54:27 -05:00
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2018-12-12 17:54:27 -05:00
@jtu.sample_product(
dtype=default_dtypes,
shape=all_shapes,
k=list(range(-4, 4)),
)
2020-06-01 23:43:43 -04:00
def testDiagFlat(self, shape, dtype, k):
rng = jtu.rand_default(self.rng())
# numpy has inconsistencies for scalar values
# https://github.com/numpy/numpy/issues/16477
# jax differs in that it treats scalars values as length-1 arrays
np_fun = lambda arg: np.diagflat(np.atleast_1d(arg), k)
jnp_fun = lambda arg: jnp.diagflat(arg, k)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@jtu.sample_product(
dtype=default_dtypes,
a1_shape=one_dim_array_shapes,
a2_shape=one_dim_array_shapes,
)
def testPolyMul(self, a1_shape, a2_shape, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg1, arg2: np.polymul(arg1, arg2)
jnp_fun_np = lambda arg1, arg2: jnp.polymul(arg1, arg2, trim_leading_zeros=True)
jnp_fun_co = lambda arg1, arg2: jnp.polymul(arg1, arg2)
args_maker = lambda: [rng(a1_shape, dtype), rng(a2_shape, dtype)]
tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13}
self._CheckAgainstNumpy(np_fun, jnp_fun_np, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun_co, args_maker, check_dtypes=False)
@jtu.sample_product(
dtype=[dtype for dtype in default_dtypes
if dtype not in (np.float16, jnp.bfloat16)],
a_shape=one_dim_array_shapes,
b_shape=one_dim_array_shapes,
)
def testPolyDiv(self, a_shape, b_shape, dtype):
rng = jtu.rand_default(self.rng())
2022-04-13 13:21:29 -07:00
@jtu.ignore_warning(category=RuntimeWarning, message="divide by zero.*")
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")
@jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*")
def np_fun(arg1, arg2):
q, r = np.polydiv(arg1, arg2)
while r.size < max(arg1.size, arg2.size): # Pad residual to same size
r = np.pad(r, (1, 0), 'constant')
return q, r
def jnp_fun(arg1, arg2):
q, r = jnp.polydiv(arg1, arg2, trim_leading_zeros=True)
while r.size < max(arg1.size, arg2.size): # Pad residual to same size
r = jnp.pad(r, (1, 0), 'constant')
return q, r
args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)]
tol = {
dtypes.bfloat16: 2e-1,
np.float16: 2e-1,
np.float32: 5e-2,
np.float64: 5e-7
}
jnp_compile = jnp.polydiv # Without trim_leading_zeros (trim_zeros make it unable to be compiled by XLA)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_compile, args_maker, check_dtypes=True, atol=tol, rtol=tol)
@jtu.sample_product(
[dict(shape=shape, axis1=axis1, axis2=axis2)
2018-12-12 17:54:27 -05:00
for shape in [shape for shape in all_shapes if len(shape) >= 2]
for axis1 in range(-len(shape), len(shape))
for axis2 in [a for a in range(-len(shape), len(shape))
if a % len(shape) != axis1 % len(shape)]
],
dtype=default_dtypes,
offset=list(range(-4, 4)),
)
def testDiagonal(self, shape, dtype, offset, axis1, axis2):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda arg: np.diagonal(arg, offset, axis1, axis2)
jnp_fun = lambda arg: jnp.diagonal(arg, offset, axis1, axis2)
2018-12-12 17:54:27 -05:00
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2018-12-12 17:54:27 -05:00
@jtu.sample_product(
dtype=default_dtypes,
n=list(range(4)),
)
def testIdentity(self, n, dtype):
2020-05-20 01:43:48 -03:00
np_fun = lambda: np.identity(n, dtype)
jnp_fun = lambda: jnp.identity(n, dtype)
args_maker = lambda: []
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
shape=nonempty_shapes,
period=[None, 0.59],
left=[None, 0],
right=[None, 1],
# Note: skip 8-bit and 16-bit types due to insufficient precision.
dtype=jtu.dtypes.integer + jtu.dtypes.floating,
target_dtype=jtu.dtypes.inexact,
)
def testInterp(self, shape, dtype, period, left, right, target_dtype):
2020-08-04 12:39:04 -07:00
rng = jtu.rand_default(self.rng(), scale=10)
kwds = dict(period=period, left=left, right=right)
np_fun = partial(np.interp, **kwds)
jnp_fun = partial(jnp.interp, **kwds)
args_maker = lambda: [rng(shape, dtype), np.unique(rng((100,), dtype))[:20],
rng((20,), target_dtype)]
2020-08-04 12:39:04 -07:00
with jtu.strict_promotion_if_dtypes_match([dtype, target_dtype]):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
rtol=3e-3, atol=1e-3)
self._CompileAndCheck(jnp_fun, args_maker)
2020-08-04 12:39:04 -07:00
@jtu.sample_product([
dict(x=0.5, left='extrapolate', expected=5),
dict(x=1.5, left='extrapolate', expected=15),
dict(x=3.5, left='extrapolate', expected=30),
dict(x=3.9, right='extrapolate', expected=39),
])
def testInterpExtrapoate(self, x, expected, **kwargs):
xp = jnp.array([1.0, 2.0, 3.0])
fp = jnp.array([10.0, 20.0, 30.0])
actual = jnp.interp(x, xp, fp, **kwargs)
self.assertAlmostEqual(actual, expected)
def testInterpErrors(self):
with self.assertRaisesWithLiteralMatch(
ValueError,
'xp and fp must be one-dimensional arrays of equal size'
):
jnp.interp(0.0, jnp.arange(2.0), jnp.arange(3.0))
with self.assertRaisesWithLiteralMatch(
ValueError,
"the only valid string value of `left` is 'extrapolate', but got: 'interpolate'"
):
jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), left='interpolate')
with self.assertRaisesWithLiteralMatch(
ValueError,
"the only valid string value of `right` is 'extrapolate', but got: 'interpolate'"
):
jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), right='interpolate')
with self.assertRaisesWithLiteralMatch(
ValueError,
"jnp.interp: complex x values not supported."
):
jnp.interp(1j, 1j * np.arange(3.0), np.arange(3.0))
with self.assertRaisesRegex(
ValueError,
"period must be a scalar; got"
):
jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), period=np.array([1.0]))
@jtu.sample_product(
period=[None, 0.59],
left=[None, 0],
right=[None, 1],
dtype=jtu.dtypes.floating,
)
def testInterpGradNan(self, dtype, period, left, right):
kwds = dict(period=period, left=left, right=right)
jnp_fun = partial(jnp.interp, **kwds)
# Probe values of x and xp that are close to zero and close together.
x = dtype(np.exp(np.linspace(-90, -20, 1000)))
g = jax.grad(lambda z: jnp.sum(jnp_fun(z, z, jnp.ones_like(z))))(x)
np.testing.assert_equal(np.all(np.isfinite(g)), True)
@jtu.sample_product(
[dict(x1_shape=x1_shape, x2_shape=x2_shape)
for x1_shape, x2_shape in filter(_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(array_shapes, 2))
],
x1_rng_factory=[jtu.rand_some_inf_and_nan, jtu.rand_some_zero],
x2_rng_factory=[partial(jtu.rand_int, low=-1075, high=1024)],
x1_dtype=default_dtypes,
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testLdexp(self, x1_shape, x1_dtype, x2_shape, x1_rng_factory, x2_rng_factory):
x1_rng = x1_rng_factory(self.rng())
x2_rng = x2_rng_factory(self.rng())
@jtu.ignore_warning(category=RuntimeWarning, message="overflow.*")
def np_fun(x1, x2):
out_dtype = dtypes.to_inexact_dtype(x1.dtype)
return np.ldexp(x1.astype(out_dtype), x2)
jnp_fun = jnp.ldexp
args_maker = lambda: [x1_rng(x1_shape, x1_dtype),
2020-05-20 01:43:48 -03:00
x2_rng(x2_shape, np.int32)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
rng_factory=[
jtu.rand_some_inf_and_nan,
jtu.rand_some_zero,
partial(jtu.rand_not_small, offset=1e8),
],
shape=all_shapes,
dtype=default_dtypes,
)
@jtu.ignore_warning(category=RuntimeWarning, message="overflow")
def testFrexp(self, shape, dtype, rng_factory):
# integer types are converted to float64 in numpy's implementation
2020-05-20 01:43:48 -03:00
if (dtype not in [jnp.bfloat16, np.float16, np.float32]
and not config.enable_x64.value):
self.skipTest("Only run float64 testcase when float64 is enabled.")
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype)]
def np_frexp(x):
mantissa, exponent = np.frexp(x)
# NumPy is inconsistent between Windows and Linux/Mac on what the
# value of exponent is if the input is infinite. Normalize to the Linux
# behavior.
exponent = np.where(np.isinf(mantissa), np.zeros_like(exponent), exponent)
return mantissa, exponent
self._CheckAgainstNumpy(np_frexp, jnp.frexp, args_maker,
check_dtypes=np.issubdtype(dtype, np.inexact))
self._CompileAndCheck(jnp.frexp, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis1=axis1, axis2=axis2)
for shape in [shape for shape in all_shapes if len(shape) >= 2]
2019-05-20 17:11:18 -07:00
for axis1 in range(-len(shape), len(shape))
for axis2 in range(-len(shape), len(shape))
if (axis1 % len(shape)) != (axis2 % len(shape))
],
dtype=default_dtypes,
out_dtype=[None] + number_dtypes,
offset=list(range(-4, 4)),
)
def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
def np_fun(arg):
if out_dtype == jnp.bfloat16:
2020-05-20 01:43:48 -03:00
return np.trace(arg, offset, axis1, axis2, np.float32).astype(jnp.bfloat16)
else:
2020-05-20 01:43:48 -03:00
return np.trace(arg, offset, axis1, axis2, out_dtype)
jnp_fun = lambda arg: jnp.trace(arg, offset, axis1, axis2, out_dtype)
args_maker = lambda: [rng(shape, dtype)]
# TODO: Fails with uint8/uint16 output dtypes (integer overflow?)
if out_dtype not in (np.uint8, np.uint16, np.uint32):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testTraceSameAxesError(self):
a = jnp.arange(1, 13).reshape(2, 3, 2)
with self.assertRaisesRegex(ValueError, r"axis1 and axis2 can not be same"):
jnp.trace(a, axis1=1, axis2=-2)
@jtu.sample_product(
ashape=[(15,), (16,), (17,)],
vshape=[(), (5,), (5, 5)],
side=['left', 'right'],
dtype=number_dtypes,
method=['sort', 'scan', 'scan_unrolled', 'compare_all'],
use_sorter=[True, False],
)
def testSearchsorted(self, ashape, vshape, side, dtype, method, use_sorter):
rng = jtu.rand_default(self.rng())
def args_maker():
a = rng(ashape, dtype)
v = rng(vshape, dtype)
return (a, v, np.argsort(a)) if use_sorter else (np.sort(a), v)
def np_fun(a, v, sorter=None):
return np.searchsorted(a, v, side=side, sorter=sorter).astype('int32')
def jnp_fun(a, v, sorter=None):
return jnp.searchsorted(a, v, side=side, method=method, sorter=sorter)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(
platform.system() == "Windows",
"Under Windows, NumPy throws if 2**32 is converted to an int32"
)
def testSearchsortedDtype(self):
# Test that for large arrays, int64 indices are used. We test this
# via abstract evaluation to avoid allocating a large array in tests.
a_int32 = core.ShapedArray((np.iinfo(np.int32).max,), np.float32)
a_int64 = core.ShapedArray((np.iinfo(np.int32).max + 1,), np.float32)
v = core.ShapedArray((), np.float32)
out_int32 = jax.eval_shape(jnp.searchsorted, a_int32, v)
self.assertEqual(out_int32.dtype, np.int32)
if config.enable_x64.value:
out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v)
self.assertEqual(out_int64.dtype, np.int64)
elif jtu.numpy_version() < (2, 0, 0):
with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"):
with jtu.ignore_warning(category=DeprecationWarning,
message="NumPy will stop allowing conversion.*"):
out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v)
else:
with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype.*int64"):
with self.assertRaisesRegex(OverflowError, "Python integer 2147483648 out of bounds.*"):
out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v)
@jtu.sample_product(
dtype=inexact_dtypes,
side=['left', 'right'],
method=['sort', 'scan', 'compare_all'],
)
def testSearchsortedNans(self, dtype, side, method):
2022-01-06 09:19:28 -08:00
if np.issubdtype(dtype, np.complexfloating):
raise SkipTest("Known failure for complex inputs; see #9107")
x = np.array([-np.inf, -1.0, 0.0, -0.0, 1.0, np.inf, np.nan, -np.nan], dtype=dtype)
# The sign bit should not matter for 0.0 or NaN, so argsorting the above should be
# equivalent to argsorting the following:
x_equiv = np.array([0, 1, 2, 2, 3, 4, 5, 5])
if jnp.issubdtype(dtype, jnp.complexfloating):
x = np.array([complex(r, c) for r, c in itertools.product(x, repeat=2)])
x_equiv = np.array([complex(r, c) for r, c in itertools.product(x_equiv, repeat=2)])
fun = partial(jnp.searchsorted, side=side, method=method)
self.assertArraysEqual(fun(x, x), fun(x_equiv, x_equiv))
self.assertArraysEqual(jax.jit(fun)(x, x), fun(x_equiv, x_equiv))
2022-01-06 09:19:28 -08:00
@jtu.sample_product(
xshape=[(20,), (5, 4)],
2024-03-14 16:55:06 -05:00
binshape=[(0,), (1,), (5,)],
right=[True, False],
reverse=[True, False],
dtype=default_dtypes,
)
def testDigitize(self, xshape, binshape, right, reverse, dtype):
order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:]
rng = jtu.rand_default(self.rng())
2020-05-09 05:36:09 -07:00
args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]]
np_fun = lambda x, bins: np.digitize(x, bins, right=right).astype('int32')
2020-05-09 05:36:09 -07:00
jnp_fun = lambda x, bins: jnp.digitize(x, bins, right=right)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
xshape=[(20,), (5, 4)],
binshape=[(0,), (1,), (5,)],
right=[True, False],
method=['scan', 'scan_unrolled', 'sort', 'compare_all'],
reverse=[True, False],
dtype=default_dtypes,
)
def testDigitizeMethod(self, xshape, binshape, right, method, reverse, dtype):
order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:]
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]]
np_fun = lambda x, bins: np.digitize(x, bins, right=right).astype('int32')
jnp_fun = lambda x, bins: jnp.digitize(x, bins, right=right, method=method)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2020-05-09 05:36:09 -07:00
@jtu.sample_product(
dtypes=[
[np.float32],
[np.float32, np.float32],
[np.float32, np.int32, np.float32],
[np.float32, np.int64, np.float32],
[np.float32, np.int32, np.float64],
],
shape=[(), (2,), (3, 4), (1, 5)],
array_input=[True, False],
)
def testColumnStack(self, shape, dtypes, array_input):
rng = jtu.rand_default(self.rng())
if array_input:
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
2022-10-06 10:20:26 -07:00
np_fun = jtu.promote_like_jnp(np.column_stack)
jnp_fun = jnp.column_stack
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in [(), (2,), (3, 4), (1, 100)]
for axis in range(-len(shape), len(shape) + 1)
],
dtypes=[
[np.float32],
[np.float32, np.float32],
[np.float32, np.int32, np.float32],
[np.float32, np.int64, np.float32],
[np.float32, np.int32, np.float64],
],
array_input=[True, False],
out_dtype=[np.float32, np.int32],
)
def testStack(self, shape, axis, dtypes, array_input, out_dtype):
rng = jtu.rand_default(self.rng())
if array_input:
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
2024-06-21 14:57:35 -07:00
np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe'))
jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype)
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
op=["hstack", "vstack", "dstack"],
dtypes=[
[np.float32],
[np.float32, np.float32],
[np.float32, np.int32, np.float32],
[np.float32, np.int64, np.float32],
[np.float32, np.int32, np.float64],
],
shape=[(), (2,), (3, 4), (1, 100), (2, 3, 4)],
array_input=[True, False],
out_dtype=[np.float32, np.int32],
)
def testHVDStack(self, shape, op, dtypes, array_input, out_dtype):
rng = jtu.rand_default(self.rng())
if array_input:
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
2024-06-21 14:57:35 -07:00
if op == "dstack":
2022-10-06 10:20:26 -07:00
np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype))
else:
np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype,
casting='unsafe')
jnp_fun = partial(getattr(jnp, op), dtype=out_dtype)
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(name=name, **kwds)
for name in ['blackman', 'bartlett', 'hamming', 'hanning', 'kaiser']
for kwds in ([dict(beta=1), dict(beta=0.5)] if name == 'kaiser' else [{}])
],
size = [0, 1, 5, 10],
)
def testWindowFunction(self, name, size, **kwds):
jnp_fun = partial(getattr(jnp, name), size, **kwds)
np_fun = jtu.with_jax_dtype_defaults(partial(getattr(np, name), size, **kwds))
args_maker = lambda: []
tol = (
5e-6 if jtu.test_device_matches(['tpu']) and name == 'kaiser' else None
)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, atol=tol, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, fill_value_shape=fill_value_shape)
2020-05-20 01:43:48 -03:00
for shape in array_shapes + [3, np.array(7, dtype=np.int32)]
for fill_value_shape in _compatible_shapes(shape)],
fill_value_dtype=default_dtypes,
out_dtype=[None] + default_dtypes,
)
def testFull(self, shape, fill_value_dtype, fill_value_shape, out_dtype):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda fill_value: np.full(shape, fill_value, dtype=out_dtype)
jnp_fun = lambda fill_value: jnp.full(shape, fill_value, dtype=out_dtype)
args_maker = lambda: [rng(fill_value_shape, fill_value_dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2018-11-17 18:03:33 -08:00
@jtu.sample_product(
[dict(shape=shape, dtype=dtype, axis=axis)
for shape, dtype in _shape_and_dtypes(nonempty_nonscalar_array_shapes, default_dtypes)
for axis in list(range(-len(shape), max(1, len(shape))))
],
prepend=[None, 1, 0],
append=[None, 1, 0],
n=[0, 1, 2],
)
def testDiff(self, shape, dtype, n, axis, prepend, append):
prepend = np.zeros(shape, dtype=dtype) if prepend == 0 else prepend
append = np.zeros(shape, dtype=dtype) if append == 0 else append
rng = jtu.rand_default(self.rng())
2020-12-04 11:35:32 +05:30
args_maker = lambda: [rng(shape, dtype)]
def np_fun(x, n=n, axis=axis, prepend=prepend, append=append):
if prepend is None:
prepend = np._NoValue
2020-12-04 11:35:32 +05:30
elif not np.isscalar(prepend) and prepend.dtype == jnp.bfloat16:
prepend = prepend.astype(np.float32)
2020-12-04 11:35:32 +05:30
if append is None:
append = np._NoValue
elif not np.isscalar(append) and append.dtype == jnp.bfloat16:
append = append.astype(np.float32)
if x.dtype == jnp.bfloat16:
return np.diff(x.astype(np.float32), n=n, axis=axis, prepend=prepend, append=append).astype(jnp.bfloat16)
else:
return np.diff(x, n=n, axis=axis, prepend=prepend, append=append)
jnp_fun = lambda x: jnp.diff(x, n=n, axis=axis, prepend=prepend, append=append)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
2024-10-21 13:35:13 -07:00
def testDiffBool(self):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng((10,), bool)]
self._CheckAgainstNumpy(np.diff, jnp.diff, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp.diff, args_maker)
def testDiffPrepoendScalar(self):
# Regression test for https://github.com/jax-ml/jax/issues/19362
x = jnp.arange(10)
result_jax = jnp.diff(x, prepend=x[0], append=x[-1])
x = np.array(x)
result_numpy = np.diff(x, prepend=x[0], append=x[-1])
self.assertArraysEqual(result_jax, result_numpy)
@jtu.sample_product(
op=["zeros", "ones"],
shape=[2, (), (2,), (3, 0), np.array((4, 5, 6), dtype=np.int32),
np.array(4, dtype=np.int32)],
dtype=all_dtypes,
)
def testZerosOnes(self, op, shape, dtype):
np_op = getattr(np, op)
jnp_op = getattr(jnp, op)
args_maker = lambda: []
2020-05-20 01:43:48 -03:00
np_op = partial(np_op, shape, dtype)
jnp_op = partial(jnp_op, shape, dtype)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testOnesWithInvalidShape(self):
with self.assertRaises(TypeError):
jnp.ones((-1, 1))
2024-07-24 10:47:27 -07:00
def test_full_like_committed(self):
x = jnp.array((1, 2, 3), dtype=np.int32)
self.assertFalse(x._committed)
self.assertFalse(lax.full_like(x, 1.1)._committed)
x = jax.device_put(x, jax.devices()[-1])
self.assertTrue(x._committed)
y = lax.full_like(x, 1.1)
self.assertTrue(y._committed)
self.assertEqual(x.sharding, y.sharding)
def test_zeros_like_with_explicit_device_and_jitted(self):
x = jnp.array((1, 2, 3), dtype=np.int32)
x = jax.device_put(x, jax.devices()[0])
zeros_like_with_device = partial(jnp.zeros_like, device=jax.devices()[0])
y = jax.jit(zeros_like_with_device)(x)
self.assertEqual(x.shape, y.shape)
self.assertEqual(y.sharding, SingleDeviceSharding(jax.devices()[0]))
@jtu.sample_product(
[dict(shape=shape, out_shape=out_shape, fill_value_shape=fill_value_shape)
for shape in array_shapes
for out_shape in [None] + array_shapes
for fill_value_shape in _compatible_shapes(shape if out_shape is None else out_shape)
],
in_dtype=default_dtypes,
fill_value_dtype=default_dtypes,
out_dtype=default_dtypes,
)
def testFullLike(self, shape, in_dtype, fill_value_dtype, fill_value_shape, out_dtype, out_shape):
rng = jtu.rand_default(self.rng())
np_fun = lambda x, fill_value: np.full_like(
x, fill_value, dtype=out_dtype, shape=out_shape)
jnp_fun = lambda x, fill_value: jnp.full_like(
x, fill_value, dtype=out_dtype, shape=out_shape)
args_maker = lambda: [rng(shape, in_dtype), rng(fill_value_shape, fill_value_dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
shape=array_shapes,
out_shape=[None] + array_shapes,
in_dtype=default_dtypes,
func=["ones_like", "zeros_like"],
out_dtype=default_dtypes,
)
def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda x: getattr(np, func)(x, dtype=out_dtype, shape=out_shape)
jnp_fun = lambda x: getattr(jnp, func)(x, dtype=out_dtype, shape=out_shape)
args_maker = lambda: [rng(shape, in_dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
func=[jnp.empty, jnp.zeros, jnp.ones, jnp.full],
shape=array_shapes,
dtype=default_dtypes,
)
def testArrayCreationWithDevice(self, func, shape, dtype):
device = jax.devices()[-1]
kwds = {'fill_value': 1} if func is jnp.full else {}
out = func(**kwds, shape=shape, dtype=dtype, device=device)
self.assertEqual(out.devices(), {device})
@jtu.sample_product(
func=[jnp.empty, jnp.zeros, jnp.ones, jnp.full],
shape=array_shapes,
dtype=default_dtypes,
)
def testArrayCreationWithSharding(self, func, shape, dtype):
sharding = SingleDeviceSharding(jax.devices()[-1])
kwds = {'fill_value': 1} if func is jnp.full else {}
out = func(**kwds, shape=shape, dtype=dtype, device=sharding)
self.assertEqual(out.sharding, sharding)
@jtu.sample_product(
func=[
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device),
lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device),
],
dtype=default_dtypes,
)
def testArangeEyeLinspaceArrayWithDevice(self, func, dtype):
device = jax.devices()[-1]
output = func(dtype=dtype, device=device)
if isinstance(output, tuple):
self.assertEqual(output[0].devices(), {device})
else:
self.assertEqual(output.devices(), {device})
@jtu.sample_product(
func=[
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device),
lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device),
],
dtype=default_dtypes,
)
def testArangeEyeLinspaceArrayWithSharding(self, func, dtype):
sharding = SingleDeviceSharding(jax.devices()[-1])
output = func(dtype=dtype, device=sharding)
if isinstance(output, tuple):
self.assertEqual(output[0].sharding, sharding)
else:
self.assertEqual(output.sharding, sharding)
@jtu.sample_product(
func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like],
shape=array_shapes,
dtype=default_dtypes,
)
def testFullLikeWithDevice(self, func, shape, dtype):
device = jax.devices()[-1]
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
kwds = {'fill_value': 1} if func is jnp.full_like else {}
with self.subTest('device from keyword'):
out = func(x, **kwds, device=device)
self.assertEqual(out.devices(), {device})
with self.subTest('device from input array'):
out2 = func(out, **kwds)
self.assertEqual(out2.devices(), out.devices())
@jtu.sample_product(
func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like],
shape=array_shapes,
dtype=default_dtypes,
)
def testFullLikeWithSharding(self, func, shape, dtype):
sharding = SingleDeviceSharding(jax.devices()[-1])
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
kwds = {'fill_value': 1} if func is jnp.full_like else {}
with self.subTest('device from keyword'):
out = func(x, **kwds, device=sharding)
self.assertEqual(out.sharding, sharding)
with self.subTest('device from input array'):
out2 = func(out, **kwds)
self.assertEqual(out2.devices(), out.devices())
def testDuckTypedLike(self):
x = jax.ShapeDtypeStruct((1, 2, 3), np.dtype("int32"))
self.assertArraysEqual(jnp.zeros_like(x), jnp.zeros(x.shape, x.dtype))
self.assertArraysEqual(jnp.ones_like(x), jnp.ones(x.shape, x.dtype))
self.assertArraysEqual(jnp.empty_like(x), jnp.empty(x.shape, x.dtype))
self.assertArraysEqual(jnp.full_like(x, 2), jnp.full(x.shape, 2, x.dtype))
@jtu.sample_product(
[dict(func=func, args=args)
for func, args in [("full_like", (-100,)), ("ones_like", ()), ("zeros_like", ())]
],
shape=array_shapes,
in_dtype=[np.int32, np.float32, np.complex64],
weak_type=[True, False],
out_shape=[None, (), (10,)],
out_dtype=[None, float],
)
2021-02-08 13:37:25 -08:00
def testZerosOnesFullLikeWeakType(self, func, args, shape, in_dtype, weak_type, out_shape, out_dtype):
rng = jtu.rand_default(self.rng())
x = lax_internal._convert_element_type(rng(shape, in_dtype),
weak_type=weak_type)
2021-02-08 13:37:25 -08:00
fun = lambda x: getattr(jnp, func)(x, *args, dtype=out_dtype, shape=out_shape)
expected_weak_type = weak_type and (out_dtype is None)
self.assertEqual(dtypes.is_weakly_typed(fun(x)), expected_weak_type)
self.assertEqual(dtypes.is_weakly_typed(jax.jit(fun)(x)), expected_weak_type)
2021-02-08 13:37:25 -08:00
@jtu.sample_product(
funcname=["array", "asarray"],
dtype=[int, float, None],
val=[0, 1],
input_type=[int, float, np.int32, np.float32],
)
2021-02-08 13:37:25 -08:00
def testArrayWeakType(self, funcname, input_type, val, dtype):
func = lambda x: getattr(jnp, funcname)(x, dtype=dtype)
fjit = jax.jit(func)
2021-02-08 13:37:25 -08:00
val = input_type(val)
expected_weak_type = dtype is None and input_type in set(dtypes._weak_types)
self.assertEqual(dtypes.is_weakly_typed(func(val)), expected_weak_type)
self.assertEqual(dtypes.is_weakly_typed(fjit(val)), expected_weak_type)
@jtu.sample_product(
shape=nonempty_nonscalar_array_shapes,
dtype=[int, float, complex],
weak_type=[True, False],
slc=[slice(None), slice(0), slice(3), 0, ...],
)
2021-02-08 13:37:25 -08:00
def testSliceWeakTypes(self, shape, dtype, weak_type, slc):
rng = jtu.rand_default(self.rng())
x = lax_internal._convert_element_type(rng(shape, dtype),
weak_type=weak_type)
2021-02-08 13:37:25 -08:00
op = lambda x: x[slc]
self.assertEqual(op(x).aval.weak_type, weak_type)
self.assertEqual(jax.jit(op)(x).aval.weak_type, weak_type)
2021-02-08 13:37:25 -08:00
@jtu.sample_product(
[dict(shape=shape, axis=axis, num_sections=num_sections)
2018-11-17 18:03:33 -08:00
for shape, axis, num_sections in [
((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2),
((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)]
],
dtype=default_dtypes,
)
def testSplitStaticInt(self, shape, num_sections, axis, dtype):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda x: np.split(x, num_sections, axis=axis)
jnp_fun = lambda x: jnp.split(x, num_sections, axis=axis)
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis, num_sections=num_sections)
2020-09-03 00:13:17 +01:00
# All testcases split the specified axis unequally
for shape, axis, num_sections in [
((3,), 0, 2), ((12,), 0, 5), ((12, 4), 0, 7), ((12, 4), 1, 3),
((2, 3, 5), -1, 2), ((2, 4, 4), -2, 3), ((7, 2, 2), 0, 3)]
],
dtype=default_dtypes,
)
2020-09-03 00:13:17 +01:00
def testArraySplitStaticInt(self, shape, num_sections, axis, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.array_split(x, num_sections, axis=axis)
jnp_fun = lambda x: jnp.array_split(x, num_sections, axis=axis)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testSplitTypeError(self):
# If we pass an ndarray for indices_or_sections -> no error
self.assertEqual(3, len(jnp.split(jnp.zeros(3), jnp.array([1, 2]))))
CONCRETIZATION_MSG = "Abstract tracer value encountered where concrete value is expected."
with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG):
# An abstract tracer for idx
jax.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), idx))(2.)
with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG):
# A list including an abstract tracer
jax.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), [2, idx]))(2.)
# A concrete tracer -> no error
jax.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), idx),
(2.,), (1.,))
# A tuple including a concrete tracer -> no error
2024-02-13 08:08:29 +02:00
jax.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), (1, idx.astype(np.int32))),
(2.,), (1.,))
@jtu.sample_product(
shape=[(5,), (5, 5)],
dtype=number_dtypes,
bins=[10, np.arange(-5, 6), np.array([-5, 0, 3])],
range=[None, (0, 0), (0, 10)],
weights=[True, False],
)
def testHistogramBinEdges(self, shape, dtype, bins, range, weights):
rng = jtu.rand_default(self.rng())
_weights = lambda w: abs(w) if weights else None
np_fun = lambda a, w, r: np.histogram_bin_edges(a, bins=bins, range=r,
weights=_weights(w))
jnp_fun = lambda a, w, r: jnp.histogram_bin_edges(a, bins=bins, range=r,
weights=_weights(w))
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), range]
2020-05-20 01:43:48 -03:00
tol = {jnp.bfloat16: 2E-2, np.float16: 1E-2}
# linspace() compares poorly to numpy when using bfloat16
if dtype != jnp.bfloat16:
2020-05-20 01:43:48 -03:00
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker,
atol=tol, rtol=tol)
@jtu.sample_product(
2023-11-08 08:55:09 -08:00
shape=[(5,), (4, 5)],
dtype=default_dtypes,
2020-09-21 16:59:46 -04:00
# We only test explicit integer-valued bin edges because in other cases
# rounding errors lead to flaky tests.
bins=[np.arange(-5, 6), np.array([-5, 0, 3])],
density=[True, False],
weights=[True, False],
)
def testHistogram(self, shape, dtype, bins, density, weights):
rng = jtu.rand_default(self.rng())
_weights = lambda w: abs(w) if weights else None
2023-11-08 08:55:09 -08:00
def np_fun(a, w):
# Numpy can't handle bfloat16
a = a.astype('float32') if a.dtype == jnp.bfloat16 else a
w = w.astype('float32') if w.dtype == jnp.bfloat16 else w
return np.histogram(a, bins=bins, density=density, weights=_weights(w))
jnp_fun = lambda a, w: jnp.histogram(a, bins=bins, density=density,
weights=_weights(w))
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)]
2020-05-20 01:43:48 -03:00
tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1}
2023-11-08 08:55:09 -08:00
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
shape=[(5,), (12,)],
dtype=int_dtypes,
bins=[2, [2, 2], [np.array([0, 1, 3, 5]), np.array([0, 2, 3, 4, 6])]],
weights=[False, True],
density=[False, True],
range=[None, [(-1, 1), None], [(-1, 1), (-2, 2)]],
)
def testHistogram2d(self, shape, dtype, bins, weights, density, range):
2020-10-04 17:46:13 -04:00
rng = jtu.rand_default(self.rng())
_weights = lambda w: abs(w) if weights else None
np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")(
lambda a, b, w: np.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range))
jnp_fun = lambda a, b, w: jnp.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range)
2020-10-04 17:46:13 -04:00
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)]
tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1}
# np.searchsorted errors on bfloat16 with
# "TypeError: invalid type promotion with custom data type"
with np.errstate(divide='ignore', invalid='ignore'):
if dtype != jnp.bfloat16:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
2020-10-04 17:46:13 -04:00
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
shape=[(5, 3), (10, 3)],
dtype=int_dtypes,
bins=[(2, 2, 2), [np.array([-5, 0, 4]), np.array([-4, -1, 2]), np.array([-6, -1, 4])]],
weights=[False, True],
density=[False, True],
range=[None, [(-1, 1), None, None], [(-1, 1), (-2, 2), (-3, 3)]],
)
def testHistogramdd(self, shape, dtype, bins, weights, density, range):
2020-09-21 16:59:46 -04:00
rng = jtu.rand_default(self.rng())
_weights = lambda w: abs(w) if weights else None
np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")(
lambda a, w: np.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range))
jnp_fun = lambda a, w: jnp.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range)
2020-09-21 16:59:46 -04:00
args_maker = lambda: [rng(shape, dtype), rng((shape[0],), dtype)]
tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1}
# np.searchsorted errors on bfloat16 with
# "TypeError: invalid type promotion with custom data type"
if dtype != jnp.bfloat16:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
2020-09-21 16:59:46 -04:00
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis, num_sections=num_sections)
for shape, axis, num_sections in [
((12, 4), 0, 4), ((12,), 1, 2),
((2, 3, 4), 2, 2), ((4, 3, 4), 0, 2)]],
dtype=default_dtypes,
)
def testHVDSplit(self, shape, num_sections, axis, dtype):
rng = jtu.rand_default(self.rng())
def fn(module, axis):
if axis == 0:
return module.vsplit
elif axis == 1:
return module.hsplit
else:
assert axis == 2
return module.dsplit
2020-05-20 01:43:48 -03:00
np_fun = lambda x: fn(np, axis)(x, num_sections)
jnp_fun = lambda x: fn(jnp, axis)(x, num_sections)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2018-11-17 18:03:33 -08:00
@jtu.sample_product(
[dict(arg_shape=arg_shape, out_shape=out_shape)
2018-11-17 18:03:33 -08:00
for arg_shape, out_shape in [
(jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)),
((), (1, 1, 1)),
((7, 0), (0, 42, 101)),
2018-11-17 18:03:33 -08:00
((3, 4), 12),
((3, 4), (12,)),
((3, 4), -1),
((2, 1, 4), (-1,)),
((2, 2, 4), (2, 8))
]
],
dtype=default_dtypes,
order=["C", "F"],
)
def testReshape(self, arg_shape, out_shape, dtype, order):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda x: np.reshape(x, out_shape, order=order)
jnp_fun = lambda x: jnp.reshape(x, out_shape, order=order)
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2018-11-17 18:03:33 -08:00
@jtu.sample_product(
[dict(arg_shape=arg_shape, out_shape=out_shape)
for arg_shape, out_shape in [
((7, 0), (0, 42, 101)),
((2, 1, 4), (-1,)),
((2, 2, 4), (2, 8))
]
],
dtype=default_dtypes,
)
def testReshapeMethod(self, arg_shape, out_shape, dtype):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda x: np.reshape(x, out_shape)
jnp_fun = lambda x: x.reshape(*out_shape)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(arg_shape=arg_shape, out_shape=out_shape)
for arg_shape, out_shape in itertools.product(all_shapes, array_shapes)],
dtype=default_dtypes,
)
2021-05-21 09:24:41 +01:00
def testResize(self, arg_shape, out_shape, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.resize(x, out_shape)
jnp_fun = lambda x: jnp.resize(x, out_shape)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
2021-05-21 09:24:41 +01:00
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(arg_shape=arg_shape, dim=dim)
2018-11-17 18:03:33 -08:00
for arg_shape in [(), (3,), (3, 4)]
for dim in (list(range(-len(arg_shape)+1, len(arg_shape)))
+ [np.array(0), np.array(-1), (0,), [np.array(0)],
(len(arg_shape), len(arg_shape) + 1)])
],
dtype=default_dtypes,
)
def testExpandDimsStaticDim(self, arg_shape, dtype, dim):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda x: np.expand_dims(x, dim)
jnp_fun = lambda x: jnp.expand_dims(x, dim)
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng(arg_shape, dtype)]
self._CompileAndCheck(jnp_fun, args_maker)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
def testExpandDimsRepeatedAxisError(self):
x = jnp.ones((2, 3))
self.assertRaisesRegex(
ValueError, 'repeated axis.*',
lambda: jnp.expand_dims(x, [1, 1]))
self.assertRaisesRegex(
ValueError, 'repeated axis.*',
lambda: jnp.expand_dims(x, [3, -1]))
# ensure this is numpy's behavior too, so that we remain consistent
x = np.ones((2, 3))
self.assertRaisesRegex(
ValueError, 'repeated axis.*',
lambda: np.expand_dims(x, [1, 1]))
self.assertRaisesRegex(
ValueError, 'repeated axis.*',
lambda: np.expand_dims(x, [3, -1]))
@jtu.sample_product(
[dict(arg_shape=arg_shape, ax1=ax1, ax2=ax2)
2018-11-17 18:03:33 -08:00
for arg_shape, ax1, ax2 in [
((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2),
((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)]
],
dtype=default_dtypes,
)
def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda x: np.swapaxes(x, ax1, ax2)
jnp_fun = lambda x: jnp.swapaxes(x, ax1, ax2)
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2018-11-17 18:03:33 -08:00
@jtu.sample_product(
[dict(arg_shape=arg_shape, ax=ax)
2018-11-17 18:03:33 -08:00
for arg_shape, ax in [
((3, 1), None),
((3, 1), 1),
((3, 1), -1),
((3, 1), np.array(1)),
2018-11-17 18:03:33 -08:00
((1, 3, 1), (0, 2)),
((1, 3, 1), (0,)),
((1, 4, 1), (np.array(0),))]
],
dtype=default_dtypes,
)
def testSqueeze(self, arg_shape, dtype, ax):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
np_fun = lambda x: np.squeeze(x, ax)
jnp_fun = lambda x: jnp.squeeze(x, ax)
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2018-11-17 18:03:33 -08:00
2022-12-27 15:42:49 -08:00
def testArrayFromMasked(self):
args_maker = lambda: [np.ma.array([1, 2], mask=[True, False])]
# Like np.array, jnp.array strips the mask from masked array inputs.
self._CheckAgainstNumpy(np.array, jnp.array, args_maker)
# Under JIT, masked arrays are flagged as invalid.
with self.assertRaisesRegex(ValueError, "numpy masked arrays are not supported"):
jax.jit(jnp.asarray)(*args_maker())
@jtu.sample_product(
[dict(arg=arg, dtype=dtype, ndmin=ndmin)
for arg, dtypes in [
([True, False, True], all_dtypes),
(3., all_dtypes),
([1, 2, 3], all_dtypes),
(np.array([1, 2, 3], dtype=np.int64), all_dtypes),
([1., 2., 3.], all_dtypes),
([[1, 2], [3, 4], [5, 6]], all_dtypes),
([[1, 2.], [3, 4], [5, 6]], all_dtypes),
([[1., 2j], [3., 4.], [5., 6.]], complex_dtypes),
2020-05-20 01:43:48 -03:00
([[3, np.array(2, dtype=jnp.float_), 1],
np.arange(3., dtype=jnp.float_)], all_dtypes),
]
for dtype in [None] + dtypes
for ndmin in [None, np.ndim(arg), np.ndim(arg) + 1, np.ndim(arg) + 2]
],
)
def testArray(self, arg, ndmin, dtype):
2018-11-17 18:03:33 -08:00
args_maker = lambda: [arg]
canonical_dtype = dtypes.canonicalize_dtype(dtype or np.array(arg).dtype)
if ndmin is not None:
np_fun = partial(np.array, ndmin=ndmin, dtype=canonical_dtype)
jnp_fun = partial(jnp.array, ndmin=ndmin, dtype=dtype)
else:
np_fun = partial(np.array, dtype=canonical_dtype)
jnp_fun = partial(jnp.array, dtype=dtype)
# We are testing correct canonicalization behavior here, so we turn off the
# permissive canonicalization logic in the test harness.
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
canonicalize_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
2018-11-17 18:03:33 -08:00
@jtu.sample_product(copy=[None, True, False])
def testAsarrayCopy(self, copy):
x_jax = jnp.arange(4)
x_np = np.arange(4)
x_list = [0, 1, 2, 3]
x_buf = make_python_array('l', x_list)
func = partial(jnp.asarray, copy=copy)
self.assertArraysEqual(x_jax, func(x_jax))
self.assertArraysEqual(x_jax, func(x_list), check_dtypes=False)
if copy is False and jax.default_backend() != 'cpu':
# copy=False is strict: it must raise if the input supports the buffer protocol
# but a copy is still required.
self.assertRaises(ValueError, func, x_np)
self.assertRaises(ValueError, func, x_buf)
else:
self.assertArraysEqual(x_jax, func(x_np), check_dtypes=False)
self.assertArraysEqual(x_jax, func(x_buf), check_dtypes=False)
@jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*")
def testArrayDtypeInference(self):
def _check(obj, out_dtype, weak_type):
dtype_reference = np.array(obj, dtype=out_dtype)
out = jnp.array(obj)
self.assertDtypesMatch(out, dtype_reference)
self.assertEqual(dtypes.is_weakly_typed(out), weak_type)
2021-12-14 15:20:25 -08:00
out_jit = jax.jit(jnp.array)(obj)
self.assertDtypesMatch(out_jit, dtype_reference)
self.assertEqual(dtypes.is_weakly_typed(out_jit), weak_type)
# Python scalars become 64-bit weak types.
_check(1, np.int64, True)
_check(1.0, np.float64, True)
_check(1.0j, np.complex128, True)
# Lists become strongly-typed defaults.
_check([1], jnp.int64, False)
_check([1.0], jnp.float64, False)
_check([1.0j], jnp.complex128, False)
# Lists of weakly-typed objects become strongly-typed defaults.
_check([jnp.array(1)], jnp.int64, False)
_check([jnp.array(1.0)], jnp.float64, False)
_check([jnp.array(1.0j)], jnp.complex128, False)
# Lists of strongly-typed objects maintain their strong type.
_check([jnp.int64(1)], np.int64, False)
_check([jnp.float64(1)], np.float64, False)
_check([jnp.complex128(1)], np.complex128, False)
# Mixed inputs use JAX-style promotion.
# (regression test for https://github.com/jax-ml/jax/issues/8945)
_check([0, np.int16(1)], np.int16, False)
_check([0.0, np.float16(1)], np.float16, False)
@jtu.sample_product(
dtype=all_dtypes,
func=["array", "copy", "copy.copy", "copy.deepcopy"],
)
2022-03-01 10:44:33 -08:00
def testArrayCopy(self, dtype, func):
x = jnp.ones(10, dtype=dtype)
if func == "copy.deepcopy":
copy_func = copy.deepcopy
elif func == "copy.copy":
copy_func = copy.copy
else:
copy_func = getattr(jnp, func)
x_view = jnp.asarray(x)
x_view_jit = jax.jit(jnp.asarray)(x)
2022-03-01 10:44:33 -08:00
x_copy = copy_func(x)
x_copy_jit = jax.jit(copy_func)(x)
_ptr = lambda x: x.unsafe_buffer_pointer()
self.assertEqual(_ptr(x), _ptr(x_view))
self.assertNotEqual(_ptr(x), _ptr(x_view_jit))
self.assertNotEqual(_ptr(x), _ptr(x_copy))
self.assertNotEqual(_ptr(x), _ptr(x_copy_jit))
x.delete()
self.assertTrue(x_view.is_deleted())
self.assertFalse(x_view_jit.is_deleted())
self.assertFalse(x_copy.is_deleted())
self.assertFalse(x_copy_jit.is_deleted())
def testArrayCopyAutodiff(self):
f = lambda x: jnp.array(x, copy=True)
x = jnp.ones(10)
xdot = jnp.ones(10)
y, ydot = jax.jvp(f, (x,), (xdot,))
self.assertIsNot(x, y)
self.assertIsNot(xdot, ydot)
ybar = jnp.ones(10)
y, f_vjp = jax.vjp(f, x)
xbar, = f_vjp(ybar)
self.assertIsNot(x, y)
self.assertIsNot(xbar, ybar)
def testArrayCopyVmap(self):
f = lambda x: jnp.array(x, copy=True)
x = jnp.ones(10)
y = jax.vmap(f)(x)
self.assertIsNot(x, y)
def testArrayUnsupportedDtypeError(self):
with self.assertRaisesRegex(
TypeError, 'JAX only supports number, bool, and string dtypes.*'
):
jnp.array(3, [('a','<i4'),('b','<i4')])
def testArrayFromInteger(self):
int_dtype = dtypes.canonicalize_dtype(jnp.int64)
int_max = jnp.iinfo(int_dtype).max
int_min = jnp.iinfo(int_dtype).min
# Values at extremes are converted correctly.
for val in [int_min, 0, int_max]:
self.assertEqual(jnp.array(val).dtype, int_dtype)
# out of bounds leads to an OverflowError
val = int_max + 1
with self.assertRaisesRegex(OverflowError, f"Python int {val} too large to convert to {int_dtype.name}"):
jnp.array(val)
# explicit uint64 should work
if config.enable_x64.value:
self.assertEqual(np.uint64(val), jnp.array(val, dtype='uint64'))
2022-01-27 14:28:14 -08:00
def testArrayFromList(self):
dtype = dtypes.canonicalize_dtype('int64')
int_max = jnp.iinfo(dtype).max
int_min = jnp.iinfo(dtype).min
2022-01-27 14:28:14 -08:00
# Values at extremes are converted correctly.
for val in [int_min, 0, int_max]:
self.assertEqual(jnp.array([val]).dtype, dtype)
2022-01-27 14:28:14 -08:00
# list of values results in promoted type.
with jax.numpy_dtype_promotion('standard'):
self.assertEqual(jnp.array([0, np.float16(1)]).dtype, jnp.result_type('int64', 'float16'))
2022-01-27 14:28:14 -08:00
# out of bounds leads to an OverflowError
val = jnp.iinfo(jnp.int64).min - 1
2022-01-27 14:28:14 -08:00
with self.assertRaisesRegex(OverflowError, "Python int too large.*"):
jnp.array([0, val])
def testArrayNone(self):
with self.assertRaisesRegex(
ValueError, 'None is not a valid value for jnp.array'
):
jnp.array([0.0, None])
def testIssue121(self):
2020-05-20 01:43:48 -03:00
assert not np.isscalar(jnp.array(3))
def testArrayOutputsArrays(self):
assert type(jnp.array([])) is array.ArrayImpl
assert type(jnp.array(np.array([]))) is array.ArrayImpl
class NDArrayLike:
2024-03-05 09:31:16 -08:00
def __array__(self, dtype=None, copy=None):
return np.array([], dtype=dtype)
assert type(jnp.array(NDArrayLike())) is array.ArrayImpl
# NOTE(mattjj): disabled b/c __array__ must produce ndarrays
# class ArrayLike:
# def __array__(self, dtype=None):
# return jnp.array([], dtype=dtype)
# assert xla.type_is_device_array(jnp.array(ArrayLike()))
def testArrayMethod(self):
class arraylike:
dtype = np.dtype('float32')
2024-03-05 09:31:16 -08:00
def __array__(self, dtype=None, copy=None):
return np.array(3., dtype=dtype)
a = arraylike()
ans = jnp.array(a)
self.assertEqual(ans, 3.)
def testJaxArrayOps(self):
class arraylike:
def __jax_array__(self):
return jnp.array(3.)
self.assertArraysEqual(arraylike() * jnp.arange(10.), jnp.array(3.) * jnp.arange(10.))
def testMemoryView(self):
self.assertAllClose(
jnp.array(bytearray(b'\x2a')),
np.array(bytearray(b'\x2a'))
)
self.assertAllClose(
jnp.array(bytearray(b'\x2a\xf3'), ndmin=2),
np.array(bytearray(b'\x2a\xf3'), ndmin=2)
)
@jtu.sample_product(value=[False, 1, 1.0, np.int32(5), np.array(16)])
def testIsScalar(self, value):
self.assertTrue(jnp.isscalar(value))
@jtu.sample_product(value=[None, [1], slice(4), (), np.array([0])])
def testIsNotScalar(self, value):
self.assertFalse(jnp.isscalar(value))
@jtu.sample_product(val=[1+1j, [1+1j], jnp.pi, np.arange(2)])
def testIsComplexObj(self, val):
args_maker = lambda: [val]
self._CheckAgainstNumpy(np.iscomplexobj, jnp.iscomplexobj, args_maker)
self._CompileAndCheck(jnp.iscomplexobj, args_maker)
2020-03-25 09:59:43 +00:00
def testIsClose(self):
c_isclose = jax.jit(jnp.isclose)
c_isclose_nan = jax.jit(partial(jnp.isclose, equal_nan=True))
2020-03-25 09:59:43 +00:00
n = 2
rng = self.rng()
2020-03-25 09:59:43 +00:00
x = rng.randn(n, 1)
y = rng.randn(n, 1)
2020-05-20 01:43:48 -03:00
inf = np.asarray(n * [np.inf]).reshape([n, 1])
nan = np.asarray(n * [np.nan]).reshape([n, 1])
2020-03-25 09:59:43 +00:00
args = [x, y, inf, -inf, nan]
for arg0 in args:
for arg1 in args:
2020-05-20 01:43:48 -03:00
result_np = np.isclose(arg0, arg1)
2020-03-25 09:59:43 +00:00
result_jax = jnp.isclose(arg0, arg1)
result_jit = c_isclose(arg0, arg1)
self.assertTrue(jnp.all(jnp.equal(result_np, result_jax)))
self.assertTrue(jnp.all(jnp.equal(result_np, result_jit)))
2020-05-20 01:43:48 -03:00
result_np = np.isclose(arg0, arg1, equal_nan=True)
2020-03-25 09:59:43 +00:00
result_jax = jnp.isclose(arg0, arg1, equal_nan=True)
result_jit = c_isclose_nan(arg0, arg1)
self.assertTrue(jnp.all(jnp.equal(result_np, result_jax)))
self.assertTrue(jnp.all(jnp.equal(result_np, result_jit)))
self.assertEqual(np.isclose(6, 10, rtol=0.5), jnp.isclose(6, 10, rtol=0.5))
key = jax.random.key(0)
self.assertTrue(jnp.isclose(key, key))
@jtu.sample_product(
x=[1, [1], [1, 1 + 1E-4], [1, np.nan]],
y=[1, [1], [1, 1 + 1E-4], [1, np.nan]],
equal_nan=[True, False],
)
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testAllClose(self, x, y, equal_nan):
jnp_fun = partial(jnp.allclose, equal_nan=equal_nan, rtol=1E-3)
np_fun = partial(np.allclose, equal_nan=equal_nan, rtol=1E-3)
args_maker = lambda: [np.array(x), np.array(y)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2018-11-17 18:03:33 -08:00
def testZeroStridesConstantHandler(self):
raw_const = self.rng().randn(1, 2, 1, 1, 5, 1)
2020-05-20 01:43:48 -03:00
const = np.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6))
2018-11-17 18:03:33 -08:00
def fun(x):
return x * const
fun = jax.jit(fun)
2018-11-17 18:03:33 -08:00
out_val = fun(3.)
self.assertAllClose(out_val, 3. * const, check_dtypes=False)
def testIsInstanceNdarrayDuringTracing(self):
2020-05-20 01:43:48 -03:00
arr = np.ones(3)
2018-11-17 18:03:33 -08:00
@jax.jit
2018-11-17 18:03:33 -08:00
def f(x):
self.assertIsInstance(x, jax.Array)
return jnp.sum(x)
2018-11-17 18:03:33 -08:00
f(arr)
def testNonArrayErrorMessage(self):
x = [1., 2.]
2020-05-20 01:43:48 -03:00
y = np.array([3., 4.])
2018-11-17 18:03:33 -08:00
def g(x, y):
return jnp.add(x, y)
2018-11-17 18:03:33 -08:00
def f(x, y):
return jnp.dot(x, y)
2018-11-17 18:03:33 -08:00
self.assertRaises(TypeError, lambda: g(x, y))
self.assertRaises(TypeError, lambda: f(x, y))
self.assertRaises(TypeError, lambda: jax.jit(g)(x, y))
self.assertRaises(TypeError, lambda: jax.jit(f)(x, y))
2018-11-17 18:03:33 -08:00
def testAbstractionErrorMessage(self):
@jax.jit
2018-11-17 18:03:33 -08:00
def f(x, n):
for _ in range(n):
x = x * x
return x
self.assertRaises(jax.errors.TracerIntegerConversionError, lambda: f(3., 3))
2018-11-17 18:03:33 -08:00
@jax.jit
2018-11-17 18:03:33 -08:00
def g(x):
if x > 0.:
return x * 2
else:
return x + 2
self.assertRaises(jax.errors.ConcretizationTypeError, lambda: g(3.))
2018-11-17 18:03:33 -08:00
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in [(3,), (2, 3)]
2021-04-06 17:08:33 +02:00
for axis in list(range(-len(shape), len(shape))) + [None] + [tuple(range(len(shape)))] # Test negative axes and tuples
],
dtype=default_dtypes,
)
def testFlip(self, shape, dtype, axis):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
jnp_op = lambda x: jnp.flip(x, axis)
2020-05-20 01:43:48 -03:00
np_op = lambda x: np.flip(x, axis)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@jtu.sample_product(
shape=[(3,), (2, 3), (3, 2, 4)],
dtype=default_dtypes,
)
def testFlipud(self, shape, dtype):
rng = jtu.rand_default(self.rng())
2019-01-31 12:57:43 -08:00
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
jnp_op = lambda x: jnp.flipud(x)
2020-05-20 01:43:48 -03:00
np_op = lambda x: np.flipud(x)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
2019-01-31 12:57:43 -08:00
@jtu.sample_product(
shape=[(3, 2), (2, 3), (3, 2, 4)],
dtype=default_dtypes,
)
def testFliplr(self, shape, dtype):
rng = jtu.rand_default(self.rng())
2019-01-31 12:57:43 -08:00
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
jnp_op = lambda x: jnp.fliplr(x)
2020-05-20 01:43:48 -03:00
np_op = lambda x: np.fliplr(x)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
2019-01-31 12:57:43 -08:00
@jtu.sample_product(
[dict(shape=shape, axes=axes)
for shape, axes in [
[(2, 3), (0, 1)],
[(2, 3), (1, 0)],
[(4, 3, 2), (0, 2)],
[(4, 3, 2), (2, 1)],
]
],
k=range(-3, 4),
dtype=default_dtypes,
)
def testRot90(self, shape, dtype, k, axes):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
jnp_op = lambda x: jnp.rot90(x, k, axes)
2020-05-20 01:43:48 -03:00
np_op = lambda x: np.rot90(x, k, axes)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
2018-11-17 18:03:33 -08:00
# TODO(mattjj): test infix operator overrides
2018-12-13 11:52:41 -08:00
def testRavel(self):
rng = self.rng()
2018-11-17 18:03:33 -08:00
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
self._CompileAndCheck(lambda x: x.ravel(), args_maker)
2018-11-17 18:03:33 -08:00
@jtu.sample_product(
shape=nonempty_nonscalar_array_shapes,
order=['C', 'F'],
mode=['wrap', 'clip', 'raise'],
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testRavelMultiIndex(self, shape, order, mode):
# generate indices in each dimension with a few out of bounds.
rngs = [jtu.rand_int(self.rng(), low=-1, high=dim + 1)
for dim in shape]
# generate multi_indices of different dimensions that broadcast.
args_maker = lambda: [tuple(rng(ndim * (3,), jnp.int_)
for ndim, rng in enumerate(rngs))]
def np_fun(x):
try:
return np.ravel_multi_index(x, shape, order=order, mode=mode)
except ValueError as err:
if str(err).startswith('invalid entry'):
# sentinel indicating expected error.
return -999
else:
raise
def jnp_fun(x):
try:
return jnp.ravel_multi_index(x, shape, order=order, mode=mode)
except ValueError as err:
if str(err).startswith('invalid entry'):
# sentinel indicating expected error.
return -999
else:
raise
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
if mode == 'raise':
msg = ("The error occurred because ravel_multi_index was jit-compiled "
"with mode='raise'. Use mode='wrap' or mode='clip' instead.")
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
jax.jit(jnp_fun)(*args_maker())
else:
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
ashape=((), (4,), (3, 4)),
cshapes=[
[(), (4,)],
[(3, 4), (4,), (3, 1)]
],
adtype=int_dtypes,
cdtype=default_dtypes,
mode=['wrap', 'clip', 'raise'],
)
2020-10-02 13:13:21 -07:00
def testChoose(self, ashape, adtype, cshapes, cdtype, mode):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(ashape, adtype), [rng(s, cdtype) for s in cshapes]]
def np_fun(a, c):
try:
return np.choose(a, c, mode=mode)
except ValueError as err:
if mode == 'raise' and str(err).startswith('invalid entry'):
return -999 # sentinel indicating expected error.
else:
raise
def jnp_fun(a, c):
try:
return jnp.choose(a, c, mode=mode)
except ValueError as err:
if mode == 'raise' and str(err).startswith('invalid entry'):
return -999 # sentinel indicating expected error.
else:
raise
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
if mode == 'raise':
msg = ("The error occurred because jnp.choose was jit-compiled"
" with mode='raise'. Use mode='wrap' or mode='clip' instead.")
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
2020-10-02 13:13:21 -07:00
jax.jit(jnp_fun)(*args_maker())
else:
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
shape=nonempty_nonscalar_array_shapes,
dtype=int_dtypes,
idx_shape=all_shapes,
)
def testUnravelIndex(self, shape, idx_shape, dtype):
size = math.prod(shape)
rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3)
def np_fun(index, shape):
# JAX's version outputs the same dtype as the input in the typical case
# where shape is weakly-typed.
out_dtype = index.dtype
# Adjust out-of-bounds behavior to match jax's documented behavior.
index = np.clip(index, -size, size - 1)
index = np.where(index < 0, index + size, index)
return [i.astype(out_dtype) for i in np.unravel_index(index, shape)]
jnp_fun = jnp.unravel_index
args_maker = lambda: [rng(idx_shape, dtype), shape]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2020-05-06 16:05:49 -04:00
2023-11-30 15:50:22 -08:00
@jtu.sample_product(
from_dtype=['int32', 'float32'],
to_dtype=['int32', 'float32', None],
use_method=[True, False],
)
def testAstype(self, from_dtype, to_dtype, use_method):
rng = self.rng()
2023-11-30 15:50:22 -08:00
args_maker = lambda: [rng.randn(3, 4).astype(from_dtype)]
if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0
np_op = lambda x: np.astype(x, to_dtype)
else:
np_op = lambda x: np.asarray(x).astype(to_dtype)
if use_method:
jnp_op = lambda x: jnp.asarray(x).astype(to_dtype)
else:
jnp_op = lambda x: jnp.astype(x, to_dtype)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@jtu.sample_product(
from_dtype=['int32', 'float32', 'complex64'],
use_method=[True, False],
)
def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng((3, 4), from_dtype)]
if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0
np_op = lambda x: np.astype(x, to_dtype)
else:
np_op = lambda x: np.asarray(x).astype(to_dtype)
if use_method:
jnp_op = lambda x: jnp.asarray(x).astype(to_dtype)
else:
jnp_op = lambda x: jnp.astype(x, to_dtype)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@jtu.sample_product(
change_dtype=[True, False],
copy=[True, False],
)
def testAstypeCopy(self, change_dtype, copy):
dtype = 'float32' if change_dtype else 'int32'
expect_copy = change_dtype or copy
x = jnp.arange(5, dtype='int32')
y = x.astype(dtype, copy=copy)
self.assertEqual(y.dtype, dtype)
y.delete()
self.assertNotEqual(x.is_deleted(), expect_copy)
def testAstypeComplexDowncast(self):
x = jnp.array(2.0+1.5j, dtype='complex64')
with self.assertDeprecationWarnsOrRaises("jax-numpy-astype-complex-to-real",
"Casting from complex to real dtypes.*"):
x.astype('float32')
@parameterized.parameters('int2', 'int4')
def testAstypeIntN(self, dtype):
if dtype == 'int2':
self.skipTest('XLA support for int2 is incomplete.')
# Test converting from intN to int8
x = np.array([1, -2, -3, 4, -8, 7], dtype=dtype)
args_maker = lambda: [x]
np_op = lambda x: np.asarray(x).astype(jnp.int8)
jnp_op = lambda x: jnp.asarray(x).astype(jnp.int8)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
# Test converting from int8 to intN
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int8)
args_maker = lambda: [x]
np_op = lambda x: np.asarray(x).astype(dtype)
jnp_op = lambda x: jnp.asarray(x).astype(dtype)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@jtu.sample_product(
shape=array_shapes,
dtype=all_dtypes,
)
def testNbytes(self, shape, dtype):
rng = jtu.rand_default(self.rng())
np_op = lambda x: np.asarray(x).nbytes
jnp_op = lambda x: jnp.asarray(x).nbytes
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@jtu.sample_product(
shape=array_shapes,
dtype=all_dtypes,
)
def testItemsize(self, shape, dtype):
rng = jtu.rand_default(self.rng())
np_op = lambda x: np.asarray(x).itemsize
jnp_op = lambda x: jnp.asarray(x).itemsize
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
2024-01-03 13:03:47 -08:00
@jtu.sample_product(
shape=nonempty_array_shapes,
dtype=all_dtypes,
num_args=[0, 1, "all"],
use_tuple=[True, False]
)
def testItem(self, shape, dtype, num_args, use_tuple):
rng = jtu.rand_default(self.rng())
size = math.prod(shape)
if num_args == 0:
args = ()
elif num_args == 1:
args = (self.rng().randint(0, size),)
else:
args = tuple(self.rng().randint(0, s) for s in shape)
args = (args,) if use_tuple else args
np_op = lambda x: np.asarray(x).item(*args)
jnp_op = lambda x: jnp.asarray(x).item(*args)
args_maker = lambda: [rng(shape, dtype)]
if size != 1 and num_args == 0:
with self.assertRaises(ValueError):
jnp_op(*args_maker())
else:
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
@jtu.sample_product(
2024-07-24 10:47:27 -07:00
# Final dimension must be a multiple of 16 to ensure compatibility of all dtype pairs.
2024-12-20 13:57:40 -08:00
shape=[(0,), (64,), (2, 32)],
a_dtype=(jnp.int4, jnp.uint4, *all_dtypes),
dtype=((jnp.int4, jnp.uint4, *all_dtypes, None)
if config.enable_x64.value else (jnp.int4, jnp.uint4, *all_dtypes)),
)
def testView(self, shape, a_dtype, dtype):
if jtu.test_device_matches(["tpu"]):
if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]:
self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.")
# It is possible to fill bool arrays with arbitrary bits (not just 0/1
# bytes), but the behavior is implementation-defined. We therefore only test
# the well-defined case.
rng = (jtu.rand_bool if a_dtype == np.bool_ else jtu.rand_fullrange)(
self.rng()
)
args_maker = lambda: [rng(shape, a_dtype)]
2024-12-20 13:57:40 -08:00
np_op = lambda x: np_view(x, dtype)
jnp_op = lambda x: jnp.asarray(x).view(dtype)
# Above may produce signaling nans; ignore warnings from invalid values.
with np.errstate(invalid='ignore'):
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@jtu.sample_product([
{'a_dtype': a_dtype, 'dtype': dtype}
2024-12-20 13:57:40 -08:00
for a_dtype in [jnp.int4, jnp.uint4, *all_dtypes]
for dtype in [jnp.int4, jnp.uint4, *all_dtypes]
if dtypes.bit_width(a_dtype) == dtypes.bit_width(dtype)
])
def testViewScalar(self, a_dtype, dtype):
if jtu.test_device_matches(["tpu"]):
if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]:
self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.")
rng = jtu.rand_fullrange(self.rng())
args_maker = lambda: [jnp.array(rng((), a_dtype))]
np_op = lambda x: np.asarray(x).view(dtype)
jnp_op = lambda x: jnp.asarray(x).view(dtype)
# Above may produce signaling nans; ignore warnings from invalid values.
with np.errstate(invalid='ignore'):
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testPathologicalFloats(self):
args_maker = lambda: [np.array([
0b_0111_1111_1000_0000_0000_0000_0000_0000, # inf
0b_1111_1111_1000_0000_0000_0000_0000_0000, # -inf
0b_0111_1111_1100_0000_0000_0000_0000_0000, # qnan
0b_1111_1111_1100_0000_0000_0000_0000_0000, # -qnan
0b_0111_1111_1000_0000_0000_0000_0000_0001, # snan
0b_1111_1111_1000_0000_0000_0000_0000_0001, # -snan
0b_0111_1111_1000_0000_0000_1100_0000_0000, # nonstandard nan
0b_1111_1111_1000_0000_0000_1100_0000_0000, # -nonstandard nan
0b_0000_0000_0000_0000_0000_0000_0000_0000, # zero
0b_1000_0000_0000_0000_0000_0000_0000_0000, # -zero
], dtype='uint32')]
np_op = lambda x: np.asarray(x).view('float32').view('uint32')
jnp_op = lambda x: jnp.asarray(x).view('float32').view('uint32')
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
2018-11-17 18:03:33 -08:00
# TODO(mattjj): test other ndarray-like method overrides
2020-05-20 01:43:48 -03:00
def testNpMean(self):
# from https://github.com/jax-ml/jax/issues/125
x = jnp.eye(3, dtype=float) + 0.
2020-05-20 01:43:48 -03:00
ans = np.mean(x)
self.assertAllClose(ans, np.array(1./3), check_dtypes=False)
2018-12-17 14:26:28 -08:00
def testArangeOnFloats(self):
np_arange = jtu.with_jax_dtype_defaults(np.arange)
# from https://github.com/jax-ml/jax/issues/145
self.assertAllClose(np_arange(0.0, 1.0, 0.1),
jnp.arange(0.0, 1.0, 0.1))
# from https://github.com/jax-ml/jax/issues/3450
self.assertAllClose(np_arange(2.5),
jnp.arange(2.5))
self.assertAllClose(np_arange(0., 2.5),
jnp.arange(0., 2.5))
2021-11-01 11:44:14 -07:00
def testArangeTypes(self):
# Test that arange() output type is equal to the default types.
int_ = dtypes.canonicalize_dtype(jnp.int_)
float_ = dtypes.canonicalize_dtype(jnp.float_)
self.assertEqual(jnp.arange(10).dtype, int_)
self.assertEqual(jnp.arange(10.).dtype, float_)
self.assertEqual(jnp.arange(10, dtype='uint16').dtype, np.uint16)
self.assertEqual(jnp.arange(10, dtype='bfloat16').dtype, jnp.bfloat16)
self.assertEqual(jnp.arange(0, 10, 1).dtype, int_)
with jax.numpy_dtype_promotion('standard'):
self.assertEqual(jnp.arange(0, 10, 1.).dtype, float_)
self.assertEqual(jnp.arange(0., 10, 1).dtype, float_)
2021-11-01 11:44:14 -07:00
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in nonzerodim_shapes
for axis in (NO_VALUE, None, *range(-len(shape), len(shape)))
],
stable=[True, False],
dtype=all_dtypes,
)
def testSort(self, dtype, shape, axis, stable):
rng = jtu.rand_some_equal(self.rng()) if stable else jtu.rand_some_inf_and_nan(self.rng())
args_maker = lambda: [rng(shape, dtype)]
kwds = {} if axis is NO_VALUE else {'axis': axis}
def np_fun(arr):
# Note: numpy sort fails on NaN and Inf values with bfloat16
dtype = arr.dtype
if arr.dtype == jnp.bfloat16:
arr = arr.astype('float32')
# TODO(jakevdp): switch to stable=stable when supported by numpy.
result = np.sort(arr, kind='stable' if stable else None, **kwds)
with jtu.ignore_warning(category=RuntimeWarning, message='invalid value'):
return result.astype(dtype)
jnp_fun = partial(jnp.sort, stable=stable, **kwds)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2019-01-13 09:01:01 -08:00
def testSortStableDescending(self):
# TODO(jakevdp): test directly against np.sort when descending is supported.
x = jnp.array([0, 1, jnp.nan, 0, 2, jnp.nan, -jnp.inf, jnp.inf])
x_sorted = jnp.array([-jnp.inf, 0, 0, 1, 2, jnp.inf, jnp.nan, jnp.nan])
argsorted_stable = jnp.array([6, 0, 3, 1, 4, 7, 2, 5])
argsorted_rev_stable = jnp.array([2, 5, 7, 4, 1, 0, 3, 6])
self.assertArraysEqual(jnp.sort(x), x_sorted)
self.assertArraysEqual(jnp.sort(x, descending=True), lax.rev(x_sorted, [0]))
self.assertArraysEqual(jnp.argsort(x), argsorted_stable)
self.assertArraysEqual(jnp.argsort(x, descending=True), argsorted_rev_stable)
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
def testSortComplex(self, shape, dtype):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.sort_complex, jnp.sort_complex, args_maker,
check_dtypes=False)
self._CompileAndCheck(jnp.sort_complex, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in nonempty_nonscalar_array_shapes
for axis in (-1, *range(len(shape) - 1))
],
dtype=all_dtypes,
input_type=[np.array, tuple],
)
2020-07-22 12:48:49 -07:00
def testLexsort(self, dtype, shape, input_type, axis):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [input_type(rng(shape, dtype))]
jnp_op = lambda x: jnp.lexsort(x, axis=axis)
np_op = jtu.with_jax_dtype_defaults(lambda x: np.lexsort(x, axis=axis))
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
2020-07-22 12:48:49 -07:00
self._CompileAndCheck(jnp_op, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in nonzerodim_shapes
for axis in (NO_VALUE, None, *range(-len(shape), len(shape)))
],
dtype=all_dtypes,
)
def testArgsort(self, dtype, shape, axis):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
kwds = {} if axis is NO_VALUE else {'axis': axis}
@jtu.with_jax_dtype_defaults
def np_fun(arr):
# Note: numpy sort fails on NaN and Inf values with bfloat16
if arr.dtype == jnp.bfloat16:
arr = arr.astype('float32')
# TODO(jakevdp): switch to stable=True when supported by numpy.
return np.argsort(arr, kind='stable', **kwds)
jnp_fun = partial(jnp.argsort, stable=True, **kwds)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in nonempty_nonscalar_array_shapes
for axis in (NO_VALUE, None, *range(-len(shape), len(shape)))
],
descending=[True, False],
dtype=all_dtypes,
)
def testArgsortUnstable(self, dtype, shape, axis, descending):
# We cannot directly compare unstable argsorts, so instead check that indexed values match.
rng = jtu.rand_some_equal(self.rng())
x = rng(shape, dtype)
kwds = {} if axis is NO_VALUE else {'axis': axis}
expected = jnp.sort(x, descending=descending, stable=False, **kwds)
indices = jnp.argsort(x, descending=descending, stable=False, **kwds)
if axis is None:
actual = jnp.ravel(x)[indices]
else:
actual = jnp.take_along_axis(x, indices, axis=-1 if axis is NO_VALUE else axis)
self.assertArraysEqual(actual, expected)
2024-07-24 10:47:27 -07:00
def _assertSamePartitionedArrays(self, jnp_output, np_output, axis, kth, shape):
# Assert that pivot point is equal:
self.assertArraysEqual(
lax.index_in_dim(jnp_output, axis=axis, index=kth),
lax.index_in_dim(np_output, axis=axis, index=kth))
# Assert remaining values are correctly partitioned:
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis))
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
2023-01-30 13:50:25 -08:00
@jtu.sample_product(
[{'shape': shape, 'axis': axis, 'kth': kth}
for shape in nonzerodim_shapes
for axis in range(-len(shape), len(shape))
for kth in range(-shape[axis], shape[axis])],
dtype=default_dtypes,
)
def testPartition(self, shape, dtype, axis, kth):
rng = jtu.rand_default(self.rng())
arg = rng(shape, dtype)
jnp_output = jnp.partition(arg, axis=axis, kth=kth)
np_output = np.partition(arg, axis=axis, kth=kth)
2024-07-24 10:47:27 -07:00
self._assertSamePartitionedArrays(jnp_output, np_output, axis, kth, shape)
2023-01-30 13:50:25 -08:00
@jtu.sample_product(
kth=range(10),
dtype=unsigned_dtypes,
)
def testPartitionUnsignedWithZeros(self, kth, dtype):
# https://github.com/jax-ml/jax/issues/22137
max_val = np.iinfo(dtype).max
arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2]], dtype=dtype)
axis = -1
shape = arg.shape
jnp_output = jnp.partition(arg, axis=axis, kth=kth)
np_output = np.partition(arg, axis=axis, kth=kth)
2024-07-24 10:47:27 -07:00
self._assertSamePartitionedArrays(jnp_output, np_output, axis, kth, shape)
2023-01-30 13:50:25 -08:00
2023-02-08 14:41:39 -08:00
@jtu.sample_product(
[{'shape': shape, 'axis': axis, 'kth': kth}
for shape in nonzerodim_shapes
for axis in range(-len(shape), len(shape))
for kth in range(-shape[axis], shape[axis])],
dtype=default_dtypes,
)
def testArgpartition(self, shape, dtype, axis, kth):
rng = jtu.rand_default(self.rng())
arg = rng(shape, dtype)
jnp_output = jnp.argpartition(arg, axis=axis, kth=kth)
np_output = np.argpartition(arg, axis=axis, kth=kth)
# Assert that all indices are present
self.assertArraysEqual(jnp.sort(jnp_output, axis), np.sort(np_output, axis), check_dtypes=False)
# Because JAX & numpy may treat duplicates differently, we must compare values
# rather than indices.
getvals = lambda x, ind: x[ind]
for ax in range(arg.ndim):
if ax != range(arg.ndim)[axis]:
getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax)
jnp_values = getvals(arg, jnp_output)
np_values = getvals(arg, np_output)
2024-07-24 10:47:27 -07:00
self._assertSamePartitionedArrays(jnp_values, np_values, axis, kth, shape)
2023-02-08 14:41:39 -08:00
@jtu.sample_product(
kth=range(10),
dtype=unsigned_dtypes,
)
def testArgpartitionUnsignedWithZeros(self, kth, dtype):
# https://github.com/jax-ml/jax/issues/22137
max_val = np.iinfo(dtype).max
arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2, 3]], dtype=dtype)
axis = -1
shape = arg.shape
jnp_output = jnp.argpartition(arg, axis=axis, kth=kth)
np_output = np.argpartition(arg, axis=axis, kth=kth)
2023-02-08 14:41:39 -08:00
# Assert that all indices are present
self.assertArraysEqual(jnp.sort(jnp_output, axis), np.sort(np_output, axis), check_dtypes=False)
# Because JAX & numpy may treat duplicates differently, we must compare values
# rather than indices.
getvals = lambda x, ind: x[ind]
for ax in range(arg.ndim):
if ax != range(arg.ndim)[axis]:
getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax)
jnp_values = getvals(arg, jnp_output)
np_values = getvals(arg, np_output)
2024-07-24 10:47:27 -07:00
self._assertSamePartitionedArrays(jnp_values, np_values, axis, kth, shape)
2023-02-08 14:41:39 -08:00
@jtu.sample_product(
[dict(shifts=shifts, axis=axis)
2019-02-18 15:52:32 -05:00
for shifts, axis in [
(3, None),
(1, 1),
((3,), (0,)),
((-2,), (-2,)),
((1, 2), (0, -1)),
((4, 2, 5, 5, 2, 4), None),
(100, None),
]
],
dtype=all_dtypes,
shape=[(3, 4), (3, 4, 5), (7, 4, 0)],
)
def testRoll(self, shape, dtype, shifts, axis):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
args_maker = lambda: [rng(shape, dtype), np.array(shifts)]
jnp_op = partial(jnp.roll, axis=axis)
2020-05-20 01:43:48 -03:00
np_op = partial(np.roll, axis=axis)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
2019-02-18 15:52:32 -05:00
@jtu.sample_product(
dtype=all_dtypes,
shape=[(1, 2, 3, 4)],
axis=[-3, 0, 2, 3],
start=[-4, -1, 2, 4],
)
def testRollaxis(self, shape, dtype, start, axis):
rng = jtu.rand_default(self.rng())
2020-04-12 07:37:02 -07:00
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.rollaxis, axis=axis, start=start)
2020-05-20 01:43:48 -03:00
np_op = partial(np.rollaxis, axis=axis, start=start)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
2020-04-12 07:37:02 -07:00
@jtu.sample_product(
dtype=[np.uint8, np.bool_],
bitorder=['big', 'little'],
shape=[(1, 2, 3, 4)],
axis=[None, 0, 1, -2, -1],
)
def testPackbits(self, shape, dtype, axis, bitorder):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
2020-05-20 01:43:48 -03:00
np_op = partial(np.packbits, axis=axis, bitorder=bitorder)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@jtu.sample_product(
dtype=[np.uint8],
bitorder=['big', 'little'],
shape=[(1, 2, 3, 4)],
axis=[None, 0, 1, -2, -1],
count=[None, 20],
)
def testUnpackbits(self, shape, dtype, axis, bitorder, count):
rng = jtu.rand_int(self.rng(), 0, 256)
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder, count=count)
np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder, count=count)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
2019-02-01 19:32:09 -05:00
for shape in [(3,), (3, 4), (3, 4, 5)]
for axis in itertools.chain(range(-len(shape), len(shape)),
[cast(Union[int, None], None)])
],
index_shape=scalar_shapes + [(3,), (2, 1, 3)],
dtype=all_dtypes,
index_dtype=int_dtypes,
mode=[None, 'wrap', 'clip'],
)
def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode):
2019-02-01 19:32:09 -05:00
def args_maker():
x = rng(shape, dtype)
i = rng_indices(index_shape, index_dtype)
return x, i
rng = jtu.rand_default(self.rng())
if mode is None:
rng_indices = jtu.rand_int(self.rng(), -shape[axis or 0], shape[axis or 0])
else:
rng_indices = jtu.rand_int(self.rng(), -5, 5)
jnp_op = lambda x, i: jnp.take(x, i, axis=axis, mode=mode)
2020-05-20 01:43:48 -03:00
np_op = lambda x, i: np.take(x, i, axis=axis, mode=mode)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
2019-02-01 19:32:09 -05:00
def testTakeEmpty(self):
np.testing.assert_array_equal(
jnp.array([], dtype=jnp.float32),
jnp.take(jnp.array([], jnp.float32), jnp.array([], jnp.int32)))
np.testing.assert_array_equal(
jnp.ones((2, 0, 4), dtype=jnp.float32),
jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32), jnp.array([], jnp.int32),
axis=1))
with self.assertRaisesRegex(IndexError, "non-empty jnp.take"):
jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32),
jnp.array([0], jnp.int32), axis=1)
def testTakeOptionalArgs(self):
x = jnp.arange(5.0)
ind = jnp.array([0, 2, 4, 6])
expected = jnp.array([0.0, 2.0, 4.0, 10.0], dtype=x.dtype)
actual = jnp.take(x, ind, unique_indices=True,
indices_are_sorted=True, fill_value=10.0)
self.assertArraysEqual(expected, actual)
@jtu.sample_product(
[dict(x_shape=x_shape, i_shape=i_shape, axis=axis)
for x_shape, i_shape in filter(
_shapes_are_equal_length,
filter(_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(nonempty_nonscalar_array_shapes, 2)))
for axis in itertools.chain(range(len(x_shape)), [-1],
[cast(Union[int, None], None)])
],
dtype=default_dtypes,
index_dtype=int_dtypes,
)
def testTakeAlongAxis(self, x_shape, i_shape, dtype, index_dtype, axis):
rng = jtu.rand_default(self.rng())
i_shape = list(i_shape)
if axis is None:
i_shape = [math.prod(i_shape)]
else:
# Test the case where the size of the axis doesn't necessarily broadcast.
i_shape[axis] *= 3
def args_maker():
x = rng(x_shape, dtype)
n = math.prod(x_shape) if axis is None else x_shape[axis]
if np.issubdtype(index_dtype, np.unsignedinteger):
index_rng = jtu.rand_int(self.rng(), 0, n)
else:
index_rng = jtu.rand_int(self.rng(), -n, n)
i = index_rng(i_shape, index_dtype)
return x, i
jnp_op = lambda x, i: jnp.take_along_axis(x, i, axis=axis)
jnp_one_hot_op = lambda x, i: jnp.take_along_axis(
x, i, axis=axis, mode='one_hot'
)
2020-05-20 01:43:48 -03:00
if hasattr(np, "take_along_axis"):
np_op = lambda x, i: np.take_along_axis(x, i, axis=axis)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CheckAgainstNumpy(np_op, jnp_one_hot_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
self._CompileAndCheck(jnp_one_hot_op, args_maker)
def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self):
# https://github.com/jax-ml/jax/issues/5088
h = jtu.rand_default(self.rng())((256, 256, 100), np.float32)
g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8)
q0 = jnp.take_along_axis(h, g, axis=-1)
q1 = np.take_along_axis( h, g, axis=-1)
np.testing.assert_equal(q0, q1)
#jax Optimize `jax.numpy.take_along_axis` along the dimension satisfies * the dimension is not the one along which to take values * the dimension size of input tensor is 1 * the dimension size of the indices is not 1 Previously, we create constant zero as the dummy indices, which is redundant. We can squeeze the input tensor and generate the `stablehlo.gather` directly. In the following example, ``` h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32) g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8) q0 = jnp.take_along_axis(h, g, axis=-2) ``` It lowers to the following module before this change, ``` module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) { %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32) return %0 : tensor<2x3x5x11x13xf32> loc(#loc) } loc(#loc) func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> { %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33) %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32) %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34) %2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35) %3 = stablehlo.compare LT, %0, %2, SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc35) %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32) %4 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc36) %5 = stablehlo.add %0, %4 : tensor<2x3x5x11x1xi32> loc(#loc36) %6 = stablehlo.select %3, %5, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc37) %7 = stablehlo.concatenate %1, %6, dim = 4 : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x2xi32> loc(#loc38) %c_1 = stablehlo.constant dense<[0, 6]> : tensor<2xi64> loc(#loc39) %8 = stablehlo.convert %7 : (tensor<2x3x5x11x2xi32>) -> tensor<2x3x5x11x2xi64> loc(#loc33) %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc39) %9 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x2xi64> loc(#loc40) %10 = stablehlo.compare GE, %8, %9, SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc40) %11 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<2xi64>) -> tensor<1x1x1x1x2xi64> loc(#loc34) %12 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x2xi64>) -> tensor<2x3x5x11x2xi64> loc(#loc41) %13 = stablehlo.compare LE, %8, %12, SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc41) %14 = stablehlo.and %10, %13 : tensor<2x3x5x11x2xi1> loc(#loc42) %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43) %15 = stablehlo.reduce(%14 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x2xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43) %16 = "stablehlo.gather"(%arg0, %8) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [1, 3], operand_batching_dims = [0, 2], start_indices_batching_dims = [0, 2], start_index_map = [1, 3], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 1, 13>}> : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc39) %17 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc34) %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc39) %18 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc34) %19 = stablehlo.select %17, %16, %18 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc37) return %19 : tensor<2x3x5x11x13xf32> loc(#loc32) } } ``` With this change, we have ``` module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) { %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32) return %0 : tensor<2x3x5x11x13xf32> loc(#loc) } loc(#loc) func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> { %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33) %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32) %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34) %2 = stablehlo.compare LT, %0, %1, SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc34) %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32) %3 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35) %4 = stablehlo.add %0, %3 : tensor<2x3x5x11x1xi32> loc(#loc35) %5 = stablehlo.select %2, %4, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc36) %6 = stablehlo.reshape %arg0 : (tensor<2x1x5x7x13xf32>) -> tensor<2x5x7x13xf32> loc(#loc37) %c_1 = stablehlo.constant dense<6> : tensor<1xi64> loc(#loc38) %7 = stablehlo.convert %5 : (tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi64> loc(#loc33) %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc38) %8 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x1xi64> loc(#loc39) %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc39) %10 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<1xi64>) -> tensor<1x1x1x1x1xi64> loc(#loc40) %11 = stablehlo.broadcast_in_dim %10, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x1xi64>) -> tensor<2x3x5x11x1xi64> loc(#loc41) %12 = stablehlo.compare LE, %7, %11, SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc41) %13 = stablehlo.and %9, %12 : tensor<2x3x5x11x1xi1> loc(#loc42) %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43) %14 = stablehlo.reduce(%13 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x1xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43) %15 = "stablehlo.gather"(%6, %7) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [2], operand_batching_dims = [0, 1], start_indices_batching_dims = [0, 2], start_index_map = [2], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 13>}> : (tensor<2x5x7x13xf32>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc38) %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc40) %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc38) %17 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc40) %18 = stablehlo.select %16, %15, %17 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc36) return %18 : tensor<2x3x5x11x13xf32> loc(#loc32) } } ``` PiperOrigin-RevId: 725506779
2025-02-11 00:08:08 -08:00
def testTakeAlongAxisInputTensorHasSingletonDimension(self):
h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32)
g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8)
q0 = jnp.take_along_axis(h, g, axis=-2)
q1 = np.take_along_axis( h, g, axis=-2)
np.testing.assert_equal(q0, q1)
def testTakeAlongAxisOutOfBounds(self):
x = jnp.arange(10, dtype=jnp.float32)
idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11])
out = jnp.take_along_axis(x, idx, axis=0)
expected_fill = np.array([jnp.nan, 0, 1, 5, 9, 0, 1, 5, 9, jnp.nan,
jnp.nan], np.float32)
np.testing.assert_array_equal(expected_fill, out)
out = jnp.take_along_axis(x, idx, axis=0, mode="fill")
np.testing.assert_array_equal(expected_fill, out)
expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32)
out = jnp.take_along_axis(x, idx, axis=0, mode="clip")
np.testing.assert_array_equal(expected_clip, out)
def testTakeAlongAxisRequiresIntIndices(self):
x = jnp.arange(5)
idx = jnp.array([3.], jnp.float32)
with self.assertRaisesRegex(
TypeError,
"take_along_axis indices must be of integer type, got float32"):
jnp.take_along_axis(x, idx, axis=0)
def testTakeAlongAxisWithEmptyArgs(self):
# take_along_axis should allow us to gather an empty list of indices from
# an empty input axis without raising a shape error.
x = jnp.ones((4, 0, 3), dtype=jnp.int32)
np.testing.assert_array_equal(x, jnp.take_along_axis(x, x, axis=1))
def testTakeAlongAxisOptionalArgs(self):
x = jnp.arange(5.0)
ind = jnp.array([0, 2, 4, 6])
expected = jnp.array([0.0, 2.0, 4.0, 10.0], dtype=x.dtype)
actual = jnp.take_along_axis(x, ind, axis=None, mode='fill', fill_value=10.0)
self.assertArraysEqual(expected, actual)
@jtu.sample_product(
dtype=inexact_dtypes,
shape=[0, 5],
n=[2, 4],
increasing=[False, True],
)
def testVander(self, shape, dtype, n, increasing):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
def np_fun(arg):
arg = arg.astype(np.float32) if dtype == jnp.bfloat16 else arg
return np.vander(arg, N=n, increasing=increasing)
jnp_fun = lambda arg: jnp.vander(arg, N=n, increasing=increasing)
args_maker = lambda: [rng([shape], dtype)]
# np.vander seems to return float64 for all floating types. We could obey
# those semantics, but they seem like a bug.
2020-05-20 01:43:48 -03:00
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol={np.float32: 1e-3, np.complex64: 1e-3})
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False)
@jtu.sample_product(
shape=array_shapes,
dtype=all_dtypes,
)
def testNanToNum(self, shape, dtype):
rng = jtu.rand_some_inf_and_nan(self.rng())
2020-05-20 01:43:48 -03:00
dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type
def np_fun(x):
if dtype == jnp.bfloat16:
2020-05-20 01:43:48 -03:00
x = np.where(np.isnan(x), dtype(0), x)
x = np.where(np.isposinf(x), jnp.finfo(dtype).max, x)
x = np.where(np.isneginf(x), jnp.finfo(dtype).min, x)
return x
else:
2020-05-20 01:43:48 -03:00
return np.nan_to_num(x).astype(dtype)
args_maker = lambda: [rng(shape, dtype)]
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE
2020-05-20 01:43:48 -03:00
self._CheckAgainstNumpy(np_fun, jnp.nan_to_num, args_maker,
check_dtypes=check_dtypes)
self._CompileAndCheck(jnp.nan_to_num, args_maker,
check_dtypes=check_dtypes)
@jtu.sample_product(
[dict(shapes=shapes, dtypes=dtypes)
2019-06-17 17:08:27 -04:00
for shapes, dtypes in (
((), ()),
2020-05-20 01:43:48 -03:00
(((7,),), (np.int32,)),
(((3,), (4,)), (np.int32, np.int32)),
(((3,), (1,), (4,)), (np.int32, np.int32, np.int32)),
)
],
)
def testIx_(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
2019-06-17 17:08:27 -04:00
args_maker = lambda: [rng(shape, dtype)
for shape, dtype in zip(shapes, dtypes)]
self._CheckAgainstNumpy(np.ix_, jnp.ix_, args_maker)
self._CompileAndCheck(jnp.ix_, args_maker)
2019-06-17 17:08:27 -04:00
@jtu.sample_product(
dimensions=[(), (2,), (3, 0), (4, 5, 6)],
dtype=number_dtypes,
sparse=[True, False],
)
def testIndices(self, dimensions, dtype, sparse):
def args_maker(): return []
2021-06-10 12:12:13 -04:00
np_fun = partial(np.indices, dimensions=dimensions,
dtype=dtype, sparse=sparse)
jnp_fun = partial(jnp.indices, dimensions=dimensions,
dtype=dtype, sparse=sparse)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2024-07-24 10:47:27 -07:00
def testIndicesDefaultDtype(self):
self.assertEqual(jnp.indices((2, 3)).dtype,
dtypes.canonicalize_dtype(np.int64))
@jtu.sample_product(
shape=nonzerodim_shapes,
dtype=all_dtypes,
)
def testWhereOneArgument(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)
2021-06-08 14:04:04 -07:00
# JIT compilation requires specifying a size statically. Full test of
# this behavior is in testNonzeroSize().
jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2)
self._CompileAndCheck(jnp_fun, args_maker)
2021-06-08 14:04:04 -07:00
@jtu.sample_product(
shapes=filter(_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(all_shapes, 3)),
dtypes=itertools.combinations_with_replacement(all_dtypes, 3),
)
def testWhereThreeArgument(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
2020-05-20 01:43:48 -03:00
def np_fun(cond, x, y):
2022-10-06 10:20:26 -07:00
return jtu.promote_like_jnp(partial(np.where, cond))(x, y)
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(np_fun, jnp.where, args_maker)
self._CompileAndCheck(jnp.where, args_maker)
def testWhereExtraCode(self):
def f(x):
return jnp.where(x > 0, x, -x)
jaxpr = jax.make_jaxpr(jax.grad(f))(3.)
# Test no comparison literal True/False in jaxpr, and hence no comparison to
# literals
self.assertNotIn('False', str(jaxpr))
self.assertNotIn('True', str(jaxpr))
def testWhereScalarPromotion(self):
x = jnp.where(jnp.array([True, False]), 3,
jnp.ones((2,), dtype=jnp.float32))
2020-05-20 01:43:48 -03:00
self.assertEqual(x.dtype, np.dtype(np.float32))
@jtu.sample_product(
[dict(n=n, shapes=shapes)
for n in range(1, 3)
for shapes in filter(
2019-06-24 09:27:01 -04:00
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(all_shapes, 2 * n + 1))
],
# To avoid forming the full product of shapes and dtypes we always sample
# maximal set of dtypes.
dtypes=itertools.combinations_with_replacement(all_dtypes, 3),
)
2024-04-22 05:39:11 -07:00
@jax.numpy_rank_promotion('allow')
def testSelect(self, n, shapes, dtypes):
dtypes = dtypes[:n+1]
rng = jtu.rand_default(self.rng())
2019-06-24 09:27:01 -04:00
n = len(dtypes) - 1
def args_maker():
2020-05-20 01:43:48 -03:00
condlist = [rng(shape, np.bool_) for shape in shapes[:n]]
2019-06-24 09:27:01 -04:00
choicelist = [rng(shape, dtype)
for shape, dtype in zip(shapes[n:-1], dtypes[:n])]
default = rng(shapes[-1], dtypes[-1])
return condlist, choicelist, default
# TODO(phawkins): float32/float64 type mismatches
@jax.numpy_dtype_promotion('standard')
2020-05-20 01:43:48 -03:00
def np_fun(condlist, choicelist, default):
choicelist = [x if jnp.result_type(x) != jnp.bfloat16
2020-05-20 01:43:48 -03:00
else x.astype(np.float32) for x in choicelist]
dtype = jnp.result_type(default, *choicelist)
2020-05-20 01:43:48 -03:00
return np.select(condlist,
[np.asarray(x).astype(dtype) for x in choicelist],
2020-05-20 01:43:48 -03:00
np.asarray(default, dtype=dtype))
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(np_fun, jnp.select, args_maker,
check_dtypes=False)
self._CompileAndCheck(jnp.select, args_maker,
rtol={np.float64: 1e-7, np.complex128: 1e-7})
2019-06-24 09:27:01 -04:00
def testIssue330(self):
x = jnp.full((1, 1), jnp.array([1])[0]) # doesn't crash
self.assertEqual(x[0, 0], 1)
def testScalarDtypePromotion(self):
2020-05-20 01:43:48 -03:00
orig_numpy_result = (1 + np.eye(1, dtype=np.float32)).dtype
jax_numpy_result = (1 + jnp.eye(1, dtype=jnp.float32)).dtype
self.assertEqual(orig_numpy_result, jax_numpy_result)
2019-02-13 08:52:42 -08:00
def testSymmetrizeDtypePromotion(self):
2020-05-20 01:43:48 -03:00
x = np.eye(3, dtype=np.float32)
2019-02-13 08:52:42 -08:00
orig_numpy_result = ((x + x.T) / 2).dtype
x = jnp.eye(3, dtype=jnp.float32)
2019-02-13 08:52:42 -08:00
jax_numpy_result = ((x + x.T) / 2).dtype
self.assertEqual(orig_numpy_result, jax_numpy_result)
# NOTE(mattjj): I disabled this test when removing lax._safe_mul because
# introducing the convention 0 * inf = 0 leads to silently wrong results in
# some cases. See this comment for details:
# https://github.com/jax-ml/jax/issues/1052#issuecomment-514083352
# def testIssue347(self):
# # https://github.com/jax-ml/jax/issues/347
# def test_fail(x):
# x = jnp.sqrt(jnp.sum(x ** 2, axis=1))
# ones = jnp.ones_like(x)
# x = jnp.where(x > 0.5, x, ones)
# return jnp.sum(x)
# x = jnp.array([[1, 2], [3, 4], [0, 0]], dtype=jnp.float64)
# result = jax.grad(test_fail)(x)
2020-05-20 01:43:48 -03:00
# assert not np.any(np.isnan(result))
def testIssue453(self):
# https://github.com/jax-ml/jax/issues/453
2020-05-20 01:43:48 -03:00
a = np.arange(6) + 1
ans = jnp.reshape(a, (3, 2), order='F')
2020-05-20 01:43:48 -03:00
expected = np.reshape(a, (3, 2), order='F')
self.assertAllClose(ans, expected)
@jtu.sample_product(
dtype=[int, float, bool, complex],
op=["atleast_1d", "atleast_2d", "atleast_3d"],
)
def testAtLeastNdLiterals(self, dtype, op):
# Fixes: https://github.com/jax-ml/jax/issues/634
np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype])
jnp_fun = lambda arg: getattr(jnp, op)(arg)
args_maker = lambda: [dtype(2)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
shape=[(0,), (5,), (10,)],
2024-11-11 06:42:06 -08:00
dtype=int_dtypes + bool_dtypes,
weights=[True, False],
minlength=[0, 20],
length=[None, 8],
)
def testBincount(self, shape, dtype, weights, minlength, length):
2021-10-15 12:31:17 -07:00
rng = jtu.rand_default(self.rng())
2020-05-07 13:17:43 -07:00
args_maker = lambda: (rng(shape, dtype), (rng(shape, 'float32') if weights else None))
2021-10-15 12:31:17 -07:00
def np_fun(x, *args):
x = np.clip(x, 0, None) # jnp.bincount clips negative values to zero.
out = np.bincount(x, *args, minlength=minlength)
if length and length > out.size:
return np.pad(out, (0, length - out.size))
return out[:length]
2020-05-07 13:17:43 -07:00
jnp_fun = partial(jnp.bincount, minlength=minlength, length=length)
2021-10-15 12:31:17 -07:00
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
2020-05-07 13:17:43 -07:00
if length is not None:
self._CompileAndCheck(jnp_fun, args_maker)
2020-05-07 13:17:43 -07:00
def testBincountNegative(self):
# Test that jnp.bincount ignores negative values.
x_rng = jtu.rand_int(self.rng(), -100, 100)
w_rng = jtu.rand_uniform(self.rng())
shape = (1000,)
x = x_rng(shape, 'int32')
w = w_rng(shape, 'float32')
2020-05-20 01:43:48 -03:00
xn = np.array(x)
2020-05-07 13:17:43 -07:00
xn[xn < 0] = 0
2020-05-20 01:43:48 -03:00
wn = np.array(w)
np_result = np.bincount(xn[xn >= 0], wn[xn >= 0])
2020-05-07 13:17:43 -07:00
jnp_result = jnp.bincount(x, w)
2020-05-20 01:43:48 -03:00
self.assertAllClose(np_result, jnp_result, check_dtypes=False)
2020-05-07 13:17:43 -07:00
@jtu.sample_product(
input=[
3,
[3],
[np.array(3)],
[np.array([3])],
[[np.array(3)]],
[[np.array([3])]],
[3, 4, 5],
[
[np.eye(2, dtype=np.int32) * 2, np.zeros((2, 3), dtype=np.int32)],
[np.ones((3, 2), dtype=np.int32), np.eye(3, dtype=np.int32) * 3],
],
[np.array([1, 2, 3]), np.array([2, 3, 4]), 10],
[np.ones((2, 2), dtype=np.int32), np.zeros((2, 2), dtype=np.int32)],
[[np.array([1, 2, 3])], [np.array([2, 3, 4])]],
],
)
def testBlock(self, input):
args_maker = lambda: [input]
self._CheckAgainstNumpy(np.block, jnp.block, args_maker)
self._CompileAndCheck(jnp.block, args_maker)
def testLongLong(self):
self.assertAllClose(np.int64(7), jax.jit(lambda x: x)(np.longlong(7)))
2020-12-08 13:03:30 -08:00
@jtu.ignore_warning(category=UserWarning,
message="Explicitly requested dtype.*")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testArange(self):
# test cases inspired by dask tests at
# https://github.com/dask/dask/blob/main/dask/array/tests/test_creation.py#L92
np_arange = jtu.with_jax_dtype_defaults(np.arange)
self.assertAllClose(jnp.arange(77),
np_arange(77))
self.assertAllClose(jnp.arange(2, 13),
np_arange(2, 13))
self.assertAllClose(jnp.arange(4, 21, 9),
np_arange(4, 21, 9))
self.assertAllClose(jnp.arange(53, 5, -3),
np_arange(53, 5, -3))
self.assertAllClose(jnp.arange(77, dtype=float),
np_arange(77, dtype=float))
self.assertAllClose(jnp.arange(2, 13, dtype=int),
np_arange(2, 13, dtype=int))
self.assertAllClose(jnp.arange(0, 1, -0.5),
np_arange(0, 1, -0.5))
self.assertRaises(TypeError, lambda: jnp.arange())
# test that jnp.arange(N) doesn't instantiate an ndarray
2020-05-20 01:43:48 -03:00
self.assertNotEqual(type(jnp.arange(77)), type(np.arange(77)))
self.assertEqual(type(jnp.arange(77)), type(lax.iota(np.int32, 77)))
# test that jnp.arange(N, dtype=int32) doesn't instantiate an ndarray
self.assertNotEqual(type(jnp.arange(77, dtype=jnp.int32)),
type(np.arange(77, dtype=np.int32)))
self.assertEqual(type(jnp.arange(77, dtype=jnp.int32)),
type(lax.iota(np.int32, 77)))
def testArangeJit(self):
ans = jax.jit(lambda: jnp.arange(5))()
expected = jtu.with_jax_dtype_defaults(np.arange)(5)
self.assertAllClose(ans, expected)
@jtu.sample_product(
args=[(5,), (0, 5)],
specify_device=[True, False],
)
def testArangeJaxpr(self, args, specify_device):
device = jax.devices()[-1] if specify_device else None
kwargs = {"device": device}
jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args, **kwargs))()
# We have 2 statements in jaxpr:
# [a:i32[5] = iota[dimension=0 dtype=int32 shape=(5,)],
# a:i32[5] = device_put[devices=[None] srcs=[None]] b]
num_eqs = 2 if device is not None else 1
self.assertEqual(len(jaxpr.jaxpr.eqns), num_eqs)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p)
def testIssue830(self):
a = jnp.arange(4, dtype=jnp.complex64)
self.assertEqual(a.dtype, jnp.complex64)
2019-05-17 12:48:46 -07:00
def testIssue728(self):
np_eye = jtu.with_jax_dtype_defaults(np.eye)
self.assertAllClose(jnp.eye(5000), np_eye(5000))
self.assertEqual(0, np.sum(jnp.eye(1050) - np_eye(1050)))
2019-05-17 12:48:46 -07:00
def testIssue746(self):
jnp.arange(12).reshape(3, 4) # doesn't crash
def testIssue764(self):
x = jnp.linspace(190, 200, 4)
f = jax.grad(lambda x: jnp.sum(jnp.tanh(x)))
# Expected values computed with autograd in float64 precision.
2020-05-20 01:43:48 -03:00
expected = np.array([3.71669453e-165, 4.72999108e-168, 6.01954653e-171,
7.66067839e-174], np.float64)
self.assertAllClose(f(x), expected, check_dtypes=False)
2024-01-18 13:13:47 -08:00
# Test removed because tie_in is deprecated.
# def testIssue776(self):
# """Tests that the scatter-add transpose rule instantiates symbolic zeros."""
# def f(u):
# y = jnp.ones_like(u, shape=10).at[np.array([2, 4, 5])].add(u)
# # The transpose rule for lax.tie_in returns a symbolic zero for its first
# # argument.
# return lax.tie_in(y, 7.)
# self.assertAllClose(np.zeros(3,), jax.grad(f)(np.ones(3,)))
# NOTE(mattjj): I disabled this test when removing lax._safe_mul because this
# is a numerical stability issue that should be solved with a custom jvp rule
# of the sigmoid function being differentiated here, not by safe_mul.
# def testIssue777(self):
2020-05-20 01:43:48 -03:00
# x = jnp.linspace(-200, 0, 4, dtype=np.float32)
# f = jax.grad(lambda x: jnp.sum(1 / (1 + jnp.exp(-x))))
# self.assertAllClose(f(x), np.array([0., 0., 0., 0.25], dtype=np.float32))
@jtu.sample_product(
dtype=float_dtypes,
op=("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan",
"sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp",
"log", "expm1", "log1p"),
)
def testMathSpecialFloatValues(self, op, dtype):
2020-05-20 01:43:48 -03:00
np_op = getattr(np, op)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="invalid value.*")(np_op)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="divide by zero.*")(np_op)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="overflow.*")(np_op)
jnp_op = getattr(jnp, op)
2020-05-20 01:43:48 -03:00
dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type
for x in (np.nan, -np.inf, -100., -2., -1., 0., 1., 2., 100., np.inf,
jnp.finfo(dtype).max, np.sqrt(jnp.finfo(dtype).max),
np.sqrt(jnp.finfo(dtype).max) * 2.):
x = dtype(x)
2020-05-20 01:43:48 -03:00
expected = np_op(x)
actual = jnp_op(x)
2020-05-20 01:43:48 -03:00
tol = jtu.tolerance(dtype, {np.float32: 1e-3, np.float64: 1e-7})
self.assertAllClose(expected, actual, atol=tol,
rtol=tol)
def testIssue956(self):
self.assertRaises(TypeError, lambda: jnp.ndarray((1, 1)))
def testIssue967(self):
self.assertRaises(TypeError, lambda: jnp.zeros(1.5))
@jtu.sample_product(
shape=[(5,), (10, 5), (4, 10)],
dtype=number_dtypes,
rowvar=[True, False],
)
@jax.default_matmul_precision("float32")
2020-05-11 12:09:54 -04:00
def testCorrCoef(self, shape, dtype, rowvar):
rng = jtu.rand_default(self.rng())
def args_maker():
ok = False
while not ok:
x = rng(shape, dtype)
2020-05-20 01:43:48 -03:00
ok = not np.any(np.isclose(np.std(x), 0.0))
2020-05-11 12:09:54 -04:00
return (x,)
2020-05-20 01:43:48 -03:00
np_fun = partial(np.corrcoef, rowvar=rowvar)
np_fun = jtu.ignore_warning(
category=RuntimeWarning, message="invalid value encountered.*")(np_fun)
jnp_fun = partial(jnp.corrcoef, rowvar=rowvar)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
2019-07-28 15:17:23 -04:00
@jtu.sample_product(
[dict(dtype=dtype, end_dtype=end_dtype, begin_dtype=begin_dtype,
shape=shape, begin_shape=begin_shape, end_shape=end_shape)
for dtype in number_dtypes
for end_dtype in [None] + [dtype]
for begin_dtype in [None] + [dtype]
for shape in [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]
for begin_shape in (
[None] if begin_dtype is None
else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE])
for end_shape in (
[None] if end_dtype is None
else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE])
],
)
def testEDiff1d(self, shape, dtype, end_shape, end_dtype, begin_shape,
begin_dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype),
(None if end_dtype is None else rng(end_shape, end_dtype)),
(None if begin_dtype is None else rng(begin_shape, begin_dtype))]
2020-05-20 01:43:48 -03:00
np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin)
jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testEDiff1dWithDtypeCast(self):
rng = jtu.rand_default(self.rng())
shape = jtu.NUMPY_SCALAR_SHAPE
dtype = jnp.float32
end_dtype = jnp.int32
args_maker = lambda: [rng(shape, dtype), rng(shape, end_dtype), rng(shape, dtype)]
2020-05-20 01:43:48 -03:00
np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin)
jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
shapes=[(), (5,), (5, 3)],
dtype=number_dtypes,
indexing=['xy', 'ij'],
sparse=[True, False],
)
def testMeshGrid(self, shapes, dtype, indexing, sparse):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [(x,) for x in shapes],
[dtype] * len(shapes))
2020-05-20 01:43:48 -03:00
np_fun = partial(np.meshgrid, indexing=indexing, sparse=sparse)
jnp_fun = partial(jnp.meshgrid, indexing=indexing, sparse=sparse)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2021-03-27 16:56:32 +09:00
def testMgrid(self):
# wrap indexer for appropriate dtype defaults.
np_mgrid = _indexer_with_default_outputs(np.mgrid)
2021-03-27 16:56:32 +09:00
assertAllEqual = partial(self.assertAllClose, atol=0, rtol=0)
2023-01-12 12:57:30 +01:00
assertAllEqual(np_mgrid[()], jnp.mgrid[()])
assertAllEqual(np_mgrid[:4], jnp.mgrid[:4])
assertAllEqual(np_mgrid[:4,], jnp.mgrid[:4,])
assertAllEqual(np_mgrid[:4], jax.jit(lambda: jnp.mgrid[:4])())
assertAllEqual(np_mgrid[:5, :5], jnp.mgrid[:5, :5])
assertAllEqual(np_mgrid[:3, :2], jnp.mgrid[:3, :2])
assertAllEqual(np_mgrid[1:4:2], jnp.mgrid[1:4:2])
assertAllEqual(np_mgrid[1:5:3, :5], jnp.mgrid[1:5:3, :5])
assertAllEqual(np_mgrid[:3, :2, :5], jnp.mgrid[:3, :2, :5])
assertAllEqual(np_mgrid[:3:2, :2, :5], jnp.mgrid[:3:2, :2, :5])
2021-03-27 16:56:32 +09:00
# Corner cases
assertAllEqual(np_mgrid[:], jnp.mgrid[:])
# When the step length is a complex number, because of float calculation,
2021-03-27 16:56:32 +09:00
# the values between jnp and np might slightly different.
atol = 1e-6
rtol = 1e-6
self.assertAllClose(np_mgrid[-1:1:5j],
2021-03-27 16:56:32 +09:00
jnp.mgrid[-1:1:5j],
atol=atol,
rtol=rtol)
self.assertAllClose(np_mgrid[3:4:7j],
2021-03-27 16:56:32 +09:00
jnp.mgrid[3:4:7j],
atol=atol,
rtol=rtol)
self.assertAllClose(np_mgrid[1:6:8j, 2:4],
2021-03-27 16:56:32 +09:00
jnp.mgrid[1:6:8j, 2:4],
atol=atol,
rtol=rtol)
# Non-integer steps
self.assertAllClose(np_mgrid[0:3.5:0.5],
2021-03-27 16:56:32 +09:00
jnp.mgrid[0:3.5:0.5],
atol=atol,
rtol=rtol)
self.assertAllClose(np_mgrid[1.3:4.2:0.3],
2021-03-27 16:56:32 +09:00
jnp.mgrid[1.3:4.2:0.3],
atol=atol,
rtol=rtol)
# abstract tracer value for jnp.mgrid slice
with self.assertRaisesRegex(core.ConcretizationTypeError,
"slice start of jnp.mgrid"):
jax.jit(lambda a, b: jnp.mgrid[a:b])(0, 2)
2021-03-27 16:56:32 +09:00
2021-04-05 10:35:45 +09:00
def testOgrid(self):
# wrap indexer for appropriate dtype defaults.
np_ogrid = _indexer_with_default_outputs(np.ogrid)
def assertSequenceOfArraysEqual(xs, ys):
self.assertIsInstance(xs, (list, tuple))
self.assertIsInstance(ys, (list, tuple))
2021-04-05 10:35:45 +09:00
self.assertEqual(len(xs), len(ys))
for x, y in zip(xs, ys):
self.assertArraysEqual(x, y)
self.assertArraysEqual(np_ogrid[:5], jnp.ogrid[:5])
self.assertArraysEqual(np_ogrid[:5], jax.jit(lambda: jnp.ogrid[:5])())
self.assertArraysEqual(np_ogrid[1:7:2], jnp.ogrid[1:7:2])
2021-04-05 10:35:45 +09:00
# List of arrays
assertSequenceOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,])
assertSequenceOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3])
assertSequenceOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3])
assertSequenceOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11])
2021-04-05 10:35:45 +09:00
# Corner cases
self.assertArraysEqual(np_ogrid[:], jnp.ogrid[:])
2021-04-05 10:35:45 +09:00
# Complex number steps
atol = 1e-6
rtol = 1e-6
self.assertAllClose(np_ogrid[-1:1:5j],
2021-04-05 10:35:45 +09:00
jnp.ogrid[-1:1:5j],
atol=atol,
rtol=rtol)
# Non-integer steps
self.assertAllClose(np_ogrid[0:3.5:0.3],
2021-04-05 10:35:45 +09:00
jnp.ogrid[0:3.5:0.3],
atol=atol,
rtol=rtol)
self.assertAllClose(np_ogrid[1.2:4.8:0.24],
2021-04-05 10:35:45 +09:00
jnp.ogrid[1.2:4.8:0.24],
atol=atol,
rtol=rtol)
# abstract tracer value for ogrid slice
with self.assertRaisesRegex(core.ConcretizationTypeError,
"slice start of jnp.ogrid"):
jax.jit(lambda a, b: jnp.ogrid[a:b])(0, 2)
2021-04-05 10:35:45 +09:00
2021-05-01 01:05:22 +02:00
def testR_(self):
a = np.arange(6).reshape((2,3))
self.assertArraysEqual(np.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])],
jnp.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])])
self.assertArraysEqual(np.r_['-1', a, a], jnp.r_['-1', a, a])
self.assertArraysEqual(np.r_['0,2', [1,2,3], [4,5,6]], jnp.r_['0,2', [1,2,3], [4,5,6]])
self.assertArraysEqual(np.r_['0,2,0', [1,2,3], [4,5,6]], jnp.r_['0,2,0', [1,2,3], [4,5,6]])
self.assertArraysEqual(np.r_['1,2,0', [1,2,3], [4,5,6]], jnp.r_['1,2,0', [1,2,3], [4,5,6]])
2021-05-01 01:05:22 +02:00
# negative 1d axis start
self.assertArraysEqual(np.r_['0,4,-1', [1,2,3], [4,5,6]], jnp.r_['0,4,-1', [1,2,3], [4,5,6]])
self.assertArraysEqual(np.r_['0,4,-2', [1,2,3], [4,5,6]], jnp.r_['0,4,-2', [1,2,3], [4,5,6]])
2021-05-01 01:05:22 +02:00
# matrix directives
2023-11-30 10:35:24 -08:00
with jtu.ignore_warning(category=PendingDeprecationWarning):
self.assertArraysEqual(np.r_['r',[1,2,3], [4,5,6]], jnp.r_['r',[1,2,3], [4,5,6]])
self.assertArraysEqual(np.r_['c', [1, 2, 3], [4, 5, 6]], jnp.r_['c', [1, 2, 3], [4, 5, 6]])
2021-05-01 01:05:22 +02:00
# bad directive
with self.assertRaisesRegex(ValueError, "could not understand directive.*"):
jnp.r_["asdfgh",[1,2,3]]
# abstract tracer value for r_ slice
with self.assertRaisesRegex(core.ConcretizationTypeError,
"slice start of jnp.r_"):
jax.jit(lambda a, b: jnp.r_[a:b])(0, 2)
2021-05-01 01:05:22 +02:00
# wrap indexer for appropriate dtype defaults.
np_r_ = _indexer_with_default_outputs(np.r_)
2021-05-01 01:05:22 +02:00
# Complex number steps
atol = 1e-6
rtol = 1e-6
self.assertAllClose(np_r_[-1:1:6j],
2021-05-01 01:05:22 +02:00
jnp.r_[-1:1:6j],
atol=atol,
rtol=rtol)
with jax.numpy_dtype_promotion('standard'): # Requires dtype promotion.
self.assertAllClose(np_r_[-1:1:6j, [0]*3, 5, 6],
jnp.r_[-1:1:6j, [0]*3, 5, 6],
atol=atol,
rtol=rtol)
2021-05-01 01:05:22 +02:00
# Non-integer steps
self.assertAllClose(np_r_[1.2:4.8:0.24],
2021-05-01 01:05:22 +02:00
jnp.r_[1.2:4.8:0.24],
atol=atol,
rtol=rtol)
def testC_(self):
a = np.arange(6).reshape((2, 3))
self.assertArraysEqual(np.c_[np.array([1,2,3]), np.array([4,5,6])],
jnp.c_[np.array([1,2,3]), np.array([4,5,6])])
self.assertArraysEqual(np.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])],
jnp.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])])
self.assertArraysEqual(np.c_['-1', a, a], jnp.c_['-1', a, a])
self.assertArraysEqual(np.c_['0,2', [1,2,3], [4,5,6]], jnp.c_['0,2', [1,2,3], [4,5,6]])
self.assertArraysEqual(np.c_['0,2,0', [1,2,3], [4,5,6]], jnp.c_['0,2,0', [1,2,3], [4,5,6]])
self.assertArraysEqual(np.c_['1,2,0', [1,2,3], [4,5,6]], jnp.c_['1,2,0', [1,2,3], [4,5,6]])
2021-05-01 01:05:22 +02:00
# negative 1d axis start
self.assertArraysEqual(np.c_['0,4,-1', [1,2,3], [4,5,6]], jnp.c_['0,4,-1', [1,2,3], [4,5,6]])
self.assertArraysEqual(np.c_['0,4,-2', [1,2,3], [4,5,6]], jnp.c_['0,4,-2', [1,2,3], [4,5,6]])
2021-05-01 01:05:22 +02:00
# matrix directives, avoid numpy deprecation warning
2023-11-30 10:35:24 -08:00
with jtu.ignore_warning(category=PendingDeprecationWarning):
self.assertArraysEqual(np.c_['r',[1,2,3], [4,5,6]], jnp.c_['r',[1,2,3], [4,5,6]])
self.assertArraysEqual(np.c_['c', [1, 2, 3], [4, 5, 6]], jnp.c_['c', [1, 2, 3], [4, 5, 6]])
2021-05-01 01:05:22 +02:00
# bad directive
with self.assertRaisesRegex(ValueError, "could not understand directive.*"):
jnp.c_["asdfgh",[1,2,3]]
# abstract tracer value for c_ slice
with self.assertRaisesRegex(core.ConcretizationTypeError,
"slice start of jnp.c_"):
jax.jit(lambda a, b: jnp.c_[a:b])(0, 2)
2021-05-01 01:05:22 +02:00
# wrap indexer for appropriate dtype defaults.
np_c_ = _indexer_with_default_outputs(np.c_)
2021-05-01 01:05:22 +02:00
# Complex number steps
atol = 1e-6
rtol = 1e-6
self.assertAllClose(np_c_[-1:1:6j],
2021-05-01 01:05:22 +02:00
jnp.c_[-1:1:6j],
atol=atol,
rtol=rtol)
# Non-integer steps
self.assertAllClose(np_c_[1.2:4.8:0.24],
2021-05-01 01:05:22 +02:00
jnp.c_[1.2:4.8:0.24],
atol=atol,
rtol=rtol)
2021-07-09 02:35:28 +05:30
def testS_(self):
self.assertEqual(np.s_[1:2:20],jnp.s_[1:2:20])
def testIndex_exp(self):
self.assertEqual(np.index_exp[5:3:2j],jnp.index_exp[5:3:2j])
@jtu.sample_product(
start_shape=[(), (2,), (2, 2)],
stop_shape=[(), (2,), (2, 2)],
num=[0, 1, 2, 5, 20],
endpoint=[True, False],
retstep=[True, False],
# floating-point compute between jitted platforms and non-jit + rounding
# cause unavoidable variation in integer truncation for some inputs, so
# we currently only test inexact 'dtype' arguments.
dtype=inexact_dtypes + [None,],
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testLinspace(self, start_shape, stop_shape, num, endpoint, retstep, dtype):
rng = jtu.rand_default(self.rng())
# relax default tolerances slightly
2020-05-20 01:43:48 -03:00
tol = jtu.tolerance(dtype if dtype else np.float32) * 10
args_maker = self._GetArgsMaker(rng,
[start_shape, stop_shape],
[dtype, dtype])
start, stop = args_maker()
2020-05-20 01:43:48 -03:00
ndim = len(np.shape(start + stop))
for axis in range(-ndim, ndim):
jnp_op = lambda start, stop: jnp.linspace(
start, stop, num,
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
np_op = lambda start, stop: np.linspace(
start, stop, num,
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
2020-05-20 01:43:48 -03:00
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_op, args_maker,
check_dtypes=False, atol=tol, rtol=tol)
@jtu.sample_product(dtype=number_dtypes)
def testLinspaceEndpoints(self, dtype):
"""Regression test for Issue #3014."""
rng = jtu.rand_default(self.rng())
endpoints = rng((2,), dtype)
out = jnp.linspace(*endpoints, 10, dtype=dtype)
self.assertAllClose(out[np.array([0, -1])], endpoints, rtol=0, atol=0)
def testLinspaceArrayNum(self):
"""Regression test for Issue #22405."""
rng = jtu.rand_default(self.rng())
endpoints = rng((2,), np.float32)
# The num parameter is an np.array.
out = jnp.linspace(*endpoints, np.array(10, dtype=np.int32),
dtype=np.float32)
self.assertAllClose(out[np.array([0, -1])], endpoints, rtol=0, atol=0)
@jtu.sample_product(
start_shape=[(), (2,), (2, 2)],
stop_shape=[(), (2,), (2, 2)],
num=[0, 1, 2, 5, 20],
endpoint=[True, False],
base=[10.0, 2, np.e],
# skip 16-bit floats due to insufficient precision for the test.
dtype=jtu.dtypes.inexact + [None,],
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testLogspace(self, start_shape, stop_shape, num,
endpoint, base, dtype):
if (dtype in int_dtypes and
jtu.test_device_matches(["gpu", "tpu"]) and
not config.enable_x64.value):
raise unittest.SkipTest("GPUx32 truncated exponentiation"
" doesn't exactly match other platforms.")
rng = jtu.rand_default(self.rng())
# relax default tolerances slightly
tol = {np.float32: 1e-2, np.float64: 1e-6, np.complex64: 1e-3, np.complex128: 1e-6}
args_maker = self._GetArgsMaker(rng,
[start_shape, stop_shape],
[dtype, dtype])
start, stop = args_maker()
2020-05-20 01:43:48 -03:00
ndim = len(np.shape(start + stop))
for axis in range(-ndim, ndim):
jnp_op = lambda start, stop: jnp.logspace(
start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis)
@jtu.ignore_warning(category=RuntimeWarning,
message="overflow encountered in power")
def np_op(start, stop):
return np.logspace(start, stop, num, endpoint=endpoint,
base=base, dtype=dtype, axis=axis)
2020-05-20 01:43:48 -03:00
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
check_dtypes=False, tol=tol)
if dtype in (inexact_dtypes + [None,]):
# Why do compiled and op-by-op float16 np.power numbers differ
# slightly more than expected?
2020-05-20 01:43:48 -03:00
atol = {np.float16: 1e-2}
self._CompileAndCheck(jnp_op, args_maker,
check_dtypes=False, atol=atol, rtol=tol)
@jtu.sample_product(
[dict(start_shape=start_shape, stop_shape=stop_shape, axis=axis)
for start_shape in [(), (2,), (2, 2)]
for stop_shape in [(), (2,), (2, 2)]
for axis in range(-max(len(start_shape), len(stop_shape)),
max(len(start_shape), len(stop_shape)))
],
num=[0, 1, 2, 5, 20],
endpoint=[True, False],
# NB: numpy's geomspace gives nonsense results on integer types
dtype=inexact_dtypes + [None,],
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testGeomspace(self, start_shape, stop_shape, num,
endpoint, dtype, axis):
rng = jtu.rand_default(self.rng())
# relax default tolerances slightly
tol = {dtypes.bfloat16: 2e-2, np.float16: 4e-3, np.float32: 2e-3,
np.float64: 1e-14, np.complex64: 2e-3, np.complex128: 1e-14}
def args_maker():
2020-05-20 01:43:48 -03:00
"""Test the set of inputs np.geomspace is well-defined on."""
start, stop = self._GetArgsMaker(rng,
[start_shape, stop_shape],
[dtype, dtype])()
2020-05-20 01:43:48 -03:00
# np.geomspace can't handle differently ranked tensors
# w. negative numbers!
start, stop = jnp.broadcast_arrays(start, stop)
if dtype in complex_dtypes:
return start, stop
# to avoid NaNs, non-complex start and stop cannot
# differ in sign, elementwise
start = start * jnp.sign(start) * jnp.sign(stop)
return start, stop
start, stop = args_maker()
def jnp_op(start, stop):
return jnp.geomspace(start, stop, num, endpoint=endpoint, dtype=dtype,
axis=axis)
def np_op(start, stop):
start = start.astype(np.float32) if dtype == jnp.bfloat16 else start
stop = stop.astype(np.float32) if dtype == jnp.bfloat16 else stop
return np.geomspace(
start, stop, num, endpoint=endpoint,
dtype=dtype if dtype != jnp.bfloat16 else np.float32,
axis=axis).astype(dtype)
# JAX follows NumPy 2.0 semantics for complex geomspace.
if not (jtu.numpy_version() < (2, 0, 0) and dtypes.issubdtype(dtype, jnp.complexfloating)):
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
check_dtypes=False, tol=tol)
if dtype in (inexact_dtypes + [None,]):
self._CompileAndCheck(jnp_op, args_maker,
check_dtypes=False, atol=tol, rtol=tol)
def testDisableNumpyRankPromotionBroadcasting(self):
with jax.numpy_rank_promotion('allow'):
jnp.ones(2) + jnp.ones((1, 2)) # works just fine
with jax.numpy_rank_promotion('raise'):
self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2)))
jnp.ones(2) + 3 # don't want to raise for scalars
with jax.numpy_rank_promotion('warn'):
with self.assertWarnsRegex(
UserWarning,
"Following NumPy automatic rank promotion for add on shapes "
r"\(2,\) \(1, 2\).*"
):
jnp.ones(2) + jnp.ones((1, 2))
jnp.ones(2) + 3 # don't want to warn for scalars
@unittest.skip("Test fails on CI, perhaps due to JIT caching")
def testDisableNumpyRankPromotionBroadcastingDecorator(self):
with jax.numpy_rank_promotion("allow"):
jnp.ones(2) + jnp.ones((1, 2)) # works just fine
with jax.numpy_rank_promotion("raise"):
self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2)))
jnp.ones(2) + 3 # don't want to raise for scalars
with jax.numpy_rank_promotion("warn"):
self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on "
r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2)))
jnp.ones(2) + 3 # don't want to warn for scalars
def testStackArrayArgument(self):
# tests https://github.com/jax-ml/jax/issues/1271
@jax.jit
def foo(x):
return jnp.stack(x)
2020-05-20 01:43:48 -03:00
foo(np.zeros(2)) # doesn't crash
@jax.jit
def foo(x):
return jnp.concatenate(x)
2020-05-20 01:43:48 -03:00
foo(np.zeros((2, 2))) # doesn't crash
def testReluGradientConstants(self):
# This is a regression test that verifies that constants associated with the
# gradient of np.maximum (from lax._balanced_eq) aren't hoisted into the
# outermost jaxpr. This was producing some large materialized constants for
# every relu activation in a model.
def body(i, xy):
x, y = xy
y = y + jax.grad(lambda z: jnp.sum(jnp.maximum(z, 0.)))(x)
return x, y
f = lambda y: lax.fori_loop(0, 5, body, (y, y))
2021-04-23 08:31:11 -07:00
jaxpr = jax.make_jaxpr(f)(np.zeros((3, 4), np.float32))
self.assertFalse(
2020-05-20 01:43:48 -03:00
any(np.array_equal(x, np.full((3, 4), 2., dtype=np.float32))
2021-04-23 08:31:11 -07:00
for x in jaxpr.consts))
@jtu.sample_product(
[dict(from_shape=from_shape, to_shape=to_shape)
for from_shape, to_shape in [
[(1, 3), (4, 3)],
[(3,), (2, 1, 3)],
[(3,), (3, 3)],
[(1,), (3,)],
[(1,), 3],
]
],
)
def testBroadcastTo(self, from_shape, to_shape):
rng = jtu.rand_default(self.rng())
2020-05-20 01:43:48 -03:00
args_maker = self._GetArgsMaker(rng, [from_shape], [np.float32])
np_op = lambda x: np.broadcast_to(x, to_shape)
jnp_op = lambda x: jnp.broadcast_to(x, to_shape)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testBroadcastToInvalidShape(self):
# Regression test for https://github.com/jax-ml/jax/issues/20533
x = jnp.zeros((3, 4, 5))
with self.assertRaisesRegex(
ValueError, "Cannot broadcast to shape with fewer dimensions"):
jnp.broadcast_to(x, (4, 5))
@jtu.sample_product(
[dict(shapes=shapes, broadcasted_shape=broadcasted_shape)
for shapes, broadcasted_shape in [
[[], ()],
[[()], ()],
[[(1, 3), (4, 3)], (4, 3)],
[[(3,), (2, 1, 3)], (2, 1, 3)],
[[(3,), (3, 3)], (3, 3)],
[[(1,), (3,)], (3,)],
[[(1,), 3], (3,)],
[[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)],
[[[1], [0, 1]], (0, 1)],
[[(1,), np.array([0, 1])], (0, 1)],
]
],
)
def testBroadcastShapes(self, shapes, broadcasted_shape):
# Test against np.broadcast_shapes once numpy 1.20 is minimum required version
np.testing.assert_equal(jnp.broadcast_shapes(*shapes), broadcasted_shape)
def testBroadcastToIssue1522(self):
self.assertRaisesRegex(
ValueError, "Incompatible shapes for broadcasting: .*",
2020-05-20 01:43:48 -03:00
lambda: jnp.broadcast_to(np.ones((2, 3)), (1, 3)))
def testBroadcastToIntIssue1548(self):
2020-05-20 01:43:48 -03:00
self.assertAllClose(jnp.broadcast_to(1, (3, 2)), np.ones((3, 2)),
check_dtypes=False)
def testBroadcastToOnScalar(self):
self.assertIsInstance(jnp.broadcast_to(10.0, ()), jax.Array)
2020-05-20 01:43:48 -03:00
self.assertIsInstance(np.broadcast_to(10.0, ()), np.ndarray)
def testPrecision(self):
2020-05-20 01:43:48 -03:00
ones_1d = np.ones((2,))
ones_2d = np.ones((2, 2))
ones_3d = np.ones((2, 2, 2))
HIGHEST = lax.Precision.HIGHEST
jtu.assert_dot_precision(None, jnp.dot, ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.dot, precision=HIGHEST),
ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.dot, precision=HIGHEST),
ones_3d, ones_3d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.matmul, precision=HIGHEST),
ones_2d, ones_2d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.vdot, precision=HIGHEST),
ones_1d, ones_1d)
2024-01-09 13:23:57 -08:00
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.vecdot, precision=HIGHEST),
ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.tensordot, axes=2, precision=HIGHEST),
ones_2d, ones_2d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.tensordot, axes=(0, 0), precision=HIGHEST),
ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.tensordot, axes=((0,), (0,)), precision=HIGHEST),
ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.einsum, 'i,i', precision=HIGHEST),
ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.einsum, 'ij,ij', precision=HIGHEST),
ones_2d, ones_2d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.inner, precision=HIGHEST),
ones_1d, ones_1d)
@jtu.sample_product(
2024-01-09 13:23:57 -08:00
funcname=['inner', 'matmul', 'dot', 'vdot', 'tensordot', 'vecdot']
)
def testPreferredElementType(self, funcname):
func = getattr(jnp, funcname)
kwargs = dict(axes=0) if funcname == 'tensordot' else {}
ones_i32 = np.ones(2, dtype='int32')
ones_f32 = np.ones(2, dtype='float32')
with jax.numpy_dtype_promotion('strict'):
jtu.assert_dot_preferred_element_type('int32', func, ones_i32, ones_i32, **kwargs)
jtu.assert_dot_preferred_element_type('float32', func, ones_f32, ones_f32, **kwargs)
jtu.assert_dot_preferred_element_type('bfloat16', func, ones_f32, ones_f32, **kwargs,
preferred_element_type='bfloat16')
with jax.numpy_dtype_promotion('standard'):
jtu.assert_dot_preferred_element_type('float32', func, ones_i32, ones_f32, **kwargs)
@jtu.sample_product(
[dict(shape=shape, varargs=varargs, axis=axis)
for shape in [(10,), (10, 15), (10, 15, 20)]
2020-04-13 17:48:49 -07:00
for _num_axes in range(len(shape))
2020-04-28 19:34:27 +01:00
for varargs in itertools.combinations(range(1, len(shape) + 1), _num_axes)
for axis in itertools.combinations(range(len(shape)), _num_axes)
],
dtype=inexact_dtypes,
)
def testGradient(self, shape, varargs, axis, dtype):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
2020-04-28 19:34:27 +01:00
jnp_fun = lambda y: jnp.gradient(y, *varargs, axis=axis)
2020-05-20 01:43:48 -03:00
np_fun = lambda y: np.gradient(y, *varargs, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
shape=[(5,), (5, 7), (5, 10, 3)],
dtype=inexact_dtypes,
)
def testGradientNonConstant(self, shape, dtype):
rng = jtu.rand_default(self.rng())
varargs = [(s,) for s in shape]
args = [shape] + varargs
args_maker = self._GetArgsMaker(rng, args, [dtype] * len(args))
atol = jtu.tolerance(
dtype, {np.float16: 4e-2, jax.dtypes.bfloat16: 4e-1, np.float32: 2e-5}
)
rtol = jtu.tolerance(dtype, {jax.dtypes.bfloat16: 5e-1})
self._CheckAgainstNumpy(
np.gradient,
jnp.gradient,
args_maker,
check_dtypes=False,
atol=atol,
rtol=rtol,
)
self._CompileAndCheck(jnp.gradient, args_maker)
2020-01-06 20:57:19 -08:00
def testZerosShapeErrors(self):
# see https://github.com/jax-ml/jax/issues/1822
2020-01-06 20:57:19 -08:00
self.assertRaisesRegex(
TypeError,
"Shapes must be 1D sequences of concrete values of integer type.*",
lambda: jnp.zeros(1.))
2020-01-06 20:57:19 -08:00
self.assertRaisesRegex(
TypeError,
r"Shapes must be 1D sequences of concrete values of integer type.*\n"
"If using `jit`, try using `static_argnums` or applying `jit` to "
"smaller subfunctions.",
lambda: jax.jit(jnp.zeros)(2))
2020-01-06 20:57:19 -08:00
def testTraceMethod(self):
x = self.rng().randn(3, 4).astype(jnp.float_)
self.assertAllClose(x.trace(), jnp.array(x).trace())
self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x))
@jtu.ignore_warning(category=RuntimeWarning, message="divide by zero")
def testIntegerPowersArePrecise(self):
# See https://github.com/jax-ml/jax/pull/3036
# Checks if the squares of float32 integers have no numerical errors.
# It should be satisfied with all integers less than sqrt(2**24).
x = jnp.arange(-2**12, 2**12, dtype=jnp.int32)
2020-05-20 01:43:48 -03:00
np.testing.assert_array_equal(jnp.square(x.astype(jnp.float32)), x * x)
np.testing.assert_array_equal(x.astype(jnp.float32) ** 2, x * x)
# Similarly for cubes.
x = jnp.arange(-2**8, 2**8, dtype=jnp.int32)
2020-05-20 01:43:48 -03:00
np.testing.assert_array_equal(x.astype(jnp.float32) ** 3, x * x * x)
2020-05-20 01:43:48 -03:00
x = np.arange(10, dtype=np.float32)
for i in range(10):
self.assertAllClose(x.astype(jnp.float32) ** i, x ** i,
check_dtypes=False)
def testToBytes(self):
v = np.arange(12, dtype=np.int32).reshape(3, 4)
for order in ['C', 'F']:
self.assertEqual(jnp.asarray(v).tobytes(order), v.tobytes(order))
def testToBytesJitError(self):
v = np.arange(12, dtype=np.int32).reshape(3, 4)
f = jax.jit(lambda x: x.tobytes())
msg = r".*The tobytes\(\) method was called on traced array"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(v)
def testToList(self):
v = np.arange(12, dtype=np.int32).reshape(3, 4)
self.assertEqual(jnp.asarray(v).tolist(), v.tolist())
def testToListJitError(self):
v = np.arange(12, dtype=np.int32).reshape(3, 4)
f = jax.jit(lambda x: x.tolist())
msg = r".*The tolist\(\) method was called on traced array"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(v)
def testArangeConcretizationError(self):
msg = r"It arose in the jnp.arange argument '{}'".format
with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')):
jax.jit(jnp.arange)(3)
with self.assertRaisesRegex(core.ConcretizationTypeError, msg('start')):
jax.jit(lambda start: jnp.arange(start, 3))(0)
with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')):
jax.jit(lambda stop: jnp.arange(0, stop))(3)
@jtu.sample_product(dtype=[None] + float_dtypes)
2021-11-15 13:33:51 -08:00
def testArange64Bit(self, dtype):
# Test that jnp.arange uses 64-bit arithmetic to define its range, even if the
# output has another dtype. The issue here is that if python scalar inputs to
# jnp.arange are cast to float32 before the range is computed, it changes the
# number of elements output by the range. It's unclear whether this was deliberate
# behavior in the initial implementation, but it's behavior that downstream users
# have come to rely on.
args = (1.2, 4.8, 0.24)
# Ensure that this test case leads to differing lengths if cast to float32.
self.assertLen(np.arange(*args), 15)
self.assertLen(np.arange(*map(np.float32, args)), 16)
jnp_fun = lambda: jnp.arange(*args, dtype=dtype)
np_fun = jtu.with_jax_dtype_defaults(lambda: np.arange(*args, dtype=dtype), dtype is None)
2021-11-15 13:33:51 -08:00
args_maker = lambda: []
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testIssue2347(self):
# https://github.com/jax-ml/jax/issues/2347
object_list = list[tuple[jnp.array, float, float, jnp.array, bool]]
self.assertRaises(TypeError, jnp.array, object_list)
np_object_list = np.array(object_list)
self.assertRaises(TypeError, jnp.array, np_object_list)
@jtu.sample_product(
[dict(shapes=shapes, dtypes=dtypes)
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(all_shapes, 2))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, complex_dtypes) for s in shapes))
],
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
2021-07-03 18:09:58 +02:00
def testLogaddexpComplex(self, shapes, dtypes):
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")
def np_op(x1, x2):
return np.log(np.exp(x1) + np.exp(x2))
rng = jtu.rand_some_nan(self.rng())
args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes))
if jtu.test_device_matches(["tpu"]):
tol = {np.complex64: 1e-3, np.complex128: 1e-10}
else:
tol = {np.complex64: 1e-5, np.complex128: 1e-14}
with jtu.strict_promotion_if_dtypes_match(dtypes):
2022-10-06 10:20:26 -07:00
self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp.logaddexp, args_maker, tol=tol)
self._CompileAndCheck(jnp.logaddexp, args_maker, rtol=tol, atol=tol)
2021-07-03 18:09:58 +02:00
@jtu.sample_product(
[dict(shapes=shapes, dtypes=dtypes)
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(all_shapes, 2))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, complex_dtypes) for s in shapes))
],
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
2021-07-03 18:09:58 +02:00
def testLogaddexp2Complex(self, shapes, dtypes):
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")
def np_op(x1, x2):
return np.log2(np.exp2(x1) + np.exp2(x2))
rng = jtu.rand_some_nan(self.rng())
args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes))
if jtu.test_device_matches(["tpu"]):
tol = {np.complex64: 1e-3, np.complex128: 1e-10}
else:
tol = {np.complex64: 1e-5, np.complex128: 1e-14}
with jtu.strict_promotion_if_dtypes_match(dtypes):
2022-10-06 10:20:26 -07:00
self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp.logaddexp2, args_maker, tol=tol)
self._CompileAndCheck(jnp.logaddexp2, args_maker, rtol=tol, atol=tol)
2021-07-03 18:09:58 +02:00
def testDefaultDtypes(self):
precision = config.default_dtype_bits.value
assert precision in ['32', '64']
self.assertEqual(jnp.bool_, np.bool_)
self.assertEqual(jnp.int_, np.int32 if precision == '32' else np.int64)
self.assertEqual(jnp.uint, np.uint32 if precision == '32' else np.uint64)
self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64)
self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128)
def testFromBuffer(self):
buf = b'\x01\x02\x03'
expected = np.frombuffer(buf, dtype='uint8')
actual = jnp.frombuffer(buf, dtype='uint8')
self.assertArraysEqual(expected, actual)
def testFromFunction(self):
def f(x, y, z):
return x + 2 * y + 3 * z
shape = (3, 4, 5)
expected = np.fromfunction(f, shape=shape)
actual = jnp.fromfunction(f, shape=shape)
self.assertArraysEqual(expected, actual, check_dtypes=False)
def testFromString(self):
s = "1,2,3"
expected = np.fromstring(s, sep=',', dtype=int)
actual = jnp.fromstring(s, sep=',', dtype=int)
self.assertArraysEqual(expected, actual)
@jtu.sample_product(
a_shape=nonempty_nonscalar_array_shapes,
v_shape=nonempty_shapes,
dtype=jtu.dtypes.all,
)
def testPlace(self, a_shape, v_shape, dtype):
rng = jtu.rand_default(self.rng())
mask_rng = jtu.rand_bool(self.rng())
def args_maker():
a = rng(a_shape, dtype)
m = mask_rng(a_shape, bool)
v = rng(v_shape, dtype)
return a, m, v
def np_fun(a, m, v):
a_copy = a.copy()
np.place(a_copy, m, v)
return a_copy
jnp_fun = partial(jnp.place, inplace=False)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2023-07-26 08:54:54 -07:00
@jtu.sample_product(
a_shape=nonempty_nonscalar_array_shapes,
i_shape=all_shapes,
v_shape=all_shapes,
dtype=jtu.dtypes.all,
mode=[None, 'wrap', 'clip'],
)
def testPut(self, mode, a_shape, i_shape, v_shape, dtype):
size = math.prod(a_shape)
if math.prod(i_shape) > size:
self.skipTest("too many indices")
rng = jtu.rand_default(self.rng())
# Must test unique integers, because overlapping updates in
# JAX have implementation-defined order
idx_rng = jtu.rand_unique_int(self.rng(), size)
def args_maker():
a = rng(a_shape, dtype)
i = idx_rng(i_shape, np.int32)
v = rng(v_shape, dtype)
# put some indices out of range without duplicating indices
if mode == "clip" and i.size:
np.put(i, np.argmax(i), size + 2)
np.put(i, np.argmin(i), -2)
if mode == "wrap" and i.size:
np.put(i, 0, np.take(i, 0) + size)
return a, i, v
def np_fun(a, i, v):
a_copy = a.copy()
np.put(a_copy, i, v, mode=mode)
return a_copy
jnp_fun = partial(jnp.put, mode=mode, inplace=False)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
2024-11-14 15:23:26 -05:00
@jtu.sample_product(
[
dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=axis)
for a_shape in nonempty_array_shapes
for axis in list(range(-len(a_shape), len(a_shape)))
for i_shape in [tuple_replace(a_shape, axis, J) for J in range(a_shape[axis] + 1)]
for v_shape in [(), (1,), i_shape]
] + [
dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=None)
for a_shape in nonempty_array_shapes
for i_shape in [(J,) for J in range(math.prod(a_shape) + 1)]
for v_shape in [(), (1,), i_shape]
],
dtype=jtu.dtypes.all,
mode=[None, "promise_in_bounds", "clip"],
)
def testPutAlongAxis(self, a_shape, i_shape, v_shape, axis, dtype, mode):
a_rng = jtu.rand_default(self.rng())
if axis is None:
size = math.prod(a_shape)
else:
size = a_shape[axis]
i_rng = jtu.rand_indices_unique_along_axis(self.rng())
def args_maker():
a = a_rng(a_shape, dtype)
i = i_rng(dim=size, shape=i_shape, axis=0 if axis is None else axis)
v = a_rng(v_shape, dtype)
return a, i, v
def np_fun(a, i, v):
a_copy = a.copy()
np.put_along_axis(a_copy, i, v, axis=axis)
return a_copy
jnp_fun = partial(jnp.put_along_axis, axis=axis, inplace=False, mode=mode)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def test_rot90_error(self):
with self.assertRaisesRegex(
ValueError,
"rot90 requires its first argument to have ndim at least two, "
"but got first argument of"):
jnp.rot90(jnp.ones(2))
@parameterized.named_parameters(
('ones', jnp.ones),
('zeros', jnp.zeros),
('empty', jnp.empty))
def test_error_hint(self, fn):
with self.assertRaisesRegex(
TypeError,
r"Did you accidentally write `jax\.numpy\..*?\(2, 3\)` "
r"when you meant `jax\.numpy\..*?\(\(2, 3\)\)`"):
fn(2, 3)
@jtu.sample_product(
dtype=jtu.dtypes.all,
kind=['bool', 'signed integer', 'unsigned integer', 'integral',
'real floating', 'complex floating', 'numeric']
)
def test_isdtype(self, dtype, kind):
# Full tests also in dtypes_test.py; here we just compare against numpy
jax_result = jnp.isdtype(dtype, kind)
if jtu.numpy_version() < (2, 0, 0) or dtype == dtypes.bfloat16:
# just a smoke test
self.assertIsInstance(jax_result, bool)
else:
numpy_result = np.isdtype(dtype, kind)
self.assertEqual(jax_result, numpy_result)
@jtu.sample_product(
[dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis)
for yshape, xshape, dx, axis in [
((10,), None, 1.0, -1),
((3, 10), None, 2.0, -1),
((3, 10), None, 3.0, -0),
((10, 3), (10,), 1.0, -2),
((3, 10), (10,), 1.0, -1),
((3, 10), (3, 10), 1.0, -1),
((2, 3, 10), (3, 10), 1.0, -2),
]
],
dtype=float_dtypes + int_dtypes,
)
@jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test.
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def test_trapezoid(self, yshape, xshape, dtype, dx, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None]
if jtu.numpy_version() >= (2, 0, 0):
np_fun = partial(np.trapezoid, dx=dx, axis=axis)
else:
np_fun = partial(np.trapz, dx=dx, axis=axis)
jnp_fun = partial(jnp.trapezoid, dx=dx, axis=axis)
tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12,
jax.dtypes.bfloat16: 4e-2})
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol,
check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
check_dtypes=False)
@jtu.sample_product(
shape=all_shapes,
dtype=default_dtypes,
op=['ndim', 'shape', 'size'],
)
def testNdimShapeSize(self, shape, dtype, op):
rng = jtu.rand_default(self.rng())
jnp_op = getattr(jnp, op)
np_op = getattr(np, op)
x = rng(shape, dtype)
expected = np_op(x)
self.assertEqual(expected, jnp_op(x)) # np.ndarray or scalar input.
self.assertEqual(expected, jnp_op(jnp.asarray(x))) # jax.Array input.
self.assertEqual(expected, jax.jit(jnp_op)(x)) # Traced input.
@jtu.sample_product(
shape=nonzerodim_shapes,
dtype=default_dtypes,
)
def testSizeAlongAxis(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
axis = self.rng().randint(-len(shape), len(shape))
np_op = partial(np.size, axis=axis)
jnp_op = partial(jnp.size, axis=axis)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@jtu.sample_product(
op=[jnp.ndim, jnp.shape, jnp.size],
)
def testNdimShapeSizeNonArrayInput(self, op):
msg = f"{op.__name__} requires ndarray or scalar arguments"
with self.assertWarnsRegex(DeprecationWarning, msg):
op([1, 2, 3])
# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.
GradTestSpec = collections.namedtuple(
"GradTestSpec",
["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"])
def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
return GradTestSpec(
op, nargs, order, rng_factory, dtypes, name or op.__name__, tol)
GRAD_TEST_RECORDS = [
grad_test_spec(jnp.arcsinh, nargs=1, order=2,
rng_factory=jtu.rand_positive,
2020-05-20 01:43:48 -03:00
dtypes=[np.float64, np.complex64],
tol={np.complex64: 2e-2}),
grad_test_spec(jnp.arccosh, nargs=1, order=1,
rng_factory=jtu.rand_positive,
2020-05-20 01:43:48 -03:00
dtypes=[np.float64, np.complex64],
tol={np.complex64: 2e-2}),
grad_test_spec(jnp.arctanh, nargs=1, order=2,
rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9),
2020-05-20 01:43:48 -03:00
dtypes=[np.float64, np.complex64],
tol={np.complex64: 2e-2}),
grad_test_spec(jnp.logaddexp, nargs=2, order=1,
rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9),
2020-05-20 01:43:48 -03:00
dtypes=[np.float64], tol=1e-4),
grad_test_spec(jnp.logaddexp2, nargs=2, order=2,
rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9),
2020-05-20 01:43:48 -03:00
dtypes=[np.float64], tol=1e-4),
]
GradSpecialValuesTestSpec = collections.namedtuple(
"GradSpecialValuesTestSpec", ["op", "values", "order"])
GRAD_SPECIAL_VALUE_TEST_RECORDS = [
GradSpecialValuesTestSpec(jnp.arcsinh, [0., 1000.], 2),
GradSpecialValuesTestSpec(jnp.arccosh, [1000.], 2),
GradSpecialValuesTestSpec(jnp.arctanh, [0.], 2),
GradSpecialValuesTestSpec(jnp.sinc, [0.], 1),
]
class NumpyGradTests(jtu.JaxTestCase):
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(op=rec.op, rng_factory=rec.rng_factory, tol=rec.tol,
order=rec.order)],
shapes=itertools.combinations_with_replacement(nonempty_shapes, rec.nargs),
dtype=rec.dtypes)
for rec in GRAD_TEST_RECORDS))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
rng = rng_factory(self.rng())
tol = jtu.join_tolerance(tol, {np.float32: 1e-1, np.float64: 1e-3,
np.complex64: 1e-1, np.complex128: 1e-3})
if jtu.test_device_matches(["tpu"]) and op == jnp.arctanh:
tol = jtu.join_tolerance(tol, {np.float32: 2e-1})
args = tuple(rng(shape, dtype) for shape in shapes)
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(op=rec.op, order=rec.order)],
special_value=rec.values
)
for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS))
def testOpGradSpecialValue(self, op, special_value, order):
check_grads(op, (special_value,), order, ["fwd", "rev"],
atol={np.float32: 3e-3})
2020-12-02 00:36:39 -08:00
def testSincAtZero(self):
# Some manual tests for sinc at zero, since it doesn't have well-behaved
# numerical derivatives at zero
def deriv(f):
return lambda x: jax.jvp(f, (x,), (1.,))[1]
2020-12-02 00:36:39 -08:00
def apply_all(fns, x):
for f in fns:
x = f(x)
return x
d1 = 0.
for ops in itertools.combinations_with_replacement([deriv, jax.grad], 1):
2020-12-02 00:36:39 -08:00
self.assertAllClose(apply_all(ops, jnp.sinc)(0.), d1)
d2 = -np.pi ** 2 / 3
for ops in itertools.combinations_with_replacement([deriv, jax.grad], 2):
2020-12-02 00:36:39 -08:00
self.assertAllClose(apply_all(ops, jnp.sinc)(0.), d2)
d3 = 0.
for ops in itertools.combinations_with_replacement([deriv, jax.grad], 3):
2020-12-02 00:36:39 -08:00
self.assertAllClose(apply_all(ops, jnp.sinc)(0.), d3)
d4 = np.pi ** 4 / 5
for ops in itertools.combinations_with_replacement([deriv, jax.grad], 4):
2020-12-02 00:36:39 -08:00
self.assertAllClose(apply_all(ops, jnp.sinc)(0.), d4)
2020-12-02 08:55:14 -08:00
def testSincGradArrayInput(self):
# tests for a bug almost introduced in #5077
jax.grad(lambda x: jnp.sinc(x).sum())(jnp.arange(10.)) # doesn't crash
def testTakeAlongAxisIssue1521(self):
# https://github.com/jax-ml/jax/issues/1521
idx = jnp.repeat(jnp.arange(3), 10).reshape((30, 1))
def f(x):
y = x * jnp.arange(3.).reshape((1, 3))
return jnp.take_along_axis(y, idx, -1).sum()
check_grads(f, (1.,), order=1)
@jtu.sample_product(
shapes=filter(_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(nonempty_shapes, 2)),
dtype=(np.complex128,),
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
2021-07-03 18:09:58 +02:00
def testGradLogaddexpComplex(self, shapes, dtype):
rng = jtu.rand_default(self.rng())
args = tuple(jnp.array(rng(shape, dtype)) for shape in shapes)
if jtu.test_device_matches(["tpu"]):
tol = 5e-2
2021-07-03 18:09:58 +02:00
else:
tol = 3e-2
check_grads(jnp.logaddexp, args, 1, ["fwd", "rev"], tol, tol)
2021-07-03 18:09:58 +02:00
@jtu.sample_product(
shapes=filter(_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(nonempty_shapes, 2)),
dtype=(np.complex128,),
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
2021-07-03 18:09:58 +02:00
def testGradLogaddexp2Complex(self, shapes, dtype):
rng = jtu.rand_default(self.rng())
args = tuple(jnp.array(rng(shape, dtype)) for shape in shapes)
if jtu.test_device_matches(["tpu"]):
tol = 5e-2
2021-07-03 18:09:58 +02:00
else:
tol = 3e-2
check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol)
@jtu.sample_product(
n=range(-4, 5),
dtype=[jnp.float32, jnp.float64],
)
def testGradLdexp(self, n, dtype):
rng = jtu.rand_default(self.rng())
x = rng((), dtype)
check_grads(lambda x: jnp.ldexp(x, n), (x,), 1)
2020-07-17 14:15:52 -04:00
class NumpySignaturesTest(jtu.JaxTestCase):
def testWrappedSignaturesMatch(self):
"""Test that jax.numpy function signatures match numpy."""
# NumPy functions explicitly not implemented in JAX:
skip = {'array2string',
'asanyarray',
'asarray_chkfinite',
'ascontiguousarray',
'asfortranarray',
'asmatrix',
'base_repr',
'binary_repr',
'bmat',
'broadcast',
'busday_count',
'busday_offset',
'busdaycalendar',
'common_type',
'copyto',
'datetime_as_string',
'datetime_data',
'errstate',
'flatiter',
'format_float_positional',
'format_float_scientific',
'fromregex',
'genfromtxt',
'get_include',
'getbufsize',
'geterr',
'geterrcall',
'in1d',
'info',
'is_busday',
'isfortran',
'isnat',
'loadtxt',
'matrix',
'may_share_memory',
'memmap',
'min_scalar_type',
'mintypecode',
'ndenumerate',
'ndindex',
'nditer',
'nested_iters',
'poly1d',
'putmask',
'real_if_close',
'recarray',
'record',
'require',
'row_stack',
'savetxt',
'savez_compressed',
'setbufsize',
'seterr',
'seterrcall',
'shares_memory',
'show_config',
'show_runtime',
'test',
'trapz',
2024-12-10 16:03:19 -08:00
'typename'}
# symbols removed in NumPy 2.0
skip |= {'add_docstring',
'add_newdoc',
'add_newdoc_ufunc',
'alltrue',
'asfarray',
'byte_bounds',
'compare_chararrays',
'cumproduct',
'deprecate',
'deprecate_with_doc',
'disp',
'fastCopyAndTranspose',
'find_common_type',
'get_array_wrap',
'geterrobj',
'issctype',
'issubclass_',
'issubsctype',
'lookfor',
'mat',
'maximum_sctype',
'msort',
'obj2sctype',
'product',
'recfromcsv',
'recfromtxt',
'round_',
'safe_eval',
'sctype2char',
'set_numeric_ops',
'set_string_function',
'seterrobj',
'sometrue',
'source',
'who'}
self.assertEmpty(skip.intersection(dir(jnp)))
names = (name for name in dir(np) if not (name.startswith('_') or name in skip))
names = (name for name in names if callable(getattr(np, name)))
names = {name for name in names if not isinstance(getattr(np, name), type)}
self.assertEmpty(names.difference(dir(jnp)))
self.assertNotEmpty(names)
# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
unsupported_params = {
2023-02-08 14:41:39 -08:00
'argpartition': ['kind', 'order'],
'asarray': ['like'],
'broadcast_to': ['subok'],
'clip': ['kwargs', 'out'],
2022-03-01 10:44:33 -08:00
'copy': ['subok'],
'corrcoef': ['ddof', 'bias', 'dtype'],
'cov': ['dtype'],
'cumulative_prod': ['out'],
'cumulative_sum': ['out'],
'empty_like': ['subok', 'order'],
'einsum': ['kwargs'],
2020-07-17 14:15:52 -04:00
'einsum_path': ['einsum_call'],
'eye': ['order', 'like'],
'hstack': ['casting'],
'identity': ['like'],
'isin': ['kind'],
'full': ['order', 'like'],
'full_like': ['subok', 'order'],
'fromfunction': ['like'],
'load': ['mmap_mode', 'allow_pickle', 'fix_imports', 'encoding', 'max_header_size'],
'nanpercentile': ['weights'],
'nanquantile': ['weights'],
'nanstd': ['correction', 'mean'],
'nanvar': ['correction', 'mean'],
'ones': ['order', 'like'],
'ones_like': ['subok', 'order'],
2023-01-30 13:50:25 -08:00
'partition': ['kind', 'order'],
'percentile': ['weights'],
'quantile': ['weights'],
'row_stack': ['casting'],
'stack': ['casting'],
'std': ['mean'],
'tri': ['like'],
'trim_zeros': ['axis'],
'var': ['mean'],
'vstack': ['casting'],
'zeros_like': ['subok', 'order']
}
extra_params = {
'compress': ['size', 'fill_value'],
'einsum': ['subscripts', 'precision'],
2020-07-17 14:15:52 -04:00
'einsum_path': ['subscripts'],
'fill_diagonal': ['inplace'],
'load': ['args', 'kwargs'],
'take_along_axis': ['mode', 'fill_value'],
'unique': ['size', 'fill_value'],
}
mismatches = {}
for name in names:
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
2024-06-21 14:57:35 -07:00
if name in ['histogram', 'histogram2d', 'histogramdd']:
# numpy 1.24 re-orders the density and weights arguments.
# TODO(jakevdp): migrate histogram APIs to match newer numpy versions.
continue
if name == "clip":
# JAX's support of the Array API spec for clip, and the way it handles
# backwards compatibility was introduced in
# https://github.com/jax-ml/jax/pull/20550 with a different signature
# from the one in numpy, introduced in
# https://github.com/numpy/numpy/pull/26724
# TODO(dfm): After our deprecation period for the clip arguments ends
# it should be possible to reintroduce the check.
continue
if name == "reshape":
# Similar issue to clip: we'd need logic specific to the NumPy version
# because of the change in argument name from `newshape` to `shape`.
continue
# Note: can't use inspect.getfullargspec for some functions due to numpy issue
# https://github.com/numpy/numpy/issues/12225
try:
np_params = inspect.signature(np_fun).parameters
except ValueError:
continue
jnp_params = inspect.signature(jnp_fun).parameters
extra = set(extra_params.get(name, []))
unsupported = set(unsupported_params.get(name, []))
# Checks to prevent tests from becoming out-of-date. If these fail,
# it means that extra_params or unsupported_params need to be updated.
2022-12-01 09:12:01 -08:00
assert extra.issubset(jnp_params), f"{name}: {extra=} is not a subset of jnp_params={set(jnp_params)}."
assert not unsupported.intersection(jnp_params), f"{name}: {unsupported=} overlaps with jnp_params={set(jnp_params)}."
# Skip functions that only have *args and **kwargs; we can't introspect these further.
var_args = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
if all(p.kind in var_args for p in jnp_params.values()):
continue
if all(p.kind in var_args for p in np_params.values()):
continue
# Remove known extra parameters.
jnp_params = {a: p for a, p in jnp_params.items() if a not in extra}
# Remove known unsupported parameters.
np_params = {a: p for a, p in np_params.items() if a not in unsupported}
# Older versions of numpy may have fewer parameters; to avoid extraneous errors on older numpy
# versions, we allow for jnp to have more parameters.
if list(jnp_params)[:len(np_params)] != list(np_params):
mismatches[name] = {'np_params': list(np_params), 'jnp_params': list(jnp_params)}
self.assertEqual(mismatches, {})
_available_numpy_dtypes: list[str] = [dtype.__name__ for dtype in jtu.dtypes.all
if dtype != dtypes.bfloat16]
2024-12-02 10:57:57 -08:00
# TODO(jakevdp): implement missing ufuncs.
UNIMPLEMENTED_UFUNCS = {'spacing', 'matvec', 'vecmat'}
def _all_numpy_ufuncs() -> Iterator[str]:
"""Generate the names of all ufuncs in the top-level numpy namespace."""
for name in dir(np):
f = getattr(np, name)
if isinstance(f, np.ufunc) and name not in UNIMPLEMENTED_UFUNCS:
yield name
def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]:
"""Generate valid dtypes of inputs to the given numpy ufunc."""
func = getattr(np, name)
for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin):
args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes)
try:
with jtu.ignore_warning(
category=RuntimeWarning, message="(divide by zero|invalid value)"):
_ = func(*args)
except TypeError:
pass
else:
yield arg_dtypes
class NumpyUfuncTests(jtu.JaxTestCase):
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases([dict(name=name)],
arg_dtypes=_dtypes_for_ufunc(name))
for name in _all_numpy_ufuncs()
))
def testUfuncInputTypes(self, name, arg_dtypes):
2024-01-08 11:15:38 -08:00
if name in ['arctanh', 'atanh'] and jnp.issubdtype(arg_dtypes[0], jnp.complexfloating):
self.skipTest("np.arctanh & jnp.arctanh have mismatched NaNs for complex input.")
jnp_op = getattr(jnp, name)
np_op = getattr(np, name)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="(divide by zero|invalid value)")(np_op)
args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes)
with jtu.strict_promotion_if_dtypes_match(arg_dtypes):
# large tol comes from the fact that numpy returns float16 in places
# that jnp returns float32. e.g. np.cos(np.uint8(0))
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False, tol=1E-2)
class NumpyDocTests(jtu.JaxTestCase):
def test_lax_numpy_docstrings(self):
2024-05-03 11:04:43 -07:00
unimplemented = ['fromfile', 'fromiter']
2024-06-25 23:07:59 +05:30
aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2',
'amax', 'amin', 'around', 'bitwise_invert', 'bitwise_left_shift',
'bitwise_not','bitwise_right_shift', 'conj', 'degrees', 'divide',
'get_printoptions', 'mod', 'pow', 'printoptions', 'radians', 'round_',
'set_printoptions']
skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split']
for name in dir(jnp):
if name.startswith('_') or name in unimplemented:
continue
obj = getattr(jnp, name)
2024-05-03 11:04:43 -07:00
if isinstance(obj, type) or not callable(obj):
# Skip docstring checks for non-functions
pass
elif hasattr(np, name) and obj is getattr(np, name):
# Some APIs are imported directly from NumPy; we don't check these.
pass
elif name in aliases:
assert "Alias of" in obj.__doc__
elif name not in skip_args_check:
2024-05-03 11:04:43 -07:00
# Other functions should have nontrivial docs including "Args" and "Returns".
doc = obj.__doc__
self.assertNotEmpty(doc)
self.assertIn("Args:", doc, msg=f"'Args:' not found in docstring of jnp.{name}")
self.assertIn("Returns:", doc, msg=f"'Returns:' not found in docstring of jnp.{name}")
if name not in ["frompyfunc", "isdtype", "promote_types"]:
self.assertIn("Examples:", doc, msg=f"'Examples:' not found in docstring of jnp.{name}")
2018-11-17 18:03:33 -08:00
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())