mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #26376 from jakevdp:array-creation
PiperOrigin-RevId: 724399604
This commit is contained in:
commit
ec477634f1
@ -214,6 +214,28 @@ _default_types: dict[str, type[Any]] = {
|
||||
'c': complex_,
|
||||
}
|
||||
|
||||
|
||||
def jax_dtype(obj: DTypeLike | None, *, align: bool = False,
|
||||
copy: bool = False) -> DType:
|
||||
"""Cast an object to a dtype, respecting JAX dtype defaults.
|
||||
|
||||
Arguments mirror those of :func:`numpy.dtype`.
|
||||
"""
|
||||
if obj is None:
|
||||
obj = float_
|
||||
elif issubdtype(obj, extended):
|
||||
return obj # type: ignore[return-value]
|
||||
elif isinstance(obj, type):
|
||||
obj = _DEFAULT_TYPEMAP.get(obj, obj)
|
||||
return np.dtype(obj, align=align, copy=copy)
|
||||
|
||||
_DEFAULT_TYPEMAP: dict[type, DTypeLike] = {
|
||||
bool: bool,
|
||||
int: int_,
|
||||
float: float_,
|
||||
complex: complex_,
|
||||
}
|
||||
|
||||
def bit_width(dtype: DTypeLike) -> int:
|
||||
"""Number of bits per element for the dtype."""
|
||||
# Note: we cannot use dtype.itemsize here because this is
|
||||
|
394
jax/_src/numpy/array_creation.py
Normal file
394
jax/_src/numpy/array_creation.py
Normal file
@ -0,0 +1,394 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import types
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.numpy import util
|
||||
from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike
|
||||
from jax._src.util import set_module
|
||||
from jax.sharding import Sharding
|
||||
|
||||
|
||||
export = set_module('jax.numpy')
|
||||
|
||||
|
||||
# Like core.canonicalize_shape, but also accept int-like (non-sequence)
|
||||
# arguments for `shape`.
|
||||
def canonicalize_shape(shape: Any, context: str="") -> core.Shape:
|
||||
if (not isinstance(shape, (tuple, list)) and
|
||||
(getattr(shape, 'ndim', None) == 0 or np.ndim(shape) == 0)):
|
||||
return core.canonicalize_shape((shape,), context)
|
||||
else:
|
||||
return core.canonicalize_shape(shape, context)
|
||||
|
||||
|
||||
@export
|
||||
def zeros(shape: Any, dtype: DTypeLike | None = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an array full of zeros.
|
||||
|
||||
JAX implementation of :func:`numpy.zeros`.
|
||||
|
||||
Args:
|
||||
shape: int or sequence of ints specifying the shape of the created array.
|
||||
dtype: optional dtype for the created array; defaults to floating point.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.zeros_like`
|
||||
- :func:`jax.numpy.empty`
|
||||
- :func:`jax.numpy.ones`
|
||||
- :func:`jax.numpy.full`
|
||||
|
||||
Examples:
|
||||
>>> jnp.zeros(4)
|
||||
Array([0., 0., 0., 0.], dtype=float32)
|
||||
>>> jnp.zeros((2, 3), dtype=bool)
|
||||
Array([[False, False, False],
|
||||
[False, False, False]], dtype=bool)
|
||||
"""
|
||||
if isinstance(shape, types.GeneratorType):
|
||||
raise TypeError("expected sequence object with len >= 0 or a single integer")
|
||||
if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m)
|
||||
dtypes.check_user_dtype_supported(dtype, "zeros")
|
||||
shape = canonicalize_shape(shape)
|
||||
return lax.full(shape, 0, dtypes.jax_dtype(dtype), sharding=util.normalize_device_to_sharding(device))
|
||||
|
||||
|
||||
@export
|
||||
def ones(shape: Any, dtype: DTypeLike | None = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an array full of ones.
|
||||
|
||||
JAX implementation of :func:`numpy.ones`.
|
||||
|
||||
Args:
|
||||
shape: int or sequence of ints specifying the shape of the created array.
|
||||
dtype: optional dtype for the created array; defaults to floating point.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.ones_like`
|
||||
- :func:`jax.numpy.empty`
|
||||
- :func:`jax.numpy.zeros`
|
||||
- :func:`jax.numpy.full`
|
||||
|
||||
Examples:
|
||||
>>> jnp.ones(4)
|
||||
Array([1., 1., 1., 1.], dtype=float32)
|
||||
>>> jnp.ones((2, 3), dtype=bool)
|
||||
Array([[ True, True, True],
|
||||
[ True, True, True]], dtype=bool)
|
||||
"""
|
||||
if isinstance(shape, types.GeneratorType):
|
||||
raise TypeError("expected sequence object with len >= 0 or a single integer")
|
||||
if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m)
|
||||
shape = canonicalize_shape(shape)
|
||||
dtypes.check_user_dtype_supported(dtype, "ones")
|
||||
return lax.full(shape, 1, dtypes.jax_dtype(dtype), sharding=util.normalize_device_to_sharding(device))
|
||||
|
||||
|
||||
@export
|
||||
def empty(shape: Any, dtype: DTypeLike | None = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an empty array.
|
||||
|
||||
JAX implementation of :func:`numpy.empty`. Because XLA cannot create an
|
||||
un-initialized array, :func:`jax.numpy.empty` will always return an array
|
||||
full of zeros.
|
||||
|
||||
Args:
|
||||
shape: int or sequence of ints specifying the shape of the created array.
|
||||
dtype: optional dtype for the created array; defaults to floating point.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.empty_like`
|
||||
- :func:`jax.numpy.zeros`
|
||||
- :func:`jax.numpy.ones`
|
||||
- :func:`jax.numpy.full`
|
||||
|
||||
Examples:
|
||||
>>> jnp.empty(4)
|
||||
Array([0., 0., 0., 0.], dtype=float32)
|
||||
>>> jnp.empty((2, 3), dtype=bool)
|
||||
Array([[False, False, False],
|
||||
[False, False, False]], dtype=bool)
|
||||
"""
|
||||
if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m)
|
||||
dtypes.check_user_dtype_supported(dtype, "empty")
|
||||
return zeros(shape, dtype, device=device)
|
||||
|
||||
|
||||
def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore
|
||||
if isinstance(dtype, int) and isinstance(shape, int):
|
||||
return (f"Cannot interpret '{dtype}' as a data type."
|
||||
f"\n\nDid you accidentally write "
|
||||
f"`jax.numpy.{name}({shape}, {dtype})` "
|
||||
f"when you meant `jax.numpy.{name}(({shape}, {dtype}))`, i.e. "
|
||||
"with a single tuple argument for the shape?")
|
||||
|
||||
@export
|
||||
def full(shape: Any, fill_value: ArrayLike,
|
||||
dtype: DTypeLike | None = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an array full of a specified value.
|
||||
|
||||
JAX implementation of :func:`numpy.full`.
|
||||
|
||||
Args:
|
||||
shape: int or sequence of ints specifying the shape of the created array.
|
||||
fill_value: scalar or array with which to fill the created array.
|
||||
dtype: optional dtype for the created array; defaults to the dtype of the
|
||||
fill value.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.full_like`
|
||||
- :func:`jax.numpy.empty`
|
||||
- :func:`jax.numpy.zeros`
|
||||
- :func:`jax.numpy.ones`
|
||||
|
||||
Examples:
|
||||
>>> jnp.full(4, 2, dtype=float)
|
||||
Array([2., 2., 2., 2.], dtype=float32)
|
||||
>>> jnp.full((2, 3), 0, dtype=bool)
|
||||
Array([[False, False, False],
|
||||
[False, False, False]], dtype=bool)
|
||||
|
||||
`fill_value` may also be an array that is broadcast to the specified shape:
|
||||
|
||||
>>> jnp.full((2, 3), fill_value=jnp.arange(3))
|
||||
Array([[0, 1, 2],
|
||||
[0, 1, 2]], dtype=int32)
|
||||
"""
|
||||
dtypes.check_user_dtype_supported(dtype, "full")
|
||||
util.check_arraylike("full", fill_value)
|
||||
|
||||
if np.ndim(fill_value) == 0:
|
||||
shape = canonicalize_shape(shape)
|
||||
return lax.full(shape, fill_value, dtype, sharding=util.normalize_device_to_sharding(device))
|
||||
else:
|
||||
return jax.device_put(
|
||||
util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device)
|
||||
|
||||
|
||||
@export
|
||||
def zeros_like(a: ArrayLike | DuckTypedArray,
|
||||
dtype: DTypeLike | None = None,
|
||||
shape: Any = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an array full of zeros with the same shape and dtype as an array.
|
||||
|
||||
JAX implementation of :func:`numpy.zeros_like`.
|
||||
|
||||
Args:
|
||||
a: Array-like object with ``shape`` and ``dtype`` attributes.
|
||||
shape: optionally override the shape of the created array.
|
||||
dtype: optionally override the dtype of the created array.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.zeros`
|
||||
- :func:`jax.numpy.empty_like`
|
||||
- :func:`jax.numpy.ones_like`
|
||||
- :func:`jax.numpy.full_like`
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.zeros_like(x)
|
||||
Array([0, 0, 0, 0], dtype=int32)
|
||||
>>> jnp.zeros_like(x, dtype=bool)
|
||||
Array([False, False, False, False], dtype=bool)
|
||||
>>> jnp.zeros_like(x, shape=(2, 3))
|
||||
Array([[0, 0, 0],
|
||||
[0, 0, 0]], dtype=int32)
|
||||
"""
|
||||
if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing
|
||||
util.check_arraylike("zeros_like", a)
|
||||
dtypes.check_user_dtype_supported(dtype, "zeros_like")
|
||||
if shape is not None:
|
||||
shape = canonicalize_shape(shape)
|
||||
return lax.full_like(a, 0, dtype, shape, sharding=util.normalize_device_to_sharding(device))
|
||||
|
||||
|
||||
@export
|
||||
def ones_like(a: ArrayLike | DuckTypedArray,
|
||||
dtype: DTypeLike | None = None,
|
||||
shape: Any = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an array of ones with the same shape and dtype as an array.
|
||||
|
||||
JAX implementation of :func:`numpy.ones_like`.
|
||||
|
||||
Args:
|
||||
a: Array-like object with ``shape`` and ``dtype`` attributes.
|
||||
shape: optionally override the shape of the created array.
|
||||
dtype: optionally override the dtype of the created array.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.empty`
|
||||
- :func:`jax.numpy.zeros_like`
|
||||
- :func:`jax.numpy.ones_like`
|
||||
- :func:`jax.numpy.full_like`
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.ones_like(x)
|
||||
Array([1, 1, 1, 1], dtype=int32)
|
||||
>>> jnp.ones_like(x, dtype=bool)
|
||||
Array([ True, True, True, True], dtype=bool)
|
||||
>>> jnp.ones_like(x, shape=(2, 3))
|
||||
Array([[1, 1, 1],
|
||||
[1, 1, 1]], dtype=int32)
|
||||
"""
|
||||
if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing
|
||||
util.check_arraylike("ones_like", a)
|
||||
dtypes.check_user_dtype_supported(dtype, "ones_like")
|
||||
if shape is not None:
|
||||
shape = canonicalize_shape(shape)
|
||||
return lax.full_like(a, 1, dtype, shape, sharding=util.normalize_device_to_sharding(device))
|
||||
|
||||
|
||||
@export
|
||||
def empty_like(prototype: ArrayLike | DuckTypedArray,
|
||||
dtype: DTypeLike | None = None,
|
||||
shape: Any = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an empty array with the same shape and dtype as an array.
|
||||
|
||||
JAX implementation of :func:`numpy.empty_like`. Because XLA cannot create
|
||||
an un-initialized array, :func:`jax.numpy.empty` will always return an
|
||||
array full of zeros.
|
||||
|
||||
Args:
|
||||
a: Array-like object with ``shape`` and ``dtype`` attributes.
|
||||
shape: optionally override the shape of the created array.
|
||||
dtype: optionally override the dtype of the created array.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.empty`
|
||||
- :func:`jax.numpy.zeros_like`
|
||||
- :func:`jax.numpy.ones_like`
|
||||
- :func:`jax.numpy.full_like`
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.empty_like(x)
|
||||
Array([0, 0, 0, 0], dtype=int32)
|
||||
>>> jnp.empty_like(x, dtype=bool)
|
||||
Array([False, False, False, False], dtype=bool)
|
||||
>>> jnp.empty_like(x, shape=(2, 3))
|
||||
Array([[0, 0, 0],
|
||||
[0, 0, 0]], dtype=int32)
|
||||
"""
|
||||
if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing
|
||||
util.check_arraylike("empty_like", prototype)
|
||||
dtypes.check_user_dtype_supported(dtype, "empty_like")
|
||||
return zeros_like(prototype, dtype=dtype, shape=shape, device=device)
|
||||
|
||||
|
||||
@export
|
||||
def full_like(a: ArrayLike | DuckTypedArray,
|
||||
fill_value: ArrayLike, dtype: DTypeLike | None = None,
|
||||
shape: Any = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an array full of a specified value with the same shape and dtype as an array.
|
||||
|
||||
JAX implementation of :func:`numpy.full_like`.
|
||||
|
||||
Args:
|
||||
a: Array-like object with ``shape`` and ``dtype`` attributes.
|
||||
fill_value: scalar or array with which to fill the created array.
|
||||
shape: optionally override the shape of the created array.
|
||||
dtype: optionally override the dtype of the created array.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.full`
|
||||
- :func:`jax.numpy.empty_like`
|
||||
- :func:`jax.numpy.zeros_like`
|
||||
- :func:`jax.numpy.ones_like`
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.arange(4.0)
|
||||
>>> jnp.full_like(x, 2)
|
||||
Array([2., 2., 2., 2.], dtype=float32)
|
||||
>>> jnp.full_like(x, 0, shape=(2, 3))
|
||||
Array([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)
|
||||
|
||||
`fill_value` may also be an array that is broadcast to the specified shape:
|
||||
|
||||
>>> x = jnp.arange(6).reshape(2, 3)
|
||||
>>> jnp.full_like(x, fill_value=jnp.array([[1], [2]]))
|
||||
Array([[1, 1, 1],
|
||||
[2, 2, 2]], dtype=int32)
|
||||
"""
|
||||
if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing
|
||||
util.check_arraylike("full_like", 0, fill_value)
|
||||
else:
|
||||
util.check_arraylike("full_like", a, fill_value)
|
||||
dtypes.check_user_dtype_supported(dtype, "full_like")
|
||||
if shape is not None:
|
||||
shape = canonicalize_shape(shape)
|
||||
if np.ndim(fill_value) == 0:
|
||||
return lax.full_like(a, fill_value, dtype, shape, sharding=util.normalize_device_to_sharding(device))
|
||||
else:
|
||||
shape = np.shape(a) if shape is None else shape # type: ignore[arg-type]
|
||||
dtype = dtypes.result_type(a) if dtype is None else dtype
|
||||
return jax.device_put(
|
||||
util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device)
|
@ -34,7 +34,6 @@ import math
|
||||
import operator
|
||||
import os
|
||||
import string
|
||||
import types
|
||||
from typing import (Any, IO, Literal, NamedTuple,
|
||||
Protocol, TypeVar, Union, overload)
|
||||
import warnings
|
||||
@ -58,14 +57,15 @@ from jax._src.lax.lax import (PrecisionLike,_array_copy,
|
||||
_sort_le_comparator, _sort_lt_comparator)
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.numpy.array_creation import (empty, empty_like, full, full_like,
|
||||
ones, ones_like, zeros, zeros_like)
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.numpy import util
|
||||
from jax._src.numpy.sorting import argsort, sort
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.typing import (
|
||||
Array, ArrayLike,
|
||||
DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape, StaticScalar,
|
||||
Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape, StaticScalar,
|
||||
)
|
||||
from jax._src.util import (
|
||||
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
|
||||
@ -94,16 +94,6 @@ newaxis = None
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
# Like core.canonicalize_shape, but also accept int-like (non-sequence)
|
||||
# arguments for `shape`.
|
||||
def canonicalize_shape(shape: Any, context: str="") -> core.Shape:
|
||||
if (not isinstance(shape, (tuple, list)) and
|
||||
(getattr(shape, 'ndim', None) == 0 or ndim(shape) == 0)):
|
||||
return core.canonicalize_shape((shape,), context)
|
||||
else:
|
||||
return core.canonicalize_shape(shape, context)
|
||||
|
||||
|
||||
# NumPy constants
|
||||
|
||||
pi = np.pi
|
||||
@ -182,27 +172,6 @@ array_repr = np.array_repr
|
||||
save = np.save
|
||||
savez = np.savez
|
||||
|
||||
|
||||
def _jnp_dtype(obj: DTypeLike | None, *, align: bool = False,
|
||||
copy: bool = False) -> DType:
|
||||
"""Similar to np.dtype, but respects JAX dtype defaults."""
|
||||
if dtypes.issubdtype(obj, dtypes.extended):
|
||||
return obj # type: ignore[return-value]
|
||||
if obj is None:
|
||||
obj = dtypes.float_
|
||||
elif isinstance(obj, type) and obj in dtypes.python_scalar_dtypes:
|
||||
obj = _DEFAULT_TYPEMAP[obj]
|
||||
return np.dtype(obj, align=align, copy=copy)
|
||||
|
||||
### utility functions
|
||||
|
||||
_DEFAULT_TYPEMAP: dict[type, np.dtype] = {
|
||||
bool: np.dtype(bool),
|
||||
int: np.dtype(dtypes.int_),
|
||||
float: np.dtype(dtypes.float_),
|
||||
complex: np.dtype(dtypes.complex_),
|
||||
}
|
||||
|
||||
_lax_const = lax_internal._const
|
||||
|
||||
|
||||
@ -5679,7 +5648,7 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
|
||||
# We offer a more specific warning than the usual ComplexWarning so we prefer
|
||||
# to issue our warning.
|
||||
result = lax_internal._convert_element_type(
|
||||
x_arr, dtype, sharding=_normalize_to_sharding(device),
|
||||
x_arr, dtype, sharding=util.normalize_device_to_sharding(device),
|
||||
warn_on_complex_to_real_cast=False)
|
||||
return _array_copy(result) if copy else result
|
||||
|
||||
@ -5815,366 +5784,6 @@ def copy(a: ArrayLike, order: str | None = None) -> Array:
|
||||
return array(a, copy=True, order=order)
|
||||
|
||||
|
||||
@export
|
||||
def zeros_like(a: ArrayLike | DuckTypedArray,
|
||||
dtype: DTypeLike | None = None,
|
||||
shape: Any = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an array full of zeros with the same shape and dtype as an array.
|
||||
|
||||
JAX implementation of :func:`numpy.zeros_like`.
|
||||
|
||||
Args:
|
||||
a: Array-like object with ``shape`` and ``dtype`` attributes.
|
||||
shape: optionally override the shape of the created array.
|
||||
dtype: optionally override the dtype of the created array.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.zeros`
|
||||
- :func:`jax.numpy.empty_like`
|
||||
- :func:`jax.numpy.ones_like`
|
||||
- :func:`jax.numpy.full_like`
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.zeros_like(x)
|
||||
Array([0, 0, 0, 0], dtype=int32)
|
||||
>>> jnp.zeros_like(x, dtype=bool)
|
||||
Array([False, False, False, False], dtype=bool)
|
||||
>>> jnp.zeros_like(x, shape=(2, 3))
|
||||
Array([[0, 0, 0],
|
||||
[0, 0, 0]], dtype=int32)
|
||||
"""
|
||||
if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing
|
||||
util.check_arraylike("zeros_like", a)
|
||||
dtypes.check_user_dtype_supported(dtype, "zeros_like")
|
||||
if shape is not None:
|
||||
shape = canonicalize_shape(shape)
|
||||
return lax.full_like(a, 0, dtype, shape, sharding=_normalize_to_sharding(device))
|
||||
|
||||
|
||||
@export
|
||||
def ones_like(a: ArrayLike | DuckTypedArray,
|
||||
dtype: DTypeLike | None = None,
|
||||
shape: Any = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an array of ones with the same shape and dtype as an array.
|
||||
|
||||
JAX implementation of :func:`numpy.ones_like`.
|
||||
|
||||
Args:
|
||||
a: Array-like object with ``shape`` and ``dtype`` attributes.
|
||||
shape: optionally override the shape of the created array.
|
||||
dtype: optionally override the dtype of the created array.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.empty`
|
||||
- :func:`jax.numpy.zeros_like`
|
||||
- :func:`jax.numpy.ones_like`
|
||||
- :func:`jax.numpy.full_like`
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.ones_like(x)
|
||||
Array([1, 1, 1, 1], dtype=int32)
|
||||
>>> jnp.ones_like(x, dtype=bool)
|
||||
Array([ True, True, True, True], dtype=bool)
|
||||
>>> jnp.ones_like(x, shape=(2, 3))
|
||||
Array([[1, 1, 1],
|
||||
[1, 1, 1]], dtype=int32)
|
||||
"""
|
||||
if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing
|
||||
util.check_arraylike("ones_like", a)
|
||||
dtypes.check_user_dtype_supported(dtype, "ones_like")
|
||||
if shape is not None:
|
||||
shape = canonicalize_shape(shape)
|
||||
return lax.full_like(a, 1, dtype, shape, sharding=_normalize_to_sharding(device))
|
||||
|
||||
|
||||
@export
|
||||
def empty_like(prototype: ArrayLike | DuckTypedArray,
|
||||
dtype: DTypeLike | None = None,
|
||||
shape: Any = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an empty array with the same shape and dtype as an array.
|
||||
|
||||
JAX implementation of :func:`numpy.empty_like`. Because XLA cannot create
|
||||
an un-initialized array, :func:`jax.numpy.empty` will always return an
|
||||
array full of zeros.
|
||||
|
||||
Args:
|
||||
a: Array-like object with ``shape`` and ``dtype`` attributes.
|
||||
shape: optionally override the shape of the created array.
|
||||
dtype: optionally override the dtype of the created array.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.empty`
|
||||
- :func:`jax.numpy.zeros_like`
|
||||
- :func:`jax.numpy.ones_like`
|
||||
- :func:`jax.numpy.full_like`
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.empty_like(x)
|
||||
Array([0, 0, 0, 0], dtype=int32)
|
||||
>>> jnp.empty_like(x, dtype=bool)
|
||||
Array([False, False, False, False], dtype=bool)
|
||||
>>> jnp.empty_like(x, shape=(2, 3))
|
||||
Array([[0, 0, 0],
|
||||
[0, 0, 0]], dtype=int32)
|
||||
"""
|
||||
if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing
|
||||
util.check_arraylike("empty_like", prototype)
|
||||
dtypes.check_user_dtype_supported(dtype, "empty_like")
|
||||
return zeros_like(prototype, dtype=dtype, shape=shape, device=device)
|
||||
|
||||
|
||||
def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | None:
|
||||
if isinstance(device, xc.Device):
|
||||
return SingleDeviceSharding(device)
|
||||
else:
|
||||
return device
|
||||
|
||||
|
||||
@export
|
||||
def full(shape: Any, fill_value: ArrayLike,
|
||||
dtype: DTypeLike | None = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an array full of a specified value.
|
||||
|
||||
JAX implementation of :func:`numpy.full`.
|
||||
|
||||
Args:
|
||||
shape: int or sequence of ints specifying the shape of the created array.
|
||||
fill_value: scalar or array with which to fill the created array.
|
||||
dtype: optional dtype for the created array; defaults to the dtype of the
|
||||
fill value.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.full_like`
|
||||
- :func:`jax.numpy.empty`
|
||||
- :func:`jax.numpy.zeros`
|
||||
- :func:`jax.numpy.ones`
|
||||
|
||||
Examples:
|
||||
>>> jnp.full(4, 2, dtype=float)
|
||||
Array([2., 2., 2., 2.], dtype=float32)
|
||||
>>> jnp.full((2, 3), 0, dtype=bool)
|
||||
Array([[False, False, False],
|
||||
[False, False, False]], dtype=bool)
|
||||
|
||||
`fill_value` may also be an array that is broadcast to the specified shape:
|
||||
|
||||
>>> jnp.full((2, 3), fill_value=jnp.arange(3))
|
||||
Array([[0, 1, 2],
|
||||
[0, 1, 2]], dtype=int32)
|
||||
"""
|
||||
dtypes.check_user_dtype_supported(dtype, "full")
|
||||
util.check_arraylike("full", fill_value)
|
||||
|
||||
if ndim(fill_value) == 0:
|
||||
shape = canonicalize_shape(shape)
|
||||
return lax.full(shape, fill_value, dtype, sharding=_normalize_to_sharding(device))
|
||||
else:
|
||||
return jax.device_put(
|
||||
broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
|
||||
|
||||
|
||||
@export
|
||||
def full_like(a: ArrayLike | DuckTypedArray,
|
||||
fill_value: ArrayLike, dtype: DTypeLike | None = None,
|
||||
shape: Any = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an array full of a specified value with the same shape and dtype as an array.
|
||||
|
||||
JAX implementation of :func:`numpy.full_like`.
|
||||
|
||||
Args:
|
||||
a: Array-like object with ``shape`` and ``dtype`` attributes.
|
||||
fill_value: scalar or array with which to fill the created array.
|
||||
shape: optionally override the shape of the created array.
|
||||
dtype: optionally override the dtype of the created array.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.full`
|
||||
- :func:`jax.numpy.empty_like`
|
||||
- :func:`jax.numpy.zeros_like`
|
||||
- :func:`jax.numpy.ones_like`
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.arange(4.0)
|
||||
>>> jnp.full_like(x, 2)
|
||||
Array([2., 2., 2., 2.], dtype=float32)
|
||||
>>> jnp.full_like(x, 0, shape=(2, 3))
|
||||
Array([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)
|
||||
|
||||
`fill_value` may also be an array that is broadcast to the specified shape:
|
||||
|
||||
>>> x = jnp.arange(6).reshape(2, 3)
|
||||
>>> jnp.full_like(x, fill_value=jnp.array([[1], [2]]))
|
||||
Array([[1, 1, 1],
|
||||
[2, 2, 2]], dtype=int32)
|
||||
"""
|
||||
if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing
|
||||
util.check_arraylike("full_like", 0, fill_value)
|
||||
else:
|
||||
util.check_arraylike("full_like", a, fill_value)
|
||||
dtypes.check_user_dtype_supported(dtype, "full_like")
|
||||
if shape is not None:
|
||||
shape = canonicalize_shape(shape)
|
||||
if ndim(fill_value) == 0:
|
||||
return lax.full_like(a, fill_value, dtype, shape, sharding=_normalize_to_sharding(device))
|
||||
else:
|
||||
shape = np.shape(a) if shape is None else shape # type: ignore[arg-type]
|
||||
dtype = result_type(a) if dtype is None else dtype
|
||||
return jax.device_put(
|
||||
broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
|
||||
|
||||
|
||||
@export
|
||||
def zeros(shape: Any, dtype: DTypeLike | None = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an array full of zeros.
|
||||
|
||||
JAX implementation of :func:`numpy.zeros`.
|
||||
|
||||
Args:
|
||||
shape: int or sequence of ints specifying the shape of the created array.
|
||||
dtype: optional dtype for the created array; defaults to floating point.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.zeros_like`
|
||||
- :func:`jax.numpy.empty`
|
||||
- :func:`jax.numpy.ones`
|
||||
- :func:`jax.numpy.full`
|
||||
|
||||
Examples:
|
||||
>>> jnp.zeros(4)
|
||||
Array([0., 0., 0., 0.], dtype=float32)
|
||||
>>> jnp.zeros((2, 3), dtype=bool)
|
||||
Array([[False, False, False],
|
||||
[False, False, False]], dtype=bool)
|
||||
"""
|
||||
if isinstance(shape, types.GeneratorType):
|
||||
raise TypeError("expected sequence object with len >= 0 or a single integer")
|
||||
if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m)
|
||||
dtypes.check_user_dtype_supported(dtype, "zeros")
|
||||
shape = canonicalize_shape(shape)
|
||||
return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
|
||||
|
||||
|
||||
@export
|
||||
def ones(shape: Any, dtype: DTypeLike | None = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an array full of ones.
|
||||
|
||||
JAX implementation of :func:`numpy.ones`.
|
||||
|
||||
Args:
|
||||
shape: int or sequence of ints specifying the shape of the created array.
|
||||
dtype: optional dtype for the created array; defaults to floating point.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.ones_like`
|
||||
- :func:`jax.numpy.empty`
|
||||
- :func:`jax.numpy.zeros`
|
||||
- :func:`jax.numpy.full`
|
||||
|
||||
Examples:
|
||||
>>> jnp.ones(4)
|
||||
Array([1., 1., 1., 1.], dtype=float32)
|
||||
>>> jnp.ones((2, 3), dtype=bool)
|
||||
Array([[ True, True, True],
|
||||
[ True, True, True]], dtype=bool)
|
||||
"""
|
||||
if isinstance(shape, types.GeneratorType):
|
||||
raise TypeError("expected sequence object with len >= 0 or a single integer")
|
||||
if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m)
|
||||
shape = canonicalize_shape(shape)
|
||||
dtypes.check_user_dtype_supported(dtype, "ones")
|
||||
return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
|
||||
|
||||
|
||||
@export
|
||||
def empty(shape: Any, dtype: DTypeLike | None = None, *,
|
||||
device: xc.Device | Sharding | None = None) -> Array:
|
||||
"""Create an empty array.
|
||||
|
||||
JAX implementation of :func:`numpy.empty`. Because XLA cannot create an
|
||||
un-initialized array, :func:`jax.numpy.empty` will always return an array
|
||||
full of zeros.
|
||||
|
||||
Args:
|
||||
shape: int or sequence of ints specifying the shape of the created array.
|
||||
dtype: optional dtype for the created array; defaults to floating point.
|
||||
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
||||
to which the created array will be committed.
|
||||
|
||||
Returns:
|
||||
Array of the specified shape and dtype, on the specified device if specified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.empty_like`
|
||||
- :func:`jax.numpy.zeros`
|
||||
- :func:`jax.numpy.ones`
|
||||
- :func:`jax.numpy.full`
|
||||
|
||||
Examples:
|
||||
>>> jnp.empty(4)
|
||||
Array([0., 0., 0., 0.], dtype=float32)
|
||||
>>> jnp.empty((2, 3), dtype=bool)
|
||||
Array([[False, False, False],
|
||||
[False, False, False]], dtype=bool)
|
||||
"""
|
||||
if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m)
|
||||
dtypes.check_user_dtype_supported(dtype, "empty")
|
||||
return zeros(shape, dtype, device=device)
|
||||
|
||||
def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore
|
||||
if isinstance(dtype, int) and isinstance(shape, int):
|
||||
return (f"Cannot interpret '{dtype}' as a data type."
|
||||
f"\n\nDid you accidentally write "
|
||||
f"`jax.numpy.{name}({shape}, {dtype})` "
|
||||
f"when you meant `jax.numpy.{name}(({shape}, {dtype}))`, i.e. "
|
||||
"with a single tuple argument for the shape?")
|
||||
|
||||
|
||||
@export
|
||||
def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:
|
||||
"""Check if two arrays are element-wise equal.
|
||||
@ -6741,7 +6350,7 @@ def _arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None,
|
||||
return _arange_dynamic(start, stop, step, dtype or dtypes.canonicalize_dtype(np.int64))
|
||||
if dtype is None:
|
||||
dtype = result_type(start, *(x for x in [stop, step] if x is not None))
|
||||
dtype = _jnp_dtype(dtype)
|
||||
dtype = dtypes.jax_dtype(dtype)
|
||||
if stop is None and step is None:
|
||||
start_dtype = _dtype(start)
|
||||
if (not dtypes.issubdtype(start_dtype, np.integer) and
|
||||
@ -6884,7 +6493,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
|
||||
if dtype is None:
|
||||
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
|
||||
dtype = _jnp_dtype(dtype)
|
||||
dtype = dtypes.jax_dtype(dtype)
|
||||
computation_dtype = dtypes.to_inexact_dtype(dtype)
|
||||
start = start.astype(computation_dtype)
|
||||
stop = stop.astype(computation_dtype)
|
||||
@ -7004,7 +6613,7 @@ def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
dtypes.check_user_dtype_supported(dtype, "logspace")
|
||||
if dtype is None:
|
||||
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
|
||||
dtype = _jnp_dtype(dtype)
|
||||
dtype = dtypes.jax_dtype(dtype)
|
||||
computation_dtype = dtypes.to_inexact_dtype(dtype)
|
||||
start, stop = util.ensure_arraylike("logspace", start, stop)
|
||||
start = start.astype(computation_dtype)
|
||||
@ -7074,7 +6683,7 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool
|
||||
dtypes.check_user_dtype_supported(dtype, "geomspace")
|
||||
if dtype is None:
|
||||
dtype = dtypes.to_inexact_dtype(result_type(start, stop))
|
||||
dtype = _jnp_dtype(dtype)
|
||||
dtype = dtypes.jax_dtype(dtype)
|
||||
computation_dtype = dtypes.to_inexact_dtype(dtype)
|
||||
start, stop = util.ensure_arraylike("geomspace", start, stop)
|
||||
start = start.astype(computation_dtype)
|
||||
|
@ -24,8 +24,11 @@ from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.sharding_impls import SingleDeviceSharding
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
|
||||
from jax.sharding import Sharding
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -299,3 +302,10 @@ def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array:
|
||||
except:
|
||||
is_always_empty = False # can fail with dynamic shapes
|
||||
return lax.select(condition, x_arr, y_arr) if not is_always_empty else x_arr
|
||||
|
||||
|
||||
def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Sharding | None:
|
||||
if isinstance(device, xc.Device):
|
||||
return SingleDeviceSharding(device)
|
||||
else:
|
||||
return device
|
||||
|
@ -83,8 +83,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
ediff1d as ediff1d,
|
||||
einsum as einsum,
|
||||
einsum_path as einsum_path,
|
||||
empty as empty,
|
||||
empty_like as empty_like,
|
||||
euler_gamma as euler_gamma,
|
||||
expand_dims as expand_dims,
|
||||
extract as extract,
|
||||
@ -104,8 +102,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
fromiter as fromiter,
|
||||
fromstring as fromstring,
|
||||
from_dlpack as from_dlpack,
|
||||
full as full,
|
||||
full_like as full_like,
|
||||
gcd as gcd,
|
||||
geomspace as geomspace,
|
||||
get_printoptions as get_printoptions,
|
||||
@ -154,8 +150,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
ndim as ndim,
|
||||
newaxis as newaxis,
|
||||
nonzero as nonzero,
|
||||
ones as ones,
|
||||
ones_like as ones_like,
|
||||
outer as outer,
|
||||
packbits as packbits,
|
||||
pad as pad,
|
||||
@ -215,6 +209,15 @@ from jax._src.numpy.lax_numpy import (
|
||||
vsplit as vsplit,
|
||||
vstack as vstack,
|
||||
where as where,
|
||||
)
|
||||
|
||||
from jax._src.numpy.array_creation import (
|
||||
empty as empty,
|
||||
empty_like as empty_like,
|
||||
full as full,
|
||||
full_like as full_like,
|
||||
ones as ones,
|
||||
ones_like as ones_like,
|
||||
zeros as zeros,
|
||||
zeros_like as zeros_like,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user