Add jax.numpy.astype function

This commit is contained in:
Jake VanderPlas 2023-11-30 15:50:22 -08:00
parent 5b3fc1bd5d
commit d77cd9a0f4
7 changed files with 35 additions and 16 deletions

View File

@ -78,6 +78,7 @@ namespace; they are listed below.
array_split
array_str
asarray
astype
atleast_1d
atleast_2d
atleast_3d

View File

@ -61,10 +61,7 @@ def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
some cases. In particular, the details of float-to-int and int-to-float
casts are implementation dependent.
"""
if dtype is None:
dtype = dtypes.canonicalize_dtype(lax_numpy.float_)
dtypes.check_user_dtype_supported(dtype, "astype")
return lax.convert_element_type(arr, dtype)
return lax_numpy.astype(arr, dtype)
def _nbytes(arr: ArrayLike) -> int:

View File

@ -2179,6 +2179,20 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
return x
@util._wraps(getattr(np, "astype", None), lax_description="""
This is implemented via :func:`jax.lax.convert_element_type`, which may
have slightly different behavior than :func:`numpy.astype` in some cases.
In particular, the details of float-to-int and int-to-float casts are
implementation dependent.
""")
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
del copy # unused in JAX
if dtype is None:
dtype = dtypes.canonicalize_dtype(float_)
dtypes.check_user_dtype_supported(dtype, "astype")
return lax.convert_element_type(x, dtype)
@util._wraps(np.asarray, lax_description=_ARRAY_DOC)
def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "asarray")

View File

@ -156,6 +156,8 @@ def _wraps(
op.__np_wrapped__ = fun
# Allows this pattern: @wraps(getattr(np, 'new_function', None))
if fun is None:
if lax_description:
op.__doc__ = lax_description
return op
docstr = getattr(fun, "__doc__", None)
name = getattr(fun, "__name__", getattr(op, "__name__", str(op)))

View File

@ -39,6 +39,7 @@ from jax._src.numpy.lax_numpy import (
array_repr as array_repr,
array_split as array_split,
array_str as array_str,
astype as astype,
asarray as asarray,
atleast_1d as atleast_1d,
atleast_2d as atleast_2d,

View File

@ -103,6 +103,7 @@ array_str = _np.array_str
def asarray(
a: Any, dtype: Optional[DTypeLike] = ..., order: Optional[str] = ...
) -> Array: ...
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: bool = ...) -> Array: ...
@overload
def atleast_1d() -> list[Array]: ...

View File

@ -3531,19 +3531,22 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testAstype(self):
@jtu.sample_product(
from_dtype=['int32', 'float32'],
to_dtype=['int32', 'float32', None],
use_method=[True, False],
)
def testAstype(self, from_dtype, to_dtype, use_method):
rng = self.rng()
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
np_op = lambda x: np.asarray(x).astype(jnp.int32)
jnp_op = lambda x: jnp.asarray(x).astype(jnp.int32)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testAstypeNone(self):
rng = self.rng()
args_maker = lambda: [rng.randn(3, 4).astype("int32")]
np_op = jtu.with_jax_dtype_defaults(lambda x: np.asarray(x).astype(None))
jnp_op = lambda x: jnp.asarray(x).astype(None)
args_maker = lambda: [rng.randn(3, 4).astype(from_dtype)]
if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0
np_op = lambda x: np.astype(x, to_dtype)
else:
np_op = lambda x: np.asarray(x).astype(to_dtype)
if use_method:
jnp_op = lambda x: jnp.asarray(x).astype(to_dtype)
else:
jnp_op = lambda x: jnp.astype(x, to_dtype)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)