mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Make complex_arr.astype(bool) follow NumPy's semantics
This commit is contained in:
parent
f1ae6232e9
commit
e07325a672
@ -13,6 +13,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
|
||||
the old behavior by transforming the arguments via
|
||||
`jax.tree.map(np.asarray, args)` before passing them to the callback.
|
||||
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
|
||||
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
|
||||
|
||||
* Deprecations & Removals
|
||||
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
|
||||
|
@ -2263,11 +2263,16 @@ In particular, the details of float-to-int and int-to-float casts are
|
||||
implementation dependent.
|
||||
""")
|
||||
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
|
||||
util.check_arraylike("astype", x)
|
||||
x_arr = asarray(x)
|
||||
del copy # unused in JAX
|
||||
if dtype is None:
|
||||
dtype = dtypes.canonicalize_dtype(float_)
|
||||
dtypes.check_user_dtype_supported(dtype, "astype")
|
||||
return lax.convert_element_type(x, dtype)
|
||||
# convert_element_type(complex, bool) has the wrong semantics.
|
||||
if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating):
|
||||
return (x_arr != _lax_const(x_arr, 0))
|
||||
return lax.convert_element_type(x_arr, dtype)
|
||||
|
||||
|
||||
@util.implements(np.asarray, lax_description=_ARRAY_DOC)
|
||||
|
@ -3822,6 +3822,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
from_dtype=['int32', 'float32', 'complex64'],
|
||||
use_method=[True, False],
|
||||
)
|
||||
def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
args_maker = lambda: [rng((3, 4), from_dtype)]
|
||||
if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0
|
||||
np_op = lambda x: np.astype(x, to_dtype)
|
||||
else:
|
||||
np_op = lambda x: np.asarray(x).astype(to_dtype)
|
||||
if use_method:
|
||||
jnp_op = lambda x: jnp.asarray(x).astype(to_dtype)
|
||||
else:
|
||||
jnp_op = lambda x: jnp.astype(x, to_dtype)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
def testAstypeInt4(self):
|
||||
# Test converting from int4 to int8
|
||||
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user