rocm_jax/tests/lax_numpy_ufuncs_test.py
2024-11-18 09:27:36 -08:00

508 lines
18 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax.numpy.ufunc and its methods."""
import itertools
from functools import partial
from absl.testing import absltest
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
jax.config.parse_flags_with_absl()
def scalar_add(x, y):
assert np.shape(x) == np.shape(y) == ()
return x + y
def scalar_div(x, y):
assert np.shape(x) == np.shape(y) == ()
return x / y
def scalar_mul(x, y):
assert np.shape(x) == np.shape(y) == ()
return x * y
def scalar_sub(x, y):
assert np.shape(x) == np.shape(y) == ()
return x - y
SCALAR_FUNCS = [
{'func': scalar_add, 'nin': 2, 'nout': 1, 'identity': 0},
{'func': scalar_div, 'nin': 2, 'nout': 1, 'identity': None},
{'func': scalar_mul, 'nin': 2, 'nout': 1, 'identity': 1},
{'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None},
]
def _jnp_ufunc_props(name):
jnp_func = getattr(jnp, name)
assert isinstance(jnp_func, jnp.ufunc)
np_func = getattr(np, name)
dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types or f"{c}->{c}" in np_func.types]
return [dict(name=name, dtype=dtype) for dtype in dtypes]
JAX_NUMPY_UFUNCS = [
name for name in dir(jnp) if isinstance(getattr(jnp, name), jnp.ufunc)
]
BINARY_UFUNCS = [
name for name in JAX_NUMPY_UFUNCS if getattr(jnp, name).nin == 2
]
UNARY_UFUNCS = [
name for name in JAX_NUMPY_UFUNCS if getattr(jnp, name).nin == 1
]
JAX_NUMPY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable(
_jnp_ufunc_props(name) for name in JAX_NUMPY_UFUNCS
))
BINARY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable(
_jnp_ufunc_props(name) for name in BINARY_UFUNCS
))
UNARY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable(
_jnp_ufunc_props(name) for name in UNARY_UFUNCS
))
broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
nonscalar_shapes = [(3,), (4,), (4, 3)]
def cast_outputs(fun):
def wrapped(*args, **kwargs):
dtype = np.asarray(args[0]).dtype
return jax.tree.map(lambda x: np.asarray(x, dtype=dtype), fun(*args, **kwargs))
return wrapped
class LaxNumpyUfuncTests(jtu.JaxTestCase):
@jtu.sample_product(SCALAR_FUNCS)
def test_frompyfunc_properties(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
self.assertEqual(jnp_fun.identity, identity)
self.assertEqual(jnp_fun.nin, nin)
self.assertEqual(jnp_fun.nout, nout)
self.assertEqual(jnp_fun.nargs, nin)
@jtu.sample_product(name=JAX_NUMPY_UFUNCS)
def test_ufunc_properties(self, name):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
self.assertEqual(jnp_fun.identity, np_fun.identity)
self.assertEqual(jnp_fun.nin, np_fun.nin)
self.assertEqual(jnp_fun.nout, np_fun.nout)
self.assertEqual(jnp_fun.nargs, np_fun.nargs - 1) # -1 because NumPy accepts `out`
@jtu.sample_product(SCALAR_FUNCS)
def test_frompyfunc_properties_readonly(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
for attr in ['nargs', 'nin', 'nout', 'identity', '_func']:
getattr(jnp_fun, attr) # no error on attribute access.
with self.assertRaises(AttributeError):
setattr(jnp_fun, attr, None) # error when trying to mutate.
@jtu.sample_product(name=JAX_NUMPY_UFUNCS)
def test_ufunc_properties_readonly(self, name):
jnp_fun = getattr(jnp, name)
for attr in ['nargs', 'nin', 'nout', 'identity', '_func']:
getattr(jnp_fun, attr) # no error on attribute access.
with self.assertRaises(AttributeError):
setattr(jnp_fun, attr, None) # error when trying to mutate.
@jtu.sample_product(SCALAR_FUNCS)
def test_frompyfunc_hash(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
jnp_fun_2 = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
self.assertEqual(jnp_fun, jnp_fun_2)
self.assertEqual(hash(jnp_fun), hash(jnp_fun_2))
other_fun = jnp.frompyfunc(jnp.add, nin=2, nout=1, identity=0)
self.assertNotEqual(jnp_fun, other_fun)
# Note: don't test hash for non-equality because it may collide.
@jtu.sample_product(
SCALAR_FUNCS,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
dtype=jtu.dtypes.floating,
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def test_frompyfunc_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
np_fun = cast_outputs(np.frompyfunc(func, nin=nin, nout=nout, identity=identity))
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
UNARY_UFUNCS_WITH_DTYPES,
shape=broadcast_compatible_shapes,
)
def test_unary_ufunc_call(self, name, dtype, shape):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
BINARY_UFUNCS_WITH_DTYPES,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def test_binary_ufunc_call(self, name, dtype, lhs_shape, rhs_shape):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
dtype=jtu.dtypes.floating,
)
def test_frompyfunc_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).outer
np_fun = cast_outputs(np.frompyfunc(func, nin=nin, nout=nout, identity=identity).outer)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
BINARY_UFUNCS_WITH_DTYPES,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
)
def test_binary_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None
self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun.outer, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [None, *range(-len(shape), len(shape))]],
dtype=jtu.dtypes.floating,
)
def test_frompyfunc_reduce(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
np_fun = cast_outputs(partial(np.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis))
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
BINARY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [None, *range(-len(shape), len(shape))]],
)
def test_binary_ufunc_reduce(self, name, shape, axis, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if jnp_fun.identity is None and axis is None and len(shape) > 1:
self.skipTest("Multiple-axis reduction over non-reorderable ufunc.")
jnp_fun_reduce = partial(jnp_fun.reduce, axis=axis)
np_fun_reduce = partial(np_fun.reduce, axis=axis)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None
self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun_reduce, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [None, *range(-len(shape), len(shape))]],
dtype=jtu.dtypes.floating,
)
def test_frompyfunc_reduce_where(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
# Need initial if identity is None
initial = 1 if identity is None else None
def jnp_fun(arr, where):
return jnp.frompyfunc(func, nin, nout, identity=identity).reduce(
arr, where=where, axis=axis, initial=initial)
@cast_outputs
def np_fun(arr, where):
# Workaround for https://github.com/numpy/numpy/issues/24530
# TODO(jakevdp): remove this when possible.
initial_workaround = identity if initial is None else initial
return np.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce(
arr, where=where, axis=axis, initial=initial_workaround)
rng = jtu.rand_default(self.rng())
rng_where = jtu.rand_bool(self.rng())
args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
BINARY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [None, *range(-len(shape), len(shape))]],
)
def test_binary_ufunc_reduce_where(self, name, shape, axis, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if jnp_fun.identity is None:
self.skipTest("reduce with where requires identity")
jnp_fun_reduce = lambda a, where: jnp_fun.reduce(a, axis=axis, where=where)
np_fun_reduce = lambda a, where: np_fun.reduce(a, axis=axis, where=where)
rng = jtu.rand_default(self.rng())
rng_where = jtu.rand_bool(self.rng())
args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)]
tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None
self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun_reduce, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating,
)
def test_frompyfunc_accumulate(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
np_fun = cast_outputs(partial(np.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis))
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
BINARY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))]
)
def test_binary_ufunc_accumulate(self, name, shape, axis, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_fun_accumulate = partial(jnp_fun.accumulate, axis=axis)
def np_fun_accumulate(x):
# numpy accumulate has different dtype casting behavior.
result = np_fun.accumulate(x, axis=axis)
return result if x.dtype == bool else result.astype(x.dtype)
tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None
self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun_accumulate, args_maker, tol=tol)
@jtu.sample_product(
SCALAR_FUNCS,
shape=nonscalar_shapes,
idx_shape=[(), (2,)],
dtype=jtu.dtypes.floating,
)
def test_frompyfunc_at(self, func, nin, nout, identity, shape, idx_shape, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).at, inplace=False)
def np_fun(x, idx, y):
x_copy = x.copy()
np.frompyfunc(func, nin=nin, nout=nout, identity=identity).at(x_copy, idx, y)
return x_copy
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0])
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32'), rng(idx_shape[1:], dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
UNARY_UFUNCS_WITH_DTYPES,
shape=nonscalar_shapes,
idx_shape=[(), (2,)],
)
def test_unary_ufunc_at(self, name, shape, idx_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0])
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')]
jnp_fun_at = partial(jnp_fun.at, inplace=False)
def np_fun_at(x, idx):
x_copy = x.copy()
np_fun.at(x_copy, idx)
return x_copy
tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None
self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun_at, args_maker)
@jtu.sample_product(
BINARY_UFUNCS_WITH_DTYPES,
shape=nonscalar_shapes,
idx_shape=[(), (2,)],
)
def test_binary_ufunc_at(self, name, shape, idx_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0])
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32'), rng(idx_shape[1:], dtype)]
jnp_fun_at = partial(jnp_fun.at, inplace=False)
def np_fun_at(x, idx, y):
x_copy = x.copy()
np_fun.at(x_copy, idx, y)
return x_copy
tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None
self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun_at, args_maker)
def test_frompyfunc_at_broadcasting(self):
# Regression test for https://github.com/jax-ml/jax/issues/18004
args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]),
np.arange(9.0).reshape(3, 3)]
def np_fun(x, idx, y):
x_copy = np.copy(x)
np.add.at(x_copy, idx, y)
return x_copy
jnp_fun = partial(jnp.frompyfunc(jnp.add, nin=2, nout=1, identity=0).at, inplace=False)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [*range(-len(shape), len(shape))]],
idx_shape=[(0,), (3,), (5,)],
dtype=jtu.dtypes.floating,
)
def test_frompyfunc_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduceat, axis=axis)
np_fun = cast_outputs(partial(np.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduceat, axis=axis))
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=0, high=shape[axis])
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
BINARY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [*range(-len(shape), len(shape))]],
idx_shape=[(0,), (3,), (5,)],
)
def test_binary_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
if name in ['add', 'multiply'] and dtype == bool:
# TODO(jakevdp): figure out how to fix thest cases.
self.skipTest(f"known failure for {name}.reduceat with {dtype=}")
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=0, high=shape[axis])
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')]
def np_fun_reduceat(x, i):
# Numpy has different casting behavior.
return np_fun.reduceat(x, i).astype(x.dtype)
tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None
self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun.reduceat, args_maker)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())