From 9bdad4debec074a71def805717171289a1b6444c Mon Sep 17 00:00:00 2001 From: Shuhan Ding Date: Thu, 22 Feb 2024 18:23:43 -0800 Subject: [PATCH 1/3] setup ci for jax-metal plugin --- .github/workflows/metal_plugin_ci.yml | 42 + jax/_src/test_util.py | 4 + tests/lax_metal_test.py | 1681 +++++++++++++++++++++++++ 3 files changed, 1727 insertions(+) create mode 100644 .github/workflows/metal_plugin_ci.yml create mode 100644 tests/lax_metal_test.py diff --git a/.github/workflows/metal_plugin_ci.yml b/.github/workflows/metal_plugin_ci.yml new file mode 100644 index 000000000..a1f15a712 --- /dev/null +++ b/.github/workflows/metal_plugin_ci.yml @@ -0,0 +1,42 @@ +# JAX-Metal plugin CI + +name: Jax-Metal CI +on: + workflow_dispatch: # allows triggering the workflow run manually + +jobs: + jax-metal-plugin-test: + + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + jaxlib-version: ["plugin_latest"] + name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})" + runs-on: [self-hosted, macOS, ARM64] + + steps: + - name: Get repo + uses: actions/checkout@v4 + with: + path: jax + - name: Setup build and test enviroment + run: | + rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv + python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv + source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate + pip install -U pip numpy wheel + pip install jax-metal absl-py pytest + if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then + pip install --pre jaxlib \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + fi; + cd jax + pip install . + - name: Run test + run: | + source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate + export ENABLE_PJRT_COMPATIBILITY=1 + cd jax + pytest tests/lax_metal_test.py + + diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c420df50e..d1a5091ba 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -349,6 +349,8 @@ def supported_dtypes(): elif device_under_test() == "iree": types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, np.uint32, np.float32} + elif device_under_test() == "METAL": + types = {np.int32, np.uint32, np.float32} else: types = {np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, @@ -423,6 +425,8 @@ def _get_device_tags(): device_tags = {device_under_test(), "rocm"} elif is_device_cuda(): device_tags = {device_under_test(), "cuda"} + elif device_under_test() == "METAL": + device_tags = {device_under_test(), "gpu"} else: device_tags = {device_under_test()} return device_tags diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py new file mode 100644 index 000000000..59f8aa40c --- /dev/null +++ b/tests/lax_metal_test.py @@ -0,0 +1,1681 @@ +import collections +import copy +from functools import partial +import inspect +import io +import itertools +import math +from typing import cast, Iterator, Optional, List, Tuple +import unittest +from unittest import SkipTest +import warnings + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np +try: + import numpy_dispatch +except ImportError: + numpy_dispatch = None + +import jax +import jax.ops +from jax import lax +from jax import numpy as jnp +from jax import tree_util +from jax.test_util import check_grads + +from jax._src import core +from jax._src import array +from jax._src import dtypes +from jax._src import test_util as jtu +from jax._src.lax import lax as lax_internal +from jax._src.numpy.util import _parse_numpydoc, ParsedDoc +import jax.util +from jax._src import array +from jax._src import lax_reference +from jax._src import config +config.parse_flags_with_absl() + +numpy_version = jtu.numpy_version() + +nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] +nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes +one_dim_array_shapes = [(1,), (6,), (12,)] +empty_array_shapes = [(0,), (0, 4), (3, 0),] + +scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE] +array_shapes = nonempty_array_shapes +nonzerodim_shapes = nonempty_nonscalar_array_shapes +nonempty_shapes = scalar_shapes + nonempty_array_shapes +all_shapes = scalar_shapes + array_shapes + +float_dtypes = [np.float16, np.float32] +int_dtypes = [np.int32] +unsigned_dtypes = [np.uint32] +bool_dtypes = jtu.dtypes.boolean +default_dtypes = float_dtypes + int_dtypes +inexact_dtypes = float_dtypes +number_dtypes = float_dtypes + int_dtypes + unsigned_dtypes +all_dtypes = number_dtypes + bool_dtypes + + +python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_] + +# uint64 is problematic because with any uint type it promotes to float: +int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64] + +def _indexer_with_default_outputs(indexer, use_defaults=True): + """Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs""" + class Indexer: + @partial(jtu.with_jax_dtype_defaults, use_defaults=use_defaults) + def __getitem__(self, *args): + return indexer.__getitem__(*args) + return Indexer() + +def _valid_dtypes_for_shape(shape, dtypes): + # Not all (shape, dtype) pairs are valid. In particular, Python scalars only + # have one type in each category (float, bool, etc.) + if shape is jtu.PYTHON_SCALAR_SHAPE: + return [t for t in dtypes if t in python_scalar_dtypes] + return dtypes + +def _shape_and_dtypes(shapes, dtypes): + for shape in shapes: + for dtype in _valid_dtypes_for_shape(shape, dtypes): + yield (shape, dtype) + +def _compatible_shapes(shape): + if np.ndim(shape) == 0 or shape in scalar_shapes: + return [shape] + return (shape[n:] for n in range(len(shape) + 1)) + +OpRecord = collections.namedtuple( + "OpRecord", + ["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes", + "test_name", "check_dtypes", "tolerance", "inexact", "kwargs"]) + +def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, + test_name=None, check_dtypes=True, + tolerance=None, inexact=False, kwargs=None): + test_name = test_name or name + return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes, + test_name, check_dtypes, tolerance, inexact, kwargs) + + +JAX_ARGMINMAX_RECORDS = [ + op_record("argmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []), + op_record("argmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []), + op_record("nanargmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []), + op_record("nanargmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []), +] + +def _shapes_are_broadcast_compatible(shapes): + try: + lax.broadcast_shapes(*(() if s in scalar_shapes else s for s in shapes)) + except ValueError: + return False + else: + return True + +def _shapes_are_equal_length(shapes): + return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) + + +class LaxBackedNumpyTests(jtu.JaxTestCase): + """Tests for LAX-backed Numpy tests for metal implementation.""" + + def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): + def f(): + out = [rng(shape, dtype or jnp.float_) + for shape, dtype in zip(shapes, dtypes)] + if np_arrays: + return out + return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a + for a in out] + return f + + @parameterized.parameters( + [dtype for dtype in [jnp.bool_, jnp.uint8, jnp.uint16, jnp.uint32, + jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64, + jnp.float16, jnp.float32] + if dtype == dtypes.canonicalize_dtype(dtype)]) + def testDtypeWrappers(self, dtype): + arr = dtype(0) + self.assertIsInstance(arr, jax.Array) + self.assertEqual(arr.dtype, np.dtype(dtype)) + self.assertArraysEqual(arr, 0, check_dtypes=False) + + # No copy primitive is generated + jaxpr = jax.make_jaxpr(dtype)(0) + prims = [eqn.primitive for eqn in jaxpr.eqns] + self.assertEqual(prims, [lax.convert_element_type_p]) # No copy generated. + + @jtu.sample_product( + dtype=float_dtypes + [object], + allow_pickle=[True, False], + ) + def testLoad(self, dtype, allow_pickle): + if dtype == object and not allow_pickle: + self.skipTest("dtype=object requires allow_pickle=True") + rng = jtu.rand_default(self.rng()) + arr = rng((10), dtype) + with io.BytesIO() as f: + jnp.save(f, arr) + f.seek(0) + arr_out = jnp.load(f, allow_pickle=allow_pickle) + self.assertArraysEqual(arr, arr_out, allow_object_dtype=True) + + def testArrayEqualExamples(self): + # examples from the array_equal() docstring. + self.assertTrue(jnp.array_equal([1, 2], [1, 2])) + self.assertTrue(jnp.array_equal(np.array([1, 2]), np.array([1, 2]))) + self.assertFalse(jnp.array_equal([1, 2], [1, 2, 3])) + self.assertFalse(jnp.array_equal([1, 2], [1, 4])) + + a = np.array([1, np.nan]) + self.assertFalse(jnp.array_equal(a, a)) + self.assertTrue(jnp.array_equal(a, a, equal_nan=True)) + + def testArrayEquivExamples(self): + # examples from the array_equiv() docstring. + self.assertTrue(jnp.array_equiv([1, 2], [1, 2])) + self.assertFalse(jnp.array_equiv([1, 2], [1, 3])) + with jax.numpy_rank_promotion('allow'): + self.assertTrue(jnp.array_equiv([1, 2], [[1, 2], [1, 2]])) + self.assertFalse(jnp.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]])) + self.assertFalse(jnp.array_equiv([1, 2], [[1, 2], [1, 3]])) + + def testArrayModule(self): + if numpy_dispatch is None: + raise SkipTest('requires https://github.com/seberg/numpy-dispatch') + + jnp_array = jnp.array(1.0) + np_array = np.array(1.0) + + module = numpy_dispatch.get_array_module(jnp_array) + self.assertIs(module, jnp) + + module = numpy_dispatch.get_array_module(jnp_array, np_array) + self.assertIs(module, jnp) + + def f(x): + module = numpy_dispatch.get_array_module(x) + self.assertIs(module, jnp) + return x + jax.jit(f)(jnp_array) + jax.grad(f)(jnp_array) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in list(range(-len(shape), len(shape)))], + discont=[None, "pi", 2], + period=["2pi", "pi"], + dtype=default_dtypes, + ) + + # This currently seg-faults because of dynamic shape = (0, ) + def testUnwrap(self, shape, dtype, axis, discont, period): + if (not np.all(shape)): + self.skipTest("JAX metal does not support dynamic shape unwrap()") + + self.skipTest("JAX metal does not support remainder") + + if numpy_version < (1, 21) and period != "2pi": + self.skipTest("numpy < 1.21 does not support the period argument to unwrap()") + special_vals = {"pi": np.pi, "2pi": 2 * np.pi} + period = special_vals.get(period, period) + discont = special_vals.get(discont, discont) + + rng = jtu.rand_default(self.rng()) + + def np_fun(x): + dtype = None + if x.dtype == dtypes.bfloat16: + dtype = x.dtype + x = x.astype(np.float32) + if numpy_version < (1, 21): + out = np.unwrap(x, axis=axis, discont=discont or np.pi) + else: + out = np.unwrap(x, axis=axis, discont=discont, period=period) + return out if dtype is None else out.astype(dtype) + + jnp_fun = partial(jnp.unwrap, axis=axis, discont=discont, period=period) + if not dtypes.issubdtype(dtype, np.inexact): + # This case requires implicit dtype promotion + jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + atol={dtypes.bfloat16: 1e-1, np.float16: 1e-2}) + self._CompileAndCheck(jnp_fun, args_maker, atol={dtypes.bfloat16: 1e-1}) + + @jtu.sample_product( + [dict(shape=shape, indices=indices, update_shape=update_shape) + for shape, indices, update_shape in [ + [(3,), (1,), (1,)], + [(5, 3), (1, 1), (3, 1)], + [(7, 5, 3), (4, 1, 0), (2, 3, 1)], + [(1, 50, 8, 64), (0,20, 0, 0), (1,1,8,64)] + ] + ], + dtype=default_dtypes, + ) + def testDynamicUpdateSliceAgainstNumpy(self, shape, dtype, indices, + update_shape): + rng = jtu.rand_default(self.rng()) + + def args_maker(): + return [rng(shape, dtype), rng(update_shape, dtype), np.array(indices)] + + self._CheckAgainstNumpy(lax_reference.dynamic_update_slice, + lax.dynamic_update_slice, args_maker) + + @jtu.sample_product( + [dict(init_val=init_val, op=op, dtype=dtype) + for init_val, op, dtypes in [ + (0, lax.add, [np.float32]), + (-np.inf, lax.max, [np.float32]), + #(np.inf, lax.min, [np.float32]), + ] + for dtype in [np.float32] + ], + [dict(shape=shape, dims=dims, strides=strides, padding=padding, + base_dilation=base_dilation, window_dilation=window_dilation) + for shape, dims, strides, padding, base_dilation, window_dilation in ( + itertools.chain( + itertools.product( + [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], + [(1, 2, 2, 1), (1, 1, 1, 1)], + ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], + [(1, 1, 1, 1)], + [(1, 1, 1, 1), (1, 2, 2, 1)]))) + ], + ) + def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding, + base_dilation, window_dilation): + rng = jtu.rand_small(self.rng()) + init_val = np.asarray(init_val, dtype=dtype) + + def fun(operand, init_val): + return lax.reduce_window(operand, init_val, op, dims, strides, padding, + base_dilation, window_dilation) + + def reference_fun(operand, init_val): + return lax_reference.reduce_window(operand, init_val, op, dims, strides, + padding, base_dilation) + + args_maker = lambda: [rng(shape, dtype), init_val] + self._CompileAndCheck(fun, args_maker) + if all(d == 1 for d in window_dilation): + self._CheckAgainstNumpy(reference_fun, fun, args_maker) + + # we separately test the version that uses a concrete init_val because it + # can hit different code paths + def fun(operand): + return lax.reduce_window(operand, init_val, op, dims, strides, padding, + base_dilation, window_dilation) + + args_maker = lambda: [rng(shape, dtype)] + self._CompileAndCheck(fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in list(range(-len(shape), len(shape))) + [None]], + dtype=all_dtypes, + ) + def testCountNonzero(self, shape, dtype, axis): + if (not np.all(shape)): + self.skipTest("JAX metal does not support dynamic shape unwrap()") + rng = jtu.rand_some_zero(self.rng()) + np_fun = lambda x: np.count_nonzero(x, axis) + jnp_fun = lambda x: jnp.count_nonzero(x, axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + def testNonzero(self, shape, dtype): + if (not np.all(shape)): + self.skipTest("JAX metal does not support dynamic shape unwrap()") + if (dtype == np.uint16): + self.skipTest("JAX metal has a crash with NonZero uint16") + self.skipTest("JAX metal has a bug with Pad") + rng = jtu.rand_some_zero(self.rng()) + np_fun = lambda x: np.nonzero(x) + np_fun = jtu.ignore_warning( + category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*")(np_fun) + jnp_fun = lambda x: jnp.nonzero(x) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + + @jtu.sample_product( + [dict(shape=shape, fill_value=fill_value) + for shape in nonempty_array_shapes + for fill_value in [None, -1, shape or (1,)] + ], + dtype=all_dtypes, + size=[1, 5, 10], + ) + def testNonzeroSize(self, shape, dtype, size, fill_value): + self.skipTest("JAX metal has a bug with Pad") + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + @jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*") + def np_fun(x): + result = np.nonzero(x) + if size <= len(result[0]): + return tuple(arg[:size] for arg in result) + else: + fillvals = fill_value if np.ndim(fill_value) else len(result) * [fill_value or 0] + return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) + for fval, arg in safe_zip(fillvals, result)) + jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=[np.float32]) + def testArgWhere(self, shape, dtype): + if (not np.all(shape)): + self.skipTest("JAX metal does not support dynamic shape unwrap()") + self.skipTest("JAX metal does not support ArgWhere") + rng = jtu.rand_some_zero(self.rng()) + np_fun = jtu.ignore_warning( + category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*")(np.argwhere) + jnp_fun = jnp.argwhere + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + + # JIT compilation requires specifying a size statically. Full test of this + # behavior is in testNonzeroSize(). + jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name), + shape=shape, dtype=dtype, axis=axis, rng_factory=rec.rng_factory) + for rec in JAX_ARGMINMAX_RECORDS + for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes) + for axis in range(-len(shape), len(shape))], + keepdims=[False, True], + ) + def testArgMinMax(self, np_op, jnp_op, rng_factory, shape, dtype, axis, keepdims): + rng = rng_factory(self.rng()) + if dtype == np.complex128 and jtu.device_under_test() == "gpu": + raise unittest.SkipTest("complex128 reductions not supported on GPU") + if "nan" in np_op.__name__ and dtype == jnp.bfloat16: + raise unittest.SkipTest("NumPy doesn't correctly handle bfloat16 arrays") + if numpy_version < (1, 22) and keepdims: + raise unittest.SkipTest("NumPy < 1.22 does not support keepdims argument to argmin/argmax") + kwds = {"keepdims": True} if keepdims else {} + + np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=axis, **kwds)) + jnp_fun = partial(jnp_op, axis=axis, **kwds) + + args_maker = lambda: [rng(shape, dtype)] + try: + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + except ValueError as e: + if str(e) == "All-NaN slice encountered": + self.skipTest("JAX doesn't support checking for all-NaN slices") + else: + raise + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(name=rec.name, np_op=getattr(np, rec.name), + jnp_op=getattr(jnp, rec.name)) + for rec in JAX_ARGMINMAX_RECORDS], + ) + def testArgMinMaxEmpty(self, name, np_op, jnp_op): + self.skipTest("JAX metal does crash") + name = name[3:] if name.startswith("nan") else name + msg = f"attempt to get {name} of an empty sequence" + with self.assertRaises(ValueError, msg=msg): + jnp_op(np.array([])) + with self.assertRaises(ValueError, msg=msg): + jnp_op(np.zeros((2, 0)), axis=1) + np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=0)) + jnp_fun = partial(jnp_op, axis=0) + args_maker = lambda: [np.zeros((2, 0))] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes) + for lhs_shape, rhs_shape, axes in [ + [(2,), (2,), (-1, -1, -1, None)], # scalar output + [(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors + [(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors + [(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting + [(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes + [(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting + [(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors + [(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting + [(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing + [(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)] # same as before + ]], + lhs_dtype=number_dtypes, + rhs_dtype=number_dtypes, + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + axisa, axisb, axisc, axis = axes + jnp_fun = lambda a, b: jnp.cross(a, b, axisa, axisb, axisc, axis) + def np_fun(a, b): + a = a.astype(np.float32) if lhs_dtype == jnp.bfloat16 else a + b = b.astype(np.float32) if rhs_dtype == jnp.bfloat16 else b + out = np.cross(a, b, axisa, axisb, axisc, axis) + return out.astype(jnp.promote_types(lhs_dtype, rhs_dtype)) + tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15} + tol = max(jtu.tolerance(lhs_dtype, tol_spec), + jtu.tolerance(rhs_dtype, tol_spec)) + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) + + @jtu.sample_product( + [dict(name=name, lhs_shape=lhs_shape, rhs_shape=rhs_shape) + for name, lhs_shape, rhs_shape in [ + ("matrix-scalar", (3, 3), ()), + ("scalar-matrix", (), (3, 3)), + # ("matrix-vector", (4, 5), (5,)), + ("vector-matrix", (6,), (6, 4)), + ("matrix-matrix", (3, 4), (4, 5)), + # ("tensor-vector", (4, 3, 2), (2,)), + ("vector-tensor", (2,), (3, 2, 4)), + ("tensor-matrix", (4, 3, 2), (2, 5)), + ("matrix-tensor", (5, 2), (3, 2, 4)), + ("tensor-tensor", (2, 3, 4), (5, 4, 1))]], + lhs_dtype=float_dtypes, + rhs_dtype=float_dtypes, + ) + @jax.default_matmul_precision("float32") + def testDot(self, name, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + tol = {np.float16: 1e-2, np.float32: 2e-5, np.float64: 1e-14, + np.complex128: 1e-14} + def np_dot(x, y): + x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x + y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y + return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype)) + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, tol=tol) + self._CompileAndCheck(jnp.dot, args_maker, atol=tol, rtol=tol) + + @jtu.sample_product( + [dict(name=name, lhs_shape=lhs_shape, rhs_shape=rhs_shape) + for name, lhs_shape, rhs_shape in [ + ("vector-vector", (3,), (3,)), + ("matrix-vector", (3, 3), (3,)), + ("vector-matrix", (3,), (3, 3)), + ("matrix-matrix", (3, 3), (3, 3)), + ("vector-tensor", (3,), (5, 3, 2)), + ("tensor-vector", (5, 3, 2), (2,)), + ("matrix-tensor", (5, 2), (3, 2, 4)), + ("tensor-matrix", (5, 2, 3), (3, 2)), + ("tensor-tensor", (5, 3, 4), (5, 4, 1)), + ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]], + lhs_dtype=float_dtypes, + rhs_dtype=float_dtypes, + ) + @jax.default_matmul_precision("float32") + def testMatmul(self, name, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): + rng = jtu.rand_default(self.rng()) + def np_fun(x, y): + dtype = jnp.promote_types(lhs_dtype, rhs_dtype) + return np.matmul(x, y).astype(dtype) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12, + np.complex128: 1e-12} + + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol) + self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol) + + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes) + for lhs_shape, rhs_shape, axes in [ + [(3,), (), 0], + [(2, 3, 4), (5, 6, 7), 0], # from issue #740 + [(2, 3, 4), (3, 4, 5, 6), 2], + [(2, 3, 4), (5, 4, 3, 6), [1, 2]], + [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]], + [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]], + ]], + lhs_dtype=float_dtypes, + rhs_dtype=float_dtypes, + ) + @jax.default_matmul_precision("float32") + def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + jnp_fun = lambda a, b: jnp.tensordot(a, b, axes) + def np_fun(a, b): + a = a if lhs_dtype != jnp.bfloat16 else a.astype(np.float32) + b = b if rhs_dtype != jnp.bfloat16 else b.astype(np.float32) + dtype = jnp.promote_types(lhs_dtype, rhs_dtype) + return np.tensordot(a, b, axes).astype(dtype) + tol = {np.float16: 1e-1, np.float32: 1e-3, np.float64: 1e-12, + np.complex64: 1e-3, np.complex128: 1e-12} + + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + def testTensordotErrors(self): + a = self.rng().random((3, 2, 2)) + b = self.rng().random((2,)) + self.assertRaisesRegex( + TypeError, "Number of tensordot axes.*exceeds input ranks.*", + lambda: jnp.tensordot(a, b, axes=2)) + + self.assertRaisesRegex( + TypeError, "tensordot requires axes lists to have equal length.*", + lambda: jnp.tensordot(a, b, axes=([0], [0, 1]))) + + self.assertRaisesRegex( + TypeError, "tensordot requires both axes lists to be either ints, tuples or lists.*", + lambda: jnp.tensordot(a, b, axes=('bad', 'axes'))) + + self.assertRaisesRegex( + TypeError, "tensordot axes argument must be an int, a pair of ints, or a pair of lists.*", + lambda: jnp.tensordot(a, b, axes='badaxes')) + + @jtu.sample_product( + element_shape=all_shapes, + test_shape=all_shapes, + dtype=float_dtypes, # TODO: int_dtypes are not working + # dtype=default_dtypes, + invert=[False, True], + ) + def testIsin(self, element_shape, test_shape, dtype, invert): + if ((not np.all(test_shape)) or (not np.all(element_shape))): + self.skipTest("JAX metal does not support dynamic shape") + if not ReportedIssuesTests.jax_metal_supported('0.0.6'): + self.skipTest("JAx metal has a regression on the version") + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] + jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert) + np_fun = lambda e, t: np.isin(e, t, invert=invert) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonzerodim_shapes + for axis in (None, *range(len(shape))) + ], + dtype=float_dtypes, + ) + def testSort(self, dtype, shape, axis): + if (axis != 0): + return + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + jnp_fun = jnp.sort + np_fun = np.sort + if axis is not None: + jnp_fun = partial(jnp_fun, axis=axis) + np_fun = partial(np_fun, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in [(2,3,4)] + for axis in (-1, *range(len(shape))) + ], + dtype=float_dtypes, + ) + def testSortKeyVal(self, dtype, shape, axis): + #self.skipTest("JAX metal does not support sortkeyval.") + rng = jtu.rand_default(self.rng()) + # This test relies on the property that wherever keys are tied, values are + # too, since we don't guarantee the same ordering of values with equal keys. + # To avoid that case, we generate unique keys (globally in the key array). + def args_maker(): + flat_keys = np.arange(math.prod(shape), dtype=dtype) + keys = self.rng().permutation(flat_keys).reshape(shape) + values = rng(shape, dtype) + return keys, values + + fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable=True) + numpy_op = lambda ks, vs: lax_reference.sort_key_val(ks, vs, axis) + self._CheckAgainstNumpy(fun, numpy_op, args_maker) + + @jtu.sample_product( + [dict(shifts=shifts, axis=axis) + for shifts, axis in [ + (3, None), + (1, 1), + # ((3,), (0,)), + ((-2,), (-2,)), + # ((1, 2), (0, -1)), + ((4, 2, 5, 5, 2, 4), None), + (100, None), + ] + ], + dtype=all_dtypes, + shape=[(3, 4), (3, 4, 5)], + ) + def testRoll(self, shape, dtype, shifts, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), np.array(shifts)] + jnp_op = partial(jnp.roll, axis=axis) + np_op = partial(np.roll, axis=axis) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + dtype=[np.uint8, np.bool_], + bitorder=['big', 'little'], + shape=[(1, 2, 3, 4)], + axis=[None, 0, 1, -2, -1], + ) + def testPackbits(self, shape, dtype, axis, bitorder): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder) + np_op = partial(np.packbits, axis=axis, bitorder=bitorder) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + dtype=[np.uint8], + bitorder=['big', 'little'], + shape=[(1, 2, 3, 4)], + axis=[None, 0, 1, -2, -1], + count=[None, 20], + ) + def testUnpackbits(self, shape, dtype, axis, bitorder, count): + rng = jtu.rand_int(self.rng(), 0, 256) + args_maker = lambda: [rng(shape, dtype)] + jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder) + np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + + + @jtu.sample_product( + shape=array_shapes, + dtype=all_dtypes, + ) + def testNanToNum(self, shape, dtype): + rng = jtu.rand_some_inf_and_nan(self.rng()) + dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type + def np_fun(x): + if dtype == jnp.bfloat16: + x = np.where(np.isnan(x), dtype(0), x) + x = np.where(np.isposinf(x), jnp.finfo(dtype).max, x) + x = np.where(np.isneginf(x), jnp.finfo(dtype).min, x) + return x + else: + return np.nan_to_num(x).astype(dtype) + + args_maker = lambda: [rng(shape, dtype)] + check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE + self._CheckAgainstNumpy(np_fun, jnp.nan_to_num, args_maker, + check_dtypes=check_dtypes) + self._CompileAndCheck(jnp.nan_to_num, args_maker, + check_dtypes=check_dtypes) + # self._CompileAndCheck(jnp_op, args_maker) + + + @jtu.sample_product( + [dict(shapes=shapes, dtypes=dtypes) + for shapes, dtypes in ( + ((), ()), + (((7,),), (np.int32,)), + (((3,), (4,)), (np.int32, np.int32)), + (((3,), (1,), (4,)), (np.int32, np.int32, np.int32)), + ) + ], + ) + def testIx_(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype) + for shape, dtype in zip(shapes, dtypes)] + self._CheckAgainstNumpy(np.ix_, jnp.ix_, args_maker) + self._CompileAndCheck(jnp.ix_, args_maker) + + @jtu.sample_product( + dimensions=[(2,), (4, 5, 6)], + dtype=number_dtypes, + sparse=[True, False], + ) + def testIndices(self, dimensions, dtype, sparse): + if jtu.device_under_test() == "tpu" and dtype in (np.int16, np.uint16): + raise unittest.SkipTest("Compilation failure on TPU ") + def args_maker(): return [] + np_fun = partial(np.indices, dimensions=dimensions, + dtype=dtype, sparse=sparse) + jnp_fun = partial(jnp.indices, dimensions=dimensions, + dtype=dtype, sparse=sparse) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(op=op, q_rng=q_rng) + for (op, q_rng) in ( + ("percentile", partial(jtu.rand_uniform, low=0., high=100.)), + ("quantile", partial(jtu.rand_uniform, low=0., high=1.)), + ("nanpercentile", partial(jtu.rand_uniform, low=0., high=100.)), + ("nanquantile", partial(jtu.rand_uniform, low=0., high=1.)), + ) + ], + [dict(a_shape=a_shape, axis=axis) + for a_shape, axis in ( + ((7,), None), + ((47, 7), 0), + ((47, 7), ()), + ((4, 101), 1), + ((4, 47, 7), (1, 2)), + ((4, 47, 7), (0, 2)), + ((4, 47, 7), (1, 0, 2)), + ) + ], + a_dtype=default_dtypes, + q_dtype=[np.float32], + q_shape=scalar_shapes + [(1,), (4,)], + keepdims=[False, True], + method=['linear', 'lower', 'higher', 'nearest', 'midpoint'], + ) + def testQuantile(self, op, q_rng, a_shape, a_dtype, q_shape, q_dtype, + axis, keepdims, method): + raise unittest.SkipTest("JAX metal legalization error with Sort ") + a_rng = jtu.rand_some_nan(self.rng()) + q_rng = q_rng(self.rng()) + if "median" in op: + args_maker = lambda: [a_rng(a_shape, a_dtype)] + else: + args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)] + + @jtu.ignore_warning(category=RuntimeWarning, + message="All-NaN slice encountered") + def np_fun(*args): + args = [x if jnp.result_type(x) != jnp.bfloat16 else + np.asarray(x, np.float32) for x in args] + if numpy_version <= (1, 22): + return getattr(np, op)(*args, axis=axis, keepdims=keepdims, + interpolation=method) + else: + return getattr(np, op)(*args, axis=axis, keepdims=keepdims, + method=method) + jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims, + method=method) + + # TODO(phawkins): we currently set dtype=False because we aren't as + # aggressive about promoting to float64. It's not clear we want to mimic + # Numpy here. + tol_spec = {np.float16: 1E-2, np.float32: 2e-4, np.float64: 5e-6} + tol = max(jtu.tolerance(a_dtype, tol_spec), + jtu.tolerance(q_dtype, tol_spec)) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) + + @jtu.sample_product( + [dict(a_shape=a_shape, axis=axis) + for a_shape, axis in ( + ((7,), None), + ((47, 7), 0), + ((4, 101), 1), + ) + ], + a_dtype=default_dtypes, + keepdims=[False, True], + op=["median", "nanmedian"], + ) + def testMedian(self, op, a_shape, a_dtype, axis, keepdims): + raise unittest.SkipTest("JAX metal legalization error with Sort ") + if op == "median": + a_rng = jtu.rand_default(self.rng()) + else: + a_rng = jtu.rand_some_nan(self.rng()) + args_maker = lambda: [a_rng(a_shape, a_dtype)] + def np_fun(*args): + args = [x if jnp.result_type(x) != jnp.bfloat16 else + np.asarray(x, np.float32) for x in args] + return getattr(np, op)(*args, axis=axis, keepdims=keepdims) + jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims) + # TODO(phawkins): we currently set dtype=False because we aren't as + # aggressive about promoting to float64. It's not clear we want to mimic + # Numpy here. + tol_spec = {np.float32: 2e-4, np.float64: 5e-6} + tol = jtu.tolerance(a_dtype, tol_spec) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) + + @jtu.sample_product( + [dict(n=n, shapes=shapes) + for n in range(1, 3) + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(all_shapes, 2 * n + 1)) + ], + # To avoid forming the full product of shapes and dtypes we always sample + # maximal set of dtypes. + dtypes=itertools.combinations_with_replacement(all_dtypes, 3), + ) + def testSelect(self, n, shapes, dtypes): + dtypes = dtypes[:n+1] + rng = jtu.rand_default(self.rng()) + n = len(dtypes) - 1 + def args_maker(): + condlist = [rng(shape, np.bool_) for shape in shapes[:n]] + choicelist = [rng(shape, dtype) + for shape, dtype in zip(shapes[n:-1], dtypes[:n])] + default = rng(shapes[-1], dtypes[-1]) + return condlist, choicelist, default + # TODO(phawkins): float32/float64 type mismatches + @jax.numpy_dtype_promotion('standard') + def np_fun(condlist, choicelist, default): + choicelist = [x if jnp.result_type(x) != jnp.bfloat16 + else x.astype(np.float32) for x in choicelist] + dtype = jnp.result_type(default, *choicelist) + return np.select(condlist, + [np.asarray(x, dtype=dtype) for x in choicelist], + np.asarray(default, dtype=dtype)) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(np_fun, jnp.select, args_maker, + check_dtypes=False) + self._CompileAndCheck(jnp.select, args_maker, + rtol={np.float64: 1e-7, np.complex128: 1e-7}) + + # def testIssue330(self): + # x = jnp.full((1, 1), jnp.array([1])[0]) # doesn't crash + # self.assertEqual(x[0, 0], 1) + + def testScalarDtypePromotion(self): + orig_numpy_result = (1 + np.eye(1, dtype=np.float32)).dtype + jax_numpy_result = (1 + jnp.eye(1, dtype=jnp.float32)).dtype + self.assertEqual(orig_numpy_result, jax_numpy_result) + + def testSymmetrizeDtypePromotion(self): + x = np.eye(3, dtype=np.float32) + orig_numpy_result = ((x + x.T) / 2).dtype + + x = jnp.eye(3, dtype=jnp.float32) + jax_numpy_result = ((x + x.T) / 2).dtype + self.assertEqual(orig_numpy_result, jax_numpy_result) + + def testIssue453(self): + # https://github.com/google/jax/issues/453 + a = np.arange(6) + 1 + ans = jnp.reshape(a, (3, 2), order='F') + expected = np.reshape(a, (3, 2), order='F') + self.assertAllClose(ans, expected) + + @jtu.sample_product( + input=[ + 3, + [3], + [np.array(3)], + [np.array([3])], + [[np.array(3)]], + [[np.array([3])]], + [3, 4, 5], + [ + [np.eye(2, dtype=np.int32) * 2, np.zeros((2, 3), dtype=np.int32)], + [np.ones((3, 2), dtype=np.int32), np.eye(3, dtype=np.int32) * 3], + ], + [np.array([1, 2, 3]), np.array([2, 3, 4]), 10], + [np.ones((2, 2), dtype=np.int32), np.zeros((2, 2), dtype=np.int32)], + [[np.array([1, 2, 3])], [np.array([2, 3, 4])]], + ], + ) + def testBlock(self, input): + args_maker = lambda: [input] + self._CheckAgainstNumpy(np.block, jnp.block, args_maker) + self._CompileAndCheck(jnp.block, args_maker) + + def testLongLong(self): + self.assertAllClose(np.int64(7), jax.jit(lambda x: x)(np.longlong(7))) + + @jtu.sample_product( + shape=[(5,), (10, 5), (4, 10)], + dtype=number_dtypes, + rowvar=[True, False], + ) + @jax.default_matmul_precision("float32") + def testCorrCoef(self, shape, dtype, rowvar): + rng = jtu.rand_default(self.rng()) + def args_maker(): + ok = False + while not ok: + x = rng(shape, dtype) + ok = not np.any(np.isclose(np.std(x), 0.0)) + return (x,) + np_fun = partial(np.corrcoef, rowvar=rowvar) + np_fun = jtu.ignore_warning( + category=RuntimeWarning, message="invalid value encountered.*")(np_fun) + jnp_fun = partial(jnp.corrcoef, rowvar=rowvar) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(dtype=dtype, end_dtype=end_dtype, begin_dtype=begin_dtype, + shape=shape, begin_shape=begin_shape, end_shape=end_shape) + for dtype in number_dtypes + for end_dtype in [None] + [dtype] + for begin_dtype in [None] + [dtype] + for shape in [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE] + for begin_shape in ( + [None] if begin_dtype is None + else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]) + for end_shape in ( + [None] if end_dtype is None + else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]) + ], + ) + def testEDiff1d(self, shape, dtype, end_shape, end_dtype, begin_shape, + begin_dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), + (None if end_dtype is None else rng(end_shape, end_dtype)), + (None if begin_dtype is None else rng(begin_shape, begin_dtype))] + np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin) + jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testEDiff1dWithDtypeCast(self): + rng = jtu.rand_default(self.rng()) + shape = jtu.NUMPY_SCALAR_SHAPE + dtype = jnp.float32 + end_dtype = jnp.int32 + args_maker = lambda: [rng(shape, dtype), rng(shape, end_dtype), rng(shape, dtype)] + np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin) + jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shapes=[(), (5,), (5, 3)], + dtype=number_dtypes, + indexing=['xy', 'ij'], + sparse=[True, False], + ) + def testMeshGrid(self, shapes, dtype, indexing, sparse): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [(x,) for x in shapes], + [dtype] * len(shapes)) + np_fun = partial(np.meshgrid, indexing=indexing, sparse=sparse) + jnp_fun = partial(jnp.meshgrid, indexing=indexing, sparse=sparse) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testOgrid(self): + # wrap indexer for appropriate dtype defaults. + np_ogrid = _indexer_with_default_outputs(np.ogrid) + def assertListOfArraysEqual(xs, ys): + self.assertIsInstance(xs, list) + self.assertIsInstance(ys, list) + self.assertEqual(len(xs), len(ys)) + for x, y in zip(xs, ys): + self.assertArraysEqual(x, y) + + self.assertArraysEqual(np_ogrid[1:7:2], jnp.ogrid[1:7:2]) + # List of arrays + assertListOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,]) + assertListOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3]) + assertListOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3]) + assertListOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11]) + + def testR_(self): + a = np.arange(6).reshape((2,3)) + self.assertArraysEqual(np.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])], + jnp.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])]) + self.assertArraysEqual(np.r_['-1', a, a], jnp.r_['-1', a, a]) + + self.assertArraysEqual(np.r_['0,2', [1,2,3], [4,5,6]], jnp.r_['0,2', [1,2,3], [4,5,6]]) + self.assertArraysEqual(np.r_['0,2,0', [1,2,3], [4,5,6]], jnp.r_['0,2,0', [1,2,3], [4,5,6]]) + self.assertArraysEqual(np.r_['1,2,0', [1,2,3], [4,5,6]], jnp.r_['1,2,0', [1,2,3], [4,5,6]]) + # negative 1d axis start + self.assertArraysEqual(np.r_['0,4,-1', [1,2,3], [4,5,6]], jnp.r_['0,4,-1', [1,2,3], [4,5,6]]) + self.assertArraysEqual(np.r_['0,4,-2', [1,2,3], [4,5,6]], jnp.r_['0,4,-2', [1,2,3], [4,5,6]]) + + # matrix directives + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=PendingDeprecationWarning) + self.assertArraysEqual(np.r_['r',[1,2,3], [4,5,6]], jnp.r_['r',[1,2,3], [4,5,6]]) + self.assertArraysEqual(np.r_['c', [1, 2, 3], [4, 5, 6]], jnp.r_['c', [1, 2, 3], [4, 5, 6]]) + + # bad directive + with self.assertRaisesRegex(ValueError, "could not understand directive.*"): + jnp.r_["asdfgh",[1,2,3]] + # abstract tracer value for r_ slice + with self.assertRaisesRegex(core.ConcretizationTypeError, + "slice start of jnp.r_"): + jax.jit(lambda a, b: jnp.r_[a:b])(0, 2) + + # wrap indexer for appropriate dtype defaults. + np_r_ = _indexer_with_default_outputs(np.r_) + + # Complex number steps + atol = 1e-6 + rtol = 1e-6 + self.assertAllClose(np_r_[-1:1:6j], + jnp.r_[-1:1:6j], + atol=atol, + rtol=rtol) + with jax.numpy_dtype_promotion('standard'): # Requires dtype promotion. + self.assertAllClose(np_r_[-1:1:6j, [0]*3, 5, 6], + jnp.r_[-1:1:6j, [0]*3, 5, 6], + atol=atol, + rtol=rtol) + # Non-integer steps + self.assertAllClose(np_r_[1.2:4.8:0.24], + jnp.r_[1.2:4.8:0.24], + atol=atol, + rtol=rtol) + + def testC_(self): + a = np.arange(6).reshape((2, 3)) + self.assertArraysEqual(np.c_[np.array([1,2,3]), np.array([4,5,6])], + jnp.c_[np.array([1,2,3]), np.array([4,5,6])]) + self.assertArraysEqual(np.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])], + jnp.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])]) + self.assertArraysEqual(np.c_['-1', a, a], jnp.c_['-1', a, a]) + + self.assertArraysEqual(np.c_['0,2', [1,2,3], [4,5,6]], jnp.c_['0,2', [1,2,3], [4,5,6]]) + self.assertArraysEqual(np.c_['0,2,0', [1,2,3], [4,5,6]], jnp.c_['0,2,0', [1,2,3], [4,5,6]]) + self.assertArraysEqual(np.c_['1,2,0', [1,2,3], [4,5,6]], jnp.c_['1,2,0', [1,2,3], [4,5,6]]) + # negative 1d axis start + self.assertArraysEqual(np.c_['0,4,-1', [1,2,3], [4,5,6]], jnp.c_['0,4,-1', [1,2,3], [4,5,6]]) + self.assertArraysEqual(np.c_['0,4,-2', [1,2,3], [4,5,6]], jnp.c_['0,4,-2', [1,2,3], [4,5,6]]) + # matrix directives, avoid numpy deprecation warning + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=PendingDeprecationWarning) + self.assertArraysEqual(np.c_['r',[1,2,3], [4,5,6]], jnp.c_['r',[1,2,3], [4,5,6]]) + self.assertArraysEqual(np.c_['c', [1, 2, 3], [4, 5, 6]], jnp.c_['c', [1, 2, 3], [4, 5, 6]]) + + # bad directive + with self.assertRaisesRegex(ValueError, "could not understand directive.*"): + jnp.c_["asdfgh",[1,2,3]] + # abstract tracer value for c_ slice + with self.assertRaisesRegex(core.ConcretizationTypeError, + "slice start of jnp.c_"): + jax.jit(lambda a, b: jnp.c_[a:b])(0, 2) + + # wrap indexer for appropriate dtype defaults. + np_c_ = _indexer_with_default_outputs(np.c_) + + # Complex number steps + atol = 1e-6 + rtol = 1e-6 + self.assertAllClose(np_c_[-1:1:6j], + jnp.c_[-1:1:6j], + atol=atol, + rtol=rtol) + + # Non-integer steps + self.assertAllClose(np_c_[1.2:4.8:0.24], + jnp.c_[1.2:4.8:0.24], + atol=atol, + rtol=rtol) + + def testS_(self): + self.assertEqual(np.s_[1:2:20],jnp.s_[1:2:20]) + + def testIndex_exp(self): + self.assertEqual(np.index_exp[5:3:2j],jnp.index_exp[5:3:2j]) + + @jtu.sample_product( + start_shape=[(2,), (2, 2)], + stop_shape=[(2,), (2, 2)], + num=[1, 2, 5, 20], + endpoint=[True, False], + retstep=[True, False], + # floating-point compute between jitted platforms and non-jit + rounding + # cause unavoidable variation in integer truncation for some inputs, so + # we currently only test inexact 'dtype' arguments. + dtype=inexact_dtypes + [None,], + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLinspace(self, start_shape, stop_shape, num, endpoint, retstep, dtype): + rng = jtu.rand_default(self.rng()) + # relax default tolerances slightly + tol = jtu.tolerance(dtype if dtype else np.float32) * 10 + args_maker = self._GetArgsMaker(rng, + [start_shape, stop_shape], + [dtype, dtype]) + start, stop = args_maker() + ndim = len(np.shape(start + stop)) + for axis in range(-ndim, ndim): + jnp_op = lambda start, stop: jnp.linspace( + start, stop, num, + endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) + np_op = lambda start, stop: np.linspace( + start, stop, num, + endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) + + self._CheckAgainstNumpy(np_op, jnp_op, args_maker, + check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp_op, args_maker, + check_dtypes=False, atol=tol, rtol=tol) + + @jtu.sample_product(dtype=number_dtypes) + def testLinspaceEndpoints(self, dtype): + """Regression test for Issue #3014.""" + rng = jtu.rand_default(self.rng()) + endpoints = rng((2,), dtype) + out = jnp.linspace(*endpoints, 10, dtype=dtype) + self.assertAllClose(out[np.array([0, -1])], endpoints, rtol=0, atol=0) + + @jtu.sample_product( + start_shape=[(2,), (2, 2)], + stop_shape=[(2,), (2, 2)], + num=[1, 2, 5, 20], + endpoint=[True, False], + base=[10.0, 2, np.e], + # skip 16-bit floats due to insufficient precision for the test. + dtype=[np.float32] + [None,], + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLogspace(self, start_shape, stop_shape, num, + endpoint, base, dtype): + if (dtype in int_dtypes and + jtu.device_under_test() in ("gpu", "tpu", "METAL") and + not config.x64_enabled): + raise unittest.SkipTest("GPUx32 truncated exponentiation" + " doesn't exactly match other platforms.") + rng = jtu.rand_default(self.rng()) + # relax default tolerances slightly + tol = {np.float32: 1e-2, np.float64: 1e-6, np.complex64: 1e-3, np.complex128: 1e-6} + args_maker = self._GetArgsMaker(rng, + [start_shape, stop_shape], + [dtype, dtype]) + start, stop = args_maker() + ndim = len(np.shape(start + stop)) + for axis in range(-ndim, ndim): + jnp_op = lambda start, stop: jnp.logspace( + start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis) + @jtu.ignore_warning(category=RuntimeWarning, + message="overflow encountered in power") + def np_op(start, stop): + return np.logspace(start, stop, num, endpoint=endpoint, + base=base, dtype=dtype, axis=axis) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker, + check_dtypes=False, tol=tol) + if dtype in (inexact_dtypes + [None,]): + # Why do compiled and op-by-op float16 np.power numbers differ + # slightly more than expected? + atol = {np.float16: 1e-2} + self._CompileAndCheck(jnp_op, args_maker, + check_dtypes=False, atol=atol, rtol=tol) + + def testStackArrayArgument(self): + # tests https://github.com/google/jax/issues/1271 + @jax.jit + def foo(x): + return jnp.stack(x) + foo(np.zeros(2)) # doesn't crash + + @jax.jit + def foo(x): + return jnp.concatenate(x) + foo(np.zeros((2, 2))) # doesn't crash + + def testReluGradientConstants(self): + # This is a regression test that verifies that constants associated with the + # gradient of np.maximum (from lax._balanced_eq) aren't hoisted into the + # outermost jaxpr. This was producing some large materialized constants for + # every relu activation in a model. + def body(i, xy): + x, y = xy + y = y + jax.grad(lambda z: jnp.sum(jnp.maximum(z, 0.)))(x) + return x, y + + f = lambda y: lax.fori_loop(0, 5, body, (y, y)) + jaxpr = jax.make_jaxpr(f)(np.zeros((3, 4), np.float32)) + self.assertFalse( + any(np.array_equal(x, np.full((3, 4), 2., dtype=np.float32)) + for x in jaxpr.consts)) + + @jtu.sample_product( + [dict(from_shape=from_shape, to_shape=to_shape) + for from_shape, to_shape in [ + [(1, 3), (4, 3)], + [(3,), (2, 1, 3)], + [(3,), (3, 3)], + [(1,), (3,)], + [(1,), 3], + ] + ], + ) + def testBroadcastTo(self, from_shape, to_shape): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [from_shape], [np.float32]) + np_op = lambda x: np.broadcast_to(x, to_shape) + jnp_op = lambda x: jnp.broadcast_to(x, to_shape) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + [dict(shapes=shapes, broadcasted_shape=broadcasted_shape) + for shapes, broadcasted_shape in [ + [[], ()], + [[()], ()], + [[(1, 3), (4, 3)], (4, 3)], + [[(3,), (2, 1, 3)], (2, 1, 3)], + [[(3,), (3, 3)], (3, 3)], + [[(1,), (3,)], (3,)], + [[(1,), 3], (3,)], + [[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)], + [[[1], [0, 1]], (0, 1)], + [[(1,), np.array([0, 1])], (0, 1)], + ] + ], + ) + def testBroadcastShapes(self, shapes, broadcasted_shape): + # Test against np.broadcast_shapes once numpy 1.20 is minimum required version + np.testing.assert_equal(jnp.broadcast_shapes(*shapes), broadcasted_shape) + + def testBroadcastToIssue1522(self): + self.assertRaisesRegex( + ValueError, "Incompatible shapes for broadcasting: .*", + lambda: jnp.broadcast_to(np.ones((2, 3)), (1, 3))) + + def testBroadcastToIntIssue1548(self): + self.assertAllClose(jnp.broadcast_to(1, (3, 2)), np.ones((3, 2)), + check_dtypes=False) + + def testBroadcastToOnScalar(self): + self.assertIsInstance(jnp.broadcast_to(10.0, ()), jax.Array) + self.assertIsInstance(np.broadcast_to(10.0, ()), np.ndarray) + + def testPrecision(self): + + ones_1d = np.ones((2,)) + ones_2d = np.ones((2, 2)) + ones_3d = np.ones((2, 2, 2)) + HIGHEST = lax.Precision.HIGHEST + + jtu.assert_dot_precision(None, jnp.dot, ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.dot, precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.dot, precision=HIGHEST), + ones_3d, ones_3d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.matmul, precision=HIGHEST), + ones_2d, ones_2d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.vdot, precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.tensordot, axes=2, precision=HIGHEST), + ones_2d, ones_2d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.tensordot, axes=(0, 0), precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.tensordot, axes=((0,), (0,)), precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.einsum, 'i,i', precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.einsum, 'ij,ij', precision=HIGHEST), + ones_2d, ones_2d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.inner, precision=HIGHEST), + ones_1d, ones_1d) + + @jtu.sample_product( + [dict(shape=shape, varargs=varargs, axis=axis) + for shape in [(10,), (10, 15), (10, 15, 20)] + for _num_axes in range(len(shape)) + for varargs in itertools.combinations(range(1, len(shape) + 1), _num_axes) + for axis in itertools.combinations(range(len(shape)), _num_axes) + ], + dtype=inexact_dtypes, + ) + def testGradient(self, shape, varargs, axis, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + jnp_fun = lambda y: jnp.gradient(y, *varargs, axis=axis) + np_fun = lambda y: np.gradient(y, *varargs, axis=axis) + self._CheckAgainstNumpy( + np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + def testZerosShapeErrors(self): + # see https://github.com/google/jax/issues/1822 + self.assertRaisesRegex( + TypeError, + "Shapes must be 1D sequences of concrete values of integer type.*", + lambda: jnp.zeros(1.)) + self.assertRaisesRegex( + TypeError, + r"Shapes must be 1D sequences of concrete values of integer type.*\n" + "If using `jit`, try using `static_argnums` or applying `jit` to " + "smaller subfunctions.", + lambda: jax.jit(jnp.zeros)(2)) + + def testTraceMethod(self): + x = self.rng().randn(3, 4).astype(jnp.float_) + self.assertAllClose(x.trace(), jnp.array(x).trace()) + self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x)) + + def testIntegerPowersArePrecise(self): + # See https://github.com/google/jax/pull/3036 + # Checks if the squares of float32 integers have no numerical errors. + # It should be satisfied with all integers less than sqrt(2**24). + x = jnp.arange(-2**12, 2**12, dtype=jnp.int32) + np.testing.assert_array_equal(jnp.square(x.astype(jnp.float32)), x * x) + np.testing.assert_array_equal(x.astype(jnp.float32) ** 2, x * x) + + # Similarly for cubes. + x = jnp.arange(-2**8, 2**8, dtype=jnp.int32) + np.testing.assert_array_equal(x.astype(jnp.float32) ** 3, x * x * x) + + x = np.arange(10, dtype=np.float32) + for i in range(10): + self.assertAllClose(x.astype(jnp.float32) ** i, x ** i, + check_dtypes=False) + + def testToBytes(self): + v = np.arange(12, dtype=np.int32).reshape(3, 4) + for order in ['C', 'F']: + self.assertEqual(jnp.asarray(v).tobytes(order), v.tobytes(order)) + + def testToList(self): + v = np.arange(12, dtype=np.int32).reshape(3, 4) + self.assertEqual(jnp.asarray(v).tolist(), v.tolist()) + + @jtu.skip_on_devices("METAL") + def testArangeConcretizationError(self): + msg = r"It arose in jax.numpy.arange argument `{}`".format + with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')): + jax.jit(jnp.arange)(3) + + with self.assertRaisesRegex(core.ConcretizationTypeError, msg('start')): + jax.jit(lambda start: jnp.arange(start, 3))(0) + + with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')): + jax.jit(lambda stop: jnp.arange(0, stop))(3) + + @jtu.sample_product(dtype=[None] + float_dtypes) + def testArange64Bit(self, dtype): + # Test that jnp.arange uses 64-bit arithmetic to define its range, even if the + # output has another dtype. The issue here is that if python scalar inputs to + # jnp.arange are cast to float32 before the range is computed, it changes the + # number of elements output by the range. It's unclear whether this was deliberate + # behavior in the initial implementation, but it's behavior that downstream users + # have come to rely on. + args = (1.2, 4.8, 0.24) + + # Ensure that this test case leads to differing lengths if cast to float32. + self.assertLen(np.arange(*args), 15) + self.assertLen(np.arange(*map(np.float32, args)), 16) + + jnp_fun = lambda: jnp.arange(*args, dtype=dtype) + np_fun = jtu.with_jax_dtype_defaults(lambda: np.arange(*args, dtype=dtype), dtype is None) + args_maker = lambda: [] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testIssue2347(self): + # https://github.com/google/jax/issues/2347 + object_list = List[Tuple[jnp.array, float, float, jnp.array, bool]] + self.assertRaises(TypeError, jnp.array, object_list) + + np_object_list = np.array(object_list) + self.assertRaises(TypeError, jnp.array, np_object_list) + + def testDefaultDtypes(self): + precision = config.default_dtype_bits.value + assert precision in ['32', '64'] + self.assertEqual(jnp.bool_, np.bool_) + self.assertEqual(jnp.int_, np.int32 if precision == '32' else np.int64) + self.assertEqual(jnp.uint, np.uint32 if precision == '32' else np.uint64) + self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64) + self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128) + + def testFromBuffer(self): + buf = b'\x01\x02\x03' + expected = np.frombuffer(buf, dtype='uint8') + actual = jnp.frombuffer(buf, dtype='uint8') + self.assertArraysEqual(expected, actual) + + def testFromFunction(self): + def f(x, y, z): + return x + 2 * y + 3 * z + shape = (3, 4, 5) + expected = np.fromfunction(f, shape=shape) + actual = jnp.fromfunction(f, shape=shape) + self.assertArraysEqual(expected, actual, check_dtypes=False) + + def testFromString(self): + s = "1,2,3" + expected = np.fromstring(s, sep=',', dtype=int) + actual = jnp.fromstring(s, sep=',', dtype=int) + self.assertArraysEqual(expected, actual) + + +# Most grad tests are at the lax level (see lax_test.py), but we add some here +# as needed for e.g. particular compound ops of interest. + +GradTestSpec = collections.namedtuple( + "GradTestSpec", + ["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"]) +def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): + return GradTestSpec( + op, nargs, order, rng_factory, dtypes, name or op.__name__, tol) + +GRAD_TEST_RECORDS = [ + grad_test_spec(jnp.arcsinh, nargs=1, order=2, + rng_factory=jtu.rand_positive, + dtypes=[np.float64, np.complex64], + tol={np.complex64: 2e-2}), + grad_test_spec(jnp.arccosh, nargs=1, order=1, + rng_factory=jtu.rand_positive, + dtypes=[np.float64, np.complex64], + tol={np.complex64: 2e-2}), + grad_test_spec(jnp.arctanh, nargs=1, order=2, + rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9), + dtypes=[np.float64, np.complex64], + tol={np.complex64: 2e-2}), + grad_test_spec(jnp.logaddexp, nargs=2, order=1, + rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9), + dtypes=[np.float64], tol=1e-4), + grad_test_spec(jnp.logaddexp2, nargs=2, order=2, + rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9), + dtypes=[np.float64], tol=1e-4), +] + +GradSpecialValuesTestSpec = collections.namedtuple( + "GradSpecialValuesTestSpec", ["op", "values", "order"]) + +GRAD_SPECIAL_VALUE_TEST_RECORDS = [ + GradSpecialValuesTestSpec(jnp.arcsinh, [0., 1000.], 2), + GradSpecialValuesTestSpec(jnp.arccosh, [1000.], 2), + GradSpecialValuesTestSpec(jnp.arctanh, [0.], 2), + GradSpecialValuesTestSpec(jnp.sinc, [0.], 1), +] + + +from jaxlib import xla_client +from jax_plugins import metal_plugin +class ReportedIssuesTests(jtu.JaxTestCase): + def dispatchOn(self, args, func, device=jax.devices('cpu')[0]): + deviceArgs = [] + for arg in args: + deviceArgs.append(jax.device_put(arg, device)) + return func(*deviceArgs) + + @staticmethod + def compile_and_exec(module, args, run_on_cpu=False): + backend = jax.lib.xla_bridge.get_backend('METAL') + if (run_on_cpu): + backend = jax.lib.xla_bridge.get_backend('cpu') + executables = backend.compile(module) + return xla_client.execute_with_python_values(executables, args, backend) + + @staticmethod + def jax_metal_supported(target_ver): + if metal_plugin is None or not hasattr(metal_plugin, 'version'): + return False + curr_ver = metal_plugin.version() + if hasattr(jtu, 'parse_version'): + return jtu.parse_version(curr_ver) >= jtu.parse_version(target_ver) + return False + + + #https://github.com/google/jax/issues/16420 + def test_broadcast_dim(self): + x = jnp.arange(2) + f = lambda x : jax.lax.broadcast_in_dim(x, (2, 2), (0,)) + res = f(x) + print(res) + res_cpu = self.dispatchOn([x],f) + jtu.check_eq(res, res_cpu) + f = lambda x : jax.lax.broadcast_in_dim(x, (2, 2), (1,)) + res = f(x) + print(res) + res_cpu = self.dispatchOn([x],f) + jtu.check_eq(res, res_cpu) + + def test_identity(self): + x = jnp.identity(4) + jtu.check_eq(x, np.identity(4)) + + def test_triu(self): + x = np.ones((4,4)) + res = jnp.triu(x) + jtu.check_eq(res, np.triu(x)) + + #https://github.com/google/jax/issues/16471 + def test_matmul_1d(self): + x = np.array(np.random.rand(3, 3)) + y = np.array(np.random.rand(3)) + z = np.array(np.random.rand(3)) + res = jnp.dot(y, z) + self.assertArraysAllClose(res, np.dot(y,z)) + res = jnp.dot(x, y) + self.assertArraysAllClose(res, np.dot(x,y)) + + #https://github.com/google/jax/issues/17175 + def test_indexing(self): + x = jnp.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], dtype=jnp.float32) + @jax.vmap + def f(i): + return x[i] + f = jax.jit(f) + idx = jnp.array([1,1,2,2,0]) + res = f(idx) + jtu.check_eq(res, np.array([[4., 5., 6.], [4., 5., 6.], [7., 8., 9.], [7., 8., 9.], [1., 2., 3.]])) + + #https://github.com/google/jax/issues/17344 + def test_take_along_axis(self): + @jax.jit + def f(): + idx = jnp.array([[0],[0],[0]]) + x = jnp.array([[0.3756883, 0.05820537, 0.7399422, 0.45242703], + [0.5848844, 0.18772626, 0.47942543, 0.20703673], + [0.1071583, 0.26139486, 0.25664794, 0.8109596]]) + return jnp.take_along_axis(x, idx, axis=1) + jtu.check_eq(f(), self.dispatchOn([], f)) + + #https://github.com/google/jax/issues/17590 + def test_in1d(self): + a = np.array([123,2,4]) + b = np.array([123,1]) + res = jnp.isin(a,b) + jtu.check_eq(res, np.isin(a, b)) + + def test_indexing_update(self): + x = jnp.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], dtype=jnp.float32) + @jax.vmap + def f(x): + return x.at[0].set(1.0) + f = jax.jit(f) + res= f(x) + jtu.check_eq(res, np.array([[1., 2., 3.], [1., 5., 6.,], [1., 8., 9.], [1., 11., 12.]])) + + #https://github.com/google/jax/issues/16326 + def test_indexing_update2(self): + @jax.jit + def f(x, r): + x = x.at[:, 0].set(x[:, 0] / r) + return x + x = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + fx = f(x, jnp.array([10.0])) + jtu.check_eq(fx, np.array([[0.1, 2.0], [0.3, 4.]])) + + def test_gather_ir(self): + ir = ''' +#loc = loc(unknown) +module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<3x2x3xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<3x2xi32> {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<3x2xf32> { + %0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 2, 1]> : tensor<3xi64>} : (tensor<3x2x3xf32>, tensor<3x2xi32>) -> tensor<3x2xf32> loc(#loc2) + return %0 : tensor<3x2xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc1 = loc("/Users/shuhan/Code/jax-metal/tests/lax_numpy_indexing_test.py":1156:0) +#loc2 = loc("jit(gather)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0, 2), start_index_map=(0, 2)) slice_sizes=(1, 2, 1) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.CLIP fill_value=None]"(#loc1)) + ''' + data = np.array([[[0.6369617, 0.26978672, 0.04097353], + [0.01652764, 0.8132702, 0.91275555]], + [[0.60663575, 0.72949654, 0.543625 ], + [0.9350724, 0.81585354, 0.0027385 ]], + [[0.8574043, 0.03358557, 0.72965544], + [0.17565562, 0.8631789, 0.5414612 ]]], dtype=np.float32) + index = np.array([[1, 0],[2, 1],[0, 2]], dtype=np.int32) + res = ReportedIssuesTests.compile_and_exec(ir, [data, index]) + res_ref = ReportedIssuesTests.compile_and_exec(ir, [data, index], run_on_cpu = True) + print(res) + jtu.check_eq(res, res_ref) + + #https://github.com/google/jax/issues/16366 + def test_pad_interior_1(self): + if not ReportedIssuesTests.jax_metal_supported('0.0.6'): + raise unittest.SkipTest("jax-metal version doesn't support it.") + ir = ''' + module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<128x7x7x64xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<128x15x15x64xf32> { + %206 = "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[0, 1, 1, 0]> : tensor<4xi64>, edge_padding_low = dense<[0, 1, 1, 0]> : tensor<4xi64>, interior_padding = dense<[0, 1, 1, 0]> : tensor<4xi64>} : (tensor<128x7x7x64xf32>, tensor) -> tensor<128x15x15x64xf32> + return %206 : tensor<128x15x15x64xf32> + } + } + ''' + data = np.random.rand(128,7,7,64).astype(np.float32) + padding = np.array(0.5, dtype=np.float32) + res = ReportedIssuesTests.compile_and_exec(ir, [data, padding]) + res_ref = ReportedIssuesTests.compile_and_exec(ir, [data, padding], run_on_cpu = True) + jtu.check_eq(res, res_ref) + + def test_pad_interior_2(self): + if not ReportedIssuesTests.jax_metal_supported('0.0.6'): + raise unittest.SkipTest("jax-metal version doesn't support it.") + batch = 2 + seq_len = 8 + num_decode = 32 + + seq = np.random.randint(size=(batch, seq_len, num_decode), low=0, high=256, dtype=np.uint8) + res = jnp.cumsum(seq, axis=-1) + res_ref = np.cumsum(seq, axis=-1, dtype=np.uint8) + jtu.check_eq(res, res_ref) + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 4c4bcde723c5272a05dad5ec476e25b78766c224 Mon Sep 17 00:00:00 2001 From: Shuhan Ding Date: Thu, 7 Mar 2024 16:33:03 -0800 Subject: [PATCH 2/3] pass set of lax_numpy_test --- tests/lax_metal_test.py | 4844 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 4484 insertions(+), 360 deletions(-) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index 59f8aa40c..4b96fd150 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -1,14 +1,30 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from array import array as make_python_array import collections import copy from functools import partial -import inspect import io import itertools import math -from typing import cast, Iterator, Optional, List, Tuple +import platform +from typing import Union, cast import unittest from unittest import SkipTest -import warnings from absl.testing import absltest from absl.testing import parameterized @@ -23,19 +39,22 @@ import jax import jax.ops from jax import lax from jax import numpy as jnp -from jax import tree_util -from jax.test_util import check_grads +from jax.sharding import SingleDeviceSharding -from jax._src import core from jax._src import array +from jax._src import config +from jax._src import core from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal -from jax._src.numpy.util import _parse_numpydoc, ParsedDoc -import jax.util -from jax._src import array -from jax._src import lax_reference -from jax._src import config + +from jax._src.util import safe_zip, NumpyComplexWarning + +try: + from jax_plugins import metal_plugin +except ImportError: + metal_plugin = None + config.parse_flags_with_absl() numpy_version = jtu.numpy_version() @@ -44,28 +63,48 @@ nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes one_dim_array_shapes = [(1,), (6,), (12,)] empty_array_shapes = [(0,), (0, 4), (3, 0),] +broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE] -array_shapes = nonempty_array_shapes -nonzerodim_shapes = nonempty_nonscalar_array_shapes +array_shapes = nonempty_array_shapes + empty_array_shapes +nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes nonempty_shapes = scalar_shapes + nonempty_array_shapes all_shapes = scalar_shapes + array_shapes -float_dtypes = [np.float16, np.float32] -int_dtypes = [np.int32] -unsigned_dtypes = [np.uint32] +float_dtypes = jtu.dtypes.all_floating +complex_dtypes = jtu.dtypes.complex +int_dtypes = jtu.dtypes.all_integer +unsigned_dtypes = jtu.dtypes.all_unsigned bool_dtypes = jtu.dtypes.boolean default_dtypes = float_dtypes + int_dtypes -inexact_dtypes = float_dtypes -number_dtypes = float_dtypes + int_dtypes + unsigned_dtypes +inexact_dtypes = float_dtypes + complex_dtypes +number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes all_dtypes = number_dtypes + bool_dtypes +NO_VALUE = object() -python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_] +python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_] # uint64 is problematic because with any uint type it promotes to float: int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64] +def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False, + axis=None, **kwds): + # Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0 + result = np.unique(ar, return_index=return_index, return_inverse=return_inverse, + return_counts=return_counts, axis=axis, **kwds) + if jtu.numpy_version() >= (2, 0, 0) or np.ndim(ar) == 1 or not return_inverse: + return result + + idx = 2 if return_index else 1 + inverse_indices = result[idx] + if axis is None: + inverse_indices = inverse_indices.reshape(np.shape(ar)) + else: + inverse_indices = np.expand_dims(inverse_indices, [i for i in range(np.ndim(ar)) if i != axis]) + return (*result[:idx], inverse_indices, *result[idx + 1:]) + + def _indexer_with_default_outputs(indexer, use_defaults=True): """Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs""" class Indexer: @@ -122,9 +161,9 @@ def _shapes_are_broadcast_compatible(shapes): def _shapes_are_equal_length(shapes): return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) - +@unittest.skipIf(metal_plugin == None, "Tests require jax-metal plugin.") class LaxBackedNumpyTests(jtu.JaxTestCase): - """Tests for LAX-backed Numpy tests for metal implementation.""" + """Tests for LAX-backed Numpy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): def f(): @@ -137,8 +176,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): return f @parameterized.parameters( - [dtype for dtype in [jnp.bool_, jnp.uint8, jnp.uint16, jnp.uint32, - jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64, + [dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32, + jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.float16, jnp.float32] if dtype == dtypes.canonicalize_dtype(dtype)]) def testDtypeWrappers(self, dtype): @@ -152,6 +191,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): prims = [eqn.primitive for eqn in jaxpr.eqns] self.assertEqual(prims, [lax.convert_element_type_p]) # No copy generated. + def testBoolDtypeAlias(self): + self.assertIs(jnp.bool, jnp.bool_) + @jtu.sample_product( dtype=float_dtypes + [object], allow_pickle=[True, False], @@ -167,6 +209,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): arr_out = jnp.load(f, allow_pickle=allow_pickle) self.assertArraysEqual(arr, arr_out, allow_object_dtype=True) + @unittest.skip("Jax-metal fail.") def testArrayEqualExamples(self): # examples from the array_equal() docstring. self.assertTrue(jnp.array_equal([1, 2], [1, 2])) @@ -178,6 +221,12 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertFalse(jnp.array_equal(a, a)) self.assertTrue(jnp.array_equal(a, a, equal_nan=True)) + a = np.array([1 + 1j]) + b = a.copy() + a.real = np.nan + b.imag = np.nan + self.assertTrue(jnp.array_equal(a, b, equal_nan=True)) + def testArrayEquivExamples(self): # examples from the array_equiv() docstring. self.assertTrue(jnp.array_equiv([1, 2], [1, 2])) @@ -215,16 +264,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): period=["2pi", "pi"], dtype=default_dtypes, ) - - # This currently seg-faults because of dynamic shape = (0, ) def testUnwrap(self, shape, dtype, axis, discont, period): - if (not np.all(shape)): - self.skipTest("JAX metal does not support dynamic shape unwrap()") - - self.skipTest("JAX metal does not support remainder") - - if numpy_version < (1, 21) and period != "2pi": - self.skipTest("numpy < 1.21 does not support the period argument to unwrap()") special_vals = {"pi": np.pi, "2pi": 2 * np.pi} period = special_vals.get(period, period) discont = special_vals.get(discont, discont) @@ -236,10 +276,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): if x.dtype == dtypes.bfloat16: dtype = x.dtype x = x.astype(np.float32) - if numpy_version < (1, 21): - out = np.unwrap(x, axis=axis, discont=discont or np.pi) - else: - out = np.unwrap(x, axis=axis, discont=discont, period=period) + out = np.unwrap(x, axis=axis, discont=discont, period=period) return out if dtype is None else out.astype(dtype) jnp_fun = partial(jnp.unwrap, axis=axis, discont=discont, period=period) @@ -251,75 +288,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): atol={dtypes.bfloat16: 1e-1, np.float16: 1e-2}) self._CompileAndCheck(jnp_fun, args_maker, atol={dtypes.bfloat16: 1e-1}) - @jtu.sample_product( - [dict(shape=shape, indices=indices, update_shape=update_shape) - for shape, indices, update_shape in [ - [(3,), (1,), (1,)], - [(5, 3), (1, 1), (3, 1)], - [(7, 5, 3), (4, 1, 0), (2, 3, 1)], - [(1, 50, 8, 64), (0,20, 0, 0), (1,1,8,64)] - ] - ], - dtype=default_dtypes, - ) - def testDynamicUpdateSliceAgainstNumpy(self, shape, dtype, indices, - update_shape): - rng = jtu.rand_default(self.rng()) - - def args_maker(): - return [rng(shape, dtype), rng(update_shape, dtype), np.array(indices)] - - self._CheckAgainstNumpy(lax_reference.dynamic_update_slice, - lax.dynamic_update_slice, args_maker) - - @jtu.sample_product( - [dict(init_val=init_val, op=op, dtype=dtype) - for init_val, op, dtypes in [ - (0, lax.add, [np.float32]), - (-np.inf, lax.max, [np.float32]), - #(np.inf, lax.min, [np.float32]), - ] - for dtype in [np.float32] - ], - [dict(shape=shape, dims=dims, strides=strides, padding=padding, - base_dilation=base_dilation, window_dilation=window_dilation) - for shape, dims, strides, padding, base_dilation, window_dilation in ( - itertools.chain( - itertools.product( - [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], - [(1, 2, 2, 1), (1, 1, 1, 1)], - ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], - [(1, 1, 1, 1)], - [(1, 1, 1, 1), (1, 2, 2, 1)]))) - ], - ) - def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding, - base_dilation, window_dilation): - rng = jtu.rand_small(self.rng()) - init_val = np.asarray(init_val, dtype=dtype) - - def fun(operand, init_val): - return lax.reduce_window(operand, init_val, op, dims, strides, padding, - base_dilation, window_dilation) - - def reference_fun(operand, init_val): - return lax_reference.reduce_window(operand, init_val, op, dims, strides, - padding, base_dilation) - - args_maker = lambda: [rng(shape, dtype), init_val] - self._CompileAndCheck(fun, args_maker) - if all(d == 1 for d in window_dilation): - self._CheckAgainstNumpy(reference_fun, fun, args_maker) - - # we separately test the version that uses a concrete init_val because it - # can hit different code paths - def fun(operand): - return lax.reduce_window(operand, init_val, op, dims, strides, padding, - base_dilation, window_dilation) - - args_maker = lambda: [rng(shape, dtype)] - self._CompileAndCheck(fun, args_maker) - @jtu.sample_product( [dict(shape=shape, axis=axis) for shape in all_shapes @@ -327,8 +295,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): dtype=all_dtypes, ) def testCountNonzero(self, shape, dtype, axis): - if (not np.all(shape)): - self.skipTest("JAX metal does not support dynamic shape unwrap()") rng = jtu.rand_some_zero(self.rng()) np_fun = lambda x: np.count_nonzero(x, axis) jnp_fun = lambda x: jnp.count_nonzero(x, axis) @@ -338,19 +304,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) def testNonzero(self, shape, dtype): - if (not np.all(shape)): - self.skipTest("JAX metal does not support dynamic shape unwrap()") - if (dtype == np.uint16): - self.skipTest("JAX metal has a crash with NonZero uint16") - self.skipTest("JAX metal has a bug with Pad") rng = jtu.rand_some_zero(self.rng()) - np_fun = lambda x: np.nonzero(x) - np_fun = jtu.ignore_warning( - category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*")(np_fun) - jnp_fun = lambda x: jnp.nonzero(x) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + with jtu.ignore_warning(category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*"): + self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False) @jtu.sample_product( [dict(shape=shape, fill_value=fill_value) @@ -361,10 +319,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): size=[1, 5, 10], ) def testNonzeroSize(self, shape, dtype, size, fill_value): - self.skipTest("JAX metal has a bug with Pad") rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - @jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*") def np_fun(x): result = np.nonzero(x) if size <= len(result[0]): @@ -374,26 +330,86 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) for fval, arg in safe_zip(fillvals, result)) jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) + with jtu.ignore_warning(category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*"): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) - @jtu.sample_product(shape=all_shapes, dtype=[np.float32]) - def testArgWhere(self, shape, dtype): - if (not np.all(shape)): - self.skipTest("JAX metal does not support dynamic shape unwrap()") - self.skipTest("JAX metal does not support ArgWhere") + @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + def testFlatNonzero(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) np_fun = jtu.ignore_warning( category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*")(np.argwhere) - jnp_fun = jnp.argwhere + message="Calling nonzero on 0d arrays.*")(np.flatnonzero) + jnp_fun = jnp.flatnonzero args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + # JIT compilation requires specifying the size statically: + jnp_fun = lambda x: jnp.flatnonzero(x, size=np.size(x) // 2) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape=nonempty_array_shapes, + dtype=all_dtypes, + fill_value=[None, -1, 10, (-1,), (10,)], + size=[1, 5, 10], + ) + def testFlatNonzeroSize(self, shape, dtype, size, fill_value): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + @jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*") + def np_fun(x): + result = np.flatnonzero(x) + if size <= len(result): + return result[:size] + else: + fill_val = fill_value or 0 + return np.concatenate([result, np.full(size - len(result), fill_val, result.dtype)]) + jnp_fun = lambda x: jnp.flatnonzero(x, size=size, fill_value=fill_value) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + def testArgWhere(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + with jtu.ignore_warning(category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*"): + self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False) + # JIT compilation requires specifying a size statically. Full test of this # behavior is in testNonzeroSize(). jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2) - self._CompileAndCheck(jnp_fun, args_maker) + with jtu.ignore_warning(category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*"): + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, fill_value=fill_value) + for shape in nonempty_array_shapes + for fill_value in [None, -1, shape or (1,)] + ], + dtype=all_dtypes, + size=[1, 5, 10], + ) + def testArgWhereSize(self, shape, dtype, size, fill_value): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + def np_fun(x): + result = np.argwhere(x) + if size <= len(result): + return result[:size] + else: + fillvals = fill_value if np.ndim(fill_value) else result.shape[-1] * [fill_value or 0] + return np.empty((size, 0), dtype=int) if np.ndim(x) == 0 else np.stack([np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) + for fval, arg in safe_zip(fillvals, result.T)]).T + jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value) + + with jtu.ignore_warning(category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*"): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( [dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name), @@ -405,12 +421,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): ) def testArgMinMax(self, np_op, jnp_op, rng_factory, shape, dtype, axis, keepdims): rng = rng_factory(self.rng()) - if dtype == np.complex128 and jtu.device_under_test() == "gpu": + if dtype == np.complex128 and jtu.test_device_matches(["gpu"]): raise unittest.SkipTest("complex128 reductions not supported on GPU") if "nan" in np_op.__name__ and dtype == jnp.bfloat16: raise unittest.SkipTest("NumPy doesn't correctly handle bfloat16 arrays") - if numpy_version < (1, 22) and keepdims: - raise unittest.SkipTest("NumPy < 1.22 does not support keepdims argument to argmin/argmax") kwds = {"keepdims": True} if keepdims else {} np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=axis, **kwds)) @@ -432,12 +446,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): for rec in JAX_ARGMINMAX_RECORDS], ) def testArgMinMaxEmpty(self, name, np_op, jnp_op): - self.skipTest("JAX metal does crash") name = name[3:] if name.startswith("nan") else name msg = f"attempt to get {name} of an empty sequence" - with self.assertRaises(ValueError, msg=msg): + with self.assertRaisesRegex(ValueError, msg): jnp_op(np.array([])) - with self.assertRaises(ValueError, msg=msg): + with self.assertRaisesRegex(ValueError, msg): jnp_op(np.zeros((2, 0)), axis=1) np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=0)) jnp_fun = partial(jnp_op, axis=0) @@ -468,6 +481,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] axisa, axisb, axisc, axis = axes jnp_fun = lambda a, b: jnp.cross(a, b, axisa, axisb, axisc, axis) + # Note: 2D inputs to jnp.cross are deprecated in numpy 2.0. + @jtu.ignore_warning(category=DeprecationWarning, + message="Arrays of 2-dimensional vectors are deprecated.") def np_fun(a, b): a = a.astype(np.float32) if lhs_dtype == jnp.bfloat16 else a b = b.astype(np.float32) if rhs_dtype == jnp.bfloat16 else b @@ -481,27 +497,30 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) @jtu.sample_product( - [dict(name=name, lhs_shape=lhs_shape, rhs_shape=rhs_shape) - for name, lhs_shape, rhs_shape in [ - ("matrix-scalar", (3, 3), ()), - ("scalar-matrix", (), (3, 3)), - # ("matrix-vector", (4, 5), (5,)), - ("vector-matrix", (6,), (6, 4)), - ("matrix-matrix", (3, 4), (4, 5)), - # ("tensor-vector", (4, 3, 2), (2,)), - ("vector-tensor", (2,), (3, 2, 4)), - ("tensor-matrix", (4, 3, 2), (2, 5)), - ("matrix-tensor", (5, 2), (3, 2, 4)), - ("tensor-tensor", (2, 3, 4), (5, 4, 1))]], - lhs_dtype=float_dtypes, - rhs_dtype=float_dtypes, + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) + for lhs_shape, rhs_shape in [ + ((3, 3), ()), + ((), (3, 3)), + ((4, 5), (5,)), + ((6,), (6, 4)), + ((3, 4), (4, 5)), + ((4, 3, 2), (2,)), + ((2,), (3, 2, 4)), + ((4, 3, 2), (2, 5)), + ((5, 2), (3, 2, 4)), + ((2, 3, 4), (5, 4, 1))]], + lhs_dtype=float_dtypes,#number_dtypes, + rhs_dtype=float_dtypes,#number_dtypes, ) @jax.default_matmul_precision("float32") - def testDot(self, name, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): + def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] tol = {np.float16: 1e-2, np.float32: 2e-5, np.float64: 1e-14, np.complex128: 1e-14} + if (lhs_dtype in [np.float16, jnp.bfloat16] and + rhs_dtype in [np.float16, jnp.bfloat16]): + tol = 1e-2 def np_dot(x, y): x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y @@ -510,6 +529,23 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, tol=tol) self._CompileAndCheck(jnp.dot, args_maker, atol=tol, rtol=tol) + @jtu.sample_product( + lhs_dtype=number_dtypes, + rhs_dtype=number_dtypes, + ) + @jax.numpy_dtype_promotion('standard') + def testMixedPrecisionDot(self, lhs_dtype, rhs_dtype): + # This test confirms that jnp.dot lowers to a single dot_general call, + # avoiding explicit type casting of inputs and outputs. + lhs = jax.ShapeDtypeStruct((5,), lhs_dtype) + rhs = jax.ShapeDtypeStruct((5,), rhs_dtype) + jaxpr = jax.make_jaxpr(jnp.dot)(lhs, rhs) + prims = [eqn.primitive for eqn in jaxpr.eqns] + self.assertIn(prims, [ + [lax.dot_general_p], + [lax.dot_general_p, lax.convert_element_type_p] + ]) + @jtu.sample_product( [dict(name=name, lhs_shape=lhs_shape, rhs_shape=rhs_shape) for name, lhs_shape, rhs_shape in [ @@ -523,8 +559,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): ("tensor-matrix", (5, 2, 3), (3, 2)), ("tensor-tensor", (5, 3, 4), (5, 4, 1)), ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]], - lhs_dtype=float_dtypes, - rhs_dtype=float_dtypes, + lhs_dtype=float_dtypes, #number_dtypes, + rhs_dtype=float_dtypes, #number_dtypes, ) @jax.default_matmul_precision("float32") def testMatmul(self, name, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): @@ -540,6 +576,39 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol) self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol) + @jtu.sample_product( + lhs_batch=broadcast_compatible_shapes, + rhs_batch=broadcast_compatible_shapes, + axis_size=[2, 4], + axis=range(-2, 2), + dtype=float_dtypes,#number_dtypes, + ) + @jax.default_matmul_precision("float32") + def testVecdot(self, lhs_batch, rhs_batch, axis_size, axis, dtype): + # Construct vecdot-compatible shapes. + size = min(len(lhs_batch), len(rhs_batch)) + axis = int(np.clip(axis, -size - 1, size)) + if axis >= 0: + lhs_shape = (*lhs_batch[:axis], axis_size, *lhs_batch[axis:]) + rhs_shape = (*rhs_batch[:axis], axis_size, *rhs_batch[axis:]) + else: + laxis = axis + len(lhs_batch) + 1 + lhs_shape = (*lhs_batch[:laxis], axis_size, *lhs_batch[laxis:]) + raxis = axis + len(rhs_batch) + 1 + rhs_shape = (*rhs_batch[:raxis], axis_size, *rhs_batch[raxis:]) + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + @jtu.promote_like_jnp + def np_fn(x, y, axis=axis): + f = jtu.numpy_vecdot if jtu.numpy_version() < (2, 0, 0) else np.vecdot + return f(x, y, axis=axis).astype(x.dtype) + jnp_fn = partial(jnp.vecdot, axis=axis) + tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12, + np.complex64: 1E-3, np.complex128: 1e-12} + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) + self._CompileAndCheck(jnp_fn, args_maker, tol=tol) + @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes) for lhs_shape, rhs_shape, axes in [ @@ -550,8 +619,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]], [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]], ]], - lhs_dtype=float_dtypes, - rhs_dtype=float_dtypes, + lhs_dtype=float_dtypes,#number_dtypes, + rhs_dtype=float_dtypes,#number_dtypes, ) @jax.default_matmul_precision("float32") def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes): @@ -568,7 +637,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker, tol=tol) def testTensordotErrors(self): a = self.rng().random((3, 2, 2)) @@ -592,15 +661,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): @jtu.sample_product( element_shape=all_shapes, test_shape=all_shapes, - dtype=float_dtypes, # TODO: int_dtypes are not working - # dtype=default_dtypes, + dtype=default_dtypes, invert=[False, True], ) def testIsin(self, element_shape, test_shape, dtype, invert): - if ((not np.all(test_shape)) or (not np.all(element_shape))): - self.skipTest("JAX metal does not support dynamic shape") - if not ReportedIssuesTests.jax_metal_supported('0.0.6'): - self.skipTest("JAx metal has a regression on the version") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert) @@ -609,61 +673,3515 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonzerodim_shapes - for axis in (None, *range(len(shape))) - ], - dtype=float_dtypes, + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=all_shapes, + shape2=all_shapes, ) - def testSort(self, dtype, shape, axis): - if (axis != 0): - return - rng = jtu.rand_some_equal(self.rng()) + def testSetdiff1d(self, shape1, shape2, dtype1, dtype2): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np.setdiff1d, jnp.setdiff1d, args_maker) + + @unittest.skip("JAx-metal fail.") + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=all_shapes, + shape2=all_shapes, + size=[1, 5, 10], + fill_value=[None, -1], + ) + def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + def np_fun(arg1, arg2): + result = np.setdiff1d(arg1, arg2) + if size <= len(result): + return result[:size] + else: + return np.pad(result, (0, size-len(result)), constant_values=fill_value or 0) + def jnp_fun(arg1, arg2): + return jnp.setdiff1d(arg1, arg2, size=size, fill_value=fill_value) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=nonempty_nonscalar_array_shapes, + shape2=nonempty_nonscalar_array_shapes, + ) + def testUnion1d(self, shape1, shape2, dtype1, dtype2): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + def np_fun(arg1, arg2): + dtype = jnp.promote_types(arg1.dtype, arg2.dtype) + return np.union1d(arg1, arg2).astype(dtype) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp.union1d, args_maker) + + @unittest.skip("Jax-metal fail.") + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=nonempty_nonscalar_array_shapes, + shape2=nonempty_nonscalar_array_shapes, + size=[1, 5, 10], + fill_value=[None, -1], + ) + def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + def np_fun(arg1, arg2): + dtype = jnp.promote_types(arg1.dtype, arg2.dtype) + result = np.union1d(arg1, arg2).astype(dtype) + fv = result.min() if fill_value is None else fill_value + if size <= len(result): + return result[:size] + else: + return np.concatenate([result, np.full(size - len(result), fv, result.dtype)]) + def jnp_fun(arg1, arg2): + return jnp.union1d(arg1, arg2, size=size, fill_value=fill_value) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=all_shapes, + shape2=all_shapes, + assume_unique=[False, True], + ) + def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique) + def np_fun(ar1, ar2): + if assume_unique: + # pre-flatten the arrays to match with jax implementation + ar1 = np.ravel(ar1) + ar2 = np.ravel(ar2) + return np.setxor1d(ar1, ar2, assume_unique) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=all_shapes, + shape2=all_shapes, + assume_unique=[False, True], + return_indices=[False, True], + ) + def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, + return_indices): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + jnp_fun = lambda ar1, ar2: jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, lhs_dtype=lhs_dtype, + rhs_shape=rhs_shape, rhs_dtype=rhs_dtype) + # TODO(phawkins): support integer dtypes too. + for lhs_shape, lhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) + for rhs_shape, rhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) + if len(jtu._dims_of_shape(lhs_shape)) == 0 + or len(jtu._dims_of_shape(rhs_shape)) == 0 + or lhs_shape[-1] == rhs_shape[-1]], + ) + @jax.default_matmul_precision("float32") + def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + def np_fun(lhs, rhs): + lhs = lhs if lhs_dtype != jnp.bfloat16 else lhs.astype(np.float32) + rhs = rhs if rhs_dtype != jnp.bfloat16 else rhs.astype(np.float32) + dtype = jnp.promote_types(lhs_dtype, rhs_dtype) + return np.inner(lhs, rhs).astype(dtype) + jnp_fun = lambda lhs, rhs: jnp.inner(lhs, rhs) + tol_spec = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-13, + np.complex64: 1e-5} + tol = max(jtu.tolerance(lhs_dtype, tol_spec), + jtu.tolerance(rhs_dtype, tol_spec)) + # TODO(phawkins): there are float32/float64 disagreements for some inputs. + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol) + + @unittest.skip("MLIR translation rule for primitive 'eigh' not found for platform METAL.") + @jtu.sample_product( + dtype=[dt for dt in float_dtypes if dt not in [jnp.float16, jnp.bfloat16]], + shape=[shape for shape in one_dim_array_shapes if shape != (1,)], + deg=[1, 2, 3], + rcond=[None, -1, 10e-3, 10e-5, 10e-10], + full=[False, True], + w=[False, True], + cov=[False, True, "unscaled"], + ) + @jax.default_matmul_precision("float32") + def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov): + rng = jtu.rand_default(self.rng()) + tol_spec = {np.float32: 1e-3, np.float64: 1e-13, np.complex64: 1e-5} + tol = jtu.tolerance(dtype, tol_spec) + _w = lambda a: abs(a) if w else None + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] + jnp_fun = lambda x, y, a: jnp.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov) + np_fun = jtu.ignore_warning( + message="Polyfit may be poorly conditioned*")(lambda x, y, a: np.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov)) + + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol) + + args = args_maker() + if not full: + args = args_maker() + try: + np_out = np_fun(*args) + except ValueError: + return # https://github.com/numpy/numpy/issues/22380 + jnp_out = jnp_fun(*args) + self.assertAllClose(np_out, jnp_out, atol=tol, rtol=tol, + check_dtypes=False) + else: + # Don't compare the residuals because jnp.linalg.lstsq acts slightly + # differently to remain `jit`-compatible. + np_p, _, nrank, nsingular_values, nrcond = np_fun(*args) + jp_p, _, jrank, jsingular_values, jrcond = jnp_fun(*args) + self.assertAllClose( + (np_p, nrank, nsingular_values, nrcond), + (jp_p, jrank, jsingular_values, jrcond), + atol=tol, rtol=tol, check_dtypes=False) + + @jtu.sample_product( + [dict(a_min=a_min, a_max=a_max) + for a_min, a_max in [(-1, None), (None, 1), (-0.9, 1), + (-np.ones(1), None), + (None, np.ones(1)), + (np.full(1, -0.9), np.ones(1))] + ], + shape=all_shapes, + dtype=number_dtypes, + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + def testClipStaticBounds(self, shape, dtype, a_min, a_max): + if np.issubdtype(dtype, np.unsignedinteger): + a_min = None if a_min is None else abs(a_min) + a_max = None if a_max is None else abs(a_max) + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max) + jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max) args_maker = lambda: [rng(shape, dtype)] - jnp_fun = jnp.sort - np_fun = np.sort - if axis is not None: - jnp_fun = partial(jnp_fun, axis=axis) - np_fun = partial(np_fun, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + def testClipError(self): + with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"): + jnp.clip(jnp.zeros((3,))) + + @jtu.sample_product( + [dict(shape=shape, dtype=dtype) + for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)], + decimals=[0, 1, -2], + ) + def testRoundStaticDecimals(self, shape, dtype, decimals): + rng = jtu.rand_default(self.rng()) + if jnp.issubdtype(dtype, np.integer) and decimals < 0: + self.skipTest("Integer rounding with decimals < 0 not implemented") + np_fun = lambda x: np.round(x, decimals=decimals) + jnp_fun = lambda x: jnp.round(x, decimals=decimals) + args_maker = lambda: [rng(shape, dtype)] + tol = {jnp.bfloat16: 5e-2, np.float16: 1e-2} + check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=check_dtypes, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes, + atol=tol, rtol=tol) + + @jtu.sample_product(jit=[False, True]) + def testOperatorRound(self, jit): + jround = jax.jit(round, static_argnums=1) if jit else round + self.assertAllClose(round(np.float32(7.532), 1), + jround(jnp.float32(7.5), 1)) + self.assertAllClose(round(np.float32(1.234), 2), + jround(jnp.float32(1.234), 2)) + self.assertAllClose(round(np.float32(1.234)), + jround(jnp.float32(1.234)), check_dtypes=False) + self.assertAllClose(round(np.float32(7.532), 1), + jround(jnp.array(7.5, jnp.float32), 1)) + self.assertAllClose(round(np.float32(1.234), 2), + jround(jnp.array(1.234, jnp.float32), 2)) + self.assertAllClose(round(np.float32(1.234)), + jround(jnp.array(1.234, jnp.float32)), + check_dtypes=False) + + def testRoundMethod(self): + # https://github.com/google/jax/issues/15190 + (jnp.arange(3.) / 5.).round() # doesn't crash + + @jtu.sample_product(shape=[(5,), (5, 2)]) + def testOperatorReversed(self, shape): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, 'float32')] + np_fun = lambda x: np.array(list(reversed(x))) + jnp_fun = lambda x: jnp.array(list(reversed(x))) + + self._CompileAndCheck(jnp_fun, args_maker) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - + + @jtu.sample_product( + [dict(mode=mode, shape=shape, dtype=dtype, + pad_width=pad_width, constant_values=constant_values) + for mode, shapes in [ + ('constant', all_shapes), + ('wrap', nonempty_shapes), + ('edge', nonempty_shapes), + ] + for shape, dtype in _shape_and_dtypes(shapes, all_dtypes) + for constant_values in [ + # None is used for modes other than 'constant' + None, + # constant + 0, 1, + # (constant,) + (0,), (2.718,), + # ((before_const, after_const),) + ((0, 2),), ((-1, 3.14),), + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i / 2, -3.14 * i) for i in range(len(shape))), + ] + for pad_width in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 0),), + # (before, after) (not in the docstring but works in numpy) + (2, 0), (0, 0), + # (pad,) + (1,), (2,), + # pad + 0, 1, + ] + if (pad_width != () and constant_values != () and + ((mode == 'constant' and constant_values is not None) or + (mode != 'constant' and constant_values is None)))], + ) + def testPad(self, shape, dtype, mode, pad_width, constant_values): + if np.issubdtype(dtype, np.unsignedinteger): + constant_values = jax.tree.map(abs, constant_values) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if constant_values is None: + np_fun = partial(np.pad, pad_width=pad_width, mode=mode) + jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode) + else: + np_fun = partial(np.pad, pad_width=pad_width, mode=mode, + constant_values=constant_values) + jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, + constant_values=constant_values) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(mode=mode, shape=shape, dtype=dtype, + pad_width=pad_width, stat_length=stat_length) + for mode in ['maximum', 'minimum', 'mean', 'median'] + for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes) + for pad_width in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 0),), + # (before, after) (not in the docstring but works in numpy) + (2, 0), (0, 0), + # (pad,) + (1,), (2,), + # pad + 0, 1, + ] + for stat_length in [ + None, + # ((before_1, after_1), ..., (before_N, after_N)) + tuple(((i % 3 + 1), ((i + 1) % 3) + 1) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 2),), + # (before, after) (not in the docstring but works in numpy) + (1, 1), (3, 4), + # (pad,) + (1,), (2,), + # pad + 1, 2 + ] + if (pad_width != () and stat_length != () and + not (dtype in bool_dtypes and mode == 'mean'))], + ) + def testPadStatValues(self, shape, dtype, mode, pad_width, stat_length): + if mode == 'median' and np.issubdtype(dtype, np.complexfloating): + self.skipTest("median statistic is not supported for dtype=complex.") + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + np_fun = partial(np.pad, pad_width=pad_width, mode=mode, stat_length=stat_length) + jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, stat_length=stat_length) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, dtype=dtype, + pad_width=pad_width, reflect_type=reflect_type) + for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes) + for pad_width in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 3),), + # (before, after) (not in the docstring but works in numpy) + (2, 1), (1, 2), + # (pad,) + (1,), (2,), (3,), + # pad + 0, 5, 7, 10 + ] + for reflect_type in ['even', 'odd'] + if (pad_width != () and + # following types lack precision when calculating odd values + (reflect_type != 'odd' or dtype not in [np.bool_, np.float16, jnp.bfloat16]))], + mode=['symmetric', 'reflect'] + ) + def testPadSymmetricAndReflect(self, shape, dtype, mode, pad_width, reflect_type): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + np_fun = partial(np.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type) + jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE, + tol={np.float32: 1e-3, np.complex64: 1e-3}) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, dtype=dtype, pad_width=pad_width, end_values=end_values) + for shape, dtype in _shape_and_dtypes(nonempty_shapes, default_dtypes + complex_dtypes) + for pad_width in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 0),), + # (before, after) (not in the docstring but works in numpy) + (2, 0), (0, 0), + # (pad,) + (1,), (2,), + # pad + 0, 1, + ] + for end_values in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2.0, 3.14),), + # (before, after) (not in the docstring but works in numpy) + (0, 0), (-8.0, 2.0), + # (end_values,) + (1,), (2,), + # end_values + 0, 1, 100, 10.0, 3.5, 4.2, -5, -3 + ] + if (pad_width != () and end_values != () and + # following types lack precision + dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16])], + ) + def testPadLinearRamp(self, shape, dtype, pad_width, end_values): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + np_fun = partial(np.pad, pad_width=pad_width, mode="linear_ramp", + end_values=end_values) + jnp_fun = partial(jnp.pad, pad_width=pad_width, mode="linear_ramp", + end_values=end_values) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) + self._CompileAndCheck(jnp_fun, args_maker) + + def testPadEmpty(self): + arr = np.arange(6).reshape(2, 3) + + pad_width = ((2, 3), (3, 1)) + np_res = np.pad(arr, pad_width=pad_width, mode="empty") + jnp_res = jnp.pad(arr, pad_width=pad_width, mode="empty") + + np.testing.assert_equal(np_res.shape, jnp_res.shape) + np.testing.assert_equal(arr, np_res[2:-3, 3:-1]) + np.testing.assert_equal(arr, jnp_res[2:-3, 3:-1]) + np.testing.assert_equal(np_res[2:-3, 3:-1], jnp_res[2:-3, 3:-1]) + + def testPadKwargs(self): + modes = { + 'constant': {'constant_values': 0}, + 'edge': {}, + 'linear_ramp': {'end_values': 0}, + 'maximum': {'stat_length': None}, + 'mean': {'stat_length': None}, + 'median': {'stat_length': None}, + 'minimum': {'stat_length': None}, + 'reflect': {'reflect_type': 'even'}, + 'symmetric': {'reflect_type': 'even'}, + 'wrap': {}, + 'empty': {} + } + arr = jnp.array([1, 2, 3]) + pad_width = 1 + + for mode in modes.keys(): + allowed = modes[mode] + not_allowed = {} + for kwargs in modes.values(): + if kwargs != allowed: + not_allowed.update(kwargs) + + # Test if allowed keyword arguments pass + jnp.pad(arr, pad_width, mode, **allowed) + # Test if prohibited keyword arguments of other modes raise an error + match = f"unsupported keyword arguments for mode '{mode}'" + for key, value in not_allowed.items(): + with self.assertRaisesRegex(ValueError, match): + jnp.pad(arr, pad_width, mode, **{key: value}) + + # Test if unsupported mode raise error. + unsupported_modes = [1, None, "foo"] + for mode in unsupported_modes: + match = f"Unimplemented padding mode '{mode}' for np.pad." + with self.assertRaisesRegex(NotImplementedError, match): + jnp.pad(arr, pad_width, mode) + + def testPadFunction(self): + def np_pad_with(vector, pad_width, iaxis, kwargs): + pad_value = kwargs.get('padder', 10) + vector[:pad_width[0]] = pad_value + vector[-pad_width[1]:] = pad_value + + def jnp_pad_with(vector, pad_width, iaxis, kwargs): + pad_value = kwargs.get('padder', 10) + vector = vector.at[:pad_width[0]].set(pad_value) + vector = vector.at[-pad_width[1]:].set(pad_value) + return vector + + arr = np.arange(6).reshape(2, 3) + np_res = np.pad(arr, 2, np_pad_with) + jnp_res = jnp.pad(arr, 2, jnp_pad_with) + np.testing.assert_equal(np_res, jnp_res) + + arr = np.arange(24).reshape(2, 3, 4) + np_res = np.pad(arr, 1, np_pad_with, padder=100) + jnp_res = jnp.pad(arr, 1, jnp_pad_with, padder=100) + np.testing.assert_equal(np_res, jnp_res) + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(arr.shape, arr.dtype)] + jnp_fun = partial(jnp.pad, pad_width=1, mode=jnp_pad_with) + self._CompileAndCheck(jnp_fun, args_maker) + + def testPadWithNumpyPadWidth(self): + a = jnp.array([1, 2, 3, 4, 5]) + f = jax.jit( + partial( + jnp.pad, + pad_width=np.asarray((2, 3)), + mode="constant", + constant_values=(4, 6))) + + np.testing.assert_array_equal( + f(a), + np.pad( + a, + pad_width=np.asarray((2, 3)), + mode="constant", + constant_values=(4, 6))) + + def testPadWeakType(self): + x = jnp.array(1.0)[None] + for mode in ['constant', 'edge', 'linear_ramp', 'maximum', 'mean', 'median', + 'minimum', 'reflect', 'symmetric', 'wrap', 'empty']: + y = jnp.pad(x, 0, mode=mode) + self.assertTrue(dtypes.is_weakly_typed(y)) + + @jtu.sample_product( + [dict(shape=shape, dtype=dtype) + for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes)], + reps=[(), (2,), (3, 4), (2, 3, 4), (1, 0, 2)], + ) + def testTile(self, shape, dtype, reps): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np.tile(arg, reps) + jnp_fun = lambda arg: jnp.tile(arg, reps) + + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + def testExtract(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)] + self._CheckAgainstNumpy(np.extract, jnp.extract, args_maker) + + @jtu.sample_product( + [dict(ncond=ncond, nfunc=nfunc) + for ncond in [1, 2, 3] + for nfunc in [ncond, ncond + 1] + ], + shape=all_shapes, + dtype=all_dtypes) + def testPiecewise(self, shape, dtype, ncond, nfunc): + rng = jtu.rand_default(self.rng()) + rng_bool = jtu.rand_int(self.rng(), 0, 2) + funclist = [lambda x: x - 1, 1, lambda x: x, 0][:nfunc] + args_maker = lambda: (rng(shape, dtype), [rng_bool(shape, bool) for i in range(ncond)]) + np_fun = partial(np.piecewise, funclist=funclist) + jnp_fun = partial(jnp.piecewise, funclist=funclist) + + if dtype == np.bool_: + # The `x - 1` above uses type promotion. + jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + # This is a higher-order function, so the cache miss check will fail. + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, check_cache_misses=False) + + def testPiecewiseRecompile(self): + def g(x): + g.num_traces += 1 + return x + g.num_traces = 0 + x = jnp.arange(10.0) + for i in range(5): + jnp.piecewise(x, [x < 0], [g, 0.]) + self.assertEqual(g.num_traces, 1) + + @jtu.sample_product( + [dict(shape=shape, perm=perm) + for shape in array_shapes + for perm in [ + None, + tuple(np.random.RandomState(0).permutation(np.zeros(shape).ndim)), + tuple(np.random.RandomState(0).permutation( + np.zeros(shape).ndim) - np.zeros(shape).ndim) + ] + ], + dtype=default_dtypes, + arg_type=["splat", "value"], + ) + def testTransposeTuple(self, shape, dtype, perm, arg_type): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if arg_type == "value": + np_fun = lambda x: x.transpose(perm) + jnp_fun = lambda x: jnp.array(x).transpose(perm) + else: + np_fun = lambda x: x.transpose(*(perm or ())) + jnp_fun = lambda x: jnp.array(x).transpose(*(perm or ())) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @jtu.sample_product( + shape=array_shapes, + dtype=default_dtypes, + ) + def testPermuteDims(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + axes = self.rng().permutation(len(shape)) + np_fun = partial(getattr(np, "permute_dims", np.transpose), axes=axes) + jnp_fun = partial(jnp.permute_dims, axes=axes) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @jtu.sample_product( + shape=[s for s in array_shapes if len(s) >= 2], + dtype=default_dtypes, + use_property=[True, False] + ) + def testMatrixTranspose(self, shape, dtype, use_property): + if use_property: + jnp_fun = lambda x: jnp.asarray(x).mT + else: + jnp_fun = jnp.matrix_transpose + if hasattr(np, 'matrix_transpose'): + np_fun = np.matrix_transpose + else: + np_fun = lambda x: np.swapaxes(x, -1, -2) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + a_shape=one_dim_array_shapes, + trim=["f", "b", "fb"], + ) + def testTrimZeros(self, a_shape, dtype, trim): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(a_shape, dtype)] + np_fun = lambda arg1: np.trim_zeros(arg1, trim) + jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + + @unittest.skip("Jax-metal don't support map op.") + @jtu.sample_product( + rank=(1, 2), + dtype=default_dtypes, + a_shape=one_dim_array_shapes, + ) + @jax.default_matmul_precision("float32") + def testPoly(self, a_shape, dtype, rank): + if dtype in (np.float16, jnp.bfloat16, np.int16): + self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.") + elif rank == 2 and not jtu.test_device_matches(["cpu"]): + self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.") + rng = jtu.rand_default(self.rng()) + tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 } + if jtu.test_device_matches(["tpu"]): + tol[np.int32] = tol[np.float32] = 1e-1 + tol = jtu.tolerance(dtype, tol) + args_maker = lambda: [rng(a_shape * rank, dtype)] + self._CheckAgainstNumpy(np.poly, jnp.poly, args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp.poly, args_maker, check_dtypes=True, rtol=tol, atol=tol) + + @unittest.skip("Jax-metal don't support map op.") + @jtu.sample_product( + dtype=default_dtypes, + a_shape=one_dim_array_shapes, + b_shape=one_dim_array_shapes, + ) + def testPolyAdd(self, a_shape, b_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1, arg2: np.polyadd(arg1, arg2) + jnp_fun = lambda arg1, arg2: jnp.polyadd(arg1, arg2) + args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @unittest.skip("Jax-metal don't support map op.") + @jtu.sample_product( + dtype=default_dtypes, + a_shape=one_dim_array_shapes, + b_shape=one_dim_array_shapes, + ) + def testPolySub(self, a_shape, b_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1, arg2: np.polysub(arg1, arg2) + jnp_fun = lambda arg1, arg2: jnp.polysub(arg1, arg2) + args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @unittest.skip("Jax-metal don't support map op.") + @jtu.sample_product( + [dict(order=order, k=k, dtype=dtype) + for dtype in default_dtypes + for order in range(5) + for k in [np.arange(order, dtype=dtype), np.ones(1, dtype), None]], + a_shape=one_dim_array_shapes, + ) + def testPolyInt(self, a_shape, order, k, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1: np.polyint(arg1, m=order, k=k) + jnp_fun = lambda arg1: jnp.polyint(arg1, m=order, k=k) + args_maker = lambda: [rng(a_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @unittest.skip("Jax-metal don't support map op.") + @jtu.sample_product( + dtype=default_dtypes, + a_shape=one_dim_array_shapes, + order=list(range(5)), + ) + def testPolyDer(self, a_shape, order, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1: np.polyder(arg1, m=order) + jnp_fun = lambda arg1: jnp.polyder(arg1, m=order) + args_maker = lambda: [rng(a_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @parameterized.parameters(['int', 'np.int', 'jnp.int']) + def testIntegerPower(self, ptype): + p = {'int': 2, 'np.int': np.int32(2), 'jnp.int': jnp.int32(2)}[ptype] + jaxpr = jax.make_jaxpr(lambda x1: jnp.power(x1, p))(1) + eqns = jaxpr.jaxpr.eqns + self.assertLen(eqns, 1) + self.assertEqual(eqns[0].primitive, lax.integer_pow_p) + + @jtu.sample_product( + x=[-1, 0, 1], + y=[0, 32, 64, 128], + ) + def testIntegerPowerOverflow(self, x, y): + # Regression test for https://github.com/google/jax/issues/5987 + args_maker = lambda: [x, y] + self._CheckAgainstNumpy(np.power, jnp.power, args_maker) + self._CompileAndCheck(jnp.power, args_maker) + @jtu.sample_product( [dict(shape=shape, axis=axis) - for shape in [(2,3,4)] - for axis in (-1, *range(len(shape))) + for shape in all_shapes + for axis in [None] + list(range(len(shape))) ], - dtype=float_dtypes, + dtype=all_dtypes, ) - def testSortKeyVal(self, dtype, shape, axis): - #self.skipTest("JAX metal does not support sortkeyval.") - rng = jtu.rand_default(self.rng()) - # This test relies on the property that wherever keys are tied, values are - # too, since we don't guarantee the same ordering of values with equal keys. - # To avoid that case, we generate unique keys (globally in the key array). - def args_maker(): - flat_keys = np.arange(math.prod(shape), dtype=dtype) - keys = self.rng().permutation(flat_keys).reshape(shape) - values = rng(shape, dtype) - return keys, values + def testCompress(self, shape, dtype, axis): + rng = jtu.rand_some_zero(self.rng()) + if shape in scalar_shapes or len(shape) == 0: + cond_shape = (0,) + elif axis is None: + cond_shape = (math.prod(shape),) + else: + cond_shape = (shape[axis],) - fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable=True) - numpy_op = lambda ks, vs: lax_reference.sort_key_val(ks, vs, axis) - self._CheckAgainstNumpy(fun, numpy_op, args_maker) + args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)] + + np_fun = partial(np.compress, axis=axis) + jnp_fun = partial(jnp.compress, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + shape=[(2, 3)], + dtype=int_dtypes, + # condition entries beyond axis size must be zero. + condition=[[1], [1, 0, 0, 0, 0, 0, 0]], + axis=[None, 0, 1], + ) + def testCompressMismatchedShapes(self, shape, dtype, condition, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [np.array(condition), rng(shape, dtype)] + np_fun = partial(np.compress, axis=axis) + jnp_fun = partial(jnp.compress, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in array_shapes + for axis in [None] + list(range(len(shape))) + ], + dtype=all_dtypes, + ) + def testCompressMethod(self, shape, dtype, axis): + rng = jtu.rand_some_zero(self.rng()) + if shape in scalar_shapes or len(shape) == 0: + cond_shape = (0,) + elif axis is None: + cond_shape = (math.prod(shape),) + else: + cond_shape = (shape[axis],) + + args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)] + + np_fun = lambda condition, x: np.compress(condition, x, axis=axis) + jnp_fun = lambda condition, x: x.compress(condition, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + [dict(base_shape=base_shape, axis=axis) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in (None, *range(-len(base_shape)+1, len(base_shape))) + ], + arg_dtypes=[ + arg_dtypes + for num_arrs in [3] + for arg_dtypes in itertools.combinations_with_replacement(default_dtypes, num_arrs) + ], + dtype=[None] + default_dtypes, + ) + def testConcatenate(self, axis, dtype, base_shape, arg_dtypes): + rng = jtu.rand_default(self.rng()) + wrapped_axis = 0 if axis is None else axis % len(base_shape) + shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] + for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] + @jtu.promote_like_jnp + def np_fun(*args, dtype=dtype): + dtype = dtype or args[0].dtype + args = [x if x.dtype != jnp.bfloat16 else x.astype(np.float32) + for x in args] + return np.concatenate(args, axis=axis, dtype=dtype, casting='unsafe') + jnp_fun = lambda *args: jnp.concatenate(args, axis=axis, dtype=dtype) + + def args_maker(): + return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] + + with jtu.strict_promotion_if_dtypes_match(arg_dtypes): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in [(4, 1), (4, 3), (4, 5, 6)] + for axis in [None] + list(range(1 - len(shape), len(shape) - 1)) + ], + dtype=all_dtypes, + ) + def testConcatenateArray(self, shape, dtype, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda x: np.concatenate(x, axis=axis) + jnp_fun = lambda x: jnp.concatenate(x, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testConcatenateAxisNone(self): + # https://github.com/google/jax/issues/3419 + a = jnp.array([[1, 2], [3, 4]]) + b = jnp.array([[5]]) + jnp.concatenate((a, b), axis=None) + + def testConcatenateScalarAxisNone(self): + arrays = [np.int32(0), np.int32(1)] + self.assertArraysEqual(jnp.concatenate(arrays, axis=None), + np.concatenate(arrays, axis=None)) + + @jtu.sample_product( + [dict(base_shape=base_shape, axis=axis) + for base_shape in [(), (4,), (3, 4), (2, 3, 4)] + for axis in (None, *range(-len(base_shape)+1, len(base_shape))) + ], + dtype=default_dtypes, + ) + def testConcat(self, axis, base_shape, dtype): + rng = jtu.rand_default(self.rng()) + wrapped_axis = 0 if axis is None else axis % len(base_shape) + shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] + for size in [3, 1, 4]] + @jtu.promote_like_jnp + def np_fun(*args): + if jtu.numpy_version() >= (2, 0, 0): + return np.concat(args, axis=axis) + else: + return np.concatenate(args, axis=axis) + jnp_fun = lambda *args: jnp.concat(args, axis=axis) + args_maker = lambda: [rng(shape, dtype) for shape in shapes] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(base_shape=base_shape, axis=axis) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(-len(base_shape)+1, len(base_shape))], + arg_dtypes=itertools.combinations_with_replacement(default_dtypes, 2) + ) + def testAppend(self, axis, base_shape, arg_dtypes): + rng = jtu.rand_default(self.rng()) + wrapped_axis = axis % len(base_shape) + shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] + for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] + def np_fun(arr, values): + arr = arr.astype(np.float32) if arr.dtype == jnp.bfloat16 else arr + values = (values.astype(np.float32) if values.dtype == jnp.bfloat16 + else values) + out = np.append(arr, values, axis=axis) + return out.astype(jnp.promote_types(*arg_dtypes)) + jnp_fun = lambda arr, values: jnp.append(arr, values, axis=axis) + + def args_maker(): + return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] + + with jtu.strict_promotion_if_dtypes_match(arg_dtypes): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis, idx=idx) + for shape in nonempty_nonscalar_array_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + for idx in (range(-math.prod(shape), math.prod(shape)) + if axis is None else + range(-shape[axis], shape[axis]))], + dtype=all_dtypes, + ) + def testDeleteInteger(self, shape, dtype, idx, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, idx, axis=axis) + jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + ], + dtype=all_dtypes, + slc=[slice(None), slice(1, 3), slice(1, 5, 2)], + ) + def testDeleteSlice(self, shape, dtype, axis, slc): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, slc, axis=axis) + jnp_fun = lambda arg: jnp.delete(arg, slc, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + ], + dtype=all_dtypes, + idx_shape=all_shapes, + ) + def testDeleteIndexArray(self, shape, dtype, axis, idx_shape): + rng = jtu.rand_default(self.rng()) + max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] + idx = jtu.rand_int(self.rng(), low=-max_idx, high=max_idx)(idx_shape, int) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, idx, axis=axis) + jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + ], + dtype=all_dtypes, + idx_shape=all_shapes, + ) + def testDeleteUniqueIndices(self, shape, dtype, axis, idx_shape): + rng = jtu.rand_default(self.rng()) + max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] + idx_size = np.zeros(idx_shape).size + if idx_size > max_idx: + self.skipTest("Too many indices to be unique") + def args_maker(): + x = rng(shape, dtype) + idx = self.rng().choice(max_idx, idx_shape, replace=False) + return x, idx + np_fun = partial(np.delete, axis=axis) + jnp_fun = partial(jnp.delete, axis=axis, assume_unique_indices=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + ], + dtype=all_dtypes, + ) + def testDeleteMaskArray(self, shape, dtype, axis): + rng = jtu.rand_default(self.rng()) + mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] + mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool) + if numpy_version == (1, 23, 0) and mask.shape == (1,): + # https://github.com/numpy/numpy/issues/21840 + self.skipTest("test fails for numpy v1.23.0") + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, mask, axis=axis) + jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skip("JAX-metal fail.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + ], + dtype=all_dtypes, + ) + def testInsertInteger(self, shape, dtype, axis): + x = jnp.empty(shape) + max_ind = x.size if axis is None else x.shape[axis] + rng = jtu.rand_default(self.rng()) + i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind) + args_maker = lambda: [rng(shape, dtype), i_rng((), np.int32), rng((), dtype)] + np_fun = lambda *args: np.insert(*args, axis=axis) + jnp_fun = lambda *args: jnp.insert(*args, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skip("Jax-metal fail.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + ], + dtype=all_dtypes, + ) + def testInsertSlice(self, shape, dtype, axis): + x = jnp.empty(shape) + max_ind = x.size if axis is None else x.shape[axis] + rng = jtu.rand_default(self.rng()) + i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind) + slc = slice(i_rng((), jnp.int32).item(), i_rng((), jnp.int32).item()) + args_maker = lambda: [rng(shape, dtype), rng((), dtype)] + np_fun = lambda x, val: np.insert(x, slc, val, axis=axis) + jnp_fun = lambda x, val: jnp.insert(x, slc, val, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @parameterized.parameters([ + [[[1, 1], [2, 2], [3, 3]], 1, 5, None], + [[[1, 1], [2, 2], [3, 3]], 1, 5, 1], + [[[1, 1], [2, 2], [3, 3]], 1, [1, 2, 3], 1], + [[[1, 1], [2, 2], [3, 3]], [1], [[1],[2],[3]], 1], + [[1, 1, 2, 2, 3, 3], [2, 2], [5, 6], None], + [[1, 1, 2, 2, 3, 3], slice(2, 4), [5, 6], None], + [[1, 1, 2, 2, 3, 3], [2, 2], [7.13, False], None], + [[[0, 1, 2, 3], [4, 5, 6, 7]], (1, 3), 999, 1] + ]) + def testInsertExamples(self, arr, index, values, axis): + # Test examples from the np.insert docstring + args_maker = lambda: ( + np.asarray(arr), index if isinstance(index, slice) else np.array(index), + np.asarray(values), axis) + self._CheckAgainstNumpy(np.insert, jnp.insert, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_array_shapes + for axis in range(-len(shape), len(shape)) + ], + dtype=default_dtypes, + out_dims=[0, 1, 2], + ) + def testApplyAlongAxis(self, shape, dtype, axis, out_dims): + def func(x, out_dims): + if out_dims == 0: + return x.sum(dtype=x.dtype) + elif out_dims == 1: + return x * x[0] + elif out_dims == 2: + return x[:, None] + x[None, :] + else: + raise NotImplementedError(f"{out_dims=}") + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arr: np.apply_along_axis(func, axis, arr, out_dims=out_dims) + jnp_fun = lambda arr: jnp.apply_along_axis(func, axis, arr, out_dims=out_dims) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + atol={dtypes.bfloat16: 2e-2}) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axes=axes) + for shape in nonempty_shapes + for axes in itertools.combinations(range(len(shape)), 2) + ], + func=["sum"], + keepdims=[True, False], + # Avoid low-precision types in sum() + dtype=[dtype for dtype in default_dtypes + if dtype not in [np.float16, jnp.bfloat16]], + ) + def testApplyOverAxes(self, shape, dtype, func, keepdims, axes): + f = lambda x, axis: getattr(x, func)(axis=axis, keepdims=keepdims, dtype=dtype) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: (rng(shape, dtype),) + np_fun = lambda a: np.apply_over_axes(f, a, axes) + jnp_fun = lambda a: jnp.apply_over_axes(f, a, axes) + self._CompileAndCheck(jnp_fun, args_maker) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, dtype=dtype, axis=axis) + for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) + for axis in [None] + list(range(-len(shape), max(1, len(shape)))) + ], + repeats=[0, 1, 2], + fixed_size=[False, True], + ) + def testRepeat(self, axis, shape, dtype, repeats, fixed_size): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np.repeat(arg, repeats=repeats, axis=axis) + np_fun = jtu.promote_like_jnp(np_fun) + if fixed_size: + total_repeat_length = np.repeat(np.zeros(shape), repeats, axis).shape[axis or 0] + jnp_fun = lambda arg, rep: jnp.repeat(arg, repeats=rep, axis=axis, + total_repeat_length=total_repeat_length) + jnp_args_maker = lambda: [rng(shape, dtype), repeats] + clo_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis, + total_repeat_length=total_repeat_length) + clo_fun_args_maker = lambda: [rng(shape, dtype)] + self._CompileAndCheck(jnp_fun, jnp_args_maker) + self._CheckAgainstNumpy(np_fun, clo_fun, clo_fun_args_maker) + else: + # Now repeats is in a closure, so a constant. + jnp_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testRepeatScalarFastPath(self): + a = jnp.array([1,2,3,4]) + f = lambda a: jnp.repeat(a, repeats=2) + jaxpr = jax.make_jaxpr(f)(a) + self.assertLessEqual(len(jaxpr.jaxpr.eqns), 6) + + @unittest.skip("jax-metal fail to convert sort op.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in [None] + list(range(len(shape)))], + dtype=number_dtypes, + return_index=[False, True], + return_inverse=[False, True], + return_counts=[False, True], + ) + def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + extra_args = (return_index, return_inverse, return_counts) + use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False + np_fun = jtu.with_jax_dtype_defaults(lambda x: np_unique_backport(x, *extra_args, axis=axis), use_defaults) + jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueAll(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + np_fun = partial(np_unique_backport, return_index=True, return_inverse=True, return_counts=True) + else: + np_fun = np.unique_all + self._CheckAgainstNumpy(jnp.unique_all, np_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueCounts(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + np_fun = lambda x: np.unique(x, return_counts=True) + else: + np_fun = np.unique_counts + self._CheckAgainstNumpy(jnp.unique_counts, np_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueInverse(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + np_fun = partial(np_unique_backport, return_inverse=True) + else: + np_fun = np.unique_inverse + self._CheckAgainstNumpy(jnp.unique_inverse, np_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueValues(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + np_fun = np.unique + else: + np_fun = np.unique_values + self._CheckAgainstNumpy(jnp.unique_values, np_fun, args_maker) + + @unittest.skip("jax-metal fail.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_array_shapes + for axis in [None] + list(range(len(shape)))], + dtype=number_dtypes, + size=[1, 5, 10], + fill_value=[None, 0, "slice"], + ) + def testUniqueSize(self, shape, dtype, axis, size, fill_value): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + kwds = dict(axis=axis, return_index=True, return_inverse=True, return_counts=True) + + if fill_value == "slice": + if axis is None: + fill_value = rng((), dtype) + else: + fill_value = rng(shape[:axis] + shape[axis + 1:], dtype) + elif fill_value is not None: + fill_value = np.array(fill_value).astype(dtype) + + @partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True)) + def np_fun(x, fill_value=fill_value): + u, ind, inv, counts = np_unique_backport(x, **kwds) + axis = kwds['axis'] + if axis is None: + x = x.ravel() + axis = 0 + + n_unique = u.shape[axis] + if size <= u.shape[axis]: + slc = (slice(None),) * axis + (slice(size),) + u, ind, counts = u[slc], ind[:size], counts[:size] + else: + extra = (0, size - n_unique) + pads = [(0, 0)] * u.ndim + pads[axis] = extra + u = np.pad(u, pads, constant_values=0) + slices = [slice(None)] * u.ndim + slices[axis] = slice(1) + if fill_value is None: + fill_value = u[tuple(slices)] + elif np.ndim(fill_value): + fill_value = lax.expand_dims(fill_value, (axis,)) + slices[axis] = slice(n_unique, None) + u[tuple(slices)] = fill_value + ind = np.pad(ind, extra, constant_values=ind[0]) + counts = np.pad(counts, extra, constant_values=0) + return u, ind, inv, counts + + jnp_fun = lambda x: jnp.unique(x, size=size, fill_value=fill_value, **kwds) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skip("jax-metal fail.") + @jtu.sample_product(dtype=inexact_dtypes) + def testUniqueNans(self, dtype): + if numpy_version == (1, 23, 0) and dtype == np.float16: + # https://github.com/numpy/numpy/issues/21838 + self.skipTest("Known failure on numpy 1.23.0") + def args_maker(): + x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan] + if np.issubdtype(dtype, np.complexfloating): + x = [complex(i, j) for i, j in itertools.product(x, repeat=2)] + return [np.array(x, dtype=dtype)] + + kwds = dict(return_index=True, return_inverse=True, return_counts=True) + jnp_fun = partial(jnp.unique, **kwds) + def np_fun(x): + dtype = x.dtype + # numpy unique fails for bfloat16 NaNs, so we cast to float64 + if x.dtype == jnp.bfloat16: + x = x.astype('float64') + u, *rest = np.unique(x, **kwds) + return (u.astype(dtype), *rest) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @unittest.skip("jax-metal fail.") + @jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False]) + def testUniqueEqualNan(self, dtype, equal_nan): + if numpy_version < (1, 24, 0): + self.skipTest("np.unique equal_nan requires NumPy 1.24 or newer.") + shape = (20,) + rng = jtu.rand_some_nan(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + def np_fun(x): + dtype = x.dtype + # numpy unique fails for bfloat16 NaNs, so we cast to float64 + if x.dtype == jnp.bfloat16: + x = x.astype('float64') + return np.unique(x, equal_nan=equal_nan).astype(dtype) + jnp_fun = partial(jnp.unique, equal_nan=equal_nan) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product(fixed_size=[False, True]) + def testNonScalarRepeats(self, fixed_size): + ''' + Following numpy test suite from `test_repeat` at + https://github.com/numpy/numpy/blob/main/numpy/core/tests/test_multiarray.py + ''' + tol = 1e-5 + + def test_single(m, args_maker, repeats, axis): + lax_ans = jnp.repeat(m, repeats, axis) + numpy_ans = np.repeat(m, repeats, axis) + + self.assertAllClose(lax_ans, numpy_ans, rtol=tol, atol=tol) + if fixed_size: + + # Calculate expected size of the repeated axis. + rep_length = np.repeat(np.zeros_like(m), repeats, axis).shape[axis or 0] + jnp_fun = lambda arg, rep: jnp.repeat( + arg, repeats=rep, axis=axis, total_repeat_length=rep_length) + else: + jnp_fun = lambda arg: jnp.repeat(arg, repeats = repeats, axis=axis) + self._CompileAndCheck(jnp_fun, args_maker) + + m = jnp.array([1,2,3,4,5,6]) + if fixed_size: + args_maker = lambda: [m, repeats] + else: + args_maker = lambda: [m] + + for repeats in [2, jnp.array([1,3,0,1,1,2]), jnp.array([1,3,2,1,1,2]), jnp.array([2])]: + test_single(m, args_maker, repeats, axis=None) + test_single(m, args_maker, repeats, axis=0) + + m_rect = m.reshape((2,3)) + if fixed_size: + args_maker = lambda: [m_rect, repeats] + else: + args_maker = lambda: [m_rect] + + for repeats in [2, jnp.array([2,1]), jnp.array([2])]: + test_single(m_rect, args_maker, repeats, axis=0) + + for repeats in [2, jnp.array([1,3,2]), jnp.array([2])]: + test_single(m_rect, args_maker, repeats, axis=1) + + def testIssue2330(self): + ''' + Make sure return value of jnp.concatenate is a jax.ndarray and is side-effect save + ''' + def attempt_sideeffect(x): + x = [x] + x = jnp.concatenate(x) + x -= 1. + return x + + np_input = np.ones(1) + jnp_input = jnp.ones(1) + expected_np_input_after_call = np.ones(1) + expected_jnp_input_after_call = jnp.ones(1) + + out = jnp.concatenate([np_input]) + self.assertIs(type(out), array.ArrayImpl) + + attempt_sideeffect(np_input) + attempt_sideeffect(jnp_input) + + self.assertAllClose(np_input, expected_np_input_after_call) + self.assertAllClose(jnp_input, expected_jnp_input_after_call) + + @unittest.skip("Jax-metal fail to convert 1D convolution op.") + @jtu.sample_product( + mode=['full', 'same', 'valid'], + op=['convolve', 'correlate'], + dtype= float_dtypes, #number_dtypes, + xshape=one_dim_array_shapes, + yshape=one_dim_array_shapes, + ) + def testConvolutions(self, xshape, yshape, dtype, mode, op): + jnp_op = getattr(jnp, op) + np_op = getattr(np, op) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] + precision = lax.Precision.HIGHEST if jtu.test_device_matches(["tpu"]) else None + jnp_fun = partial(jnp_op, mode=mode, precision=precision) + def np_fun(x, y): + return np_op(x, y, mode=mode).astype(dtypes.to_inexact_dtype(dtype)) + tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14, + np.complex128: 1e-14} + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skip("Jax-metal fail to convert 1D convolution op.") + @jtu.sample_product( + mode=['full', 'same', 'valid'], + op=['convolve', 'correlate'], + dtype=float_dtypes, #number_dtypes, + xshape=one_dim_array_shapes, + yshape=one_dim_array_shapes, + ) + @jtu.skip_on_devices("cuda", "rocm") # backends don't support all dtypes. + def testConvolutionsPreferredElementType(self, xshape, yshape, dtype, mode, op): + jnp_op = getattr(jnp, op) + np_op = getattr(np, op) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] + precision = lax.Precision.HIGHEST if jtu.test_device_matches(["tpu"]) else None + jnp_fun = partial(jnp_op, mode=mode, precision=precision, + preferred_element_type=dtype) + def np_fun(x, y): + return np_op(x, y, mode=mode).astype(dtype) + tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14, + np.complex128: 1e-14} + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in [None] + list(range(-len(shape), len(shape)))], + op=["cumsum", "cumprod"], + dtype=all_dtypes, + out_dtype=[dtype for dtype in default_dtypes if dtype != np.float16], + ) + def testCumSumProd(self, axis, shape, dtype, out_dtype, op): + jnp_op = getattr(jnp, op) + np_op = getattr(np, op) + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype) + np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=RuntimeWarning, + message="overflow encountered.*")(np_fun) + jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype) + jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + + args_maker = lambda: [rng(shape, dtype)] + + tol_thresholds = {dtypes.bfloat16: 4e-2} + tol = max(jtu.tolerance(dtype, tol_thresholds), + jtu.tolerance(out_dtype, tol_thresholds)) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in [None] + list(range(-len(shape), len(shape)))], + op=["nancumsum", "nancumprod"], + dtype=all_dtypes, + out_dtype=default_dtypes, + ) + def testNanCumSumProd(self, axis, shape, dtype, out_dtype, op): + jnp_op = getattr(jnp, op) + np_op = getattr(np, op) + rng = jtu.rand_some_nan(self.rng()) + np_fun = partial(np_op, axis=axis, dtype=out_dtype) + np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=RuntimeWarning, + message="overflow encountered.*")(np_fun) + jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype) + jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + + args_maker = lambda: [rng(shape, dtype)] + + tol_thresholds = {dtypes.bfloat16: 4e-2, np.float16: 3e-3} + tol = max(jtu.tolerance(dtype, tol_thresholds), + jtu.tolerance(out_dtype, tol_thresholds)) + if dtype != jnp.bfloat16: + # numpy functions do not properly handle bfloat16 + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @unittest.skip("Jax-metal fail on testEye2") + @jtu.sample_product( + dtype=default_dtypes, + n=[0, 4], + m=[None, 0, 1, 3, 4], + k=[*range(-4, 4), -2**100, 2**100], + ) + def testEye(self, n, m, k, dtype): + np_fun = lambda: np.eye(n, M=m, k=k, dtype=dtype) + jnp_fun = lambda: jnp.eye(n, M=m, k=k, dtype=dtype) + args_maker = lambda: [] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + n=[0, 4], + m=[None, 0, 1, 3, 4], + k=range(-4, 4), + ) + def testTri(self, m, n, k, dtype): + np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype) + jnp_fun = lambda: jnp.tri(n, M=m, k=k, dtype=dtype) + args_maker = lambda: [] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + shape=[shape for shape in all_shapes if len(shape) >= 2], + op=["tril", "triu"], + k=list(range(-3, 3)), + ) + def testTriLU(self, dtype, shape, op, k): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: getattr(np, op)(arg, k=k) + jnp_fun = lambda arg: getattr(jnp, op)(arg, k=k) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + n=range(5), + k=range(-3, 3), + m=[None, *range(5)], + ) + def testTrilIndices(self, n, k, m): + np_fun = lambda n, k, m: np.tril_indices(n, k=k, m=m) + jnp_fun = lambda n, k, m: jnp.tril_indices(n, k=k, m=m) + args_maker = lambda: [n, k, m] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + n=range(5), + k=range(-3, 3), + m=[None, *range(5)], + ) + def testTriuIndices(self, n, k, m): + np_fun = lambda n, k, m: np.triu_indices(n, k=k, m=m) + jnp_fun = lambda n, k, m: jnp.triu_indices(n, k=k, m=m) + args_maker = lambda: [n, k, m] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + shape=[(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)], + k=[-1, 0, 1], + ) + def testTriuIndicesFrom(self, shape, dtype, k): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arr, k: np.triu_indices_from(arr, k=k) + jnp_fun = lambda arr, k: jnp.triu_indices_from(arr, k=k) + args_maker = lambda: [rng(shape, dtype), k] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + shape=[(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)], + k=[-1, 0, 1], + ) + def testTrilIndicesFrom(self, shape, dtype, k): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arr, k: np.tril_indices_from(arr, k=k) + jnp_fun = lambda arr, k: jnp.tril_indices_from(arr, k=k) + args_maker = lambda: [rng(shape, dtype), k] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + a_shape=[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2), (1, 2), (0, 2), (2, 3), (2, 2, 2), (2, 2, 2, 2)], + val_shape=[(), (1,), (2,), (1, 2), (3, 2)], + ) + def testFillDiagonal(self, dtype, a_shape, val_shape): + rng = jtu.rand_default(self.rng()) + + def np_fun(a, val): + a_copy = a.copy() + np.fill_diagonal(a_copy, val) + return a_copy + + jnp_fun = partial(jnp.fill_diagonal, inplace=False) + args_maker = lambda : [rng(a_shape, dtype), rng(val_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + ndim=[0, 1, 4], + n=[0, 1, 7], + ) + def testDiagIndices(self, ndim, n): + np.testing.assert_equal(jtu.with_jax_dtype_defaults(np.diag_indices)(n, ndim), + jnp.diag_indices(n, ndim)) + + @jtu.sample_product( + dtype=default_dtypes, + shape=[(1,1), (2,2), (3,3), (4,4), (5,5)], + ) + def testDiagIndicesFrom(self, dtype, shape): + rng = jtu.rand_default(self.rng()) + np_fun = jtu.with_jax_dtype_defaults(np.diag_indices_from) + jnp_fun = jnp.diag_indices_from + args_maker = lambda : [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + shape=[shape for shape in all_shapes if len(shape) in (1, 2)], + k=list(range(-4, 4)), + ) + def testDiag(self, shape, dtype, k): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np.diag(arg, k) + jnp_fun = lambda arg: jnp.diag(arg, k) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + shape=all_shapes, + k=list(range(-4, 4)), + ) + def testDiagFlat(self, shape, dtype, k): + rng = jtu.rand_default(self.rng()) + # numpy has inconsistencies for scalar values + # https://github.com/numpy/numpy/issues/16477 + # jax differs in that it treats scalars values as length-1 arrays + np_fun = lambda arg: np.diagflat(np.atleast_1d(arg), k) + jnp_fun = lambda arg: jnp.diagflat(arg, k) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + + @unittest.skip("jax-metal fail.") + @jtu.sample_product( + dtype=default_dtypes, + a1_shape=one_dim_array_shapes, + a2_shape=one_dim_array_shapes, + ) + def testPolyMul(self, a1_shape, a2_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1, arg2: np.polymul(arg1, arg2) + jnp_fun_np = lambda arg1, arg2: jnp.polymul(arg1, arg2, trim_leading_zeros=True) + jnp_fun_co = lambda arg1, arg2: jnp.polymul(arg1, arg2) + args_maker = lambda: [rng(a1_shape, dtype), rng(a2_shape, dtype)] + tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13} + self._CheckAgainstNumpy(np_fun, jnp_fun_np, args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp_fun_co, args_maker, check_dtypes=False) + + @unittest.skip("jax-metal fail.") + @jtu.sample_product( + dtype=[dtype for dtype in default_dtypes + if dtype not in (np.float16, jnp.bfloat16)], + a_shape=one_dim_array_shapes, + b_shape=one_dim_array_shapes, + ) + def testPolyDiv(self, a_shape, b_shape, dtype): + rng = jtu.rand_default(self.rng()) + + @jtu.ignore_warning(category=RuntimeWarning, message="divide by zero.*") + @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") + @jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*") + def np_fun(arg1, arg2): + q, r = np.polydiv(arg1, arg2) + while r.size < max(arg1.size, arg2.size): # Pad residual to same size + r = np.pad(r, (1, 0), 'constant') + return q, r + + def jnp_fun(arg1, arg2): + q, r = jnp.polydiv(arg1, arg2, trim_leading_zeros=True) + while r.size < max(arg1.size, arg2.size): # Pad residual to same size + r = jnp.pad(r, (1, 0), 'constant') + return q, r + + args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] + tol = { + dtypes.bfloat16: 2e-1, + np.float16: 2e-1, + np.float32: 5e-2, + np.float64: 5e-7 + } + + jnp_compile = jnp.polydiv # Without trim_leading_zeros (trim_zeros make it unable to be compiled by XLA) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp_compile, args_maker, check_dtypes=True, atol=tol, rtol=tol) + + @jtu.sample_product( + [dict(shape=shape, axis1=axis1, axis2=axis2) + for shape in [shape for shape in all_shapes if len(shape) >= 2] + for axis1 in range(-len(shape), len(shape)) + for axis2 in [a for a in range(-len(shape), len(shape)) + if a % len(shape) != axis1 % len(shape)] + ], + dtype=default_dtypes, + offset=list(range(-4, 4)), + ) + def testDiagonal(self, shape, dtype, offset, axis1, axis2): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np.diagonal(arg, offset, axis1, axis2) + jnp_fun = lambda arg: jnp.diagonal(arg, offset, axis1, axis2) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype=default_dtypes, + n=list(range(4)), + ) + def testIdentity(self, n, dtype): + np_fun = lambda: np.identity(n, dtype) + jnp_fun = lambda: jnp.identity(n, dtype) + args_maker = lambda: [] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skip("jax-metal crash.") + @jtu.sample_product( + shape=nonempty_shapes, + period=[None, 0.59], + left=[None, 0], + right=[None, 1], + # Note: skip 8-bit and 16-bit types due to insufficient precision. + dtype=jtu.dtypes.integer + jtu.dtypes.floating, + target_dtype=jtu.dtypes.inexact, + ) + def testInterp(self, shape, dtype, period, left, right, target_dtype): + rng = jtu.rand_default(self.rng(), scale=10) + kwds = dict(period=period, left=left, right=right) + np_fun = partial(np.interp, **kwds) + jnp_fun = partial(jnp.interp, **kwds) + + args_maker = lambda: [rng(shape, dtype), np.unique(rng((100,), dtype))[:20], + rng((20,), target_dtype)] + + with jtu.strict_promotion_if_dtypes_match([dtype, target_dtype]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + rtol=3e-3, atol=1e-3) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skip("jax-metal crash.") + @jtu.sample_product([ + dict(x=0.5, left='extrapolate', expected=5), + dict(x=1.5, left='extrapolate', expected=15), + dict(x=3.5, left='extrapolate', expected=30), + dict(x=3.9, right='extrapolate', expected=39), + ]) + def testInterpExtrapoate(self, x, expected, **kwargs): + xp = jnp.array([1.0, 2.0, 3.0]) + fp = jnp.array([10.0, 20.0, 30.0]) + actual = jnp.interp(x, xp, fp, **kwargs) + self.assertAlmostEqual(actual, expected) + + def testInterpErrors(self): + with self.assertRaisesWithLiteralMatch( + ValueError, + 'xp and fp must be one-dimensional arrays of equal size' + ): + jnp.interp(0.0, jnp.arange(2.0), jnp.arange(3.0)) + with self.assertRaisesWithLiteralMatch( + ValueError, + "the only valid string value of `left` is 'extrapolate', but got: 'interpolate'" + ): + jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), left='interpolate') + with self.assertRaisesWithLiteralMatch( + ValueError, + "the only valid string value of `right` is 'extrapolate', but got: 'interpolate'" + ): + jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), right='interpolate') + with self.assertRaisesWithLiteralMatch( + ValueError, + "jnp.interp: complex x values not supported." + ): + jnp.interp(1j, 1j * np.arange(3.0), np.arange(3.0)) + with self.assertRaisesRegex( + ValueError, + "period must be a scalar; got" + ): + jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), period=np.array([1.0])) + + @jtu.sample_product( + period=[None, 0.59], + left=[None, 0], + right=[None, 1], + dtype=jtu.dtypes.floating, + ) + def testInterpGradNan(self, dtype, period, left, right): + kwds = dict(period=period, left=left, right=right) + jnp_fun = partial(jnp.interp, **kwds) + # Probe values of x and xp that are close to zero and close together. + x = dtype(np.exp(np.linspace(-90, -20, 1000))) + g = jax.grad(lambda z: jnp.sum(jnp_fun(z, z, jnp.ones_like(z))))(x) + np.testing.assert_equal(np.all(np.isfinite(g)), True) + + @jtu.sample_product( + [dict(x1_shape=x1_shape, x2_shape=x2_shape) + for x1_shape, x2_shape in filter(_shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(array_shapes, 2)) + ], + x1_rng_factory=[jtu.rand_some_inf_and_nan, jtu.rand_some_zero], + x2_rng_factory=[partial(jtu.rand_int, low=-1075, high=1024)], + x1_dtype=default_dtypes, + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLdexp(self, x1_shape, x1_dtype, x2_shape, x1_rng_factory, x2_rng_factory): + x1_rng = x1_rng_factory(self.rng()) + x2_rng = x2_rng_factory(self.rng()) + + @jtu.ignore_warning(category=RuntimeWarning, message="overflow.*") + def np_fun(x1, x2): + out_dtype = dtypes.to_inexact_dtype(x1.dtype) + return np.ldexp(x1.astype(out_dtype), x2) + + jnp_fun = jnp.ldexp + args_maker = lambda: [x1_rng(x1_shape, x1_dtype), + x2_rng(x2_shape, np.int32)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + rng_factory=[ + jtu.rand_some_inf_and_nan, + jtu.rand_some_zero, + partial(jtu.rand_not_small, offset=1e8), + ], + shape=all_shapes, + dtype=default_dtypes, + ) + def testFrexp(self, shape, dtype, rng_factory): + # integer types are converted to float64 in numpy's implementation + if (dtype not in [jnp.bfloat16, np.float16, np.float32] + and not config.enable_x64.value): + self.skipTest("Only run float64 testcase when float64 is enabled.") + rng = rng_factory(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + def np_frexp(x): + mantissa, exponent = np.frexp(x) + # NumPy is inconsistent between Windows and Linux/Mac on what the + # value of exponent is if the input is infinite. Normalize to the Linux + # behavior. + exponent = np.where(np.isinf(mantissa), np.zeros_like(exponent), exponent) + return mantissa, exponent + self._CheckAgainstNumpy(np_frexp, jnp.frexp, args_maker, + check_dtypes=np.issubdtype(dtype, np.inexact)) + self._CompileAndCheck(jnp.frexp, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis1=axis1, axis2=axis2) + for shape in [shape for shape in all_shapes if len(shape) >= 2] + for axis1 in range(-len(shape), len(shape)) + for axis2 in range(-len(shape), len(shape)) + if (axis1 % len(shape)) != (axis2 % len(shape)) + ], + dtype=default_dtypes, + out_dtype=[None] + number_dtypes, + offset=list(range(-4, 4)), + ) + def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2): + rng = jtu.rand_default(self.rng()) + def np_fun(arg): + if out_dtype == jnp.bfloat16: + return np.trace(arg, offset, axis1, axis2, np.float32).astype(jnp.bfloat16) + else: + return np.trace(arg, offset, axis1, axis2, out_dtype) + jnp_fun = lambda arg: jnp.trace(arg, offset, axis1, axis2, out_dtype) + args_maker = lambda: [rng(shape, dtype)] + # TODO: Fails with uint8/uint16 output dtypes (integer overflow?) + if out_dtype not in (np.uint8, np.uint16, np.uint32): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + #unittest.skip("jax-metal fail with empty vshape.") + @jtu.sample_product( + ashape=[(15,), (16,), (17,)], + vshape= [(5,), (5, 5)],#[(), (5,), (5, 5)], + side=['left', 'right'], + dtype= number_dtypes, + method=['sort', 'scan', 'scan_unrolled', 'compare_all'], + ) + def testSearchsorted(self, ashape, vshape, side, dtype, method): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [np.sort(rng(ashape, dtype)), rng(vshape, dtype)] + def np_fun(a, v): + return np.searchsorted(a, v, side=side).astype('int32') + jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side, method=method) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skipIf( + platform.system() == "Windows", + "Under Windows, NumPy throws if 2**32 is converted to an int32" + ) + def testSearchsortedDtype(self): + # Test that for large arrays, int64 indices are used. We test this + # via abstract evaluation to avoid allocating a large array in tests. + a_int32 = core.ShapedArray((np.iinfo(np.int32).max,), np.float32) + a_int64 = core.ShapedArray((np.iinfo(np.int32).max + 1,), np.float32) + v = core.ShapedArray((), np.float32) + + out_int32 = jax.eval_shape(jnp.searchsorted, a_int32, v) + self.assertEqual(out_int32.dtype, np.int32) + + if config.enable_x64.value: + out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) + self.assertEqual(out_int64.dtype, np.int64) + elif jtu.numpy_version() < (2, 0, 0): + with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"): + with jtu.ignore_warning(category=DeprecationWarning, + message="NumPy will stop allowing conversion.*"): + out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) + else: + with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"): + with self.assertRaisesRegex(OverflowError, "Python integer 2147483648 out of bounds.*"): + out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) + + @unittest.skip("Jax-metal fail.") + @jtu.sample_product( + dtype=inexact_dtypes, + side=['left', 'right'], + method=['sort', 'scan', 'compare_all'], + ) + def testSearchsortedNans(self, dtype, side, method): + if np.issubdtype(dtype, np.complexfloating): + raise SkipTest("Known failure for complex inputs; see #9107") + x = np.array([-np.inf, -1.0, 0.0, -0.0, 1.0, np.inf, np.nan, -np.nan], dtype=dtype) + # The sign bit should not matter for 0.0 or NaN, so argsorting the above should be + # equivalent to argsorting the following: + x_equiv = np.array([0, 1, 2, 2, 3, 4, 5, 5]) + + if jnp.issubdtype(dtype, jnp.complexfloating): + x = np.array([complex(r, c) for r, c in itertools.product(x, repeat=2)]) + x_equiv = np.array([complex(r, c) for r, c in itertools.product(x_equiv, repeat=2)]) + + fun = partial(jnp.searchsorted, side=side, method=method) + self.assertArraysEqual(fun(x, x), fun(x_equiv, x_equiv)) + self.assertArraysEqual(jax.jit(fun)(x, x), fun(x_equiv, x_equiv)) + + @jtu.sample_product( + xshape=[(20,), (5, 4)], + binshape=[(1,), (5,)], + right=[True, False], + reverse=[True, False], + dtype=default_dtypes, + ) + def testDigitize(self, xshape, binshape, right, reverse, dtype): + order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:] + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]] + np_fun = lambda x, bins: np.digitize(x, bins, right=right).astype('int32') + jnp_fun = lambda x, bins: jnp.digitize(x, bins, right=right) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtypes=[ + [np.float32], + [np.float32, np.float32], + [np.float32, np.int32, np.float32], + [np.float32, np.int64, np.float32], + [np.float32, np.int32, np.float64], + ], + shape=[(), (2,), (3, 4), (1, 5)], + array_input=[True, False], + ) + def testColumnStack(self, shape, dtypes, array_input): + rng = jtu.rand_default(self.rng()) + if array_input: + args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] + else: + args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] + np_fun = jtu.promote_like_jnp(np.column_stack) + jnp_fun = jnp.column_stack + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in [(), (2,), (3, 4), (1, 100)] + for axis in range(-len(shape), len(shape) + 1) + ], + dtypes=[ + [np.float32], + [np.float32, np.float32], + [np.float32, np.int32, np.float32], + [np.float32, np.int64, np.float32], + [np.float32, np.int32, np.float64], + ], + array_input=[True, False], + out_dtype=[np.float32, np.int32], + ) + def testStack(self, shape, axis, dtypes, array_input, out_dtype): + rng = jtu.rand_default(self.rng()) + if array_input: + args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] + else: + args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] + + if numpy_version < (1, 24): + np_fun = jtu.promote_like_jnp(lambda *args: np.stack(*args, axis=axis).astype(out_dtype)) + else: + np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe')) + + jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + op=["hstack", "vstack", "dstack"], + dtypes=[ + [np.float32], + [np.float32, np.float32], + [np.float32, np.int32, np.float32], + [np.float32, np.int64, np.float32], + [np.float32, np.int32, np.float64], + ], + shape=[(), (2,), (3, 4), (1, 100), (2, 3, 4)], + array_input=[True, False], + out_dtype=[np.float32, np.int32], + ) + def testHVDStack(self, shape, op, dtypes, array_input, out_dtype): + rng = jtu.rand_default(self.rng()) + if array_input: + args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] + else: + args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] + + if numpy_version < (1, 24) or op == "dstack": + np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype)) + else: + np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype, + casting='unsafe') + + jnp_fun = partial(getattr(jnp, op), dtype=out_dtype) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(name=name, **kwds) + for name in ['blackman', 'bartlett', 'hamming', 'hanning', 'kaiser'] + for kwds in ([dict(beta=1), dict(beta=0.5)] if name == 'kaiser' else [{}]) + ], + size = [0, 1, 5, 10], + ) + def testWindowFunction(self, name, size, **kwds): + jnp_fun = partial(getattr(jnp, name), size, **kwds) + np_fun = jtu.with_jax_dtype_defaults(partial(getattr(np, name), size, **kwds)) + args_maker = lambda: [] + tol = ( + 5e-6 if jtu.test_device_matches(['tpu']) and name == 'kaiser' else None + ) + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, atol=tol, rtol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, fill_value_shape=fill_value_shape) + for shape in array_shapes + [3, np.array(7, dtype=np.int32)] + for fill_value_shape in _compatible_shapes(shape)], + fill_value_dtype=default_dtypes, + out_dtype=[None] + default_dtypes, + ) + def testFull(self, shape, fill_value_dtype, fill_value_shape, out_dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda fill_value: np.full(shape, fill_value, dtype=out_dtype) + jnp_fun = lambda fill_value: jnp.full(shape, fill_value, dtype=out_dtype) + args_maker = lambda: [rng(fill_value_shape, fill_value_dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, dtype=dtype, axis=axis) + for shape, dtype in _shape_and_dtypes(nonempty_nonscalar_array_shapes, default_dtypes) + for axis in list(range(-len(shape), max(1, len(shape)))) + ], + prepend=[None, 1, 0], + append=[None, 1, 0], + n=[0, 1, 2], + ) + def testDiff(self, shape, dtype, n, axis, prepend, append): + prepend = np.zeros(shape, dtype=dtype) if prepend == 0 else prepend + append = np.zeros(shape, dtype=dtype) if append == 0 else append + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + def np_fun(x, n=n, axis=axis, prepend=prepend, append=append): + if prepend is None: + prepend = np._NoValue + elif not np.isscalar(prepend) and prepend.dtype == jnp.bfloat16: + prepend = prepend.astype(np.float32) + + if append is None: + append = np._NoValue + elif not np.isscalar(append) and append.dtype == jnp.bfloat16: + append = append.astype(np.float32) + + if x.dtype == jnp.bfloat16: + return np.diff(x.astype(np.float32), n=n, axis=axis, prepend=prepend, append=append).astype(jnp.bfloat16) + else: + return np.diff(x, n=n, axis=axis, prepend=prepend, append=append) + + jnp_fun = lambda x: jnp.diff(x, n=n, axis=axis, prepend=prepend, append=append) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + def testDiffPrepoendScalar(self): + # Regression test for https://github.com/google/jax/issues/19362 + x = jnp.arange(10) + result_jax = jnp.diff(x, prepend=x[0], append=x[-1]) + + x = np.array(x) + result_numpy = np.diff(x, prepend=x[0], append=x[-1]) + + self.assertArraysEqual(result_jax, result_numpy) + + @jtu.sample_product( + op=["zeros", "ones"], + shape=[2, (), (2,), (3, 0), np.array((4, 5, 6), dtype=np.int32), + np.array(4, dtype=np.int32)], + dtype=all_dtypes, + ) + def testZerosOnes(self, op, shape, dtype): + np_op = getattr(np, op) + jnp_op = getattr(jnp, op) + args_maker = lambda: [] + np_op = partial(np_op, shape, dtype) + jnp_op = partial(jnp_op, shape, dtype) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + def testOnesWithInvalidShape(self): + with self.assertRaises(TypeError): + jnp.ones((-1, 1)) + + def test_full_like_commited(self): + x = jnp.array((1, 2, 3), dtype=np.int32) + self.assertFalse(x._committed) + self.assertFalse(lax.full_like(x, 1.1)._committed) + x = jax.device_put(x, jax.devices()[-1]) + self.assertTrue(x._committed) + y = lax.full_like(x, 1.1) + self.assertTrue(y._committed) + self.assertEqual(x.sharding, y.sharding) + + def test_zeros_like_with_explicit_device_and_jitted(self): + x = jnp.array((1, 2, 3), dtype=np.int32) + x = jax.device_put(x, jax.devices()[0]) + zeros_like_with_device = partial(jnp.zeros_like, device=jax.devices()[0]) + y = jax.jit(zeros_like_with_device)(x) + self.assertEqual(x.shape, y.shape) + self.assertEqual(y.sharding, SingleDeviceSharding(jax.devices()[0])) + + @jtu.sample_product( + [dict(shape=shape, out_shape=out_shape, fill_value_shape=fill_value_shape) + for shape in array_shapes + for out_shape in [None] + array_shapes + for fill_value_shape in _compatible_shapes(shape if out_shape is None else out_shape) + ], + in_dtype=default_dtypes, + fill_value_dtype=default_dtypes, + out_dtype=default_dtypes, + ) + def testFullLike(self, shape, in_dtype, fill_value_dtype, fill_value_shape, out_dtype, out_shape): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x, fill_value: np.full_like( + x, fill_value, dtype=out_dtype, shape=out_shape) + jnp_fun = lambda x, fill_value: jnp.full_like( + x, fill_value, dtype=out_dtype, shape=out_shape) + args_maker = lambda: [rng(shape, in_dtype), rng(fill_value_shape, fill_value_dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape=array_shapes, + out_shape=[None] + array_shapes, + in_dtype=default_dtypes, + func=["ones_like", "zeros_like"], + out_dtype=default_dtypes, + ) + def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: getattr(np, func)(x, dtype=out_dtype, shape=out_shape) + jnp_fun = lambda x: getattr(jnp, func)(x, dtype=out_dtype, shape=out_shape) + args_maker = lambda: [rng(shape, in_dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + func=[jnp.empty, jnp.zeros, jnp.ones, jnp.full], + shape=array_shapes, + dtype=default_dtypes, + ) + def testArrayCreationWithDevice(self, func, shape, dtype): + device = jax.devices()[-1] + kwds = {'fill_value': 1} if func is jnp.full else {} + out = func(**kwds, shape=shape, dtype=dtype, device=device) + self.assertEqual(out.devices(), {device}) + + @jtu.sample_product( + func=[jnp.empty, jnp.zeros, jnp.ones, jnp.full], + shape=array_shapes, + dtype=default_dtypes, + ) + def testArrayCreationWithSharding(self, func, shape, dtype): + sharding = SingleDeviceSharding(jax.devices()[-1]) + kwds = {'fill_value': 1} if func is jnp.full else {} + out = func(**kwds, shape=shape, dtype=dtype, device=sharding) + self.assertEqual(out.sharding, sharding) + + @jtu.sample_product( + func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], + shape=array_shapes, + dtype=default_dtypes, + ) + def testFullLikeWithDevice(self, func, shape, dtype): + device = jax.devices()[-1] + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + kwds = {'fill_value': 1} if func is jnp.full_like else {} + + with self.subTest('device from keyword'): + out = func(x, **kwds, device=device) + self.assertEqual(out.devices(), {device}) + + with self.subTest('device from input array'): + out2 = func(out, **kwds) + self.assertEqual(out2.devices(), out.devices()) + + @jtu.sample_product( + func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], + shape=array_shapes, + dtype=default_dtypes, + ) + def testFullLikeWithSharding(self, func, shape, dtype): + sharding = SingleDeviceSharding(jax.devices()[-1]) + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + kwds = {'fill_value': 1} if func is jnp.full_like else {} + + with self.subTest('device from keyword'): + out = func(x, **kwds, device=sharding) + self.assertEqual(out.sharding, sharding) + + with self.subTest('device from input array'): + out2 = func(out, **kwds) + self.assertEqual(out2.devices(), out.devices()) + + def testDuckTypedLike(self): + x = jax.ShapeDtypeStruct((1, 2, 3), np.dtype("int32")) + self.assertArraysEqual(jnp.zeros_like(x), jnp.zeros(x.shape, x.dtype)) + self.assertArraysEqual(jnp.ones_like(x), jnp.ones(x.shape, x.dtype)) + self.assertArraysEqual(jnp.empty_like(x), jnp.empty(x.shape, x.dtype)) + self.assertArraysEqual(jnp.full_like(x, 2), jnp.full(x.shape, 2, x.dtype)) + + @jtu.sample_product( + [dict(func=func, args=args) + for func, args in [("full_like", (-100,)), ("ones_like", ()), ("zeros_like", ())] + ], + shape=array_shapes, + #in_dtype=[np.int32, np.float32, np.complex64], + in_dtype=[np.int32, np.float32], + weak_type=[True, False], + out_shape=[None, (), (10,)], + out_dtype=[None, float], + ) + def testZerosOnesFullLikeWeakType(self, func, args, shape, in_dtype, weak_type, out_shape, out_dtype): + rng = jtu.rand_default(self.rng()) + x = lax_internal._convert_element_type(rng(shape, in_dtype), + weak_type=weak_type) + fun = lambda x: getattr(jnp, func)(x, *args, dtype=out_dtype, shape=out_shape) + expected_weak_type = weak_type and (out_dtype is None) + self.assertEqual(dtypes.is_weakly_typed(fun(x)), expected_weak_type) + self.assertEqual(dtypes.is_weakly_typed(jax.jit(fun)(x)), expected_weak_type) + + @jtu.sample_product( + funcname=["array", "asarray"], + dtype=[int, float, None], + val=[0, 1], + input_type=[int, float, np.int32, np.float32], + ) + def testArrayWeakType(self, funcname, input_type, val, dtype): + func = lambda x: getattr(jnp, funcname)(x, dtype=dtype) + fjit = jax.jit(func) + val = input_type(val) + expected_weak_type = dtype is None and input_type in set(dtypes._weak_types) + self.assertEqual(dtypes.is_weakly_typed(func(val)), expected_weak_type) + self.assertEqual(dtypes.is_weakly_typed(fjit(val)), expected_weak_type) + + @jtu.sample_product( + shape=nonempty_nonscalar_array_shapes, + #dtype=[int, float, complex], + dtype=[int, float], + weak_type=[True, False], + slc=[slice(None), slice(0), slice(3), 0, ...], + ) + def testSliceWeakTypes(self, shape, dtype, weak_type, slc): + rng = jtu.rand_default(self.rng()) + x = lax_internal._convert_element_type(rng(shape, dtype), + weak_type=weak_type) + op = lambda x: x[slc] + self.assertEqual(op(x).aval.weak_type, weak_type) + self.assertEqual(jax.jit(op)(x).aval.weak_type, weak_type) + + @jtu.sample_product( + [dict(shape=shape, axis=axis, num_sections=num_sections) + for shape, axis, num_sections in [ + ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2), + ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)] + ], + dtype=default_dtypes, + ) + def testSplitStaticInt(self, shape, num_sections, axis, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.split(x, num_sections, axis=axis) + jnp_fun = lambda x: jnp.split(x, num_sections, axis=axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis, num_sections=num_sections) + # All testcases split the specified axis unequally + for shape, axis, num_sections in [ + ((3,), 0, 2), ((12,), 0, 5), ((12, 4), 0, 7), ((12, 4), 1, 3), + ((2, 3, 5), -1, 2), ((2, 4, 4), -2, 3), ((7, 2, 2), 0, 3)] + ], + dtype=default_dtypes, + ) + def testArraySplitStaticInt(self, shape, num_sections, axis, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.array_split(x, num_sections, axis=axis) + jnp_fun = lambda x: jnp.array_split(x, num_sections, axis=axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testSplitTypeError(self): + # If we pass an ndarray for indices_or_sections -> no error + self.assertEqual(3, len(jnp.split(jnp.zeros(3), jnp.array([1, 2])))) + + CONCRETIZATION_MSG = "Abstract tracer value encountered where concrete value is expected." + with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG): + # An abstract tracer for idx + jax.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), idx))(2.) + with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG): + # A list including an abstract tracer + jax.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), [2, idx]))(2.) + + # A concrete tracer -> no error + jax.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), idx), + (2.,), (1.,)) + # A tuple including a concrete tracer -> no error + jax.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), (1, idx.astype(np.int32))), + (2.,), (1.,)) + + @jtu.sample_product( + shape=[(5,), (5, 5)], + dtype=number_dtypes, + bins=[10, np.arange(-5, 6), np.array([-5, 0, 3])], + range=[None, (0, 0), (0, 10)], + weights=[True, False], + ) + def testHistogramBinEdges(self, shape, dtype, bins, range, weights): + rng = jtu.rand_default(self.rng()) + _weights = lambda w: abs(w) if weights else None + np_fun = lambda a, w, r: np.histogram_bin_edges(a, bins=bins, range=r, + weights=_weights(w)) + jnp_fun = lambda a, w, r: jnp.histogram_bin_edges(a, bins=bins, range=r, + weights=_weights(w)) + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), range] + tol = {jnp.bfloat16: 2E-2, np.float16: 1E-2} + # linspace() compares poorly to numpy when using bfloat16 + if dtype != jnp.bfloat16: + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, + atol=tol, rtol=tol) + + @jtu.sample_product( + shape=[(5,), (4, 5)], + dtype=default_dtypes, + # We only test explicit integer-valued bin edges because in other cases + # rounding errors lead to flaky tests. + bins=[np.arange(-5, 6), np.array([-5, 0, 3])], + density=[True, False], + weights=[True, False], + ) + def testHistogram(self, shape, dtype, bins, density, weights): + rng = jtu.rand_default(self.rng()) + _weights = lambda w: abs(w) if weights else None + def np_fun(a, w): + # Numpy can't handle bfloat16 + a = a.astype('float32') if a.dtype == jnp.bfloat16 else a + w = w.astype('float32') if w.dtype == jnp.bfloat16 else w + return np.histogram(a, bins=bins, density=density, weights=_weights(w)) + jnp_fun = lambda a, w: jnp.histogram(a, bins=bins, density=density, + weights=_weights(w)) + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)] + tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape=[(5,), (12,)], + dtype=int_dtypes, + bins=[2, [2, 2], [np.array([0, 1, 3, 5]), np.array([0, 2, 3, 4, 6])]], + weights=[False, True], + density=[False, True], + range=[None, [(-1, 1), None], [(-1, 1), (-2, 2)]], + ) + def testHistogram2d(self, shape, dtype, bins, weights, density, range): + rng = jtu.rand_default(self.rng()) + _weights = lambda w: abs(w) if weights else None + np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")( + lambda a, b, w: np.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range)) + jnp_fun = lambda a, b, w: jnp.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range) + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] + tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} + # np.searchsorted errors on bfloat16 with + # "TypeError: invalid type promotion with custom data type" + with np.errstate(divide='ignore', invalid='ignore'): + if dtype != jnp.bfloat16: + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape=[(5, 3), (10, 3)], + dtype=int_dtypes, + bins=[(2, 2, 2), [np.array([-5, 0, 4]), np.array([-4, -1, 2]), np.array([-6, -1, 4])]], + weights=[False, True], + density=[False, True], + range=[None, [(-1, 1), None, None], [(-1, 1), (-2, 2), (-3, 3)]], + ) + def testHistogramdd(self, shape, dtype, bins, weights, density, range): + rng = jtu.rand_default(self.rng()) + _weights = lambda w: abs(w) if weights else None + np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")( + lambda a, w: np.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range)) + jnp_fun = lambda a, w: jnp.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range) + args_maker = lambda: [rng(shape, dtype), rng((shape[0],), dtype)] + tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} + # np.searchsorted errors on bfloat16 with + # "TypeError: invalid type promotion with custom data type" + if dtype != jnp.bfloat16: + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis, num_sections=num_sections) + for shape, axis, num_sections in [ + ((12, 4), 0, 4), ((12,), 1, 2), + ((2, 3, 4), 2, 2), ((4, 3, 4), 0, 2)]], + dtype=default_dtypes, + ) + def testHVDSplit(self, shape, num_sections, axis, dtype): + rng = jtu.rand_default(self.rng()) + def fn(module, axis): + if axis == 0: + return module.vsplit + elif axis == 1: + return module.hsplit + else: + assert axis == 2 + return module.dsplit + + np_fun = lambda x: fn(np, axis)(x, num_sections) + jnp_fun = lambda x: fn(jnp, axis)(x, num_sections) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(arg_shape=arg_shape, out_shape=out_shape) + for arg_shape, out_shape in [ + (jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)), + ((), (1, 1, 1)), + ((7, 0), (0, 42, 101)), + ((3, 4), 12), + ((3, 4), (12,)), + ((3, 4), -1), + ((2, 1, 4), (-1,)), + ((2, 2, 4), (2, 8)) + ] + ], + dtype=default_dtypes, + order=["C", "F"], + ) + def testReshape(self, arg_shape, out_shape, dtype, order): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.reshape(x, out_shape, order=order) + jnp_fun = lambda x: jnp.reshape(x, out_shape, order=order) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(arg_shape=arg_shape, out_shape=out_shape) + for arg_shape, out_shape in [ + ((7, 0), (0, 42, 101)), + ((2, 1, 4), (-1,)), + ((2, 2, 4), (2, 8)) + ] + ], + dtype=default_dtypes, + ) + def testReshapeMethod(self, arg_shape, out_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.reshape(x, out_shape) + jnp_fun = lambda x: x.reshape(*out_shape) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(arg_shape=arg_shape, out_shape=out_shape) + for arg_shape, out_shape in itertools.product(all_shapes, array_shapes)], + dtype=default_dtypes, + ) + def testResize(self, arg_shape, out_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.resize(x, out_shape) + jnp_fun = lambda x: jnp.resize(x, out_shape) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(arg_shape=arg_shape, dim=dim) + for arg_shape in [(), (3,), (3, 4)] + for dim in (list(range(-len(arg_shape)+1, len(arg_shape))) + + [np.array(0), np.array(-1), (0,), [np.array(0)], + (len(arg_shape), len(arg_shape) + 1)]) + ], + dtype=default_dtypes, + ) + def testExpandDimsStaticDim(self, arg_shape, dtype, dim): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.expand_dims(x, dim) + jnp_fun = lambda x: jnp.expand_dims(x, dim) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CompileAndCheck(jnp_fun, args_maker) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + def testExpandDimsRepeatedAxisError(self): + x = jnp.ones((2, 3)) + self.assertRaisesRegex( + ValueError, 'repeated axis.*', + lambda: jnp.expand_dims(x, [1, 1])) + self.assertRaisesRegex( + ValueError, 'repeated axis.*', + lambda: jnp.expand_dims(x, [3, -1])) + + # ensure this is numpy's behavior too, so that we remain consistent + x = np.ones((2, 3)) + self.assertRaisesRegex( + ValueError, 'repeated axis.*', + lambda: np.expand_dims(x, [1, 1])) + self.assertRaisesRegex( + ValueError, 'repeated axis.*', + lambda: np.expand_dims(x, [3, -1])) + + @jtu.sample_product( + [dict(arg_shape=arg_shape, ax1=ax1, ax2=ax2) + for arg_shape, ax1, ax2 in [ + ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2), + ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)] + ], + dtype=default_dtypes, + ) + def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.swapaxes(x, ax1, ax2) + jnp_fun = lambda x: jnp.swapaxes(x, ax1, ax2) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + [dict(arg_shape=arg_shape, ax=ax) + for arg_shape, ax in [ + ((3, 1), None), + ((3, 1), 1), + ((3, 1), -1), + ((3, 1), np.array(1)), + ((1, 3, 1), (0, 2)), + ((1, 3, 1), (0,)), + ((1, 4, 1), (np.array(0),))] + ], + dtype=default_dtypes, + ) + def testSqueeze(self, arg_shape, dtype, ax): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.squeeze(x, ax) + jnp_fun = lambda x: jnp.squeeze(x, ax) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testArrayFromMasked(self): + args_maker = lambda: [np.ma.array([1, 2], mask=[True, False])] + # Like np.array, jnp.array strips the mask from masked array inputs. + self._CheckAgainstNumpy(np.array, jnp.array, args_maker) + # Under JIT, masked arrays are flagged as invalid. + with self.assertRaisesRegex(ValueError, "numpy masked arrays are not supported"): + jax.jit(jnp.asarray)(*args_maker()) + + @jtu.sample_product( + [dict(arg=arg, dtype=dtype, ndmin=ndmin) + for arg, dtypes in [ + ([True, False, True], all_dtypes), + (3., all_dtypes), + ([1, 2, 3], all_dtypes), + (np.array([1, 2, 3], dtype=np.int64), all_dtypes), + ([1., 2., 3.], all_dtypes), + ([[1, 2], [3, 4], [5, 6]], all_dtypes), + ([[1, 2.], [3, 4], [5, 6]], all_dtypes), + ([[1., 2j], [3., 4.], [5., 6.]], complex_dtypes), + ([[3, np.array(2, dtype=jnp.float_), 1], + np.arange(3., dtype=jnp.float_)], all_dtypes), + ] + for dtype in [None] + dtypes + for ndmin in [None, np.ndim(arg), np.ndim(arg) + 1, np.ndim(arg) + 2] + ], + ) + def testArray(self, arg, ndmin, dtype): + args_maker = lambda: [arg] + canonical_dtype = dtypes.canonicalize_dtype(dtype or np.array(arg).dtype) + if ndmin is not None: + np_fun = partial(np.array, ndmin=ndmin, dtype=canonical_dtype) + jnp_fun = partial(jnp.array, ndmin=ndmin, dtype=dtype) + else: + np_fun = partial(np.array, dtype=canonical_dtype) + jnp_fun = partial(jnp.array, dtype=dtype) + + # We are testing correct canonicalization behavior here, so we turn off the + # permissive canonicalization logic in the test harness. + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + canonicalize_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product(copy=[None, True, False]) + def testAsarrayCopy(self, copy): + x_jax = jnp.arange(4) + x_np = np.arange(4) + x_list = [0, 1, 2, 3] + x_buf = make_python_array('l', x_list) + + func = partial(jnp.asarray, copy=copy) + self.assertArraysEqual(x_jax, func(x_jax)) + self.assertArraysEqual(x_jax, func(x_list), check_dtypes=False) + + if copy is False and jax.default_backend() != 'cpu': + # copy=False is strict: it must raise if the input supports the buffer protocol + # but a copy is still required. + self.assertRaises(ValueError, func, x_np) + self.assertRaises(ValueError, func, x_buf) + else: + self.assertArraysEqual(x_jax, func(x_np), check_dtypes=False) + self.assertArraysEqual(x_jax, func(x_buf), check_dtypes=False) + + @unittest.skip("Jax-metal don't support all dtypes.") + @jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*") + def testArrayDtypeInference(self): + def _check(obj, out_dtype, weak_type): + dtype_reference = np.array(obj, dtype=out_dtype) + + out = jnp.array(obj) + self.assertDtypesMatch(out, dtype_reference) + self.assertEqual(dtypes.is_weakly_typed(out), weak_type) + + out_jit = jax.jit(jnp.array)(obj) + self.assertDtypesMatch(out_jit, dtype_reference) + self.assertEqual(dtypes.is_weakly_typed(out_jit), weak_type) + + # Python scalars become 64-bit weak types. + _check(1, np.int64, True) + _check(1.0, np.float64, True) + _check(1.0j, np.complex128, True) + + # Lists become strongly-typed defaults. + _check([1], jnp.int64, False) + _check([1.0], jnp.float64, False) + _check([1.0j], jnp.complex128, False) + + # Lists of weakly-typed objects become strongly-typed defaults. + _check([jnp.array(1)], jnp.int64, False) + _check([jnp.array(1.0)], jnp.float64, False) + _check([jnp.array(1.0j)], jnp.complex128, False) + + # Lists of strongly-typed objects maintain their strong type. + _check([jnp.int64(1)], np.int64, False) + _check([jnp.float64(1)], np.float64, False) + _check([jnp.complex128(1)], np.complex128, False) + + # Mixed inputs use JAX-style promotion. + # (regression test for https://github.com/google/jax/issues/8945) + _check([0, np.int16(1)], np.int16, False) + _check([0.0, np.float16(1)], np.float16, False) + + @jtu.sample_product( + dtype=all_dtypes, + func=["array", "copy", "copy.copy", "copy.deepcopy"], + ) + def testArrayCopy(self, dtype, func): + x = jnp.ones(10, dtype=dtype) + if func == "copy.deepcopy": + copy_func = copy.deepcopy + elif func == "copy.copy": + copy_func = copy.copy + else: + copy_func = getattr(jnp, func) + + x_view = jnp.asarray(x) + x_view_jit = jax.jit(jnp.asarray)(x) + x_copy = copy_func(x) + x_copy_jit = jax.jit(copy_func)(x) + + _ptr = lambda x: x.unsafe_buffer_pointer() + + self.assertEqual(_ptr(x), _ptr(x_view)) + self.assertNotEqual(_ptr(x), _ptr(x_view_jit)) + self.assertNotEqual(_ptr(x), _ptr(x_copy)) + self.assertNotEqual(_ptr(x), _ptr(x_copy_jit)) + + x.delete() + + self.assertTrue(x_view.is_deleted()) + self.assertFalse(x_view_jit.is_deleted()) + + self.assertFalse(x_copy.is_deleted()) + self.assertFalse(x_copy_jit.is_deleted()) + + def testArrayCopyAutodiff(self): + f = lambda x: jnp.array(x, copy=True) + + x = jnp.ones(10) + xdot = jnp.ones(10) + y, ydot = jax.jvp(f, (x,), (xdot,)) + self.assertIsNot(x, y) + self.assertIsNot(xdot, ydot) + + ybar = jnp.ones(10) + y, f_vjp = jax.vjp(f, x) + xbar, = f_vjp(ybar) + self.assertIsNot(x, y) + self.assertIsNot(xbar, ybar) + + def testArrayCopyVmap(self): + f = lambda x: jnp.array(x, copy=True) + x = jnp.ones(10) + y = jax.vmap(f)(x) + self.assertIsNot(x, y) + + def testArrayUnsupportedDtypeError(self): + with self.assertRaisesRegex(TypeError, + "JAX only supports number and bool dtypes.*"): + jnp.array(3, [('a',' 0.: + return x * 2 + else: + return x + 2 + + self.assertRaises(jax.errors.ConcretizationTypeError, lambda: g(3.)) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in [(3,), (2, 3)] + for axis in list(range(-len(shape), len(shape))) + [None] + [tuple(range(len(shape)))] # Test negative axes and tuples + ], + dtype=default_dtypes, + ) + def testFlip(self, shape, dtype, axis): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + jnp_op = lambda x: jnp.flip(x, axis) + np_op = lambda x: np.flip(x, axis) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + shape=[(3,), (2, 3), (3, 2, 4)], + dtype=default_dtypes, + ) + def testFlipud(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + jnp_op = lambda x: jnp.flipud(x) + np_op = lambda x: np.flipud(x) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + shape=[(3, 2), (2, 3), (3, 2, 4)], + dtype=default_dtypes, + ) + def testFliplr(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + jnp_op = lambda x: jnp.fliplr(x) + np_op = lambda x: np.fliplr(x) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axes=axes) + for shape, axes in [ + [(2, 3), (0, 1)], + [(2, 3), (1, 0)], + [(4, 3, 2), (0, 2)], + [(4, 3, 2), (2, 1)], + ] + ], + k=range(-3, 4), + dtype=default_dtypes, + ) + def testRot90(self, shape, dtype, k, axes): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + jnp_op = lambda x: jnp.rot90(x, k, axes) + np_op = lambda x: np.rot90(x, k, axes) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + # TODO(mattjj): test infix operator overrides + + def testRavel(self): + rng = self.rng() + args_maker = lambda: [rng.randn(3, 4).astype("float32")] + self._CompileAndCheck(lambda x: x.ravel(), args_maker) + + @jtu.sample_product( + shape=nonempty_nonscalar_array_shapes, + order=['C', 'F'], + mode=['wrap', 'clip', 'raise'], + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testRavelMultiIndex(self, shape, order, mode): + # generate indices in each dimension with a few out of bounds. + rngs = [jtu.rand_int(self.rng(), low=-1, high=dim + 1) + for dim in shape] + # generate multi_indices of different dimensions that broadcast. + args_maker = lambda: [tuple(rng(ndim * (3,), jnp.int_) + for ndim, rng in enumerate(rngs))] + def np_fun(x): + try: + return np.ravel_multi_index(x, shape, order=order, mode=mode) + except ValueError as err: + if str(err).startswith('invalid entry'): + # sentinel indicating expected error. + return -999 + else: + raise + def jnp_fun(x): + try: + return jnp.ravel_multi_index(x, shape, order=order, mode=mode) + except ValueError as err: + if str(err).startswith('invalid entry'): + # sentinel indicating expected error. + return -999 + else: + raise + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + if mode == 'raise': + msg = ("The error occurred because ravel_multi_index was jit-compiled " + "with mode='raise'. Use mode='wrap' or mode='clip' instead.") + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + jax.jit(jnp_fun)(*args_maker()) + else: + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + ashape=((), (4,), (3, 4)), + cshapes=[ + [(), (4,)], + [(3, 4), (4,), (3, 1)] + ], + adtype=int_dtypes, + cdtype=default_dtypes, + mode=['wrap', 'clip', 'raise'], + ) + def testChoose(self, ashape, adtype, cshapes, cdtype, mode): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(ashape, adtype), [rng(s, cdtype) for s in cshapes]] + def np_fun(a, c): + try: + return np.choose(a, c, mode=mode) + except ValueError as err: + if mode == 'raise' and str(err).startswith('invalid entry'): + return -999 # sentinel indicating expected error. + else: + raise + def jnp_fun(a, c): + try: + return jnp.choose(a, c, mode=mode) + except ValueError as err: + if mode == 'raise' and str(err).startswith('invalid entry'): + return -999 # sentinel indicating expected error. + else: + raise + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + if mode == 'raise': + msg = ("The error occurred because jnp.choose was jit-compiled" + " with mode='raise'. Use mode='wrap' or mode='clip' instead.") + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + jax.jit(jnp_fun)(*args_maker()) + else: + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape=nonempty_nonscalar_array_shapes, + dtype=int_dtypes, + idx_shape=all_shapes, + ) + def testUnravelIndex(self, shape, idx_shape, dtype): + size = math.prod(shape) + rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3) + + def np_fun(index, shape): + # JAX's version outputs the same dtype as the input in the typical case + # where shape is weakly-typed. + out_dtype = index.dtype + # Adjust out-of-bounds behavior to match jax's documented behavior. + index = np.clip(index, -size, size - 1) + index = np.where(index < 0, index + size, index) + return [i.astype(out_dtype) for i in np.unravel_index(index, shape)] + + jnp_fun = jnp.unravel_index + args_maker = lambda: [rng(idx_shape, dtype), shape] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + from_dtype=['int32', 'float32'], + to_dtype=['int32', 'float32', None], + use_method=[True, False], + ) + def testAstype(self, from_dtype, to_dtype, use_method): + rng = self.rng() + args_maker = lambda: [rng.randn(3, 4).astype(from_dtype)] + if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0 + np_op = lambda x: np.astype(x, to_dtype) + else: + np_op = lambda x: np.asarray(x).astype(to_dtype) + if use_method: + jnp_op = lambda x: jnp.asarray(x).astype(to_dtype) + else: + jnp_op = lambda x: jnp.astype(x, to_dtype) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @unittest.skip("Jax-metal don't support all dtypes") + def testAstypeInt4(self): + # Test converting from int4 to int8 + x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4) + args_maker = lambda: [x] + np_op = lambda x: np.asarray(x).astype(jnp.int8) + jnp_op = lambda x: jnp.asarray(x).astype(jnp.int8) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + # Test converting from int8 to int4 + x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int8) + args_maker = lambda: [x] + np_op = lambda x: np.asarray(x).astype(jnp.int4) + jnp_op = lambda x: jnp.asarray(x).astype(jnp.int4) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + shape=array_shapes, + dtype=all_dtypes, + ) + def testNbytes(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + np_op = lambda x: np.asarray(x).nbytes + jnp_op = lambda x: jnp.asarray(x).nbytes + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + shape=array_shapes, + dtype=all_dtypes, + ) + def testItemsize(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + np_op = lambda x: np.asarray(x).itemsize + jnp_op = lambda x: jnp.asarray(x).itemsize + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + shape=nonempty_array_shapes, + dtype=all_dtypes, + num_args=[0, 1, "all"], + use_tuple=[True, False] + ) + def testItem(self, shape, dtype, num_args, use_tuple): + rng = jtu.rand_default(self.rng()) + size = math.prod(shape) + + if num_args == 0: + args = () + elif num_args == 1: + args = (self.rng().randint(0, size),) + else: + args = tuple(self.rng().randint(0, s) for s in shape) + args = (args,) if use_tuple else args + + np_op = lambda x: np.asarray(x).item(*args) + jnp_op = lambda x: jnp.asarray(x).item(*args) + args_maker = lambda: [rng(shape, dtype)] + + if size != 1 and num_args == 0: + with self.assertRaises(ValueError): + jnp_op(*args_maker()) + else: + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + + @jtu.sample_product( + # Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs. + shape=[(0,), (32,), (2, 16)], + a_dtype=all_dtypes, + dtype=(*all_dtypes, None) if config.enable_x64.value else all_dtypes, + ) + def testView(self, shape, a_dtype, dtype): + if jtu.test_device_matches(["tpu"]): + if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]: + self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.") + # It is possible to fill bool arrays with arbitrary bits (not just 0/1 + # bytes), but the behavior is implementation-defined. We therefore only test + # the well-defined case. + rng = (jtu.rand_bool if a_dtype == np.bool_ else jtu.rand_fullrange)( + self.rng() + ) + args_maker = lambda: [rng(shape, a_dtype)] + np_op = lambda x: np.asarray(x).view(dtype) + jnp_op = lambda x: jnp.asarray(x).view(dtype) + # Above may produce signaling nans; ignore warnings from invalid values. + with np.errstate(invalid='ignore'): + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product([ + {'a_dtype': a_dtype, 'dtype': dtype} + for a_dtype in all_dtypes + for dtype in all_dtypes + if np.dtype(a_dtype).itemsize == np.dtype(dtype).itemsize + ]) + def testViewScalar(self, a_dtype, dtype): + if jtu.test_device_matches(["tpu"]): + if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]: + self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.") + rng = jtu.rand_fullrange(self.rng()) + args_maker = lambda: [jnp.array(rng((), a_dtype))] + np_op = lambda x: np.asarray(x).view(dtype) + jnp_op = lambda x: jnp.asarray(x).view(dtype) + # Above may produce signaling nans; ignore warnings from invalid values. + with np.errstate(invalid='ignore'): + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + def testPathologicalFloats(self): + args_maker = lambda: [np.array([ + 0b_0111_1111_1000_0000_0000_0000_0000_0000, # inf + 0b_1111_1111_1000_0000_0000_0000_0000_0000, # -inf + 0b_0111_1111_1100_0000_0000_0000_0000_0000, # qnan + 0b_1111_1111_1100_0000_0000_0000_0000_0000, # -qnan + 0b_0111_1111_1000_0000_0000_0000_0000_0001, # snan + 0b_1111_1111_1000_0000_0000_0000_0000_0001, # -snan + 0b_0111_1111_1000_0000_0000_1100_0000_0000, # nonstandard nan + 0b_1111_1111_1000_0000_0000_1100_0000_0000, # -nonstandard nan + 0b_0000_0000_0000_0000_0000_0000_0000_0000, # zero + 0b_1000_0000_0000_0000_0000_0000_0000_0000, # -zero + ], dtype='uint32')] + + np_op = lambda x: np.asarray(x).view('float32').view('uint32') + jnp_op = lambda x: jnp.asarray(x).view('float32').view('uint32') + + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + # TODO(mattjj): test other ndarray-like method overrides + + def testNpMean(self): + # from https://github.com/google/jax/issues/125 + x = jnp.eye(3, dtype=float) + 0. + ans = np.mean(x) + self.assertAllClose(ans, np.array(1./3), check_dtypes=False) + + def testArangeOnFloats(self): + np_arange = jtu.with_jax_dtype_defaults(np.arange) + # from https://github.com/google/jax/issues/145 + self.assertAllClose(np_arange(0.0, 1.0, 0.1), + jnp.arange(0.0, 1.0, 0.1)) + # from https://github.com/google/jax/issues/3450 + self.assertAllClose(np_arange(2.5), + jnp.arange(2.5)) + self.assertAllClose(np_arange(0., 2.5), + jnp.arange(0., 2.5)) + + def testArangeTypes(self): + # Test that arange() output type is equal to the default types. + int_ = dtypes.canonicalize_dtype(jnp.int_) + float_ = dtypes.canonicalize_dtype(jnp.float_) + + self.assertEqual(jnp.arange(10).dtype, int_) + self.assertEqual(jnp.arange(10.).dtype, float_) + self.assertEqual(jnp.arange(10, dtype='uint16').dtype, np.uint16) + #self.assertEqual(jnp.arange(10, dtype='bfloat16').dtype, jnp.bfloat16) + + self.assertEqual(jnp.arange(0, 10, 1).dtype, int_) + with jax.numpy_dtype_promotion('standard'): + self.assertEqual(jnp.arange(0, 10, 1.).dtype, float_) + self.assertEqual(jnp.arange(0., 10, 1).dtype, float_) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonzerodim_shapes + for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) + ], + stable=[True, False], + dtype=all_dtypes, + ) + def testSort(self, dtype, shape, axis, stable): + rng = jtu.rand_some_equal(self.rng()) if stable else jtu.rand_some_inf_and_nan(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + kwds = {} if axis is NO_VALUE else {'axis': axis} + + def np_fun(arr): + # Note: numpy sort fails on NaN and Inf values with bfloat16 + dtype = arr.dtype + if arr.dtype == jnp.bfloat16: + arr = arr.astype('float32') + # TODO(jakevdp): switch to stable=stable when supported by numpy. + result = np.sort(arr, kind='stable' if stable else None, **kwds) + with jtu.ignore_warning(category=RuntimeWarning, message='invalid value'): + return result.astype(dtype) + jnp_fun = partial(jnp.sort, stable=stable, **kwds) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testSortStableDescending(self): + # TODO(jakevdp): test directly against np.sort when descending is supported. + x = jnp.array([0, 1, jnp.nan, 0, 2, jnp.nan, -jnp.inf, jnp.inf]) + x_sorted = jnp.array([-jnp.inf, 0, 0, 1, 2, jnp.inf, jnp.nan, jnp.nan]) + argsorted_stable = jnp.array([6, 0, 3, 1, 4, 7, 2, 5]) + argsorted_rev_stable = jnp.array([2, 5, 7, 4, 1, 0, 3, 6]) + + self.assertArraysEqual(jnp.sort(x), x_sorted) + self.assertArraysEqual(jnp.sort(x, descending=True), lax.rev(x_sorted, [0])) + self.assertArraysEqual(jnp.argsort(x), argsorted_stable) + self.assertArraysEqual(jnp.argsort(x, descending=True), argsorted_rev_stable) + + @unittest.skip("Jax-metal don't support complex.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in one_dim_array_shapes + for axis in [None] + ], + dtype=all_dtypes, + ) + def testSortComplex(self, dtype, shape, axis): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np.sort_complex, jnp.sort_complex, args_maker, + check_dtypes=False) + self._CompileAndCheck(jnp.sort_complex, args_maker) + + @unittest.skip("Jax-metal fail to convert sort op.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in (-1, *range(len(shape) - 1)) + ], + dtype=all_dtypes, + input_type=[np.array, tuple], + ) + def testLexsort(self, dtype, shape, input_type, axis): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [input_type(rng(shape, dtype))] + jnp_op = lambda x: jnp.lexsort(x, axis=axis) + np_op = jtu.with_jax_dtype_defaults(lambda x: np.lexsort(x, axis=axis)) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @unittest.skip("JAX-metal crash.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonzerodim_shapes + for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) + ], + dtype=all_dtypes, + ) + def testArgsort(self, dtype, shape, axis): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + kwds = {} if axis is NO_VALUE else {'axis': axis} + + @jtu.with_jax_dtype_defaults + def np_fun(arr): + # Note: numpy sort fails on NaN and Inf values with bfloat16 + if arr.dtype == jnp.bfloat16: + arr = arr.astype('float32') + # TODO(jakevdp): switch to stable=True when supported by numpy. + return np.argsort(arr, kind='stable', **kwds) + jnp_fun = partial(jnp.argsort, stable=True, **kwds) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skip("JAX-metal crash.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) + ], + descending=[True, False], + dtype=all_dtypes, + ) + def testArgsortUnstable(self, dtype, shape, axis, descending): + # We cannot directly compare unstable argsorts, so instead check that indexed values match. + rng = jtu.rand_some_equal(self.rng()) + x = rng(shape, dtype) + kwds = {} if axis is NO_VALUE else {'axis': axis} + expected = jnp.sort(x, descending=descending, stable=False, **kwds) + indices = jnp.argsort(x, descending=descending, stable=False, **kwds) + if axis is None: + actual = jnp.ravel(x)[indices] + else: + actual = jnp.take_along_axis(x, indices, axis=-1 if axis is NO_VALUE else axis) + self.assertArraysEqual(actual, expected) + + @jtu.sample_product( + [{'shape': shape, 'axis': axis, 'kth': kth} + for shape in nonzerodim_shapes + for axis in range(-len(shape), len(shape)) + for kth in range(-shape[axis], shape[axis])], + dtype=default_dtypes, + ) + def testPartition(self, shape, dtype, axis, kth): + rng = jtu.rand_default(self.rng()) + arg = rng(shape, dtype) + jnp_output = jnp.partition(arg, axis=axis, kth=kth) + np_output = np.partition(arg, axis=axis, kth=kth) + + # Assert that pivot point is equal: + self.assertArraysEqual( + lax.index_in_dim(jnp_output, axis=axis, index=kth), + lax.index_in_dim(np_output, axis=axis, index=kth)) + + # Assert remaining values are correctly partitioned: + self.assertArraysEqual( + lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis), + lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis)) + self.assertArraysEqual( + lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis), + lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis)) + + #@unittest.skipIf(jtu.device_under_test=="METAL", "Jax-metal fail on empty dim shape.") + @jtu.sample_product( + [{'shape': shape, 'axis': axis, 'kth': kth} + for shape in nonempty_shapes# nonzerodim_shapes + for axis in range(-len(shape), len(shape)) + for kth in range(-shape[axis], shape[axis])], + dtype=default_dtypes, + ) + def testArgpartition(self, shape, dtype, axis, kth): + rng = jtu.rand_default(self.rng()) + arg = rng(shape, dtype) + + jnp_output = jnp.argpartition(arg, axis=axis, kth=kth) + np_output = np.argpartition(arg, axis=axis, kth=kth) + + # Assert that all indices are present + self.assertArraysEqual(jnp.sort(jnp_output, axis), np.sort(np_output, axis), check_dtypes=False) + + # Because JAX & numpy may treat duplicates differently, we must compare values + # rather than indices. + getvals = lambda x, ind: x[ind] + for ax in range(arg.ndim): + if ax != range(arg.ndim)[axis]: + getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax) + jnp_values = getvals(arg, jnp_output) + np_values = getvals(arg, np_output) + + # Assert that pivot point is equal: + self.assertArraysEqual( + lax.index_in_dim(jnp_values, axis=axis, index=kth), + lax.index_in_dim(np_values, axis=axis, index=kth)) + + # Assert remaining values are correctly partitioned: + self.assertArraysEqual( + lax.sort(lax.slice_in_dim(jnp_values, start_index=0, limit_index=kth, axis=axis), dimension=axis), + lax.sort(lax.slice_in_dim(np_values, start_index=0, limit_index=kth, axis=axis), dimension=axis)) + self.assertArraysEqual( + lax.sort(lax.slice_in_dim(jnp_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis), + lax.sort(lax.slice_in_dim(np_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis)) @jtu.sample_product( [dict(shifts=shifts, axis=axis) for shifts, axis in [ (3, None), (1, 1), - # ((3,), (0,)), + ((3,), (0,)), ((-2,), (-2,)), - # ((1, 2), (0, -1)), + ((1, 2), (0, -1)), ((4, 2, 5, 5, 2, 4), None), (100, None), ] ], dtype=all_dtypes, - shape=[(3, 4), (3, 4, 5)], + shape=[(3, 4), (3, 4, 5), (7, 4, 0)], ) def testRoll(self, shape, dtype, shifts, axis): rng = jtu.rand_default(self.rng()) @@ -673,6 +4191,21 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + @jtu.sample_product( + dtype=all_dtypes, + shape=[(1, 2, 3, 4)], + axis=[-3, 0, 2, 3], + start=[-4, -1, 2, 4], + ) + def testRollaxis(self, shape, dtype, start, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + jnp_op = partial(jnp.rollaxis, axis=axis, start=start) + np_op = partial(np.rollaxis, axis=axis, start=start) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @unittest.skip("jax-metal generates a different result from cpu.") @jtu.sample_product( dtype=[np.uint8, np.bool_], bitorder=['big', 'little'], @@ -697,10 +4230,155 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def testUnpackbits(self, shape, dtype, axis, bitorder, count): rng = jtu.rand_int(self.rng(), 0, 256) args_maker = lambda: [rng(shape, dtype)] - jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder) - np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder) + jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder, count=count) + np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder, count=count) self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + #@unittest.skip("jax-metal generates a different result from cpu.") + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in [(3,), (3, 4), (3, 4, 5)] + for axis in itertools.chain(range(-len(shape), len(shape)), + [cast(Union[int, None], None)]) + ], + index_shape=scalar_shapes + [(3,), (2, 1, 3)], + dtype=all_dtypes, + index_dtype=int_dtypes, + #mode=[None, 'wrap', 'clip'], + mode=[None, 'wrap'], + ) + def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode): + def args_maker(): + x = rng(shape, dtype) + i = rng_indices(index_shape, index_dtype) + return x, i + + rng = jtu.rand_default(self.rng()) + if mode is None: + rng_indices = jtu.rand_int(self.rng(), -shape[axis or 0], shape[axis or 0]) + else: + rng_indices = jtu.rand_int(self.rng(), -5, 5) + jnp_op = lambda x, i: jnp.take(x, i, axis=axis, mode=mode) + np_op = lambda x, i: np.take(x, i, axis=axis, mode=mode) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + def testTakeEmpty(self): + np.testing.assert_array_equal( + jnp.array([], dtype=jnp.float32), + jnp.take(jnp.array([], jnp.float32), jnp.array([], jnp.int32))) + + np.testing.assert_array_equal( + jnp.ones((2, 0, 4), dtype=jnp.float32), + jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32), jnp.array([], jnp.int32), + axis=1)) + + with self.assertRaisesRegex(IndexError, "non-empty jnp.take"): + jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32), + jnp.array([0], jnp.int32), axis=1) + + def testTakeOptionalArgs(self): + x = jnp.arange(5.0) + ind = jnp.array([0, 2, 4, 6]) + expected = jnp.array([0.0, 2.0, 4.0, 10.0], dtype=x.dtype) + actual = jnp.take(x, ind, unique_indices=True, + indices_are_sorted=True, fill_value=10.0) + self.assertArraysEqual(expected, actual) + + @jtu.sample_product( + [dict(x_shape=x_shape, i_shape=i_shape, axis=axis) + for x_shape, i_shape in filter( + _shapes_are_equal_length, + filter(_shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(nonempty_nonscalar_array_shapes, 2))) + for axis in itertools.chain(range(len(x_shape)), [-1], + [cast(Union[int, None], None)]) + ], + dtype=default_dtypes, + index_dtype=int_dtypes, + ) + def testTakeAlongAxis(self, x_shape, i_shape, dtype, index_dtype, axis): + rng = jtu.rand_default(self.rng()) + + i_shape = list(i_shape) + if axis is None: + i_shape = [math.prod(i_shape)] + else: + # Test the case where the size of the axis doesn't necessarily broadcast. + i_shape[axis] *= 3 + def args_maker(): + x = rng(x_shape, dtype) + n = math.prod(x_shape) if axis is None else x_shape[axis] + if np.issubdtype(index_dtype, np.unsignedinteger): + index_rng = jtu.rand_int(self.rng(), 0, n) + else: + index_rng = jtu.rand_int(self.rng(), -n, n) + i = index_rng(i_shape, index_dtype) + return x, i + + jnp_op = lambda x, i: jnp.take_along_axis(x, i, axis=axis) + + if hasattr(np, "take_along_axis"): + np_op = lambda x, i: np.take_along_axis(x, i, axis=axis) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self): + # https://github.com/google/jax/issues/5088 + h = jtu.rand_default(self.rng())((256, 256, 100), np.float32) + g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8) + q0 = jnp.take_along_axis(h, g, axis=-1) + q1 = np.take_along_axis( h, g, axis=-1) + np.testing.assert_equal(q0, q1) + + @unittest.skip("Jax-metal fail.") + def testTakeAlongAxisOutOfBounds(self): + x = jnp.arange(10, dtype=jnp.float32) + idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11]) + out = jnp.take_along_axis(x, idx, axis=0) + expected_fill = np.array([jnp.nan, 0, 1, 5, 9, 0, 1, 5, 9, jnp.nan, + jnp.nan], np.float32) + np.testing.assert_array_equal(expected_fill, out) + out = jnp.take_along_axis(x, idx, axis=0, mode="fill") + np.testing.assert_array_equal(expected_fill, out) + + expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32) + out = jnp.take_along_axis(x, idx, axis=0, mode="clip") + np.testing.assert_array_equal(expected_clip, out) + + def testTakeAlongAxisRequiresIntIndices(self): + x = jnp.arange(5) + idx = jnp.array([3.], jnp.float32) + with self.assertRaisesRegex( + TypeError, + "take_along_axis indices must be of integer type, got float32"): + jnp.take_along_axis(x, idx, axis=0) + + def testTakeAlongAxisWithEmptyArgs(self): + # take_along_axis should allow us to gather an empty list of indices from + # an empty input axis without raising a shape error. + x = jnp.ones((4, 0, 3), dtype=jnp.int32) + np.testing.assert_array_equal(x, jnp.take_along_axis(x, x, axis=1)) + + @jtu.sample_product( + dtype=inexact_dtypes, + shape=[0, 5], + n=[2, 4], + increasing=[False, True], + ) + def testVander(self, shape, dtype, n, increasing): + rng = jtu.rand_default(self.rng()) + def np_fun(arg): + arg = arg.astype(np.float32) if dtype == jnp.bfloat16 else arg + return np.vander(arg, N=n, increasing=increasing) + jnp_fun = lambda arg: jnp.vander(arg, N=n, increasing=increasing) + args_maker = lambda: [rng([shape], dtype)] + # np.vander seems to return float64 for all floating types. We could obey + # those semantics, but they seem like a bug. + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + tol={np.float32: 1e-3, np.complex64: 1e-3}) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False) @jtu.sample_product( shape=array_shapes, @@ -724,8 +4402,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): check_dtypes=check_dtypes) self._CompileAndCheck(jnp.nan_to_num, args_maker, check_dtypes=check_dtypes) - # self._CompileAndCheck(jnp_op, args_maker) - @jtu.sample_product( [dict(shapes=shapes, dtypes=dtypes) @@ -745,13 +4421,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp.ix_, args_maker) @jtu.sample_product( - dimensions=[(2,), (4, 5, 6)], + dimensions=[(), (2,), (3, 0), (4, 5, 6)], dtype=number_dtypes, sparse=[True, False], ) def testIndices(self, dimensions, dtype, sparse): - if jtu.device_under_test() == "tpu" and dtype in (np.int16, np.uint16): - raise unittest.SkipTest("Compilation failure on TPU ") def args_maker(): return [] np_fun = partial(np.indices, dimensions=dimensions, dtype=dtype, sparse=sparse) @@ -761,97 +4435,58 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - [dict(op=op, q_rng=q_rng) - for (op, q_rng) in ( - ("percentile", partial(jtu.rand_uniform, low=0., high=100.)), - ("quantile", partial(jtu.rand_uniform, low=0., high=1.)), - ("nanpercentile", partial(jtu.rand_uniform, low=0., high=100.)), - ("nanquantile", partial(jtu.rand_uniform, low=0., high=1.)), - ) - ], - [dict(a_shape=a_shape, axis=axis) - for a_shape, axis in ( - ((7,), None), - ((47, 7), 0), - ((47, 7), ()), - ((4, 101), 1), - ((4, 47, 7), (1, 2)), - ((4, 47, 7), (0, 2)), - ((4, 47, 7), (1, 0, 2)), - ) - ], - a_dtype=default_dtypes, - q_dtype=[np.float32], - q_shape=scalar_shapes + [(1,), (4,)], - keepdims=[False, True], - method=['linear', 'lower', 'higher', 'nearest', 'midpoint'], + shape=all_shapes, dtype=all_dtypes, ) - def testQuantile(self, op, q_rng, a_shape, a_dtype, q_shape, q_dtype, - axis, keepdims, method): - raise unittest.SkipTest("JAX metal legalization error with Sort ") - a_rng = jtu.rand_some_nan(self.rng()) - q_rng = q_rng(self.rng()) - if "median" in op: - args_maker = lambda: [a_rng(a_shape, a_dtype)] - else: - args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)] + def testWhereOneArgument(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] - @jtu.ignore_warning(category=RuntimeWarning, - message="All-NaN slice encountered") - def np_fun(*args): - args = [x if jnp.result_type(x) != jnp.bfloat16 else - np.asarray(x, np.float32) for x in args] - if numpy_version <= (1, 22): - return getattr(np, op)(*args, axis=axis, keepdims=keepdims, - interpolation=method) - else: - return getattr(np, op)(*args, axis=axis, keepdims=keepdims, - method=method) - jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims, - method=method) + with jtu.ignore_warning(category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*"): + self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False) - # TODO(phawkins): we currently set dtype=False because we aren't as - # aggressive about promoting to float64. It's not clear we want to mimic - # Numpy here. - tol_spec = {np.float16: 1E-2, np.float32: 2e-4, np.float64: 5e-6} - tol = max(jtu.tolerance(a_dtype, tol_spec), - jtu.tolerance(q_dtype, tol_spec)) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) + # JIT compilation requires specifying a size statically. Full test of + # this behavior is in testNonzeroSize(). + jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2) + + with jtu.ignore_warning(category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*"): + self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - [dict(a_shape=a_shape, axis=axis) - for a_shape, axis in ( - ((7,), None), - ((47, 7), 0), - ((4, 101), 1), - ) - ], - a_dtype=default_dtypes, - keepdims=[False, True], - op=["median", "nanmedian"], + shapes=filter(_shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(all_shapes, 3)), + dtypes=itertools.combinations_with_replacement(all_dtypes, 3), ) - def testMedian(self, op, a_shape, a_dtype, axis, keepdims): - raise unittest.SkipTest("JAX metal legalization error with Sort ") - if op == "median": - a_rng = jtu.rand_default(self.rng()) - else: - a_rng = jtu.rand_some_nan(self.rng()) - args_maker = lambda: [a_rng(a_shape, a_dtype)] - def np_fun(*args): - args = [x if jnp.result_type(x) != jnp.bfloat16 else - np.asarray(x, np.float32) for x in args] - return getattr(np, op)(*args, axis=axis, keepdims=keepdims) - jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims) - # TODO(phawkins): we currently set dtype=False because we aren't as - # aggressive about promoting to float64. It's not clear we want to mimic - # Numpy here. - tol_spec = {np.float32: 2e-4, np.float64: 5e-6} - tol = jtu.tolerance(a_dtype, tol_spec) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) + def testWhereThreeArgument(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, shapes, dtypes) + def np_fun(cond, x, y): + return jtu.promote_like_jnp(partial(np.where, cond))(x, y) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(np_fun, jnp.where, args_maker) + self._CompileAndCheck(jnp.where, args_maker) + + def testWhereExtraCode(self): + def f(x): + return jnp.where(x > 0, x, -x) + + # Test no comparison literal True/False in jaxpr, and hence no comparison to + # literals + jaxpr = jax.make_jaxpr(jax.grad(f))(3.) + self.assertNotIn('False', str(jaxpr)) + self.assertNotIn('True', str(jaxpr)) + + # But if we set the option off, we get the old behavior. + with config.new_select_transpose(False): + jaxpr = jax.make_jaxpr(jax.grad(f))(3.) + self.assertIn('False', str(jaxpr)) + self.assertIn('True', str(jaxpr)) + + def testWhereScalarPromotion(self): + x = jnp.where(jnp.array([True, False]), 3, + jnp.ones((2,), dtype=jnp.float32)) + self.assertEqual(x.dtype, np.dtype(np.float32)) @jtu.sample_product( [dict(n=n, shapes=shapes) @@ -889,9 +4524,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp.select, args_maker, rtol={np.float64: 1e-7, np.complex128: 1e-7}) - # def testIssue330(self): - # x = jnp.full((1, 1), jnp.array([1])[0]) # doesn't crash - # self.assertEqual(x[0, 0], 1) + def testIssue330(self): + x = jnp.full((1, 1), jnp.array([1])[0]) # doesn't crash + self.assertEqual(x[0, 0], 1) def testScalarDtypePromotion(self): orig_numpy_result = (1 + np.eye(1, dtype=np.float32)).dtype @@ -906,6 +4541,21 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jax_numpy_result = ((x + x.T) / 2).dtype self.assertEqual(orig_numpy_result, jax_numpy_result) + # NOTE(mattjj): I disabled this test when removing lax._safe_mul because + # introducing the convention 0 * inf = 0 leads to silently wrong results in + # some cases. See this comment for details: + # https://github.com/google/jax/issues/1052#issuecomment-514083352 + # def testIssue347(self): + # # https://github.com/google/jax/issues/347 + # def test_fail(x): + # x = jnp.sqrt(jnp.sum(x ** 2, axis=1)) + # ones = jnp.ones_like(x) + # x = jnp.where(x > 0.5, x, ones) + # return jnp.sum(x) + # x = jnp.array([[1, 2], [3, 4], [0, 0]], dtype=jnp.float64) + # result = jax.grad(test_fail)(x) + # assert not np.any(np.isnan(result)) + def testIssue453(self): # https://github.com/google/jax/issues/453 a = np.arange(6) + 1 @@ -913,6 +4563,57 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): expected = np.reshape(a, (3, 2), order='F') self.assertAllClose(ans, expected) + @jtu.sample_product( + #dtype=[int, float, bool, complex], + dtype=[int, float, bool], + op=["atleast_1d", "atleast_2d", "atleast_3d"], + ) + def testAtLeastNdLiterals(self, dtype, op): + # Fixes: https://github.com/google/jax/issues/634 + np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype]) + jnp_fun = lambda arg: getattr(jnp, op)(arg) + args_maker = lambda: [dtype(2)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape=[(0,), (5,), (10,)], + dtype=int_dtypes, + weights=[True, False], + minlength=[0, 20], + length=[None, 8], + ) + def testBincount(self, shape, dtype, weights, minlength, length): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: (rng(shape, dtype), (rng(shape, 'float32') if weights else None)) + + def np_fun(x, *args): + x = np.clip(x, 0, None) # jnp.bincount clips negative values to zero. + out = np.bincount(x, *args, minlength=minlength) + if length and length > out.size: + return np.pad(out, (0, length - out.size)) + return out[:length] + jnp_fun = partial(jnp.bincount, minlength=minlength, length=length) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + if length is not None: + self._CompileAndCheck(jnp_fun, args_maker) + + def testBincountNegative(self): + # Test that jnp.bincount ignores negative values. + x_rng = jtu.rand_int(self.rng(), -100, 100) + w_rng = jtu.rand_uniform(self.rng()) + shape = (1000,) + x = x_rng(shape, 'int32') + w = w_rng(shape, 'float32') + + xn = np.array(x) + xn[xn < 0] = 0 + wn = np.array(w) + np_result = np.bincount(xn[xn >= 0], wn[xn >= 0]) + jnp_result = jnp.bincount(x, w) + self.assertAllClose(np_result, jnp_result, check_dtypes=False) + @jtu.sample_product( input=[ 3, @@ -939,6 +4640,125 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def testLongLong(self): self.assertAllClose(np.int64(7), jax.jit(lambda x: x)(np.longlong(7))) + @jtu.ignore_warning(category=UserWarning, + message="Explicitly requested dtype.*") + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + def testArange(self): + # test cases inspired by dask tests at + # https://github.com/dask/dask/blob/main/dask/array/tests/test_creation.py#L92 + np_arange = jtu.with_jax_dtype_defaults(np.arange) + self.assertAllClose(jnp.arange(77), + np_arange(77)) + self.assertAllClose(jnp.arange(2, 13), + np_arange(2, 13)) + self.assertAllClose(jnp.arange(4, 21, 9), + np_arange(4, 21, 9)) + self.assertAllClose(jnp.arange(53, 5, -3), + np_arange(53, 5, -3)) + self.assertAllClose(jnp.arange(77, dtype=float), + np_arange(77, dtype=float)) + self.assertAllClose(jnp.arange(2, 13, dtype=int), + np_arange(2, 13, dtype=int)) + self.assertAllClose(jnp.arange(0, 1, -0.5), + np_arange(0, 1, -0.5)) + + self.assertRaises(TypeError, lambda: jnp.arange()) + + # test that jnp.arange(N) doesn't instantiate an ndarray + self.assertNotEqual(type(jnp.arange(77)), type(np.arange(77))) + self.assertEqual(type(jnp.arange(77)), type(lax.iota(np.int32, 77))) + + # test that jnp.arange(N, dtype=int32) doesn't instantiate an ndarray + self.assertNotEqual(type(jnp.arange(77, dtype=jnp.int32)), + type(np.arange(77, dtype=np.int32))) + self.assertEqual(type(jnp.arange(77, dtype=jnp.int32)), + type(lax.iota(np.int32, 77))) + + def testArangeJit(self): + ans = jax.jit(lambda: jnp.arange(5))() + expected = jtu.with_jax_dtype_defaults(np.arange)(5) + self.assertAllClose(ans, expected) + + @jtu.sample_product(args=[(5,), (0, 5)]) + def testArangeJaxpr(self, args): + jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args))() + self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p) + + @unittest.skip("Jax-metal don't support complex.") + def testIssue830(self): + a = jnp.arange(4, dtype=jnp.complex64) + self.assertEqual(a.dtype, jnp.complex64) + + def testIssue728(self): + np_eye = jtu.with_jax_dtype_defaults(np.eye) + self.assertAllClose(jnp.eye(5000), np_eye(5000)) + self.assertEqual(0, np.sum(jnp.eye(1050) - np_eye(1050))) + + def testIssue746(self): + jnp.arange(12).reshape(3, 4) # doesn't crash + + def testIssue764(self): + x = jnp.linspace(190, 200, 4) + f = jax.grad(lambda x: jnp.sum(jnp.tanh(x))) + # Expected values computed with autograd in float64 precision. + expected = np.array([3.71669453e-165, 4.72999108e-168, 6.01954653e-171, + 7.66067839e-174], np.float64) + self.assertAllClose(f(x), expected, check_dtypes=False) + + # Test removed because tie_in is deprecated. + # def testIssue776(self): + # """Tests that the scatter-add transpose rule instantiates symbolic zeros.""" + # def f(u): + # y = jnp.ones_like(u, shape=10).at[np.array([2, 4, 5])].add(u) + # # The transpose rule for lax.tie_in returns a symbolic zero for its first + # # argument. + # return lax.tie_in(y, 7.) + + # self.assertAllClose(np.zeros(3,), jax.grad(f)(np.ones(3,))) + + # NOTE(mattjj): I disabled this test when removing lax._safe_mul because this + # is a numerical stability issue that should be solved with a custom jvp rule + # of the sigmoid function being differentiated here, not by safe_mul. + # def testIssue777(self): + # x = jnp.linspace(-200, 0, 4, dtype=np.float32) + # f = jax.grad(lambda x: jnp.sum(1 / (1 + jnp.exp(-x)))) + # self.assertAllClose(f(x), np.array([0., 0., 0., 0.25], dtype=np.float32)) + + #unittest.skip("Jax-metal fail on tanh with np.nan") + @jtu.sample_product( + dtype=float_dtypes, + op=("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan", + "sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp", + "log", "expm1", "log1p"), + ) + def testMathSpecialFloatValues(self, op, dtype): + np_op = getattr(np, op) + np_op = jtu.ignore_warning(category=RuntimeWarning, + message="invalid value.*")(np_op) + np_op = jtu.ignore_warning(category=RuntimeWarning, + message="divide by zero.*")(np_op) + np_op = jtu.ignore_warning(category=RuntimeWarning, + message="overflow.*")(np_op) + + jnp_op = getattr(jnp, op) + dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type + for x in (-np.inf, -100., -2., -1., 0., 1., 2., 100., np.inf, + jnp.finfo(dtype).max, np.sqrt(jnp.finfo(dtype).max), + np.sqrt(jnp.finfo(dtype).max) * 2.): #np.nan + x = dtype(x) + expected = np_op(x) + actual = jnp_op(x) + tol = jtu.tolerance(dtype, {np.float32: 1e-3, np.float64: 1e-7}) + self.assertAllClose(expected, actual, atol=tol, + rtol=tol) + + def testIssue956(self): + self.assertRaises(TypeError, lambda: jnp.ndarray((1, 1))) + + def testIssue967(self): + self.assertRaises(TypeError, lambda: jnp.zeros(1.5)) + @jtu.sample_product( shape=[(5,), (10, 5), (4, 10)], dtype=number_dtypes, @@ -1012,22 +4832,92 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + def testMgrid(self): + # wrap indexer for appropriate dtype defaults. + np_mgrid = _indexer_with_default_outputs(np.mgrid) + assertAllEqual = partial(self.assertAllClose, atol=0, rtol=0) + assertAllEqual(np_mgrid[()], jnp.mgrid[()]) + assertAllEqual(np_mgrid[:4], jnp.mgrid[:4]) + assertAllEqual(np_mgrid[:4,], jnp.mgrid[:4,]) + assertAllEqual(np_mgrid[:4], jax.jit(lambda: jnp.mgrid[:4])()) + assertAllEqual(np_mgrid[:5, :5], jnp.mgrid[:5, :5]) + assertAllEqual(np_mgrid[:3, :2], jnp.mgrid[:3, :2]) + assertAllEqual(np_mgrid[1:4:2], jnp.mgrid[1:4:2]) + assertAllEqual(np_mgrid[1:5:3, :5], jnp.mgrid[1:5:3, :5]) + assertAllEqual(np_mgrid[:3, :2, :5], jnp.mgrid[:3, :2, :5]) + assertAllEqual(np_mgrid[:3:2, :2, :5], jnp.mgrid[:3:2, :2, :5]) + # Corner cases + assertAllEqual(np_mgrid[:], jnp.mgrid[:]) + # When the step length is a complex number, because of float calculation, + # the values between jnp and np might slightly different. + atol = 1e-6 + rtol = 1e-6 + self.assertAllClose(np_mgrid[-1:1:5j], + jnp.mgrid[-1:1:5j], + atol=atol, + rtol=rtol) + self.assertAllClose(np_mgrid[3:4:7j], + jnp.mgrid[3:4:7j], + atol=atol, + rtol=rtol) + self.assertAllClose(np_mgrid[1:6:8j, 2:4], + jnp.mgrid[1:6:8j, 2:4], + atol=atol, + rtol=rtol) + # Non-integer steps + self.assertAllClose(np_mgrid[0:3.5:0.5], + jnp.mgrid[0:3.5:0.5], + atol=atol, + rtol=rtol) + self.assertAllClose(np_mgrid[1.3:4.2:0.3], + jnp.mgrid[1.3:4.2:0.3], + atol=atol, + rtol=rtol) + # abstract tracer value for jnp.mgrid slice + with self.assertRaisesRegex(core.ConcretizationTypeError, + "slice start of jnp.mgrid"): + jax.jit(lambda a, b: jnp.mgrid[a:b])(0, 2) + def testOgrid(self): # wrap indexer for appropriate dtype defaults. np_ogrid = _indexer_with_default_outputs(np.ogrid) - def assertListOfArraysEqual(xs, ys): - self.assertIsInstance(xs, list) - self.assertIsInstance(ys, list) + def assertSequenceOfArraysEqual(xs, ys): + self.assertIsInstance(xs, (list, tuple)) + self.assertIsInstance(ys, (list, tuple)) self.assertEqual(len(xs), len(ys)) for x, y in zip(xs, ys): self.assertArraysEqual(x, y) + self.assertArraysEqual(np_ogrid[:5], jnp.ogrid[:5]) + self.assertArraysEqual(np_ogrid[:5], jax.jit(lambda: jnp.ogrid[:5])()) self.assertArraysEqual(np_ogrid[1:7:2], jnp.ogrid[1:7:2]) # List of arrays - assertListOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,]) - assertListOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3]) - assertListOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3]) - assertListOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11]) + assertSequenceOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,]) + assertSequenceOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3]) + assertSequenceOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3]) + assertSequenceOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11]) + # Corner cases + self.assertArraysEqual(np_ogrid[:], jnp.ogrid[:]) + # Complex number steps + atol = 1e-6 + rtol = 1e-6 + self.assertAllClose(np_ogrid[-1:1:5j], + jnp.ogrid[-1:1:5j], + atol=atol, + rtol=rtol) + # Non-integer steps + self.assertAllClose(np_ogrid[0:3.5:0.3], + jnp.ogrid[0:3.5:0.3], + atol=atol, + rtol=rtol) + self.assertAllClose(np_ogrid[1.2:4.8:0.24], + jnp.ogrid[1.2:4.8:0.24], + atol=atol, + rtol=rtol) + # abstract tracer value for ogrid slice + with self.assertRaisesRegex(core.ConcretizationTypeError, + "slice start of jnp.ogrid"): + jax.jit(lambda a, b: jnp.ogrid[a:b])(0, 2) def testR_(self): a = np.arange(6).reshape((2,3)) @@ -1043,8 +4933,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertArraysEqual(np.r_['0,4,-2', [1,2,3], [4,5,6]], jnp.r_['0,4,-2', [1,2,3], [4,5,6]]) # matrix directives - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=PendingDeprecationWarning) + with jtu.ignore_warning(category=PendingDeprecationWarning): self.assertArraysEqual(np.r_['r',[1,2,3], [4,5,6]], jnp.r_['r',[1,2,3], [4,5,6]]) self.assertArraysEqual(np.r_['c', [1, 2, 3], [4, 5, 6]], jnp.r_['c', [1, 2, 3], [4, 5, 6]]) @@ -1092,8 +4981,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertArraysEqual(np.c_['0,4,-1', [1,2,3], [4,5,6]], jnp.c_['0,4,-1', [1,2,3], [4,5,6]]) self.assertArraysEqual(np.c_['0,4,-2', [1,2,3], [4,5,6]], jnp.c_['0,4,-2', [1,2,3], [4,5,6]]) # matrix directives, avoid numpy deprecation warning - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=PendingDeprecationWarning) + with jtu.ignore_warning(category=PendingDeprecationWarning): self.assertArraysEqual(np.c_['r',[1,2,3], [4,5,6]], jnp.c_['r',[1,2,3], [4,5,6]]) self.assertArraysEqual(np.c_['c', [1, 2, 3], [4, 5, 6]], jnp.c_['c', [1, 2, 3], [4, 5, 6]]) @@ -1129,9 +5017,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertEqual(np.index_exp[5:3:2j],jnp.index_exp[5:3:2j]) @jtu.sample_product( - start_shape=[(2,), (2, 2)], - stop_shape=[(2,), (2, 2)], - num=[1, 2, 5, 20], + start_shape=[(), (2,), (2, 2)], + stop_shape=[(), (2,), (2, 2)], + num=[0, 1, 2, 5, 20], endpoint=[True, False], retstep=[True, False], # floating-point compute between jitted platforms and non-jit + rounding @@ -1171,20 +5059,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertAllClose(out[np.array([0, -1])], endpoints, rtol=0, atol=0) @jtu.sample_product( - start_shape=[(2,), (2, 2)], - stop_shape=[(2,), (2, 2)], - num=[1, 2, 5, 20], + start_shape=[(), (2,), (2, 2)], + stop_shape=[(), (2,), (2, 2)], + num=[0, 1, 2, 5, 20], endpoint=[True, False], base=[10.0, 2, np.e], # skip 16-bit floats due to insufficient precision for the test. - dtype=[np.float32] + [None,], + dtype=jtu.dtypes.inexact + [None,], ) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testLogspace(self, start_shape, stop_shape, num, endpoint, base, dtype): if (dtype in int_dtypes and - jtu.device_under_test() in ("gpu", "tpu", "METAL") and - not config.x64_enabled): + jtu.test_device_matches(["gpu", "tpu"]) and + not config.enable_x64.value): raise unittest.SkipTest("GPUx32 truncated exponentiation" " doesn't exactly match other platforms.") rng = jtu.rand_default(self.rng()) @@ -1212,6 +5100,83 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_op, args_maker, check_dtypes=False, atol=atol, rtol=tol) + @jtu.sample_product( + [dict(start_shape=start_shape, stop_shape=stop_shape, axis=axis) + for start_shape in [(), (2,), (2, 2)] + for stop_shape in [(), (2,), (2, 2)] + for axis in range(-max(len(start_shape), len(stop_shape)), + max(len(start_shape), len(stop_shape))) + ], + num=[0, 1, 2, 5, 20], + endpoint=[True, False], + # NB: numpy's geomspace gives nonsense results on integer types + dtype=inexact_dtypes + [None,], + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testGeomspace(self, start_shape, stop_shape, num, + endpoint, dtype, axis): + rng = jtu.rand_default(self.rng()) + # relax default tolerances slightly + tol = {dtypes.bfloat16: 2e-2, np.float16: 4e-3, np.float32: 2e-3, + np.float64: 1e-14, np.complex64: 2e-3, np.complex128: 1e-14} + def args_maker(): + """Test the set of inputs np.geomspace is well-defined on.""" + start, stop = self._GetArgsMaker(rng, + [start_shape, stop_shape], + [dtype, dtype])() + # np.geomspace can't handle differently ranked tensors + # w. negative numbers! + start, stop = jnp.broadcast_arrays(start, stop) + if dtype in complex_dtypes: + return start, stop + # to avoid NaNs, non-complex start and stop cannot + # differ in sign, elementwise + start = start * jnp.sign(start) * jnp.sign(stop) + return start, stop + start, stop = args_maker() + def jnp_op(start, stop): + return jnp.geomspace(start, stop, num, endpoint=endpoint, dtype=dtype, + axis=axis) + def np_op(start, stop): + start = start.astype(np.float32) if dtype == jnp.bfloat16 else start + stop = stop.astype(np.float32) if dtype == jnp.bfloat16 else stop + return np.geomspace( + start, stop, num, endpoint=endpoint, + dtype=dtype if dtype != jnp.bfloat16 else np.float32, + axis=axis).astype(dtype) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker, + check_dtypes=False, tol=tol) + if dtype in (inexact_dtypes + [None,]): + self._CompileAndCheck(jnp_op, args_maker, + check_dtypes=False, atol=tol, rtol=tol) + + def testDisableNumpyRankPromotionBroadcasting(self): + with jax.numpy_rank_promotion('allow'): + jnp.ones(2) + jnp.ones((1, 2)) # works just fine + + with jax.numpy_rank_promotion('raise'): + self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2))) + jnp.ones(2) + 3 # don't want to raise for scalars + + with jax.numpy_rank_promotion('warn'): + self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on " + r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2))) + jnp.ones(2) + 3 # don't want to warn for scalars + + @unittest.skip("Test fails on CI, perhaps due to JIT caching") + def testDisableNumpyRankPromotionBroadcastingDecorator(self): + with jax.numpy_rank_promotion("allow"): + jnp.ones(2) + jnp.ones((1, 2)) # works just fine + + with jax.numpy_rank_promotion("raise"): + self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2))) + jnp.ones(2) + 3 # don't want to raise for scalars + + with jax.numpy_rank_promotion("warn"): + self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on " + r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2))) + jnp.ones(2) + 3 # don't want to warn for scalars + def testStackArrayArgument(self): # tests https://github.com/google/jax/issues/1271 @jax.jit @@ -1316,6 +5281,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): HIGHEST, partial(jnp.vdot, precision=HIGHEST), ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.vecdot, precision=HIGHEST), + ones_1d, ones_1d) jtu.assert_dot_precision( HIGHEST, partial(jnp.tensordot, axes=2, precision=HIGHEST), @@ -1341,6 +5310,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): partial(jnp.inner, precision=HIGHEST), ones_1d, ones_1d) + @jtu.sample_product( + funcname=['inner', 'matmul', 'dot', 'vdot', 'tensordot', 'vecdot'] + ) + def testPreferredElementType(self, funcname): + func = getattr(jnp, funcname) + kwargs = dict(axes=0) if funcname == 'tensordot' else {} + + ones_i32 = np.ones(2, dtype='int32') + ones_f32 = np.ones(2, dtype='float32') + + with jax.numpy_dtype_promotion('strict'): + jtu.assert_dot_preferred_element_type('int32', func, ones_i32, ones_i32, **kwargs) + jtu.assert_dot_preferred_element_type('float32', func, ones_f32, ones_f32, **kwargs) + jtu.assert_dot_preferred_element_type('bfloat16', func, ones_f32, ones_f32, **kwargs, + preferred_element_type='bfloat16') + with jax.numpy_dtype_promotion('standard'): + jtu.assert_dot_preferred_element_type('float32', func, ones_i32, ones_f32, **kwargs) + @jtu.sample_product( [dict(shape=shape, varargs=varargs, axis=axis) for shape in [(10,), (10, 15), (10, 15, 20)] @@ -1399,13 +5386,26 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): for order in ['C', 'F']: self.assertEqual(jnp.asarray(v).tobytes(order), v.tobytes(order)) + def testToBytesJitError(self): + v = np.arange(12, dtype=np.int32).reshape(3, 4) + f = jax.jit(lambda x: x.tobytes()) + msg = r".*The tobytes\(\) method was called on traced array" + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(v) + def testToList(self): v = np.arange(12, dtype=np.int32).reshape(3, 4) self.assertEqual(jnp.asarray(v).tolist(), v.tolist()) - @jtu.skip_on_devices("METAL") + def testToListJitError(self): + v = np.arange(12, dtype=np.int32).reshape(3, 4) + f = jax.jit(lambda x: x.tolist()) + msg = r".*The tolist\(\) method was called on traced array" + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(v) + def testArangeConcretizationError(self): - msg = r"It arose in jax.numpy.arange argument `{}`".format + msg = r"It arose in the jnp.arange argument '{}'".format with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')): jax.jit(jnp.arange)(3) @@ -1437,12 +5437,66 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def testIssue2347(self): # https://github.com/google/jax/issues/2347 - object_list = List[Tuple[jnp.array, float, float, jnp.array, bool]] + object_list = list[tuple[jnp.array, float, float, jnp.array, bool]] self.assertRaises(TypeError, jnp.array, object_list) np_object_list = np.array(object_list) self.assertRaises(TypeError, jnp.array, np_object_list) + @unittest.skip("JAX-metal don't support complex type yet.") + @jtu.sample_product( + [dict(shapes=shapes, dtypes=dtypes) + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(all_shapes, 2)) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, complex_dtypes + [None]) for s in shapes)) + ], + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLogaddexpComplex(self, shapes, dtypes): + @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") + def np_op(x1, x2): + return np.log(np.exp(x1) + np.exp(x2)) + + rng = jtu.rand_some_nan(self.rng()) + args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)) + if jtu.test_device_matches(["tpu"]): + tol = {np.complex64: 1e-3, np.complex128: 1e-10} + else: + tol = {np.complex64: 1e-5, np.complex128: 1e-14} + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp.logaddexp, args_maker, tol=tol) + self._CompileAndCheck(jnp.logaddexp, args_maker, rtol=tol, atol=tol) + + @unittest.skip("JAX-metal don't support complex type yet.") + @jtu.sample_product( + [dict(shapes=shapes, dtypes=dtypes) + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(all_shapes, 2)) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, complex_dtypes + [None]) for s in shapes)) + ], + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLogaddexp2Complex(self, shapes, dtypes): + @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") + def np_op(x1, x2): + return np.log2(np.exp2(x1) + np.exp2(x2)) + + rng = jtu.rand_some_nan(self.rng()) + args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)) + if jtu.test_device_matches(["tpu"]): + tol = {np.complex64: 1e-3, np.complex128: 1e-10} + else: + tol = {np.complex64: 1e-5, np.complex128: 1e-14} + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp.logaddexp2, args_maker, tol=tol) + self._CompileAndCheck(jnp.logaddexp2, args_maker, rtol=tol, atol=tol) + def testDefaultDtypes(self): precision = config.default_dtype_bits.value assert precision in ['32', '64'] @@ -1472,51 +5526,103 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): actual = jnp.fromstring(s, sep=',', dtype=int) self.assertArraysEqual(expected, actual) + @jtu.sample_product( + a_shape=nonempty_nonscalar_array_shapes, + v_shape=nonempty_shapes, + dtype=jtu.dtypes.all, + ) + def testPlace(self, a_shape, v_shape, dtype): + rng = jtu.rand_default(self.rng()) + mask_rng = jtu.rand_bool(self.rng()) -# Most grad tests are at the lax level (see lax_test.py), but we add some here -# as needed for e.g. particular compound ops of interest. + def args_maker(): + a = rng(a_shape, dtype) + m = mask_rng(a_shape, bool) + v = rng(v_shape, dtype) + return a, m, v -GradTestSpec = collections.namedtuple( - "GradTestSpec", - ["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"]) -def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): - return GradTestSpec( - op, nargs, order, rng_factory, dtypes, name or op.__name__, tol) + def np_fun(a, m, v): + a_copy = a.copy() + np.place(a_copy, m, v) + return a_copy -GRAD_TEST_RECORDS = [ - grad_test_spec(jnp.arcsinh, nargs=1, order=2, - rng_factory=jtu.rand_positive, - dtypes=[np.float64, np.complex64], - tol={np.complex64: 2e-2}), - grad_test_spec(jnp.arccosh, nargs=1, order=1, - rng_factory=jtu.rand_positive, - dtypes=[np.float64, np.complex64], - tol={np.complex64: 2e-2}), - grad_test_spec(jnp.arctanh, nargs=1, order=2, - rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9), - dtypes=[np.float64, np.complex64], - tol={np.complex64: 2e-2}), - grad_test_spec(jnp.logaddexp, nargs=2, order=1, - rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9), - dtypes=[np.float64], tol=1e-4), - grad_test_spec(jnp.logaddexp2, nargs=2, order=2, - rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9), - dtypes=[np.float64], tol=1e-4), -] + jnp_fun = partial(jnp.place, inplace=False) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) -GradSpecialValuesTestSpec = collections.namedtuple( - "GradSpecialValuesTestSpec", ["op", "values", "order"]) + @jtu.sample_product( + a_shape=nonempty_nonscalar_array_shapes, + i_shape=all_shapes, + v_shape=all_shapes, + dtype=jtu.dtypes.all, + mode=[None, 'wrap', 'clip'], + ) + def testPut(self, mode, a_shape, i_shape, v_shape, dtype): + size = math.prod(a_shape) + if math.prod(i_shape) > size: + self.skipTest("too many indices") + rng = jtu.rand_default(self.rng()) + # Must test unique integers, because overlapping updates in + # JAX have implementation-defined order + idx_rng = jtu.rand_unique_int(self.rng(), size) -GRAD_SPECIAL_VALUE_TEST_RECORDS = [ - GradSpecialValuesTestSpec(jnp.arcsinh, [0., 1000.], 2), - GradSpecialValuesTestSpec(jnp.arccosh, [1000.], 2), - GradSpecialValuesTestSpec(jnp.arctanh, [0.], 2), - GradSpecialValuesTestSpec(jnp.sinc, [0.], 1), -] + def args_maker(): + a = rng(a_shape, dtype) + i = idx_rng(i_shape, np.int32) + v = rng(v_shape, dtype) + # put some indices out of range without duplicating indices + if mode == "clip" and i.size: + np.put(i, np.argmax(i), size + 2) + np.put(i, np.argmin(i), -2) + if mode == "wrap" and i.size: + np.put(i, 0, np.take(i, 0) + size) + return a, i, v + + def np_fun(a, i, v): + a_copy = a.copy() + np.put(a_copy, i, v, mode=mode) + return a_copy + + jnp_fun = partial(jnp.put, mode=mode, inplace=False) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def test_rot90_error(self): + with self.assertRaisesRegex( + ValueError, + "rot90 requires its first argument to have ndim at least two, " + "but got first argument of"): + jnp.rot90(jnp.ones(2)) + + @parameterized.named_parameters( + ('ones', jnp.ones), + ('zeros', jnp.zeros), + ('empty', jnp.empty)) + def test_error_hint(self, fn): + with self.assertRaisesRegex( + TypeError, + r"Did you accidentally write `jax\.numpy\..*?\(2, 3\)` " + r"when you meant `jax\.numpy\..*?\(\(2, 3\)\)`"): + fn(2, 3) + + @jtu.sample_product( + dtype=jtu.dtypes.all, + kind=['bool', 'signed integer', 'unsigned integer', 'integral', + 'real floating', 'complex floating', 'numeric'] + ) + def test_isdtype(self, dtype, kind): + # Full tests also in dtypes_test.py; here we just compare against numpy + jax_result = jnp.isdtype(dtype, kind) + if jtu.numpy_version() < (2, 0, 0) or dtype == dtypes.bfloat16: + # just a smoke test + self.assertIsInstance(jax_result, bool) + else: + numpy_result = np.isdtype(dtype, kind) + self.assertEqual(jax_result, numpy_result) from jaxlib import xla_client -from jax_plugins import metal_plugin +@unittest.skipIf(metal_plugin == None, "Tests require jax-metal plugin.") class ReportedIssuesTests(jtu.JaxTestCase): def dispatchOn(self, args, func, device=jax.devices('cpu')[0]): deviceArgs = [] @@ -1541,7 +5647,7 @@ class ReportedIssuesTests(jtu.JaxTestCase): return jtu.parse_version(curr_ver) >= jtu.parse_version(target_ver) return False - + #https://github.com/google/jax/issues/16420 def test_broadcast_dim(self): x = jnp.arange(2) @@ -1559,12 +5665,12 @@ class ReportedIssuesTests(jtu.JaxTestCase): def test_identity(self): x = jnp.identity(4) jtu.check_eq(x, np.identity(4)) - + def test_triu(self): x = np.ones((4,4)) res = jnp.triu(x) jtu.check_eq(res, np.triu(x)) - + #https://github.com/google/jax/issues/16471 def test_matmul_1d(self): x = np.array(np.random.rand(3, 3)) @@ -1574,7 +5680,7 @@ class ReportedIssuesTests(jtu.JaxTestCase): self.assertArraysAllClose(res, np.dot(y,z)) res = jnp.dot(x, y) self.assertArraysAllClose(res, np.dot(x,y)) - + #https://github.com/google/jax/issues/17175 def test_indexing(self): x = jnp.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], dtype=jnp.float32) @@ -1596,21 +5702,21 @@ class ReportedIssuesTests(jtu.JaxTestCase): [0.1071583, 0.26139486, 0.25664794, 0.8109596]]) return jnp.take_along_axis(x, idx, axis=1) jtu.check_eq(f(), self.dispatchOn([], f)) - + #https://github.com/google/jax/issues/17590 def test_in1d(self): a = np.array([123,2,4]) b = np.array([123,1]) res = jnp.isin(a,b) jtu.check_eq(res, np.isin(a, b)) - + def test_indexing_update(self): x = jnp.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], dtype=jnp.float32) @jax.vmap def f(x): return x.at[0].set(1.0) f = jax.jit(f) - res= f(x) + res = f(x) jtu.check_eq(res, np.array([[1., 2., 3.], [1., 5., 6.,], [1., 8., 9.], [1., 11., 12.]])) #https://github.com/google/jax/issues/16326 @@ -1646,7 +5752,7 @@ module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas res_ref = ReportedIssuesTests.compile_and_exec(ir, [data, index], run_on_cpu = True) print(res) jtu.check_eq(res, res_ref) - + #https://github.com/google/jax/issues/16366 def test_pad_interior_1(self): if not ReportedIssuesTests.jax_metal_supported('0.0.6'): @@ -1677,5 +5783,23 @@ module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas res_ref = np.cumsum(seq, axis=-1, dtype=np.uint8) jtu.check_eq(res, res_ref) + @unittest.expectedFailure + def test_issue_pad(self): + ir = ''' + module @jit_dummy attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x2xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<3x4xf32> {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<4x4xf32> { + %12 = stablehlo.slice %arg0 [0:1, 1:2] : (tensor<2x2xf32>) -> tensor<1x1xf32> + %13 = stablehlo.reshape %12 : (tensor<1x1xf32>) -> tensor + %14 = stablehlo.pad %arg1, %13, low = [0, 0], high = [1, 0], interior = [0, 0] : (tensor<3x4xf32>, tensor) -> tensor<4x4xf32> + return %14 : tensor<4x4xf32> + } + } + ''' + data = np.array([[1, 3], [1, 3]], dtype=np.float32) + input = np.random.rand(3,4).astype(np.float32) + res = ReportedIssuesTests.compile_and_exec(ir, [data, input]) + res_ref = ReportedIssuesTests.compile_and_exec(ir, [data, input], run_on_cpu = True) + jtu.check_eq(res, res_ref) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 5a93e15bd782f707a3dd0dd78702c160b7689897 Mon Sep 17 00:00:00 2001 From: Shuhan Ding Date: Tue, 12 Mar 2024 17:17:20 -0700 Subject: [PATCH 3/3] add to tests/BUILD --- tests/BUILD | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/BUILD b/tests/BUILD index 9c8ca9310..62fb9fb8a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -542,6 +542,21 @@ jax_test( ] + py_deps("numpy"), ) +jax_test( + name = "lax_metal_test", + srcs = ["lax_metal_test.py"], + tags = ["notap"], + disable_backends = [ + "cpu", + "gpu", + "tpu", + ], + deps = [ + "//jax:internal_test_util", + "//jax:lax_reference", + ] + py_deps("numpy"), +) + jax_test( name = "lax_autodiff_test", srcs = ["lax_autodiff_test.py"],