mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #22664 from jakevdp:astype-device
PiperOrigin-RevId: 656016734
This commit is contained in:
commit
f17d0f382a
@ -3479,15 +3479,44 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
|
||||
|
||||
deprecations.register("jax-numpy-astype-complex-to-real")
|
||||
|
||||
@util.implements(getattr(np, "astype", None), lax_description="""
|
||||
This is implemented via :func:`jax.lax.convert_element_type`, which may
|
||||
have slightly different behavior than :func:`numpy.astype` in some cases.
|
||||
In particular, the details of float-to-int and int-to-float casts are
|
||||
implementation dependent.
|
||||
""")
|
||||
def astype(x: ArrayLike, dtype: DTypeLike | None,
|
||||
/, *, copy: bool = False,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Convert an array to a specified dtype.
|
||||
|
||||
JAX imlementation of :func:`numpy.astype`.
|
||||
|
||||
This is implemented via :func:`jax.lax.convert_element_type`, which may
|
||||
have slightly different behavior than :func:`numpy.astype` in some cases.
|
||||
In particular, the details of float-to-int and int-to-float casts are
|
||||
implementation dependent.
|
||||
|
||||
Args:
|
||||
x: input array to convert
|
||||
dtype: output dtype
|
||||
copy: if True, then always return a copy. If False (default) then only
|
||||
return a copy if necessary.
|
||||
device: optionally specify the device to which the output will be committed.
|
||||
|
||||
Returns:
|
||||
An array with the same shape as ``x``, containing values of the specified
|
||||
dtype.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.lax.convert_element_type`: lower-level function for XLA-style
|
||||
dtype conversions.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([0, 1, 2, 3])
|
||||
>>> x
|
||||
Array([0, 1, 2, 3], dtype=int32)
|
||||
>>> x.astype('float32')
|
||||
Array([0.0, 1.0, 2.0, 3.0], dtype=float32)
|
||||
|
||||
>>> y = jnp.array([0.0, 0.5, 1.0])
|
||||
>>> y.astype(int) # truncates fractional values
|
||||
Array([0, 0, 1], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("astype", x)
|
||||
x_arr = asarray(x)
|
||||
|
||||
@ -3510,17 +3539,9 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
|
||||
# to issue our warning.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", ComplexWarning)
|
||||
return _place_array(
|
||||
lax.convert_element_type(x_arr, dtype),
|
||||
device=device, copy=copy,
|
||||
)
|
||||
|
||||
def _place_array(x, device=None, copy=None):
|
||||
# TODO(micky774): Implement in future PRs as we formalize device placement
|
||||
# semantics
|
||||
if copy:
|
||||
return _array_copy(x)
|
||||
return x
|
||||
result = lax_internal._convert_element_type(
|
||||
x_arr, dtype, sharding=_normalize_to_sharding(device))
|
||||
return _array_copy(result) if copy else result
|
||||
|
||||
|
||||
@util.implements(np.asarray, lax_description=_ARRAY_DOC)
|
||||
|
Loading…
x
Reference in New Issue
Block a user