rocm_jax/jax/_src/numpy/array_creation.py

395 lines
14 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import types
from typing import Any
import numpy as np
import jax
from jax import lax
from jax._src import core
from jax._src import dtypes
from jax._src.lib import xla_client as xc
from jax._src.numpy import util
from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike
from jax._src.util import set_module
from jax.sharding import Sharding
export = set_module('jax.numpy')
# Like core.canonicalize_shape, but also accept int-like (non-sequence)
# arguments for `shape`.
def canonicalize_shape(shape: Any, context: str="") -> core.Shape:
if (not isinstance(shape, (tuple, list)) and
(getattr(shape, 'ndim', None) == 0 or np.ndim(shape) == 0)):
return core.canonicalize_shape((shape,), context)
else:
return core.canonicalize_shape(shape, context)
@export
def zeros(shape: Any, dtype: DTypeLike | None = None, *,
device: xc.Device | Sharding | None = None) -> Array:
"""Create an array full of zeros.
JAX implementation of :func:`numpy.zeros`.
Args:
shape: int or sequence of ints specifying the shape of the created array.
dtype: optional dtype for the created array; defaults to floating point.
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
Returns:
Array of the specified shape and dtype, on the specified device if specified.
See also:
- :func:`jax.numpy.zeros_like`
- :func:`jax.numpy.empty`
- :func:`jax.numpy.ones`
- :func:`jax.numpy.full`
Examples:
>>> jnp.zeros(4)
Array([0., 0., 0., 0.], dtype=float32)
>>> jnp.zeros((2, 3), dtype=bool)
Array([[False, False, False],
[False, False, False]], dtype=bool)
"""
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m)
dtypes.check_user_dtype_supported(dtype, "zeros")
shape = canonicalize_shape(shape)
return lax.full(shape, 0, dtypes.jax_dtype(dtype), sharding=util.normalize_device_to_sharding(device))
@export
def ones(shape: Any, dtype: DTypeLike | None = None, *,
device: xc.Device | Sharding | None = None) -> Array:
"""Create an array full of ones.
JAX implementation of :func:`numpy.ones`.
Args:
shape: int or sequence of ints specifying the shape of the created array.
dtype: optional dtype for the created array; defaults to floating point.
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
Returns:
Array of the specified shape and dtype, on the specified device if specified.
See also:
- :func:`jax.numpy.ones_like`
- :func:`jax.numpy.empty`
- :func:`jax.numpy.zeros`
- :func:`jax.numpy.full`
Examples:
>>> jnp.ones(4)
Array([1., 1., 1., 1.], dtype=float32)
>>> jnp.ones((2, 3), dtype=bool)
Array([[ True, True, True],
[ True, True, True]], dtype=bool)
"""
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m)
shape = canonicalize_shape(shape)
dtypes.check_user_dtype_supported(dtype, "ones")
return lax.full(shape, 1, dtypes.jax_dtype(dtype), sharding=util.normalize_device_to_sharding(device))
@export
def empty(shape: Any, dtype: DTypeLike | None = None, *,
device: xc.Device | Sharding | None = None) -> Array:
"""Create an empty array.
JAX implementation of :func:`numpy.empty`. Because XLA cannot create an
un-initialized array, :func:`jax.numpy.empty` will always return an array
full of zeros.
Args:
shape: int or sequence of ints specifying the shape of the created array.
dtype: optional dtype for the created array; defaults to floating point.
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
Returns:
Array of the specified shape and dtype, on the specified device if specified.
See also:
- :func:`jax.numpy.empty_like`
- :func:`jax.numpy.zeros`
- :func:`jax.numpy.ones`
- :func:`jax.numpy.full`
Examples:
>>> jnp.empty(4)
Array([0., 0., 0., 0.], dtype=float32)
>>> jnp.empty((2, 3), dtype=bool)
Array([[False, False, False],
[False, False, False]], dtype=bool)
"""
if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m)
dtypes.check_user_dtype_supported(dtype, "empty")
return zeros(shape, dtype, device=device)
def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore
if isinstance(dtype, int) and isinstance(shape, int):
return (f"Cannot interpret '{dtype}' as a data type."
f"\n\nDid you accidentally write "
f"`jax.numpy.{name}({shape}, {dtype})` "
f"when you meant `jax.numpy.{name}(({shape}, {dtype}))`, i.e. "
"with a single tuple argument for the shape?")
@export
def full(shape: Any, fill_value: ArrayLike,
dtype: DTypeLike | None = None, *,
device: xc.Device | Sharding | None = None) -> Array:
"""Create an array full of a specified value.
JAX implementation of :func:`numpy.full`.
Args:
shape: int or sequence of ints specifying the shape of the created array.
fill_value: scalar or array with which to fill the created array.
dtype: optional dtype for the created array; defaults to the dtype of the
fill value.
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
Returns:
Array of the specified shape and dtype, on the specified device if specified.
See also:
- :func:`jax.numpy.full_like`
- :func:`jax.numpy.empty`
- :func:`jax.numpy.zeros`
- :func:`jax.numpy.ones`
Examples:
>>> jnp.full(4, 2, dtype=float)
Array([2., 2., 2., 2.], dtype=float32)
>>> jnp.full((2, 3), 0, dtype=bool)
Array([[False, False, False],
[False, False, False]], dtype=bool)
`fill_value` may also be an array that is broadcast to the specified shape:
>>> jnp.full((2, 3), fill_value=jnp.arange(3))
Array([[0, 1, 2],
[0, 1, 2]], dtype=int32)
"""
dtypes.check_user_dtype_supported(dtype, "full")
util.check_arraylike("full", fill_value)
if np.ndim(fill_value) == 0:
shape = canonicalize_shape(shape)
return lax.full(shape, fill_value, dtype, sharding=util.normalize_device_to_sharding(device))
else:
return jax.device_put(
util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device)
@export
def zeros_like(a: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = None,
shape: Any = None, *,
device: xc.Device | Sharding | None = None) -> Array:
"""Create an array full of zeros with the same shape and dtype as an array.
JAX implementation of :func:`numpy.zeros_like`.
Args:
a: Array-like object with ``shape`` and ``dtype`` attributes.
shape: optionally override the shape of the created array.
dtype: optionally override the dtype of the created array.
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
Returns:
Array of the specified shape and dtype, on the specified device if specified.
See also:
- :func:`jax.numpy.zeros`
- :func:`jax.numpy.empty_like`
- :func:`jax.numpy.ones_like`
- :func:`jax.numpy.full_like`
Examples:
>>> x = jnp.arange(4)
>>> jnp.zeros_like(x)
Array([0, 0, 0, 0], dtype=int32)
>>> jnp.zeros_like(x, dtype=bool)
Array([False, False, False, False], dtype=bool)
>>> jnp.zeros_like(x, shape=(2, 3))
Array([[0, 0, 0],
[0, 0, 0]], dtype=int32)
"""
if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing
util.check_arraylike("zeros_like", a)
dtypes.check_user_dtype_supported(dtype, "zeros_like")
if shape is not None:
shape = canonicalize_shape(shape)
return lax.full_like(a, 0, dtype, shape, sharding=util.normalize_device_to_sharding(device))
@export
def ones_like(a: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = None,
shape: Any = None, *,
device: xc.Device | Sharding | None = None) -> Array:
"""Create an array of ones with the same shape and dtype as an array.
JAX implementation of :func:`numpy.ones_like`.
Args:
a: Array-like object with ``shape`` and ``dtype`` attributes.
shape: optionally override the shape of the created array.
dtype: optionally override the dtype of the created array.
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
Returns:
Array of the specified shape and dtype, on the specified device if specified.
See also:
- :func:`jax.numpy.empty`
- :func:`jax.numpy.zeros_like`
- :func:`jax.numpy.ones_like`
- :func:`jax.numpy.full_like`
Examples:
>>> x = jnp.arange(4)
>>> jnp.ones_like(x)
Array([1, 1, 1, 1], dtype=int32)
>>> jnp.ones_like(x, dtype=bool)
Array([ True, True, True, True], dtype=bool)
>>> jnp.ones_like(x, shape=(2, 3))
Array([[1, 1, 1],
[1, 1, 1]], dtype=int32)
"""
if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing
util.check_arraylike("ones_like", a)
dtypes.check_user_dtype_supported(dtype, "ones_like")
if shape is not None:
shape = canonicalize_shape(shape)
return lax.full_like(a, 1, dtype, shape, sharding=util.normalize_device_to_sharding(device))
@export
def empty_like(prototype: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = None,
shape: Any = None, *,
device: xc.Device | Sharding | None = None) -> Array:
"""Create an empty array with the same shape and dtype as an array.
JAX implementation of :func:`numpy.empty_like`. Because XLA cannot create
an un-initialized array, :func:`jax.numpy.empty` will always return an
array full of zeros.
Args:
a: Array-like object with ``shape`` and ``dtype`` attributes.
shape: optionally override the shape of the created array.
dtype: optionally override the dtype of the created array.
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
Returns:
Array of the specified shape and dtype, on the specified device if specified.
See also:
- :func:`jax.numpy.empty`
- :func:`jax.numpy.zeros_like`
- :func:`jax.numpy.ones_like`
- :func:`jax.numpy.full_like`
Examples:
>>> x = jnp.arange(4)
>>> jnp.empty_like(x)
Array([0, 0, 0, 0], dtype=int32)
>>> jnp.empty_like(x, dtype=bool)
Array([False, False, False, False], dtype=bool)
>>> jnp.empty_like(x, shape=(2, 3))
Array([[0, 0, 0],
[0, 0, 0]], dtype=int32)
"""
if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing
util.check_arraylike("empty_like", prototype)
dtypes.check_user_dtype_supported(dtype, "empty_like")
return zeros_like(prototype, dtype=dtype, shape=shape, device=device)
@export
def full_like(a: ArrayLike | DuckTypedArray,
fill_value: ArrayLike, dtype: DTypeLike | None = None,
shape: Any = None, *,
device: xc.Device | Sharding | None = None) -> Array:
"""Create an array full of a specified value with the same shape and dtype as an array.
JAX implementation of :func:`numpy.full_like`.
Args:
a: Array-like object with ``shape`` and ``dtype`` attributes.
fill_value: scalar or array with which to fill the created array.
shape: optionally override the shape of the created array.
dtype: optionally override the dtype of the created array.
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
Returns:
Array of the specified shape and dtype, on the specified device if specified.
See also:
- :func:`jax.numpy.full`
- :func:`jax.numpy.empty_like`
- :func:`jax.numpy.zeros_like`
- :func:`jax.numpy.ones_like`
Examples:
>>> x = jnp.arange(4.0)
>>> jnp.full_like(x, 2)
Array([2., 2., 2., 2.], dtype=float32)
>>> jnp.full_like(x, 0, shape=(2, 3))
Array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
`fill_value` may also be an array that is broadcast to the specified shape:
>>> x = jnp.arange(6).reshape(2, 3)
>>> jnp.full_like(x, fill_value=jnp.array([[1], [2]]))
Array([[1, 1, 1],
[2, 2, 2]], dtype=int32)
"""
if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing
util.check_arraylike("full_like", 0, fill_value)
else:
util.check_arraylike("full_like", a, fill_value)
dtypes.check_user_dtype_supported(dtype, "full_like")
if shape is not None:
shape = canonicalize_shape(shape)
if np.ndim(fill_value) == 0:
return lax.full_like(a, fill_value, dtype, shape, sharding=util.normalize_device_to_sharding(device))
else:
shape = np.shape(a) if shape is None else shape # type: ignore[arg-type]
dtype = dtypes.result_type(a) if dtype is None else dtype
return jax.device_put(
util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device)