From e07325a6729ebbe63382f8bf826d4791d6b60859 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 9 Apr 2024 16:15:59 -0700 Subject: [PATCH] Make complex_arr.astype(bool) follow NumPy's semantics --- CHANGELOG.md | 2 ++ jax/_src/numpy/lax_numpy.py | 7 ++++++- tests/lax_numpy_test.py | 18 ++++++++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8acd30c6..e9f6acf70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ Remember to align the itemized text with the first line of an item within a list now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover the old behavior by transforming the arguments via `jax.tree.map(np.asarray, args)` before passing them to the callback. + * `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning + False where `complex_arr` is equal to `0 + 0j`, and True otherwise. * Deprecations & Removals * Pallas now exclusively uses XLA for compiling kernels on GPU. The old diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3df6dcaf1..4b23ca210 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2263,11 +2263,16 @@ In particular, the details of float-to-int and int-to-float casts are implementation dependent. """) def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array: + util.check_arraylike("astype", x) + x_arr = asarray(x) del copy # unused in JAX if dtype is None: dtype = dtypes.canonicalize_dtype(float_) dtypes.check_user_dtype_supported(dtype, "astype") - return lax.convert_element_type(x, dtype) + # convert_element_type(complex, bool) has the wrong semantics. + if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating): + return (x_arr != _lax_const(x_arr, 0)) + return lax.convert_element_type(x_arr, dtype) @util.implements(np.asarray, lax_description=_ARRAY_DOC) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 96232fd87..c957ed669 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3822,6 +3822,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + @jtu.sample_product( + from_dtype=['int32', 'float32', 'complex64'], + use_method=[True, False], + ) + def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng((3, 4), from_dtype)] + if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0 + np_op = lambda x: np.astype(x, to_dtype) + else: + np_op = lambda x: np.asarray(x).astype(to_dtype) + if use_method: + jnp_op = lambda x: jnp.asarray(x).astype(to_dtype) + else: + jnp_op = lambda x: jnp.astype(x, to_dtype) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + def testAstypeInt4(self): # Test converting from int4 to int8 x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)