rocm_jax/tests/lax_numpy_ufuncs_test.py

445 lines
17 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax.numpy.ufunc and its methods."""
import itertools
from functools import partial
from absl.testing import absltest
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
jax.config.parse_flags_with_absl()
def scalar_add(x, y):
assert np.shape(x) == np.shape(y) == ()
return x + y
def scalar_div(x, y):
assert np.shape(x) == np.shape(y) == ()
return x / y
def scalar_mul(x, y):
assert np.shape(x) == np.shape(y) == ()
return x * y
def scalar_sub(x, y):
assert np.shape(x) == np.shape(y) == ()
return x - y
SCALAR_FUNCS = [
{'func': scalar_add, 'nin': 2, 'nout': 1, 'identity': 0},
{'func': scalar_div, 'nin': 2, 'nout': 1, 'identity': None},
{'func': scalar_mul, 'nin': 2, 'nout': 1, 'identity': 1},
{'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None},
]
def _jnp_ufunc_props(name):
jnp_func = getattr(jnp, name)
assert isinstance(jnp_func, jnp.ufunc)
np_func = getattr(np, name)
dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types]
return [dict(name=name, dtype=dtype) for dtype in dtypes]
JAX_NUMPY_UFUNCS = [
name for name in dir(jnp) if isinstance(getattr(jnp, name), jnp.ufunc)
]
JAX_NUMPY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable(
_jnp_ufunc_props(name) for name in JAX_NUMPY_UFUNCS
))
broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
nonscalar_shapes = [(3,), (4,), (4, 3)]
def cast_outputs(fun):
def wrapped(*args, **kwargs):
dtype = np.asarray(args[0]).dtype
return jax.tree.map(lambda x: np.asarray(x, dtype=dtype), fun(*args, **kwargs))
return wrapped
class LaxNumpyUfuncTests(jtu.JaxTestCase):
@jtu.sample_product(SCALAR_FUNCS)
def test_frompyfunc_properties(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
self.assertEqual(jnp_fun.identity, identity)
self.assertEqual(jnp_fun.nin, nin)
self.assertEqual(jnp_fun.nout, nout)
self.assertEqual(jnp_fun.nargs, nin)
@jtu.sample_product(name=JAX_NUMPY_UFUNCS)
def test_ufunc_properties(self, name):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
self.assertEqual(jnp_fun.identity, np_fun.identity)
self.assertEqual(jnp_fun.nin, np_fun.nin)
self.assertEqual(jnp_fun.nout, np_fun.nout)
self.assertEqual(jnp_fun.nargs, np_fun.nargs - 1) # -1 because NumPy accepts `out`
@jtu.sample_product(SCALAR_FUNCS)
def test_frompyfunc_properties_readonly(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
for attr in ['nargs', 'nin', 'nout', 'identity', '_func']:
getattr(jnp_fun, attr) # no error on attribute access.
with self.assertRaises(AttributeError):
setattr(jnp_fun, attr, None) # error when trying to mutate.
@jtu.sample_product(name=JAX_NUMPY_UFUNCS)
def test_ufunc_properties_readonly(self, name):
jnp_fun = getattr(jnp, name)
for attr in ['nargs', 'nin', 'nout', 'identity', '_func']:
getattr(jnp_fun, attr) # no error on attribute access.
with self.assertRaises(AttributeError):
setattr(jnp_fun, attr, None) # error when trying to mutate.
@jtu.sample_product(SCALAR_FUNCS)
def test_frompyfunc_hash(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
jnp_fun_2 = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
self.assertEqual(jnp_fun, jnp_fun_2)
self.assertEqual(hash(jnp_fun), hash(jnp_fun_2))
other_fun = jnp.frompyfunc(jnp.add, nin=2, nout=1, identity=0)
self.assertNotEqual(jnp_fun, other_fun)
# Note: don't test hash for non-equality because it may collide.
@jtu.sample_product(
SCALAR_FUNCS,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
dtype=jtu.dtypes.floating,
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def test_frompyfunc_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
np_fun = cast_outputs(np.frompyfunc(func, nin=nin, nout=nout, identity=identity))
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def test_ufunc_call(self, name, dtype, lhs_shape, rhs_shape):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
dtype=jtu.dtypes.floating,
)
def test_frompyfunc_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).outer
np_fun = cast_outputs(np.frompyfunc(func, nin=nin, nout=nout, identity=identity).outer)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
)
def test_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker)
self._CompileAndCheck(jnp_fun.outer, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [None, *range(-len(shape), len(shape))]],
dtype=jtu.dtypes.floating,
)
def test_frompyfunc_reduce(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
np_fun = cast_outputs(partial(np.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis))
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [None, *range(-len(shape), len(shape))]],
)
def test_ufunc_reduce(self, name, shape, axis, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
jnp_fun_reduce = partial(jnp_fun.reduce, axis=axis)
np_fun_reduce = partial(np_fun.reduce, axis=axis)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker)
self._CompileAndCheck(jnp_fun_reduce, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [None, *range(-len(shape), len(shape))]],
dtype=jtu.dtypes.floating,
)
def test_frompyfunc_reduce_where(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
# Need initial if identity is None
initial = 1 if identity is None else None
def jnp_fun(arr, where):
return jnp.frompyfunc(func, nin, nout, identity=identity).reduce(
arr, where=where, axis=axis, initial=initial)
@cast_outputs
def np_fun(arr, where):
# Workaround for https://github.com/numpy/numpy/issues/24530
# TODO(jakevdp): remove this when possible.
initial_workaround = identity if initial is None else initial
return np.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce(
arr, where=where, axis=axis, initial=initial_workaround)
rng = jtu.rand_default(self.rng())
rng_where = jtu.rand_bool(self.rng())
args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [None, *range(-len(shape), len(shape))]],
)
def test_ufunc_reduce_where(self, name, shape, axis, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
if jnp_fun.identity is None:
self.skipTest("reduce with where requires identity")
jnp_fun_reduce = lambda a, where: jnp_fun.reduce(a, axis=axis, where=where)
np_fun_reduce = lambda a, where: np_fun.reduce(a, axis=axis, where=where)
rng = jtu.rand_default(self.rng())
rng_where = jtu.rand_bool(self.rng())
args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)]
self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker)
self._CompileAndCheck(jnp_fun_reduce, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating,
)
def test_frompyfunc_accumulate(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
np_fun = cast_outputs(partial(np.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis))
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))]
)
def test_ufunc_accumulate(self, name, shape, axis, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_fun_accumulate = partial(jnp_fun.accumulate, axis=axis)
def np_fun_accumulate(x):
# numpy accumulate has different dtype casting behavior.
result = np_fun.accumulate(x, axis=axis)
return result if x.dtype == bool else result.astype(x.dtype)
self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker)
self._CompileAndCheck(jnp_fun_accumulate, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
shape=nonscalar_shapes,
idx_shape=[(), (2,)],
dtype=jtu.dtypes.floating,
)
def test_frompyfunc_at(self, func, nin, nout, identity, shape, idx_shape, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).at, inplace=False)
def np_fun(x, idx, y):
x_copy = x.copy()
np.frompyfunc(func, nin=nin, nout=nout, identity=identity).at(x_copy, idx, y)
return x_copy
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0])
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32'), rng(idx_shape[1:], dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
shape=nonscalar_shapes,
idx_shape=[(), (2,)],
)
def test_ufunc_at(self, name, shape, idx_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0])
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32'), rng(idx_shape[1:], dtype)]
jnp_fun_at = partial(jnp_fun.at, inplace=False)
def np_fun_at(x, idx, y):
x_copy = x.copy()
np_fun.at(x_copy, idx, y)
return x_copy
self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker)
self._CompileAndCheck(jnp_fun_at, args_maker)
def test_frompyfunc_at_broadcasting(self):
# Regression test for https://github.com/google/jax/issues/18004
args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]),
np.arange(9.0).reshape(3, 3)]
def np_fun(x, idx, y):
x_copy = np.copy(x)
np.add.at(x_copy, idx, y)
return x_copy
jnp_fun = partial(jnp.frompyfunc(jnp.add, nin=2, nout=1, identity=0).at, inplace=False)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [*range(-len(shape), len(shape))]],
idx_shape=[(0,), (3,), (5,)],
dtype=jtu.dtypes.floating,
)
def test_frompyfunc_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}")
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduceat, axis=axis)
np_fun = cast_outputs(partial(np.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduceat, axis=axis))
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=0, high=shape[axis])
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [*range(-len(shape), len(shape))]],
idx_shape=[(0,), (3,), (5,)],
)
def test_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")
if name in ['add', 'multiply'] and dtype == bool:
# TODO(jakevdp): figure out how to fix thest cases.
self.skipTest(f"known failure for {name}.reduceat with {dtype=}")
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=0, high=shape[axis])
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')]
def np_fun_reduceat(x, i):
# Numpy has different casting behavior.
return np_fun.reduceat(x, i).astype(x.dtype)
self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker)
self._CompileAndCheck(jnp_fun.reduceat, args_maker)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())