mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 13:26:09 +00:00
395 lines
14 KiB
Python
395 lines
14 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import types
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
import jax
|
|
from jax import lax
|
|
from jax._src import core
|
|
from jax._src import dtypes
|
|
from jax._src.lib import xla_client as xc
|
|
from jax._src.numpy import util
|
|
from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike
|
|
from jax._src.util import set_module
|
|
from jax.sharding import Sharding
|
|
|
|
|
|
export = set_module('jax.numpy')
|
|
|
|
|
|
# Like core.canonicalize_shape, but also accept int-like (non-sequence)
|
|
# arguments for `shape`.
|
|
def canonicalize_shape(shape: Any, context: str="") -> core.Shape:
|
|
if (not isinstance(shape, (tuple, list)) and
|
|
(getattr(shape, 'ndim', None) == 0 or np.ndim(shape) == 0)):
|
|
return core.canonicalize_shape((shape,), context)
|
|
else:
|
|
return core.canonicalize_shape(shape, context)
|
|
|
|
|
|
@export
|
|
def zeros(shape: Any, dtype: DTypeLike | None = None, *,
|
|
device: xc.Device | Sharding | None = None) -> Array:
|
|
"""Create an array full of zeros.
|
|
|
|
JAX implementation of :func:`numpy.zeros`.
|
|
|
|
Args:
|
|
shape: int or sequence of ints specifying the shape of the created array.
|
|
dtype: optional dtype for the created array; defaults to floating point.
|
|
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
|
to which the created array will be committed.
|
|
|
|
Returns:
|
|
Array of the specified shape and dtype, on the specified device if specified.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.zeros_like`
|
|
- :func:`jax.numpy.empty`
|
|
- :func:`jax.numpy.ones`
|
|
- :func:`jax.numpy.full`
|
|
|
|
Examples:
|
|
>>> jnp.zeros(4)
|
|
Array([0., 0., 0., 0.], dtype=float32)
|
|
>>> jnp.zeros((2, 3), dtype=bool)
|
|
Array([[False, False, False],
|
|
[False, False, False]], dtype=bool)
|
|
"""
|
|
if isinstance(shape, types.GeneratorType):
|
|
raise TypeError("expected sequence object with len >= 0 or a single integer")
|
|
if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m)
|
|
dtypes.check_user_dtype_supported(dtype, "zeros")
|
|
shape = canonicalize_shape(shape)
|
|
return lax.full(shape, 0, dtypes.jax_dtype(dtype), sharding=util.normalize_device_to_sharding(device))
|
|
|
|
|
|
@export
|
|
def ones(shape: Any, dtype: DTypeLike | None = None, *,
|
|
device: xc.Device | Sharding | None = None) -> Array:
|
|
"""Create an array full of ones.
|
|
|
|
JAX implementation of :func:`numpy.ones`.
|
|
|
|
Args:
|
|
shape: int or sequence of ints specifying the shape of the created array.
|
|
dtype: optional dtype for the created array; defaults to floating point.
|
|
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
|
to which the created array will be committed.
|
|
|
|
Returns:
|
|
Array of the specified shape and dtype, on the specified device if specified.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.ones_like`
|
|
- :func:`jax.numpy.empty`
|
|
- :func:`jax.numpy.zeros`
|
|
- :func:`jax.numpy.full`
|
|
|
|
Examples:
|
|
>>> jnp.ones(4)
|
|
Array([1., 1., 1., 1.], dtype=float32)
|
|
>>> jnp.ones((2, 3), dtype=bool)
|
|
Array([[ True, True, True],
|
|
[ True, True, True]], dtype=bool)
|
|
"""
|
|
if isinstance(shape, types.GeneratorType):
|
|
raise TypeError("expected sequence object with len >= 0 or a single integer")
|
|
if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m)
|
|
shape = canonicalize_shape(shape)
|
|
dtypes.check_user_dtype_supported(dtype, "ones")
|
|
return lax.full(shape, 1, dtypes.jax_dtype(dtype), sharding=util.normalize_device_to_sharding(device))
|
|
|
|
|
|
@export
|
|
def empty(shape: Any, dtype: DTypeLike | None = None, *,
|
|
device: xc.Device | Sharding | None = None) -> Array:
|
|
"""Create an empty array.
|
|
|
|
JAX implementation of :func:`numpy.empty`. Because XLA cannot create an
|
|
un-initialized array, :func:`jax.numpy.empty` will always return an array
|
|
full of zeros.
|
|
|
|
Args:
|
|
shape: int or sequence of ints specifying the shape of the created array.
|
|
dtype: optional dtype for the created array; defaults to floating point.
|
|
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
|
to which the created array will be committed.
|
|
|
|
Returns:
|
|
Array of the specified shape and dtype, on the specified device if specified.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.empty_like`
|
|
- :func:`jax.numpy.zeros`
|
|
- :func:`jax.numpy.ones`
|
|
- :func:`jax.numpy.full`
|
|
|
|
Examples:
|
|
>>> jnp.empty(4)
|
|
Array([0., 0., 0., 0.], dtype=float32)
|
|
>>> jnp.empty((2, 3), dtype=bool)
|
|
Array([[False, False, False],
|
|
[False, False, False]], dtype=bool)
|
|
"""
|
|
if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m)
|
|
dtypes.check_user_dtype_supported(dtype, "empty")
|
|
return zeros(shape, dtype, device=device)
|
|
|
|
|
|
def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore
|
|
if isinstance(dtype, int) and isinstance(shape, int):
|
|
return (f"Cannot interpret '{dtype}' as a data type."
|
|
f"\n\nDid you accidentally write "
|
|
f"`jax.numpy.{name}({shape}, {dtype})` "
|
|
f"when you meant `jax.numpy.{name}(({shape}, {dtype}))`, i.e. "
|
|
"with a single tuple argument for the shape?")
|
|
|
|
@export
|
|
def full(shape: Any, fill_value: ArrayLike,
|
|
dtype: DTypeLike | None = None, *,
|
|
device: xc.Device | Sharding | None = None) -> Array:
|
|
"""Create an array full of a specified value.
|
|
|
|
JAX implementation of :func:`numpy.full`.
|
|
|
|
Args:
|
|
shape: int or sequence of ints specifying the shape of the created array.
|
|
fill_value: scalar or array with which to fill the created array.
|
|
dtype: optional dtype for the created array; defaults to the dtype of the
|
|
fill value.
|
|
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
|
to which the created array will be committed.
|
|
|
|
Returns:
|
|
Array of the specified shape and dtype, on the specified device if specified.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.full_like`
|
|
- :func:`jax.numpy.empty`
|
|
- :func:`jax.numpy.zeros`
|
|
- :func:`jax.numpy.ones`
|
|
|
|
Examples:
|
|
>>> jnp.full(4, 2, dtype=float)
|
|
Array([2., 2., 2., 2.], dtype=float32)
|
|
>>> jnp.full((2, 3), 0, dtype=bool)
|
|
Array([[False, False, False],
|
|
[False, False, False]], dtype=bool)
|
|
|
|
`fill_value` may also be an array that is broadcast to the specified shape:
|
|
|
|
>>> jnp.full((2, 3), fill_value=jnp.arange(3))
|
|
Array([[0, 1, 2],
|
|
[0, 1, 2]], dtype=int32)
|
|
"""
|
|
dtypes.check_user_dtype_supported(dtype, "full")
|
|
util.check_arraylike("full", fill_value)
|
|
|
|
if np.ndim(fill_value) == 0:
|
|
shape = canonicalize_shape(shape)
|
|
return lax.full(shape, fill_value, dtype, sharding=util.normalize_device_to_sharding(device))
|
|
else:
|
|
return jax.device_put(
|
|
util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device)
|
|
|
|
|
|
@export
|
|
def zeros_like(a: ArrayLike | DuckTypedArray,
|
|
dtype: DTypeLike | None = None,
|
|
shape: Any = None, *,
|
|
device: xc.Device | Sharding | None = None) -> Array:
|
|
"""Create an array full of zeros with the same shape and dtype as an array.
|
|
|
|
JAX implementation of :func:`numpy.zeros_like`.
|
|
|
|
Args:
|
|
a: Array-like object with ``shape`` and ``dtype`` attributes.
|
|
shape: optionally override the shape of the created array.
|
|
dtype: optionally override the dtype of the created array.
|
|
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
|
to which the created array will be committed.
|
|
|
|
Returns:
|
|
Array of the specified shape and dtype, on the specified device if specified.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.zeros`
|
|
- :func:`jax.numpy.empty_like`
|
|
- :func:`jax.numpy.ones_like`
|
|
- :func:`jax.numpy.full_like`
|
|
|
|
Examples:
|
|
>>> x = jnp.arange(4)
|
|
>>> jnp.zeros_like(x)
|
|
Array([0, 0, 0, 0], dtype=int32)
|
|
>>> jnp.zeros_like(x, dtype=bool)
|
|
Array([False, False, False, False], dtype=bool)
|
|
>>> jnp.zeros_like(x, shape=(2, 3))
|
|
Array([[0, 0, 0],
|
|
[0, 0, 0]], dtype=int32)
|
|
"""
|
|
if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing
|
|
util.check_arraylike("zeros_like", a)
|
|
dtypes.check_user_dtype_supported(dtype, "zeros_like")
|
|
if shape is not None:
|
|
shape = canonicalize_shape(shape)
|
|
return lax.full_like(a, 0, dtype, shape, sharding=util.normalize_device_to_sharding(device))
|
|
|
|
|
|
@export
|
|
def ones_like(a: ArrayLike | DuckTypedArray,
|
|
dtype: DTypeLike | None = None,
|
|
shape: Any = None, *,
|
|
device: xc.Device | Sharding | None = None) -> Array:
|
|
"""Create an array of ones with the same shape and dtype as an array.
|
|
|
|
JAX implementation of :func:`numpy.ones_like`.
|
|
|
|
Args:
|
|
a: Array-like object with ``shape`` and ``dtype`` attributes.
|
|
shape: optionally override the shape of the created array.
|
|
dtype: optionally override the dtype of the created array.
|
|
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
|
to which the created array will be committed.
|
|
|
|
Returns:
|
|
Array of the specified shape and dtype, on the specified device if specified.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.empty`
|
|
- :func:`jax.numpy.zeros_like`
|
|
- :func:`jax.numpy.ones_like`
|
|
- :func:`jax.numpy.full_like`
|
|
|
|
Examples:
|
|
>>> x = jnp.arange(4)
|
|
>>> jnp.ones_like(x)
|
|
Array([1, 1, 1, 1], dtype=int32)
|
|
>>> jnp.ones_like(x, dtype=bool)
|
|
Array([ True, True, True, True], dtype=bool)
|
|
>>> jnp.ones_like(x, shape=(2, 3))
|
|
Array([[1, 1, 1],
|
|
[1, 1, 1]], dtype=int32)
|
|
"""
|
|
if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing
|
|
util.check_arraylike("ones_like", a)
|
|
dtypes.check_user_dtype_supported(dtype, "ones_like")
|
|
if shape is not None:
|
|
shape = canonicalize_shape(shape)
|
|
return lax.full_like(a, 1, dtype, shape, sharding=util.normalize_device_to_sharding(device))
|
|
|
|
|
|
@export
|
|
def empty_like(prototype: ArrayLike | DuckTypedArray,
|
|
dtype: DTypeLike | None = None,
|
|
shape: Any = None, *,
|
|
device: xc.Device | Sharding | None = None) -> Array:
|
|
"""Create an empty array with the same shape and dtype as an array.
|
|
|
|
JAX implementation of :func:`numpy.empty_like`. Because XLA cannot create
|
|
an un-initialized array, :func:`jax.numpy.empty` will always return an
|
|
array full of zeros.
|
|
|
|
Args:
|
|
a: Array-like object with ``shape`` and ``dtype`` attributes.
|
|
shape: optionally override the shape of the created array.
|
|
dtype: optionally override the dtype of the created array.
|
|
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
|
to which the created array will be committed.
|
|
|
|
Returns:
|
|
Array of the specified shape and dtype, on the specified device if specified.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.empty`
|
|
- :func:`jax.numpy.zeros_like`
|
|
- :func:`jax.numpy.ones_like`
|
|
- :func:`jax.numpy.full_like`
|
|
|
|
Examples:
|
|
>>> x = jnp.arange(4)
|
|
>>> jnp.empty_like(x)
|
|
Array([0, 0, 0, 0], dtype=int32)
|
|
>>> jnp.empty_like(x, dtype=bool)
|
|
Array([False, False, False, False], dtype=bool)
|
|
>>> jnp.empty_like(x, shape=(2, 3))
|
|
Array([[0, 0, 0],
|
|
[0, 0, 0]], dtype=int32)
|
|
"""
|
|
if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing
|
|
util.check_arraylike("empty_like", prototype)
|
|
dtypes.check_user_dtype_supported(dtype, "empty_like")
|
|
return zeros_like(prototype, dtype=dtype, shape=shape, device=device)
|
|
|
|
|
|
@export
|
|
def full_like(a: ArrayLike | DuckTypedArray,
|
|
fill_value: ArrayLike, dtype: DTypeLike | None = None,
|
|
shape: Any = None, *,
|
|
device: xc.Device | Sharding | None = None) -> Array:
|
|
"""Create an array full of a specified value with the same shape and dtype as an array.
|
|
|
|
JAX implementation of :func:`numpy.full_like`.
|
|
|
|
Args:
|
|
a: Array-like object with ``shape`` and ``dtype`` attributes.
|
|
fill_value: scalar or array with which to fill the created array.
|
|
shape: optionally override the shape of the created array.
|
|
dtype: optionally override the dtype of the created array.
|
|
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
|
to which the created array will be committed.
|
|
|
|
Returns:
|
|
Array of the specified shape and dtype, on the specified device if specified.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.full`
|
|
- :func:`jax.numpy.empty_like`
|
|
- :func:`jax.numpy.zeros_like`
|
|
- :func:`jax.numpy.ones_like`
|
|
|
|
Examples:
|
|
>>> x = jnp.arange(4.0)
|
|
>>> jnp.full_like(x, 2)
|
|
Array([2., 2., 2., 2.], dtype=float32)
|
|
>>> jnp.full_like(x, 0, shape=(2, 3))
|
|
Array([[0., 0., 0.],
|
|
[0., 0., 0.]], dtype=float32)
|
|
|
|
`fill_value` may also be an array that is broadcast to the specified shape:
|
|
|
|
>>> x = jnp.arange(6).reshape(2, 3)
|
|
>>> jnp.full_like(x, fill_value=jnp.array([[1], [2]]))
|
|
Array([[1, 1, 1],
|
|
[2, 2, 2]], dtype=int32)
|
|
"""
|
|
if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing
|
|
util.check_arraylike("full_like", 0, fill_value)
|
|
else:
|
|
util.check_arraylike("full_like", a, fill_value)
|
|
dtypes.check_user_dtype_supported(dtype, "full_like")
|
|
if shape is not None:
|
|
shape = canonicalize_shape(shape)
|
|
if np.ndim(fill_value) == 0:
|
|
return lax.full_like(a, fill_value, dtype, shape, sharding=util.normalize_device_to_sharding(device))
|
|
else:
|
|
shape = np.shape(a) if shape is None else shape # type: ignore[arg-type]
|
|
dtype = dtypes.result_type(a) if dtype is None else dtype
|
|
return jax.device_put(
|
|
util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device)
|