From 3f0619599499fc0751cd6181c04d50245ef5dcce Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 12 Aug 2022 09:51:25 -0700 Subject: [PATCH] jax.numpy: improve support for boolean inputs --- jax/_src/dtypes.py | 5 +++++ jax/_src/numpy/lax_numpy.py | 11 ++++++----- jax/_src/numpy/ufuncs.py | 26 ++++++++++++++++---------- jax/_src/numpy/util.py | 17 +++++++++++++++++ tests/lax_numpy_test.py | 4 ---- 5 files changed, 44 insertions(+), 19 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index d1c3c528d..2488a9a38 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -69,6 +69,11 @@ _dtype_to_inexact = { ] } +def to_numeric_dtype(dtype): + """Promotes a dtype into an numeric dtype, if it is not already one.""" + dtype = np.dtype(dtype) + return np.dtype('int32') if dtype == np.dtype('bool') else dtype + def _to_inexact_dtype(dtype): """Promotes a dtype into an inexact dtype, if it is not already one.""" diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 85f78c561..670eae9f3 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -73,8 +73,8 @@ from jax._src.numpy.ufuncs import ( # noqa: F401 from jax._src.numpy.util import ( # noqa: F401 _arraylike, _broadcast_arrays, _broadcast_to, _check_arraylike, _complex_elem_type, _promote_args, _promote_args_inexact, _promote_dtypes, - _promote_dtypes_inexact, _promote_shapes, _register_stackable, _stackable, - _where, _wraps) + _promote_dtypes_numeric, _promote_dtypes_inexact, _promote_shapes, + _register_stackable, _stackable, _where, _wraps) from jax._src.numpy.vectorize import vectorize from jax._src.ops import scatter from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio, @@ -4100,10 +4100,9 @@ def _gcd_body_fn(xs): @jit def gcd(x1, x2): _check_arraylike("gcd", x1, x2) - if (not issubdtype(_dtype(x1), integer) or - not issubdtype(_dtype(x2), integer)): - raise ValueError("Arguments to jax.numpy.gcd must be integers.") x1, x2 = _promote_dtypes(x1, x2) + if not issubdtype(_dtype(x1), integer): + raise ValueError("Arguments to jax.numpy.gcd must be integers.") x1, x2 = broadcast_arrays(x1, x2) gcd, _ = lax.while_loop(_gcd_cond_fn, _gcd_body_fn, (abs(x1), abs(x2))) return gcd @@ -4114,6 +4113,8 @@ def gcd(x1, x2): def lcm(x1, x2): _check_arraylike("lcm", x1, x2) x1, x2 = _promote_dtypes(x1, x2) + if not issubdtype(_dtype(x1), integer): + raise ValueError("Arguments to jax.numpy.lcm must be integers.") d = gcd(x1, x2) return where(d == 0, _lax_const(d, 0), abs(multiply(x1, floor_divide(x2, d)))) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index bd4a1ca45..24c8f3796 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -28,7 +28,8 @@ from jax._src import dtypes from jax._src.lax import lax as lax_internal from jax._src.numpy.util import ( _check_arraylike, _promote_args, _promote_args_inexact, - _promote_dtypes_inexact, _promote_shapes, _where, _wraps) + _promote_args_numeric, _promote_dtypes_inexact, _promote_dtypes_numeric, + _promote_shapes, _where, _wraps) from jax import core from jax import lax @@ -62,9 +63,11 @@ def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): return _wraps(numpy_fn, module='numpy')(fn) -def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): +def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False, promote_to_numeric=False): if promote_to_inexact: fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2)) + elif promote_to_numeric: + fn = lambda x1, x2: lax_fn(*_promote_args_numeric(numpy_fn.__name__, x1, x2)) else: fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2)) fn = jit(fn, inline=True) @@ -143,7 +146,7 @@ add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or) bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and) bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or) bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor) -left_shift = _one_to_one_binop(np.left_shift, lax.shift_left) +left_shift = _one_to_one_binop(np.left_shift, lax.shift_left, promote_to_numeric=True) equal = _one_to_one_binop(np.equal, lax.eq) multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and) not_equal = _one_to_one_binop(np.not_equal, lax.ne) @@ -179,7 +182,7 @@ def arccosh(x): @_wraps(np.right_shift, module='numpy') @partial(jit, inline=True) def right_shift(x1, x2): - x1, x2 = _promote_args(np.right_shift.__name__, x1, x2) + x1, x2 = _promote_args_numeric(np.right_shift.__name__, x1, x2) lax_fn = lax.shift_right_logical if \ np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic return lax_fn(x1, x2) @@ -199,7 +202,7 @@ abs = _wraps(np.abs, module='numpy')(absolute) def rint(x): _check_arraylike('rint', x) dtype = dtypes.dtype(x) - if dtypes.issubdtype(dtype, np.integer): + if dtype == bool or dtypes.issubdtype(x.dtype, np.integer): return lax.convert_element_type(x, dtypes.float_) if dtypes.issubdtype(dtype, np.complexfloating): return lax.complex(rint(lax.real(x)), rint(lax.imag(x))) @@ -239,7 +242,7 @@ divide = true_divide @_wraps(np.floor_divide, module='numpy') @jit def floor_divide(x1, x2): - x1, x2 = _promote_args("floor_divide", x1, x2) + x1, x2 = _promote_args_numeric("floor_divide", x1, x2) dtype = dtypes.dtype(x1) if dtypes.issubdtype(dtype, np.integer): quotient = lax.div(x1, x2) @@ -264,7 +267,7 @@ def floor_divide(x1, x2): @_wraps(np.divmod, module='numpy') @jit def divmod(x1, x2): - x1, x2 = _promote_args("divmod", x1, x2) + x1, x2 = _promote_args_numeric("divmod", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.integer): return floor_divide(x1, x2), remainder(x1, x2) else: @@ -285,7 +288,7 @@ def _float_divmod(x1, x2): @partial(jit, inline=True) def _power(x1, x2): - x1, x2 = _promote_args("power", x1, x2) + x1, x2 = _promote_args_numeric("power", x1, x2) dtype = dtypes.dtype(x1) if not dtypes.issubdtype(dtype, np.integer): return lax.pow(x1, x2) @@ -307,6 +310,7 @@ def _power(x1, x2): @_wraps(np.power, module='numpy') def power(x1, x2): + _check_arraylike("power", x1, x2) # Special case for concrete integer scalars: use binary exponentiation. # Using lax.pow may be imprecise for floating-point values; the goal of this # code path is to make sure we end up with a precise output for the common @@ -317,6 +321,7 @@ def power(x1, x2): except TypeError: pass else: + x1, = _promote_dtypes_numeric(x1) return lax.integer_pow(x1, x2) return _power(x1, x2) @@ -522,7 +527,7 @@ def frexp(x): @_wraps(np.remainder, module='numpy') @jit def remainder(x1, x2): - x1, x2 = _promote_args("remainder", x1, x2) + x1, x2 = _promote_args_numeric("remainder", x1, x2) zero = _constant_like(x1, 0) trunc_mod = lax.rem(x1, x2) trunc_mod_not_zero = lax.ne(trunc_mod, zero) @@ -538,13 +543,14 @@ def fmod(x1, x2): _check_arraylike("fmod", x1, x2) if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): x2 = _where(x2 == 0, lax_internal._ones(x2), x2) - return lax.rem(*_promote_args("fmod", x1, x2)) + return lax.rem(*_promote_args_numeric("fmod", x1, x2)) @_wraps(np.square, module='numpy') @partial(jit, inline=True) def square(x): _check_arraylike("square", x) + x, = _promote_dtypes_numeric(x) return lax.integer_pow(x, 2) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index e6b395bd6..e6945cc41 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -287,6 +287,17 @@ def _promote_dtypes_inexact(*args): for x in args] +def _promote_dtypes_numeric(*args): + """Convenience function to apply Numpy argument dtype promotion. + + Promotes arguments to a numeric (non-bool) type.""" + to_dtype, weak_type = dtypes._lattice_result_type(*args) + to_dtype = dtypes.canonicalize_dtype(to_dtype) + to_dtype_numeric = dtypes.to_numeric_dtype(to_dtype) + return [lax_internal._convert_element_type(x, to_dtype_numeric, weak_type) + for x in args] + + def _promote_dtypes_complex(*args): """Convenience function to apply Numpy argument dtype promotion. @@ -344,6 +355,12 @@ def _promote_args(fun_name, *args): return _promote_shapes(fun_name, *_promote_dtypes(*args)) +def _promote_args_numeric(fun_name, *args): + _check_arraylike(fun_name, *args) + _check_no_float0s(fun_name, *args) + return _promote_shapes(fun_name, *_promote_dtypes_numeric(*args)) + + def _promote_args_inexact(fun_name, *args): """Convenience function to apply Numpy argument shape and dtype promotion. diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a469fd774..337c98b15 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6497,10 +6497,6 @@ class NumpyUfuncTests(jtu.JaxTestCase): for name in _all_numpy_ufuncs() for arg_dtypes in jtu.cases_from_list(_dtypes_for_ufunc(name))) def testUfuncInputTypes(self, name, arg_dtypes): - # TODO(jakevdp): fix following failures and remove from this exception list. - if (name in ['gcd', 'left_shift', 'power', 'remainder', 'right_shift', 'rint', 'square'] - and 'bool_' in arg_dtypes): - self.skipTest(f"jax.numpy does not support {name}{tuple(arg_dtypes)}") if name == 'arctanh' and jnp.issubdtype(arg_dtypes[0], jnp.complexfloating): self.skipTest("np.arctanh & jnp.arctanh have mismatched NaNs for complex input.")