Merge pull request #22664 from jakevdp:astype-device

PiperOrigin-RevId: 656016734
This commit is contained in:
jax authors 2024-07-25 11:11:49 -07:00
commit f17d0f382a

View File

@ -3479,15 +3479,44 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
deprecations.register("jax-numpy-astype-complex-to-real")
@util.implements(getattr(np, "astype", None), lax_description="""
This is implemented via :func:`jax.lax.convert_element_type`, which may
have slightly different behavior than :func:`numpy.astype` in some cases.
In particular, the details of float-to-int and int-to-float casts are
implementation dependent.
""")
def astype(x: ArrayLike, dtype: DTypeLike | None,
/, *, copy: bool = False,
device: xc.Device | Sharding | None = None) -> Array:
"""Convert an array to a specified dtype.
JAX imlementation of :func:`numpy.astype`.
This is implemented via :func:`jax.lax.convert_element_type`, which may
have slightly different behavior than :func:`numpy.astype` in some cases.
In particular, the details of float-to-int and int-to-float casts are
implementation dependent.
Args:
x: input array to convert
dtype: output dtype
copy: if True, then always return a copy. If False (default) then only
return a copy if necessary.
device: optionally specify the device to which the output will be committed.
Returns:
An array with the same shape as ``x``, containing values of the specified
dtype.
See Also:
- :func:`jax.lax.convert_element_type`: lower-level function for XLA-style
dtype conversions.
Examples:
>>> x = jnp.array([0, 1, 2, 3])
>>> x
Array([0, 1, 2, 3], dtype=int32)
>>> x.astype('float32')
Array([0.0, 1.0, 2.0, 3.0], dtype=float32)
>>> y = jnp.array([0.0, 0.5, 1.0])
>>> y.astype(int) # truncates fractional values
Array([0, 0, 1], dtype=int32)
"""
util.check_arraylike("astype", x)
x_arr = asarray(x)
@ -3510,17 +3539,9 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
# to issue our warning.
with warnings.catch_warnings():
warnings.simplefilter("ignore", ComplexWarning)
return _place_array(
lax.convert_element_type(x_arr, dtype),
device=device, copy=copy,
)
def _place_array(x, device=None, copy=None):
# TODO(micky774): Implement in future PRs as we formalize device placement
# semantics
if copy:
return _array_copy(x)
return x
result = lax_internal._convert_element_type(
x_arr, dtype, sharding=_normalize_to_sharding(device))
return _array_copy(result) if copy else result
@util.implements(np.asarray, lax_description=_ARRAY_DOC)