mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
jax.numpy: improve support for boolean inputs
This commit is contained in:
parent
70bc0da995
commit
3f06195994
@ -69,6 +69,11 @@ _dtype_to_inexact = {
|
||||
]
|
||||
}
|
||||
|
||||
def to_numeric_dtype(dtype):
|
||||
"""Promotes a dtype into an numeric dtype, if it is not already one."""
|
||||
dtype = np.dtype(dtype)
|
||||
return np.dtype('int32') if dtype == np.dtype('bool') else dtype
|
||||
|
||||
|
||||
def _to_inexact_dtype(dtype):
|
||||
"""Promotes a dtype into an inexact dtype, if it is not already one."""
|
||||
|
@ -73,8 +73,8 @@ from jax._src.numpy.ufuncs import ( # noqa: F401
|
||||
from jax._src.numpy.util import ( # noqa: F401
|
||||
_arraylike, _broadcast_arrays, _broadcast_to, _check_arraylike,
|
||||
_complex_elem_type, _promote_args, _promote_args_inexact, _promote_dtypes,
|
||||
_promote_dtypes_inexact, _promote_shapes, _register_stackable, _stackable,
|
||||
_where, _wraps)
|
||||
_promote_dtypes_numeric, _promote_dtypes_inexact, _promote_shapes,
|
||||
_register_stackable, _stackable, _where, _wraps)
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.ops import scatter
|
||||
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
|
||||
@ -4100,10 +4100,9 @@ def _gcd_body_fn(xs):
|
||||
@jit
|
||||
def gcd(x1, x2):
|
||||
_check_arraylike("gcd", x1, x2)
|
||||
if (not issubdtype(_dtype(x1), integer) or
|
||||
not issubdtype(_dtype(x2), integer)):
|
||||
raise ValueError("Arguments to jax.numpy.gcd must be integers.")
|
||||
x1, x2 = _promote_dtypes(x1, x2)
|
||||
if not issubdtype(_dtype(x1), integer):
|
||||
raise ValueError("Arguments to jax.numpy.gcd must be integers.")
|
||||
x1, x2 = broadcast_arrays(x1, x2)
|
||||
gcd, _ = lax.while_loop(_gcd_cond_fn, _gcd_body_fn, (abs(x1), abs(x2)))
|
||||
return gcd
|
||||
@ -4114,6 +4113,8 @@ def gcd(x1, x2):
|
||||
def lcm(x1, x2):
|
||||
_check_arraylike("lcm", x1, x2)
|
||||
x1, x2 = _promote_dtypes(x1, x2)
|
||||
if not issubdtype(_dtype(x1), integer):
|
||||
raise ValueError("Arguments to jax.numpy.lcm must be integers.")
|
||||
d = gcd(x1, x2)
|
||||
return where(d == 0, _lax_const(d, 0),
|
||||
abs(multiply(x1, floor_divide(x2, d))))
|
||||
|
@ -28,7 +28,8 @@ from jax._src import dtypes
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.util import (
|
||||
_check_arraylike, _promote_args, _promote_args_inexact,
|
||||
_promote_dtypes_inexact, _promote_shapes, _where, _wraps)
|
||||
_promote_args_numeric, _promote_dtypes_inexact, _promote_dtypes_numeric,
|
||||
_promote_shapes, _where, _wraps)
|
||||
from jax import core
|
||||
from jax import lax
|
||||
|
||||
@ -62,9 +63,11 @@ def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
|
||||
return _wraps(numpy_fn, module='numpy')(fn)
|
||||
|
||||
|
||||
def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
|
||||
def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False, promote_to_numeric=False):
|
||||
if promote_to_inexact:
|
||||
fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
|
||||
elif promote_to_numeric:
|
||||
fn = lambda x1, x2: lax_fn(*_promote_args_numeric(numpy_fn.__name__, x1, x2))
|
||||
else:
|
||||
fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
|
||||
fn = jit(fn, inline=True)
|
||||
@ -143,7 +146,7 @@ add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or)
|
||||
bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and)
|
||||
bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or)
|
||||
bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor)
|
||||
left_shift = _one_to_one_binop(np.left_shift, lax.shift_left)
|
||||
left_shift = _one_to_one_binop(np.left_shift, lax.shift_left, promote_to_numeric=True)
|
||||
equal = _one_to_one_binop(np.equal, lax.eq)
|
||||
multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and)
|
||||
not_equal = _one_to_one_binop(np.not_equal, lax.ne)
|
||||
@ -179,7 +182,7 @@ def arccosh(x):
|
||||
@_wraps(np.right_shift, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def right_shift(x1, x2):
|
||||
x1, x2 = _promote_args(np.right_shift.__name__, x1, x2)
|
||||
x1, x2 = _promote_args_numeric(np.right_shift.__name__, x1, x2)
|
||||
lax_fn = lax.shift_right_logical if \
|
||||
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
|
||||
return lax_fn(x1, x2)
|
||||
@ -199,7 +202,7 @@ abs = _wraps(np.abs, module='numpy')(absolute)
|
||||
def rint(x):
|
||||
_check_arraylike('rint', x)
|
||||
dtype = dtypes.dtype(x)
|
||||
if dtypes.issubdtype(dtype, np.integer):
|
||||
if dtype == bool or dtypes.issubdtype(x.dtype, np.integer):
|
||||
return lax.convert_element_type(x, dtypes.float_)
|
||||
if dtypes.issubdtype(dtype, np.complexfloating):
|
||||
return lax.complex(rint(lax.real(x)), rint(lax.imag(x)))
|
||||
@ -239,7 +242,7 @@ divide = true_divide
|
||||
@_wraps(np.floor_divide, module='numpy')
|
||||
@jit
|
||||
def floor_divide(x1, x2):
|
||||
x1, x2 = _promote_args("floor_divide", x1, x2)
|
||||
x1, x2 = _promote_args_numeric("floor_divide", x1, x2)
|
||||
dtype = dtypes.dtype(x1)
|
||||
if dtypes.issubdtype(dtype, np.integer):
|
||||
quotient = lax.div(x1, x2)
|
||||
@ -264,7 +267,7 @@ def floor_divide(x1, x2):
|
||||
@_wraps(np.divmod, module='numpy')
|
||||
@jit
|
||||
def divmod(x1, x2):
|
||||
x1, x2 = _promote_args("divmod", x1, x2)
|
||||
x1, x2 = _promote_args_numeric("divmod", x1, x2)
|
||||
if dtypes.issubdtype(dtypes.dtype(x1), np.integer):
|
||||
return floor_divide(x1, x2), remainder(x1, x2)
|
||||
else:
|
||||
@ -285,7 +288,7 @@ def _float_divmod(x1, x2):
|
||||
|
||||
@partial(jit, inline=True)
|
||||
def _power(x1, x2):
|
||||
x1, x2 = _promote_args("power", x1, x2)
|
||||
x1, x2 = _promote_args_numeric("power", x1, x2)
|
||||
dtype = dtypes.dtype(x1)
|
||||
if not dtypes.issubdtype(dtype, np.integer):
|
||||
return lax.pow(x1, x2)
|
||||
@ -307,6 +310,7 @@ def _power(x1, x2):
|
||||
|
||||
@_wraps(np.power, module='numpy')
|
||||
def power(x1, x2):
|
||||
_check_arraylike("power", x1, x2)
|
||||
# Special case for concrete integer scalars: use binary exponentiation.
|
||||
# Using lax.pow may be imprecise for floating-point values; the goal of this
|
||||
# code path is to make sure we end up with a precise output for the common
|
||||
@ -317,6 +321,7 @@ def power(x1, x2):
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
x1, = _promote_dtypes_numeric(x1)
|
||||
return lax.integer_pow(x1, x2)
|
||||
return _power(x1, x2)
|
||||
|
||||
@ -522,7 +527,7 @@ def frexp(x):
|
||||
@_wraps(np.remainder, module='numpy')
|
||||
@jit
|
||||
def remainder(x1, x2):
|
||||
x1, x2 = _promote_args("remainder", x1, x2)
|
||||
x1, x2 = _promote_args_numeric("remainder", x1, x2)
|
||||
zero = _constant_like(x1, 0)
|
||||
trunc_mod = lax.rem(x1, x2)
|
||||
trunc_mod_not_zero = lax.ne(trunc_mod, zero)
|
||||
@ -538,13 +543,14 @@ def fmod(x1, x2):
|
||||
_check_arraylike("fmod", x1, x2)
|
||||
if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer):
|
||||
x2 = _where(x2 == 0, lax_internal._ones(x2), x2)
|
||||
return lax.rem(*_promote_args("fmod", x1, x2))
|
||||
return lax.rem(*_promote_args_numeric("fmod", x1, x2))
|
||||
|
||||
|
||||
@_wraps(np.square, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def square(x):
|
||||
_check_arraylike("square", x)
|
||||
x, = _promote_dtypes_numeric(x)
|
||||
return lax.integer_pow(x, 2)
|
||||
|
||||
|
||||
|
@ -287,6 +287,17 @@ def _promote_dtypes_inexact(*args):
|
||||
for x in args]
|
||||
|
||||
|
||||
def _promote_dtypes_numeric(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion.
|
||||
|
||||
Promotes arguments to a numeric (non-bool) type."""
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
to_dtype_numeric = dtypes.to_numeric_dtype(to_dtype)
|
||||
return [lax_internal._convert_element_type(x, to_dtype_numeric, weak_type)
|
||||
for x in args]
|
||||
|
||||
|
||||
def _promote_dtypes_complex(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion.
|
||||
|
||||
@ -344,6 +355,12 @@ def _promote_args(fun_name, *args):
|
||||
return _promote_shapes(fun_name, *_promote_dtypes(*args))
|
||||
|
||||
|
||||
def _promote_args_numeric(fun_name, *args):
|
||||
_check_arraylike(fun_name, *args)
|
||||
_check_no_float0s(fun_name, *args)
|
||||
return _promote_shapes(fun_name, *_promote_dtypes_numeric(*args))
|
||||
|
||||
|
||||
def _promote_args_inexact(fun_name, *args):
|
||||
"""Convenience function to apply Numpy argument shape and dtype promotion.
|
||||
|
||||
|
@ -6497,10 +6497,6 @@ class NumpyUfuncTests(jtu.JaxTestCase):
|
||||
for name in _all_numpy_ufuncs()
|
||||
for arg_dtypes in jtu.cases_from_list(_dtypes_for_ufunc(name)))
|
||||
def testUfuncInputTypes(self, name, arg_dtypes):
|
||||
# TODO(jakevdp): fix following failures and remove from this exception list.
|
||||
if (name in ['gcd', 'left_shift', 'power', 'remainder', 'right_shift', 'rint', 'square']
|
||||
and 'bool_' in arg_dtypes):
|
||||
self.skipTest(f"jax.numpy does not support {name}{tuple(arg_dtypes)}")
|
||||
if name == 'arctanh' and jnp.issubdtype(arg_dtypes[0], jnp.complexfloating):
|
||||
self.skipTest("np.arctanh & jnp.arctanh have mismatched NaNs for complex input.")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user