mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Make complex_arr.astype(bool) follow NumPy's semantics
This commit is contained in:
parent
f1ae6232e9
commit
e07325a672
@ -13,6 +13,8 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
|
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
|
||||||
the old behavior by transforming the arguments via
|
the old behavior by transforming the arguments via
|
||||||
`jax.tree.map(np.asarray, args)` before passing them to the callback.
|
`jax.tree.map(np.asarray, args)` before passing them to the callback.
|
||||||
|
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
|
||||||
|
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
|
||||||
|
|
||||||
* Deprecations & Removals
|
* Deprecations & Removals
|
||||||
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
|
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
|
||||||
|
@ -2263,11 +2263,16 @@ In particular, the details of float-to-int and int-to-float casts are
|
|||||||
implementation dependent.
|
implementation dependent.
|
||||||
""")
|
""")
|
||||||
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
|
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
|
||||||
|
util.check_arraylike("astype", x)
|
||||||
|
x_arr = asarray(x)
|
||||||
del copy # unused in JAX
|
del copy # unused in JAX
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = dtypes.canonicalize_dtype(float_)
|
dtype = dtypes.canonicalize_dtype(float_)
|
||||||
dtypes.check_user_dtype_supported(dtype, "astype")
|
dtypes.check_user_dtype_supported(dtype, "astype")
|
||||||
return lax.convert_element_type(x, dtype)
|
# convert_element_type(complex, bool) has the wrong semantics.
|
||||||
|
if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating):
|
||||||
|
return (x_arr != _lax_const(x_arr, 0))
|
||||||
|
return lax.convert_element_type(x_arr, dtype)
|
||||||
|
|
||||||
|
|
||||||
@util.implements(np.asarray, lax_description=_ARRAY_DOC)
|
@util.implements(np.asarray, lax_description=_ARRAY_DOC)
|
||||||
|
@ -3822,6 +3822,24 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||||
self._CompileAndCheck(jnp_op, args_maker)
|
self._CompileAndCheck(jnp_op, args_maker)
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
from_dtype=['int32', 'float32', 'complex64'],
|
||||||
|
use_method=[True, False],
|
||||||
|
)
|
||||||
|
def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'):
|
||||||
|
rng = jtu.rand_some_zero(self.rng())
|
||||||
|
args_maker = lambda: [rng((3, 4), from_dtype)]
|
||||||
|
if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0
|
||||||
|
np_op = lambda x: np.astype(x, to_dtype)
|
||||||
|
else:
|
||||||
|
np_op = lambda x: np.asarray(x).astype(to_dtype)
|
||||||
|
if use_method:
|
||||||
|
jnp_op = lambda x: jnp.asarray(x).astype(to_dtype)
|
||||||
|
else:
|
||||||
|
jnp_op = lambda x: jnp.astype(x, to_dtype)
|
||||||
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||||
|
self._CompileAndCheck(jnp_op, args_maker)
|
||||||
|
|
||||||
def testAstypeInt4(self):
|
def testAstypeInt4(self):
|
||||||
# Test converting from int4 to int8
|
# Test converting from int4 to int8
|
||||||
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)
|
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user