Make complex_arr.astype(bool) follow NumPy's semantics

This commit is contained in:
Jake VanderPlas 2024-04-09 16:15:59 -07:00
parent f1ae6232e9
commit e07325a672
3 changed files with 26 additions and 1 deletions

View File

@ -13,6 +13,8 @@ Remember to align the itemized text with the first line of an item within a list
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
the old behavior by transforming the arguments via
`jax.tree.map(np.asarray, args)` before passing them to the callback.
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
* Deprecations & Removals
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old

View File

@ -2263,11 +2263,16 @@ In particular, the details of float-to-int and int-to-float casts are
implementation dependent.
""")
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
util.check_arraylike("astype", x)
x_arr = asarray(x)
del copy # unused in JAX
if dtype is None:
dtype = dtypes.canonicalize_dtype(float_)
dtypes.check_user_dtype_supported(dtype, "astype")
return lax.convert_element_type(x, dtype)
# convert_element_type(complex, bool) has the wrong semantics.
if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating):
return (x_arr != _lax_const(x_arr, 0))
return lax.convert_element_type(x_arr, dtype)
@util.implements(np.asarray, lax_description=_ARRAY_DOC)

View File

@ -3822,6 +3822,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@jtu.sample_product(
from_dtype=['int32', 'float32', 'complex64'],
use_method=[True, False],
)
def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng((3, 4), from_dtype)]
if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0
np_op = lambda x: np.astype(x, to_dtype)
else:
np_op = lambda x: np.asarray(x).astype(to_dtype)
if use_method:
jnp_op = lambda x: jnp.asarray(x).astype(to_dtype)
else:
jnp_op = lambda x: jnp.astype(x, to_dtype)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testAstypeInt4(self):
# Test converting from int4 to int8
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)