Added device kwargs to jnp.linspace, jnp.array, jnp.asarray

This commit is contained in:
vfdev-5 2024-06-25 10:38:37 +02:00
parent f17d0f382a
commit 76d61f9d8f
6 changed files with 63 additions and 61 deletions

View File

@ -508,7 +508,7 @@ def convert_element_type(operand: ArrayLike,
Similar to a C++ `static_cast`.
Args:
operand: an array or scalar value to be cast
operand: an array or scalar value to be cast.
new_dtype: a NumPy dtype representing the target type.
Returns:

View File

@ -3351,7 +3351,10 @@ https://jax.readthedocs.io/en/latest/faq.html).
deprecations.register("jax-numpy-array-none")
@util.implements(np.array, lax_description=_ARRAY_DOC)
@util.implements(np.array, lax_description=_ARRAY_DOC, extra_params="""
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
""")
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
order: str | None = "K", ndmin: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array:
@ -3453,7 +3456,6 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
out = np.array(object) if copy else np.asarray(object)
else:
raise TypeError(f"Unexpected input type for array: {type(object)}")
out_array: Array = lax_internal._convert_element_type(
out, dtype, weak_type=weak_type, sharding=sharding)
if ndmin > ndim(out_array):
@ -3544,9 +3546,13 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
return _array_copy(result) if copy else result
@util.implements(np.asarray, lax_description=_ARRAY_DOC)
@util.implements(np.asarray, lax_description=_ARRAY_DOC, extra_params="""
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
""")
def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
*, copy: bool | None = None) -> Array:
*, copy: bool | None = None,
device: xc.Device | Sharding | None = None) -> Array:
# For copy=False, the array API specifies that we raise a ValueError if the input supports
# the buffer protocol but a copy is required. Since array() supports the buffer protocol
# via numpy, this is only the case when the default device is not 'cpu'
@ -3559,7 +3565,7 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
dtypes.check_user_dtype_supported(dtype, "asarray")
if dtype is not None:
dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment]
return array(a, dtype=dtype, copy=bool(copy), order=order)
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
@util.implements(np.copy, lax_description=_ARRAY_DOC)
@ -4329,36 +4335,45 @@ def _arange_dynamic(
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: Literal[False] = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array: ...
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int,
endpoint: bool, retstep: Literal[True],
dtype: DTypeLike | None = None,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, *, retstep: Literal[True],
dtype: DTypeLike | None = None,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array | tuple[Array, Array]: ...
@util.implements(np.linspace)
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ...
@util.implements(np.linspace, extra_params="""
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
""")
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array | tuple[Array, Array]:
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]:
num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
return _linspace(start, stop, num, endpoint, retstep, dtype, axis)
return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device)
@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis'))
@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis', 'device'))
def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array | tuple[Array, Array]:
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]:
"""Implementation of linspace differentiable in start and stop args."""
dtypes.check_user_dtype_supported(dtype, "linspace")
if num < 0:
@ -4406,10 +4421,9 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
if issubdtype(dtype, integer) and not issubdtype(out.dtype, integer):
out = lax.floor(out)
if retstep:
return lax.convert_element_type(out, dtype), delta
else:
return lax.convert_element_type(out, dtype)
sharding = canonicalize_device_to_sharding(device)
result = lax_internal._convert_element_type(out, dtype, sharding=sharding)
return (result, delta) if retstep else result
@util.implements(np.logspace)

View File

@ -52,6 +52,7 @@ from jax.numpy import (
argmax as argmax,
argmin as argmin,
argsort as argsort,
asarray as asarray,
asin as asin,
asinh as asinh,
atan as atan,
@ -109,6 +110,7 @@ from jax.numpy import (
isnan as isnan,
less as less,
less_equal as less_equal,
linspace as linspace,
log as log,
log10 as log10,
log1p as log1p,
@ -187,11 +189,6 @@ from jax.experimental.array_api._manipulation_functions import (
reshape as reshape,
)
from jax.experimental.array_api._creation_functions import (
asarray as asarray,
linspace as linspace,
)
from jax.experimental.array_api._data_type_functions import (
astype as astype,
)

View File

@ -1,25 +0,0 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import jax
import jax.numpy as jnp
def asarray(obj, /, *, dtype=None, device=None, copy=None):
return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device)
def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True):
return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device)

View File

@ -114,7 +114,8 @@ def array_split(
array_str = _np.array_str
def asarray(
a: Any, dtype: DTypeLike | None = ..., order: str | None = ...,
*, copy: builtins.bool | None = ...
*, copy: builtins.bool | None = ...,
device: _Device | _Sharding | None = ...,
) -> Array: ...
def asin(x: ArrayLike, /) -> Array: ...
def asinh(x: ArrayLike, /) -> Array: ...
@ -523,22 +524,26 @@ def lexsort(keys: Sequence[ArrayLike], axis: int = ...) -> Array: ...
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: builtins.bool = True, retstep: Literal[False] = False,
dtype: DTypeLike | None = ...,
axis: int = 0) -> Array: ...
axis: int = 0,
*, device: _Device | _Sharding | None = ...) -> Array: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int,
endpoint: builtins.bool, retstep: Literal[True],
dtype: DTypeLike | None = ...,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
*, device: _Device | _Sharding | None = ...) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: builtins.bool = True, *, retstep: Literal[True],
dtype: DTypeLike | None = ...,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
device: _Device | _Sharding | None = ...) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: builtins.bool = True, retstep: builtins.bool = False,
dtype: DTypeLike | None = ...,
axis: int = 0) -> Array | tuple[Array, Array]: ...
axis: int = 0,
*, device: _Device | _Sharding | None = ...) -> Union[Array, tuple[Array, Array]]: ...
def load(*args: Any, **kwargs: Any) -> Array: ...
def log(x: ArrayLike, /) -> Array: ...

View File

@ -2998,25 +2998,37 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
func=[
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device),
lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device),
],
dtype=default_dtypes,
)
def testArangeEyeWithDevice(self, func, dtype):
def testArangeEyeLinspaceArrayWithDevice(self, func, dtype):
device = jax.devices()[-1]
out = func(dtype=dtype, device=device)
self.assertEqual(out.devices(), {device})
output = func(dtype=dtype, device=device)
if isinstance(output, tuple):
self.assertEqual(output[0].devices(), {device})
else:
self.assertEqual(output.devices(), {device})
@jtu.sample_product(
func=[
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device),
lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device),
],
dtype=default_dtypes,
)
def testArangeEyeWithSharding(self, func, dtype):
def testArangeEyeLinspaceArrayWithSharding(self, func, dtype):
sharding = SingleDeviceSharding(jax.devices()[-1])
out = func(dtype=dtype, device=sharding)
self.assertEqual(out.sharding, sharding)
output = func(dtype=dtype, device=sharding)
if isinstance(output, tuple):
self.assertEqual(output[0].sharding, sharding)
else:
self.assertEqual(output.sharding, sharding)
@jtu.sample_product(
func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like],
@ -6066,7 +6078,6 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'histogram': ['normed'],
'histogram2d': ['normed'],
'histogramdd': ['normed'],
'linspace': ['device'],
'nanpercentile': ['weights'],
'nanquantile': ['weights'],
'nanstd': ['correction', 'mean'],