2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2018-11-17 18:03:33 -08:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
import collections
|
2023-07-21 14:20:39 -04:00
|
|
|
from collections.abc import Iterator
|
2022-05-17 13:20:38 -07:00
|
|
|
import copy
|
2019-02-06 08:40:43 -05:00
|
|
|
from functools import partial
|
2020-07-01 16:52:41 -07:00
|
|
|
import inspect
|
2021-11-09 09:43:46 -08:00
|
|
|
import io
|
2018-11-17 18:03:33 -08:00
|
|
|
import itertools
|
2023-02-28 12:40:30 -08:00
|
|
|
import math
|
2023-06-16 11:02:17 -04:00
|
|
|
import platform
|
2023-07-21 14:20:39 -04:00
|
|
|
from typing import cast, Optional
|
2019-04-01 09:23:00 -07:00
|
|
|
import unittest
|
2019-04-09 18:59:42 -07:00
|
|
|
from unittest import SkipTest
|
2019-08-23 17:05:32 -07:00
|
|
|
import warnings
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
2020-05-20 01:43:48 -03:00
|
|
|
import numpy as np
|
2020-08-18 09:40:57 -07:00
|
|
|
try:
|
|
|
|
import numpy_dispatch
|
2022-08-18 11:38:31 -07:00
|
|
|
except ImportError:
|
2020-08-18 09:40:57 -07:00
|
|
|
numpy_dispatch = None
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-10-15 15:01:52 -04:00
|
|
|
import jax
|
2019-05-28 10:30:58 -04:00
|
|
|
import jax.ops
|
2018-12-17 14:42:32 -08:00
|
|
|
from jax import lax
|
2020-03-06 14:59:51 -05:00
|
|
|
from jax import numpy as jnp
|
Change JAX type promotion to prefer inexact types. (#1815)
Change the JAX type promotion table to prefer inexact types during type promotion.
NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost.
This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators:
e.g.,
```
import numpy as onp
from jax import numpy as np
In [1]: onp.promote_types(onp.float32, onp.int32)
Out[1]: dtype('float64')
In [2]: onp.promote_types(onp.float16, onp.int64)
Out[2]: dtype('float64')
In [3]: np.promote_types(onp.float32, onp.int32)
Out[3]: dtype('float32')
In [4]: np.promote_types(onp.float16, onp.int64)
Out[4]: dtype('float16')
```
This change is in preparation for enabling x64 mode by default on all platforms.
2019-12-05 10:57:23 -05:00
|
|
|
from jax import tree_util
|
2019-08-31 22:08:03 -07:00
|
|
|
from jax.test_util import check_grads
|
2022-03-09 18:18:16 -08:00
|
|
|
|
2023-02-14 23:00:40 -08:00
|
|
|
from jax._src import core
|
2022-03-09 18:18:16 -08:00
|
|
|
from jax._src import dtypes
|
|
|
|
from jax._src import test_util as jtu
|
|
|
|
from jax._src.lax import lax as lax_internal
|
|
|
|
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
|
2023-08-07 19:08:41 +02:00
|
|
|
from jax._src.util import safe_zip, NumpyComplexWarning
|
2022-09-27 10:06:10 -07:00
|
|
|
from jax._src import array
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2023-04-21 11:51:22 -07:00
|
|
|
from jax import config
|
2018-12-06 18:37:59 -05:00
|
|
|
config.parse_flags_with_absl()
|
2018-11-29 12:30:34 -08:00
|
|
|
FLAGS = config.FLAGS
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-12-08 19:40:56 +00:00
|
|
|
numpy_version = jtu.numpy_version()
|
2020-08-10 10:10:59 -07:00
|
|
|
|
2018-12-30 18:07:50 -08:00
|
|
|
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
|
|
|
|
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
2020-04-09 22:50:10 -07:00
|
|
|
one_dim_array_shapes = [(1,), (6,), (12,)]
|
2018-12-10 08:42:11 -05:00
|
|
|
empty_array_shapes = [(0,), (0, 4), (3, 0),]
|
2018-12-06 06:21:38 -08:00
|
|
|
|
2019-05-19 12:44:51 -07:00
|
|
|
scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE]
|
2018-12-10 08:42:11 -05:00
|
|
|
array_shapes = nonempty_array_shapes + empty_array_shapes
|
2019-04-14 06:01:46 +02:00
|
|
|
nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes
|
2018-12-10 08:42:11 -05:00
|
|
|
nonempty_shapes = scalar_shapes + nonempty_array_shapes
|
2021-02-05 10:07:41 -08:00
|
|
|
all_shapes = scalar_shapes + array_shapes
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-07 17:01:38 -07:00
|
|
|
float_dtypes = jtu.dtypes.all_floating
|
|
|
|
complex_dtypes = jtu.dtypes.complex
|
2020-07-23 16:17:55 -04:00
|
|
|
int_dtypes = jtu.dtypes.all_integer
|
|
|
|
unsigned_dtypes = jtu.dtypes.all_unsigned
|
2020-07-07 17:01:38 -07:00
|
|
|
bool_dtypes = jtu.dtypes.boolean
|
2018-11-17 18:03:33 -08:00
|
|
|
default_dtypes = float_dtypes + int_dtypes
|
2019-01-11 14:49:42 -05:00
|
|
|
inexact_dtypes = float_dtypes + complex_dtypes
|
2022-02-07 08:59:44 -08:00
|
|
|
number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes
|
2019-01-11 14:49:42 -05:00
|
|
|
all_dtypes = number_dtypes + bool_dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
|
2020-03-06 14:59:51 -05:00
|
|
|
python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
|
2020-08-06 03:36:46 +02:00
|
|
|
# uint64 is problematic because with any uint type it promotes to float:
|
|
|
|
int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64]
|
|
|
|
|
2021-12-09 09:47:21 -08:00
|
|
|
def _indexer_with_default_outputs(indexer, use_defaults=True):
|
|
|
|
"""Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs"""
|
|
|
|
class Indexer:
|
|
|
|
@partial(jtu.with_jax_dtype_defaults, use_defaults=use_defaults)
|
|
|
|
def __getitem__(self, *args):
|
|
|
|
return indexer.__getitem__(*args)
|
|
|
|
return Indexer()
|
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
def _valid_dtypes_for_shape(shape, dtypes):
|
|
|
|
# Not all (shape, dtype) pairs are valid. In particular, Python scalars only
|
|
|
|
# have one type in each category (float, bool, etc.)
|
|
|
|
if shape is jtu.PYTHON_SCALAR_SHAPE:
|
|
|
|
return [t for t in dtypes if t in python_scalar_dtypes]
|
|
|
|
return dtypes
|
|
|
|
|
|
|
|
def _shape_and_dtypes(shapes, dtypes):
|
|
|
|
for shape in shapes:
|
|
|
|
for dtype in _valid_dtypes_for_shape(shape, dtypes):
|
|
|
|
yield (shape, dtype)
|
|
|
|
|
2020-11-20 14:25:58 -08:00
|
|
|
def _compatible_shapes(shape):
|
2022-06-23 11:46:51 -07:00
|
|
|
if np.ndim(shape) == 0 or shape in scalar_shapes:
|
2020-11-20 14:25:58 -08:00
|
|
|
return [shape]
|
|
|
|
return (shape[n:] for n in range(len(shape) + 1))
|
|
|
|
|
2018-12-10 08:42:11 -05:00
|
|
|
OpRecord = collections.namedtuple(
|
|
|
|
"OpRecord",
|
2019-11-11 12:51:15 -08:00
|
|
|
["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes",
|
2022-01-20 12:03:49 -08:00
|
|
|
"test_name", "check_dtypes", "tolerance", "inexact", "kwargs"])
|
2019-10-22 19:53:59 -04:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
|
2020-06-01 17:19:23 -04:00
|
|
|
test_name=None, check_dtypes=True,
|
2022-01-20 12:03:49 -08:00
|
|
|
tolerance=None, inexact=False, kwargs=None):
|
2018-11-17 18:03:33 -08:00
|
|
|
test_name = test_name or name
|
2019-11-11 12:51:15 -08:00
|
|
|
return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes,
|
2022-01-20 12:03:49 -08:00
|
|
|
test_name, check_dtypes, tolerance, inexact, kwargs)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-09-26 13:31:43 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
JAX_ARGMINMAX_RECORDS = [
|
2020-07-01 11:01:22 -04:00
|
|
|
op_record("argmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
|
|
|
|
op_record("argmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
|
|
|
|
op_record("nanargmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
|
|
|
|
op_record("nanargmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
|
|
|
|
2018-12-10 08:42:11 -05:00
|
|
|
def _shapes_are_broadcast_compatible(shapes):
|
2022-02-10 13:19:49 -08:00
|
|
|
try:
|
|
|
|
lax.broadcast_shapes(*(() if s in scalar_shapes else s for s in shapes))
|
|
|
|
except ValueError:
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return True
|
2018-12-10 08:42:11 -05:00
|
|
|
|
2019-06-24 10:34:48 -04:00
|
|
|
def _shapes_are_equal_length(shapes):
|
|
|
|
return all(len(shape) == len(shapes[0]) for shape in shapes[1:])
|
|
|
|
|
2018-12-06 13:25:42 -05:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|
|
|
"""Tests for LAX-backed Numpy implementation."""
|
|
|
|
|
2020-05-20 01:43:48 -03:00
|
|
|
def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True):
|
Change JAX type promotion to prefer inexact types. (#1815)
Change the JAX type promotion table to prefer inexact types during type promotion.
NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost.
This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators:
e.g.,
```
import numpy as onp
from jax import numpy as np
In [1]: onp.promote_types(onp.float32, onp.int32)
Out[1]: dtype('float64')
In [2]: onp.promote_types(onp.float16, onp.int64)
Out[2]: dtype('float64')
In [3]: np.promote_types(onp.float32, onp.int32)
Out[3]: dtype('float32')
In [4]: np.promote_types(onp.float16, onp.int64)
Out[4]: dtype('float16')
```
This change is in preparation for enabling x64 mode by default on all platforms.
2019-12-05 10:57:23 -05:00
|
|
|
def f():
|
2020-03-06 14:59:51 -05:00
|
|
|
out = [rng(shape, dtype or jnp.float_)
|
2019-12-09 21:18:39 -05:00
|
|
|
for shape, dtype in zip(shapes, dtypes)]
|
2020-05-20 01:43:48 -03:00
|
|
|
if np_arrays:
|
2020-01-23 10:11:58 -05:00
|
|
|
return out
|
2020-05-20 01:43:48 -03:00
|
|
|
return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a
|
2020-01-23 10:11:58 -05:00
|
|
|
for a in out]
|
Change JAX type promotion to prefer inexact types. (#1815)
Change the JAX type promotion table to prefer inexact types during type promotion.
NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost.
This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators:
e.g.,
```
import numpy as onp
from jax import numpy as np
In [1]: onp.promote_types(onp.float32, onp.int32)
Out[1]: dtype('float64')
In [2]: onp.promote_types(onp.float16, onp.int64)
Out[2]: dtype('float64')
In [3]: np.promote_types(onp.float32, onp.int32)
Out[3]: dtype('float32')
In [4]: np.promote_types(onp.float16, onp.int64)
Out[4]: dtype('float16')
```
This change is in preparation for enabling x64 mode by default on all platforms.
2019-12-05 10:57:23 -05:00
|
|
|
return f
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@parameterized.parameters(
|
|
|
|
[dtype for dtype in [jnp.bool_, jnp.uint8, jnp.uint16, jnp.uint32,
|
|
|
|
jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
|
|
|
jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64,
|
|
|
|
jnp.complex64, jnp.complex128]
|
|
|
|
if dtype == dtypes.canonicalize_dtype(dtype)])
|
2022-07-08 11:16:40 -07:00
|
|
|
def testDtypeWrappers(self, dtype):
|
|
|
|
arr = dtype(0)
|
2023-02-15 14:52:31 -08:00
|
|
|
self.assertIsInstance(arr, jax.Array)
|
2022-07-08 11:16:40 -07:00
|
|
|
self.assertEqual(arr.dtype, np.dtype(dtype))
|
|
|
|
self.assertArraysEqual(arr, 0, check_dtypes=False)
|
|
|
|
|
|
|
|
# No copy primitive is generated
|
|
|
|
jaxpr = jax.make_jaxpr(dtype)(0)
|
|
|
|
prims = [eqn.primitive for eqn in jaxpr.eqns]
|
|
|
|
self.assertEqual(prims, [lax.convert_element_type_p]) # No copy generated.
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=float_dtypes + [object],
|
|
|
|
allow_pickle=[True, False],
|
|
|
|
)
|
2022-02-25 09:27:42 -08:00
|
|
|
def testLoad(self, dtype, allow_pickle):
|
|
|
|
if dtype == object and not allow_pickle:
|
|
|
|
self.skipTest("dtype=object requires allow_pickle=True")
|
2021-11-09 09:43:46 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
arr = rng((10), dtype)
|
|
|
|
with io.BytesIO() as f:
|
|
|
|
jnp.save(f, arr)
|
|
|
|
f.seek(0)
|
2022-02-25 09:27:42 -08:00
|
|
|
arr_out = jnp.load(f, allow_pickle=allow_pickle)
|
2021-11-09 09:43:46 -08:00
|
|
|
self.assertArraysEqual(arr, arr_out)
|
|
|
|
|
2020-09-21 10:37:55 -07:00
|
|
|
def testArrayEqualExamples(self):
|
|
|
|
# examples from the array_equal() docstring.
|
|
|
|
self.assertTrue(jnp.array_equal([1, 2], [1, 2]))
|
|
|
|
self.assertTrue(jnp.array_equal(np.array([1, 2]), np.array([1, 2])))
|
|
|
|
self.assertFalse(jnp.array_equal([1, 2], [1, 2, 3]))
|
|
|
|
self.assertFalse(jnp.array_equal([1, 2], [1, 4]))
|
|
|
|
|
|
|
|
a = np.array([1, np.nan])
|
|
|
|
self.assertFalse(jnp.array_equal(a, a))
|
|
|
|
self.assertTrue(jnp.array_equal(a, a, equal_nan=True))
|
|
|
|
|
|
|
|
a = np.array([1 + 1j])
|
|
|
|
b = a.copy()
|
|
|
|
a.real = np.nan
|
|
|
|
b.imag = np.nan
|
|
|
|
self.assertTrue(jnp.array_equal(a, b, equal_nan=True))
|
|
|
|
|
|
|
|
def testArrayEquivExamples(self):
|
|
|
|
# examples from the array_equiv() docstring.
|
|
|
|
self.assertTrue(jnp.array_equiv([1, 2], [1, 2]))
|
|
|
|
self.assertFalse(jnp.array_equiv([1, 2], [1, 3]))
|
2021-07-13 11:38:21 -07:00
|
|
|
with jax.numpy_rank_promotion('allow'):
|
|
|
|
self.assertTrue(jnp.array_equiv([1, 2], [[1, 2], [1, 2]]))
|
|
|
|
self.assertFalse(jnp.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]]))
|
|
|
|
self.assertFalse(jnp.array_equiv([1, 2], [[1, 2], [1, 3]]))
|
2020-09-21 10:37:55 -07:00
|
|
|
|
2020-08-18 09:40:57 -07:00
|
|
|
def testArrayModule(self):
|
|
|
|
if numpy_dispatch is None:
|
|
|
|
raise SkipTest('requires https://github.com/seberg/numpy-dispatch')
|
|
|
|
|
|
|
|
jnp_array = jnp.array(1.0)
|
|
|
|
np_array = np.array(1.0)
|
|
|
|
|
|
|
|
module = numpy_dispatch.get_array_module(jnp_array)
|
|
|
|
self.assertIs(module, jnp)
|
|
|
|
|
|
|
|
module = numpy_dispatch.get_array_module(jnp_array, np_array)
|
|
|
|
self.assertIs(module, jnp)
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
module = numpy_dispatch.get_array_module(x)
|
|
|
|
self.assertIs(module, jnp)
|
|
|
|
return x
|
|
|
|
jax.jit(f)(jnp_array)
|
|
|
|
jax.grad(f)(jnp_array)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
|
|
|
for shape in all_shapes
|
|
|
|
for axis in list(range(-len(shape), len(shape)))],
|
|
|
|
discont=[None, "pi", 2],
|
|
|
|
period=["2pi", "pi"],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2022-04-22 12:00:34 -07:00
|
|
|
def testUnwrap(self, shape, dtype, axis, discont, period):
|
|
|
|
special_vals = {"pi": np.pi, "2pi": 2 * np.pi}
|
|
|
|
period = special_vals.get(period, period)
|
|
|
|
discont = special_vals.get(discont, discont)
|
|
|
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2022-09-23 17:11:10 -07:00
|
|
|
|
|
|
|
def np_fun(x):
|
|
|
|
dtype = None
|
|
|
|
if x.dtype == dtypes.bfloat16:
|
|
|
|
dtype = x.dtype
|
|
|
|
x = x.astype(np.float32)
|
2023-06-23 01:22:49 -07:00
|
|
|
out = np.unwrap(x, axis=axis, discont=discont, period=period)
|
2022-09-23 17:11:10 -07:00
|
|
|
return out if dtype is None else out.astype(dtype)
|
|
|
|
|
2022-04-22 12:00:34 -07:00
|
|
|
jnp_fun = partial(jnp.unwrap, axis=axis, discont=discont, period=period)
|
2022-06-14 11:20:37 -07:00
|
|
|
if not dtypes.issubdtype(dtype, np.inexact):
|
|
|
|
# This case requires implicit dtype promotion
|
|
|
|
jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun)
|
2022-04-22 12:00:34 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2022-09-23 17:11:10 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
|
|
|
atol={dtypes.bfloat16: 1e-1, np.float16: 1e-2})
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, atol={dtypes.bfloat16: 1e-1})
|
2022-04-22 12:00:34 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
|
|
|
for shape in all_shapes
|
|
|
|
for axis in list(range(-len(shape), len(shape))) + [None]],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
2018-12-20 15:36:37 -05:00
|
|
|
def testCountNonzero(self, shape, dtype, axis):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda x: np.count_nonzero(x, axis)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda x: jnp.count_nonzero(x, axis)
|
2018-12-20 15:36:37 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-05-20 01:43:48 -03:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-12-20 15:36:37 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
|
2019-12-20 18:42:33 -05:00
|
|
|
def testNonzero(self, shape, dtype):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda x: np.nonzero(x)
|
|
|
|
np_fun = jtu.ignore_warning(
|
2020-04-12 15:35:35 -04:00
|
|
|
category=DeprecationWarning,
|
2020-05-20 01:43:48 -03:00
|
|
|
message="Calling nonzero on 0d arrays.*")(np_fun)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda x: jnp.nonzero(x)
|
2019-12-20 18:42:33 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-05-20 01:43:48 -03:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
2019-12-20 18:42:33 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, fill_value=fill_value)
|
2021-06-08 16:40:53 -07:00
|
|
|
for shape in nonempty_array_shapes
|
2021-10-13 15:57:39 -07:00
|
|
|
for fill_value in [None, -1, shape or (1,)]
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
size=[1, 5, 10],
|
|
|
|
)
|
2021-08-11 11:54:59 -07:00
|
|
|
def testNonzeroSize(self, shape, dtype, size, fill_value):
|
2021-06-08 16:40:53 -07:00
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
|
|
|
|
def np_fun(x):
|
|
|
|
result = np.nonzero(x)
|
|
|
|
if size <= len(result[0]):
|
|
|
|
return tuple(arg[:size] for arg in result)
|
|
|
|
else:
|
2021-10-13 15:57:39 -07:00
|
|
|
fillvals = fill_value if np.ndim(fill_value) else len(result) * [fill_value or 0]
|
|
|
|
return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
|
|
|
|
for fval, arg in safe_zip(fillvals, result))
|
2021-08-11 11:54:59 -07:00
|
|
|
jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value)
|
2021-06-08 16:40:53 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
2021-04-20 09:18:26 -07:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
|
2020-05-27 17:08:12 -07:00
|
|
|
def testFlatNonzero(self, shape, dtype):
|
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
|
|
|
np_fun = jtu.ignore_warning(
|
|
|
|
category=DeprecationWarning,
|
|
|
|
message="Calling nonzero on 0d arrays.*")(np.flatnonzero)
|
|
|
|
jnp_fun = jnp.flatnonzero
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
|
|
|
|
2021-06-07 12:53:37 -07:00
|
|
|
# JIT compilation requires specifying the size statically:
|
|
|
|
jnp_fun = lambda x: jnp.flatnonzero(x, size=np.size(x) // 2)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=nonempty_array_shapes,
|
|
|
|
dtype=all_dtypes,
|
|
|
|
fill_value=[None, -1, 10, (-1,), (10,)],
|
|
|
|
size=[1, 5, 10],
|
|
|
|
)
|
2021-12-13 16:24:12 +00:00
|
|
|
def testFlatNonzeroSize(self, shape, dtype, size, fill_value):
|
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
|
|
|
|
def np_fun(x):
|
|
|
|
result = np.flatnonzero(x)
|
|
|
|
if size <= len(result):
|
|
|
|
return result[:size]
|
|
|
|
else:
|
|
|
|
fill_val = fill_value or 0
|
|
|
|
return np.concatenate([result, np.full(size - len(result), fill_val, result.dtype)])
|
|
|
|
jnp_fun = lambda x: jnp.flatnonzero(x, size=size, fill_value=fill_value)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
|
2020-05-27 17:08:12 -07:00
|
|
|
def testArgWhere(self, shape, dtype):
|
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
|
|
|
np_fun = jtu.ignore_warning(
|
|
|
|
category=DeprecationWarning,
|
|
|
|
message="Calling nonzero on 0d arrays.*")(np.argwhere)
|
|
|
|
jnp_fun = jnp.argwhere
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
|
|
|
|
2021-06-08 14:05:30 -07:00
|
|
|
# JIT compilation requires specifying a size statically. Full test of this
|
|
|
|
# behavior is in testNonzeroSize().
|
|
|
|
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, fill_value=fill_value)
|
2021-12-13 16:24:12 +00:00
|
|
|
for shape in nonempty_array_shapes
|
|
|
|
for fill_value in [None, -1, shape or (1,)]
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
size=[1, 5, 10],
|
|
|
|
)
|
2021-12-13 16:24:12 +00:00
|
|
|
def testArgWhereSize(self, shape, dtype, size, fill_value):
|
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
|
|
|
|
def np_fun(x):
|
|
|
|
result = np.argwhere(x)
|
|
|
|
if size <= len(result):
|
|
|
|
return result[:size]
|
|
|
|
else:
|
2021-12-14 10:01:32 +00:00
|
|
|
fillvals = fill_value if np.ndim(fill_value) else result.shape[-1] * [fill_value or 0]
|
2021-12-15 08:32:42 +00:00
|
|
|
return np.empty((size, 0), dtype=int) if np.ndim(x) == 0 else np.stack([np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
|
2021-12-14 17:47:13 +00:00
|
|
|
for fval, arg in safe_zip(fillvals, result.T)]).T
|
2021-12-13 16:24:12 +00:00
|
|
|
jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name),
|
|
|
|
shape=shape, dtype=dtype, axis=axis, rng_factory=rec.rng_factory)
|
2018-11-17 18:03:33 -08:00
|
|
|
for rec in JAX_ARGMINMAX_RECORDS
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes)
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in range(-len(shape), len(shape))],
|
|
|
|
keepdims=[False, True],
|
|
|
|
)
|
2022-01-20 15:53:58 -08:00
|
|
|
def testArgMinMax(self, np_op, jnp_op, rng_factory, shape, dtype, axis, keepdims):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
if dtype == np.complex128 and jtu.device_under_test() == "gpu":
|
2019-04-01 09:23:00 -07:00
|
|
|
raise unittest.SkipTest("complex128 reductions not supported on GPU")
|
2020-05-20 01:43:48 -03:00
|
|
|
if "nan" in np_op.__name__ and dtype == jnp.bfloat16:
|
2020-04-28 22:23:03 +03:00
|
|
|
raise unittest.SkipTest("NumPy doesn't correctly handle bfloat16 arrays")
|
2022-01-20 15:53:58 -08:00
|
|
|
kwds = {"keepdims": True} if keepdims else {}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-01-20 15:53:58 -08:00
|
|
|
np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=axis, **kwds))
|
|
|
|
jnp_fun = partial(jnp_op, axis=axis, **kwds)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-04-28 22:23:03 +03:00
|
|
|
try:
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
2020-04-28 22:23:03 +03:00
|
|
|
except ValueError as e:
|
|
|
|
if str(e) == "All-NaN slice encountered":
|
|
|
|
self.skipTest("JAX doesn't support checking for all-NaN slices")
|
|
|
|
else:
|
|
|
|
raise
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2019-03-25 17:42:08 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=rec.name, np_op=getattr(np, rec.name),
|
|
|
|
jnp_op=getattr(jnp, rec.name))
|
|
|
|
for rec in JAX_ARGMINMAX_RECORDS],
|
|
|
|
)
|
2020-05-20 01:43:48 -03:00
|
|
|
def testArgMinMaxEmpty(self, name, np_op, jnp_op):
|
2020-04-30 08:31:48 -07:00
|
|
|
name = name[3:] if name.startswith("nan") else name
|
2022-05-12 19:13:00 +01:00
|
|
|
msg = f"attempt to get {name} of an empty sequence"
|
2020-04-30 08:31:48 -07:00
|
|
|
with self.assertRaises(ValueError, msg=msg):
|
2020-05-20 01:43:48 -03:00
|
|
|
jnp_op(np.array([]))
|
2020-04-30 15:56:59 -07:00
|
|
|
with self.assertRaises(ValueError, msg=msg):
|
2020-05-20 01:43:48 -03:00
|
|
|
jnp_op(np.zeros((2, 0)), axis=1)
|
2021-12-09 09:47:21 -08:00
|
|
|
np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=0))
|
2020-04-30 15:56:59 -07:00
|
|
|
jnp_fun = partial(jnp_op, axis=0)
|
2020-05-20 01:43:48 -03:00
|
|
|
args_maker = lambda: [np.zeros((2, 0))]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-04-30 08:31:48 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes)
|
2019-03-25 17:42:08 -05:00
|
|
|
for lhs_shape, rhs_shape, axes in [
|
|
|
|
[(2,), (2,), (-1, -1, -1, None)], # scalar output
|
|
|
|
[(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors
|
|
|
|
[(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors
|
|
|
|
[(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting
|
|
|
|
[(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes
|
|
|
|
[(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting
|
|
|
|
[(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors
|
|
|
|
[(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting
|
|
|
|
[(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing
|
|
|
|
[(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)] # same as before
|
2022-10-05 01:52:41 +00:00
|
|
|
]],
|
|
|
|
lhs_dtype=number_dtypes,
|
|
|
|
rhs_dtype=number_dtypes,
|
|
|
|
)
|
2021-07-13 11:38:21 -07:00
|
|
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
2020-12-02 17:00:25 -08:00
|
|
|
def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-03-25 17:42:08 -05:00
|
|
|
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
|
|
|
axisa, axisb, axisc, axis = axes
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda a, b: jnp.cross(a, b, axisa, axisb, axisc, axis)
|
2020-05-20 01:43:48 -03:00
|
|
|
def np_fun(a, b):
|
|
|
|
a = a.astype(np.float32) if lhs_dtype == jnp.bfloat16 else a
|
|
|
|
b = b.astype(np.float32) if rhs_dtype == jnp.bfloat16 else b
|
|
|
|
out = np.cross(a, b, axisa, axisb, axisc, axis)
|
2020-03-06 14:59:51 -05:00
|
|
|
return out.astype(jnp.promote_types(lhs_dtype, rhs_dtype))
|
2020-05-20 01:43:48 -03:00
|
|
|
tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15}
|
2019-11-16 13:51:42 -05:00
|
|
|
tol = max(jtu.tolerance(lhs_dtype, tol_spec),
|
|
|
|
jtu.tolerance(rhs_dtype, tol_spec))
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
2023-07-21 09:31:43 -07:00
|
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((3, 3), ()),
|
|
|
|
((), (3, 3)),
|
|
|
|
((4, 5), (5,)),
|
|
|
|
((6,), (6, 4)),
|
|
|
|
((3, 4), (4, 5)),
|
|
|
|
((4, 3, 2), (2,)),
|
|
|
|
((2,), (3, 2, 4)),
|
|
|
|
((4, 3, 2), (2, 5)),
|
|
|
|
((5, 2), (3, 2, 4)),
|
|
|
|
((2, 3, 4), (5, 4, 1))]],
|
2022-10-05 01:52:41 +00:00
|
|
|
lhs_dtype=number_dtypes,
|
|
|
|
rhs_dtype=number_dtypes,
|
|
|
|
)
|
2022-11-15 13:49:42 -08:00
|
|
|
@jax.default_matmul_precision("float32")
|
2023-07-21 09:31:43 -07:00
|
|
|
def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
2022-11-15 13:49:42 -08:00
|
|
|
tol = {np.float16: 1e-2, np.float32: 2e-5, np.float64: 1e-14,
|
2020-05-20 01:43:48 -03:00
|
|
|
np.complex128: 1e-14}
|
2023-07-21 09:31:43 -07:00
|
|
|
if (lhs_dtype in [np.float16, jnp.bfloat16] and
|
|
|
|
rhs_dtype in [np.float16, jnp.bfloat16]):
|
|
|
|
tol = 1e-2
|
2020-05-20 01:43:48 -03:00
|
|
|
def np_dot(x, y):
|
|
|
|
x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x
|
|
|
|
y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y
|
|
|
|
return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype))
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, tol=tol)
|
|
|
|
self._CompileAndCheck(jnp.dot, args_maker, atol=tol, rtol=tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2023-07-21 09:31:43 -07:00
|
|
|
@jtu.sample_product(
|
|
|
|
lhs_dtype=number_dtypes,
|
|
|
|
rhs_dtype=number_dtypes,
|
|
|
|
)
|
|
|
|
@jax.numpy_dtype_promotion('standard')
|
|
|
|
def testMixedPrecisionDot(self, lhs_dtype, rhs_dtype):
|
|
|
|
# This test confirms that jnp.dot lowers to a single dot_general call,
|
|
|
|
# avoiding explicit type casting of inputs and outputs.
|
|
|
|
lhs = jax.ShapeDtypeStruct((5,), lhs_dtype)
|
|
|
|
rhs = jax.ShapeDtypeStruct((5,), rhs_dtype)
|
|
|
|
jaxpr = jax.make_jaxpr(jnp.dot)(lhs, rhs)
|
|
|
|
prims = [eqn.primitive for eqn in jaxpr.eqns]
|
|
|
|
self.assertIn(prims, [
|
|
|
|
[lax.dot_general_p],
|
|
|
|
[lax.dot_general_p, lax.convert_element_type_p]
|
|
|
|
])
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
for name, lhs_shape, rhs_shape in [
|
|
|
|
("vector-vector", (3,), (3,)),
|
|
|
|
("matrix-vector", (3, 3), (3,)),
|
|
|
|
("vector-matrix", (3,), (3, 3)),
|
|
|
|
("matrix-matrix", (3, 3), (3, 3)),
|
|
|
|
("vector-tensor", (3,), (5, 3, 2)),
|
|
|
|
("tensor-vector", (5, 3, 2), (2,)),
|
|
|
|
("matrix-tensor", (5, 2), (3, 2, 4)),
|
|
|
|
("tensor-matrix", (5, 2, 3), (3, 2)),
|
|
|
|
("tensor-tensor", (5, 3, 4), (5, 4, 1)),
|
2022-10-05 01:52:41 +00:00
|
|
|
("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]],
|
|
|
|
lhs_dtype=number_dtypes,
|
|
|
|
rhs_dtype=number_dtypes,
|
|
|
|
)
|
2022-11-09 18:57:28 -08:00
|
|
|
@jax.default_matmul_precision("float32")
|
2022-10-05 01:52:41 +00:00
|
|
|
def testMatmul(self, name, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
def np_fun(x, y):
|
2020-03-06 14:59:51 -05:00
|
|
|
dtype = jnp.promote_types(lhs_dtype, rhs_dtype)
|
2020-05-20 01:43:48 -03:00
|
|
|
return np.matmul(x, y).astype(dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
2020-05-20 01:43:48 -03:00
|
|
|
tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12,
|
|
|
|
np.complex128: 1e-12}
|
2022-06-14 11:20:37 -07:00
|
|
|
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol)
|
|
|
|
self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes)
|
2018-12-15 21:59:18 -08:00
|
|
|
for lhs_shape, rhs_shape, axes in [
|
2020-02-07 12:49:50 -05:00
|
|
|
[(3,), (), 0],
|
2019-05-20 17:19:20 -07:00
|
|
|
[(2, 3, 4), (5, 6, 7), 0], # from issue #740
|
2018-12-15 21:59:18 -08:00
|
|
|
[(2, 3, 4), (3, 4, 5, 6), 2],
|
|
|
|
[(2, 3, 4), (5, 4, 3, 6), [1, 2]],
|
|
|
|
[(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]],
|
|
|
|
[(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]],
|
2022-10-05 01:52:41 +00:00
|
|
|
]],
|
|
|
|
lhs_dtype=number_dtypes,
|
|
|
|
rhs_dtype=number_dtypes,
|
|
|
|
)
|
2022-11-15 13:49:42 -08:00
|
|
|
@jax.default_matmul_precision("float32")
|
2020-12-02 17:00:25 -08:00
|
|
|
def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-12-15 21:59:18 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda a, b: jnp.tensordot(a, b, axes)
|
2020-05-20 01:43:48 -03:00
|
|
|
def np_fun(a, b):
|
|
|
|
a = a if lhs_dtype != jnp.bfloat16 else a.astype(np.float32)
|
|
|
|
b = b if rhs_dtype != jnp.bfloat16 else b.astype(np.float32)
|
2020-03-06 14:59:51 -05:00
|
|
|
dtype = jnp.promote_types(lhs_dtype, rhs_dtype)
|
2020-05-20 01:43:48 -03:00
|
|
|
return np.tensordot(a, b, axes).astype(dtype)
|
|
|
|
tol = {np.float16: 1e-1, np.float32: 1e-3, np.float64: 1e-12,
|
|
|
|
np.complex64: 1e-3, np.complex128: 1e-12}
|
2022-06-14 11:20:37 -07:00
|
|
|
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
|
2023-09-05 18:48:18 -07:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, tol=tol)
|
2018-12-15 21:59:18 -08:00
|
|
|
|
2020-01-29 10:14:36 -05:00
|
|
|
def testTensordotErrors(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
a = self.rng().random((3, 2, 2))
|
|
|
|
b = self.rng().random((2,))
|
2020-01-29 10:14:36 -05:00
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError, "Number of tensordot axes.*exceeds input ranks.*",
|
2020-03-06 14:59:51 -05:00
|
|
|
lambda: jnp.tensordot(a, b, axes=2))
|
2020-01-29 10:14:36 -05:00
|
|
|
|
2020-04-09 20:36:09 +01:00
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError, "tensordot requires axes lists to have equal length.*",
|
|
|
|
lambda: jnp.tensordot(a, b, axes=([0], [0, 1])))
|
|
|
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError, "tensordot requires both axes lists to be either ints, tuples or lists.*",
|
|
|
|
lambda: jnp.tensordot(a, b, axes=('bad', 'axes')))
|
|
|
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError, "tensordot axes argument must be an int, a pair of ints, or a pair of lists.*",
|
|
|
|
lambda: jnp.tensordot(a, b, axes='badaxes'))
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
element_shape=all_shapes,
|
|
|
|
test_shape=all_shapes,
|
|
|
|
dtype=default_dtypes,
|
|
|
|
invert=[False, True],
|
|
|
|
)
|
2020-05-19 16:58:42 -07:00
|
|
|
def testIsin(self, element_shape, test_shape, dtype, invert):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)]
|
|
|
|
jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert)
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda e, t: np.isin(e, t, invert=invert)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-05-19 16:58:42 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype1=[s for s in default_dtypes if s != jnp.bfloat16],
|
|
|
|
dtype2=[s for s in default_dtypes if s != jnp.bfloat16],
|
|
|
|
shape1=all_shapes,
|
|
|
|
shape2=all_shapes,
|
|
|
|
)
|
2020-10-17 23:48:39 +03:00
|
|
|
def testSetdiff1d(self, shape1, shape2, dtype1, dtype2):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
|
2022-06-14 11:20:37 -07:00
|
|
|
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(np.setdiff1d, jnp.setdiff1d, args_maker)
|
2020-05-19 16:58:42 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype1=[s for s in default_dtypes if s != jnp.bfloat16],
|
|
|
|
dtype2=[s for s in default_dtypes if s != jnp.bfloat16],
|
|
|
|
shape1=all_shapes,
|
|
|
|
shape2=all_shapes,
|
|
|
|
size=[1, 5, 10],
|
|
|
|
fill_value=[None, -1],
|
|
|
|
)
|
2021-10-08 16:48:28 -07:00
|
|
|
def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
|
|
|
|
def np_fun(arg1, arg2):
|
|
|
|
result = np.setdiff1d(arg1, arg2)
|
|
|
|
if size <= len(result):
|
|
|
|
return result[:size]
|
|
|
|
else:
|
|
|
|
return np.pad(result, (0, size-len(result)), constant_values=fill_value or 0)
|
|
|
|
def jnp_fun(arg1, arg2):
|
|
|
|
return jnp.setdiff1d(arg1, arg2, size=size, fill_value=fill_value)
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2021-10-08 16:48:28 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype1=[s for s in default_dtypes if s != jnp.bfloat16],
|
|
|
|
dtype2=[s for s in default_dtypes if s != jnp.bfloat16],
|
|
|
|
shape1=nonempty_nonscalar_array_shapes,
|
|
|
|
shape2=nonempty_nonscalar_array_shapes,
|
|
|
|
)
|
2021-02-06 22:17:06 +05:30
|
|
|
def testUnion1d(self, shape1, shape2, dtype1, dtype2):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
|
|
|
|
def np_fun(arg1, arg2):
|
|
|
|
dtype = jnp.promote_types(arg1.dtype, arg2.dtype)
|
|
|
|
return np.union1d(arg1, arg2).astype(dtype)
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp.union1d, args_maker)
|
2021-02-06 22:17:06 +05:30
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype1=[s for s in default_dtypes if s != jnp.bfloat16],
|
|
|
|
dtype2=[s for s in default_dtypes if s != jnp.bfloat16],
|
|
|
|
shape1=nonempty_nonscalar_array_shapes,
|
|
|
|
shape2=nonempty_nonscalar_array_shapes,
|
|
|
|
size=[1, 5, 10],
|
|
|
|
fill_value=[None, -1],
|
|
|
|
)
|
2021-10-08 15:18:25 -07:00
|
|
|
def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value):
|
2021-06-09 11:36:34 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
|
|
|
|
def np_fun(arg1, arg2):
|
|
|
|
dtype = jnp.promote_types(arg1.dtype, arg2.dtype)
|
|
|
|
result = np.union1d(arg1, arg2).astype(dtype)
|
2021-10-08 15:18:25 -07:00
|
|
|
fv = result.min() if fill_value is None else fill_value
|
2021-06-09 11:36:34 -07:00
|
|
|
if size <= len(result):
|
|
|
|
return result[:size]
|
|
|
|
else:
|
2021-10-08 15:18:25 -07:00
|
|
|
return np.concatenate([result, np.full(size - len(result), fv, result.dtype)])
|
2021-06-09 11:36:34 -07:00
|
|
|
def jnp_fun(arg1, arg2):
|
2021-10-08 15:18:25 -07:00
|
|
|
return jnp.union1d(arg1, arg2, size=size, fill_value=fill_value)
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2021-06-09 11:36:34 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype1=[s for s in default_dtypes if s != jnp.bfloat16],
|
|
|
|
dtype2=[s for s in default_dtypes if s != jnp.bfloat16],
|
|
|
|
shape1=all_shapes,
|
|
|
|
shape2=all_shapes,
|
|
|
|
assume_unique=[False, True],
|
|
|
|
)
|
2021-02-07 00:08:07 +05:30
|
|
|
def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
|
|
|
|
jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique)
|
|
|
|
def np_fun(ar1, ar2):
|
|
|
|
if assume_unique:
|
|
|
|
# pre-flatten the arrays to match with jax implementation
|
|
|
|
ar1 = np.ravel(ar1)
|
|
|
|
ar2 = np.ravel(ar2)
|
|
|
|
return np.setxor1d(ar1, ar2, assume_unique)
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
2021-02-07 00:08:07 +05:30
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype1=[s for s in default_dtypes if s != jnp.bfloat16],
|
|
|
|
dtype2=[s for s in default_dtypes if s != jnp.bfloat16],
|
|
|
|
shape1=all_shapes,
|
|
|
|
shape2=all_shapes,
|
|
|
|
assume_unique=[False, True],
|
|
|
|
return_indices=[False, True],
|
|
|
|
)
|
|
|
|
def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique,
|
|
|
|
return_indices):
|
2020-07-13 08:32:41 +03:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
|
|
|
|
jnp_fun = lambda ar1, ar2: jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices)
|
|
|
|
np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices)
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
2020-07-13 08:32:41 +03:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(lhs_shape=lhs_shape, lhs_dtype=lhs_dtype,
|
|
|
|
rhs_shape=rhs_shape, rhs_dtype=rhs_dtype)
|
2018-12-19 09:20:00 -05:00
|
|
|
# TODO(phawkins): support integer dtypes too.
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
for lhs_shape, lhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes)
|
|
|
|
for rhs_shape, rhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes)
|
|
|
|
if len(jtu._dims_of_shape(lhs_shape)) == 0
|
|
|
|
or len(jtu._dims_of_shape(rhs_shape)) == 0
|
2022-10-05 01:52:41 +00:00
|
|
|
or lhs_shape[-1] == rhs_shape[-1]],
|
|
|
|
)
|
2023-02-14 12:01:35 -08:00
|
|
|
@jax.default_matmul_precision("float32")
|
2020-12-02 17:00:25 -08:00
|
|
|
def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-12-19 08:57:18 -05:00
|
|
|
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
2020-05-20 01:43:48 -03:00
|
|
|
def np_fun(lhs, rhs):
|
|
|
|
lhs = lhs if lhs_dtype != jnp.bfloat16 else lhs.astype(np.float32)
|
|
|
|
rhs = rhs if rhs_dtype != jnp.bfloat16 else rhs.astype(np.float32)
|
2020-03-06 14:59:51 -05:00
|
|
|
dtype = jnp.promote_types(lhs_dtype, rhs_dtype)
|
2020-05-20 01:43:48 -03:00
|
|
|
return np.inner(lhs, rhs).astype(dtype)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda lhs, rhs: jnp.inner(lhs, rhs)
|
2020-05-20 01:43:48 -03:00
|
|
|
tol_spec = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-13,
|
|
|
|
np.complex64: 1e-5}
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
tol = max(jtu.tolerance(lhs_dtype, tol_spec),
|
|
|
|
jtu.tolerance(rhs_dtype, tol_spec))
|
2019-10-22 19:53:59 -04:00
|
|
|
# TODO(phawkins): there are float32/float64 disagreements for some inputs.
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol)
|
2021-08-29 16:48:45 +05:30
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=[dt for dt in float_dtypes if dt not in [jnp.float16, jnp.bfloat16]],
|
|
|
|
shape=[shape for shape in one_dim_array_shapes if shape != (1,)],
|
|
|
|
deg=[1, 2, 3],
|
|
|
|
rcond=[None, -1, 10e-3, 10e-5, 10e-10],
|
|
|
|
full=[False, True],
|
|
|
|
w=[False, True],
|
|
|
|
cov=[False, True, "unscaled"],
|
|
|
|
)
|
2022-11-15 13:49:42 -08:00
|
|
|
@jax.default_matmul_precision("float32")
|
2021-08-29 16:48:45 +05:30
|
|
|
def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2021-08-31 17:14:49 -07:00
|
|
|
tol_spec = {np.float32: 1e-3, np.float64: 1e-13, np.complex64: 1e-5}
|
2021-08-29 16:48:45 +05:30
|
|
|
tol = jtu.tolerance(dtype, tol_spec)
|
|
|
|
_w = lambda a: abs(a) if w else None
|
|
|
|
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)]
|
|
|
|
jnp_fun = lambda x, y, a: jnp.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov)
|
|
|
|
np_fun = jtu.ignore_warning(
|
|
|
|
message="Polyfit may be poorly conditioned*")(lambda x, y, a: np.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov))
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol)
|
2018-12-19 08:57:18 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
args = args_maker()
|
|
|
|
if not full:
|
|
|
|
args = args_maker()
|
|
|
|
try:
|
|
|
|
np_out = np_fun(*args)
|
|
|
|
except ValueError:
|
|
|
|
return # https://github.com/numpy/numpy/issues/22380
|
|
|
|
jnp_out = jnp_fun(*args)
|
|
|
|
self.assertAllClose(np_out, jnp_out, atol=tol, rtol=tol,
|
|
|
|
check_dtypes=False)
|
|
|
|
else:
|
|
|
|
# Don't compare the residuals because jnp.linalg.lstsq acts slightly
|
|
|
|
# differently to remain `jit`-compatible.
|
|
|
|
np_p, _, nrank, nsingular_values, nrcond = np_fun(*args)
|
|
|
|
jp_p, _, jrank, jsingular_values, jrcond = jnp_fun(*args)
|
|
|
|
self.assertAllClose(
|
|
|
|
(np_p, nrank, nsingular_values, nrcond),
|
|
|
|
(jp_p, jrank, jsingular_values, jrcond),
|
|
|
|
atol=tol, rtol=tol, check_dtypes=False)
|
|
|
|
|
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(a_min=a_min, a_max=a_max)
|
2020-08-11 11:12:26 -07:00
|
|
|
for a_min, a_max in [(-1, None), (None, 1), (-0.9, 1),
|
2020-05-20 01:43:48 -03:00
|
|
|
(-np.ones(1), None),
|
|
|
|
(None, np.ones(1)),
|
2022-10-05 01:52:41 +00:00
|
|
|
(np.full(1, -0.9), np.ones(1))]
|
|
|
|
],
|
|
|
|
shape=all_shapes,
|
|
|
|
dtype=number_dtypes,
|
|
|
|
)
|
2021-07-13 11:38:21 -07:00
|
|
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
2022-06-14 11:20:37 -07:00
|
|
|
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
|
2020-12-02 17:00:25 -08:00
|
|
|
def testClipStaticBounds(self, shape, dtype, a_min, a_max):
|
2022-02-07 08:59:44 -08:00
|
|
|
if np.issubdtype(dtype, np.unsignedinteger):
|
|
|
|
a_min = None if a_min is None else abs(a_min)
|
|
|
|
a_max = None if a_max is None else abs(a_max)
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max)
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-05-20 01:43:48 -03:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-01 12:35:12 -04:00
|
|
|
def testClipError(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"):
|
|
|
|
jnp.clip(jnp.zeros((3,)))
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, dtype=dtype)
|
|
|
|
for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)],
|
|
|
|
decimals=[0, 1, -2],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testRoundStaticDecimals(self, shape, dtype, decimals):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
if jnp.issubdtype(dtype, np.integer) and decimals < 0:
|
2019-01-11 14:49:42 -05:00
|
|
|
self.skipTest("Integer rounding with decimals < 0 not implemented")
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda x: np.round(x, decimals=decimals)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda x: jnp.round(x, decimals=decimals)
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-05-20 01:43:48 -03:00
|
|
|
tol = {jnp.bfloat16: 5e-2, np.float16: 1e-2}
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE
|
2020-05-20 01:43:48 -03:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
check_dtypes=check_dtypes, tol=tol)
|
2020-03-06 14:59:51 -05:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes,
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
atol=tol, rtol=tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(jit=[False, True])
|
2022-09-20 09:09:23 -07:00
|
|
|
def testOperatorRound(self, jit):
|
|
|
|
jround = jax.jit(round, static_argnums=1) if jit else round
|
2020-05-20 01:43:48 -03:00
|
|
|
self.assertAllClose(round(np.float32(7.532), 1),
|
2022-09-20 09:09:23 -07:00
|
|
|
jround(jnp.float32(7.5), 1))
|
2020-05-20 01:43:48 -03:00
|
|
|
self.assertAllClose(round(np.float32(1.234), 2),
|
2022-09-20 09:09:23 -07:00
|
|
|
jround(jnp.float32(1.234), 2))
|
2020-05-20 01:43:48 -03:00
|
|
|
self.assertAllClose(round(np.float32(1.234)),
|
2022-09-20 09:09:23 -07:00
|
|
|
jround(jnp.float32(1.234)), check_dtypes=False)
|
2020-05-20 01:43:48 -03:00
|
|
|
self.assertAllClose(round(np.float32(7.532), 1),
|
2022-09-20 09:09:23 -07:00
|
|
|
jround(jnp.array(7.5, jnp.float32), 1))
|
2020-05-20 01:43:48 -03:00
|
|
|
self.assertAllClose(round(np.float32(1.234), 2),
|
2022-09-20 09:09:23 -07:00
|
|
|
jround(jnp.array(1.234, jnp.float32), 2))
|
2020-05-20 01:43:48 -03:00
|
|
|
self.assertAllClose(round(np.float32(1.234)),
|
2022-09-20 09:09:23 -07:00
|
|
|
jround(jnp.array(1.234, jnp.float32)),
|
2019-12-12 09:14:45 -05:00
|
|
|
check_dtypes=False)
|
|
|
|
|
2023-03-23 20:16:23 -07:00
|
|
|
def testRoundMethod(self):
|
|
|
|
# https://github.com/google/jax/issues/15190
|
|
|
|
(jnp.arange(3.) / 5.).round() # doesn't crash
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(shape=[(5,), (5, 2)])
|
2022-09-20 09:09:23 -07:00
|
|
|
def testOperatorReversed(self, shape):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, 'float32')]
|
|
|
|
np_fun = lambda x: np.array(list(reversed(x)))
|
|
|
|
jnp_fun = lambda x: jnp.array(list(reversed(x)))
|
|
|
|
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(mode=mode, shape=shape, dtype=dtype,
|
|
|
|
pad_width=pad_width, constant_values=constant_values)
|
2020-10-16 13:11:56 -07:00
|
|
|
for mode, shapes in [
|
|
|
|
('constant', all_shapes),
|
|
|
|
('wrap', nonempty_shapes),
|
|
|
|
('edge', nonempty_shapes),
|
2019-06-20 19:50:12 -04:00
|
|
|
]
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
for shape, dtype in _shape_and_dtypes(shapes, all_dtypes)
|
2020-10-16 13:11:56 -07:00
|
|
|
for constant_values in [
|
|
|
|
# None is used for modes other than 'constant'
|
|
|
|
None,
|
|
|
|
# constant
|
|
|
|
0, 1,
|
|
|
|
# (constant,)
|
|
|
|
(0,), (2.718,),
|
|
|
|
# ((before_const, after_const),)
|
|
|
|
((0, 2),), ((-1, 3.14),),
|
|
|
|
# ((before_1, after_1), ..., (before_N, after_N))
|
|
|
|
tuple((i / 2, -3.14 * i) for i in range(len(shape))),
|
|
|
|
]
|
|
|
|
for pad_width in [
|
|
|
|
# ((before_1, after_1), ..., (before_N, after_N))
|
|
|
|
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
|
|
|
|
# ((before, after),)
|
|
|
|
((1, 2),), ((2, 0),),
|
|
|
|
# (before, after) (not in the docstring but works in numpy)
|
|
|
|
(2, 0), (0, 0),
|
|
|
|
# (pad,)
|
|
|
|
(1,), (2,),
|
|
|
|
# pad
|
|
|
|
0, 1,
|
|
|
|
]
|
|
|
|
if (pad_width != () and constant_values != () and
|
|
|
|
((mode == 'constant' and constant_values is not None) or
|
2022-10-05 01:52:41 +00:00
|
|
|
(mode != 'constant' and constant_values is None)))],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testPad(self, shape, dtype, mode, pad_width, constant_values):
|
2022-02-07 08:59:44 -08:00
|
|
|
if np.issubdtype(dtype, np.unsignedinteger):
|
|
|
|
constant_values = tree_util.tree_map(abs, constant_values)
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-10-16 13:11:56 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
if constant_values is None:
|
|
|
|
np_fun = partial(np.pad, pad_width=pad_width, mode=mode)
|
|
|
|
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode)
|
|
|
|
else:
|
|
|
|
np_fun = partial(np.pad, pad_width=pad_width, mode=mode,
|
|
|
|
constant_values=constant_values)
|
|
|
|
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode,
|
|
|
|
constant_values=constant_values)
|
2019-01-09 21:26:22 -05:00
|
|
|
|
2020-05-20 01:43:48 -03:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2019-01-09 21:26:22 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(mode=mode, shape=shape, dtype=dtype,
|
|
|
|
pad_width=pad_width, stat_length=stat_length)
|
2020-12-01 13:49:30 +09:00
|
|
|
for mode in ['maximum', 'minimum', 'mean', 'median']
|
|
|
|
for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes)
|
|
|
|
for pad_width in [
|
|
|
|
# ((before_1, after_1), ..., (before_N, after_N))
|
|
|
|
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
|
|
|
|
# ((before, after),)
|
|
|
|
((1, 2),), ((2, 0),),
|
|
|
|
# (before, after) (not in the docstring but works in numpy)
|
|
|
|
(2, 0), (0, 0),
|
|
|
|
# (pad,)
|
|
|
|
(1,), (2,),
|
|
|
|
# pad
|
|
|
|
0, 1,
|
|
|
|
]
|
|
|
|
for stat_length in [
|
|
|
|
None,
|
|
|
|
# ((before_1, after_1), ..., (before_N, after_N))
|
|
|
|
tuple(((i % 3 + 1), ((i + 1) % 3) + 1) for i in range(len(shape))),
|
|
|
|
# ((before, after),)
|
|
|
|
((1, 2),), ((2, 2),),
|
|
|
|
# (before, after) (not in the docstring but works in numpy)
|
|
|
|
(1, 1), (3, 4),
|
|
|
|
# (pad,)
|
|
|
|
(1,), (2,),
|
|
|
|
# pad
|
|
|
|
1, 2
|
|
|
|
]
|
|
|
|
if (pad_width != () and stat_length != () and
|
2022-10-05 01:52:41 +00:00
|
|
|
not (dtype in bool_dtypes and mode == 'mean'))],
|
|
|
|
)
|
2020-12-01 13:49:30 +09:00
|
|
|
def testPadStatValues(self, shape, dtype, mode, pad_width, stat_length):
|
2021-11-19 10:54:09 -08:00
|
|
|
if mode == 'median' and np.issubdtype(dtype, np.complexfloating):
|
|
|
|
self.skipTest("median statistic is not supported for dtype=complex.")
|
2020-12-01 13:49:30 +09:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
|
|
|
|
np_fun = partial(np.pad, pad_width=pad_width, mode=mode, stat_length=stat_length)
|
|
|
|
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, stat_length=stat_length)
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
|
|
|
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, dtype=dtype,
|
|
|
|
pad_width=pad_width, reflect_type=reflect_type)
|
2020-12-12 12:24:27 +09:00
|
|
|
for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes)
|
|
|
|
for pad_width in [
|
|
|
|
# ((before_1, after_1), ..., (before_N, after_N))
|
|
|
|
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
|
|
|
|
# ((before, after),)
|
|
|
|
((1, 2),), ((2, 3),),
|
|
|
|
# (before, after) (not in the docstring but works in numpy)
|
|
|
|
(2, 1), (1, 2),
|
|
|
|
# (pad,)
|
|
|
|
(1,), (2,), (3,),
|
|
|
|
# pad
|
|
|
|
0, 5, 7, 10
|
|
|
|
]
|
|
|
|
for reflect_type in ['even', 'odd']
|
|
|
|
if (pad_width != () and
|
|
|
|
# following types lack precision when calculating odd values
|
2022-10-05 01:52:41 +00:00
|
|
|
(reflect_type != 'odd' or dtype not in [np.bool_, np.float16, jnp.bfloat16]))],
|
|
|
|
mode=['symmetric', 'reflect']
|
|
|
|
)
|
2020-12-12 12:24:27 +09:00
|
|
|
def testPadSymmetricAndReflect(self, shape, dtype, mode, pad_width, reflect_type):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
|
|
|
|
np_fun = partial(np.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type)
|
|
|
|
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type)
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
|
|
|
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE,
|
|
|
|
tol={np.float32: 1e-3, np.complex64: 1e-3})
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, dtype=dtype, pad_width=pad_width, end_values=end_values)
|
2022-02-07 08:59:44 -08:00
|
|
|
for shape, dtype in _shape_and_dtypes(nonempty_shapes, default_dtypes + complex_dtypes)
|
2020-12-05 13:09:45 +09:00
|
|
|
for pad_width in [
|
2021-02-01 16:30:30 -05:00
|
|
|
# ((before_1, after_1), ..., (before_N, after_N))
|
|
|
|
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
|
|
|
|
# ((before, after),)
|
|
|
|
((1, 2),), ((2, 0),),
|
|
|
|
# (before, after) (not in the docstring but works in numpy)
|
|
|
|
(2, 0), (0, 0),
|
|
|
|
# (pad,)
|
|
|
|
(1,), (2,),
|
|
|
|
# pad
|
|
|
|
0, 1,
|
2020-12-05 13:09:45 +09:00
|
|
|
]
|
|
|
|
for end_values in [
|
2021-02-01 16:30:30 -05:00
|
|
|
# ((before_1, after_1), ..., (before_N, after_N))
|
|
|
|
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
|
|
|
|
# ((before, after),)
|
|
|
|
((1, 2),), ((2.0, 3.14),),
|
|
|
|
# (before, after) (not in the docstring but works in numpy)
|
|
|
|
(0, 0), (-8.0, 2.0),
|
|
|
|
# (end_values,)
|
|
|
|
(1,), (2,),
|
|
|
|
# end_values
|
|
|
|
0, 1, 100, 10.0, 3.5, 4.2, -5, -3
|
2020-12-05 13:09:45 +09:00
|
|
|
]
|
|
|
|
if (pad_width != () and end_values != () and
|
|
|
|
# following types lack precision
|
2022-10-05 01:52:41 +00:00
|
|
|
dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16])],
|
|
|
|
)
|
2020-12-05 13:09:45 +09:00
|
|
|
def testPadLinearRamp(self, shape, dtype, pad_width, end_values):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
|
2021-02-01 16:30:30 -05:00
|
|
|
np_fun = partial(np.pad, pad_width=pad_width, mode="linear_ramp",
|
|
|
|
end_values=end_values)
|
|
|
|
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode="linear_ramp",
|
|
|
|
end_values=end_values)
|
2020-12-05 13:09:45 +09:00
|
|
|
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
|
|
|
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2020-12-16 16:06:57 +09:00
|
|
|
def testPadEmpty(self):
|
|
|
|
arr = np.arange(6).reshape(2, 3)
|
|
|
|
|
|
|
|
pad_width = ((2, 3), (3, 1))
|
|
|
|
np_res = np.pad(arr, pad_width=pad_width, mode="empty")
|
|
|
|
jnp_res = jnp.pad(arr, pad_width=pad_width, mode="empty")
|
|
|
|
|
|
|
|
np.testing.assert_equal(np_res.shape, jnp_res.shape)
|
|
|
|
np.testing.assert_equal(arr, np_res[2:-3, 3:-1])
|
|
|
|
np.testing.assert_equal(arr, jnp_res[2:-3, 3:-1])
|
|
|
|
np.testing.assert_equal(np_res[2:-3, 3:-1], jnp_res[2:-3, 3:-1])
|
|
|
|
|
2020-12-15 13:12:05 +09:00
|
|
|
def testPadKwargs(self):
|
|
|
|
modes = {
|
|
|
|
'constant': {'constant_values': 0},
|
|
|
|
'edge': {},
|
|
|
|
'linear_ramp': {'end_values': 0},
|
|
|
|
'maximum': {'stat_length': None},
|
|
|
|
'mean': {'stat_length': None},
|
|
|
|
'median': {'stat_length': None},
|
|
|
|
'minimum': {'stat_length': None},
|
|
|
|
'reflect': {'reflect_type': 'even'},
|
|
|
|
'symmetric': {'reflect_type': 'even'},
|
|
|
|
'wrap': {},
|
|
|
|
'empty': {}
|
|
|
|
}
|
2021-02-03 11:06:18 +02:00
|
|
|
arr = jnp.array([1, 2, 3])
|
2020-12-15 13:12:05 +09:00
|
|
|
pad_width = 1
|
|
|
|
|
|
|
|
for mode in modes.keys():
|
|
|
|
allowed = modes[mode]
|
|
|
|
not_allowed = {}
|
|
|
|
for kwargs in modes.values():
|
|
|
|
if kwargs != allowed:
|
|
|
|
not_allowed.update(kwargs)
|
|
|
|
|
|
|
|
# Test if allowed keyword arguments pass
|
|
|
|
jnp.pad(arr, pad_width, mode, **allowed)
|
|
|
|
# Test if prohibited keyword arguments of other modes raise an error
|
2022-05-12 19:13:00 +01:00
|
|
|
match = f"unsupported keyword arguments for mode '{mode}'"
|
2020-12-15 13:12:05 +09:00
|
|
|
for key, value in not_allowed.items():
|
|
|
|
with self.assertRaisesRegex(ValueError, match):
|
|
|
|
jnp.pad(arr, pad_width, mode, **{key: value})
|
|
|
|
|
2020-12-21 19:08:57 +09:00
|
|
|
# Test if unsupported mode raise error.
|
|
|
|
unsupported_modes = [1, None, "foo"]
|
|
|
|
for mode in unsupported_modes:
|
2022-05-12 19:13:00 +01:00
|
|
|
match = f"Unimplemented padding mode '{mode}' for np.pad."
|
2020-12-21 19:08:57 +09:00
|
|
|
with self.assertRaisesRegex(NotImplementedError, match):
|
|
|
|
jnp.pad(arr, pad_width, mode)
|
|
|
|
|
2020-12-29 21:30:29 +09:00
|
|
|
def testPadFunction(self):
|
|
|
|
def np_pad_with(vector, pad_width, iaxis, kwargs):
|
|
|
|
pad_value = kwargs.get('padder', 10)
|
|
|
|
vector[:pad_width[0]] = pad_value
|
|
|
|
vector[-pad_width[1]:] = pad_value
|
|
|
|
|
|
|
|
def jnp_pad_with(vector, pad_width, iaxis, kwargs):
|
|
|
|
pad_value = kwargs.get('padder', 10)
|
2021-09-13 16:40:45 -04:00
|
|
|
vector = vector.at[:pad_width[0]].set(pad_value)
|
|
|
|
vector = vector.at[-pad_width[1]:].set(pad_value)
|
2020-12-29 21:30:29 +09:00
|
|
|
return vector
|
|
|
|
|
|
|
|
arr = np.arange(6).reshape(2, 3)
|
|
|
|
np_res = np.pad(arr, 2, np_pad_with)
|
|
|
|
jnp_res = jnp.pad(arr, 2, jnp_pad_with)
|
|
|
|
np.testing.assert_equal(np_res, jnp_res)
|
|
|
|
|
|
|
|
arr = np.arange(24).reshape(2, 3, 4)
|
|
|
|
np_res = np.pad(arr, 1, np_pad_with, padder=100)
|
|
|
|
jnp_res = jnp.pad(arr, 1, jnp_pad_with, padder=100)
|
|
|
|
np.testing.assert_equal(np_res, jnp_res)
|
|
|
|
|
2020-12-30 10:10:54 +09:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(arr.shape, arr.dtype)]
|
|
|
|
jnp_fun = partial(jnp.pad, pad_width=1, mode=jnp_pad_with)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2020-10-20 00:43:04 +02:00
|
|
|
def testPadWithNumpyPadWidth(self):
|
2021-02-03 11:06:18 +02:00
|
|
|
a = jnp.array([1, 2, 3, 4, 5])
|
2020-10-20 00:43:04 +02:00
|
|
|
f = jax.jit(
|
|
|
|
partial(
|
|
|
|
jnp.pad,
|
|
|
|
pad_width=np.asarray((2, 3)),
|
|
|
|
mode="constant",
|
|
|
|
constant_values=(4, 6)))
|
|
|
|
|
|
|
|
np.testing.assert_array_equal(
|
|
|
|
f(a),
|
|
|
|
np.pad(
|
|
|
|
a,
|
|
|
|
pad_width=np.asarray((2, 3)),
|
|
|
|
mode="constant",
|
|
|
|
constant_values=(4, 6)))
|
|
|
|
|
2021-11-29 12:16:10 -08:00
|
|
|
def testPadWeakType(self):
|
|
|
|
x = jnp.array(1.0)[None]
|
|
|
|
for mode in ['constant', 'edge', 'linear_ramp', 'maximum', 'mean', 'median',
|
|
|
|
'minimum', 'reflect', 'symmetric', 'wrap', 'empty']:
|
|
|
|
y = jnp.pad(x, 0, mode=mode)
|
|
|
|
self.assertTrue(dtypes.is_weakly_typed(y))
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, dtype=dtype)
|
|
|
|
for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes)],
|
|
|
|
reps=[(), (2,), (3, 4), (2, 3, 4), (1, 0, 2)],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testTile(self, shape, dtype, reps):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda arg: np.tile(arg, reps)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda arg: jnp.tile(arg, reps)
|
2019-04-30 12:56:48 -07:00
|
|
|
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
|
2020-05-20 01:43:48 -03:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2019-01-09 21:26:22 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
|
2020-05-28 11:04:15 -07:00
|
|
|
def testExtract(self, shape, dtype):
|
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np.extract, jnp.extract, args_maker)
|
2020-05-28 11:04:15 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(ncond=ncond, nfunc=nfunc)
|
2020-08-03 12:31:52 -07:00
|
|
|
for ncond in [1, 2, 3]
|
|
|
|
for nfunc in [ncond, ncond + 1]
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
shape=all_shapes,
|
|
|
|
dtype=all_dtypes)
|
2020-08-03 12:31:52 -07:00
|
|
|
def testPiecewise(self, shape, dtype, ncond, nfunc):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
rng_bool = jtu.rand_int(self.rng(), 0, 2)
|
|
|
|
funclist = [lambda x: x - 1, 1, lambda x: x, 0][:nfunc]
|
2021-03-09 13:25:38 -08:00
|
|
|
args_maker = lambda: (rng(shape, dtype), [rng_bool(shape, bool) for i in range(ncond)])
|
2020-08-03 12:31:52 -07:00
|
|
|
np_fun = partial(np.piecewise, funclist=funclist)
|
|
|
|
jnp_fun = partial(jnp.piecewise, funclist=funclist)
|
2022-06-14 11:20:37 -07:00
|
|
|
|
|
|
|
if dtype == np.bool_:
|
|
|
|
# The `x - 1` above uses type promotion.
|
|
|
|
jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun)
|
|
|
|
|
2020-08-03 12:31:52 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
2021-03-09 13:25:38 -08:00
|
|
|
# This is a higher-order function, so the cache miss check will fail.
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, check_cache_misses=False)
|
2020-08-03 12:31:52 -07:00
|
|
|
|
2021-10-14 05:44:38 -07:00
|
|
|
def testPiecewiseRecompile(self):
|
|
|
|
def g(x):
|
|
|
|
g.num_traces += 1
|
|
|
|
return x
|
|
|
|
g.num_traces = 0
|
|
|
|
x = jnp.arange(10.0)
|
|
|
|
for i in range(5):
|
|
|
|
jnp.piecewise(x, [x < 0], [g, 0.])
|
|
|
|
self.assertEqual(g.num_traces, 1)
|
2020-08-03 12:31:52 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, perm=perm)
|
|
|
|
for shape in array_shapes
|
|
|
|
for perm in [
|
|
|
|
None,
|
|
|
|
tuple(np.random.RandomState(0).permutation(np.zeros(shape).ndim)),
|
|
|
|
tuple(np.random.RandomState(0).permutation(
|
|
|
|
np.zeros(shape).ndim) - np.zeros(shape).ndim)
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
arg_type=["splat", "value"],
|
|
|
|
)
|
2021-03-11 09:30:03 -08:00
|
|
|
def testTransposeTuple(self, shape, dtype, perm, arg_type):
|
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
if arg_type == "value":
|
|
|
|
np_fun = lambda x: x.transpose(perm)
|
|
|
|
jnp_fun = lambda x: jnp.array(x).transpose(perm)
|
|
|
|
else:
|
|
|
|
np_fun = lambda x: x.transpose(*(perm or ()))
|
|
|
|
jnp_fun = lambda x: jnp.array(x).transpose(*(perm or ()))
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
|
|
|
|
2023-05-25 09:02:05 -07:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[s for s in array_shapes if len(s) >= 2],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
use_property=[True, False]
|
|
|
|
)
|
|
|
|
def testMatrixTranspose(self, shape, dtype, use_property):
|
|
|
|
if use_property:
|
|
|
|
jnp_fun = lambda x: jnp.asarray(x).mT
|
|
|
|
else:
|
|
|
|
jnp_fun = jnp.matrix_transpose
|
|
|
|
if hasattr(np, 'matrix_transpose'):
|
|
|
|
np_fun = np.matrix_transpose
|
|
|
|
else:
|
|
|
|
np_fun = lambda x: np.swapaxes(x, -1, -2)
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=default_dtypes,
|
|
|
|
a_shape=one_dim_array_shapes,
|
|
|
|
trim=["f", "b", "fb"],
|
|
|
|
)
|
2020-08-18 16:40:45 -04:00
|
|
|
def testTrimZeros(self, a_shape, dtype, trim):
|
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
|
|
|
args_maker = lambda: [rng(a_shape, dtype)]
|
|
|
|
np_fun = lambda arg1: np.trim_zeros(arg1, trim)
|
|
|
|
jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
rank=(1, 2),
|
|
|
|
dtype=default_dtypes,
|
|
|
|
a_shape=one_dim_array_shapes,
|
|
|
|
)
|
|
|
|
@jax.default_matmul_precision("float32")
|
2021-05-18 10:18:58 -05:00
|
|
|
def testPoly(self, a_shape, dtype, rank):
|
|
|
|
if dtype in (np.float16, jnp.bfloat16, np.int16):
|
|
|
|
self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.")
|
|
|
|
elif rank == 2 and jtu.device_under_test() in ("tpu", "gpu"):
|
|
|
|
self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
tol = { np.int8: 1e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 }
|
|
|
|
if jtu.device_under_test() == "tpu":
|
|
|
|
tol[np.int32] = tol[np.float32] = 1e-1
|
|
|
|
tol = jtu.tolerance(dtype, tol)
|
|
|
|
args_maker = lambda: [rng(a_shape * rank, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np.poly, jnp.poly, args_maker, check_dtypes=False, tol=tol)
|
|
|
|
self._CompileAndCheck(jnp.poly, args_maker, check_dtypes=True, rtol=tol, atol=tol)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=default_dtypes,
|
|
|
|
a_shape=one_dim_array_shapes,
|
|
|
|
b_shape=one_dim_array_shapes,
|
|
|
|
)
|
2020-06-09 03:06:20 +10:00
|
|
|
def testPolyAdd(self, a_shape, b_shape, dtype):
|
2020-06-04 03:26:35 +10:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
np_fun = lambda arg1, arg2: np.polyadd(arg1, arg2)
|
|
|
|
jnp_fun = lambda arg1, arg2: jnp.polyadd(arg1, arg2)
|
2020-06-09 03:06:20 +10:00
|
|
|
args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)]
|
2020-06-04 03:26:35 +10:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=default_dtypes,
|
|
|
|
a_shape=one_dim_array_shapes,
|
|
|
|
b_shape=one_dim_array_shapes,
|
|
|
|
)
|
2020-06-06 02:44:10 +10:00
|
|
|
def testPolySub(self, a_shape, b_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
np_fun = lambda arg1, arg2: np.polysub(arg1, arg2)
|
|
|
|
jnp_fun = lambda arg1, arg2: jnp.polysub(arg1, arg2)
|
|
|
|
args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(order=order, k=k, dtype=dtype)
|
2021-02-18 11:08:41 -05:00
|
|
|
for dtype in default_dtypes
|
|
|
|
for order in range(5)
|
2022-10-05 01:52:41 +00:00
|
|
|
for k in [np.arange(order, dtype=dtype), np.ones(1, dtype), None]],
|
|
|
|
a_shape=one_dim_array_shapes,
|
|
|
|
)
|
2021-02-18 11:08:41 -05:00
|
|
|
def testPolyInt(self, a_shape, order, k, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
np_fun = lambda arg1: np.polyint(arg1, m=order, k=k)
|
|
|
|
jnp_fun = lambda arg1: jnp.polyint(arg1, m=order, k=k)
|
|
|
|
args_maker = lambda: [rng(a_shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=default_dtypes,
|
|
|
|
a_shape=one_dim_array_shapes,
|
|
|
|
order=list(range(5)),
|
|
|
|
)
|
2020-06-18 02:43:50 +10:00
|
|
|
def testPolyDer(self, a_shape, order, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
np_fun = lambda arg1: np.polyder(arg1, m=order)
|
|
|
|
jnp_fun = lambda arg1: jnp.polyder(arg1, m=order)
|
|
|
|
args_maker = lambda: [rng(a_shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@parameterized.parameters(['int', 'np.int', 'jnp.int'])
|
2021-01-14 11:11:04 -08:00
|
|
|
def testIntegerPower(self, ptype):
|
|
|
|
p = {'int': 2, 'np.int': np.int32(2), 'jnp.int': jnp.int32(2)}[ptype]
|
2022-12-06 08:23:40 -08:00
|
|
|
jaxpr = jax.make_jaxpr(lambda x1: jnp.power(x1, p))(1)
|
2021-01-14 11:11:04 -08:00
|
|
|
eqns = jaxpr.jaxpr.eqns
|
|
|
|
self.assertLen(eqns, 1)
|
|
|
|
self.assertEqual(eqns[0].primitive, lax.integer_pow_p)
|
2020-06-18 02:43:50 +10:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
x=[-1, 0, 1],
|
|
|
|
y=[0, 32, 64, 128],
|
|
|
|
)
|
2021-03-09 09:36:41 -08:00
|
|
|
def testIntegerPowerOverflow(self, x, y):
|
|
|
|
# Regression test for https://github.com/google/jax/issues/5987
|
|
|
|
args_maker = lambda: [x, y]
|
|
|
|
self._CheckAgainstNumpy(np.power, jnp.power, args_maker)
|
|
|
|
self._CompileAndCheck(jnp.power, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2020-05-27 18:57:00 -07:00
|
|
|
for shape in all_shapes
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in [None] + list(range(len(shape)))
|
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
2020-05-27 18:57:00 -07:00
|
|
|
def testCompress(self, shape, dtype, axis):
|
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
|
|
|
if shape in scalar_shapes or len(shape) == 0:
|
|
|
|
cond_shape = (0,)
|
|
|
|
elif axis is None:
|
2023-02-28 12:40:30 -08:00
|
|
|
cond_shape = (math.prod(shape),)
|
2020-05-27 18:57:00 -07:00
|
|
|
else:
|
|
|
|
cond_shape = (shape[axis],)
|
|
|
|
|
|
|
|
args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)]
|
|
|
|
|
|
|
|
np_fun = partial(np.compress, axis=axis)
|
|
|
|
jnp_fun = partial(jnp.compress, axis=axis)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
2020-05-27 18:57:00 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(2, 3)],
|
|
|
|
dtype=int_dtypes,
|
|
|
|
# condition entries beyond axis size must be zero.
|
|
|
|
condition=[[1], [1, 0, 0, 0, 0, 0, 0]],
|
|
|
|
axis=[None, 0, 1],
|
|
|
|
)
|
2020-05-27 18:57:00 -07:00
|
|
|
def testCompressMismatchedShapes(self, shape, dtype, condition, axis):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [np.array(condition), rng(shape, dtype)]
|
|
|
|
np_fun = partial(np.compress, axis=axis)
|
|
|
|
jnp_fun = partial(jnp.compress, axis=axis)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
2020-05-27 18:57:00 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2020-06-24 14:31:37 -04:00
|
|
|
for shape in array_shapes
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in [None] + list(range(len(shape)))
|
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
2020-06-24 16:13:56 +01:00
|
|
|
def testCompressMethod(self, shape, dtype, axis):
|
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
|
|
|
if shape in scalar_shapes or len(shape) == 0:
|
|
|
|
cond_shape = (0,)
|
|
|
|
elif axis is None:
|
2023-02-28 12:40:30 -08:00
|
|
|
cond_shape = (math.prod(shape),)
|
2020-06-24 16:13:56 +01:00
|
|
|
else:
|
|
|
|
cond_shape = (shape[axis],)
|
|
|
|
|
|
|
|
args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)]
|
|
|
|
|
|
|
|
np_fun = lambda condition, x: np.compress(condition, x, axis=axis)
|
|
|
|
jnp_fun = lambda condition, x: x.compress(condition, axis=axis)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(base_shape=base_shape, axis=axis)
|
|
|
|
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
|
|
|
for axis in range(-len(base_shape)+1, len(base_shape))
|
|
|
|
],
|
|
|
|
arg_dtypes=[
|
|
|
|
arg_dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
for num_arrs in [3]
|
2020-06-29 16:22:05 -07:00
|
|
|
for arg_dtypes in itertools.combinations_with_replacement(default_dtypes, num_arrs)
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=[None] + default_dtypes,
|
|
|
|
)
|
2022-08-01 15:48:40 -07:00
|
|
|
def testConcatenate(self, axis, dtype, base_shape, arg_dtypes):
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
wrapped_axis = axis % len(base_shape)
|
|
|
|
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
|
2019-11-20 22:43:46 -05:00
|
|
|
for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)]
|
2022-10-06 10:20:26 -07:00
|
|
|
@jtu.promote_like_jnp
|
2022-10-05 15:29:15 -07:00
|
|
|
def np_fun(*args, dtype=dtype):
|
|
|
|
dtype = dtype or args[0].dtype
|
2020-05-20 01:43:48 -03:00
|
|
|
args = [x if x.dtype != jnp.bfloat16 else x.astype(np.float32)
|
2019-11-20 22:43:46 -05:00
|
|
|
for x in args]
|
2022-08-01 15:48:40 -07:00
|
|
|
return np.concatenate(args, axis=axis, dtype=dtype, casting='unsafe')
|
|
|
|
jnp_fun = lambda *args: jnp.concatenate(args, axis=axis, dtype=dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def args_maker():
|
2019-11-20 22:43:46 -05:00
|
|
|
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match(arg_dtypes):
|
2022-10-05 15:29:15 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2021-06-10 13:25:33 -07:00
|
|
|
for shape in [(4, 1), (4, 3), (4, 5, 6)]
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in [None] + list(range(1 - len(shape), len(shape) - 1))
|
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
2021-06-10 13:25:33 -07:00
|
|
|
def testConcatenateArray(self, shape, dtype, axis):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
np_fun = lambda x: np.concatenate(x, axis=axis)
|
|
|
|
jnp_fun = lambda x: jnp.concatenate(x, axis=axis)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2020-06-14 13:56:53 -07:00
|
|
|
def testConcatenateAxisNone(self):
|
|
|
|
# https://github.com/google/jax/issues/3419
|
|
|
|
a = jnp.array([[1, 2], [3, 4]])
|
|
|
|
b = jnp.array([[5]])
|
|
|
|
jnp.concatenate((a, b), axis=None)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(base_shape=base_shape, axis=axis)
|
2018-12-30 17:49:11 -08:00
|
|
|
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in range(-len(base_shape)+1, len(base_shape))],
|
|
|
|
arg_dtypes=itertools.combinations_with_replacement(default_dtypes, 2)
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testAppend(self, axis, base_shape, arg_dtypes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-12-30 17:49:11 -08:00
|
|
|
wrapped_axis = axis % len(base_shape)
|
|
|
|
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
|
2019-11-20 22:43:46 -05:00
|
|
|
for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)]
|
2020-05-20 01:43:48 -03:00
|
|
|
def np_fun(arr, values):
|
|
|
|
arr = arr.astype(np.float32) if arr.dtype == jnp.bfloat16 else arr
|
|
|
|
values = (values.astype(np.float32) if values.dtype == jnp.bfloat16
|
2019-11-20 22:43:46 -05:00
|
|
|
else values)
|
2020-05-20 01:43:48 -03:00
|
|
|
out = np.append(arr, values, axis=axis)
|
2020-03-06 14:59:51 -05:00
|
|
|
return out.astype(jnp.promote_types(*arg_dtypes))
|
|
|
|
jnp_fun = lambda arr, values: jnp.append(arr, values, axis=axis)
|
2018-12-30 17:49:11 -08:00
|
|
|
|
|
|
|
def args_maker():
|
2019-11-20 22:43:46 -05:00
|
|
|
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
|
2018-12-30 17:49:11 -08:00
|
|
|
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match(arg_dtypes):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-12-30 17:49:11 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis, idx=idx)
|
2021-03-16 17:05:23 -07:00
|
|
|
for shape in nonempty_nonscalar_array_shapes
|
|
|
|
for axis in [None] + list(range(-len(shape), len(shape)))
|
2023-02-28 12:40:30 -08:00
|
|
|
for idx in (range(-math.prod(shape), math.prod(shape))
|
2021-03-16 17:05:23 -07:00
|
|
|
if axis is None else
|
2022-10-05 01:52:41 +00:00
|
|
|
range(-shape[axis], shape[axis]))],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
2021-03-16 17:05:23 -07:00
|
|
|
def testDeleteInteger(self, shape, dtype, idx, axis):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
np_fun = lambda arg: np.delete(arg, idx, axis=axis)
|
|
|
|
jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2021-03-16 17:05:23 -07:00
|
|
|
for shape in nonempty_nonscalar_array_shapes
|
|
|
|
for axis in [None] + list(range(-len(shape), len(shape)))
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
slc=[slice(None), slice(1, 3), slice(1, 5, 2)],
|
|
|
|
)
|
2021-03-16 17:05:23 -07:00
|
|
|
def testDeleteSlice(self, shape, dtype, axis, slc):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
np_fun = lambda arg: np.delete(arg, slc, axis=axis)
|
|
|
|
jnp_fun = lambda arg: jnp.delete(arg, slc, axis=axis)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2021-03-16 17:05:23 -07:00
|
|
|
for shape in nonempty_nonscalar_array_shapes
|
|
|
|
for axis in [None] + list(range(-len(shape), len(shape)))
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
idx_shape=all_shapes,
|
|
|
|
)
|
2021-03-16 17:05:23 -07:00
|
|
|
def testDeleteIndexArray(self, shape, dtype, axis, idx_shape):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
|
2022-08-06 14:49:09 +00:00
|
|
|
idx = jtu.rand_int(self.rng(), low=-max_idx, high=max_idx)(idx_shape, int)
|
2021-03-16 17:05:23 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
np_fun = lambda arg: np.delete(arg, idx, axis=axis)
|
|
|
|
jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
2023-04-19 12:33:59 -07:00
|
|
|
[dict(shape=shape, axis=axis)
|
|
|
|
for shape in nonempty_nonscalar_array_shapes
|
|
|
|
for axis in [None] + list(range(-len(shape), len(shape)))
|
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
idx_shape=all_shapes,
|
|
|
|
)
|
|
|
|
def testDeleteUniqueIndices(self, shape, dtype, axis, idx_shape):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
|
|
|
|
idx_size = np.zeros(idx_shape).size
|
|
|
|
if idx_size > max_idx:
|
|
|
|
self.skipTest("Too many indices to be unique")
|
|
|
|
def args_maker():
|
|
|
|
x = rng(shape, dtype)
|
|
|
|
idx = self.rng().choice(max_idx, idx_shape, replace=False)
|
|
|
|
return x, idx
|
|
|
|
np_fun = partial(np.delete, axis=axis)
|
|
|
|
jnp_fun = partial(jnp.delete, axis=axis, assume_unique_indices=True)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
|
|
|
@jtu.sample_product(
|
2022-10-05 01:52:41 +00:00
|
|
|
[dict(shape=shape, axis=axis)
|
2021-03-16 17:05:23 -07:00
|
|
|
for shape in nonempty_nonscalar_array_shapes
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in [None] + list(range(-len(shape), len(shape)))
|
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
2021-03-16 17:05:23 -07:00
|
|
|
def testDeleteMaskArray(self, shape, dtype, axis):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
|
|
|
|
mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool)
|
2022-06-23 11:46:51 -07:00
|
|
|
if numpy_version == (1, 23, 0) and mask.shape == (1,):
|
|
|
|
# https://github.com/numpy/numpy/issues/21840
|
|
|
|
self.skipTest("test fails for numpy v1.23.0")
|
2021-03-16 17:05:23 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
np_fun = lambda arg: np.delete(arg, mask, axis=axis)
|
|
|
|
jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2021-09-16 12:18:32 -07:00
|
|
|
for shape in nonempty_nonscalar_array_shapes
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in [None] + list(range(-len(shape), len(shape)))
|
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
2021-09-16 12:18:32 -07:00
|
|
|
def testInsertInteger(self, shape, dtype, axis):
|
|
|
|
x = jnp.empty(shape)
|
|
|
|
max_ind = x.size if axis is None else x.shape[axis]
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind)
|
|
|
|
args_maker = lambda: [rng(shape, dtype), i_rng((), np.int32), rng((), dtype)]
|
|
|
|
np_fun = lambda *args: np.insert(*args, axis=axis)
|
|
|
|
jnp_fun = lambda *args: jnp.insert(*args, axis=axis)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2021-09-16 12:18:32 -07:00
|
|
|
for shape in nonempty_nonscalar_array_shapes
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in [None] + list(range(-len(shape), len(shape)))
|
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
2021-09-16 12:18:32 -07:00
|
|
|
def testInsertSlice(self, shape, dtype, axis):
|
|
|
|
x = jnp.empty(shape)
|
|
|
|
max_ind = x.size if axis is None else x.shape[axis]
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind)
|
|
|
|
slc = slice(i_rng((), jnp.int32).item(), i_rng((), jnp.int32).item())
|
|
|
|
args_maker = lambda: [rng(shape, dtype), rng((), dtype)]
|
|
|
|
np_fun = lambda x, val: np.insert(x, slc, val, axis=axis)
|
|
|
|
jnp_fun = lambda x, val: jnp.insert(x, slc, val, axis=axis)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
|
|
|
@parameterized.parameters([
|
|
|
|
[[[1, 1], [2, 2], [3, 3]], 1, 5, None],
|
|
|
|
[[[1, 1], [2, 2], [3, 3]], 1, 5, 1],
|
|
|
|
[[[1, 1], [2, 2], [3, 3]], 1, [1, 2, 3], 1],
|
|
|
|
[[[1, 1], [2, 2], [3, 3]], [1], [[1],[2],[3]], 1],
|
|
|
|
[[1, 1, 2, 2, 3, 3], [2, 2], [5, 6], None],
|
|
|
|
[[1, 1, 2, 2, 3, 3], slice(2, 4), [5, 6], None],
|
|
|
|
[[1, 1, 2, 2, 3, 3], [2, 2], [7.13, False], None],
|
|
|
|
[[[0, 1, 2, 3], [4, 5, 6, 7]], (1, 3), 999, 1]
|
|
|
|
])
|
|
|
|
def testInsertExamples(self, arr, index, values, axis):
|
|
|
|
# Test examples from the np.insert docstring
|
|
|
|
args_maker = lambda: (
|
|
|
|
np.asarray(arr), index if isinstance(index, slice) else np.array(index),
|
|
|
|
np.asarray(values), axis)
|
|
|
|
self._CheckAgainstNumpy(np.insert, jnp.insert, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2020-09-11 08:47:05 -07:00
|
|
|
for shape in nonempty_array_shapes
|
|
|
|
for axis in range(-len(shape), len(shape))
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
out_dims=[0, 1, 2],
|
|
|
|
)
|
2020-09-11 08:47:05 -07:00
|
|
|
def testApplyAlongAxis(self, shape, dtype, axis, out_dims):
|
|
|
|
def func(x, out_dims):
|
|
|
|
if out_dims == 0:
|
2022-12-01 13:56:42 -08:00
|
|
|
return x.sum(dtype=x.dtype)
|
2020-09-11 08:47:05 -07:00
|
|
|
elif out_dims == 1:
|
|
|
|
return x * x[0]
|
|
|
|
elif out_dims == 2:
|
2021-07-13 11:38:21 -07:00
|
|
|
return x[:, None] + x[None, :]
|
2020-09-11 08:47:05 -07:00
|
|
|
else:
|
2022-12-01 09:12:01 -08:00
|
|
|
raise NotImplementedError(f"{out_dims=}")
|
2020-09-11 08:47:05 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
np_fun = lambda arr: np.apply_along_axis(func, axis, arr, out_dims=out_dims)
|
|
|
|
jnp_fun = lambda arr: jnp.apply_along_axis(func, axis, arr, out_dims=out_dims)
|
2022-10-05 01:52:41 +00:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
|
|
|
atol={dtypes.bfloat16: 2e-2})
|
2020-09-11 08:47:05 -07:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axes=axes)
|
2020-09-16 13:30:08 -07:00
|
|
|
for shape in nonempty_shapes
|
|
|
|
for axes in itertools.combinations(range(len(shape)), 2)
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
func=["sum"],
|
|
|
|
keepdims=[True, False],
|
|
|
|
# Avoid low-precision types in sum()
|
|
|
|
dtype=[dtype for dtype in default_dtypes
|
|
|
|
if dtype not in [np.float16, jnp.bfloat16]],
|
|
|
|
)
|
2020-09-16 13:30:08 -07:00
|
|
|
def testApplyOverAxes(self, shape, dtype, func, keepdims, axes):
|
2022-12-01 13:56:42 -08:00
|
|
|
f = lambda x, axis: getattr(x, func)(axis=axis, keepdims=keepdims, dtype=dtype)
|
2020-09-16 13:30:08 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: (rng(shape, dtype),)
|
|
|
|
np_fun = lambda a: np.apply_over_axes(f, a, axes)
|
|
|
|
jnp_fun = lambda a: jnp.apply_over_axes(f, a, axes)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, dtype=dtype, axis=axis)
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes)
|
2020-07-14 18:37:09 +01:00
|
|
|
for axis in [None] + list(range(-len(shape), max(1, len(shape))))
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
repeats=[0, 1, 2],
|
|
|
|
fixed_size=[False, True],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testRepeat(self, axis, shape, dtype, repeats, fixed_size):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda arg: np.repeat(arg, repeats=repeats, axis=axis)
|
2022-10-06 10:20:26 -07:00
|
|
|
np_fun = jtu.promote_like_jnp(np_fun)
|
2020-07-14 18:37:09 +01:00
|
|
|
if fixed_size:
|
2020-07-21 06:48:55 -07:00
|
|
|
total_repeat_length = np.repeat(np.zeros(shape), repeats, axis).shape[axis or 0]
|
2020-07-14 18:37:09 +01:00
|
|
|
jnp_fun = lambda arg, rep: jnp.repeat(arg, repeats=rep, axis=axis,
|
|
|
|
total_repeat_length=total_repeat_length)
|
|
|
|
jnp_args_maker = lambda: [rng(shape, dtype), repeats]
|
|
|
|
clo_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis,
|
|
|
|
total_repeat_length=total_repeat_length)
|
|
|
|
clo_fun_args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CompileAndCheck(jnp_fun, jnp_args_maker)
|
|
|
|
self._CheckAgainstNumpy(np_fun, clo_fun, clo_fun_args_maker)
|
|
|
|
else:
|
|
|
|
# Now repeats is in a closure, so a constant.
|
|
|
|
jnp_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-12-07 11:29:03 -05:00
|
|
|
|
2020-10-05 11:28:19 +00:00
|
|
|
def testRepeatScalarFastPath(self):
|
|
|
|
a = jnp.array([1,2,3,4])
|
|
|
|
f = lambda a: jnp.repeat(a, repeats=2)
|
2021-09-13 16:00:22 -04:00
|
|
|
jaxpr = jax.make_jaxpr(f)(a)
|
2020-10-05 11:28:19 +00:00
|
|
|
self.assertLessEqual(len(jaxpr.jaxpr.eqns), 6)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2020-04-20 21:04:32 -04:00
|
|
|
for shape in all_shapes
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in [None] + list(range(len(shape)))],
|
|
|
|
dtype=number_dtypes,
|
|
|
|
return_index=[False, True],
|
|
|
|
return_inverse=[False, True],
|
|
|
|
return_counts=[False, True],
|
|
|
|
)
|
2021-04-21 16:00:14 -07:00
|
|
|
def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts):
|
2021-10-12 20:55:27 -07:00
|
|
|
rng = jtu.rand_some_equal(self.rng())
|
2020-04-20 21:04:32 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2021-12-09 09:47:21 -08:00
|
|
|
extra_args = (return_index, return_inverse, return_counts)
|
|
|
|
use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False
|
|
|
|
np_fun = jtu.with_jax_dtype_defaults(lambda x: np.unique(x, *extra_args, axis=axis), use_defaults)
|
|
|
|
jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
2020-04-20 21:04:32 -04:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2021-10-12 20:55:27 -07:00
|
|
|
for shape in nonempty_array_shapes
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in [None] + list(range(len(shape)))],
|
|
|
|
dtype=number_dtypes,
|
|
|
|
size=[1, 5, 10],
|
|
|
|
fill_value=[None, -1.0, "slice"],
|
|
|
|
)
|
2021-10-12 20:55:27 -07:00
|
|
|
def testUniqueSize(self, shape, dtype, axis, size, fill_value):
|
|
|
|
rng = jtu.rand_some_equal(self.rng())
|
2021-06-08 11:31:42 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2021-10-12 20:55:27 -07:00
|
|
|
kwds = dict(axis=axis, return_index=True, return_inverse=True, return_counts=True)
|
2021-06-08 11:31:42 -07:00
|
|
|
|
2021-10-14 09:15:39 -07:00
|
|
|
if fill_value == "slice":
|
|
|
|
if axis is None:
|
|
|
|
fill_value = rng((), dtype)
|
|
|
|
else:
|
|
|
|
fill_value = rng(shape[:axis] + shape[axis + 1:], dtype)
|
|
|
|
|
2021-12-09 09:47:21 -08:00
|
|
|
@partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True))
|
2021-10-14 09:15:39 -07:00
|
|
|
def np_fun(x, fill_value=fill_value):
|
2021-10-12 20:55:27 -07:00
|
|
|
u, ind, inv, counts = np.unique(x, **kwds)
|
|
|
|
axis = kwds['axis']
|
|
|
|
if axis is None:
|
|
|
|
x = x.ravel()
|
|
|
|
axis = 0
|
|
|
|
|
|
|
|
n_unique = u.shape[axis]
|
|
|
|
if size <= u.shape[axis]:
|
|
|
|
slc = (slice(None),) * axis + (slice(size),)
|
|
|
|
u, ind, counts = u[slc], ind[:size], counts[:size]
|
2021-06-08 11:31:42 -07:00
|
|
|
else:
|
2021-10-12 20:55:27 -07:00
|
|
|
extra = (0, size - n_unique)
|
|
|
|
pads = [(0, 0)] * u.ndim
|
|
|
|
pads[axis] = extra
|
|
|
|
u = np.pad(u, pads, constant_values=0)
|
|
|
|
slices = [slice(None)] * u.ndim
|
|
|
|
slices[axis] = slice(1)
|
2021-10-14 09:15:39 -07:00
|
|
|
if fill_value is None:
|
|
|
|
fill_value = u[tuple(slices)]
|
|
|
|
elif np.ndim(fill_value):
|
|
|
|
fill_value = lax.expand_dims(fill_value, (axis,))
|
2021-10-12 20:55:27 -07:00
|
|
|
slices[axis] = slice(n_unique, None)
|
2021-10-14 09:15:39 -07:00
|
|
|
u[tuple(slices)] = fill_value
|
2021-10-13 16:23:14 -07:00
|
|
|
ind = np.pad(ind, extra, constant_values=ind[0])
|
2021-10-12 20:55:27 -07:00
|
|
|
counts = np.pad(counts, extra, constant_values=0)
|
2021-06-08 11:31:42 -07:00
|
|
|
return u, ind, inv, counts
|
|
|
|
|
2021-10-06 16:28:36 -07:00
|
|
|
jnp_fun = lambda x: jnp.unique(x, size=size, fill_value=fill_value, **kwds)
|
2021-06-08 11:31:42 -07:00
|
|
|
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(dtype=inexact_dtypes)
|
2022-01-13 15:54:07 -08:00
|
|
|
def testUniqueNans(self, dtype):
|
2022-06-23 11:46:51 -07:00
|
|
|
if numpy_version == (1, 23, 0) and dtype == np.float16:
|
|
|
|
# https://github.com/numpy/numpy/issues/21838
|
|
|
|
self.skipTest("Known failure on numpy 1.23.0")
|
2022-01-13 15:54:07 -08:00
|
|
|
def args_maker():
|
|
|
|
x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan]
|
|
|
|
if np.issubdtype(dtype, np.complexfloating):
|
|
|
|
x = [complex(i, j) for i, j in itertools.product(x, repeat=2)]
|
|
|
|
return [np.array(x, dtype=dtype)]
|
|
|
|
|
|
|
|
kwds = dict(return_index=True, return_inverse=True, return_counts=True)
|
|
|
|
jnp_fun = partial(jnp.unique, **kwds)
|
|
|
|
def np_fun(x):
|
|
|
|
dtype = x.dtype
|
|
|
|
# numpy unique fails for bfloat16 NaNs, so we cast to float64
|
|
|
|
if x.dtype == jnp.bfloat16:
|
|
|
|
x = x.astype('float64')
|
|
|
|
u, *rest = np.unique(x, **kwds)
|
|
|
|
return (u.astype(dtype), *rest)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(fixed_size=[False, True])
|
2020-07-14 18:37:09 +01:00
|
|
|
def testNonScalarRepeats(self, fixed_size):
|
2019-09-14 21:24:28 +01:00
|
|
|
'''
|
2020-04-20 23:47:49 -07:00
|
|
|
Following numpy test suite from `test_repeat` at
|
2021-06-18 08:55:08 +03:00
|
|
|
https://github.com/numpy/numpy/blob/main/numpy/core/tests/test_multiarray.py
|
2019-09-14 21:24:28 +01:00
|
|
|
'''
|
|
|
|
tol = 1e-5
|
2019-11-22 14:28:47 -08:00
|
|
|
|
2019-09-14 22:07:54 +01:00
|
|
|
def test_single(m, args_maker, repeats, axis):
|
2020-03-06 14:59:51 -05:00
|
|
|
lax_ans = jnp.repeat(m, repeats, axis)
|
2020-05-20 01:43:48 -03:00
|
|
|
numpy_ans = np.repeat(m, repeats, axis)
|
2019-09-14 22:07:54 +01:00
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(lax_ans, numpy_ans, rtol=tol, atol=tol)
|
2020-07-14 18:37:09 +01:00
|
|
|
if fixed_size:
|
2019-09-14 22:07:54 +01:00
|
|
|
|
2020-07-14 18:37:09 +01:00
|
|
|
# Calculate expected size of the repeated axis.
|
2020-07-21 06:48:55 -07:00
|
|
|
rep_length = np.repeat(np.zeros_like(m), repeats, axis).shape[axis or 0]
|
2020-07-14 18:37:09 +01:00
|
|
|
jnp_fun = lambda arg, rep: jnp.repeat(
|
2020-07-21 06:48:55 -07:00
|
|
|
arg, repeats=rep, axis=axis, total_repeat_length=rep_length)
|
2020-07-14 18:37:09 +01:00
|
|
|
else:
|
|
|
|
jnp_fun = lambda arg: jnp.repeat(arg, repeats = repeats, axis=axis)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2019-09-14 21:24:28 +01:00
|
|
|
|
2020-03-06 14:59:51 -05:00
|
|
|
m = jnp.array([1,2,3,4,5,6])
|
2020-07-14 18:37:09 +01:00
|
|
|
if fixed_size:
|
|
|
|
args_maker = lambda: [m, repeats]
|
|
|
|
else:
|
|
|
|
args_maker = lambda: [m]
|
2019-09-14 21:24:28 +01:00
|
|
|
|
2021-02-03 11:06:18 +02:00
|
|
|
for repeats in [2, jnp.array([1,3,0,1,1,2]), jnp.array([1,3,2,1,1,2]), jnp.array([2])]:
|
2020-07-15 19:03:58 +01:00
|
|
|
test_single(m, args_maker, repeats, axis=None)
|
|
|
|
test_single(m, args_maker, repeats, axis=0)
|
2019-09-14 21:57:46 +01:00
|
|
|
|
2019-09-14 22:01:36 +01:00
|
|
|
m_rect = m.reshape((2,3))
|
2020-07-14 18:37:09 +01:00
|
|
|
if fixed_size:
|
|
|
|
args_maker = lambda: [m_rect, repeats]
|
|
|
|
else:
|
|
|
|
args_maker = lambda: [m_rect]
|
2019-09-14 21:24:28 +01:00
|
|
|
|
2021-02-03 11:06:18 +02:00
|
|
|
for repeats in [2, jnp.array([2,1]), jnp.array([2])]:
|
2019-09-14 22:07:54 +01:00
|
|
|
test_single(m_rect, args_maker, repeats, axis=0)
|
2019-09-14 21:24:28 +01:00
|
|
|
|
2021-02-03 11:06:18 +02:00
|
|
|
for repeats in [2, jnp.array([1,3,2]), jnp.array([2])]:
|
2019-09-14 22:07:54 +01:00
|
|
|
test_single(m_rect, args_maker, repeats, axis=1)
|
2019-09-14 21:24:28 +01:00
|
|
|
|
2020-02-29 00:06:38 +00:00
|
|
|
def testIssue2330(self):
|
|
|
|
'''
|
|
|
|
Make sure return value of jnp.concatenate is a jax.ndarray and is side-effect save
|
|
|
|
'''
|
|
|
|
def attempt_sideeffect(x):
|
|
|
|
x = [x]
|
2020-03-06 14:59:51 -05:00
|
|
|
x = jnp.concatenate(x)
|
2020-02-29 00:06:38 +00:00
|
|
|
x -= 1.
|
|
|
|
return x
|
|
|
|
|
2022-05-12 19:13:00 +01:00
|
|
|
np_input = np.ones(1)
|
|
|
|
jnp_input = jnp.ones(1)
|
|
|
|
expected_np_input_after_call = np.ones(1)
|
|
|
|
expected_jnp_input_after_call = jnp.ones(1)
|
2020-04-13 17:48:49 -07:00
|
|
|
|
2022-08-18 15:58:40 -07:00
|
|
|
out = jnp.concatenate([np_input])
|
2023-03-15 17:08:21 -07:00
|
|
|
self.assertIs(type(out), array.ArrayImpl)
|
2020-04-13 17:48:49 -07:00
|
|
|
|
2020-05-20 01:43:48 -03:00
|
|
|
attempt_sideeffect(np_input)
|
2020-03-06 14:59:51 -05:00
|
|
|
attempt_sideeffect(jnp_input)
|
2020-02-29 00:06:38 +00:00
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(np_input, expected_np_input_after_call)
|
|
|
|
self.assertAllClose(jnp_input, expected_jnp_input_after_call)
|
2020-02-29 00:06:38 +00:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
mode=['full', 'same', 'valid'],
|
|
|
|
op=['convolve', 'correlate'],
|
|
|
|
dtype=number_dtypes,
|
|
|
|
xshape=one_dim_array_shapes,
|
|
|
|
yshape=one_dim_array_shapes,
|
|
|
|
)
|
|
|
|
def testConvolutions(self, xshape, yshape, dtype, mode, op):
|
|
|
|
jnp_op = getattr(jnp, op)
|
|
|
|
np_op = getattr(np, op)
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-04-09 22:50:10 -07:00
|
|
|
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
|
2020-04-16 20:27:00 -07:00
|
|
|
precision = lax.Precision.HIGHEST if jtu.device_under_test() == "tpu" else None
|
|
|
|
jnp_fun = partial(jnp_op, mode=mode, precision=precision)
|
2023-06-20 05:34:19 -07:00
|
|
|
def np_fun(x, y):
|
|
|
|
return np_op(x, y, mode=mode).astype(dtypes.to_inexact_dtype(dtype))
|
2021-05-19 15:22:42 -04:00
|
|
|
tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14,
|
|
|
|
np.complex128: 1e-14}
|
2023-06-20 05:34:19 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
|
|
|
@jtu.sample_product(
|
|
|
|
mode=['full', 'same', 'valid'],
|
|
|
|
op=['convolve', 'correlate'],
|
|
|
|
dtype=number_dtypes,
|
|
|
|
xshape=one_dim_array_shapes,
|
|
|
|
yshape=one_dim_array_shapes,
|
|
|
|
)
|
|
|
|
@jtu.skip_on_devices("gpu", "tpu", "rocm") # backends don't support all dtypes.
|
|
|
|
def testConvolutionsPreferredElementType(self, xshape, yshape, dtype, mode, op):
|
|
|
|
jnp_op = getattr(jnp, op)
|
|
|
|
np_op = getattr(np, op)
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
|
|
|
|
precision = lax.Precision.HIGHEST if jtu.device_under_test() == "tpu" else None
|
|
|
|
jnp_fun = partial(jnp_op, mode=mode, precision=precision,
|
|
|
|
preferred_element_type=dtype)
|
|
|
|
def np_fun(x, y):
|
|
|
|
return np_op(x, y, mode=mode).astype(dtype)
|
|
|
|
tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14,
|
|
|
|
np.complex128: 1e-14}
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-04-09 22:50:10 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2019-01-31 18:56:06 -05:00
|
|
|
for shape in all_shapes
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in [None] + list(range(-len(shape), len(shape)))],
|
|
|
|
op=["cumsum", "cumprod"],
|
|
|
|
dtype=all_dtypes,
|
2022-11-09 18:57:28 -08:00
|
|
|
out_dtype=[dtype for dtype in default_dtypes if dtype != np.float16],
|
2022-10-05 01:52:41 +00:00
|
|
|
)
|
|
|
|
def testCumSumProd(self, axis, shape, dtype, out_dtype, op):
|
|
|
|
jnp_op = getattr(jnp, op)
|
|
|
|
np_op = getattr(np, op)
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype)
|
2023-08-07 19:08:41 +02:00
|
|
|
np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun)
|
2022-10-05 01:52:41 +00:00
|
|
|
np_fun = jtu.ignore_warning(category=RuntimeWarning,
|
|
|
|
message="overflow encountered.*")(np_fun)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype)
|
2020-04-12 15:35:35 -04:00
|
|
|
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
2019-01-31 18:56:06 -05:00
|
|
|
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
|
2020-05-04 14:34:08 -04:00
|
|
|
tol_thresholds = {dtypes.bfloat16: 4e-2}
|
|
|
|
tol = max(jtu.tolerance(dtype, tol_thresholds),
|
|
|
|
jtu.tolerance(out_dtype, tol_thresholds))
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
2019-10-22 19:53:59 -04:00
|
|
|
tol=tol)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2019-01-31 18:56:06 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2020-06-02 16:45:44 -07:00
|
|
|
for shape in all_shapes
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in [None] + list(range(-len(shape), len(shape)))],
|
|
|
|
op=["nancumsum", "nancumprod"],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
out_dtype=default_dtypes,
|
|
|
|
)
|
|
|
|
def testNanCumSumProd(self, axis, shape, dtype, out_dtype, op):
|
|
|
|
jnp_op = getattr(jnp, op)
|
|
|
|
np_op = getattr(np, op)
|
2020-06-02 16:45:44 -07:00
|
|
|
rng = jtu.rand_some_nan(self.rng())
|
|
|
|
np_fun = partial(np_op, axis=axis, dtype=out_dtype)
|
2023-08-07 19:08:41 +02:00
|
|
|
np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun)
|
2022-10-05 01:52:41 +00:00
|
|
|
np_fun = jtu.ignore_warning(category=RuntimeWarning,
|
|
|
|
message="overflow encountered.*")(np_fun)
|
2020-06-02 16:45:44 -07:00
|
|
|
jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype)
|
|
|
|
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
|
|
|
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
tol_thresholds = {dtypes.bfloat16: 4e-2, np.float16: 3e-3}
|
2020-06-02 16:45:44 -07:00
|
|
|
tol = max(jtu.tolerance(dtype, tol_thresholds),
|
|
|
|
jtu.tolerance(out_dtype, tol_thresholds))
|
|
|
|
if dtype != jnp.bfloat16:
|
|
|
|
# numpy functions do not properly handle bfloat16
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
|
|
|
|
tol=tol)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis)
|
|
|
|
for yshape, xshape, dx, axis in [
|
|
|
|
((10,), None, 1.0, -1),
|
|
|
|
((3, 10), None, 2.0, -1),
|
|
|
|
((3, 10), None, 3.0, -0),
|
|
|
|
((10, 3), (10,), 1.0, -2),
|
|
|
|
((3, 10), (10,), 1.0, -1),
|
|
|
|
((3, 10), (3, 10), 1.0, -1),
|
|
|
|
((2, 3, 10), (3, 10), 1.0, -2),
|
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2020-05-13 10:59:31 -04:00
|
|
|
@jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test.
|
2021-07-13 11:38:21 -07:00
|
|
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
2020-05-12 14:12:03 -07:00
|
|
|
def testTrapz(self, yshape, xshape, dtype, dx, axis):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None]
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = partial(np.trapz, dx=dx, axis=axis)
|
2020-05-12 14:12:03 -07:00
|
|
|
jnp_fun = partial(jnp.trapz, dx=dx, axis=axis)
|
2022-10-11 15:59:44 +00:00
|
|
|
tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12,
|
2021-06-25 17:46:28 -07:00
|
|
|
dtypes.bfloat16: 4e-2})
|
2020-05-20 01:43:48 -03:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol,
|
2020-05-12 18:45:21 -04:00
|
|
|
check_dtypes=False)
|
2020-05-12 14:12:03 -07:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
|
2020-05-12 18:45:21 -04:00
|
|
|
check_dtypes=False)
|
2020-05-12 14:12:03 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=default_dtypes,
|
|
|
|
n=[0, 4],
|
|
|
|
m=[None, 0, 1, 3, 4],
|
|
|
|
k=list(range(-4, 4)),
|
|
|
|
)
|
2020-06-02 19:25:47 -07:00
|
|
|
def testTri(self, m, n, k, dtype):
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda: jnp.tri(n, M=m, k=k, dtype=dtype)
|
2018-12-12 12:05:49 -05:00
|
|
|
args_maker = lambda: []
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-12-12 12:05:49 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=default_dtypes,
|
|
|
|
shape=[shape for shape in all_shapes if len(shape) >= 2],
|
|
|
|
op=["tril", "triu"],
|
|
|
|
k=list(range(-3, 3)),
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testTriLU(self, dtype, shape, op, k):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda arg: getattr(np, op)(arg, k=k)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda arg: getattr(jnp, op)(arg, k=k)
|
2018-12-12 12:05:49 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-12-12 12:05:49 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-08 09:52:41 -08:00
|
|
|
n=range(5),
|
|
|
|
k=range(-3, 3),
|
|
|
|
m=[None, *range(5)],
|
2022-10-05 01:52:41 +00:00
|
|
|
)
|
2020-06-11 02:57:35 +10:00
|
|
|
def testTrilIndices(self, n, k, m):
|
|
|
|
np_fun = lambda n, k, m: np.tril_indices(n, k=k, m=m)
|
|
|
|
jnp_fun = lambda n, k, m: jnp.tril_indices(n, k=k, m=m)
|
|
|
|
args_maker = lambda: [n, k, m]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-08 09:52:41 -08:00
|
|
|
n=range(5),
|
|
|
|
k=range(-3, 3),
|
|
|
|
m=[None, *range(5)],
|
2022-10-05 01:52:41 +00:00
|
|
|
)
|
2020-06-11 02:57:35 +10:00
|
|
|
def testTriuIndices(self, n, k, m):
|
|
|
|
np_fun = lambda n, k, m: np.triu_indices(n, k=k, m=m)
|
|
|
|
jnp_fun = lambda n, k, m: jnp.triu_indices(n, k=k, m=m)
|
|
|
|
args_maker = lambda: [n, k, m]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=default_dtypes,
|
|
|
|
shape=[(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)],
|
|
|
|
k=[-1, 0, 1],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testTriuIndicesFrom(self, shape, dtype, k):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-06-11 02:57:35 +10:00
|
|
|
np_fun = lambda arr, k: np.triu_indices_from(arr, k=k)
|
|
|
|
jnp_fun = lambda arr, k: jnp.triu_indices_from(arr, k=k)
|
|
|
|
args_maker = lambda: [rng(shape, dtype), k]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=default_dtypes,
|
|
|
|
shape=[(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)],
|
|
|
|
k=[-1, 0, 1],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testTrilIndicesFrom(self, shape, dtype, k):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-06-11 02:57:35 +10:00
|
|
|
np_fun = lambda arr, k: np.tril_indices_from(arr, k=k)
|
|
|
|
jnp_fun = lambda arr, k: jnp.tril_indices_from(arr, k=k)
|
|
|
|
args_maker = lambda: [rng(shape, dtype), k]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
ndim=[0, 1, 4],
|
|
|
|
n=[0, 1, 7],
|
|
|
|
)
|
2019-12-20 16:25:15 -05:00
|
|
|
def testDiagIndices(self, ndim, n):
|
2021-12-09 09:47:21 -08:00
|
|
|
np.testing.assert_equal(jtu.with_jax_dtype_defaults(np.diag_indices)(n, ndim),
|
|
|
|
jnp.diag_indices(n, ndim))
|
2019-12-20 16:25:15 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=default_dtypes,
|
|
|
|
shape=[(1,1), (2,2), (3,3), (4,4), (5,5)],
|
|
|
|
)
|
2020-06-24 22:01:16 +05:30
|
|
|
def testDiagIndicesFrom(self, dtype, shape):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2021-12-09 09:47:21 -08:00
|
|
|
np_fun = jtu.with_jax_dtype_defaults(np.diag_indices_from)
|
2020-06-24 22:01:16 +05:30
|
|
|
jnp_fun = jnp.diag_indices_from
|
|
|
|
args_maker = lambda : [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2019-12-20 16:25:15 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=default_dtypes,
|
|
|
|
shape=[shape for shape in all_shapes if len(shape) in (1, 2)],
|
|
|
|
k=list(range(-4, 4)),
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testDiag(self, shape, dtype, k):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda arg: np.diag(arg, k)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda arg: jnp.diag(arg, k)
|
2018-12-12 17:54:27 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-12-12 17:54:27 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=default_dtypes,
|
|
|
|
shape=all_shapes,
|
|
|
|
k=list(range(-4, 4)),
|
|
|
|
)
|
2020-06-01 23:43:43 -04:00
|
|
|
def testDiagFlat(self, shape, dtype, k):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
# numpy has inconsistencies for scalar values
|
|
|
|
# https://github.com/numpy/numpy/issues/16477
|
|
|
|
# jax differs in that it treats scalars values as length-1 arrays
|
|
|
|
np_fun = lambda arg: np.diagflat(np.atleast_1d(arg), k)
|
|
|
|
jnp_fun = lambda arg: jnp.diagflat(arg, k)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=default_dtypes,
|
|
|
|
a1_shape=one_dim_array_shapes,
|
|
|
|
a2_shape=one_dim_array_shapes,
|
|
|
|
)
|
2020-06-04 23:27:29 -04:00
|
|
|
def testPolyMul(self, a1_shape, a2_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
np_fun = lambda arg1, arg2: np.polymul(arg1, arg2)
|
|
|
|
jnp_fun_np = lambda arg1, arg2: jnp.polymul(arg1, arg2, trim_leading_zeros=True)
|
|
|
|
jnp_fun_co = lambda arg1, arg2: jnp.polymul(arg1, arg2)
|
|
|
|
args_maker = lambda: [rng(a1_shape, dtype), rng(a2_shape, dtype)]
|
2020-11-17 13:11:15 -05:00
|
|
|
tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13}
|
2020-06-04 23:27:29 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun_np, args_maker, check_dtypes=False, tol=tol)
|
|
|
|
self._CompileAndCheck(jnp_fun_co, args_maker, check_dtypes=False)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
2022-11-09 18:57:28 -08:00
|
|
|
dtype=[dtype for dtype in default_dtypes
|
|
|
|
if dtype not in (np.float16, jnp.bfloat16)],
|
2022-10-05 01:52:41 +00:00
|
|
|
a_shape=one_dim_array_shapes,
|
|
|
|
b_shape=one_dim_array_shapes,
|
|
|
|
)
|
2022-04-11 13:53:44 +00:00
|
|
|
def testPolyDiv(self, a_shape, b_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
|
2022-04-13 13:21:29 -07:00
|
|
|
@jtu.ignore_warning(category=RuntimeWarning, message="divide by zero.*")
|
|
|
|
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*")
|
2022-04-11 13:53:44 +00:00
|
|
|
def np_fun(arg1, arg2):
|
|
|
|
q, r = np.polydiv(arg1, arg2)
|
2023-02-28 14:01:43 -08:00
|
|
|
while r.size < max(arg1.size, arg2.size): # Pad residual to same size
|
2022-04-11 13:53:44 +00:00
|
|
|
r = np.pad(r, (1, 0), 'constant')
|
|
|
|
return q, r
|
|
|
|
|
|
|
|
def jnp_fun(arg1, arg2):
|
|
|
|
q, r = jnp.polydiv(arg1, arg2, trim_leading_zeros=True)
|
2023-02-28 14:01:43 -08:00
|
|
|
while r.size < max(arg1.size, arg2.size): # Pad residual to same size
|
2022-04-11 13:53:44 +00:00
|
|
|
r = jnp.pad(r, (1, 0), 'constant')
|
|
|
|
return q, r
|
|
|
|
|
|
|
|
args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)]
|
2022-09-21 12:04:03 -07:00
|
|
|
tol = {
|
|
|
|
dtypes.bfloat16: 2e-1,
|
|
|
|
np.float16: 2e-1,
|
|
|
|
np.float32: 5e-2,
|
2022-11-09 18:57:28 -08:00
|
|
|
np.float64: 5e-7
|
2022-09-21 12:04:03 -07:00
|
|
|
}
|
2022-04-11 13:53:44 +00:00
|
|
|
|
|
|
|
jnp_compile = jnp.polydiv # Without trim_leading_zeros (trim_zeros make it unable to be compiled by XLA)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
|
|
|
|
self._CompileAndCheck(jnp_compile, args_maker, check_dtypes=True, atol=tol, rtol=tol)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis1=axis1, axis2=axis2)
|
2018-12-12 17:54:27 -05:00
|
|
|
for shape in [shape for shape in all_shapes if len(shape) >= 2]
|
2018-12-20 22:18:20 -05:00
|
|
|
for axis1 in range(-len(shape), len(shape))
|
|
|
|
for axis2 in [a for a in range(-len(shape), len(shape))
|
|
|
|
if a % len(shape) != axis1 % len(shape)]
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
offset=list(range(-4, 4)),
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testDiagonal(self, shape, dtype, offset, axis1, axis2):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda arg: np.diagonal(arg, offset, axis1, axis2)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda arg: jnp.diagonal(arg, offset, axis1, axis2)
|
2018-12-12 17:54:27 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-12-12 17:54:27 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=default_dtypes,
|
|
|
|
n=list(range(4)),
|
|
|
|
)
|
2018-12-20 15:36:37 -05:00
|
|
|
def testIdentity(self, n, dtype):
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda: np.identity(n, dtype)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda: jnp.identity(n, dtype)
|
2018-12-20 15:36:37 -05:00
|
|
|
args_maker = lambda: []
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-12-20 15:36:37 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=nonempty_shapes,
|
|
|
|
period=[None, 0.59],
|
|
|
|
left=[None, 0],
|
|
|
|
right=[None, 1],
|
|
|
|
# Note: skip 8-bit and 16-bit types due to insufficient precision.
|
|
|
|
dtype=jtu.dtypes.integer + jtu.dtypes.floating,
|
|
|
|
target_dtype=jtu.dtypes.inexact,
|
|
|
|
)
|
2022-06-01 13:38:06 -07:00
|
|
|
def testInterp(self, shape, dtype, period, left, right, target_dtype):
|
2020-08-04 12:39:04 -07:00
|
|
|
rng = jtu.rand_default(self.rng(), scale=10)
|
|
|
|
kwds = dict(period=period, left=left, right=right)
|
|
|
|
np_fun = partial(np.interp, **kwds)
|
|
|
|
jnp_fun = partial(jnp.interp, **kwds)
|
2022-06-01 13:38:06 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
args_maker = lambda: [rng(shape, dtype), np.unique(rng((100,), dtype))[:20],
|
2022-06-01 13:38:06 -07:00
|
|
|
rng((20,), target_dtype)]
|
2020-08-04 12:39:04 -07:00
|
|
|
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match([dtype, target_dtype]):
|
2022-10-05 01:52:41 +00:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
|
|
|
rtol=3e-3, atol=1e-3)
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-08-04 12:39:04 -07:00
|
|
|
|
2023-03-19 17:21:32 -07:00
|
|
|
@jtu.sample_product([
|
|
|
|
dict(x=0.5, left='extrapolate', expected=5),
|
|
|
|
dict(x=1.5, left='extrapolate', expected=15),
|
|
|
|
dict(x=3.5, left='extrapolate', expected=30),
|
|
|
|
dict(x=3.9, right='extrapolate', expected=39),
|
|
|
|
])
|
|
|
|
def testInterpExtrapoate(self, x, expected, **kwargs):
|
|
|
|
xp = jnp.array([1.0, 2.0, 3.0])
|
|
|
|
fp = jnp.array([10.0, 20.0, 30.0])
|
|
|
|
actual = jnp.interp(x, xp, fp, **kwargs)
|
|
|
|
self.assertAlmostEqual(actual, expected)
|
|
|
|
|
|
|
|
def testInterpErrors(self):
|
|
|
|
with self.assertRaisesWithLiteralMatch(
|
|
|
|
ValueError,
|
|
|
|
'xp and fp must be one-dimensional arrays of equal size'
|
|
|
|
):
|
|
|
|
jnp.interp(0.0, jnp.arange(2.0), jnp.arange(3.0))
|
|
|
|
with self.assertRaisesWithLiteralMatch(
|
|
|
|
ValueError,
|
|
|
|
"the only valid string value of `left` is 'extrapolate', but got: 'interpolate'"
|
|
|
|
):
|
|
|
|
jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), left='interpolate')
|
|
|
|
with self.assertRaisesWithLiteralMatch(
|
|
|
|
ValueError,
|
|
|
|
"the only valid string value of `right` is 'extrapolate', but got: 'interpolate'"
|
|
|
|
):
|
|
|
|
jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), right='interpolate')
|
|
|
|
with self.assertRaisesWithLiteralMatch(
|
|
|
|
ValueError,
|
|
|
|
"jnp.interp: complex x values not supported."
|
|
|
|
):
|
|
|
|
jnp.interp(1j, 1j * np.arange(3.0), np.arange(3.0))
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"period must be a scalar; got"
|
|
|
|
):
|
|
|
|
jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), period=np.array([1.0]))
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
period=[None, 0.59],
|
|
|
|
left=[None, 0],
|
|
|
|
right=[None, 1],
|
|
|
|
dtype=jtu.dtypes.floating,
|
|
|
|
)
|
2022-09-06 11:21:57 -07:00
|
|
|
def testInterpGradNan(self, dtype, period, left, right):
|
|
|
|
kwds = dict(period=period, left=left, right=right)
|
|
|
|
jnp_fun = partial(jnp.interp, **kwds)
|
|
|
|
# Probe values of x and xp that are close to zero and close together.
|
|
|
|
x = dtype(np.exp(np.linspace(-90, -20, 1000)))
|
|
|
|
g = jax.grad(lambda z: jnp.sum(jnp_fun(z, z, jnp.ones_like(z))))(x)
|
|
|
|
np.testing.assert_equal(np.all(np.isfinite(g)), True)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(x1_shape=x1_shape, x2_shape=x2_shape)
|
|
|
|
for x1_shape, x2_shape in filter(_shapes_are_broadcast_compatible,
|
|
|
|
itertools.combinations_with_replacement(array_shapes, 2))
|
|
|
|
],
|
|
|
|
x1_rng_factory=[jtu.rand_some_inf_and_nan, jtu.rand_some_zero],
|
|
|
|
x2_rng_factory=[partial(jtu.rand_int, low=-1075, high=1024)],
|
|
|
|
x1_dtype=default_dtypes,
|
|
|
|
)
|
2021-07-13 11:38:21 -07:00
|
|
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
2020-04-01 12:29:48 -07:00
|
|
|
def testLdexp(self, x1_shape, x1_dtype, x2_shape, x1_rng_factory, x2_rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
x1_rng = x1_rng_factory(self.rng())
|
|
|
|
x2_rng = x2_rng_factory(self.rng())
|
2022-06-01 09:14:47 -07:00
|
|
|
|
|
|
|
@jtu.ignore_warning(category=RuntimeWarning, message="overflow.*")
|
|
|
|
def np_fun(x1, x2):
|
2022-08-12 12:51:09 +00:00
|
|
|
out_dtype = dtypes.to_inexact_dtype(x1.dtype)
|
2022-06-01 09:14:47 -07:00
|
|
|
return np.ldexp(x1.astype(out_dtype), x2)
|
|
|
|
|
|
|
|
jnp_fun = jnp.ldexp
|
2020-04-01 12:29:48 -07:00
|
|
|
args_maker = lambda: [x1_rng(x1_shape, x1_dtype),
|
2020-05-20 01:43:48 -03:00
|
|
|
x2_rng(x2_shape, np.int32)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-04-01 12:29:48 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
rng_factory=[
|
2020-04-01 12:29:48 -07:00
|
|
|
jtu.rand_some_inf_and_nan,
|
|
|
|
jtu.rand_some_zero,
|
|
|
|
partial(jtu.rand_not_small, offset=1e8),
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
shape=all_shapes,
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2020-04-01 12:29:48 -07:00
|
|
|
def testFrexp(self, shape, dtype, rng_factory):
|
|
|
|
# integer types are converted to float64 in numpy's implementation
|
2020-05-20 01:43:48 -03:00
|
|
|
if (dtype not in [jnp.bfloat16, np.float16, np.float32]
|
2021-02-04 09:48:22 -08:00
|
|
|
and not config.x64_enabled):
|
2020-04-01 12:29:48 -07:00
|
|
|
self.skipTest("Only run float64 testcase when float64 is enabled.")
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2020-04-01 12:29:48 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2023-06-13 20:42:38 -04:00
|
|
|
def np_frexp(x):
|
|
|
|
mantissa, exponent = np.frexp(x)
|
|
|
|
# NumPy is inconsistent between Windows and Linux/Mac on what the
|
|
|
|
# value of exponent is if the input is infinite. Normalize to the Linux
|
|
|
|
# behavior.
|
|
|
|
exponent = np.where(np.isinf(mantissa), np.zeros_like(exponent), exponent)
|
|
|
|
return mantissa, exponent
|
|
|
|
self._CheckAgainstNumpy(np_frexp, jnp.frexp, args_maker,
|
2020-07-23 16:17:55 -04:00
|
|
|
check_dtypes=np.issubdtype(dtype, np.inexact))
|
2023-06-13 20:42:38 -04:00
|
|
|
self._CompileAndCheck(jnp.frexp, args_maker)
|
2020-04-01 12:29:48 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis1=axis1, axis2=axis2)
|
2018-12-13 08:44:27 -05:00
|
|
|
for shape in [shape for shape in all_shapes if len(shape) >= 2]
|
2019-05-20 17:11:18 -07:00
|
|
|
for axis1 in range(-len(shape), len(shape))
|
|
|
|
for axis2 in range(-len(shape), len(shape))
|
|
|
|
if (axis1 % len(shape)) != (axis2 % len(shape))
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
out_dtype=[None] + number_dtypes,
|
|
|
|
offset=list(range(-4, 4)),
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2):
|
2022-11-09 18:57:28 -08:00
|
|
|
if out_dtype == np.uint16:
|
|
|
|
raise unittest.SkipTest("TPU compiler crashes (Google bug b/258450318)")
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
def np_fun(arg):
|
2020-03-06 14:59:51 -05:00
|
|
|
if out_dtype == jnp.bfloat16:
|
2020-05-20 01:43:48 -03:00
|
|
|
return np.trace(arg, offset, axis1, axis2, np.float32).astype(jnp.bfloat16)
|
2019-11-20 22:43:46 -05:00
|
|
|
else:
|
2020-05-20 01:43:48 -03:00
|
|
|
return np.trace(arg, offset, axis1, axis2, out_dtype)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda arg: jnp.trace(arg, offset, axis1, axis2, out_dtype)
|
2018-12-13 08:44:27 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2022-10-05 01:52:41 +00:00
|
|
|
# TODO: Fails with uint8/uint16 output dtypes (integer overflow?)
|
|
|
|
if out_dtype not in (np.uint8, np.uint16, np.uint32):
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-12-13 08:44:27 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
ashape=[(15,), (16,), (17,)],
|
|
|
|
vshape=[(), (5,), (5, 5)],
|
|
|
|
side=['left', 'right'],
|
|
|
|
dtype=number_dtypes,
|
2023-03-07 16:32:36 -07:00
|
|
|
method=['sort', 'scan', 'compare_all'],
|
2022-10-05 01:52:41 +00:00
|
|
|
)
|
2022-07-21 09:40:18 -07:00
|
|
|
def testSearchsorted(self, ashape, vshape, side, dtype, method):
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-29 12:52:41 -07:00
|
|
|
args_maker = lambda: [np.sort(rng(ashape, dtype)), rng(vshape, dtype)]
|
2022-05-31 13:51:01 -07:00
|
|
|
def np_fun(a, v):
|
|
|
|
return np.searchsorted(a, v, side=side).astype('int32')
|
2022-07-21 09:40:18 -07:00
|
|
|
jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side, method=method)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-05-06 16:58:09 -07:00
|
|
|
|
2023-06-16 11:02:17 -04:00
|
|
|
@unittest.skipIf(
|
|
|
|
platform.system() == "Windows",
|
|
|
|
"Under Windows, NumPy throws if 2**32 is converted to an int32"
|
|
|
|
)
|
2022-05-31 13:51:01 -07:00
|
|
|
def testSearchsortedDtype(self):
|
|
|
|
# Test that for large arrays, int64 indices are used. We test this
|
|
|
|
# via abstract evaluation to avoid allocating a large array in tests.
|
2023-02-14 23:00:40 -08:00
|
|
|
a_int32 = core.ShapedArray((np.iinfo(np.int32).max,), np.float32)
|
|
|
|
a_int64 = core.ShapedArray((np.iinfo(np.int32).max + 1,), np.float32)
|
|
|
|
v = core.ShapedArray((), np.float32)
|
2022-05-31 13:51:01 -07:00
|
|
|
|
|
|
|
out_int32 = jax.eval_shape(jnp.searchsorted, a_int32, v)
|
|
|
|
self.assertEqual(out_int32.dtype, np.int32)
|
|
|
|
|
|
|
|
if config.x64_enabled:
|
|
|
|
out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v)
|
|
|
|
self.assertEqual(out_int64.dtype, np.int64)
|
|
|
|
else:
|
|
|
|
with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"):
|
2022-12-08 19:40:56 +00:00
|
|
|
with jtu.ignore_warning(category=DeprecationWarning,
|
|
|
|
message="NumPy will stop allowing conversion.*"):
|
|
|
|
out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v)
|
2022-05-31 13:51:01 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=inexact_dtypes,
|
|
|
|
side=['left', 'right'],
|
2023-03-07 16:32:36 -07:00
|
|
|
method=['sort', 'scan', 'compare_all'],
|
2022-10-05 01:52:41 +00:00
|
|
|
)
|
2022-07-21 09:40:18 -07:00
|
|
|
def testSearchsortedNans(self, dtype, side, method):
|
2022-01-06 09:19:28 -08:00
|
|
|
if np.issubdtype(dtype, np.complexfloating):
|
|
|
|
raise SkipTest("Known failure for complex inputs; see #9107")
|
2022-01-13 13:03:41 -08:00
|
|
|
x = np.array([-np.inf, -1.0, 0.0, -0.0, 1.0, np.inf, np.nan, -np.nan], dtype=dtype)
|
|
|
|
# The sign bit should not matter for 0.0 or NaN, so argsorting the above should be
|
|
|
|
# equivalent to argsorting the following:
|
|
|
|
x_equiv = np.array([0, 1, 2, 2, 3, 4, 5, 5])
|
|
|
|
|
2022-01-13 13:45:59 -08:00
|
|
|
if jnp.issubdtype(dtype, jnp.complexfloating):
|
|
|
|
x = np.array([complex(r, c) for r, c in itertools.product(x, repeat=2)])
|
|
|
|
x_equiv = np.array([complex(r, c) for r, c in itertools.product(x_equiv, repeat=2)])
|
|
|
|
|
2022-07-21 09:40:18 -07:00
|
|
|
fun = partial(jnp.searchsorted, side=side, method=method)
|
2022-01-13 13:03:41 -08:00
|
|
|
self.assertArraysEqual(fun(x, x), fun(x_equiv, x_equiv))
|
|
|
|
self.assertArraysEqual(jax.jit(fun)(x, x), fun(x_equiv, x_equiv))
|
2022-01-06 09:19:28 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
xshape=[(20,), (5, 4)],
|
|
|
|
binshape=[(1,), (5,)],
|
|
|
|
right=[True, False],
|
|
|
|
reverse=[True, False],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testDigitize(self, xshape, binshape, right, reverse, dtype):
|
2021-10-19 16:38:42 -07:00
|
|
|
order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:]
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-09 05:36:09 -07:00
|
|
|
args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]]
|
2022-05-31 13:51:01 -07:00
|
|
|
np_fun = lambda x, bins: np.digitize(x, bins, right=right).astype('int32')
|
2020-05-09 05:36:09 -07:00
|
|
|
jnp_fun = lambda x, bins: jnp.digitize(x, bins, right=right)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-05-09 05:36:09 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtypes=[
|
|
|
|
[np.float32],
|
|
|
|
[np.float32, np.float32],
|
|
|
|
[np.float32, np.int32, np.float32],
|
|
|
|
[np.float32, np.int64, np.float32],
|
|
|
|
[np.float32, np.int32, np.float64],
|
|
|
|
],
|
|
|
|
shape=[(), (2,), (3, 4), (1, 5)],
|
|
|
|
array_input=[True, False],
|
|
|
|
)
|
2021-06-11 10:42:06 -07:00
|
|
|
def testColumnStack(self, shape, dtypes, array_input):
|
2020-06-04 16:34:16 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2021-06-11 10:42:06 -07:00
|
|
|
if array_input:
|
|
|
|
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
|
|
|
|
else:
|
|
|
|
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
2022-10-06 10:20:26 -07:00
|
|
|
np_fun = jtu.promote_like_jnp(np.column_stack)
|
2020-06-04 16:34:16 -07:00
|
|
|
jnp_fun = jnp.column_stack
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-06-04 16:34:16 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2018-11-20 18:51:38 -08:00
|
|
|
for shape in [(), (2,), (3, 4), (1, 100)]
|
2021-06-11 10:42:06 -07:00
|
|
|
for axis in range(-len(shape), len(shape) + 1)
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtypes=[
|
|
|
|
[np.float32],
|
|
|
|
[np.float32, np.float32],
|
|
|
|
[np.float32, np.int32, np.float32],
|
|
|
|
[np.float32, np.int64, np.float32],
|
|
|
|
[np.float32, np.int32, np.float64],
|
|
|
|
],
|
|
|
|
array_input=[True, False],
|
|
|
|
out_dtype=[np.float32, np.int32],
|
|
|
|
)
|
2022-08-14 08:26:27 +02:00
|
|
|
def testStack(self, shape, axis, dtypes, array_input, out_dtype):
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2021-06-11 10:42:06 -07:00
|
|
|
if array_input:
|
|
|
|
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
|
|
|
|
else:
|
|
|
|
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
2022-08-14 08:26:27 +02:00
|
|
|
|
|
|
|
if numpy_version < (1, 24):
|
2022-10-06 10:20:26 -07:00
|
|
|
np_fun = jtu.promote_like_jnp(lambda *args: np.stack(*args, axis=axis).astype(out_dtype))
|
2022-08-14 08:26:27 +02:00
|
|
|
else:
|
2022-10-06 10:20:26 -07:00
|
|
|
np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe'))
|
2022-08-14 08:26:27 +02:00
|
|
|
|
|
|
|
jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype)
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2019-02-06 08:40:43 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
op=["hstack", "vstack", "dstack"],
|
|
|
|
dtypes=[
|
|
|
|
[np.float32],
|
|
|
|
[np.float32, np.float32],
|
|
|
|
[np.float32, np.int32, np.float32],
|
|
|
|
[np.float32, np.int64, np.float32],
|
|
|
|
[np.float32, np.int32, np.float64],
|
|
|
|
],
|
|
|
|
shape=[(), (2,), (3, 4), (1, 100), (2, 3, 4)],
|
|
|
|
array_input=[True, False],
|
|
|
|
out_dtype=[np.float32, np.int32],
|
|
|
|
)
|
2022-08-14 08:26:27 +02:00
|
|
|
def testHVDStack(self, shape, op, dtypes, array_input, out_dtype):
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2021-06-11 10:42:06 -07:00
|
|
|
if array_input:
|
|
|
|
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
|
|
|
|
else:
|
|
|
|
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
2022-08-14 08:26:27 +02:00
|
|
|
|
|
|
|
if numpy_version < (1, 24) or op == "dstack":
|
2022-10-06 10:20:26 -07:00
|
|
|
np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype))
|
2022-08-14 08:26:27 +02:00
|
|
|
else:
|
2022-12-08 19:40:56 +00:00
|
|
|
np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype,
|
|
|
|
casting='unsafe')
|
2022-08-14 08:26:27 +02:00
|
|
|
|
|
|
|
jnp_fun = partial(getattr(jnp, op), dtype=out_dtype)
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2019-02-06 08:40:43 -05:00
|
|
|
|
2022-10-27 15:08:16 -07:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(name=name, **kwds)
|
|
|
|
for name in ['blackman', 'bartlett', 'hamming', 'hanning', 'kaiser']
|
|
|
|
for kwds in ([dict(beta=1), dict(beta=0.5)] if name == 'kaiser' else [{}])
|
|
|
|
],
|
|
|
|
size = [0, 1, 5, 10],
|
|
|
|
)
|
|
|
|
def testWindowFunction(self, name, size, **kwds):
|
|
|
|
jnp_fun = partial(getattr(jnp, name), size, **kwds)
|
2022-12-01 13:56:42 -08:00
|
|
|
np_fun = jtu.with_jax_dtype_defaults(partial(getattr(np, name), size, **kwds))
|
2022-10-27 15:08:16 -07:00
|
|
|
args_maker = lambda: []
|
|
|
|
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, fill_value_shape=fill_value_shape)
|
2020-05-20 01:43:48 -03:00
|
|
|
for shape in array_shapes + [3, np.array(7, dtype=np.int32)]
|
2022-10-05 01:52:41 +00:00
|
|
|
for fill_value_shape in _compatible_shapes(shape)],
|
|
|
|
fill_value_dtype=default_dtypes,
|
|
|
|
out_dtype=[None] + default_dtypes,
|
|
|
|
)
|
2021-02-05 10:07:41 -08:00
|
|
|
def testFull(self, shape, fill_value_dtype, fill_value_shape, out_dtype):
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda fill_value: np.full(shape, fill_value, dtype=out_dtype)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda fill_value: jnp.full(shape, fill_value, dtype=out_dtype)
|
2021-02-05 10:07:41 -08:00
|
|
|
args_maker = lambda: [rng(fill_value_shape, fill_value_dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, dtype=dtype, axis=axis)
|
|
|
|
for shape, dtype in _shape_and_dtypes(nonempty_nonscalar_array_shapes, default_dtypes)
|
|
|
|
for axis in list(range(-len(shape), max(1, len(shape))))
|
|
|
|
],
|
|
|
|
prepend=[None, 1, 0],
|
|
|
|
append=[None, 1, 0],
|
|
|
|
n=[0, 1, 2],
|
|
|
|
)
|
2020-12-11 13:47:46 -08:00
|
|
|
def testDiff(self, shape, dtype, n, axis, prepend, append):
|
2022-10-05 01:52:41 +00:00
|
|
|
prepend = np.zeros(shape, dtype=dtype) if prepend == 0 else prepend
|
|
|
|
append = np.zeros(shape, dtype=dtype) if append == 0 else append
|
2020-12-11 13:47:46 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-12-04 11:35:32 +05:30
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
|
|
|
|
def np_fun(x, n=n, axis=axis, prepend=prepend, append=append):
|
|
|
|
if prepend is None:
|
2021-04-01 12:29:20 +09:00
|
|
|
prepend = np._NoValue
|
2020-12-04 11:35:32 +05:30
|
|
|
elif not np.isscalar(prepend) and prepend.dtype == jnp.bfloat16:
|
2021-04-01 12:29:20 +09:00
|
|
|
prepend = prepend.astype(np.float32)
|
2020-12-04 11:35:32 +05:30
|
|
|
|
|
|
|
if append is None:
|
|
|
|
append = np._NoValue
|
|
|
|
elif not np.isscalar(append) and append.dtype == jnp.bfloat16:
|
|
|
|
append = append.astype(np.float32)
|
|
|
|
|
|
|
|
if x.dtype == jnp.bfloat16:
|
|
|
|
return np.diff(x.astype(np.float32), n=n, axis=axis, prepend=prepend, append=append).astype(jnp.bfloat16)
|
|
|
|
else:
|
|
|
|
return np.diff(x, n=n, axis=axis, prepend=prepend, append=append)
|
|
|
|
|
|
|
|
jnp_fun = lambda x: jnp.diff(x, n=n, axis=axis, prepend=prepend, append=append)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
op=["zeros", "ones"],
|
|
|
|
shape=[2, (), (2,), (3, 0), np.array((4, 5, 6), dtype=np.int32),
|
|
|
|
np.array(4, dtype=np.int32)],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
|
|
|
def testZerosOnes(self, op, shape, dtype):
|
|
|
|
np_op = getattr(np, op)
|
|
|
|
jnp_op = getattr(jnp, op)
|
2020-06-02 19:25:47 -07:00
|
|
|
args_maker = lambda: []
|
2020-05-20 01:43:48 -03:00
|
|
|
np_op = partial(np_op, shape, dtype)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_op = partial(jnp_op, shape, dtype)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2019-12-17 17:20:51 -05:00
|
|
|
|
2020-03-11 09:57:04 -04:00
|
|
|
def testOnesWithInvalidShape(self):
|
|
|
|
with self.assertRaises(TypeError):
|
|
|
|
jnp.ones((-1, 1))
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, out_shape=out_shape, fill_value_shape=fill_value_shape)
|
|
|
|
for shape in array_shapes
|
|
|
|
for out_shape in [None] + array_shapes
|
|
|
|
for fill_value_shape in _compatible_shapes(shape if out_shape is None else out_shape)
|
|
|
|
],
|
|
|
|
in_dtype=default_dtypes,
|
|
|
|
fill_value_dtype=default_dtypes,
|
|
|
|
out_dtype=default_dtypes,
|
|
|
|
)
|
2021-02-05 10:07:41 -08:00
|
|
|
def testFullLike(self, shape, in_dtype, fill_value_dtype, fill_value_shape, out_dtype, out_shape):
|
2020-11-17 12:53:00 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-11-30 08:57:12 -05:00
|
|
|
np_fun = lambda x, fill_value: np.full_like(
|
|
|
|
x, fill_value, dtype=out_dtype, shape=out_shape)
|
|
|
|
jnp_fun = lambda x, fill_value: jnp.full_like(
|
|
|
|
x, fill_value, dtype=out_dtype, shape=out_shape)
|
2021-02-05 10:07:41 -08:00
|
|
|
args_maker = lambda: [rng(shape, in_dtype), rng(fill_value_shape, fill_value_dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-12-20 10:36:32 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=array_shapes,
|
|
|
|
out_shape=[None] + array_shapes,
|
|
|
|
in_dtype=default_dtypes,
|
|
|
|
func=["ones_like", "zeros_like"],
|
|
|
|
out_dtype=default_dtypes,
|
|
|
|
)
|
2020-12-09 16:44:44 -08:00
|
|
|
def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype):
|
2020-11-17 12:53:00 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
np_fun = lambda x: getattr(np, func)(x, dtype=out_dtype, shape=out_shape)
|
|
|
|
jnp_fun = lambda x: getattr(jnp, func)(x, dtype=out_dtype, shape=out_shape)
|
|
|
|
args_maker = lambda: [rng(shape, in_dtype)]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2023-06-06 00:59:51 -07:00
|
|
|
def testDuckTypedLike(self):
|
|
|
|
x = jax.ShapeDtypeStruct((1, 2, 3), np.dtype("int32"))
|
|
|
|
self.assertArraysEqual(jnp.zeros_like(x), jnp.zeros(x.shape, x.dtype))
|
|
|
|
self.assertArraysEqual(jnp.ones_like(x), jnp.ones(x.shape, x.dtype))
|
|
|
|
self.assertArraysEqual(jnp.empty_like(x), jnp.empty(x.shape, x.dtype))
|
|
|
|
self.assertArraysEqual(jnp.full_like(x, 2), jnp.full(x.shape, 2, x.dtype))
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(func=func, args=args)
|
|
|
|
for func, args in [("full_like", (-100,)), ("ones_like", ()), ("zeros_like", ())]
|
|
|
|
],
|
|
|
|
shape=array_shapes,
|
|
|
|
in_dtype=[np.int32, np.float32, np.complex64],
|
|
|
|
weak_type=[True, False],
|
|
|
|
out_shape=[None, (), (10,)],
|
|
|
|
out_dtype=[None, float],
|
|
|
|
)
|
2021-02-08 13:37:25 -08:00
|
|
|
def testZerosOnesFullLikeWeakType(self, func, args, shape, in_dtype, weak_type, out_shape, out_dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2022-03-09 18:18:16 -08:00
|
|
|
x = lax_internal._convert_element_type(rng(shape, in_dtype),
|
|
|
|
weak_type=weak_type)
|
2021-02-08 13:37:25 -08:00
|
|
|
fun = lambda x: getattr(jnp, func)(x, *args, dtype=out_dtype, shape=out_shape)
|
|
|
|
expected_weak_type = weak_type and (out_dtype is None)
|
|
|
|
self.assertEqual(dtypes.is_weakly_typed(fun(x)), expected_weak_type)
|
2021-09-13 16:00:22 -04:00
|
|
|
self.assertEqual(dtypes.is_weakly_typed(jax.jit(fun)(x)), expected_weak_type)
|
2021-02-08 13:37:25 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
funcname=["array", "asarray"],
|
|
|
|
dtype=[int, float, None],
|
|
|
|
val=[0, 1],
|
|
|
|
input_type=[int, float, np.int32, np.float32],
|
|
|
|
)
|
2021-02-08 13:37:25 -08:00
|
|
|
def testArrayWeakType(self, funcname, input_type, val, dtype):
|
|
|
|
func = lambda x: getattr(jnp, funcname)(x, dtype=dtype)
|
2021-09-13 16:00:22 -04:00
|
|
|
fjit = jax.jit(func)
|
2021-02-08 13:37:25 -08:00
|
|
|
val = input_type(val)
|
|
|
|
expected_weak_type = dtype is None and input_type in set(dtypes._weak_types)
|
|
|
|
self.assertEqual(dtypes.is_weakly_typed(func(val)), expected_weak_type)
|
|
|
|
self.assertEqual(dtypes.is_weakly_typed(fjit(val)), expected_weak_type)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=nonempty_nonscalar_array_shapes,
|
|
|
|
dtype=[int, float, complex],
|
|
|
|
weak_type=[True, False],
|
|
|
|
slc=[slice(None), slice(0), slice(3), 0, ...],
|
|
|
|
)
|
2021-02-08 13:37:25 -08:00
|
|
|
def testSliceWeakTypes(self, shape, dtype, weak_type, slc):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2022-03-09 18:18:16 -08:00
|
|
|
x = lax_internal._convert_element_type(rng(shape, dtype),
|
|
|
|
weak_type=weak_type)
|
2021-02-08 13:37:25 -08:00
|
|
|
op = lambda x: x[slc]
|
|
|
|
self.assertEqual(op(x).aval.weak_type, weak_type)
|
2021-09-13 16:00:22 -04:00
|
|
|
self.assertEqual(jax.jit(op)(x).aval.weak_type, weak_type)
|
2021-02-08 13:37:25 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis, num_sections=num_sections)
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, axis, num_sections in [
|
2021-05-27 20:33:18 -07:00
|
|
|
((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2),
|
|
|
|
((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)]
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testSplitStaticInt(self, shape, num_sections, axis, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda x: np.split(x, num_sections, axis=axis)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda x: jnp.split(x, num_sections, axis=axis)
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
Add np.{hsplit,vsplit,dsplit,deg2rad,rad2deg,degrees,radians,hypot,reciprocal,product}.
Forward np.{issubsctype,array_str,array_repr} to numpy.
2019-02-04 08:47:31 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis, num_sections=num_sections)
|
2020-09-03 00:13:17 +01:00
|
|
|
# All testcases split the specified axis unequally
|
|
|
|
for shape, axis, num_sections in [
|
|
|
|
((3,), 0, 2), ((12,), 0, 5), ((12, 4), 0, 7), ((12, 4), 1, 3),
|
|
|
|
((2, 3, 5), -1, 2), ((2, 4, 4), -2, 3), ((7, 2, 2), 0, 3)]
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2020-09-03 00:13:17 +01:00
|
|
|
def testArraySplitStaticInt(self, shape, num_sections, axis, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
np_fun = lambda x: np.array_split(x, num_sections, axis=axis)
|
|
|
|
jnp_fun = lambda x: jnp.array_split(x, num_sections, axis=axis)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2020-04-22 10:25:06 +03:00
|
|
|
def testSplitTypeError(self):
|
|
|
|
# If we pass an ndarray for indices_or_sections -> no error
|
|
|
|
self.assertEqual(3, len(jnp.split(jnp.zeros(3), jnp.array([1, 2]))))
|
|
|
|
|
|
|
|
CONCRETIZATION_MSG = "Abstract tracer value encountered where concrete value is expected."
|
|
|
|
with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG):
|
|
|
|
# An abstract tracer for idx
|
2021-09-13 16:00:22 -04:00
|
|
|
jax.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), idx))(2.)
|
2020-04-22 10:25:06 +03:00
|
|
|
with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG):
|
|
|
|
# A list including an abstract tracer
|
2021-09-13 16:00:22 -04:00
|
|
|
jax.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), [2, idx]))(2.)
|
2020-04-22 10:25:06 +03:00
|
|
|
|
|
|
|
# A concrete tracer -> no error
|
2021-09-13 16:00:22 -04:00
|
|
|
jax.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), idx),
|
2020-04-22 10:25:06 +03:00
|
|
|
(2.,), (1.,))
|
|
|
|
# A tuple including a concrete tracer -> no error
|
2021-09-13 16:00:22 -04:00
|
|
|
jax.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), (1, idx)),
|
2020-09-24 16:29:57 +01:00
|
|
|
(2.,), (1.,))
|
2020-04-22 10:25:06 +03:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(5,), (5, 5)],
|
|
|
|
dtype=number_dtypes,
|
|
|
|
bins=[10, np.arange(-5, 6), np.array([-5, 0, 3])],
|
|
|
|
range=[None, (0, 0), (0, 10)],
|
|
|
|
weights=[True, False],
|
|
|
|
)
|
2020-05-16 10:23:26 -07:00
|
|
|
def testHistogramBinEdges(self, shape, dtype, bins, range, weights):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
_weights = lambda w: abs(w) if weights else None
|
2020-07-14 10:24:42 -07:00
|
|
|
np_fun = lambda a, w, r: np.histogram_bin_edges(a, bins=bins, range=r,
|
|
|
|
weights=_weights(w))
|
|
|
|
jnp_fun = lambda a, w, r: jnp.histogram_bin_edges(a, bins=bins, range=r,
|
|
|
|
weights=_weights(w))
|
|
|
|
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), range]
|
2020-05-20 01:43:48 -03:00
|
|
|
tol = {jnp.bfloat16: 2E-2, np.float16: 1E-2}
|
2020-05-16 10:23:26 -07:00
|
|
|
# linspace() compares poorly to numpy when using bfloat16
|
|
|
|
if dtype != jnp.bfloat16:
|
2020-05-20 01:43:48 -03:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker,
|
2020-05-16 10:23:26 -07:00
|
|
|
atol=tol, rtol=tol)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(5,), (5, 5)],
|
|
|
|
dtype=default_dtypes,
|
2020-09-21 16:59:46 -04:00
|
|
|
# We only test explicit integer-valued bin edges because in other cases
|
2020-05-16 10:23:26 -07:00
|
|
|
# rounding errors lead to flaky tests.
|
2022-10-05 01:52:41 +00:00
|
|
|
bins=[np.arange(-5, 6), np.array([-5, 0, 3])],
|
|
|
|
density=[True, False],
|
|
|
|
weights=[True, False],
|
|
|
|
)
|
2020-05-16 10:23:26 -07:00
|
|
|
def testHistogram(self, shape, dtype, bins, density, weights):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
_weights = lambda w: abs(w) if weights else None
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda a, w: np.histogram(a, bins=bins, density=density,
|
2020-05-16 10:23:26 -07:00
|
|
|
weights=_weights(w))
|
|
|
|
jnp_fun = lambda a, w: jnp.histogram(a, bins=bins, density=density,
|
|
|
|
weights=_weights(w))
|
|
|
|
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)]
|
2020-05-20 01:43:48 -03:00
|
|
|
tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1}
|
2020-05-16 10:23:26 -07:00
|
|
|
# np.searchsorted errors on bfloat16 with
|
|
|
|
# "TypeError: invalid type promotion with custom data type"
|
|
|
|
if dtype != jnp.bfloat16:
|
2020-05-20 01:43:48 -03:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
2020-05-16 10:23:26 -07:00
|
|
|
tol=tol)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-05-16 10:23:26 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(5,), (12,)],
|
|
|
|
dtype=int_dtypes,
|
|
|
|
bins=[2, [2, 2], [np.array([0, 1, 3, 5]), np.array([0, 2, 3, 4, 6])]],
|
|
|
|
weights=[False, True],
|
|
|
|
density=[False, True],
|
|
|
|
range=[None, [(-1, 1), None], [(-1, 1), (-2, 2)]],
|
|
|
|
)
|
2021-08-25 11:01:40 -07:00
|
|
|
def testHistogram2d(self, shape, dtype, bins, weights, density, range):
|
2020-10-04 17:46:13 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
_weights = lambda w: abs(w) if weights else None
|
2021-08-25 11:01:40 -07:00
|
|
|
np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")(
|
|
|
|
lambda a, b, w: np.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range))
|
|
|
|
jnp_fun = lambda a, b, w: jnp.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range)
|
2020-10-04 17:46:13 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)]
|
|
|
|
tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1}
|
|
|
|
# np.searchsorted errors on bfloat16 with
|
|
|
|
# "TypeError: invalid type promotion with custom data type"
|
|
|
|
with np.errstate(divide='ignore', invalid='ignore'):
|
2020-10-20 00:43:04 +02:00
|
|
|
if dtype != jnp.bfloat16:
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
|
|
|
tol=tol)
|
2020-10-04 17:46:13 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(5, 3), (10, 3)],
|
|
|
|
dtype=int_dtypes,
|
|
|
|
bins=[(2, 2, 2), [np.array([-5, 0, 4]), np.array([-4, -1, 2]), np.array([-6, -1, 4])]],
|
|
|
|
weights=[False, True],
|
|
|
|
density=[False, True],
|
|
|
|
range=[None, [(-1, 1), None, None], [(-1, 1), (-2, 2), (-3, 3)]],
|
|
|
|
)
|
2021-08-25 11:01:40 -07:00
|
|
|
def testHistogramdd(self, shape, dtype, bins, weights, density, range):
|
2020-09-21 16:59:46 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
_weights = lambda w: abs(w) if weights else None
|
2021-08-25 11:01:40 -07:00
|
|
|
np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")(
|
|
|
|
lambda a, w: np.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range))
|
|
|
|
jnp_fun = lambda a, w: jnp.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range)
|
2020-09-21 16:59:46 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype), rng((shape[0],), dtype)]
|
|
|
|
tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1}
|
|
|
|
# np.searchsorted errors on bfloat16 with
|
|
|
|
# "TypeError: invalid type promotion with custom data type"
|
|
|
|
if dtype != jnp.bfloat16:
|
2020-10-20 00:43:04 +02:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
|
|
|
tol=tol)
|
2020-09-21 16:59:46 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis, num_sections=num_sections)
|
Add np.{hsplit,vsplit,dsplit,deg2rad,rad2deg,degrees,radians,hypot,reciprocal,product}.
Forward np.{issubsctype,array_str,array_repr} to numpy.
2019-02-04 08:47:31 -05:00
|
|
|
for shape, axis, num_sections in [
|
2023-02-02 16:40:17 +07:00
|
|
|
((12, 4), 0, 4), ((12,), 1, 2),
|
2022-10-05 01:52:41 +00:00
|
|
|
((2, 3, 4), 2, 2), ((4, 3, 4), 0, 2)]],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testHVDSplit(self, shape, num_sections, axis, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
Add np.{hsplit,vsplit,dsplit,deg2rad,rad2deg,degrees,radians,hypot,reciprocal,product}.
Forward np.{issubsctype,array_str,array_repr} to numpy.
2019-02-04 08:47:31 -05:00
|
|
|
def fn(module, axis):
|
|
|
|
if axis == 0:
|
|
|
|
return module.vsplit
|
|
|
|
elif axis == 1:
|
|
|
|
return module.hsplit
|
|
|
|
else:
|
|
|
|
assert axis == 2
|
|
|
|
return module.dsplit
|
|
|
|
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda x: fn(np, axis)(x, num_sections)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda x: fn(jnp, axis)(x, num_sections)
|
Add np.{hsplit,vsplit,dsplit,deg2rad,rad2deg,degrees,radians,hypot,reciprocal,product}.
Forward np.{issubsctype,array_str,array_repr} to numpy.
2019-02-04 08:47:31 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, out_shape=out_shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
for arg_shape, out_shape in [
|
2018-12-06 06:21:38 -08:00
|
|
|
(jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)),
|
|
|
|
((), (1, 1, 1)),
|
|
|
|
((7, 0), (0, 42, 101)),
|
2018-11-17 18:03:33 -08:00
|
|
|
((3, 4), 12),
|
|
|
|
((3, 4), (12,)),
|
|
|
|
((3, 4), -1),
|
|
|
|
((2, 1, 4), (-1,)),
|
|
|
|
((2, 2, 4), (2, 8))
|
2022-10-05 01:52:41 +00:00
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
order=["C", "F"],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testReshape(self, arg_shape, out_shape, dtype, order):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda x: np.reshape(x, out_shape, order=order)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda x: jnp.reshape(x, out_shape, order=order)
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, out_shape=out_shape)
|
2019-05-21 21:37:52 -07:00
|
|
|
for arg_shape, out_shape in [
|
|
|
|
((7, 0), (0, 42, 101)),
|
|
|
|
((2, 1, 4), (-1,)),
|
|
|
|
((2, 2, 4), (2, 8))
|
2022-10-05 01:52:41 +00:00
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testReshapeMethod(self, arg_shape, out_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda x: np.reshape(x, out_shape)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda x: x.reshape(*out_shape)
|
2019-05-21 21:37:52 -07:00
|
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2019-05-21 21:37:52 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, out_shape=out_shape)
|
|
|
|
for arg_shape, out_shape in itertools.product(all_shapes, array_shapes)],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2021-05-21 09:24:41 +01:00
|
|
|
def testResize(self, arg_shape, out_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
np_fun = lambda x: np.resize(x, out_shape)
|
|
|
|
jnp_fun = lambda x: jnp.resize(x, out_shape)
|
|
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
2022-08-06 14:49:09 +00:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
2021-05-21 09:24:41 +01:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, dim=dim)
|
2018-11-17 18:03:33 -08:00
|
|
|
for arg_shape in [(), (3,), (3, 4)]
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
for dim in (list(range(-len(arg_shape)+1, len(arg_shape)))
|
2021-07-08 11:47:52 -07:00
|
|
|
+ [np.array(0), np.array(-1), (0,), [np.array(0)],
|
2022-10-05 01:52:41 +00:00
|
|
|
(len(arg_shape), len(arg_shape) + 1)])
|
|
|
|
],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testExpandDimsStaticDim(self, arg_shape, dtype, dim):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda x: np.expand_dims(x, dim)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda x: jnp.expand_dims(x, dim)
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
|
2022-02-17 11:26:35 -08:00
|
|
|
def testExpandDimsRepeatedAxisError(self):
|
|
|
|
x = jnp.ones((2, 3))
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
ValueError, 'repeated axis.*',
|
|
|
|
lambda: jnp.expand_dims(x, [1, 1]))
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
ValueError, 'repeated axis.*',
|
|
|
|
lambda: jnp.expand_dims(x, [3, -1]))
|
|
|
|
|
|
|
|
# ensure this is numpy's behavior too, so that we remain consistent
|
|
|
|
x = np.ones((2, 3))
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
ValueError, 'repeated axis.*',
|
|
|
|
lambda: np.expand_dims(x, [1, 1]))
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
ValueError, 'repeated axis.*',
|
|
|
|
lambda: np.expand_dims(x, [3, -1]))
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, ax1=ax1, ax2=ax2)
|
2018-11-17 18:03:33 -08:00
|
|
|
for arg_shape, ax1, ax2 in [
|
|
|
|
((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2),
|
|
|
|
((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)]
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda x: np.swapaxes(x, ax1, ax2)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda x: jnp.swapaxes(x, ax1, ax2)
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg_shape=arg_shape, ax=ax)
|
2018-11-17 18:03:33 -08:00
|
|
|
for arg_shape, ax in [
|
|
|
|
((3, 1), None),
|
|
|
|
((3, 1), 1),
|
2020-05-29 13:29:14 -07:00
|
|
|
((3, 1), -1),
|
|
|
|
((3, 1), np.array(1)),
|
2018-11-17 18:03:33 -08:00
|
|
|
((1, 3, 1), (0, 2)),
|
2020-05-29 13:29:14 -07:00
|
|
|
((1, 3, 1), (0,)),
|
|
|
|
((1, 4, 1), (np.array(0),))]
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testSqueeze(self, arg_shape, dtype, ax):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda x: np.squeeze(x, ax)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda x: jnp.squeeze(x, ax)
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-12-27 15:42:49 -08:00
|
|
|
def testArrayFromMasked(self):
|
|
|
|
args_maker = lambda: [np.ma.array([1, 2], mask=[True, False])]
|
|
|
|
# Like np.array, jnp.array strips the mask from masked array inputs.
|
|
|
|
self._CheckAgainstNumpy(np.array, jnp.array, args_maker)
|
|
|
|
# Under JIT, masked arrays are flagged as invalid.
|
|
|
|
with self.assertRaisesRegex(ValueError, "numpy masked arrays are not supported"):
|
|
|
|
jax.jit(jnp.asarray)(*args_maker())
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(arg=arg, dtype=dtype, ndmin=ndmin)
|
|
|
|
for arg, dtypes in [
|
2020-06-01 19:29:26 -04:00
|
|
|
([True, False, True], all_dtypes),
|
|
|
|
(3., all_dtypes),
|
|
|
|
([1, 2, 3], all_dtypes),
|
|
|
|
(np.array([1, 2, 3], dtype=np.int64), all_dtypes),
|
|
|
|
([1., 2., 3.], all_dtypes),
|
|
|
|
([[1, 2], [3, 4], [5, 6]], all_dtypes),
|
|
|
|
([[1, 2.], [3, 4], [5, 6]], all_dtypes),
|
|
|
|
([[1., 2j], [3., 4.], [5., 6.]], complex_dtypes),
|
2020-05-20 01:43:48 -03:00
|
|
|
([[3, np.array(2, dtype=jnp.float_), 1],
|
2020-06-01 19:29:26 -04:00
|
|
|
np.arange(3., dtype=jnp.float_)], all_dtypes),
|
2022-10-05 01:52:41 +00:00
|
|
|
]
|
2020-06-01 19:29:26 -04:00
|
|
|
for dtype in [None] + dtypes
|
2022-10-05 01:52:41 +00:00
|
|
|
for ndmin in [None, np.ndim(arg), np.ndim(arg) + 1, np.ndim(arg) + 2]
|
|
|
|
],
|
|
|
|
)
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
def testArray(self, arg, ndmin, dtype):
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [arg]
|
2020-06-01 19:29:26 -04:00
|
|
|
canonical_dtype = dtypes.canonicalize_dtype(dtype or np.array(arg).dtype)
|
2019-07-06 11:16:32 -07:00
|
|
|
if ndmin is not None:
|
2020-06-01 19:29:26 -04:00
|
|
|
np_fun = partial(np.array, ndmin=ndmin, dtype=canonical_dtype)
|
|
|
|
jnp_fun = partial(jnp.array, ndmin=ndmin, dtype=dtype)
|
2019-07-06 11:16:32 -07:00
|
|
|
else:
|
2020-06-01 19:29:26 -04:00
|
|
|
np_fun = partial(np.array, dtype=canonical_dtype)
|
|
|
|
jnp_fun = partial(jnp.array, dtype=dtype)
|
|
|
|
|
|
|
|
# We are testing correct canonicalization behavior here, so we turn off the
|
|
|
|
# permissive canonicalization logic in the test harness.
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
|
|
|
canonicalize_dtypes=False)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-11-12 15:34:45 -08:00
|
|
|
@jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*")
|
|
|
|
def testArrayDtypeInference(self):
|
|
|
|
def _check(obj, out_dtype, weak_type):
|
|
|
|
dtype_reference = np.array(obj, dtype=out_dtype)
|
|
|
|
|
|
|
|
out = jnp.array(obj)
|
|
|
|
self.assertDtypesMatch(out, dtype_reference)
|
|
|
|
self.assertEqual(dtypes.is_weakly_typed(out), weak_type)
|
|
|
|
|
2021-12-14 15:20:25 -08:00
|
|
|
out_jit = jax.jit(jnp.array)(obj)
|
2021-11-12 15:34:45 -08:00
|
|
|
self.assertDtypesMatch(out_jit, dtype_reference)
|
|
|
|
self.assertEqual(dtypes.is_weakly_typed(out_jit), weak_type)
|
|
|
|
|
|
|
|
# Python scalars become 64-bit weak types.
|
|
|
|
_check(1, np.int64, True)
|
|
|
|
_check(1.0, np.float64, True)
|
|
|
|
_check(1.0j, np.complex128, True)
|
|
|
|
|
|
|
|
# Lists become strongly-typed defaults.
|
2022-12-01 13:56:42 -08:00
|
|
|
_check([1], jnp.int64, False)
|
|
|
|
_check([1.0], jnp.float64, False)
|
|
|
|
_check([1.0j], jnp.complex128, False)
|
2021-11-12 15:34:45 -08:00
|
|
|
|
|
|
|
# Lists of weakly-typed objects become strongly-typed defaults.
|
2022-12-01 13:56:42 -08:00
|
|
|
_check([jnp.array(1)], jnp.int64, False)
|
|
|
|
_check([jnp.array(1.0)], jnp.float64, False)
|
|
|
|
_check([jnp.array(1.0j)], jnp.complex128, False)
|
2021-11-12 15:34:45 -08:00
|
|
|
|
|
|
|
# Lists of strongly-typed objects maintain their strong type.
|
|
|
|
_check([jnp.int64(1)], np.int64, False)
|
|
|
|
_check([jnp.float64(1)], np.float64, False)
|
|
|
|
_check([jnp.complex128(1)], np.complex128, False)
|
|
|
|
|
2021-12-15 09:07:27 -08:00
|
|
|
# Mixed inputs use JAX-style promotion.
|
|
|
|
# (regression test for https://github.com/google/jax/issues/8945)
|
|
|
|
_check([0, np.int16(1)], np.int16, False)
|
|
|
|
_check([0.0, np.float16(1)], np.float16, False)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=all_dtypes,
|
|
|
|
func=["array", "copy", "copy.copy", "copy.deepcopy"],
|
|
|
|
)
|
2022-03-01 10:44:33 -08:00
|
|
|
def testArrayCopy(self, dtype, func):
|
2021-10-05 20:10:57 -07:00
|
|
|
x = jnp.ones(10, dtype=dtype)
|
2022-05-17 13:20:38 -07:00
|
|
|
if func == "copy.deepcopy":
|
|
|
|
copy_func = copy.deepcopy
|
|
|
|
elif func == "copy.copy":
|
|
|
|
copy_func = copy.copy
|
|
|
|
else:
|
|
|
|
copy_func = getattr(jnp, func)
|
2021-10-05 20:10:57 -07:00
|
|
|
|
|
|
|
x_view = jnp.asarray(x)
|
|
|
|
x_view_jit = jax.jit(jnp.asarray)(x)
|
2022-03-01 10:44:33 -08:00
|
|
|
x_copy = copy_func(x)
|
|
|
|
x_copy_jit = jax.jit(copy_func)(x)
|
2021-10-05 20:10:57 -07:00
|
|
|
|
2022-09-08 14:39:12 -07:00
|
|
|
_ptr = lambda x: x.unsafe_buffer_pointer()
|
2021-10-05 20:10:57 -07:00
|
|
|
|
|
|
|
self.assertEqual(_ptr(x), _ptr(x_view))
|
2023-08-25 10:59:10 -07:00
|
|
|
self.assertNotEqual(_ptr(x), _ptr(x_view_jit))
|
2021-10-05 20:10:57 -07:00
|
|
|
self.assertNotEqual(_ptr(x), _ptr(x_copy))
|
|
|
|
self.assertNotEqual(_ptr(x), _ptr(x_copy_jit))
|
|
|
|
|
|
|
|
x.delete()
|
|
|
|
|
|
|
|
self.assertTrue(x_view.is_deleted())
|
2023-08-25 10:59:10 -07:00
|
|
|
self.assertFalse(x_view_jit.is_deleted())
|
2021-10-05 20:10:57 -07:00
|
|
|
|
|
|
|
self.assertFalse(x_copy.is_deleted())
|
|
|
|
self.assertFalse(x_copy_jit.is_deleted())
|
|
|
|
|
2022-01-21 08:27:10 -08:00
|
|
|
def testArrayCopyAutodiff(self):
|
|
|
|
f = lambda x: jnp.array(x, copy=True)
|
|
|
|
|
|
|
|
x = jnp.ones(10)
|
|
|
|
xdot = jnp.ones(10)
|
|
|
|
y, ydot = jax.jvp(f, (x,), (xdot,))
|
|
|
|
self.assertIsNot(x, y)
|
|
|
|
self.assertIsNot(xdot, ydot)
|
|
|
|
|
|
|
|
ybar = jnp.ones(10)
|
|
|
|
y, f_vjp = jax.vjp(f, x)
|
|
|
|
xbar, = f_vjp(ybar)
|
|
|
|
self.assertIsNot(x, y)
|
|
|
|
self.assertIsNot(xbar, ybar)
|
|
|
|
|
|
|
|
def testArrayCopyVmap(self):
|
|
|
|
f = lambda x: jnp.array(x, copy=True)
|
|
|
|
x = jnp.ones(10)
|
|
|
|
y = jax.vmap(f)(x)
|
|
|
|
self.assertIsNot(x, y)
|
|
|
|
|
2020-04-29 14:14:49 -04:00
|
|
|
def testArrayUnsupportedDtypeError(self):
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
|
|
|
"JAX only supports number and bool dtypes.*"):
|
|
|
|
jnp.array(3, [('a','<i4'),('b','<i4')])
|
|
|
|
|
2021-03-29 09:26:19 -07:00
|
|
|
def testArrayFromInteger(self):
|
2021-03-30 10:05:03 -07:00
|
|
|
int_dtype = dtypes.canonicalize_dtype(jnp.int64)
|
|
|
|
int_max = jnp.iinfo(int_dtype).max
|
|
|
|
int_min = jnp.iinfo(int_dtype).min
|
2021-03-29 09:26:19 -07:00
|
|
|
|
|
|
|
# Values at extremes are converted correctly.
|
|
|
|
for val in [int_min, 0, int_max]:
|
2021-03-30 10:05:03 -07:00
|
|
|
self.assertEqual(jnp.array(val).dtype, int_dtype)
|
2021-03-29 09:26:19 -07:00
|
|
|
|
|
|
|
# out of bounds leads to an OverflowError
|
|
|
|
val = int_max + 1
|
2021-03-30 10:05:03 -07:00
|
|
|
with self.assertRaisesRegex(OverflowError, f"Python int {val} too large to convert to {int_dtype.name}"):
|
2021-03-29 09:26:19 -07:00
|
|
|
jnp.array(val)
|
|
|
|
|
2021-03-29 13:22:51 -07:00
|
|
|
# explicit uint64 should work
|
|
|
|
if config.x64_enabled:
|
2021-08-10 06:48:55 -07:00
|
|
|
self.assertEqual(np.uint64(val), jnp.array(val, dtype='uint64'))
|
2021-03-29 13:22:51 -07:00
|
|
|
|
2022-01-27 14:28:14 -08:00
|
|
|
def testArrayFromList(self):
|
2022-12-08 19:40:56 +00:00
|
|
|
dtype = dtypes.canonicalize_dtype('int64')
|
|
|
|
int_max = jnp.iinfo(dtype).max
|
|
|
|
int_min = jnp.iinfo(dtype).min
|
2022-01-27 14:28:14 -08:00
|
|
|
|
|
|
|
# Values at extremes are converted correctly.
|
|
|
|
for val in [int_min, 0, int_max]:
|
2022-12-08 19:40:56 +00:00
|
|
|
self.assertEqual(jnp.array([val]).dtype, dtype)
|
2022-01-27 14:28:14 -08:00
|
|
|
|
|
|
|
# list of values results in promoted type.
|
2022-06-14 11:20:37 -07:00
|
|
|
with jax.numpy_dtype_promotion('standard'):
|
|
|
|
self.assertEqual(jnp.array([0, np.float16(1)]).dtype, jnp.result_type('int64', 'float16'))
|
2022-01-27 14:28:14 -08:00
|
|
|
|
|
|
|
# out of bounds leads to an OverflowError
|
2022-12-08 19:40:56 +00:00
|
|
|
val = jnp.iinfo(jnp.int64).min - 1
|
2022-01-27 14:28:14 -08:00
|
|
|
with self.assertRaisesRegex(OverflowError, "Python int too large.*"):
|
|
|
|
jnp.array([0, val])
|
2021-03-29 09:26:19 -07:00
|
|
|
|
2019-05-20 11:49:09 -07:00
|
|
|
def testIssue121(self):
|
2020-05-20 01:43:48 -03:00
|
|
|
assert not np.isscalar(jnp.array(3))
|
2019-05-20 11:49:09 -07:00
|
|
|
|
2023-08-18 16:50:36 -04:00
|
|
|
def testArrayOutputsArrays(self):
|
2023-03-15 17:08:21 -07:00
|
|
|
assert type(jnp.array([])) is array.ArrayImpl
|
|
|
|
assert type(jnp.array(np.array([]))) is array.ArrayImpl
|
2020-05-12 23:08:14 -07:00
|
|
|
|
|
|
|
class NDArrayLike:
|
2020-10-20 00:43:04 +02:00
|
|
|
def __array__(self, dtype=None):
|
|
|
|
return np.array([], dtype=dtype)
|
2023-03-15 17:08:21 -07:00
|
|
|
assert type(jnp.array(NDArrayLike())) is array.ArrayImpl
|
2020-05-12 23:08:14 -07:00
|
|
|
|
2020-06-06 21:44:14 -07:00
|
|
|
# NOTE(mattjj): disabled b/c __array__ must produce ndarrays
|
2023-08-18 16:50:36 -04:00
|
|
|
# class ArrayLike:
|
2020-06-06 21:44:14 -07:00
|
|
|
# def __array__(self, dtype=None):
|
|
|
|
# return jnp.array([], dtype=dtype)
|
2023-08-18 16:50:36 -04:00
|
|
|
# assert xla.type_is_device_array(jnp.array(ArrayLike()))
|
2020-05-12 23:08:14 -07:00
|
|
|
|
2019-02-21 07:34:27 -08:00
|
|
|
def testArrayMethod(self):
|
2022-05-12 19:13:00 +01:00
|
|
|
class arraylike:
|
2021-12-15 09:07:27 -08:00
|
|
|
dtype = np.dtype('float32')
|
2019-02-21 07:34:27 -08:00
|
|
|
def __array__(self, dtype=None):
|
2020-06-06 21:44:14 -07:00
|
|
|
return np.array(3., dtype=dtype)
|
2018-12-05 08:22:27 -08:00
|
|
|
a = arraylike()
|
2020-03-06 14:59:51 -05:00
|
|
|
ans = jnp.array(a)
|
2021-12-15 09:07:27 -08:00
|
|
|
self.assertEqual(ans, 3.)
|
2018-12-05 08:22:27 -08:00
|
|
|
|
2022-02-11 12:44:55 -08:00
|
|
|
def testJaxArrayOps(self):
|
|
|
|
class arraylike:
|
|
|
|
def __jax_array__(self):
|
|
|
|
return jnp.array(3.)
|
2022-06-14 11:20:37 -07:00
|
|
|
self.assertArraysEqual(arraylike() * jnp.arange(10.), jnp.array(3.) * jnp.arange(10.))
|
2022-02-11 12:44:55 -08:00
|
|
|
|
2019-07-21 15:05:17 -04:00
|
|
|
def testMemoryView(self):
|
2019-07-08 11:51:49 +01:00
|
|
|
self.assertAllClose(
|
2021-12-14 13:42:17 -08:00
|
|
|
jnp.array(bytearray(b'\x2a')),
|
|
|
|
np.array(bytearray(b'\x2a'))
|
|
|
|
)
|
|
|
|
self.assertAllClose(
|
|
|
|
jnp.array(bytearray(b'\x2a\xf3'), ndmin=2),
|
|
|
|
np.array(bytearray(b'\x2a\xf3'), ndmin=2)
|
|
|
|
)
|
2019-07-08 11:51:49 +01:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(val=[1+1j, [1+1j], jnp.pi, np.arange(2)])
|
2022-07-21 16:56:29 -07:00
|
|
|
def testIsComplexObj(self, val):
|
|
|
|
args_maker = lambda: [val]
|
|
|
|
self._CheckAgainstNumpy(np.iscomplexobj, jnp.iscomplexobj, args_maker)
|
|
|
|
self._CompileAndCheck(jnp.iscomplexobj, args_maker)
|
|
|
|
|
2020-03-25 09:59:43 +00:00
|
|
|
def testIsClose(self):
|
2021-09-13 16:00:22 -04:00
|
|
|
c_isclose = jax.jit(jnp.isclose)
|
|
|
|
c_isclose_nan = jax.jit(partial(jnp.isclose, equal_nan=True))
|
2020-03-25 09:59:43 +00:00
|
|
|
n = 2
|
|
|
|
|
2021-12-10 10:32:09 -08:00
|
|
|
rng = self.rng()
|
2020-03-25 09:59:43 +00:00
|
|
|
x = rng.randn(n, 1)
|
|
|
|
y = rng.randn(n, 1)
|
2020-05-20 01:43:48 -03:00
|
|
|
inf = np.asarray(n * [np.inf]).reshape([n, 1])
|
|
|
|
nan = np.asarray(n * [np.nan]).reshape([n, 1])
|
2020-03-25 09:59:43 +00:00
|
|
|
args = [x, y, inf, -inf, nan]
|
|
|
|
|
|
|
|
for arg0 in args:
|
|
|
|
for arg1 in args:
|
2020-05-20 01:43:48 -03:00
|
|
|
result_np = np.isclose(arg0, arg1)
|
2020-03-25 09:59:43 +00:00
|
|
|
result_jax = jnp.isclose(arg0, arg1)
|
|
|
|
result_jit = c_isclose(arg0, arg1)
|
|
|
|
self.assertTrue(jnp.all(jnp.equal(result_np, result_jax)))
|
|
|
|
self.assertTrue(jnp.all(jnp.equal(result_np, result_jit)))
|
2020-05-20 01:43:48 -03:00
|
|
|
result_np = np.isclose(arg0, arg1, equal_nan=True)
|
2020-03-25 09:59:43 +00:00
|
|
|
result_jax = jnp.isclose(arg0, arg1, equal_nan=True)
|
|
|
|
result_jit = c_isclose_nan(arg0, arg1)
|
|
|
|
self.assertTrue(jnp.all(jnp.equal(result_np, result_jax)))
|
|
|
|
self.assertTrue(jnp.all(jnp.equal(result_np, result_jit)))
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
x=[1, [1], [1, 1 + 1E-4], [1, np.nan]],
|
|
|
|
y=[1, [1], [1, 1 + 1E-4], [1, np.nan]],
|
|
|
|
equal_nan=[True, False],
|
|
|
|
)
|
2022-06-14 11:20:37 -07:00
|
|
|
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
|
2020-11-24 11:54:40 -08:00
|
|
|
def testAllClose(self, x, y, equal_nan):
|
|
|
|
jnp_fun = partial(jnp.allclose, equal_nan=equal_nan, rtol=1E-3)
|
|
|
|
np_fun = partial(np.allclose, equal_nan=equal_nan, rtol=1E-3)
|
2021-09-02 11:10:50 -07:00
|
|
|
args_maker = lambda: [np.array(x), np.array(y)]
|
2020-11-24 11:54:40 -08:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def testZeroStridesConstantHandler(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
raw_const = self.rng().randn(1, 2, 1, 1, 5, 1)
|
2020-05-20 01:43:48 -03:00
|
|
|
const = np.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def fun(x):
|
|
|
|
return x * const
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
fun = jax.jit(fun)
|
2018-11-17 18:03:33 -08:00
|
|
|
out_val = fun(3.)
|
|
|
|
self.assertAllClose(out_val, 3. * const, check_dtypes=False)
|
|
|
|
|
|
|
|
def testIsInstanceNdarrayDuringTracing(self):
|
2020-05-20 01:43:48 -03:00
|
|
|
arr = np.ones(3)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
@jax.jit
|
2018-11-17 18:03:33 -08:00
|
|
|
def f(x):
|
2023-02-15 14:52:31 -08:00
|
|
|
self.assertIsInstance(x, jax.Array)
|
2020-03-06 14:59:51 -05:00
|
|
|
return jnp.sum(x)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
f(arr)
|
|
|
|
|
|
|
|
def testNonArrayErrorMessage(self):
|
|
|
|
x = [1., 2.]
|
2020-05-20 01:43:48 -03:00
|
|
|
y = np.array([3., 4.])
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def g(x, y):
|
2020-03-06 14:59:51 -05:00
|
|
|
return jnp.add(x, y)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def f(x, y):
|
2020-03-06 14:59:51 -05:00
|
|
|
return jnp.dot(x, y)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
self.assertRaises(TypeError, lambda: g(x, y))
|
|
|
|
self.assertRaises(TypeError, lambda: f(x, y))
|
2021-09-13 16:00:22 -04:00
|
|
|
self.assertRaises(TypeError, lambda: jax.jit(g)(x, y))
|
|
|
|
self.assertRaises(TypeError, lambda: jax.jit(f)(x, y))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def testAbstractionErrorMessage(self):
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
@jax.jit
|
2018-11-17 18:03:33 -08:00
|
|
|
def f(x, n):
|
|
|
|
for _ in range(n):
|
|
|
|
x = x * x
|
|
|
|
return x
|
|
|
|
|
2021-03-02 09:29:59 -08:00
|
|
|
self.assertRaises(jax.errors.TracerIntegerConversionError, lambda: f(3., 3))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
@jax.jit
|
2018-11-17 18:03:33 -08:00
|
|
|
def g(x):
|
|
|
|
if x > 0.:
|
|
|
|
return x * 2
|
|
|
|
else:
|
|
|
|
return x + 2
|
|
|
|
|
2021-03-02 09:29:59 -08:00
|
|
|
self.assertRaises(jax.errors.ConcretizationTypeError, lambda: g(3.))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2018-12-11 08:54:35 -08:00
|
|
|
for shape in [(3,), (2, 3)]
|
2021-04-06 17:08:33 +02:00
|
|
|
for axis in list(range(-len(shape), len(shape))) + [None] + [tuple(range(len(shape)))] # Test negative axes and tuples
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testFlip(self, shape, dtype, axis):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-12-11 08:54:35 -08:00
|
|
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_op = lambda x: jnp.flip(x, axis)
|
2020-05-20 01:43:48 -03:00
|
|
|
np_op = lambda x: np.flip(x, axis)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2018-12-11 08:54:35 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(3,), (2, 3), (3, 2, 4)],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testFlipud(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-01-31 12:57:43 -08:00
|
|
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_op = lambda x: jnp.flipud(x)
|
2020-05-20 01:43:48 -03:00
|
|
|
np_op = lambda x: np.flipud(x)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2019-01-31 12:57:43 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(3, 2), (2, 3), (3, 2, 4)],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testFliplr(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-01-31 12:57:43 -08:00
|
|
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_op = lambda x: jnp.fliplr(x)
|
2020-05-20 01:43:48 -03:00
|
|
|
np_op = lambda x: np.fliplr(x)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2019-01-31 12:57:43 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axes=axes)
|
2018-12-11 08:54:35 -08:00
|
|
|
for shape, axes in [
|
|
|
|
[(2, 3), (0, 1)],
|
|
|
|
[(2, 3), (1, 0)],
|
|
|
|
[(4, 3, 2), (0, 2)],
|
|
|
|
[(4, 3, 2), (2, 1)],
|
|
|
|
]
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
k=range(-3, 4),
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testRot90(self, shape, dtype, k, axes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-12-11 08:54:35 -08:00
|
|
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_op = lambda x: jnp.rot90(x, k, axes)
|
2020-05-20 01:43:48 -03:00
|
|
|
np_op = lambda x: np.rot90(x, k, axes)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2018-12-11 08:54:35 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
# TODO(mattjj): test infix operator overrides
|
|
|
|
|
2018-12-13 11:52:41 -08:00
|
|
|
def testRavel(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
rng = self.rng()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(lambda x: x.ravel(), args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=nonempty_nonscalar_array_shapes,
|
|
|
|
order=['C', 'F'],
|
|
|
|
mode=['wrap', 'clip', 'raise'],
|
|
|
|
)
|
2021-07-13 11:38:21 -07:00
|
|
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
2020-09-17 12:42:22 -07:00
|
|
|
def testRavelMultiIndex(self, shape, order, mode):
|
|
|
|
# generate indices in each dimension with a few out of bounds.
|
|
|
|
rngs = [jtu.rand_int(self.rng(), low=-1, high=dim + 1)
|
|
|
|
for dim in shape]
|
|
|
|
# generate multi_indices of different dimensions that broadcast.
|
|
|
|
args_maker = lambda: [tuple(rng(ndim * (3,), jnp.int_)
|
|
|
|
for ndim, rng in enumerate(rngs))]
|
|
|
|
def np_fun(x):
|
|
|
|
try:
|
|
|
|
return np.ravel_multi_index(x, shape, order=order, mode=mode)
|
|
|
|
except ValueError as err:
|
|
|
|
if str(err).startswith('invalid entry'):
|
|
|
|
# sentinel indicating expected error.
|
|
|
|
return -999
|
|
|
|
else:
|
|
|
|
raise
|
|
|
|
def jnp_fun(x):
|
|
|
|
try:
|
|
|
|
return jnp.ravel_multi_index(x, shape, order=order, mode=mode)
|
|
|
|
except ValueError as err:
|
|
|
|
if str(err).startswith('invalid entry'):
|
|
|
|
# sentinel indicating expected error.
|
|
|
|
return -999
|
|
|
|
else:
|
|
|
|
raise
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
|
|
|
if mode == 'raise':
|
|
|
|
msg = ("The error occurred because ravel_multi_index was jit-compiled "
|
|
|
|
"with mode='raise'. Use mode='wrap' or mode='clip' instead.")
|
2023-02-14 23:00:40 -08:00
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
2020-09-17 12:42:22 -07:00
|
|
|
jax.jit(jnp_fun)(*args_maker())
|
|
|
|
else:
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
ashape=((), (4,), (3, 4)),
|
|
|
|
cshapes=[
|
|
|
|
[(), (4,)],
|
|
|
|
[(3, 4), (4,), (3, 1)]
|
|
|
|
],
|
|
|
|
adtype=int_dtypes,
|
|
|
|
cdtype=default_dtypes,
|
|
|
|
mode=['wrap', 'clip', 'raise'],
|
|
|
|
)
|
2020-10-02 13:13:21 -07:00
|
|
|
def testChoose(self, ashape, adtype, cshapes, cdtype, mode):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(ashape, adtype), [rng(s, cdtype) for s in cshapes]]
|
|
|
|
def np_fun(a, c):
|
|
|
|
try:
|
|
|
|
return np.choose(a, c, mode=mode)
|
|
|
|
except ValueError as err:
|
|
|
|
if mode == 'raise' and str(err).startswith('invalid entry'):
|
|
|
|
return -999 # sentinel indicating expected error.
|
|
|
|
else:
|
|
|
|
raise
|
|
|
|
def jnp_fun(a, c):
|
|
|
|
try:
|
|
|
|
return jnp.choose(a, c, mode=mode)
|
|
|
|
except ValueError as err:
|
|
|
|
if mode == 'raise' and str(err).startswith('invalid entry'):
|
|
|
|
return -999 # sentinel indicating expected error.
|
|
|
|
else:
|
|
|
|
raise
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
|
|
|
if mode == 'raise':
|
|
|
|
msg = ("The error occurred because jnp.choose was jit-compiled"
|
|
|
|
" with mode='raise'. Use mode='wrap' or mode='clip' instead.")
|
2023-02-14 23:00:40 -08:00
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
2020-10-02 13:13:21 -07:00
|
|
|
jax.jit(jnp_fun)(*args_maker())
|
|
|
|
else:
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-09-17 12:42:22 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=nonempty_nonscalar_array_shapes,
|
|
|
|
dtype=int_dtypes,
|
|
|
|
idx_shape=all_shapes,
|
|
|
|
)
|
2022-05-04 13:32:16 -07:00
|
|
|
def testUnravelIndex(self, shape, idx_shape, dtype):
|
2023-02-28 12:40:30 -08:00
|
|
|
size = math.prod(shape)
|
2022-05-04 13:32:16 -07:00
|
|
|
rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3)
|
|
|
|
|
|
|
|
def np_fun(index, shape):
|
2022-05-27 13:13:34 -07:00
|
|
|
# JAX's version outputs the same dtype as the input in the typical case
|
|
|
|
# where shape is weakly-typed.
|
|
|
|
out_dtype = index.dtype
|
2022-05-04 13:32:16 -07:00
|
|
|
# Adjust out-of-bounds behavior to match jax's documented behavior.
|
|
|
|
index = np.clip(index, -size, size - 1)
|
|
|
|
index = np.where(index < 0, index + size, index)
|
2022-05-27 13:13:34 -07:00
|
|
|
return [i.astype(out_dtype) for i in np.unravel_index(index, shape)]
|
|
|
|
|
2022-05-04 13:32:16 -07:00
|
|
|
jnp_fun = jnp.unravel_index
|
|
|
|
args_maker = lambda: [rng(idx_shape, dtype), shape]
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-05-06 16:05:49 -04:00
|
|
|
|
2018-12-19 09:28:08 -05:00
|
|
|
def testAstype(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
rng = self.rng()
|
2018-12-19 09:28:08 -05:00
|
|
|
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
|
2020-05-20 01:43:48 -03:00
|
|
|
np_op = lambda x: np.asarray(x).astype(jnp.int32)
|
2020-05-14 12:58:31 -07:00
|
|
|
jnp_op = lambda x: jnp.asarray(x).astype(jnp.int32)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2018-12-19 09:28:08 -05:00
|
|
|
|
2021-11-30 13:51:38 -08:00
|
|
|
def testAstypeNone(self):
|
2021-12-10 10:32:09 -08:00
|
|
|
rng = self.rng()
|
2021-11-30 13:51:38 -08:00
|
|
|
args_maker = lambda: [rng.randn(3, 4).astype("int32")]
|
2021-12-09 09:47:21 -08:00
|
|
|
np_op = jtu.with_jax_dtype_defaults(lambda x: np.asarray(x).astype(None))
|
2021-11-30 13:51:38 -08:00
|
|
|
jnp_op = lambda x: jnp.asarray(x).astype(None)
|
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=array_shapes,
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
2020-08-07 10:31:46 -07:00
|
|
|
def testNbytes(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
np_op = lambda x: np.asarray(x).nbytes
|
|
|
|
jnp_op = lambda x: jnp.asarray(x).nbytes
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=array_shapes,
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
2022-03-18 12:32:32 +01:00
|
|
|
def testItemsize(self, shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
np_op = lambda x: np.asarray(x).itemsize
|
|
|
|
jnp_op = lambda x: jnp.asarray(x).itemsize
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
2023-02-17 10:54:37 -08:00
|
|
|
# Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs.
|
|
|
|
shape=[(0,), (32,), (2, 16)],
|
|
|
|
a_dtype=all_dtypes,
|
|
|
|
dtype=(*all_dtypes, None) if config.x64_enabled else all_dtypes,
|
2022-10-05 01:52:41 +00:00
|
|
|
)
|
2020-05-13 12:48:16 -07:00
|
|
|
def testView(self, shape, a_dtype, dtype):
|
2020-05-21 06:40:24 -07:00
|
|
|
if jtu.device_under_test() == 'tpu':
|
|
|
|
if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]:
|
|
|
|
self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.")
|
2023-02-28 14:01:43 -08:00
|
|
|
# It is possible to fill bool arrays with arbitrary bits (not just 0/1
|
|
|
|
# bytes), but the behavior is implementation-defined. We therefore only test
|
|
|
|
# the well-defined case.
|
|
|
|
rng = (jtu.rand_bool if a_dtype == np.bool_ else jtu.rand_fullrange)(
|
|
|
|
self.rng()
|
|
|
|
)
|
2020-05-13 12:48:16 -07:00
|
|
|
args_maker = lambda: [rng(shape, a_dtype)]
|
2020-05-20 01:43:48 -03:00
|
|
|
np_op = lambda x: np.asarray(x).view(dtype)
|
2020-05-13 12:48:16 -07:00
|
|
|
jnp_op = lambda x: jnp.asarray(x).view(dtype)
|
2020-05-21 09:20:59 -07:00
|
|
|
# Above may produce signaling nans; ignore warnings from invalid values.
|
|
|
|
with np.errstate(invalid='ignore'):
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2020-05-13 12:48:16 -07:00
|
|
|
|
2023-02-17 10:54:37 -08:00
|
|
|
@jtu.sample_product([
|
|
|
|
{'a_dtype': a_dtype, 'dtype': dtype}
|
|
|
|
for a_dtype in all_dtypes
|
|
|
|
for dtype in all_dtypes
|
|
|
|
if np.dtype(a_dtype).itemsize == np.dtype(dtype).itemsize
|
|
|
|
])
|
|
|
|
def testViewScalar(self, a_dtype, dtype):
|
|
|
|
if jtu.device_under_test() == 'tpu':
|
|
|
|
if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]:
|
|
|
|
self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.")
|
|
|
|
rng = jtu.rand_fullrange(self.rng())
|
|
|
|
args_maker = lambda: [jnp.array(rng((), a_dtype))]
|
|
|
|
np_op = lambda x: np.asarray(x).view(dtype)
|
|
|
|
jnp_op = lambda x: jnp.asarray(x).view(dtype)
|
|
|
|
# Above may produce signaling nans; ignore warnings from invalid values.
|
|
|
|
with np.errstate(invalid='ignore'):
|
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
|
|
|
|
2020-05-21 06:40:24 -07:00
|
|
|
def testPathologicalFloats(self):
|
|
|
|
args_maker = lambda: [np.array([
|
|
|
|
0b_0111_1111_1000_0000_0000_0000_0000_0000, # inf
|
|
|
|
0b_1111_1111_1000_0000_0000_0000_0000_0000, # -inf
|
|
|
|
0b_0111_1111_1100_0000_0000_0000_0000_0000, # qnan
|
|
|
|
0b_1111_1111_1100_0000_0000_0000_0000_0000, # -qnan
|
|
|
|
0b_0111_1111_1000_0000_0000_0000_0000_0001, # snan
|
|
|
|
0b_1111_1111_1000_0000_0000_0000_0000_0001, # -snan
|
|
|
|
0b_0111_1111_1000_0000_0000_1100_0000_0000, # nonstandard nan
|
|
|
|
0b_1111_1111_1000_0000_0000_1100_0000_0000, # -nonstandard nan
|
|
|
|
0b_0000_0000_0000_0000_0000_0000_0000_0000, # zero
|
|
|
|
0b_1000_0000_0000_0000_0000_0000_0000_0000, # -zero
|
|
|
|
], dtype='uint32')]
|
|
|
|
|
|
|
|
np_op = lambda x: np.asarray(x).view('float32').view('uint32')
|
|
|
|
jnp_op = lambda x: jnp.asarray(x).view('float32').view('uint32')
|
|
|
|
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2020-05-21 06:40:24 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
# TODO(mattjj): test other ndarray-like method overrides
|
|
|
|
|
2020-05-20 01:43:48 -03:00
|
|
|
def testNpMean(self):
|
2018-12-17 14:26:28 -08:00
|
|
|
# from https://github.com/google/jax/issues/125
|
2021-12-09 09:47:21 -08:00
|
|
|
x = jnp.eye(3, dtype=float) + 0.
|
2020-05-20 01:43:48 -03:00
|
|
|
ans = np.mean(x)
|
|
|
|
self.assertAllClose(ans, np.array(1./3), check_dtypes=False)
|
2018-12-17 14:26:28 -08:00
|
|
|
|
2018-12-19 09:07:04 -08:00
|
|
|
def testArangeOnFloats(self):
|
2021-12-09 09:47:21 -08:00
|
|
|
np_arange = jtu.with_jax_dtype_defaults(np.arange)
|
2018-12-19 09:07:04 -08:00
|
|
|
# from https://github.com/google/jax/issues/145
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertAllClose(np_arange(0.0, 1.0, 0.1),
|
2020-06-15 13:02:59 -07:00
|
|
|
jnp.arange(0.0, 1.0, 0.1))
|
|
|
|
# from https://github.com/google/jax/issues/3450
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertAllClose(np_arange(2.5),
|
2020-06-15 13:02:59 -07:00
|
|
|
jnp.arange(2.5))
|
2022-04-08 23:33:47 +01:00
|
|
|
self.assertAllClose(np_arange(0., 2.5),
|
|
|
|
jnp.arange(0., 2.5))
|
2018-12-19 09:07:04 -08:00
|
|
|
|
2021-11-01 11:44:14 -07:00
|
|
|
def testArangeTypes(self):
|
|
|
|
# Test that arange() output type is equal to the default types.
|
|
|
|
int_ = dtypes.canonicalize_dtype(jnp.int_)
|
|
|
|
float_ = dtypes.canonicalize_dtype(jnp.float_)
|
|
|
|
|
|
|
|
self.assertEqual(jnp.arange(10).dtype, int_)
|
|
|
|
self.assertEqual(jnp.arange(10.).dtype, float_)
|
|
|
|
self.assertEqual(jnp.arange(10, dtype='uint16').dtype, np.uint16)
|
|
|
|
self.assertEqual(jnp.arange(10, dtype='bfloat16').dtype, jnp.bfloat16)
|
|
|
|
|
|
|
|
self.assertEqual(jnp.arange(0, 10, 1).dtype, int_)
|
2022-06-14 11:20:37 -07:00
|
|
|
with jax.numpy_dtype_promotion('standard'):
|
|
|
|
self.assertEqual(jnp.arange(0, 10, 1.).dtype, float_)
|
|
|
|
self.assertEqual(jnp.arange(0., 10, 1).dtype, float_)
|
2021-11-01 11:44:14 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2020-07-23 10:28:44 -07:00
|
|
|
for shape in nonzerodim_shapes
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in (None, *range(len(shape)))
|
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
2020-07-23 10:28:44 -07:00
|
|
|
def testSort(self, dtype, shape, axis):
|
|
|
|
rng = jtu.rand_some_equal(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
jnp_fun = jnp.sort
|
|
|
|
np_fun = np.sort
|
|
|
|
if axis is not None:
|
|
|
|
jnp_fun = partial(jnp_fun, axis=axis)
|
|
|
|
np_fun = partial(np_fun, axis=axis)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
2020-07-23 10:28:44 -07:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2019-01-13 09:01:01 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2020-07-28 02:57:36 +05:30
|
|
|
for shape in one_dim_array_shapes
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in [None]
|
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
2020-07-28 02:57:36 +05:30
|
|
|
def testSortComplex(self, dtype, shape, axis):
|
|
|
|
rng = jtu.rand_some_equal(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2022-10-05 01:52:41 +00:00
|
|
|
self._CheckAgainstNumpy(np.sort_complex, jnp.sort_complex, args_maker,
|
|
|
|
check_dtypes=False)
|
2020-07-28 02:57:36 +05:30
|
|
|
self._CompileAndCheck(jnp.sort_complex, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2020-07-23 10:28:44 -07:00
|
|
|
for shape in nonempty_nonscalar_array_shapes
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in (-1, *range(len(shape) - 1))
|
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
input_type=[np.array, tuple],
|
|
|
|
)
|
2020-07-22 12:48:49 -07:00
|
|
|
def testLexsort(self, dtype, shape, input_type, axis):
|
|
|
|
rng = jtu.rand_some_equal(self.rng())
|
|
|
|
args_maker = lambda: [input_type(rng(shape, dtype))]
|
|
|
|
jnp_op = lambda x: jnp.lexsort(x, axis=axis)
|
2021-12-09 09:47:21 -08:00
|
|
|
np_op = jtu.with_jax_dtype_defaults(lambda x: np.lexsort(x, axis=axis))
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
2020-07-22 12:48:49 -07:00
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2020-07-23 10:28:44 -07:00
|
|
|
for shape in nonzerodim_shapes
|
2022-10-05 01:52:41 +00:00
|
|
|
for axis in (None, *range(len(shape)))
|
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
2020-07-23 10:28:44 -07:00
|
|
|
def testArgsort(self, dtype, shape, axis):
|
|
|
|
rng = jtu.rand_some_equal(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
jnp_fun = jnp.argsort
|
2021-12-09 09:47:21 -08:00
|
|
|
np_fun = jtu.with_jax_dtype_defaults(np.argsort)
|
2020-07-23 10:28:44 -07:00
|
|
|
if axis is not None:
|
|
|
|
jnp_fun = partial(jnp_fun, axis=axis)
|
|
|
|
np_fun = partial(np_fun, axis=axis)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
2020-07-23 10:28:44 -07:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2023-01-30 13:50:25 -08:00
|
|
|
@jtu.sample_product(
|
|
|
|
[{'shape': shape, 'axis': axis, 'kth': kth}
|
|
|
|
for shape in nonzerodim_shapes
|
|
|
|
for axis in range(-len(shape), len(shape))
|
|
|
|
for kth in range(-shape[axis], shape[axis])],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
|
|
|
def testPartition(self, shape, dtype, axis, kth):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
arg = rng(shape, dtype)
|
|
|
|
jnp_output = jnp.partition(arg, axis=axis, kth=kth)
|
|
|
|
np_output = np.partition(arg, axis=axis, kth=kth)
|
|
|
|
|
2023-02-08 14:41:39 -08:00
|
|
|
# Assert that pivot point is equal:
|
2023-01-30 13:50:25 -08:00
|
|
|
self.assertArraysEqual(
|
|
|
|
lax.index_in_dim(jnp_output, axis=axis, index=kth),
|
|
|
|
lax.index_in_dim(np_output, axis=axis, index=kth))
|
2023-02-08 14:41:39 -08:00
|
|
|
|
|
|
|
# Assert remaining values are correctly partitioned:
|
2023-01-30 13:50:25 -08:00
|
|
|
self.assertArraysEqual(
|
|
|
|
lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis),
|
|
|
|
lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis))
|
|
|
|
self.assertArraysEqual(
|
|
|
|
lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
|
|
|
|
lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
|
|
|
|
|
2023-02-08 14:41:39 -08:00
|
|
|
@jtu.sample_product(
|
|
|
|
[{'shape': shape, 'axis': axis, 'kth': kth}
|
|
|
|
for shape in nonzerodim_shapes
|
|
|
|
for axis in range(-len(shape), len(shape))
|
|
|
|
for kth in range(-shape[axis], shape[axis])],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
)
|
|
|
|
def testArgpartition(self, shape, dtype, axis, kth):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
arg = rng(shape, dtype)
|
|
|
|
|
|
|
|
jnp_output = jnp.argpartition(arg, axis=axis, kth=kth)
|
|
|
|
np_output = np.argpartition(arg, axis=axis, kth=kth)
|
|
|
|
|
|
|
|
# Assert that all indices are present
|
|
|
|
self.assertArraysEqual(jnp.sort(jnp_output, axis), np.sort(np_output, axis), check_dtypes=False)
|
|
|
|
|
|
|
|
# Because JAX & numpy may treat duplicates differently, we must compare values
|
|
|
|
# rather than indices.
|
|
|
|
getvals = lambda x, ind: x[ind]
|
|
|
|
for ax in range(arg.ndim):
|
|
|
|
if ax != range(arg.ndim)[axis]:
|
|
|
|
getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax)
|
|
|
|
jnp_values = getvals(arg, jnp_output)
|
|
|
|
np_values = getvals(arg, np_output)
|
|
|
|
|
|
|
|
# Assert that pivot point is equal:
|
|
|
|
self.assertArraysEqual(
|
|
|
|
lax.index_in_dim(jnp_values, axis=axis, index=kth),
|
|
|
|
lax.index_in_dim(np_values, axis=axis, index=kth))
|
|
|
|
|
|
|
|
# Assert remaining values are correctly partitioned:
|
|
|
|
self.assertArraysEqual(
|
|
|
|
lax.sort(lax.slice_in_dim(jnp_values, start_index=0, limit_index=kth, axis=axis), dimension=axis),
|
|
|
|
lax.sort(lax.slice_in_dim(np_values, start_index=0, limit_index=kth, axis=axis), dimension=axis))
|
|
|
|
self.assertArraysEqual(
|
|
|
|
lax.sort(lax.slice_in_dim(jnp_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
|
|
|
|
lax.sort(lax.slice_in_dim(np_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shifts=shifts, axis=axis)
|
2019-02-18 15:52:32 -05:00
|
|
|
for shifts, axis in [
|
|
|
|
(3, None),
|
|
|
|
(1, 1),
|
|
|
|
((3,), (0,)),
|
|
|
|
((-2,), (-2,)),
|
2020-02-27 14:43:55 -05:00
|
|
|
((1, 2), (0, -1)),
|
|
|
|
((4, 2, 5, 5, 2, 4), None),
|
|
|
|
(100, None),
|
2022-10-05 01:52:41 +00:00
|
|
|
]
|
|
|
|
],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
shape=[(3, 4), (3, 4, 5), (7, 4, 0)],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testRoll(self, shape, dtype, shifts, axis):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
args_maker = lambda: [rng(shape, dtype), np.array(shifts)]
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_op = partial(jnp.roll, axis=axis)
|
2020-05-20 01:43:48 -03:00
|
|
|
np_op = partial(np.roll, axis=axis)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2019-02-18 15:52:32 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=all_dtypes,
|
|
|
|
shape=[(1, 2, 3, 4)],
|
|
|
|
axis=[-3, 0, 2, 3],
|
|
|
|
start=[-4, -1, 2, 4],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testRollaxis(self, shape, dtype, start, axis):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-04-12 07:37:02 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
jnp_op = partial(jnp.rollaxis, axis=axis, start=start)
|
2020-05-20 01:43:48 -03:00
|
|
|
np_op = partial(np.rollaxis, axis=axis, start=start)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2020-04-12 07:37:02 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=[np.uint8, np.bool_],
|
|
|
|
bitorder=['big', 'little'],
|
|
|
|
shape=[(1, 2, 3, 4)],
|
|
|
|
axis=[None, 0, 1, -2, -1],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testPackbits(self, shape, dtype, axis, bitorder):
|
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
2020-04-13 11:57:18 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
|
2020-05-20 01:43:48 -03:00
|
|
|
np_op = partial(np.packbits, axis=axis, bitorder=bitorder)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2020-04-13 11:57:18 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=[np.uint8],
|
|
|
|
bitorder=['big', 'little'],
|
|
|
|
shape=[(1, 2, 3, 4)],
|
|
|
|
axis=[None, 0, 1, -2, -1],
|
|
|
|
count=[None, 20],
|
|
|
|
)
|
2020-05-04 23:00:20 -04:00
|
|
|
def testUnpackbits(self, shape, dtype, axis, bitorder, count):
|
|
|
|
rng = jtu.rand_int(self.rng(), 0, 256)
|
2020-04-13 11:57:18 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2023-08-10 14:34:11 -07:00
|
|
|
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder, count=count)
|
|
|
|
np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder, count=count)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2020-04-13 11:57:18 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, axis=axis)
|
2019-02-01 19:32:09 -05:00
|
|
|
for shape in [(3,), (3, 4), (3, 4, 5)]
|
2020-01-18 08:26:23 -05:00
|
|
|
for axis in itertools.chain(range(-len(shape), len(shape)),
|
|
|
|
[cast(Optional[int], None)])
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
index_shape=scalar_shapes + [(3,), (2, 1, 3)],
|
|
|
|
dtype=all_dtypes,
|
|
|
|
index_dtype=int_dtypes,
|
|
|
|
mode=[None, 'wrap', 'clip'],
|
|
|
|
)
|
2020-05-04 23:00:20 -04:00
|
|
|
def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode):
|
2019-02-01 19:32:09 -05:00
|
|
|
def args_maker():
|
|
|
|
x = rng(shape, dtype)
|
|
|
|
i = rng_indices(index_shape, index_dtype)
|
|
|
|
return x, i
|
|
|
|
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2021-02-17 15:18:19 -08:00
|
|
|
if mode is None:
|
|
|
|
rng_indices = jtu.rand_int(self.rng(), -shape[axis or 0], shape[axis or 0])
|
|
|
|
else:
|
|
|
|
rng_indices = jtu.rand_int(self.rng(), -5, 5)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_op = lambda x, i: jnp.take(x, i, axis=axis, mode=mode)
|
2020-05-20 01:43:48 -03:00
|
|
|
np_op = lambda x, i: np.take(x, i, axis=axis, mode=mode)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2019-02-01 19:32:09 -05:00
|
|
|
|
2020-07-14 12:02:26 -04:00
|
|
|
def testTakeEmpty(self):
|
|
|
|
np.testing.assert_array_equal(
|
|
|
|
jnp.array([], dtype=jnp.float32),
|
|
|
|
jnp.take(jnp.array([], jnp.float32), jnp.array([], jnp.int32)))
|
|
|
|
|
2021-05-24 11:59:41 -04:00
|
|
|
np.testing.assert_array_equal(
|
|
|
|
jnp.ones((2, 0, 4), dtype=jnp.float32),
|
|
|
|
jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32), jnp.array([], jnp.int32),
|
|
|
|
axis=1))
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(IndexError, "non-empty jnp.take"):
|
|
|
|
jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32),
|
|
|
|
jnp.array([0], jnp.int32), axis=1)
|
2020-07-14 12:02:26 -04:00
|
|
|
|
2022-09-29 09:33:38 -07:00
|
|
|
def testTakeOptionalArgs(self):
|
|
|
|
x = jnp.arange(5.0)
|
|
|
|
ind = jnp.array([0, 2, 4, 6])
|
2022-12-01 13:56:42 -08:00
|
|
|
expected = jnp.array([0.0, 2.0, 4.0, 10.0], dtype=x.dtype)
|
2022-09-29 09:33:38 -07:00
|
|
|
actual = jnp.take(x, ind, unique_indices=True,
|
|
|
|
indices_are_sorted=True, fill_value=10.0)
|
|
|
|
self.assertArraysEqual(expected, actual)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(x_shape=x_shape, i_shape=i_shape, axis=axis)
|
2019-06-24 10:34:48 -04:00
|
|
|
for x_shape, i_shape in filter(
|
|
|
|
_shapes_are_equal_length,
|
|
|
|
filter(_shapes_are_broadcast_compatible,
|
2020-06-29 16:22:05 -07:00
|
|
|
itertools.combinations_with_replacement(nonempty_nonscalar_array_shapes, 2)))
|
2020-01-18 08:26:23 -05:00
|
|
|
for axis in itertools.chain(range(len(x_shape)), [-1],
|
|
|
|
[cast(Optional[int], None)])
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=default_dtypes,
|
|
|
|
index_dtype=int_dtypes,
|
|
|
|
)
|
2020-12-03 08:40:23 -05:00
|
|
|
def testTakeAlongAxis(self, x_shape, i_shape, dtype, index_dtype, axis):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
|
2023-04-13 11:48:11 -07:00
|
|
|
i_shape = list(i_shape)
|
2019-06-24 10:34:48 -04:00
|
|
|
if axis is None:
|
2023-04-13 11:48:11 -07:00
|
|
|
i_shape = [math.prod(i_shape)]
|
2019-06-24 10:34:48 -04:00
|
|
|
else:
|
|
|
|
# Test the case where the size of the axis doesn't necessarily broadcast.
|
|
|
|
i_shape[axis] *= 3
|
2019-01-13 12:26:37 -08:00
|
|
|
def args_maker():
|
2019-06-24 10:34:48 -04:00
|
|
|
x = rng(x_shape, dtype)
|
2023-04-13 11:48:11 -07:00
|
|
|
n = math.prod(x_shape) if axis is None else x_shape[axis]
|
2020-12-03 08:40:23 -05:00
|
|
|
if np.issubdtype(index_dtype, np.unsignedinteger):
|
|
|
|
index_rng = jtu.rand_int(self.rng(), 0, n)
|
|
|
|
else:
|
|
|
|
index_rng = jtu.rand_int(self.rng(), -n, n)
|
|
|
|
i = index_rng(i_shape, index_dtype)
|
2019-01-13 12:26:37 -08:00
|
|
|
return x, i
|
|
|
|
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_op = lambda x, i: jnp.take_along_axis(x, i, axis=axis)
|
2019-01-14 12:56:41 -08:00
|
|
|
|
2020-05-20 01:43:48 -03:00
|
|
|
if hasattr(np, "take_along_axis"):
|
|
|
|
np_op = lambda x, i: np.take_along_axis(x, i, axis=axis)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2019-01-13 12:26:37 -08:00
|
|
|
|
2020-12-03 08:40:23 -05:00
|
|
|
def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self):
|
|
|
|
# https://github.com/google/jax/issues/5088
|
|
|
|
h = jtu.rand_default(self.rng())((256, 256, 100), np.float32)
|
|
|
|
g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8)
|
|
|
|
q0 = jnp.take_along_axis(h, g, axis=-1)
|
|
|
|
q1 = np.take_along_axis( h, g, axis=-1)
|
|
|
|
np.testing.assert_equal(q0, q1)
|
|
|
|
|
2022-04-19 16:05:29 -04:00
|
|
|
def testTakeAlongAxisOutOfBounds(self):
|
|
|
|
x = jnp.arange(10, dtype=jnp.float32)
|
|
|
|
idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11])
|
|
|
|
out = jnp.take_along_axis(x, idx, axis=0)
|
|
|
|
expected_fill = np.array([jnp.nan, 0, 1, 5, 9, 0, 1, 5, 9, jnp.nan,
|
|
|
|
jnp.nan], np.float32)
|
2022-04-21 08:51:35 -07:00
|
|
|
np.testing.assert_array_equal(expected_fill, out)
|
2022-04-19 16:05:29 -04:00
|
|
|
out = jnp.take_along_axis(x, idx, axis=0, mode="fill")
|
|
|
|
np.testing.assert_array_equal(expected_fill, out)
|
|
|
|
|
2022-04-21 08:51:35 -07:00
|
|
|
expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32)
|
|
|
|
out = jnp.take_along_axis(x, idx, axis=0, mode="clip")
|
|
|
|
np.testing.assert_array_equal(expected_clip, out)
|
|
|
|
|
2022-04-20 10:04:37 -07:00
|
|
|
def testTakeAlongAxisRequiresIntIndices(self):
|
|
|
|
x = jnp.arange(5)
|
|
|
|
idx = jnp.array([3.], jnp.float32)
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
"take_along_axis indices must be of integer type, got float32"):
|
|
|
|
jnp.take_along_axis(x, idx, axis=0)
|
|
|
|
|
2022-11-09 18:57:28 -08:00
|
|
|
def testTakeAlongAxisWithEmptyArgs(self):
|
|
|
|
# take_along_axis should allow us to gather an empty list of indices from
|
|
|
|
# an empty input axis without raising a shape error.
|
|
|
|
x = jnp.ones((4, 0, 3), dtype=jnp.int32)
|
|
|
|
np.testing.assert_array_equal(x, jnp.take_along_axis(x, x, axis=1))
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=inexact_dtypes,
|
|
|
|
shape=[0, 5],
|
|
|
|
n=[2, 4],
|
|
|
|
increasing=[False, True],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testVander(self, shape, dtype, n, increasing):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
def np_fun(arg):
|
|
|
|
arg = arg.astype(np.float32) if dtype == jnp.bfloat16 else arg
|
|
|
|
return np.vander(arg, N=n, increasing=increasing)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda arg: jnp.vander(arg, N=n, increasing=increasing)
|
2019-02-05 08:58:57 -05:00
|
|
|
args_maker = lambda: [rng([shape], dtype)]
|
|
|
|
# np.vander seems to return float64 for all floating types. We could obey
|
|
|
|
# those semantics, but they seem like a bug.
|
2020-05-20 01:43:48 -03:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
2023-09-05 18:48:18 -07:00
|
|
|
tol={np.float32: 1e-3, np.complex64: 1e-3})
|
2020-03-06 14:59:51 -05:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False)
|
2019-02-05 08:58:57 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=array_shapes,
|
2023-01-30 10:36:36 -08:00
|
|
|
dtype=all_dtypes,
|
2022-10-05 01:52:41 +00:00
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testNanToNum(self, shape, dtype):
|
|
|
|
rng = jtu.rand_some_inf_and_nan(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type
|
|
|
|
def np_fun(x):
|
2020-03-06 14:59:51 -05:00
|
|
|
if dtype == jnp.bfloat16:
|
2020-05-20 01:43:48 -03:00
|
|
|
x = np.where(np.isnan(x), dtype(0), x)
|
|
|
|
x = np.where(np.isposinf(x), jnp.finfo(dtype).max, x)
|
|
|
|
x = np.where(np.isneginf(x), jnp.finfo(dtype).min, x)
|
2019-11-20 22:43:46 -05:00
|
|
|
return x
|
|
|
|
else:
|
2020-05-20 01:43:48 -03:00
|
|
|
return np.nan_to_num(x).astype(dtype)
|
2019-11-20 22:43:46 -05:00
|
|
|
|
2019-05-07 16:29:45 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE
|
2020-05-20 01:43:48 -03:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp.nan_to_num, args_maker,
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
check_dtypes=check_dtypes)
|
2020-03-06 14:59:51 -05:00
|
|
|
self._CompileAndCheck(jnp.nan_to_num, args_maker,
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
check_dtypes=check_dtypes)
|
2019-05-07 16:29:45 -04:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shapes=shapes, dtypes=dtypes)
|
2019-06-17 17:08:27 -04:00
|
|
|
for shapes, dtypes in (
|
|
|
|
((), ()),
|
2020-05-20 01:43:48 -03:00
|
|
|
(((7,),), (np.int32,)),
|
|
|
|
(((3,), (4,)), (np.int32, np.int32)),
|
|
|
|
(((3,), (1,), (4,)), (np.int32, np.int32, np.int32)),
|
2022-10-05 01:52:41 +00:00
|
|
|
)
|
|
|
|
],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testIx_(self, shapes, dtypes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-06-17 17:08:27 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype)
|
|
|
|
for shape, dtype in zip(shapes, dtypes)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np.ix_, jnp.ix_, args_maker)
|
|
|
|
self._CompileAndCheck(jnp.ix_, args_maker)
|
2019-06-17 17:08:27 -04:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dimensions=[(), (2,), (3, 0), (4, 5, 6)],
|
|
|
|
dtype=number_dtypes,
|
|
|
|
sparse=[True, False],
|
|
|
|
)
|
2020-05-11 10:22:49 -07:00
|
|
|
def testIndices(self, dimensions, dtype, sparse):
|
2022-11-09 18:57:28 -08:00
|
|
|
if jtu.device_under_test() == "tpu" and dtype in (np.int16, np.uint16):
|
|
|
|
raise unittest.SkipTest("Compilation failure on TPU ")
|
2020-05-11 10:22:49 -07:00
|
|
|
def args_maker(): return []
|
2021-06-10 12:12:13 -04:00
|
|
|
np_fun = partial(np.indices, dimensions=dimensions,
|
|
|
|
dtype=dtype, sparse=sparse)
|
2020-05-11 10:22:49 -07:00
|
|
|
jnp_fun = partial(jnp.indices, dimensions=dimensions,
|
|
|
|
dtype=dtype, sparse=sparse)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-05-11 10:22:49 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=all_shapes,
|
|
|
|
dtype=all_dtypes,
|
|
|
|
)
|
2019-12-20 18:42:33 -05:00
|
|
|
def testWhereOneArgument(self, shape, dtype):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda x: np.where(x)
|
|
|
|
np_fun = jtu.ignore_warning(
|
2020-04-12 15:35:35 -04:00
|
|
|
category=DeprecationWarning,
|
2020-05-20 01:43:48 -03:00
|
|
|
message="Calling nonzero on 0d arrays.*")(np_fun)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda x: jnp.where(x)
|
2019-12-20 18:42:33 -05:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-05-20 01:43:48 -03:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
2019-12-20 18:42:33 -05:00
|
|
|
|
2021-06-08 14:04:04 -07:00
|
|
|
# JIT compilation requires specifying a size statically. Full test of
|
|
|
|
# this behavior is in testNonzeroSize().
|
|
|
|
jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shapes=filter(_shapes_are_broadcast_compatible,
|
|
|
|
itertools.combinations_with_replacement(all_shapes, 3)),
|
|
|
|
dtypes=itertools.combinations_with_replacement(all_dtypes, 3),
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testWhereThreeArgument(self, shapes, dtypes):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
2020-05-20 01:43:48 -03:00
|
|
|
def np_fun(cond, x, y):
|
2022-10-06 10:20:26 -07:00
|
|
|
return jtu.promote_like_jnp(partial(np.where, cond))(x, y)
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp.where, args_maker)
|
|
|
|
self._CompileAndCheck(jnp.where, args_maker)
|
Change JAX type promotion to prefer inexact types. (#1815)
Change the JAX type promotion table to prefer inexact types during type promotion.
NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost.
This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators:
e.g.,
```
import numpy as onp
from jax import numpy as np
In [1]: onp.promote_types(onp.float32, onp.int32)
Out[1]: dtype('float64')
In [2]: onp.promote_types(onp.float16, onp.int64)
Out[2]: dtype('float64')
In [3]: np.promote_types(onp.float32, onp.int32)
Out[3]: dtype('float32')
In [4]: np.promote_types(onp.float16, onp.int64)
Out[4]: dtype('float16')
```
This change is in preparation for enabling x64 mode by default on all platforms.
2019-12-05 10:57:23 -05:00
|
|
|
|
|
|
|
def testWhereScalarPromotion(self):
|
2020-03-06 14:59:51 -05:00
|
|
|
x = jnp.where(jnp.array([True, False]), 3,
|
|
|
|
jnp.ones((2,), dtype=jnp.float32))
|
2020-05-20 01:43:48 -03:00
|
|
|
self.assertEqual(x.dtype, np.dtype(np.float32))
|
Change JAX type promotion to prefer inexact types. (#1815)
Change the JAX type promotion table to prefer inexact types during type promotion.
NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost.
This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators:
e.g.,
```
import numpy as onp
from jax import numpy as np
In [1]: onp.promote_types(onp.float32, onp.int32)
Out[1]: dtype('float64')
In [2]: onp.promote_types(onp.float16, onp.int64)
Out[2]: dtype('float64')
In [3]: np.promote_types(onp.float32, onp.int32)
Out[3]: dtype('float32')
In [4]: np.promote_types(onp.float16, onp.int64)
Out[4]: dtype('float16')
```
This change is in preparation for enabling x64 mode by default on all platforms.
2019-12-05 10:57:23 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(n=n, shapes=shapes)
|
|
|
|
for n in range(1, 3)
|
|
|
|
for shapes in filter(
|
2019-06-24 09:27:01 -04:00
|
|
|
_shapes_are_broadcast_compatible,
|
2022-10-05 01:52:41 +00:00
|
|
|
itertools.combinations_with_replacement(all_shapes, 2 * n + 1))
|
|
|
|
],
|
|
|
|
# To avoid forming the full product of shapes and dtypes we always sample
|
|
|
|
# maximal set of dtypes.
|
|
|
|
dtypes=itertools.combinations_with_replacement(all_dtypes, 3),
|
|
|
|
)
|
|
|
|
def testSelect(self, n, shapes, dtypes):
|
|
|
|
dtypes = dtypes[:n+1]
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-06-24 09:27:01 -04:00
|
|
|
n = len(dtypes) - 1
|
|
|
|
def args_maker():
|
2020-05-20 01:43:48 -03:00
|
|
|
condlist = [rng(shape, np.bool_) for shape in shapes[:n]]
|
2019-06-24 09:27:01 -04:00
|
|
|
choicelist = [rng(shape, dtype)
|
|
|
|
for shape, dtype in zip(shapes[n:-1], dtypes[:n])]
|
|
|
|
default = rng(shapes[-1], dtypes[-1])
|
|
|
|
return condlist, choicelist, default
|
2019-10-22 19:53:59 -04:00
|
|
|
# TODO(phawkins): float32/float64 type mismatches
|
2022-10-05 01:52:41 +00:00
|
|
|
@jax.numpy_dtype_promotion('standard')
|
2020-05-20 01:43:48 -03:00
|
|
|
def np_fun(condlist, choicelist, default):
|
2020-03-06 14:59:51 -05:00
|
|
|
choicelist = [x if jnp.result_type(x) != jnp.bfloat16
|
2020-05-20 01:43:48 -03:00
|
|
|
else x.astype(np.float32) for x in choicelist]
|
2020-03-06 14:59:51 -05:00
|
|
|
dtype = jnp.result_type(default, *choicelist)
|
2020-05-20 01:43:48 -03:00
|
|
|
return np.select(condlist,
|
|
|
|
[np.asarray(x, dtype=dtype) for x in choicelist],
|
|
|
|
np.asarray(default, dtype=dtype))
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp.select, args_maker,
|
|
|
|
check_dtypes=False)
|
|
|
|
self._CompileAndCheck(jnp.select, args_maker,
|
|
|
|
rtol={np.float64: 1e-7, np.complex128: 1e-7})
|
2019-06-24 09:27:01 -04:00
|
|
|
|
2019-02-06 09:23:34 -08:00
|
|
|
def testIssue330(self):
|
2020-03-06 14:59:51 -05:00
|
|
|
x = jnp.full((1, 1), jnp.array([1])[0]) # doesn't crash
|
2019-02-06 09:23:34 -08:00
|
|
|
self.assertEqual(x[0, 0], 1)
|
|
|
|
|
2019-02-12 19:56:00 -08:00
|
|
|
def testScalarDtypePromotion(self):
|
2020-05-20 01:43:48 -03:00
|
|
|
orig_numpy_result = (1 + np.eye(1, dtype=np.float32)).dtype
|
2020-03-06 14:59:51 -05:00
|
|
|
jax_numpy_result = (1 + jnp.eye(1, dtype=jnp.float32)).dtype
|
2019-02-12 19:56:00 -08:00
|
|
|
self.assertEqual(orig_numpy_result, jax_numpy_result)
|
|
|
|
|
2019-02-13 08:52:42 -08:00
|
|
|
def testSymmetrizeDtypePromotion(self):
|
2020-05-20 01:43:48 -03:00
|
|
|
x = np.eye(3, dtype=np.float32)
|
2019-02-13 08:52:42 -08:00
|
|
|
orig_numpy_result = ((x + x.T) / 2).dtype
|
|
|
|
|
2020-03-06 14:59:51 -05:00
|
|
|
x = jnp.eye(3, dtype=jnp.float32)
|
2019-02-13 08:52:42 -08:00
|
|
|
jax_numpy_result = ((x + x.T) / 2).dtype
|
|
|
|
self.assertEqual(orig_numpy_result, jax_numpy_result)
|
|
|
|
|
2020-03-17 22:07:53 -07:00
|
|
|
# NOTE(mattjj): I disabled this test when removing lax._safe_mul because
|
|
|
|
# introducing the convention 0 * inf = 0 leads to silently wrong results in
|
|
|
|
# some cases. See this comment for details:
|
|
|
|
# https://github.com/google/jax/issues/1052#issuecomment-514083352
|
|
|
|
# def testIssue347(self):
|
|
|
|
# # https://github.com/google/jax/issues/347
|
|
|
|
# def test_fail(x):
|
|
|
|
# x = jnp.sqrt(jnp.sum(x ** 2, axis=1))
|
|
|
|
# ones = jnp.ones_like(x)
|
|
|
|
# x = jnp.where(x > 0.5, x, ones)
|
|
|
|
# return jnp.sum(x)
|
|
|
|
# x = jnp.array([[1, 2], [3, 4], [0, 0]], dtype=jnp.float64)
|
2021-09-13 16:00:22 -04:00
|
|
|
# result = jax.grad(test_fail)(x)
|
2020-05-20 01:43:48 -03:00
|
|
|
# assert not np.any(np.isnan(result))
|
2019-02-15 07:04:57 -08:00
|
|
|
|
2019-02-27 07:42:26 -08:00
|
|
|
def testIssue453(self):
|
|
|
|
# https://github.com/google/jax/issues/453
|
2020-05-20 01:43:48 -03:00
|
|
|
a = np.arange(6) + 1
|
2020-03-06 14:59:51 -05:00
|
|
|
ans = jnp.reshape(a, (3, 2), order='F')
|
2020-05-20 01:43:48 -03:00
|
|
|
expected = np.reshape(a, (3, 2), order='F')
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-02-27 07:42:26 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=[int, float, bool, complex],
|
|
|
|
op=["atleast_1d", "atleast_2d", "atleast_3d"],
|
|
|
|
)
|
2021-12-09 09:47:21 -08:00
|
|
|
def testAtLeastNdLiterals(self, dtype, op):
|
2019-04-29 02:57:02 -05:00
|
|
|
# Fixes: https://github.com/google/jax/issues/634
|
2021-12-09 09:47:21 -08:00
|
|
|
np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype])
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = lambda arg: getattr(jnp, op)(arg)
|
2021-12-09 09:47:21 -08:00
|
|
|
args_maker = lambda: [dtype(2)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2019-04-29 02:57:02 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(0,), (5,), (10,)],
|
|
|
|
dtype=int_dtypes,
|
|
|
|
weights=[True, False],
|
|
|
|
minlength=[0, 20],
|
|
|
|
length=[None, 8],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testBincount(self, shape, dtype, weights, minlength, length):
|
2021-10-15 12:31:17 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-07 13:17:43 -07:00
|
|
|
args_maker = lambda: (rng(shape, dtype), (rng(shape, 'float32') if weights else None))
|
|
|
|
|
2021-10-15 12:31:17 -07:00
|
|
|
def np_fun(x, *args):
|
|
|
|
x = np.clip(x, 0, None) # jnp.bincount clips negative values to zero.
|
|
|
|
out = np.bincount(x, *args, minlength=minlength)
|
|
|
|
if length and length > out.size:
|
|
|
|
return np.pad(out, (0, length - out.size))
|
|
|
|
return out[:length]
|
2020-05-07 13:17:43 -07:00
|
|
|
jnp_fun = partial(jnp.bincount, minlength=minlength, length=length)
|
|
|
|
|
2021-10-15 12:31:17 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
2020-05-07 13:17:43 -07:00
|
|
|
if length is not None:
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-05-07 13:17:43 -07:00
|
|
|
|
|
|
|
def testBincountNegative(self):
|
|
|
|
# Test that jnp.bincount ignores negative values.
|
|
|
|
x_rng = jtu.rand_int(self.rng(), -100, 100)
|
|
|
|
w_rng = jtu.rand_uniform(self.rng())
|
|
|
|
shape = (1000,)
|
|
|
|
x = x_rng(shape, 'int32')
|
|
|
|
w = w_rng(shape, 'float32')
|
|
|
|
|
2020-05-20 01:43:48 -03:00
|
|
|
xn = np.array(x)
|
2020-05-07 13:17:43 -07:00
|
|
|
xn[xn < 0] = 0
|
2020-05-20 01:43:48 -03:00
|
|
|
wn = np.array(w)
|
|
|
|
np_result = np.bincount(xn[xn >= 0], wn[xn >= 0])
|
2020-05-07 13:17:43 -07:00
|
|
|
jnp_result = jnp.bincount(x, w)
|
2020-05-20 01:43:48 -03:00
|
|
|
self.assertAllClose(np_result, jnp_result, check_dtypes=False)
|
2020-05-07 13:17:43 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
input=[
|
|
|
|
3,
|
|
|
|
[3],
|
|
|
|
[np.array(3)],
|
|
|
|
[np.array([3])],
|
|
|
|
[[np.array(3)]],
|
|
|
|
[[np.array([3])]],
|
|
|
|
[3, 4, 5],
|
|
|
|
[
|
|
|
|
[np.eye(2, dtype=np.int32) * 2, np.zeros((2, 3), dtype=np.int32)],
|
|
|
|
[np.ones((3, 2), dtype=np.int32), np.eye(3, dtype=np.int32) * 3],
|
|
|
|
],
|
|
|
|
[np.array([1, 2, 3]), np.array([2, 3, 4]), 10],
|
|
|
|
[np.ones((2, 2), dtype=np.int32), np.zeros((2, 2), dtype=np.int32)],
|
|
|
|
[[np.array([1, 2, 3])], [np.array([2, 3, 4])]],
|
|
|
|
],
|
|
|
|
)
|
2020-01-29 11:55:53 -05:00
|
|
|
def testBlock(self, input):
|
|
|
|
args_maker = lambda: [input]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np.block, jnp.block, args_maker)
|
|
|
|
self._CompileAndCheck(jnp.block, args_maker)
|
2019-04-29 02:57:02 -05:00
|
|
|
|
2019-04-01 15:49:12 -07:00
|
|
|
def testLongLong(self):
|
2021-09-13 16:00:22 -04:00
|
|
|
self.assertAllClose(np.int64(7), jax.jit(lambda x: x)(np.longlong(7)))
|
2019-04-01 15:49:12 -07:00
|
|
|
|
2020-12-08 13:03:30 -08:00
|
|
|
@jtu.ignore_warning(category=UserWarning,
|
|
|
|
message="Explicitly requested dtype.*")
|
2022-06-14 11:20:37 -07:00
|
|
|
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
|
2019-04-06 12:52:47 -07:00
|
|
|
def testArange(self):
|
|
|
|
# test cases inspired by dask tests at
|
2021-06-18 08:55:08 +03:00
|
|
|
# https://github.com/dask/dask/blob/main/dask/array/tests/test_creation.py#L92
|
2021-12-09 09:47:21 -08:00
|
|
|
np_arange = jtu.with_jax_dtype_defaults(np.arange)
|
2020-03-06 14:59:51 -05:00
|
|
|
self.assertAllClose(jnp.arange(77),
|
2021-12-09 09:47:21 -08:00
|
|
|
np_arange(77))
|
2020-03-06 14:59:51 -05:00
|
|
|
self.assertAllClose(jnp.arange(2, 13),
|
2021-12-09 09:47:21 -08:00
|
|
|
np_arange(2, 13))
|
2020-03-06 14:59:51 -05:00
|
|
|
self.assertAllClose(jnp.arange(4, 21, 9),
|
2021-12-09 09:47:21 -08:00
|
|
|
np_arange(4, 21, 9))
|
2020-03-06 14:59:51 -05:00
|
|
|
self.assertAllClose(jnp.arange(53, 5, -3),
|
2021-12-09 09:47:21 -08:00
|
|
|
np_arange(53, 5, -3))
|
2020-03-06 14:59:51 -05:00
|
|
|
self.assertAllClose(jnp.arange(77, dtype=float),
|
2021-12-09 09:47:21 -08:00
|
|
|
np_arange(77, dtype=float))
|
2020-03-06 14:59:51 -05:00
|
|
|
self.assertAllClose(jnp.arange(2, 13, dtype=int),
|
2021-12-09 09:47:21 -08:00
|
|
|
np_arange(2, 13, dtype=int))
|
2020-03-06 14:59:51 -05:00
|
|
|
self.assertAllClose(jnp.arange(0, 1, -0.5),
|
2021-12-09 09:47:21 -08:00
|
|
|
np_arange(0, 1, -0.5))
|
2019-04-06 12:52:47 -07:00
|
|
|
|
2020-03-06 14:59:51 -05:00
|
|
|
self.assertRaises(TypeError, lambda: jnp.arange())
|
2019-04-06 12:52:47 -07:00
|
|
|
|
2020-03-06 14:59:51 -05:00
|
|
|
# test that jnp.arange(N) doesn't instantiate an ndarray
|
2020-05-20 01:43:48 -03:00
|
|
|
self.assertNotEqual(type(jnp.arange(77)), type(np.arange(77)))
|
|
|
|
self.assertEqual(type(jnp.arange(77)), type(lax.iota(np.int32, 77)))
|
2019-04-06 12:52:47 -07:00
|
|
|
|
2020-03-06 14:59:51 -05:00
|
|
|
# test that jnp.arange(N, dtype=int32) doesn't instantiate an ndarray
|
2020-03-09 10:08:56 +01:00
|
|
|
self.assertNotEqual(type(jnp.arange(77, dtype=jnp.int32)),
|
2021-12-09 09:47:21 -08:00
|
|
|
type(np.arange(77, dtype=np.int32)))
|
2020-03-09 10:08:56 +01:00
|
|
|
self.assertEqual(type(jnp.arange(77, dtype=jnp.int32)),
|
2021-12-09 09:47:21 -08:00
|
|
|
type(lax.iota(np.int32, 77)))
|
2019-06-25 13:02:09 -04:00
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def testArangeJit(self):
|
2021-09-13 16:00:22 -04:00
|
|
|
ans = jax.jit(lambda: jnp.arange(5))()
|
2021-12-09 09:47:21 -08:00
|
|
|
expected = jtu.with_jax_dtype_defaults(np.arange)(5)
|
2020-07-30 12:59:36 -07:00
|
|
|
self.assertAllClose(ans, expected)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(args=[(5,), (0, 5)])
|
2022-04-08 23:33:47 +01:00
|
|
|
def testArangeJaxpr(self, args):
|
|
|
|
jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args))()
|
|
|
|
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
|
|
|
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p)
|
|
|
|
|
2019-06-09 20:18:18 -07:00
|
|
|
def testIssue830(self):
|
2020-03-06 14:59:51 -05:00
|
|
|
a = jnp.arange(4, dtype=jnp.complex64)
|
|
|
|
self.assertEqual(a.dtype, jnp.complex64)
|
2019-06-09 20:18:18 -07:00
|
|
|
|
2019-05-17 12:48:46 -07:00
|
|
|
def testIssue728(self):
|
2022-12-01 13:56:42 -08:00
|
|
|
np_eye = jtu.with_jax_dtype_defaults(np.eye)
|
|
|
|
self.assertAllClose(jnp.eye(5000), np_eye(5000))
|
|
|
|
self.assertEqual(0, np.sum(jnp.eye(1050) - np_eye(1050)))
|
2019-05-17 12:48:46 -07:00
|
|
|
|
2019-05-21 21:37:52 -07:00
|
|
|
def testIssue746(self):
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp.arange(12).reshape(3, 4) # doesn't crash
|
2019-05-21 21:37:52 -07:00
|
|
|
|
2019-05-24 11:07:08 -04:00
|
|
|
def testIssue764(self):
|
2020-03-06 14:59:51 -05:00
|
|
|
x = jnp.linspace(190, 200, 4)
|
2021-09-13 16:00:22 -04:00
|
|
|
f = jax.grad(lambda x: jnp.sum(jnp.tanh(x)))
|
2019-05-24 11:07:08 -04:00
|
|
|
# Expected values computed with autograd in float64 precision.
|
2020-05-20 01:43:48 -03:00
|
|
|
expected = np.array([3.71669453e-165, 4.72999108e-168, 6.01954653e-171,
|
|
|
|
7.66067839e-174], np.float64)
|
2019-05-24 11:07:08 -04:00
|
|
|
self.assertAllClose(f(x), expected, check_dtypes=False)
|
|
|
|
|
2019-05-28 10:30:58 -04:00
|
|
|
def testIssue776(self):
|
|
|
|
"""Tests that the scatter-add transpose rule instantiates symbolic zeros."""
|
|
|
|
def f(u):
|
2022-12-01 13:56:42 -08:00
|
|
|
y = jnp.ones_like(u, shape=10).at[np.array([2, 4, 5])].add(u)
|
2019-05-28 10:30:58 -04:00
|
|
|
# The transpose rule for lax.tie_in returns a symbolic zero for its first
|
|
|
|
# argument.
|
|
|
|
return lax.tie_in(y, 7.)
|
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
self.assertAllClose(np.zeros(3,), jax.grad(f)(np.ones(3,)))
|
2019-05-28 10:30:58 -04:00
|
|
|
|
2020-03-17 22:07:53 -07:00
|
|
|
# NOTE(mattjj): I disabled this test when removing lax._safe_mul because this
|
|
|
|
# is a numerical stability issue that should be solved with a custom jvp rule
|
|
|
|
# of the sigmoid function being differentiated here, not by safe_mul.
|
|
|
|
# def testIssue777(self):
|
2020-05-20 01:43:48 -03:00
|
|
|
# x = jnp.linspace(-200, 0, 4, dtype=np.float32)
|
2021-09-13 16:00:22 -04:00
|
|
|
# f = jax.grad(lambda x: jnp.sum(1 / (1 + jnp.exp(-x))))
|
2020-06-01 17:19:23 -04:00
|
|
|
# self.assertAllClose(f(x), np.array([0., 0., 0., 0.25], dtype=np.float32))
|
2019-07-08 09:29:53 -04:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
dtype=float_dtypes,
|
|
|
|
op=("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan",
|
|
|
|
"sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp",
|
|
|
|
"log", "expm1", "log1p"),
|
|
|
|
)
|
2019-05-29 12:51:24 -04:00
|
|
|
def testMathSpecialFloatValues(self, op, dtype):
|
2020-05-20 01:43:48 -03:00
|
|
|
np_op = getattr(np, op)
|
|
|
|
np_op = jtu.ignore_warning(category=RuntimeWarning,
|
|
|
|
message="invalid value.*")(np_op)
|
|
|
|
np_op = jtu.ignore_warning(category=RuntimeWarning,
|
|
|
|
message="divide by zero.*")(np_op)
|
|
|
|
np_op = jtu.ignore_warning(category=RuntimeWarning,
|
|
|
|
message="overflow.*")(np_op)
|
2020-04-12 15:35:35 -04:00
|
|
|
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_op = getattr(jnp, op)
|
2020-05-20 01:43:48 -03:00
|
|
|
dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type
|
|
|
|
for x in (np.nan, -np.inf, -100., -2., -1., 0., 1., 2., 100., np.inf,
|
|
|
|
jnp.finfo(dtype).max, np.sqrt(jnp.finfo(dtype).max),
|
|
|
|
np.sqrt(jnp.finfo(dtype).max) * 2.):
|
2020-04-25 11:01:06 -04:00
|
|
|
if (op in ("sin", "cos", "tan") and
|
2019-08-04 17:17:49 -04:00
|
|
|
jtu.device_under_test() == "tpu"):
|
2020-04-25 11:01:06 -04:00
|
|
|
continue # TODO(b/132196789): fix and reenable.
|
2019-05-29 12:51:24 -04:00
|
|
|
x = dtype(x)
|
2020-05-20 01:43:48 -03:00
|
|
|
expected = np_op(x)
|
2020-03-06 14:59:51 -05:00
|
|
|
actual = jnp_op(x)
|
2020-05-20 01:43:48 -03:00
|
|
|
tol = jtu.tolerance(dtype, {np.float32: 1e-3, np.float64: 1e-7})
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(expected, actual, atol=tol,
|
2019-10-22 19:53:59 -04:00
|
|
|
rtol=tol)
|
2019-05-29 12:51:24 -04:00
|
|
|
|
2019-07-01 14:55:39 -04:00
|
|
|
def testIssue956(self):
|
2020-03-06 14:59:51 -05:00
|
|
|
self.assertRaises(TypeError, lambda: jnp.ndarray((1, 1)))
|
2019-07-01 14:55:39 -04:00
|
|
|
|
2019-07-23 16:18:10 -04:00
|
|
|
def testIssue967(self):
|
2020-03-06 14:59:51 -05:00
|
|
|
self.assertRaises(TypeError, lambda: jnp.zeros(1.5))
|
2019-02-05 08:58:57 -05:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shape=[(5,), (10, 5), (4, 10)],
|
|
|
|
dtype=number_dtypes,
|
|
|
|
rowvar=[True, False],
|
|
|
|
)
|
2022-11-15 13:49:42 -08:00
|
|
|
@jax.default_matmul_precision("float32")
|
2020-05-11 12:09:54 -04:00
|
|
|
def testCorrCoef(self, shape, dtype, rowvar):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
def args_maker():
|
|
|
|
ok = False
|
|
|
|
while not ok:
|
|
|
|
x = rng(shape, dtype)
|
2020-05-20 01:43:48 -03:00
|
|
|
ok = not np.any(np.isclose(np.std(x), 0.0))
|
2020-05-11 12:09:54 -04:00
|
|
|
return (x,)
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = partial(np.corrcoef, rowvar=rowvar)
|
|
|
|
np_fun = jtu.ignore_warning(
|
|
|
|
category=RuntimeWarning, message="invalid value encountered.*")(np_fun)
|
2020-04-12 15:35:35 -04:00
|
|
|
jnp_fun = partial(jnp.corrcoef, rowvar=rowvar)
|
2022-11-15 13:49:42 -08:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2019-07-28 15:17:23 -04:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(dtype=dtype, end_dtype=end_dtype, begin_dtype=begin_dtype,
|
|
|
|
shape=shape, begin_shape=begin_shape, end_shape=end_shape)
|
2020-05-04 16:55:47 -04:00
|
|
|
for dtype in number_dtypes
|
|
|
|
for end_dtype in [None] + [dtype]
|
|
|
|
for begin_dtype in [None] + [dtype]
|
|
|
|
for shape in [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]
|
2020-05-04 23:00:20 -04:00
|
|
|
for begin_shape in (
|
|
|
|
[None] if begin_dtype is None
|
|
|
|
else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE])
|
|
|
|
for end_shape in (
|
|
|
|
[None] if end_dtype is None
|
2022-10-05 01:52:41 +00:00
|
|
|
else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE])
|
|
|
|
],
|
|
|
|
)
|
2020-05-04 16:55:47 -04:00
|
|
|
def testEDiff1d(self, shape, dtype, end_shape, end_dtype, begin_shape,
|
2020-05-04 23:00:20 -04:00
|
|
|
begin_dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-04 16:55:47 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype),
|
|
|
|
(None if end_dtype is None else rng(end_shape, end_dtype)),
|
|
|
|
(None if begin_dtype is None else rng(begin_shape, begin_dtype))]
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin)
|
2020-05-04 16:55:47 -04:00
|
|
|
jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-05-04 16:55:47 -04:00
|
|
|
|
|
|
|
def testEDiff1dWithDtypeCast(self):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-04 16:55:47 -04:00
|
|
|
shape = jtu.NUMPY_SCALAR_SHAPE
|
|
|
|
dtype = jnp.float32
|
2020-05-04 23:00:20 -04:00
|
|
|
end_dtype = jnp.int32
|
2020-05-04 16:55:47 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype), rng(shape, end_dtype), rng(shape, dtype)]
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin)
|
2020-05-04 16:55:47 -04:00
|
|
|
jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-05-04 16:55:47 -04:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shapes=[(), (5,), (5, 3)],
|
|
|
|
dtype=number_dtypes,
|
|
|
|
indexing=['xy', 'ij'],
|
|
|
|
sparse=[True, False],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testMeshGrid(self, shapes, dtype, indexing, sparse):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-08-15 20:25:32 -04:00
|
|
|
args_maker = self._GetArgsMaker(rng, [(x,) for x in shapes],
|
|
|
|
[dtype] * len(shapes))
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = partial(np.meshgrid, indexing=indexing, sparse=sparse)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_fun = partial(jnp.meshgrid, indexing=indexing, sparse=sparse)
|
2020-06-02 19:25:47 -07:00
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2019-08-15 20:25:32 -04:00
|
|
|
|
2021-03-27 16:56:32 +09:00
|
|
|
def testMgrid(self):
|
2021-12-09 09:47:21 -08:00
|
|
|
# wrap indexer for appropriate dtype defaults.
|
|
|
|
np_mgrid = _indexer_with_default_outputs(np.mgrid)
|
2021-03-27 16:56:32 +09:00
|
|
|
assertAllEqual = partial(self.assertAllClose, atol=0, rtol=0)
|
2023-01-12 12:57:30 +01:00
|
|
|
assertAllEqual(np_mgrid[()], jnp.mgrid[()])
|
2021-12-09 09:47:21 -08:00
|
|
|
assertAllEqual(np_mgrid[:4], jnp.mgrid[:4])
|
|
|
|
assertAllEqual(np_mgrid[:4,], jnp.mgrid[:4,])
|
|
|
|
assertAllEqual(np_mgrid[:4], jax.jit(lambda: jnp.mgrid[:4])())
|
|
|
|
assertAllEqual(np_mgrid[:5, :5], jnp.mgrid[:5, :5])
|
|
|
|
assertAllEqual(np_mgrid[:3, :2], jnp.mgrid[:3, :2])
|
|
|
|
assertAllEqual(np_mgrid[1:4:2], jnp.mgrid[1:4:2])
|
|
|
|
assertAllEqual(np_mgrid[1:5:3, :5], jnp.mgrid[1:5:3, :5])
|
|
|
|
assertAllEqual(np_mgrid[:3, :2, :5], jnp.mgrid[:3, :2, :5])
|
|
|
|
assertAllEqual(np_mgrid[:3:2, :2, :5], jnp.mgrid[:3:2, :2, :5])
|
2021-03-27 16:56:32 +09:00
|
|
|
# Corner cases
|
2021-12-09 09:47:21 -08:00
|
|
|
assertAllEqual(np_mgrid[:], jnp.mgrid[:])
|
2021-07-24 15:25:13 +07:00
|
|
|
# When the step length is a complex number, because of float calculation,
|
2021-03-27 16:56:32 +09:00
|
|
|
# the values between jnp and np might slightly different.
|
|
|
|
atol = 1e-6
|
|
|
|
rtol = 1e-6
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertAllClose(np_mgrid[-1:1:5j],
|
2021-03-27 16:56:32 +09:00
|
|
|
jnp.mgrid[-1:1:5j],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol)
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertAllClose(np_mgrid[3:4:7j],
|
2021-03-27 16:56:32 +09:00
|
|
|
jnp.mgrid[3:4:7j],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol)
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertAllClose(np_mgrid[1:6:8j, 2:4],
|
2021-03-27 16:56:32 +09:00
|
|
|
jnp.mgrid[1:6:8j, 2:4],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol)
|
|
|
|
# Non-integer steps
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertAllClose(np_mgrid[0:3.5:0.5],
|
2021-03-27 16:56:32 +09:00
|
|
|
jnp.mgrid[0:3.5:0.5],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol)
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertAllClose(np_mgrid[1.3:4.2:0.3],
|
2021-03-27 16:56:32 +09:00
|
|
|
jnp.mgrid[1.3:4.2:0.3],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol)
|
2021-06-17 20:26:16 +02:00
|
|
|
# abstract tracer value for jnp.mgrid slice
|
2023-02-14 23:00:40 -08:00
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError,
|
2021-06-17 20:26:16 +02:00
|
|
|
"slice start of jnp.mgrid"):
|
|
|
|
jax.jit(lambda a, b: jnp.mgrid[a:b])(0, 2)
|
2021-03-27 16:56:32 +09:00
|
|
|
|
2021-04-05 10:35:45 +09:00
|
|
|
def testOgrid(self):
|
2021-12-09 09:47:21 -08:00
|
|
|
# wrap indexer for appropriate dtype defaults.
|
|
|
|
np_ogrid = _indexer_with_default_outputs(np.ogrid)
|
2021-04-05 10:35:45 +09:00
|
|
|
def assertListOfArraysEqual(xs, ys):
|
|
|
|
self.assertIsInstance(xs, list)
|
|
|
|
self.assertIsInstance(ys, list)
|
|
|
|
self.assertEqual(len(xs), len(ys))
|
|
|
|
for x, y in zip(xs, ys):
|
|
|
|
self.assertArraysEqual(x, y)
|
|
|
|
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertArraysEqual(np_ogrid[:5], jnp.ogrid[:5])
|
|
|
|
self.assertArraysEqual(np_ogrid[:5], jax.jit(lambda: jnp.ogrid[:5])())
|
|
|
|
self.assertArraysEqual(np_ogrid[1:7:2], jnp.ogrid[1:7:2])
|
2021-04-05 10:35:45 +09:00
|
|
|
# List of arrays
|
2021-12-09 09:47:21 -08:00
|
|
|
assertListOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,])
|
|
|
|
assertListOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3])
|
|
|
|
assertListOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3])
|
|
|
|
assertListOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11])
|
2021-04-05 10:35:45 +09:00
|
|
|
# Corner cases
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertArraysEqual(np_ogrid[:], jnp.ogrid[:])
|
2021-04-05 10:35:45 +09:00
|
|
|
# Complex number steps
|
|
|
|
atol = 1e-6
|
|
|
|
rtol = 1e-6
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertAllClose(np_ogrid[-1:1:5j],
|
2021-04-05 10:35:45 +09:00
|
|
|
jnp.ogrid[-1:1:5j],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol)
|
|
|
|
# Non-integer steps
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertAllClose(np_ogrid[0:3.5:0.3],
|
2021-04-05 10:35:45 +09:00
|
|
|
jnp.ogrid[0:3.5:0.3],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol)
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertAllClose(np_ogrid[1.2:4.8:0.24],
|
2021-04-05 10:35:45 +09:00
|
|
|
jnp.ogrid[1.2:4.8:0.24],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol)
|
2021-06-17 20:26:16 +02:00
|
|
|
# abstract tracer value for ogrid slice
|
2023-02-14 23:00:40 -08:00
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError,
|
2021-06-17 20:26:16 +02:00
|
|
|
"slice start of jnp.ogrid"):
|
|
|
|
jax.jit(lambda a, b: jnp.ogrid[a:b])(0, 2)
|
2021-04-05 10:35:45 +09:00
|
|
|
|
2021-05-01 01:05:22 +02:00
|
|
|
def testR_(self):
|
|
|
|
a = np.arange(6).reshape((2,3))
|
|
|
|
self.assertArraysEqual(np.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])],
|
|
|
|
jnp.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])])
|
|
|
|
self.assertArraysEqual(np.r_['-1', a, a], jnp.r_['-1', a, a])
|
2021-12-09 09:47:21 -08:00
|
|
|
|
2022-12-01 13:56:42 -08:00
|
|
|
self.assertArraysEqual(np.r_['0,2', [1,2,3], [4,5,6]], jnp.r_['0,2', [1,2,3], [4,5,6]])
|
|
|
|
self.assertArraysEqual(np.r_['0,2,0', [1,2,3], [4,5,6]], jnp.r_['0,2,0', [1,2,3], [4,5,6]])
|
|
|
|
self.assertArraysEqual(np.r_['1,2,0', [1,2,3], [4,5,6]], jnp.r_['1,2,0', [1,2,3], [4,5,6]])
|
2021-05-01 01:05:22 +02:00
|
|
|
# negative 1d axis start
|
2022-12-01 13:56:42 -08:00
|
|
|
self.assertArraysEqual(np.r_['0,4,-1', [1,2,3], [4,5,6]], jnp.r_['0,4,-1', [1,2,3], [4,5,6]])
|
|
|
|
self.assertArraysEqual(np.r_['0,4,-2', [1,2,3], [4,5,6]], jnp.r_['0,4,-2', [1,2,3], [4,5,6]])
|
2021-05-01 01:05:22 +02:00
|
|
|
|
|
|
|
# matrix directives
|
|
|
|
with warnings.catch_warnings():
|
|
|
|
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
|
2022-12-01 13:56:42 -08:00
|
|
|
self.assertArraysEqual(np.r_['r',[1,2,3], [4,5,6]], jnp.r_['r',[1,2,3], [4,5,6]])
|
|
|
|
self.assertArraysEqual(np.r_['c', [1, 2, 3], [4, 5, 6]], jnp.r_['c', [1, 2, 3], [4, 5, 6]])
|
2021-05-01 01:05:22 +02:00
|
|
|
|
|
|
|
# bad directive
|
|
|
|
with self.assertRaisesRegex(ValueError, "could not understand directive.*"):
|
|
|
|
jnp.r_["asdfgh",[1,2,3]]
|
2021-06-17 20:26:16 +02:00
|
|
|
# abstract tracer value for r_ slice
|
2023-02-14 23:00:40 -08:00
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError,
|
2021-06-17 20:26:16 +02:00
|
|
|
"slice start of jnp.r_"):
|
|
|
|
jax.jit(lambda a, b: jnp.r_[a:b])(0, 2)
|
2021-05-01 01:05:22 +02:00
|
|
|
|
2022-12-01 13:56:42 -08:00
|
|
|
# wrap indexer for appropriate dtype defaults.
|
|
|
|
np_r_ = _indexer_with_default_outputs(np.r_)
|
|
|
|
|
2021-05-01 01:05:22 +02:00
|
|
|
# Complex number steps
|
|
|
|
atol = 1e-6
|
|
|
|
rtol = 1e-6
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertAllClose(np_r_[-1:1:6j],
|
2021-05-01 01:05:22 +02:00
|
|
|
jnp.r_[-1:1:6j],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol)
|
2022-06-14 11:20:37 -07:00
|
|
|
with jax.numpy_dtype_promotion('standard'): # Requires dtype promotion.
|
|
|
|
self.assertAllClose(np_r_[-1:1:6j, [0]*3, 5, 6],
|
|
|
|
jnp.r_[-1:1:6j, [0]*3, 5, 6],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol)
|
2021-05-01 01:05:22 +02:00
|
|
|
# Non-integer steps
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertAllClose(np_r_[1.2:4.8:0.24],
|
2021-05-01 01:05:22 +02:00
|
|
|
jnp.r_[1.2:4.8:0.24],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol)
|
|
|
|
|
|
|
|
def testC_(self):
|
|
|
|
a = np.arange(6).reshape((2, 3))
|
|
|
|
self.assertArraysEqual(np.c_[np.array([1,2,3]), np.array([4,5,6])],
|
|
|
|
jnp.c_[np.array([1,2,3]), np.array([4,5,6])])
|
|
|
|
self.assertArraysEqual(np.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])],
|
|
|
|
jnp.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])])
|
|
|
|
self.assertArraysEqual(np.c_['-1', a, a], jnp.c_['-1', a, a])
|
2021-12-09 09:47:21 -08:00
|
|
|
|
2022-12-01 13:56:42 -08:00
|
|
|
self.assertArraysEqual(np.c_['0,2', [1,2,3], [4,5,6]], jnp.c_['0,2', [1,2,3], [4,5,6]])
|
|
|
|
self.assertArraysEqual(np.c_['0,2,0', [1,2,3], [4,5,6]], jnp.c_['0,2,0', [1,2,3], [4,5,6]])
|
|
|
|
self.assertArraysEqual(np.c_['1,2,0', [1,2,3], [4,5,6]], jnp.c_['1,2,0', [1,2,3], [4,5,6]])
|
2021-05-01 01:05:22 +02:00
|
|
|
# negative 1d axis start
|
2022-12-01 13:56:42 -08:00
|
|
|
self.assertArraysEqual(np.c_['0,4,-1', [1,2,3], [4,5,6]], jnp.c_['0,4,-1', [1,2,3], [4,5,6]])
|
|
|
|
self.assertArraysEqual(np.c_['0,4,-2', [1,2,3], [4,5,6]], jnp.c_['0,4,-2', [1,2,3], [4,5,6]])
|
2021-05-01 01:05:22 +02:00
|
|
|
# matrix directives, avoid numpy deprecation warning
|
|
|
|
with warnings.catch_warnings():
|
|
|
|
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
|
2022-12-01 13:56:42 -08:00
|
|
|
self.assertArraysEqual(np.c_['r',[1,2,3], [4,5,6]], jnp.c_['r',[1,2,3], [4,5,6]])
|
|
|
|
self.assertArraysEqual(np.c_['c', [1, 2, 3], [4, 5, 6]], jnp.c_['c', [1, 2, 3], [4, 5, 6]])
|
2021-05-01 01:05:22 +02:00
|
|
|
|
|
|
|
# bad directive
|
|
|
|
with self.assertRaisesRegex(ValueError, "could not understand directive.*"):
|
|
|
|
jnp.c_["asdfgh",[1,2,3]]
|
2021-06-17 20:26:16 +02:00
|
|
|
# abstract tracer value for c_ slice
|
2023-02-14 23:00:40 -08:00
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError,
|
2021-06-17 20:26:16 +02:00
|
|
|
"slice start of jnp.c_"):
|
|
|
|
jax.jit(lambda a, b: jnp.c_[a:b])(0, 2)
|
2021-05-01 01:05:22 +02:00
|
|
|
|
2022-12-01 13:56:42 -08:00
|
|
|
# wrap indexer for appropriate dtype defaults.
|
|
|
|
np_c_ = _indexer_with_default_outputs(np.c_)
|
|
|
|
|
2021-05-01 01:05:22 +02:00
|
|
|
# Complex number steps
|
|
|
|
atol = 1e-6
|
|
|
|
rtol = 1e-6
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertAllClose(np_c_[-1:1:6j],
|
2021-05-01 01:05:22 +02:00
|
|
|
jnp.c_[-1:1:6j],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol)
|
|
|
|
|
|
|
|
# Non-integer steps
|
2021-12-09 09:47:21 -08:00
|
|
|
self.assertAllClose(np_c_[1.2:4.8:0.24],
|
2021-05-01 01:05:22 +02:00
|
|
|
jnp.c_[1.2:4.8:0.24],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol)
|
|
|
|
|
2021-07-09 02:35:28 +05:30
|
|
|
def testS_(self):
|
|
|
|
self.assertEqual(np.s_[1:2:20],jnp.s_[1:2:20])
|
|
|
|
|
|
|
|
def testIndex_exp(self):
|
|
|
|
self.assertEqual(np.index_exp[5:3:2j],jnp.index_exp[5:3:2j])
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
start_shape=[(), (2,), (2, 2)],
|
|
|
|
stop_shape=[(), (2,), (2, 2)],
|
|
|
|
num=[0, 1, 2, 5, 20],
|
|
|
|
endpoint=[True, False],
|
|
|
|
retstep=[True, False],
|
|
|
|
# floating-point compute between jitted platforms and non-jit + rounding
|
|
|
|
# cause unavoidable variation in integer truncation for some inputs, so
|
|
|
|
# we currently only test inexact 'dtype' arguments.
|
|
|
|
dtype=inexact_dtypes + [None,],
|
|
|
|
)
|
2021-07-13 11:38:21 -07:00
|
|
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
2020-12-02 17:00:25 -08:00
|
|
|
def testLinspace(self, start_shape, stop_shape, num, endpoint, retstep, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-11-09 00:16:18 -08:00
|
|
|
# relax default tolerances slightly
|
2020-05-20 01:43:48 -03:00
|
|
|
tol = jtu.tolerance(dtype if dtype else np.float32) * 10
|
2019-11-09 00:16:18 -08:00
|
|
|
args_maker = self._GetArgsMaker(rng,
|
|
|
|
[start_shape, stop_shape],
|
|
|
|
[dtype, dtype])
|
|
|
|
start, stop = args_maker()
|
2020-05-20 01:43:48 -03:00
|
|
|
ndim = len(np.shape(start + stop))
|
2019-11-09 00:16:18 -08:00
|
|
|
for axis in range(-ndim, ndim):
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_op = lambda start, stop: jnp.linspace(
|
2019-11-09 00:16:18 -08:00
|
|
|
start, stop, num,
|
|
|
|
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
|
2022-08-06 14:49:09 +00:00
|
|
|
np_op = lambda start, stop: np.linspace(
|
|
|
|
start, stop, num,
|
|
|
|
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
|
2021-02-01 16:30:30 -05:00
|
|
|
|
2020-05-20 01:43:48 -03:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
|
2019-11-09 00:16:18 -08:00
|
|
|
check_dtypes=False, tol=tol)
|
2021-09-07 14:53:18 -04:00
|
|
|
self._CompileAndCheck(jnp_op, args_maker,
|
|
|
|
check_dtypes=False, atol=tol, rtol=tol)
|
2019-11-09 00:16:18 -08:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(dtype=number_dtypes)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testLinspaceEndpoints(self, dtype):
|
2020-05-08 21:01:57 -07:00
|
|
|
"""Regression test for Issue #3014."""
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-08 21:01:57 -07:00
|
|
|
endpoints = rng((2,), dtype)
|
|
|
|
out = jnp.linspace(*endpoints, 10, dtype=dtype)
|
2020-10-19 14:07:40 -07:00
|
|
|
self.assertAllClose(out[np.array([0, -1])], endpoints, rtol=0, atol=0)
|
2020-05-08 21:01:57 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
start_shape=[(), (2,), (2, 2)],
|
|
|
|
stop_shape=[(), (2,), (2, 2)],
|
|
|
|
num=[0, 1, 2, 5, 20],
|
|
|
|
endpoint=[True, False],
|
|
|
|
base=[10.0, 2, np.e],
|
|
|
|
# skip 16-bit floats due to insufficient precision for the test.
|
|
|
|
dtype=jtu.dtypes.inexact + [None,],
|
|
|
|
)
|
2021-07-13 11:38:21 -07:00
|
|
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
2019-11-09 00:16:18 -08:00
|
|
|
def testLogspace(self, start_shape, stop_shape, num,
|
2020-12-02 17:00:25 -08:00
|
|
|
endpoint, base, dtype):
|
2019-11-09 00:16:18 -08:00
|
|
|
if (dtype in int_dtypes and
|
2019-11-16 13:51:42 -05:00
|
|
|
jtu.device_under_test() in ("gpu", "tpu") and
|
2021-02-04 09:48:22 -08:00
|
|
|
not config.x64_enabled):
|
2019-11-09 00:16:18 -08:00
|
|
|
raise unittest.SkipTest("GPUx32 truncated exponentiation"
|
|
|
|
" doesn't exactly match other platforms.")
|
2020-12-02 17:00:25 -08:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-11-09 00:16:18 -08:00
|
|
|
# relax default tolerances slightly
|
2022-05-26 16:41:20 -07:00
|
|
|
tol = {np.float32: 1e-2, np.float64: 1e-6, np.complex64: 1e-3, np.complex128: 1e-6}
|
2019-11-09 00:16:18 -08:00
|
|
|
args_maker = self._GetArgsMaker(rng,
|
|
|
|
[start_shape, stop_shape],
|
|
|
|
[dtype, dtype])
|
|
|
|
start, stop = args_maker()
|
2020-05-20 01:43:48 -03:00
|
|
|
ndim = len(np.shape(start + stop))
|
2019-11-09 00:16:18 -08:00
|
|
|
for axis in range(-ndim, ndim):
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_op = lambda start, stop: jnp.logspace(
|
2019-11-09 00:16:18 -08:00
|
|
|
start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis)
|
2020-07-07 17:01:38 -07:00
|
|
|
@jtu.ignore_warning(category=RuntimeWarning,
|
|
|
|
message="overflow encountered in power")
|
|
|
|
def np_op(start, stop):
|
|
|
|
return np.logspace(start, stop, num, endpoint=endpoint,
|
|
|
|
base=base, dtype=dtype, axis=axis)
|
2020-05-20 01:43:48 -03:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
|
2019-11-09 00:16:18 -08:00
|
|
|
check_dtypes=False, tol=tol)
|
|
|
|
if dtype in (inexact_dtypes + [None,]):
|
|
|
|
# Why do compiled and op-by-op float16 np.power numbers differ
|
|
|
|
# slightly more than expected?
|
2020-05-20 01:43:48 -03:00
|
|
|
atol = {np.float16: 1e-2}
|
2020-03-06 14:59:51 -05:00
|
|
|
self._CompileAndCheck(jnp_op, args_maker,
|
2019-11-09 00:16:18 -08:00
|
|
|
check_dtypes=False, atol=atol, rtol=tol)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(start_shape=start_shape, stop_shape=stop_shape, axis=axis)
|
|
|
|
for start_shape in [(), (2,), (2, 2)]
|
|
|
|
for stop_shape in [(), (2,), (2, 2)]
|
|
|
|
for axis in range(-max(len(start_shape), len(stop_shape)),
|
|
|
|
max(len(start_shape), len(stop_shape)))
|
|
|
|
],
|
|
|
|
num=[0, 1, 2, 5, 20],
|
|
|
|
endpoint=[True, False],
|
|
|
|
# NB: numpy's geomspace gives nonsense results on integer types
|
|
|
|
dtype=inexact_dtypes + [None,],
|
|
|
|
)
|
2021-07-13 11:38:21 -07:00
|
|
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
2019-11-09 00:16:18 -08:00
|
|
|
def testGeomspace(self, start_shape, stop_shape, num,
|
2020-08-25 13:05:06 -04:00
|
|
|
endpoint, dtype, axis):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2019-11-09 00:16:18 -08:00
|
|
|
# relax default tolerances slightly
|
2022-10-05 01:52:41 +00:00
|
|
|
tol = {dtypes.bfloat16: 2e-2, np.float16: 4e-3, np.float32: 2e-3,
|
2023-09-05 18:48:18 -07:00
|
|
|
np.float64: 1e-14, np.complex64: 2e-3, np.complex128: 1e-14}
|
2019-11-09 00:16:18 -08:00
|
|
|
def args_maker():
|
2020-05-20 01:43:48 -03:00
|
|
|
"""Test the set of inputs np.geomspace is well-defined on."""
|
2019-11-09 00:16:18 -08:00
|
|
|
start, stop = self._GetArgsMaker(rng,
|
|
|
|
[start_shape, stop_shape],
|
|
|
|
[dtype, dtype])()
|
2020-05-20 01:43:48 -03:00
|
|
|
# np.geomspace can't handle differently ranked tensors
|
2019-11-09 00:16:18 -08:00
|
|
|
# w. negative numbers!
|
2020-03-06 14:59:51 -05:00
|
|
|
start, stop = jnp.broadcast_arrays(start, stop)
|
2019-11-09 00:16:18 -08:00
|
|
|
if dtype in complex_dtypes:
|
|
|
|
return start, stop
|
|
|
|
# to avoid NaNs, non-complex start and stop cannot
|
|
|
|
# differ in sign, elementwise
|
2020-03-06 14:59:51 -05:00
|
|
|
start = start * jnp.sign(start) * jnp.sign(stop)
|
2019-11-09 00:16:18 -08:00
|
|
|
return start, stop
|
|
|
|
start, stop = args_maker()
|
2020-08-25 13:05:06 -04:00
|
|
|
def jnp_op(start, stop):
|
|
|
|
return jnp.geomspace(start, stop, num, endpoint=endpoint, dtype=dtype,
|
|
|
|
axis=axis)
|
|
|
|
def np_op(start, stop):
|
|
|
|
start = start.astype(np.float32) if dtype == jnp.bfloat16 else start
|
|
|
|
stop = stop.astype(np.float32) if dtype == jnp.bfloat16 else stop
|
|
|
|
return np.geomspace(
|
|
|
|
start, stop, num, endpoint=endpoint,
|
|
|
|
dtype=dtype if dtype != jnp.bfloat16 else np.float32,
|
|
|
|
axis=axis).astype(dtype)
|
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
|
|
|
|
check_dtypes=False, tol=tol)
|
|
|
|
if dtype in (inexact_dtypes + [None,]):
|
|
|
|
self._CompileAndCheck(jnp_op, args_maker,
|
|
|
|
check_dtypes=False, atol=tol, rtol=tol)
|
2019-11-09 00:16:18 -08:00
|
|
|
|
2019-08-23 17:05:32 -07:00
|
|
|
def testDisableNumpyRankPromotionBroadcasting(self):
|
|
|
|
try:
|
2022-01-26 10:51:35 -08:00
|
|
|
prev_flag = config._read('jax_numpy_rank_promotion')
|
2019-08-23 17:05:32 -07:00
|
|
|
FLAGS.jax_numpy_rank_promotion = "allow"
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp.ones(2) + jnp.ones((1, 2)) # works just fine
|
2019-08-23 17:05:32 -07:00
|
|
|
finally:
|
|
|
|
FLAGS.jax_numpy_rank_promotion = prev_flag
|
|
|
|
|
|
|
|
try:
|
2022-01-26 10:51:35 -08:00
|
|
|
prev_flag = config._read('jax_numpy_rank_promotion')
|
2019-08-23 17:05:32 -07:00
|
|
|
FLAGS.jax_numpy_rank_promotion = "raise"
|
2020-03-06 14:59:51 -05:00
|
|
|
self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2)))
|
2022-01-26 10:51:35 -08:00
|
|
|
jnp.ones(2) + 3 # don't want to raise for scalars
|
2019-08-23 17:05:32 -07:00
|
|
|
finally:
|
|
|
|
FLAGS.jax_numpy_rank_promotion = prev_flag
|
|
|
|
|
|
|
|
try:
|
2022-01-26 10:51:35 -08:00
|
|
|
prev_flag = config._read('jax_numpy_rank_promotion')
|
2019-08-23 17:05:32 -07:00
|
|
|
FLAGS.jax_numpy_rank_promotion = "warn"
|
2022-01-26 10:51:35 -08:00
|
|
|
self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on "
|
|
|
|
r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2)))
|
|
|
|
jnp.ones(2) + 3 # don't want to warn for scalars
|
2019-08-23 17:05:32 -07:00
|
|
|
finally:
|
|
|
|
FLAGS.jax_numpy_rank_promotion = prev_flag
|
|
|
|
|
2021-08-10 06:48:55 -07:00
|
|
|
@unittest.skip("Test fails on CI, perhaps due to JIT caching")
|
2021-03-23 20:58:52 -07:00
|
|
|
def testDisableNumpyRankPromotionBroadcastingDecorator(self):
|
|
|
|
with jax.numpy_rank_promotion("allow"):
|
|
|
|
jnp.ones(2) + jnp.ones((1, 2)) # works just fine
|
|
|
|
|
|
|
|
with jax.numpy_rank_promotion("raise"):
|
|
|
|
self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2)))
|
2022-01-26 10:51:35 -08:00
|
|
|
jnp.ones(2) + 3 # don't want to raise for scalars
|
2021-03-23 20:58:52 -07:00
|
|
|
|
|
|
|
with jax.numpy_rank_promotion("warn"):
|
2022-01-26 10:51:35 -08:00
|
|
|
self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on "
|
|
|
|
r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2)))
|
|
|
|
jnp.ones(2) + 3 # don't want to warn for scalars
|
2021-03-23 20:58:52 -07:00
|
|
|
|
2019-09-02 07:55:25 -07:00
|
|
|
def testStackArrayArgument(self):
|
|
|
|
# tests https://github.com/google/jax/issues/1271
|
2021-09-13 16:00:22 -04:00
|
|
|
@jax.jit
|
2019-09-02 07:55:25 -07:00
|
|
|
def foo(x):
|
2020-03-06 14:59:51 -05:00
|
|
|
return jnp.stack(x)
|
2020-05-20 01:43:48 -03:00
|
|
|
foo(np.zeros(2)) # doesn't crash
|
2019-09-02 07:55:25 -07:00
|
|
|
|
2021-09-13 16:00:22 -04:00
|
|
|
@jax.jit
|
2019-09-02 07:55:25 -07:00
|
|
|
def foo(x):
|
2020-03-06 14:59:51 -05:00
|
|
|
return jnp.concatenate(x)
|
2020-05-20 01:43:48 -03:00
|
|
|
foo(np.zeros((2, 2))) # doesn't crash
|
2019-09-02 07:55:25 -07:00
|
|
|
|
2019-10-15 15:01:52 -04:00
|
|
|
def testReluGradientConstants(self):
|
|
|
|
# This is a regression test that verifies that constants associated with the
|
|
|
|
# gradient of np.maximum (from lax._balanced_eq) aren't hoisted into the
|
|
|
|
# outermost jaxpr. This was producing some large materialized constants for
|
|
|
|
# every relu activation in a model.
|
|
|
|
def body(i, xy):
|
|
|
|
x, y = xy
|
2020-03-06 14:59:51 -05:00
|
|
|
y = y + jax.grad(lambda z: jnp.sum(jnp.maximum(z, 0.)))(x)
|
2019-10-15 15:01:52 -04:00
|
|
|
return x, y
|
|
|
|
|
|
|
|
f = lambda y: lax.fori_loop(0, 5, body, (y, y))
|
2021-04-23 08:31:11 -07:00
|
|
|
jaxpr = jax.make_jaxpr(f)(np.zeros((3, 4), np.float32))
|
2019-10-15 15:01:52 -04:00
|
|
|
self.assertFalse(
|
2020-05-20 01:43:48 -03:00
|
|
|
any(np.array_equal(x, np.full((3, 4), 2., dtype=np.float32))
|
2021-04-23 08:31:11 -07:00
|
|
|
for x in jaxpr.consts))
|
2019-10-15 15:01:52 -04:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(from_shape=from_shape, to_shape=to_shape)
|
2019-10-17 23:23:08 +00:00
|
|
|
for from_shape, to_shape in [
|
|
|
|
[(1, 3), (4, 3)],
|
|
|
|
[(3,), (2, 1, 3)],
|
|
|
|
[(3,), (3, 3)],
|
|
|
|
[(1,), (3,)],
|
2021-02-05 10:07:41 -08:00
|
|
|
[(1,), 3],
|
2022-10-05 01:52:41 +00:00
|
|
|
]
|
|
|
|
],
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testBroadcastTo(self, from_shape, to_shape):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-05-20 01:43:48 -03:00
|
|
|
args_maker = self._GetArgsMaker(rng, [from_shape], [np.float32])
|
|
|
|
np_op = lambda x: np.broadcast_to(x, to_shape)
|
2020-03-06 14:59:51 -05:00
|
|
|
jnp_op = lambda x: jnp.broadcast_to(x, to_shape)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_op, args_maker)
|
2019-10-17 23:23:08 +00:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shapes=shapes, broadcasted_shape=broadcasted_shape)
|
2021-04-30 19:32:51 +02:00
|
|
|
for shapes, broadcasted_shape in [
|
|
|
|
[[], ()],
|
|
|
|
[[()], ()],
|
|
|
|
[[(1, 3), (4, 3)], (4, 3)],
|
|
|
|
[[(3,), (2, 1, 3)], (2, 1, 3)],
|
|
|
|
[[(3,), (3, 3)], (3, 3)],
|
|
|
|
[[(1,), (3,)], (3,)],
|
|
|
|
[[(1,), 3], (3,)],
|
|
|
|
[[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)],
|
|
|
|
[[[1], [0, 1]], (0, 1)],
|
|
|
|
[[(1,), np.array([0, 1])], (0, 1)],
|
2022-10-05 01:52:41 +00:00
|
|
|
]
|
|
|
|
],
|
|
|
|
)
|
2021-04-30 19:32:51 +02:00
|
|
|
def testBroadcastShapes(self, shapes, broadcasted_shape):
|
|
|
|
# Test against np.broadcast_shapes once numpy 1.20 is minimum required version
|
|
|
|
np.testing.assert_equal(jnp.broadcast_shapes(*shapes), broadcasted_shape)
|
|
|
|
|
2019-10-17 23:23:08 +00:00
|
|
|
def testBroadcastToIssue1522(self):
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
ValueError, "Incompatible shapes for broadcasting: .*",
|
2020-05-20 01:43:48 -03:00
|
|
|
lambda: jnp.broadcast_to(np.ones((2, 3)), (1, 3)))
|
2019-08-23 17:05:32 -07:00
|
|
|
|
2019-10-21 15:11:51 -07:00
|
|
|
def testBroadcastToIntIssue1548(self):
|
2020-05-20 01:43:48 -03:00
|
|
|
self.assertAllClose(jnp.broadcast_to(1, (3, 2)), np.ones((3, 2)),
|
2019-10-21 15:11:51 -07:00
|
|
|
check_dtypes=False)
|
|
|
|
|
2019-11-27 03:17:08 +00:00
|
|
|
def testBroadcastToOnScalar(self):
|
2023-02-15 14:52:31 -08:00
|
|
|
self.assertIsInstance(jnp.broadcast_to(10.0, ()), jax.Array)
|
2020-05-20 01:43:48 -03:00
|
|
|
self.assertIsInstance(np.broadcast_to(10.0, ()), np.ndarray)
|
2019-11-27 03:17:08 +00:00
|
|
|
|
2019-11-21 15:30:02 -08:00
|
|
|
def testPrecision(self):
|
|
|
|
|
2020-05-20 01:43:48 -03:00
|
|
|
ones_1d = np.ones((2,))
|
|
|
|
ones_2d = np.ones((2, 2))
|
|
|
|
ones_3d = np.ones((2, 2, 2))
|
2019-11-21 15:30:02 -08:00
|
|
|
HIGHEST = lax.Precision.HIGHEST
|
|
|
|
|
2020-03-06 14:59:51 -05:00
|
|
|
jtu.assert_dot_precision(None, jnp.dot, ones_1d, ones_1d)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
2019-11-21 15:30:02 -08:00
|
|
|
HIGHEST,
|
2020-03-06 14:59:51 -05:00
|
|
|
partial(jnp.dot, precision=HIGHEST),
|
2019-11-21 15:30:02 -08:00
|
|
|
ones_1d, ones_1d)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
2019-11-21 15:30:02 -08:00
|
|
|
HIGHEST,
|
2020-03-06 14:59:51 -05:00
|
|
|
partial(jnp.dot, precision=HIGHEST),
|
2019-11-21 15:30:02 -08:00
|
|
|
ones_3d, ones_3d)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
2019-11-21 15:30:02 -08:00
|
|
|
HIGHEST,
|
2020-03-06 14:59:51 -05:00
|
|
|
partial(jnp.matmul, precision=HIGHEST),
|
2019-11-21 15:30:02 -08:00
|
|
|
ones_2d, ones_2d)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
2019-11-21 15:30:02 -08:00
|
|
|
HIGHEST,
|
2020-03-06 14:59:51 -05:00
|
|
|
partial(jnp.vdot, precision=HIGHEST),
|
2019-11-21 15:30:02 -08:00
|
|
|
ones_1d, ones_1d)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
2019-11-21 15:30:02 -08:00
|
|
|
HIGHEST,
|
2020-03-06 14:59:51 -05:00
|
|
|
partial(jnp.tensordot, axes=2, precision=HIGHEST),
|
2019-11-21 15:30:02 -08:00
|
|
|
ones_2d, ones_2d)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
2019-11-21 15:30:02 -08:00
|
|
|
HIGHEST,
|
2020-03-06 14:59:51 -05:00
|
|
|
partial(jnp.tensordot, axes=(0, 0), precision=HIGHEST),
|
2019-11-21 15:30:02 -08:00
|
|
|
ones_1d, ones_1d)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
2019-11-21 15:30:02 -08:00
|
|
|
HIGHEST,
|
2020-03-06 14:59:51 -05:00
|
|
|
partial(jnp.tensordot, axes=((0,), (0,)), precision=HIGHEST),
|
2019-11-21 15:30:02 -08:00
|
|
|
ones_1d, ones_1d)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
2019-11-21 15:30:02 -08:00
|
|
|
HIGHEST,
|
2020-03-06 14:59:51 -05:00
|
|
|
partial(jnp.einsum, 'i,i', precision=HIGHEST),
|
2019-11-21 15:30:02 -08:00
|
|
|
ones_1d, ones_1d)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
2019-11-21 15:30:02 -08:00
|
|
|
HIGHEST,
|
2020-03-06 14:59:51 -05:00
|
|
|
partial(jnp.einsum, 'ij,ij', precision=HIGHEST),
|
2019-11-21 15:30:02 -08:00
|
|
|
ones_2d, ones_2d)
|
2019-12-10 00:38:18 -08:00
|
|
|
jtu.assert_dot_precision(
|
2019-11-21 15:30:02 -08:00
|
|
|
HIGHEST,
|
2020-03-06 14:59:51 -05:00
|
|
|
partial(jnp.inner, precision=HIGHEST),
|
2019-11-21 15:30:02 -08:00
|
|
|
ones_1d, ones_1d)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shape=shape, varargs=varargs, axis=axis)
|
2020-01-07 12:34:34 +01:00
|
|
|
for shape in [(10,), (10, 15), (10, 15, 20)]
|
2020-04-13 17:48:49 -07:00
|
|
|
for _num_axes in range(len(shape))
|
2020-04-28 19:34:27 +01:00
|
|
|
for varargs in itertools.combinations(range(1, len(shape) + 1), _num_axes)
|
2020-01-09 08:46:36 +01:00
|
|
|
for axis in itertools.combinations(range(len(shape)), _num_axes)
|
2022-10-05 01:52:41 +00:00
|
|
|
],
|
|
|
|
dtype=inexact_dtypes,
|
|
|
|
)
|
2020-12-02 17:00:25 -08:00
|
|
|
def testGradient(self, shape, varargs, axis, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-01-07 12:34:34 +01:00
|
|
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
2020-04-28 19:34:27 +01:00
|
|
|
jnp_fun = lambda y: jnp.gradient(y, *varargs, axis=axis)
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun = lambda y: np.gradient(y, *varargs, axis=axis)
|
2020-01-09 08:46:36 +01:00
|
|
|
self._CheckAgainstNumpy(
|
2020-05-20 01:43:48 -03:00
|
|
|
np_fun, jnp_fun, args_maker, check_dtypes=False)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-01-07 12:34:34 +01:00
|
|
|
|
2020-01-06 20:57:19 -08:00
|
|
|
def testZerosShapeErrors(self):
|
|
|
|
# see https://github.com/google/jax/issues/1822
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
"Shapes must be 1D sequences of concrete values of integer type.*",
|
2020-03-06 14:59:51 -05:00
|
|
|
lambda: jnp.zeros(1.))
|
2020-01-06 20:57:19 -08:00
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
2020-07-30 12:59:36 -07:00
|
|
|
r"Shapes must be 1D sequences of concrete values of integer type.*\n"
|
2022-01-20 22:58:09 -08:00
|
|
|
"If using `jit`, try using `static_argnums` or applying `jit` to "
|
|
|
|
"smaller subfunctions.",
|
2021-09-13 16:00:22 -04:00
|
|
|
lambda: jax.jit(jnp.zeros)(2))
|
2020-01-06 20:57:19 -08:00
|
|
|
|
2020-01-29 16:23:27 -05:00
|
|
|
def testTraceMethod(self):
|
2020-05-04 23:00:20 -04:00
|
|
|
x = self.rng().randn(3, 4).astype(jnp.float_)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(x.trace(), jnp.array(x).trace())
|
2021-09-13 16:00:22 -04:00
|
|
|
self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x))
|
2020-01-29 16:23:27 -05:00
|
|
|
|
2020-05-12 17:19:09 -04:00
|
|
|
def testIntegerPowersArePrecise(self):
|
2020-05-13 00:04:53 +09:00
|
|
|
# See https://github.com/google/jax/pull/3036
|
|
|
|
# Checks if the squares of float32 integers have no numerical errors.
|
|
|
|
# It should be satisfied with all integers less than sqrt(2**24).
|
2020-05-12 17:19:09 -04:00
|
|
|
x = jnp.arange(-2**12, 2**12, dtype=jnp.int32)
|
2020-05-20 01:43:48 -03:00
|
|
|
np.testing.assert_array_equal(jnp.square(x.astype(jnp.float32)), x * x)
|
|
|
|
np.testing.assert_array_equal(x.astype(jnp.float32) ** 2, x * x)
|
2020-05-12 17:19:09 -04:00
|
|
|
|
|
|
|
# Similarly for cubes.
|
|
|
|
x = jnp.arange(-2**8, 2**8, dtype=jnp.int32)
|
2020-05-20 01:43:48 -03:00
|
|
|
np.testing.assert_array_equal(x.astype(jnp.float32) ** 3, x * x * x)
|
2020-05-12 17:19:09 -04:00
|
|
|
|
2020-05-20 01:43:48 -03:00
|
|
|
x = np.arange(10, dtype=np.float32)
|
2020-05-12 17:19:09 -04:00
|
|
|
for i in range(10):
|
|
|
|
self.assertAllClose(x.astype(jnp.float32) ** i, x ** i,
|
|
|
|
check_dtypes=False)
|
|
|
|
|
2020-06-28 11:26:48 -04:00
|
|
|
def testToBytes(self):
|
|
|
|
v = np.arange(12, dtype=np.int32).reshape(3, 4)
|
|
|
|
for order in ['C', 'F']:
|
|
|
|
self.assertEqual(jnp.asarray(v).tobytes(order), v.tobytes(order))
|
|
|
|
|
2023-03-17 09:42:49 -07:00
|
|
|
def testToBytesJitError(self):
|
|
|
|
v = np.arange(12, dtype=np.int32).reshape(3, 4)
|
|
|
|
f = jax.jit(lambda x: x.tobytes())
|
2023-06-20 00:33:51 -07:00
|
|
|
msg = r".*The tobytes\(\) method was called on traced array"
|
2023-03-17 09:42:49 -07:00
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
|
|
f(v)
|
|
|
|
|
2020-06-28 11:26:48 -04:00
|
|
|
def testToList(self):
|
|
|
|
v = np.arange(12, dtype=np.int32).reshape(3, 4)
|
|
|
|
self.assertEqual(jnp.asarray(v).tolist(), v.tolist())
|
|
|
|
|
2023-03-17 09:42:49 -07:00
|
|
|
def testToListJitError(self):
|
|
|
|
v = np.arange(12, dtype=np.int32).reshape(3, 4)
|
|
|
|
f = jax.jit(lambda x: x.tolist())
|
2023-06-20 00:33:51 -07:00
|
|
|
msg = r".*The tolist\(\) method was called on traced array"
|
2023-03-17 09:42:49 -07:00
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
|
|
f(v)
|
|
|
|
|
2020-07-03 20:54:25 -07:00
|
|
|
def testArangeConcretizationError(self):
|
2023-03-14 16:41:58 -07:00
|
|
|
msg = r"It arose in the jnp.arange argument '{}'".format
|
2023-02-14 23:00:40 -08:00
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')):
|
2020-07-03 20:54:25 -07:00
|
|
|
jax.jit(jnp.arange)(3)
|
|
|
|
|
2023-02-14 23:00:40 -08:00
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg('start')):
|
2020-07-03 20:54:25 -07:00
|
|
|
jax.jit(lambda start: jnp.arange(start, 3))(0)
|
|
|
|
|
2023-02-14 23:00:40 -08:00
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')):
|
2020-07-03 20:54:25 -07:00
|
|
|
jax.jit(lambda stop: jnp.arange(0, stop))(3)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(dtype=[None] + float_dtypes)
|
2021-11-15 13:33:51 -08:00
|
|
|
def testArange64Bit(self, dtype):
|
|
|
|
# Test that jnp.arange uses 64-bit arithmetic to define its range, even if the
|
|
|
|
# output has another dtype. The issue here is that if python scalar inputs to
|
|
|
|
# jnp.arange are cast to float32 before the range is computed, it changes the
|
|
|
|
# number of elements output by the range. It's unclear whether this was deliberate
|
|
|
|
# behavior in the initial implementation, but it's behavior that downstream users
|
|
|
|
# have come to rely on.
|
|
|
|
args = (1.2, 4.8, 0.24)
|
|
|
|
|
|
|
|
# Ensure that this test case leads to differing lengths if cast to float32.
|
|
|
|
self.assertLen(np.arange(*args), 15)
|
|
|
|
self.assertLen(np.arange(*map(np.float32, args)), 16)
|
|
|
|
|
|
|
|
jnp_fun = lambda: jnp.arange(*args, dtype=dtype)
|
2021-12-09 09:47:21 -08:00
|
|
|
np_fun = jtu.with_jax_dtype_defaults(lambda: np.arange(*args, dtype=dtype), dtype is None)
|
2021-11-15 13:33:51 -08:00
|
|
|
args_maker = lambda: []
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2020-09-17 18:33:30 +03:00
|
|
|
def testIssue2347(self):
|
|
|
|
# https://github.com/google/jax/issues/2347
|
2023-06-23 15:11:37 -07:00
|
|
|
object_list = list[tuple[jnp.array, float, float, jnp.array, bool]]
|
2020-09-17 18:33:30 +03:00
|
|
|
self.assertRaises(TypeError, jnp.array, object_list)
|
|
|
|
|
|
|
|
np_object_list = np.array(object_list)
|
|
|
|
self.assertRaises(TypeError, jnp.array, np_object_list)
|
2020-05-13 00:04:53 +09:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shapes=shapes, dtypes=dtypes)
|
|
|
|
for shapes in filter(
|
|
|
|
_shapes_are_broadcast_compatible,
|
|
|
|
itertools.combinations_with_replacement(all_shapes, 2))
|
|
|
|
for dtypes in itertools.product(
|
|
|
|
*(_valid_dtypes_for_shape(s, complex_dtypes) for s in shapes))
|
|
|
|
],
|
|
|
|
)
|
2021-07-13 11:38:21 -07:00
|
|
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
2021-07-03 18:09:58 +02:00
|
|
|
def testLogaddexpComplex(self, shapes, dtypes):
|
|
|
|
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")
|
|
|
|
def np_op(x1, x2):
|
|
|
|
return np.log(np.exp(x1) + np.exp(x2))
|
|
|
|
|
|
|
|
rng = jtu.rand_some_nan(self.rng())
|
|
|
|
args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes))
|
2021-07-07 14:07:42 -07:00
|
|
|
if jtu.device_under_test() == 'tpu':
|
|
|
|
tol = {np.complex64: 1e-3, np.complex128: 1e-10}
|
|
|
|
else:
|
|
|
|
tol = {np.complex64: 1e-5, np.complex128: 1e-14}
|
2022-06-14 11:20:37 -07:00
|
|
|
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
2022-10-06 10:20:26 -07:00
|
|
|
self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp.logaddexp, args_maker, tol=tol)
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CompileAndCheck(jnp.logaddexp, args_maker, rtol=tol, atol=tol)
|
2021-07-03 18:09:58 +02:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
[dict(shapes=shapes, dtypes=dtypes)
|
|
|
|
for shapes in filter(
|
|
|
|
_shapes_are_broadcast_compatible,
|
|
|
|
itertools.combinations_with_replacement(all_shapes, 2))
|
|
|
|
for dtypes in itertools.product(
|
|
|
|
*(_valid_dtypes_for_shape(s, complex_dtypes) for s in shapes))
|
|
|
|
],
|
|
|
|
)
|
2021-07-13 11:38:21 -07:00
|
|
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
2021-07-03 18:09:58 +02:00
|
|
|
def testLogaddexp2Complex(self, shapes, dtypes):
|
|
|
|
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")
|
|
|
|
def np_op(x1, x2):
|
|
|
|
return np.log2(np.exp2(x1) + np.exp2(x2))
|
|
|
|
|
|
|
|
rng = jtu.rand_some_nan(self.rng())
|
|
|
|
args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes))
|
2021-07-07 14:07:42 -07:00
|
|
|
if jtu.device_under_test() == 'tpu':
|
|
|
|
tol = {np.complex64: 1e-3, np.complex128: 1e-10}
|
|
|
|
else:
|
|
|
|
tol = {np.complex64: 1e-5, np.complex128: 1e-14}
|
2022-06-14 11:20:37 -07:00
|
|
|
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
2022-10-06 10:20:26 -07:00
|
|
|
self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp.logaddexp2, args_maker, tol=tol)
|
2022-06-14 11:20:37 -07:00
|
|
|
self._CompileAndCheck(jnp.logaddexp2, args_maker, rtol=tol, atol=tol)
|
2021-07-03 18:09:58 +02:00
|
|
|
|
2021-12-07 08:23:37 -08:00
|
|
|
def testDefaultDtypes(self):
|
|
|
|
precision = config.jax_default_dtype_bits
|
|
|
|
assert precision in ['32', '64']
|
|
|
|
self.assertEqual(jnp.bool_, np.bool_)
|
|
|
|
self.assertEqual(jnp.int_, np.int32 if precision == '32' else np.int64)
|
|
|
|
self.assertEqual(jnp.uint, np.uint32 if precision == '32' else np.uint64)
|
|
|
|
self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64)
|
|
|
|
self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128)
|
|
|
|
|
2022-03-29 10:52:47 -07:00
|
|
|
def testFromBuffer(self):
|
|
|
|
buf = b'\x01\x02\x03'
|
|
|
|
expected = np.frombuffer(buf, dtype='uint8')
|
|
|
|
actual = jnp.frombuffer(buf, dtype='uint8')
|
|
|
|
self.assertArraysEqual(expected, actual)
|
|
|
|
|
|
|
|
def testFromFunction(self):
|
|
|
|
def f(x, y, z):
|
|
|
|
return x + 2 * y + 3 * z
|
|
|
|
shape = (3, 4, 5)
|
|
|
|
expected = np.fromfunction(f, shape=shape)
|
|
|
|
actual = jnp.fromfunction(f, shape=shape)
|
2022-12-01 13:56:42 -08:00
|
|
|
self.assertArraysEqual(expected, actual, check_dtypes=False)
|
2022-03-29 10:52:47 -07:00
|
|
|
|
|
|
|
def testFromString(self):
|
|
|
|
s = "1,2,3"
|
|
|
|
expected = np.fromstring(s, sep=',', dtype=int)
|
|
|
|
actual = jnp.fromstring(s, sep=',', dtype=int)
|
|
|
|
self.assertArraysEqual(expected, actual)
|
|
|
|
|
2023-08-02 13:55:34 -07:00
|
|
|
@jtu.sample_product(
|
|
|
|
a_shape=nonempty_nonscalar_array_shapes,
|
|
|
|
v_shape=nonempty_shapes,
|
|
|
|
dtype=jtu.dtypes.all,
|
|
|
|
)
|
|
|
|
def testPlace(self, a_shape, v_shape, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
mask_rng = jtu.rand_bool(self.rng())
|
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
a = rng(a_shape, dtype)
|
|
|
|
m = mask_rng(a_shape, bool)
|
|
|
|
v = rng(v_shape, dtype)
|
|
|
|
return a, m, v
|
|
|
|
|
|
|
|
def np_fun(a, m, v):
|
|
|
|
a_copy = a.copy()
|
|
|
|
np.place(a_copy, m, v)
|
|
|
|
return a_copy
|
|
|
|
|
|
|
|
jnp_fun = partial(jnp.place, inplace=False)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2023-07-26 08:54:54 -07:00
|
|
|
@jtu.sample_product(
|
|
|
|
a_shape=nonempty_nonscalar_array_shapes,
|
|
|
|
i_shape=all_shapes,
|
|
|
|
v_shape=all_shapes,
|
|
|
|
dtype=jtu.dtypes.all,
|
|
|
|
mode=[None, 'wrap', 'clip'],
|
|
|
|
)
|
|
|
|
def testPut(self, mode, a_shape, i_shape, v_shape, dtype):
|
|
|
|
size = math.prod(a_shape)
|
|
|
|
if math.prod(i_shape) > size:
|
|
|
|
self.skipTest("too many indices")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
# Must test unique integers, because overlapping updates in
|
|
|
|
# JAX have implementation-defined order
|
|
|
|
idx_rng = jtu.rand_unique_int(self.rng(), size)
|
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
a = rng(a_shape, dtype)
|
|
|
|
i = idx_rng(i_shape, np.int32)
|
|
|
|
v = rng(v_shape, dtype)
|
|
|
|
# put some indices out of range without duplicating indices
|
|
|
|
if mode == "clip" and i.size:
|
|
|
|
np.put(i, np.argmax(i), size + 2)
|
|
|
|
np.put(i, np.argmin(i), -2)
|
|
|
|
if mode == "wrap" and i.size:
|
|
|
|
np.put(i, 0, np.take(i, 0) + size)
|
|
|
|
return a, i, v
|
|
|
|
|
|
|
|
def np_fun(a, i, v):
|
|
|
|
a_copy = a.copy()
|
|
|
|
np.put(a_copy, i, v, mode=mode)
|
|
|
|
return a_copy
|
|
|
|
|
|
|
|
jnp_fun = partial(jnp.put, mode=mode, inplace=False)
|
|
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
|
2022-03-29 10:52:47 -07:00
|
|
|
|
2019-08-31 22:08:03 -07:00
|
|
|
# Most grad tests are at the lax level (see lax_test.py), but we add some here
|
|
|
|
# as needed for e.g. particular compound ops of interest.
|
|
|
|
|
|
|
|
GradTestSpec = collections.namedtuple(
|
2019-11-11 12:51:15 -08:00
|
|
|
"GradTestSpec",
|
|
|
|
["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"])
|
|
|
|
def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
|
|
|
|
return GradTestSpec(
|
|
|
|
op, nargs, order, rng_factory, dtypes, name or op.__name__, tol)
|
2019-08-31 22:08:03 -07:00
|
|
|
|
|
|
|
GRAD_TEST_RECORDS = [
|
2020-03-06 14:59:51 -05:00
|
|
|
grad_test_spec(jnp.arcsinh, nargs=1, order=2,
|
2019-11-11 12:51:15 -08:00
|
|
|
rng_factory=jtu.rand_positive,
|
2020-05-20 01:43:48 -03:00
|
|
|
dtypes=[np.float64, np.complex64],
|
|
|
|
tol={np.complex64: 2e-2}),
|
2022-10-05 01:52:41 +00:00
|
|
|
grad_test_spec(jnp.arccosh, nargs=1, order=1,
|
2019-11-11 12:51:15 -08:00
|
|
|
rng_factory=jtu.rand_positive,
|
2020-05-20 01:43:48 -03:00
|
|
|
dtypes=[np.float64, np.complex64],
|
|
|
|
tol={np.complex64: 2e-2}),
|
2020-03-06 14:59:51 -05:00
|
|
|
grad_test_spec(jnp.arctanh, nargs=1, order=2,
|
2020-05-04 23:00:20 -04:00
|
|
|
rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9),
|
2020-05-20 01:43:48 -03:00
|
|
|
dtypes=[np.float64, np.complex64],
|
|
|
|
tol={np.complex64: 2e-2}),
|
2020-04-13 09:44:13 -07:00
|
|
|
grad_test_spec(jnp.logaddexp, nargs=2, order=1,
|
2020-05-04 23:00:20 -04:00
|
|
|
rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9),
|
2020-05-20 01:43:48 -03:00
|
|
|
dtypes=[np.float64], tol=1e-4),
|
2020-04-13 09:44:13 -07:00
|
|
|
grad_test_spec(jnp.logaddexp2, nargs=2, order=2,
|
2020-05-04 23:00:20 -04:00
|
|
|
rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9),
|
2020-05-20 01:43:48 -03:00
|
|
|
dtypes=[np.float64], tol=1e-4),
|
2019-08-31 22:08:03 -07:00
|
|
|
]
|
|
|
|
|
|
|
|
GradSpecialValuesTestSpec = collections.namedtuple(
|
2019-11-19 16:47:32 -08:00
|
|
|
"GradSpecialValuesTestSpec", ["op", "values", "order"])
|
2019-08-31 22:08:03 -07:00
|
|
|
|
|
|
|
GRAD_SPECIAL_VALUE_TEST_RECORDS = [
|
2020-03-06 14:59:51 -05:00
|
|
|
GradSpecialValuesTestSpec(jnp.arcsinh, [0., 1000.], 2),
|
|
|
|
GradSpecialValuesTestSpec(jnp.arccosh, [1000.], 2),
|
|
|
|
GradSpecialValuesTestSpec(jnp.arctanh, [0.], 2),
|
2020-04-13 09:44:13 -07:00
|
|
|
GradSpecialValuesTestSpec(jnp.sinc, [0.], 1),
|
2019-08-31 22:08:03 -07:00
|
|
|
]
|
|
|
|
|
2022-02-14 09:22:05 -08:00
|
|
|
|
2019-08-31 22:08:03 -07:00
|
|
|
class NumpyGradTests(jtu.JaxTestCase):
|
2021-08-03 11:29:10 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.sample_product_testcases(
|
|
|
|
[dict(op=rec.op, rng_factory=rec.rng_factory, tol=rec.tol,
|
|
|
|
order=rec.order)],
|
|
|
|
shapes=itertools.combinations_with_replacement(nonempty_shapes, rec.nargs),
|
|
|
|
dtype=rec.dtypes)
|
|
|
|
for rec in GRAD_TEST_RECORDS))
|
2021-08-03 11:29:10 -07:00
|
|
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
2022-06-14 11:20:37 -07:00
|
|
|
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
|
2019-11-11 12:51:15 -08:00
|
|
|
def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2023-01-25 12:15:00 -08:00
|
|
|
tol = jtu.join_tolerance(tol, {np.float32: 1e-1, np.float64: 1e-3,
|
|
|
|
np.complex64: 1e-1, np.complex128: 1e-3})
|
|
|
|
if jtu.device_under_test() == 'tpu' and op == jnp.arctanh:
|
|
|
|
tol = jtu.join_tolerance(tol, {np.float32: 2e-1})
|
|
|
|
|
2019-08-31 22:08:03 -07:00
|
|
|
args = tuple(rng(shape, dtype) for shape in shapes)
|
2023-01-25 12:15:00 -08:00
|
|
|
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
|
2019-08-31 22:08:03 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.sample_product_testcases(
|
|
|
|
[dict(op=rec.op, order=rec.order)],
|
|
|
|
special_value=rec.values
|
|
|
|
)
|
2019-08-31 22:08:03 -07:00
|
|
|
for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS))
|
2019-11-19 16:47:32 -08:00
|
|
|
def testOpGradSpecialValue(self, op, special_value, order):
|
2023-01-25 12:15:00 -08:00
|
|
|
check_grads(op, (special_value,), order, ["fwd", "rev"],
|
|
|
|
atol={np.float32: 3e-3})
|
2019-08-31 22:08:03 -07:00
|
|
|
|
2020-12-02 00:36:39 -08:00
|
|
|
def testSincAtZero(self):
|
|
|
|
# Some manual tests for sinc at zero, since it doesn't have well-behaved
|
|
|
|
# numerical derivatives at zero
|
|
|
|
def deriv(f):
|
2021-09-13 16:00:22 -04:00
|
|
|
return lambda x: jax.jvp(f, (x,), (1.,))[1]
|
2020-12-02 00:36:39 -08:00
|
|
|
|
|
|
|
def apply_all(fns, x):
|
|
|
|
for f in fns:
|
|
|
|
x = f(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
d1 = 0.
|
2021-09-13 16:00:22 -04:00
|
|
|
for ops in itertools.combinations_with_replacement([deriv, jax.grad], 1):
|
2020-12-02 00:36:39 -08:00
|
|
|
self.assertAllClose(apply_all(ops, jnp.sinc)(0.), d1)
|
|
|
|
|
|
|
|
d2 = -np.pi ** 2 / 3
|
2021-09-13 16:00:22 -04:00
|
|
|
for ops in itertools.combinations_with_replacement([deriv, jax.grad], 2):
|
2020-12-02 00:36:39 -08:00
|
|
|
self.assertAllClose(apply_all(ops, jnp.sinc)(0.), d2)
|
|
|
|
|
|
|
|
d3 = 0.
|
2021-09-13 16:00:22 -04:00
|
|
|
for ops in itertools.combinations_with_replacement([deriv, jax.grad], 3):
|
2020-12-02 00:36:39 -08:00
|
|
|
self.assertAllClose(apply_all(ops, jnp.sinc)(0.), d3)
|
|
|
|
|
|
|
|
d4 = np.pi ** 4 / 5
|
2021-09-13 16:00:22 -04:00
|
|
|
for ops in itertools.combinations_with_replacement([deriv, jax.grad], 4):
|
2020-12-02 00:36:39 -08:00
|
|
|
self.assertAllClose(apply_all(ops, jnp.sinc)(0.), d4)
|
|
|
|
|
2020-12-02 08:55:14 -08:00
|
|
|
def testSincGradArrayInput(self):
|
|
|
|
# tests for a bug almost introduced in #5077
|
|
|
|
jax.grad(lambda x: jnp.sinc(x).sum())(jnp.arange(10.)) # doesn't crash
|
|
|
|
|
2019-10-17 22:38:28 +00:00
|
|
|
def testTakeAlongAxisIssue1521(self):
|
|
|
|
# https://github.com/google/jax/issues/1521
|
2020-03-06 14:59:51 -05:00
|
|
|
idx = jnp.repeat(jnp.arange(3), 10).reshape((30, 1))
|
2019-10-17 22:38:28 +00:00
|
|
|
|
|
|
|
def f(x):
|
2020-03-06 14:59:51 -05:00
|
|
|
y = x * jnp.arange(3.).reshape((1, 3))
|
|
|
|
return jnp.take_along_axis(y, idx, -1).sum()
|
2019-10-17 22:38:28 +00:00
|
|
|
|
|
|
|
check_grads(f, (1.,), order=1)
|
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shapes=filter(_shapes_are_broadcast_compatible,
|
|
|
|
itertools.combinations_with_replacement(nonempty_shapes, 2)),
|
|
|
|
dtype=(np.complex128,),
|
|
|
|
)
|
2021-08-03 11:29:10 -07:00
|
|
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
2021-07-03 18:09:58 +02:00
|
|
|
def testGradLogaddexpComplex(self, shapes, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2022-06-14 11:20:37 -07:00
|
|
|
args = tuple(jnp.array(rng(shape, dtype)) for shape in shapes)
|
2023-01-25 12:15:00 -08:00
|
|
|
if jtu.device_under_test() == "tpu":
|
|
|
|
tol = 5e-2
|
2021-07-03 18:09:58 +02:00
|
|
|
else:
|
2023-01-25 12:15:00 -08:00
|
|
|
tol = 3e-2
|
|
|
|
check_grads(jnp.logaddexp, args, 1, ["fwd", "rev"], tol, tol)
|
2021-07-03 18:09:58 +02:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@jtu.sample_product(
|
|
|
|
shapes=filter(_shapes_are_broadcast_compatible,
|
|
|
|
itertools.combinations_with_replacement(nonempty_shapes, 2)),
|
|
|
|
dtype=(np.complex128,),
|
|
|
|
)
|
2021-08-03 11:29:10 -07:00
|
|
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
2021-07-03 18:09:58 +02:00
|
|
|
def testGradLogaddexp2Complex(self, shapes, dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2022-06-14 11:20:37 -07:00
|
|
|
args = tuple(jnp.array(rng(shape, dtype)) for shape in shapes)
|
2023-01-25 12:15:00 -08:00
|
|
|
if jtu.device_under_test() == "tpu":
|
|
|
|
tol = 5e-2
|
2021-07-03 18:09:58 +02:00
|
|
|
else:
|
2023-01-25 12:15:00 -08:00
|
|
|
tol = 3e-2
|
|
|
|
check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol)
|
2019-08-31 22:08:03 -07:00
|
|
|
|
2022-02-14 09:22:05 -08:00
|
|
|
|
2020-07-17 14:15:52 -04:00
|
|
|
class NumpySignaturesTest(jtu.JaxTestCase):
|
|
|
|
|
2020-07-01 16:52:41 -07:00
|
|
|
def testWrappedSignaturesMatch(self):
|
|
|
|
"""Test that jax.numpy function signatures match numpy."""
|
|
|
|
jnp_funcs = {name: getattr(jnp, name) for name in dir(jnp)}
|
|
|
|
func_pairs = {name: (fun, fun.__np_wrapped__) for name, fun in jnp_funcs.items()
|
2022-07-15 14:36:30 -07:00
|
|
|
if getattr(fun, '__np_wrapped__', None) is not None}
|
2020-07-01 16:52:41 -07:00
|
|
|
assert len(func_pairs) > 0
|
|
|
|
|
|
|
|
# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
|
|
|
|
unsupported_params = {
|
2023-02-08 14:41:39 -08:00
|
|
|
'argpartition': ['kind', 'order'],
|
2021-03-16 13:46:13 -04:00
|
|
|
'asarray': ['like'],
|
2022-10-25 11:17:57 -07:00
|
|
|
'broadcast_to': ['subok'],
|
2020-11-17 10:02:28 -08:00
|
|
|
'clip': ['kwargs'],
|
2022-03-01 10:44:33 -08:00
|
|
|
'copy': ['subok'],
|
2021-03-16 13:46:13 -04:00
|
|
|
'corrcoef': ['ddof', 'bias', 'dtype'],
|
|
|
|
'cov': ['dtype'],
|
2020-11-17 12:53:00 -08:00
|
|
|
'empty_like': ['subok', 'order'],
|
2020-11-17 10:02:28 -08:00
|
|
|
'einsum': ['kwargs'],
|
2020-07-17 14:15:52 -04:00
|
|
|
'einsum_path': ['einsum_call'],
|
2021-03-16 13:46:13 -04:00
|
|
|
'eye': ['order', 'like'],
|
2022-08-14 08:26:27 +02:00
|
|
|
'hstack': ['casting'],
|
2021-03-16 13:46:13 -04:00
|
|
|
'identity': ['like'],
|
2022-07-14 14:35:10 -07:00
|
|
|
'in1d': ['kind'],
|
|
|
|
'isin': ['kind'],
|
2021-03-16 13:46:13 -04:00
|
|
|
'full': ['order', 'like'],
|
2020-11-17 12:53:00 -08:00
|
|
|
'full_like': ['subok', 'order'],
|
2022-03-29 10:52:47 -07:00
|
|
|
'fromfunction': ['like'],
|
2020-07-01 16:52:41 -07:00
|
|
|
'histogram': ['normed'],
|
2020-10-04 17:46:13 -04:00
|
|
|
'histogram2d': ['normed'],
|
2020-09-21 16:59:46 -04:00
|
|
|
'histogramdd': ['normed'],
|
2023-07-10 09:56:17 -07:00
|
|
|
'nanstd': ['mean'],
|
|
|
|
'nanvar': ['mean'],
|
2021-03-16 13:46:13 -04:00
|
|
|
'ones': ['order', 'like'],
|
2020-11-17 12:53:00 -08:00
|
|
|
'ones_like': ['subok', 'order'],
|
2023-01-30 13:50:25 -08:00
|
|
|
'partition': ['kind', 'order'],
|
2022-08-14 08:26:27 +02:00
|
|
|
'row_stack': ['casting'],
|
|
|
|
'stack': ['casting'],
|
2023-07-10 09:56:17 -07:00
|
|
|
'std': ['mean'],
|
2021-03-16 13:46:13 -04:00
|
|
|
'tri': ['like'],
|
2022-06-23 11:46:51 -07:00
|
|
|
'unique': ['equal_nan'],
|
2023-07-10 09:56:17 -07:00
|
|
|
'var': ['mean'],
|
2022-08-14 08:26:27 +02:00
|
|
|
'vstack': ['casting'],
|
2020-11-17 12:53:00 -08:00
|
|
|
'zeros_like': ['subok', 'order']
|
2020-07-01 16:52:41 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
extra_params = {
|
2023-03-16 09:21:42 -07:00
|
|
|
'einsum': ['subscripts', 'precision'],
|
2020-07-17 14:15:52 -04:00
|
|
|
'einsum_path': ['subscripts'],
|
2022-04-19 16:05:29 -04:00
|
|
|
'take_along_axis': ['mode'],
|
2020-07-01 16:52:41 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
mismatches = {}
|
|
|
|
|
|
|
|
for name, (jnp_fun, np_fun) in func_pairs.items():
|
2022-07-14 14:35:10 -07:00
|
|
|
if numpy_version >= (1, 24) and name in ['histogram', 'histogram2d', 'histogramdd']:
|
|
|
|
# numpy 1.24 re-orders the density and weights arguments.
|
|
|
|
# TODO(jakevdp): migrate histogram APIs to match newer numpy versions.
|
|
|
|
continue
|
2020-07-01 16:52:41 -07:00
|
|
|
# Note: can't use inspect.getfullargspec due to numpy issue
|
|
|
|
# https://github.com/numpy/numpy/issues/12225
|
|
|
|
try:
|
|
|
|
np_params = inspect.signature(np_fun).parameters
|
|
|
|
except ValueError:
|
|
|
|
# Some functions cannot be inspected
|
|
|
|
continue
|
|
|
|
jnp_params = inspect.signature(jnp_fun).parameters
|
|
|
|
extra = set(extra_params.get(name, []))
|
|
|
|
unsupported = set(unsupported_params.get(name, []))
|
|
|
|
|
2020-07-10 09:29:06 -07:00
|
|
|
# Checks to prevent tests from becoming out-of-date. If these fail,
|
2020-07-01 16:52:41 -07:00
|
|
|
# it means that extra_params or unsupported_params need to be updated.
|
2022-12-01 09:12:01 -08:00
|
|
|
assert extra.issubset(jnp_params), f"{name}: {extra=} is not a subset of jnp_params={set(jnp_params)}."
|
|
|
|
assert not unsupported.intersection(jnp_params), f"{name}: {unsupported=} overlaps with jnp_params={set(jnp_params)}."
|
2020-07-01 16:52:41 -07:00
|
|
|
|
|
|
|
# Skip functions that only have *args and **kwargs; we can't introspect these further.
|
|
|
|
var_args = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
|
|
|
|
if all(p.kind in var_args for p in jnp_params.values()):
|
|
|
|
continue
|
|
|
|
if all(p.kind in var_args for p in np_params.values()):
|
|
|
|
continue
|
|
|
|
|
|
|
|
# Remove known extra parameters.
|
|
|
|
jnp_params = {a: p for a, p in jnp_params.items() if a not in extra}
|
|
|
|
|
|
|
|
# Remove known unsupported parameters.
|
|
|
|
np_params = {a: p for a, p in np_params.items() if a not in unsupported}
|
|
|
|
|
|
|
|
# Older versions of numpy may have fewer parameters; to avoid extraneous errors on older numpy
|
|
|
|
# versions, we allow for jnp to have more parameters.
|
|
|
|
if list(jnp_params)[:len(np_params)] != list(np_params):
|
|
|
|
mismatches[name] = {'np_params': list(np_params), 'jnp_params': list(jnp_params)}
|
|
|
|
|
|
|
|
self.assertEqual(mismatches, {})
|
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
_available_numpy_dtypes: list[str] = [dtype.__name__ for dtype in jtu.dtypes.all
|
2022-01-31 14:44:45 -08:00
|
|
|
if dtype != dtypes.bfloat16]
|
2020-09-08 13:30:57 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _all_numpy_ufuncs() -> Iterator[str]:
|
|
|
|
"""Generate the names of all ufuncs in the top-level numpy namespace."""
|
|
|
|
for name in dir(np):
|
|
|
|
f = getattr(np, name)
|
|
|
|
if isinstance(f, np.ufunc):
|
2023-02-02 12:30:41 -05:00
|
|
|
# jnp.spacing is not implemented.
|
|
|
|
if f.__name__ != "spacing":
|
|
|
|
yield name
|
2020-09-08 13:30:57 -07:00
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]:
|
2020-09-08 13:30:57 -07:00
|
|
|
"""Generate valid dtypes of inputs to the given numpy ufunc."""
|
|
|
|
func = getattr(np, name)
|
2022-01-31 14:44:45 -08:00
|
|
|
for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin):
|
2020-09-08 13:30:57 -07:00
|
|
|
args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes)
|
|
|
|
try:
|
|
|
|
with warnings.catch_warnings():
|
|
|
|
warnings.filterwarnings("ignore", "divide by zero", RuntimeWarning)
|
|
|
|
_ = func(*args)
|
|
|
|
except TypeError:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
yield arg_dtypes
|
|
|
|
|
|
|
|
|
|
|
|
class NumpyUfuncTests(jtu.JaxTestCase):
|
2021-08-03 11:29:10 -07:00
|
|
|
|
2022-10-05 01:52:41 +00:00
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.sample_product_testcases([dict(name=name)],
|
|
|
|
arg_dtypes=_dtypes_for_ufunc(name))
|
2020-09-08 13:30:57 -07:00
|
|
|
for name in _all_numpy_ufuncs()
|
2022-10-05 01:52:41 +00:00
|
|
|
))
|
2020-09-08 13:30:57 -07:00
|
|
|
def testUfuncInputTypes(self, name, arg_dtypes):
|
|
|
|
if name == 'arctanh' and jnp.issubdtype(arg_dtypes[0], jnp.complexfloating):
|
|
|
|
self.skipTest("np.arctanh & jnp.arctanh have mismatched NaNs for complex input.")
|
|
|
|
|
|
|
|
jnp_op = getattr(jnp, name)
|
|
|
|
np_op = getattr(np, name)
|
|
|
|
np_op = jtu.ignore_warning(category=RuntimeWarning,
|
|
|
|
message="divide by zero.*")(np_op)
|
|
|
|
args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes)
|
|
|
|
|
2022-06-16 13:59:53 -07:00
|
|
|
with jtu.strict_promotion_if_dtypes_match(arg_dtypes):
|
2022-06-14 11:20:37 -07:00
|
|
|
try:
|
|
|
|
jnp_op(*args_maker())
|
|
|
|
except NotImplementedError:
|
|
|
|
self.skipTest(f"jtu.{name} is not yet implemented.")
|
2020-09-08 13:30:57 -07:00
|
|
|
|
2022-06-14 11:20:37 -07:00
|
|
|
# large tol comes from the fact that numpy returns float16 in places
|
|
|
|
# that jnp returns float32. e.g. np.cos(np.uint8(0))
|
|
|
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False, tol=1E-2)
|
2020-09-08 13:30:57 -07:00
|
|
|
|
2022-02-14 09:22:05 -08:00
|
|
|
|
2021-03-04 14:16:43 -08:00
|
|
|
class NumpyDocTests(jtu.JaxTestCase):
|
2021-08-03 11:29:10 -07:00
|
|
|
|
2021-03-04 14:16:43 -08:00
|
|
|
def test_lax_numpy_docstrings(self):
|
|
|
|
# Test that docstring wrapping & transformation didn't fail.
|
|
|
|
|
|
|
|
# Functions that have their own docstrings & don't wrap numpy.
|
2023-08-10 14:58:18 -07:00
|
|
|
known_exceptions = {'fromfile', 'fromiter', 'frompyfunc', 'vectorize'}
|
2021-03-04 14:16:43 -08:00
|
|
|
|
|
|
|
for name in dir(jnp):
|
2021-03-05 09:16:41 -08:00
|
|
|
if name in known_exceptions or name.startswith('_'):
|
2021-03-04 14:16:43 -08:00
|
|
|
continue
|
|
|
|
|
|
|
|
# We only check signatures of functions.
|
|
|
|
obj = getattr(jnp, name)
|
|
|
|
if isinstance(obj, type) or not callable(obj):
|
|
|
|
continue
|
|
|
|
|
|
|
|
# Some jnp functions are imported from numpy or jax.dtypes directly.
|
|
|
|
if any(obj is getattr(mod, obj.__name__, None) for mod in [np, dtypes]):
|
|
|
|
continue
|
|
|
|
|
2021-03-05 09:16:41 -08:00
|
|
|
wrapped_fun = obj.__np_wrapped__
|
2023-05-25 09:02:05 -07:00
|
|
|
if wrapped_fun is None:
|
|
|
|
continue
|
2021-03-05 09:16:41 -08:00
|
|
|
|
|
|
|
# If the wrapped function has a docstring, obj should too
|
|
|
|
if wrapped_fun.__doc__ and not obj.__doc__:
|
|
|
|
raise Exception(f"jnp.{name} does not contain wrapped docstring.")
|
|
|
|
|
|
|
|
if obj.__doc__ and "*Original docstring below.*" not in obj.__doc__:
|
2021-03-04 14:16:43 -08:00
|
|
|
raise Exception(f"jnp.{name} does not have a wrapped docstring.")
|
|
|
|
|
2022-01-21 12:18:39 -08:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_jit" if jit else "", "jit": jit} for jit in [True, False])
|
|
|
|
def test_wrapped_function_parameters(self, jit):
|
|
|
|
def orig(x):
|
|
|
|
"""Example Docstring
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
x : array_like
|
|
|
|
Input Data
|
|
|
|
|
|
|
|
.. versionadded:: 1.8.0
|
|
|
|
out : array_like, optional
|
|
|
|
Output to overwrite
|
|
|
|
other_arg : Any
|
|
|
|
not used
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
x : input
|
|
|
|
"""
|
|
|
|
return x
|
|
|
|
|
|
|
|
def wrapped(x, out=None):
|
|
|
|
return x
|
|
|
|
|
|
|
|
if jit:
|
|
|
|
wrapped = jax.jit(wrapped)
|
|
|
|
|
|
|
|
wrapped = _wraps(orig, skip_params=['out'])(wrapped)
|
|
|
|
doc = wrapped.__doc__
|
|
|
|
|
|
|
|
self.assertStartsWith(doc, "Example Docstring")
|
|
|
|
self.assertIn("Original docstring below", doc)
|
|
|
|
self.assertIn("Parameters", doc)
|
|
|
|
self.assertIn("Returns", doc)
|
|
|
|
self.assertNotIn('out', doc)
|
|
|
|
self.assertNotIn('other_arg', doc)
|
|
|
|
self.assertNotIn('versionadded', doc)
|
|
|
|
|
2020-09-08 13:30:57 -07:00
|
|
|
|
2021-03-05 09:16:41 -08:00
|
|
|
def test_parse_numpydoc(self):
|
|
|
|
# Unit test ensuring that _parse_numpydoc correctly parses docstrings for all
|
|
|
|
# functions in NumPy's top-level namespace.
|
|
|
|
section_titles = {'Attributes', 'Examples', 'Notes',
|
|
|
|
'Parameters', 'Raises', 'References',
|
|
|
|
'Returns', 'See also', 'See Also', 'Warnings', 'Warns'}
|
|
|
|
headings = [title + '\n' + '-'*len(title) for title in section_titles]
|
|
|
|
|
|
|
|
for name in dir(np):
|
|
|
|
if name.startswith('_'):
|
|
|
|
continue
|
|
|
|
obj = getattr(np, name)
|
|
|
|
if isinstance(obj, type):
|
|
|
|
continue
|
|
|
|
if not callable(obj):
|
|
|
|
continue
|
|
|
|
if 'built-in function' in repr(obj):
|
|
|
|
continue
|
|
|
|
parsed = _parse_numpydoc(obj.__doc__)
|
|
|
|
|
|
|
|
# Check that no docstring is handled gracefully.
|
|
|
|
if not obj.__doc__:
|
|
|
|
self.assertEqual(parsed, ParsedDoc(obj.__doc__))
|
|
|
|
continue
|
|
|
|
|
|
|
|
# Check that no unexpected section names are found.
|
|
|
|
extra_keys = parsed.sections.keys() - section_titles
|
|
|
|
if extra_keys:
|
|
|
|
raise ValueError(f"Extra section headers found in np.{name}: {extra_keys}")
|
|
|
|
|
|
|
|
# Check that every docstring has a summary.
|
|
|
|
if not parsed.summary:
|
|
|
|
raise ValueError(f"No summary found for np.{name}")
|
|
|
|
|
|
|
|
# Check that no expected headings are missed.
|
|
|
|
for heading in headings:
|
|
|
|
assert heading not in parsed.front_matter
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
if __name__ == "__main__":
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|