mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add jax.numpy.astype function
This commit is contained in:
parent
5b3fc1bd5d
commit
d77cd9a0f4
@ -78,6 +78,7 @@ namespace; they are listed below.
|
||||
array_split
|
||||
array_str
|
||||
asarray
|
||||
astype
|
||||
atleast_1d
|
||||
atleast_2d
|
||||
atleast_3d
|
||||
|
@ -61,10 +61,7 @@ def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
|
||||
some cases. In particular, the details of float-to-int and int-to-float
|
||||
casts are implementation dependent.
|
||||
"""
|
||||
if dtype is None:
|
||||
dtype = dtypes.canonicalize_dtype(lax_numpy.float_)
|
||||
dtypes.check_user_dtype_supported(dtype, "astype")
|
||||
return lax.convert_element_type(arr, dtype)
|
||||
return lax_numpy.astype(arr, dtype)
|
||||
|
||||
|
||||
def _nbytes(arr: ArrayLike) -> int:
|
||||
|
@ -2179,6 +2179,20 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
|
||||
return x
|
||||
|
||||
|
||||
@util._wraps(getattr(np, "astype", None), lax_description="""
|
||||
This is implemented via :func:`jax.lax.convert_element_type`, which may
|
||||
have slightly different behavior than :func:`numpy.astype` in some cases.
|
||||
In particular, the details of float-to-int and int-to-float casts are
|
||||
implementation dependent.
|
||||
""")
|
||||
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
|
||||
del copy # unused in JAX
|
||||
if dtype is None:
|
||||
dtype = dtypes.canonicalize_dtype(float_)
|
||||
dtypes.check_user_dtype_supported(dtype, "astype")
|
||||
return lax.convert_element_type(x, dtype)
|
||||
|
||||
|
||||
@util._wraps(np.asarray, lax_description=_ARRAY_DOC)
|
||||
def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None) -> Array:
|
||||
dtypes.check_user_dtype_supported(dtype, "asarray")
|
||||
|
@ -156,6 +156,8 @@ def _wraps(
|
||||
op.__np_wrapped__ = fun
|
||||
# Allows this pattern: @wraps(getattr(np, 'new_function', None))
|
||||
if fun is None:
|
||||
if lax_description:
|
||||
op.__doc__ = lax_description
|
||||
return op
|
||||
docstr = getattr(fun, "__doc__", None)
|
||||
name = getattr(fun, "__name__", getattr(op, "__name__", str(op)))
|
||||
|
@ -39,6 +39,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
array_repr as array_repr,
|
||||
array_split as array_split,
|
||||
array_str as array_str,
|
||||
astype as astype,
|
||||
asarray as asarray,
|
||||
atleast_1d as atleast_1d,
|
||||
atleast_2d as atleast_2d,
|
||||
|
@ -103,6 +103,7 @@ array_str = _np.array_str
|
||||
def asarray(
|
||||
a: Any, dtype: Optional[DTypeLike] = ..., order: Optional[str] = ...
|
||||
) -> Array: ...
|
||||
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: bool = ...) -> Array: ...
|
||||
|
||||
@overload
|
||||
def atleast_1d() -> list[Array]: ...
|
||||
|
@ -3531,19 +3531,22 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testAstype(self):
|
||||
@jtu.sample_product(
|
||||
from_dtype=['int32', 'float32'],
|
||||
to_dtype=['int32', 'float32', None],
|
||||
use_method=[True, False],
|
||||
)
|
||||
def testAstype(self, from_dtype, to_dtype, use_method):
|
||||
rng = self.rng()
|
||||
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
|
||||
np_op = lambda x: np.asarray(x).astype(jnp.int32)
|
||||
jnp_op = lambda x: jnp.asarray(x).astype(jnp.int32)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
def testAstypeNone(self):
|
||||
rng = self.rng()
|
||||
args_maker = lambda: [rng.randn(3, 4).astype("int32")]
|
||||
np_op = jtu.with_jax_dtype_defaults(lambda x: np.asarray(x).astype(None))
|
||||
jnp_op = lambda x: jnp.asarray(x).astype(None)
|
||||
args_maker = lambda: [rng.randn(3, 4).astype(from_dtype)]
|
||||
if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0
|
||||
np_op = lambda x: np.astype(x, to_dtype)
|
||||
else:
|
||||
np_op = lambda x: np.asarray(x).astype(to_dtype)
|
||||
if use_method:
|
||||
jnp_op = lambda x: jnp.asarray(x).astype(to_dtype)
|
||||
else:
|
||||
jnp_op = lambda x: jnp.astype(x, to_dtype)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user