mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add support for copy kwarg in astype to match Array API
This commit is contained in:
parent
667a0c1fe5
commit
30cd3b88fd
@ -50,6 +50,12 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and
|
||||
related functions now raise an error, following a similar change in NumPy.
|
||||
|
||||
* Bug fixes
|
||||
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
|
||||
Previously, no copy would be made when the output array would have the same
|
||||
dtype as the input array. This may result in some increased memory usage.
|
||||
The default value is set to `copy=False` to preserve backwards compatability.
|
||||
|
||||
## jaxlib 0.4.27
|
||||
|
||||
## jax 0.4.26 (April 3, 2024)
|
||||
|
@ -31,11 +31,13 @@ from typing import Any
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.sharding import Sharding
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.api_util import _ensure_index_tuple
|
||||
from jax._src.array import ArrayImpl
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy import ufuncs
|
||||
@ -55,7 +57,7 @@ zip, unsafe_zip = safe_zip, zip
|
||||
# functions, which can themselves handle instances from any of these classes.
|
||||
|
||||
|
||||
def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
|
||||
def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Copy the array and cast to a specified dtype.
|
||||
|
||||
This is implemented via :func:`jax.lax.convert_element_type`, which may
|
||||
@ -63,7 +65,7 @@ def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
|
||||
some cases. In particular, the details of float-to-int and int-to-float
|
||||
casts are implementation dependent.
|
||||
"""
|
||||
return lax_numpy.astype(arr, dtype)
|
||||
return lax_numpy.astype(arr, dtype, copy=copy, device=device)
|
||||
|
||||
|
||||
def _nbytes(arr: ArrayLike) -> int:
|
||||
|
@ -2272,17 +2272,42 @@ have slightly different behavior than :func:`numpy.astype` in some cases.
|
||||
In particular, the details of float-to-int and int-to-float casts are
|
||||
implementation dependent.
|
||||
""")
|
||||
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
|
||||
def astype(x: ArrayLike, dtype: DTypeLike | None,
|
||||
/, *, copy: bool = False,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
util.check_arraylike("astype", x)
|
||||
x_arr = asarray(x)
|
||||
del copy # unused in JAX
|
||||
|
||||
if dtype is None:
|
||||
dtype = dtypes.canonicalize_dtype(float_)
|
||||
dtypes.check_user_dtype_supported(dtype, "astype")
|
||||
# convert_element_type(complex, bool) has the wrong semantics.
|
||||
if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating):
|
||||
return (x_arr != _lax_const(x_arr, 0))
|
||||
return lax.convert_element_type(x_arr, dtype)
|
||||
if issubdtype(x_arr.dtype, complexfloating):
|
||||
if dtypes.isdtype(dtype, ("integral", "real floating")):
|
||||
warnings.warn(
|
||||
"Casting from complex to real dtypes will soon raise a ValueError. "
|
||||
"Please first use jnp.real or jnp.imag to take the real/imaginary "
|
||||
"component of your input.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
)
|
||||
elif np.dtype(dtype) == bool:
|
||||
# convert_element_type(complex, bool) has the wrong semantics.
|
||||
x_arr = (x_arr != _lax_const(x_arr, 0))
|
||||
|
||||
# We offer a more specific warning than the usual ComplexWarning so we prefer
|
||||
# to issue our warning.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", ComplexWarning)
|
||||
return _place_array(
|
||||
lax.convert_element_type(x_arr, dtype),
|
||||
device=device, copy=copy,
|
||||
)
|
||||
|
||||
def _place_array(x, device=None, copy=None):
|
||||
# TODO(micky774): Implement in future PRs as we formalize device placement
|
||||
# semantics
|
||||
if copy:
|
||||
return _array_copy(x)
|
||||
return x
|
||||
|
||||
|
||||
@util.implements(np.asarray, lax_description=_ARRAY_DOC)
|
||||
|
@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import functools
|
||||
from typing import NamedTuple
|
||||
@ -19,6 +21,9 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax.experimental.array_api._dtypes import (
|
||||
bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64,
|
||||
float32, float64, complex64, complex128
|
||||
@ -124,8 +129,19 @@ def _promote_types(t1, t2):
|
||||
raise ValueError("No promotion path for {t1} & {t2}")
|
||||
|
||||
|
||||
def astype(x, dtype, /, *, copy=True):
|
||||
return jnp.array(x, dtype=dtype, copy=copy)
|
||||
def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None):
|
||||
src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x)
|
||||
if (
|
||||
src_dtype is not None
|
||||
and _dtypes.isdtype(src_dtype, "complex floating")
|
||||
and _dtypes.isdtype(dtype, ("integral", "real floating"))
|
||||
):
|
||||
raise ValueError(
|
||||
"Casting from complex to non-complex dtypes is not permitted. Please "
|
||||
"first use jnp.real or jnp.imag to take the real/imaginary component of "
|
||||
"your input."
|
||||
)
|
||||
return jnp.astype(x, dtype, copy=copy, device=device)
|
||||
|
||||
|
||||
def can_cast(from_, to, /):
|
||||
|
@ -115,7 +115,7 @@ def asarray(
|
||||
) -> Array: ...
|
||||
def asin(x: ArrayLike, /) -> Array: ...
|
||||
def asinh(x: ArrayLike, /) -> Array: ...
|
||||
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ...) -> Array: ...
|
||||
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ..., device: _Device | _Sharding | None = ...) -> Array: ...
|
||||
def atan(x: ArrayLike, /) -> Array: ...
|
||||
def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def atanh(x: ArrayLike, /) -> Array: ...
|
||||
|
@ -776,8 +776,13 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
for axis in list(
|
||||
range(-len(shape), len(shape))
|
||||
) + ([None] if len(shape) == 1 else [])],
|
||||
dtype=all_dtypes + [None],
|
||||
out_dtype=all_dtypes,
|
||||
[dict(dtype=dtype, out_dtype=out_dtype)
|
||||
for dtype in (all_dtypes+[None])
|
||||
for out_dtype in (
|
||||
complex_dtypes if np.issubdtype(dtype, np.complexfloating)
|
||||
else all_dtypes
|
||||
)
|
||||
],
|
||||
include_initial=[False, True],
|
||||
)
|
||||
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||
|
@ -3870,6 +3870,26 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
change_dtype=[True, False],
|
||||
copy=[True, False],
|
||||
)
|
||||
def testAstypeCopy(self, change_dtype, copy):
|
||||
dtype = 'float32' if change_dtype else 'int32'
|
||||
expect_copy = change_dtype or copy
|
||||
x = jnp.arange(5, dtype='int32')
|
||||
y = x.astype(dtype, copy=copy)
|
||||
|
||||
self.assertEqual(y.dtype, dtype)
|
||||
y.delete()
|
||||
self.assertNotEqual(x.is_deleted(), expect_copy)
|
||||
|
||||
def testAstypeComplexDowncast(self):
|
||||
x = jnp.array(2.0+1.5j, dtype='complex64')
|
||||
msg = "Casting from complex to non-complex dtypes will soon raise "
|
||||
with self.assertWarns(DeprecationWarning, msg=msg):
|
||||
x.astype('float32')
|
||||
|
||||
def testAstypeInt4(self):
|
||||
# Test converting from int4 to int8
|
||||
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user