mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Added device kwargs to jnp.linspace, jnp.array, jnp.asarray
This commit is contained in:
parent
f17d0f382a
commit
76d61f9d8f
@ -508,7 +508,7 @@ def convert_element_type(operand: ArrayLike,
|
||||
Similar to a C++ `static_cast`.
|
||||
|
||||
Args:
|
||||
operand: an array or scalar value to be cast
|
||||
operand: an array or scalar value to be cast.
|
||||
new_dtype: a NumPy dtype representing the target type.
|
||||
|
||||
Returns:
|
||||
|
@ -3351,7 +3351,10 @@ https://jax.readthedocs.io/en/latest/faq.html).
|
||||
|
||||
deprecations.register("jax-numpy-array-none")
|
||||
|
||||
@util.implements(np.array, lax_description=_ARRAY_DOC)
|
||||
@util.implements(np.array, lax_description=_ARRAY_DOC, extra_params="""
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
""")
|
||||
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
order: str | None = "K", ndmin: int = 0,
|
||||
*, device: xc.Device | Sharding | None = None) -> Array:
|
||||
@ -3453,7 +3456,6 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
out = np.array(object) if copy else np.asarray(object)
|
||||
else:
|
||||
raise TypeError(f"Unexpected input type for array: {type(object)}")
|
||||
|
||||
out_array: Array = lax_internal._convert_element_type(
|
||||
out, dtype, weak_type=weak_type, sharding=sharding)
|
||||
if ndmin > ndim(out_array):
|
||||
@ -3544,9 +3546,13 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
|
||||
return _array_copy(result) if copy else result
|
||||
|
||||
|
||||
@util.implements(np.asarray, lax_description=_ARRAY_DOC)
|
||||
@util.implements(np.asarray, lax_description=_ARRAY_DOC, extra_params="""
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
""")
|
||||
def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
|
||||
*, copy: bool | None = None) -> Array:
|
||||
*, copy: bool | None = None,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
# For copy=False, the array API specifies that we raise a ValueError if the input supports
|
||||
# the buffer protocol but a copy is required. Since array() supports the buffer protocol
|
||||
# via numpy, this is only the case when the default device is not 'cpu'
|
||||
@ -3559,7 +3565,7 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
|
||||
dtypes.check_user_dtype_supported(dtype, "asarray")
|
||||
if dtype is not None:
|
||||
dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment]
|
||||
return array(a, dtype=dtype, copy=bool(copy), order=order)
|
||||
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
|
||||
|
||||
|
||||
@util.implements(np.copy, lax_description=_ARRAY_DOC)
|
||||
@ -4329,36 +4335,45 @@ def _arange_dynamic(
|
||||
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
endpoint: bool = True, retstep: Literal[False] = False,
|
||||
dtype: DTypeLike | None = None,
|
||||
axis: int = 0) -> Array: ...
|
||||
axis: int = 0,
|
||||
*, device: xc.Device | Sharding | None = None) -> Array: ...
|
||||
@overload
|
||||
def linspace(start: ArrayLike, stop: ArrayLike, num: int,
|
||||
endpoint: bool, retstep: Literal[True],
|
||||
dtype: DTypeLike | None = None,
|
||||
axis: int = 0) -> tuple[Array, Array]: ...
|
||||
axis: int = 0,
|
||||
*, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ...
|
||||
@overload
|
||||
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
endpoint: bool = True, *, retstep: Literal[True],
|
||||
dtype: DTypeLike | None = None,
|
||||
axis: int = 0) -> tuple[Array, Array]: ...
|
||||
axis: int = 0,
|
||||
device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ...
|
||||
@overload
|
||||
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
endpoint: bool = True, retstep: bool = False,
|
||||
dtype: DTypeLike | None = None,
|
||||
axis: int = 0) -> Array | tuple[Array, Array]: ...
|
||||
@util.implements(np.linspace)
|
||||
axis: int = 0,
|
||||
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ...
|
||||
@util.implements(np.linspace, extra_params="""
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
""")
|
||||
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
endpoint: bool = True, retstep: bool = False,
|
||||
dtype: DTypeLike | None = None,
|
||||
axis: int = 0) -> Array | tuple[Array, Array]:
|
||||
axis: int = 0,
|
||||
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]:
|
||||
num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace")
|
||||
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
|
||||
return _linspace(start, stop, num, endpoint, retstep, dtype, axis)
|
||||
return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device)
|
||||
|
||||
@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis'))
|
||||
@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis', 'device'))
|
||||
def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
endpoint: bool = True, retstep: bool = False,
|
||||
dtype: DTypeLike | None = None,
|
||||
axis: int = 0) -> Array | tuple[Array, Array]:
|
||||
axis: int = 0,
|
||||
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]:
|
||||
"""Implementation of linspace differentiable in start and stop args."""
|
||||
dtypes.check_user_dtype_supported(dtype, "linspace")
|
||||
if num < 0:
|
||||
@ -4406,10 +4421,9 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
if issubdtype(dtype, integer) and not issubdtype(out.dtype, integer):
|
||||
out = lax.floor(out)
|
||||
|
||||
if retstep:
|
||||
return lax.convert_element_type(out, dtype), delta
|
||||
else:
|
||||
return lax.convert_element_type(out, dtype)
|
||||
sharding = canonicalize_device_to_sharding(device)
|
||||
result = lax_internal._convert_element_type(out, dtype, sharding=sharding)
|
||||
return (result, delta) if retstep else result
|
||||
|
||||
|
||||
@util.implements(np.logspace)
|
||||
|
@ -52,6 +52,7 @@ from jax.numpy import (
|
||||
argmax as argmax,
|
||||
argmin as argmin,
|
||||
argsort as argsort,
|
||||
asarray as asarray,
|
||||
asin as asin,
|
||||
asinh as asinh,
|
||||
atan as atan,
|
||||
@ -109,6 +110,7 @@ from jax.numpy import (
|
||||
isnan as isnan,
|
||||
less as less,
|
||||
less_equal as less_equal,
|
||||
linspace as linspace,
|
||||
log as log,
|
||||
log10 as log10,
|
||||
log1p as log1p,
|
||||
@ -187,11 +189,6 @@ from jax.experimental.array_api._manipulation_functions import (
|
||||
reshape as reshape,
|
||||
)
|
||||
|
||||
from jax.experimental.array_api._creation_functions import (
|
||||
asarray as asarray,
|
||||
linspace as linspace,
|
||||
)
|
||||
|
||||
from jax.experimental.array_api._data_type_functions import (
|
||||
astype as astype,
|
||||
)
|
||||
|
@ -1,25 +0,0 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def asarray(obj, /, *, dtype=None, device=None, copy=None):
|
||||
return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device)
|
||||
|
||||
def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True):
|
||||
return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device)
|
@ -114,7 +114,8 @@ def array_split(
|
||||
array_str = _np.array_str
|
||||
def asarray(
|
||||
a: Any, dtype: DTypeLike | None = ..., order: str | None = ...,
|
||||
*, copy: builtins.bool | None = ...
|
||||
*, copy: builtins.bool | None = ...,
|
||||
device: _Device | _Sharding | None = ...,
|
||||
) -> Array: ...
|
||||
def asin(x: ArrayLike, /) -> Array: ...
|
||||
def asinh(x: ArrayLike, /) -> Array: ...
|
||||
@ -523,22 +524,26 @@ def lexsort(keys: Sequence[ArrayLike], axis: int = ...) -> Array: ...
|
||||
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
endpoint: builtins.bool = True, retstep: Literal[False] = False,
|
||||
dtype: DTypeLike | None = ...,
|
||||
axis: int = 0) -> Array: ...
|
||||
axis: int = 0,
|
||||
*, device: _Device | _Sharding | None = ...) -> Array: ...
|
||||
@overload
|
||||
def linspace(start: ArrayLike, stop: ArrayLike, num: int,
|
||||
endpoint: builtins.bool, retstep: Literal[True],
|
||||
dtype: DTypeLike | None = ...,
|
||||
axis: int = 0) -> tuple[Array, Array]: ...
|
||||
axis: int = 0,
|
||||
*, device: _Device | _Sharding | None = ...) -> tuple[Array, Array]: ...
|
||||
@overload
|
||||
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
endpoint: builtins.bool = True, *, retstep: Literal[True],
|
||||
dtype: DTypeLike | None = ...,
|
||||
axis: int = 0) -> tuple[Array, Array]: ...
|
||||
axis: int = 0,
|
||||
device: _Device | _Sharding | None = ...) -> tuple[Array, Array]: ...
|
||||
@overload
|
||||
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
endpoint: builtins.bool = True, retstep: builtins.bool = False,
|
||||
dtype: DTypeLike | None = ...,
|
||||
axis: int = 0) -> Array | tuple[Array, Array]: ...
|
||||
axis: int = 0,
|
||||
*, device: _Device | _Sharding | None = ...) -> Union[Array, tuple[Array, Array]]: ...
|
||||
|
||||
def load(*args: Any, **kwargs: Any) -> Array: ...
|
||||
def log(x: ArrayLike, /) -> Array: ...
|
||||
|
@ -2998,25 +2998,37 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
func=[
|
||||
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
|
||||
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
|
||||
lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device),
|
||||
lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device),
|
||||
lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device),
|
||||
],
|
||||
dtype=default_dtypes,
|
||||
)
|
||||
def testArangeEyeWithDevice(self, func, dtype):
|
||||
def testArangeEyeLinspaceArrayWithDevice(self, func, dtype):
|
||||
device = jax.devices()[-1]
|
||||
out = func(dtype=dtype, device=device)
|
||||
self.assertEqual(out.devices(), {device})
|
||||
output = func(dtype=dtype, device=device)
|
||||
if isinstance(output, tuple):
|
||||
self.assertEqual(output[0].devices(), {device})
|
||||
else:
|
||||
self.assertEqual(output.devices(), {device})
|
||||
|
||||
@jtu.sample_product(
|
||||
func=[
|
||||
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
|
||||
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
|
||||
lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device),
|
||||
lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device),
|
||||
lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device),
|
||||
],
|
||||
dtype=default_dtypes,
|
||||
)
|
||||
def testArangeEyeWithSharding(self, func, dtype):
|
||||
def testArangeEyeLinspaceArrayWithSharding(self, func, dtype):
|
||||
sharding = SingleDeviceSharding(jax.devices()[-1])
|
||||
out = func(dtype=dtype, device=sharding)
|
||||
self.assertEqual(out.sharding, sharding)
|
||||
output = func(dtype=dtype, device=sharding)
|
||||
if isinstance(output, tuple):
|
||||
self.assertEqual(output[0].sharding, sharding)
|
||||
else:
|
||||
self.assertEqual(output.sharding, sharding)
|
||||
|
||||
@jtu.sample_product(
|
||||
func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like],
|
||||
@ -6066,7 +6078,6 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'histogram': ['normed'],
|
||||
'histogram2d': ['normed'],
|
||||
'histogramdd': ['normed'],
|
||||
'linspace': ['device'],
|
||||
'nanpercentile': ['weights'],
|
||||
'nanquantile': ['weights'],
|
||||
'nanstd': ['correction', 'mean'],
|
||||
|
Loading…
x
Reference in New Issue
Block a user