diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 953381781..d8cdeecea 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -214,6 +214,28 @@ _default_types: dict[str, type[Any]] = { 'c': complex_, } + +def jax_dtype(obj: DTypeLike | None, *, align: bool = False, + copy: bool = False) -> DType: + """Cast an object to a dtype, respecting JAX dtype defaults. + + Arguments mirror those of :func:`numpy.dtype`. + """ + if obj is None: + obj = float_ + elif issubdtype(obj, extended): + return obj # type: ignore[return-value] + elif isinstance(obj, type): + obj = _DEFAULT_TYPEMAP.get(obj, obj) + return np.dtype(obj, align=align, copy=copy) + +_DEFAULT_TYPEMAP: dict[type, DTypeLike] = { + bool: bool, + int: int_, + float: float_, + complex: complex_, +} + def bit_width(dtype: DTypeLike) -> int: """Number of bits per element for the dtype.""" # Note: we cannot use dtype.itemsize here because this is diff --git a/jax/_src/numpy/array_creation.py b/jax/_src/numpy/array_creation.py new file mode 100644 index 000000000..67418e732 --- /dev/null +++ b/jax/_src/numpy/array_creation.py @@ -0,0 +1,394 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types +from typing import Any + +import numpy as np + +import jax +from jax import lax +from jax._src import core +from jax._src import dtypes +from jax._src.lib import xla_client as xc +from jax._src.numpy import util +from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike +from jax._src.util import set_module +from jax.sharding import Sharding + + +export = set_module('jax.numpy') + + +# Like core.canonicalize_shape, but also accept int-like (non-sequence) +# arguments for `shape`. +def canonicalize_shape(shape: Any, context: str="") -> core.Shape: + if (not isinstance(shape, (tuple, list)) and + (getattr(shape, 'ndim', None) == 0 or np.ndim(shape) == 0)): + return core.canonicalize_shape((shape,), context) + else: + return core.canonicalize_shape(shape, context) + + +@export +def zeros(shape: Any, dtype: DTypeLike | None = None, *, + device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of zeros. + + JAX implementation of :func:`numpy.zeros`. + + Args: + shape: int or sequence of ints specifying the shape of the created array. + dtype: optional dtype for the created array; defaults to floating point. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.zeros_like` + - :func:`jax.numpy.empty` + - :func:`jax.numpy.ones` + - :func:`jax.numpy.full` + + Examples: + >>> jnp.zeros(4) + Array([0., 0., 0., 0.], dtype=float32) + >>> jnp.zeros((2, 3), dtype=bool) + Array([[False, False, False], + [False, False, False]], dtype=bool) + """ + if isinstance(shape, types.GeneratorType): + raise TypeError("expected sequence object with len >= 0 or a single integer") + if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m) + dtypes.check_user_dtype_supported(dtype, "zeros") + shape = canonicalize_shape(shape) + return lax.full(shape, 0, dtypes.jax_dtype(dtype), sharding=util.normalize_device_to_sharding(device)) + + +@export +def ones(shape: Any, dtype: DTypeLike | None = None, *, + device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of ones. + + JAX implementation of :func:`numpy.ones`. + + Args: + shape: int or sequence of ints specifying the shape of the created array. + dtype: optional dtype for the created array; defaults to floating point. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.ones_like` + - :func:`jax.numpy.empty` + - :func:`jax.numpy.zeros` + - :func:`jax.numpy.full` + + Examples: + >>> jnp.ones(4) + Array([1., 1., 1., 1.], dtype=float32) + >>> jnp.ones((2, 3), dtype=bool) + Array([[ True, True, True], + [ True, True, True]], dtype=bool) + """ + if isinstance(shape, types.GeneratorType): + raise TypeError("expected sequence object with len >= 0 or a single integer") + if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m) + shape = canonicalize_shape(shape) + dtypes.check_user_dtype_supported(dtype, "ones") + return lax.full(shape, 1, dtypes.jax_dtype(dtype), sharding=util.normalize_device_to_sharding(device)) + + +@export +def empty(shape: Any, dtype: DTypeLike | None = None, *, + device: xc.Device | Sharding | None = None) -> Array: + """Create an empty array. + + JAX implementation of :func:`numpy.empty`. Because XLA cannot create an + un-initialized array, :func:`jax.numpy.empty` will always return an array + full of zeros. + + Args: + shape: int or sequence of ints specifying the shape of the created array. + dtype: optional dtype for the created array; defaults to floating point. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.empty_like` + - :func:`jax.numpy.zeros` + - :func:`jax.numpy.ones` + - :func:`jax.numpy.full` + + Examples: + >>> jnp.empty(4) + Array([0., 0., 0., 0.], dtype=float32) + >>> jnp.empty((2, 3), dtype=bool) + Array([[False, False, False], + [False, False, False]], dtype=bool) + """ + if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m) + dtypes.check_user_dtype_supported(dtype, "empty") + return zeros(shape, dtype, device=device) + + +def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore + if isinstance(dtype, int) and isinstance(shape, int): + return (f"Cannot interpret '{dtype}' as a data type." + f"\n\nDid you accidentally write " + f"`jax.numpy.{name}({shape}, {dtype})` " + f"when you meant `jax.numpy.{name}(({shape}, {dtype}))`, i.e. " + "with a single tuple argument for the shape?") + +@export +def full(shape: Any, fill_value: ArrayLike, + dtype: DTypeLike | None = None, *, + device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of a specified value. + + JAX implementation of :func:`numpy.full`. + + Args: + shape: int or sequence of ints specifying the shape of the created array. + fill_value: scalar or array with which to fill the created array. + dtype: optional dtype for the created array; defaults to the dtype of the + fill value. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.full_like` + - :func:`jax.numpy.empty` + - :func:`jax.numpy.zeros` + - :func:`jax.numpy.ones` + + Examples: + >>> jnp.full(4, 2, dtype=float) + Array([2., 2., 2., 2.], dtype=float32) + >>> jnp.full((2, 3), 0, dtype=bool) + Array([[False, False, False], + [False, False, False]], dtype=bool) + + `fill_value` may also be an array that is broadcast to the specified shape: + + >>> jnp.full((2, 3), fill_value=jnp.arange(3)) + Array([[0, 1, 2], + [0, 1, 2]], dtype=int32) + """ + dtypes.check_user_dtype_supported(dtype, "full") + util.check_arraylike("full", fill_value) + + if np.ndim(fill_value) == 0: + shape = canonicalize_shape(shape) + return lax.full(shape, fill_value, dtype, sharding=util.normalize_device_to_sharding(device)) + else: + return jax.device_put( + util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device) + + +@export +def zeros_like(a: ArrayLike | DuckTypedArray, + dtype: DTypeLike | None = None, + shape: Any = None, *, + device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of zeros with the same shape and dtype as an array. + + JAX implementation of :func:`numpy.zeros_like`. + + Args: + a: Array-like object with ``shape`` and ``dtype`` attributes. + shape: optionally override the shape of the created array. + dtype: optionally override the dtype of the created array. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.zeros` + - :func:`jax.numpy.empty_like` + - :func:`jax.numpy.ones_like` + - :func:`jax.numpy.full_like` + + Examples: + >>> x = jnp.arange(4) + >>> jnp.zeros_like(x) + Array([0, 0, 0, 0], dtype=int32) + >>> jnp.zeros_like(x, dtype=bool) + Array([False, False, False, False], dtype=bool) + >>> jnp.zeros_like(x, shape=(2, 3)) + Array([[0, 0, 0], + [0, 0, 0]], dtype=int32) + """ + if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing + util.check_arraylike("zeros_like", a) + dtypes.check_user_dtype_supported(dtype, "zeros_like") + if shape is not None: + shape = canonicalize_shape(shape) + return lax.full_like(a, 0, dtype, shape, sharding=util.normalize_device_to_sharding(device)) + + +@export +def ones_like(a: ArrayLike | DuckTypedArray, + dtype: DTypeLike | None = None, + shape: Any = None, *, + device: xc.Device | Sharding | None = None) -> Array: + """Create an array of ones with the same shape and dtype as an array. + + JAX implementation of :func:`numpy.ones_like`. + + Args: + a: Array-like object with ``shape`` and ``dtype`` attributes. + shape: optionally override the shape of the created array. + dtype: optionally override the dtype of the created array. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.empty` + - :func:`jax.numpy.zeros_like` + - :func:`jax.numpy.ones_like` + - :func:`jax.numpy.full_like` + + Examples: + >>> x = jnp.arange(4) + >>> jnp.ones_like(x) + Array([1, 1, 1, 1], dtype=int32) + >>> jnp.ones_like(x, dtype=bool) + Array([ True, True, True, True], dtype=bool) + >>> jnp.ones_like(x, shape=(2, 3)) + Array([[1, 1, 1], + [1, 1, 1]], dtype=int32) + """ + if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing + util.check_arraylike("ones_like", a) + dtypes.check_user_dtype_supported(dtype, "ones_like") + if shape is not None: + shape = canonicalize_shape(shape) + return lax.full_like(a, 1, dtype, shape, sharding=util.normalize_device_to_sharding(device)) + + +@export +def empty_like(prototype: ArrayLike | DuckTypedArray, + dtype: DTypeLike | None = None, + shape: Any = None, *, + device: xc.Device | Sharding | None = None) -> Array: + """Create an empty array with the same shape and dtype as an array. + + JAX implementation of :func:`numpy.empty_like`. Because XLA cannot create + an un-initialized array, :func:`jax.numpy.empty` will always return an + array full of zeros. + + Args: + a: Array-like object with ``shape`` and ``dtype`` attributes. + shape: optionally override the shape of the created array. + dtype: optionally override the dtype of the created array. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.empty` + - :func:`jax.numpy.zeros_like` + - :func:`jax.numpy.ones_like` + - :func:`jax.numpy.full_like` + + Examples: + >>> x = jnp.arange(4) + >>> jnp.empty_like(x) + Array([0, 0, 0, 0], dtype=int32) + >>> jnp.empty_like(x, dtype=bool) + Array([False, False, False, False], dtype=bool) + >>> jnp.empty_like(x, shape=(2, 3)) + Array([[0, 0, 0], + [0, 0, 0]], dtype=int32) + """ + if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing + util.check_arraylike("empty_like", prototype) + dtypes.check_user_dtype_supported(dtype, "empty_like") + return zeros_like(prototype, dtype=dtype, shape=shape, device=device) + + +@export +def full_like(a: ArrayLike | DuckTypedArray, + fill_value: ArrayLike, dtype: DTypeLike | None = None, + shape: Any = None, *, + device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of a specified value with the same shape and dtype as an array. + + JAX implementation of :func:`numpy.full_like`. + + Args: + a: Array-like object with ``shape`` and ``dtype`` attributes. + fill_value: scalar or array with which to fill the created array. + shape: optionally override the shape of the created array. + dtype: optionally override the dtype of the created array. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.full` + - :func:`jax.numpy.empty_like` + - :func:`jax.numpy.zeros_like` + - :func:`jax.numpy.ones_like` + + Examples: + >>> x = jnp.arange(4.0) + >>> jnp.full_like(x, 2) + Array([2., 2., 2., 2.], dtype=float32) + >>> jnp.full_like(x, 0, shape=(2, 3)) + Array([[0., 0., 0.], + [0., 0., 0.]], dtype=float32) + + `fill_value` may also be an array that is broadcast to the specified shape: + + >>> x = jnp.arange(6).reshape(2, 3) + >>> jnp.full_like(x, fill_value=jnp.array([[1], [2]])) + Array([[1, 1, 1], + [2, 2, 2]], dtype=int32) + """ + if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing + util.check_arraylike("full_like", 0, fill_value) + else: + util.check_arraylike("full_like", a, fill_value) + dtypes.check_user_dtype_supported(dtype, "full_like") + if shape is not None: + shape = canonicalize_shape(shape) + if np.ndim(fill_value) == 0: + return lax.full_like(a, fill_value, dtype, shape, sharding=util.normalize_device_to_sharding(device)) + else: + shape = np.shape(a) if shape is None else shape # type: ignore[arg-type] + dtype = dtypes.result_type(a) if dtype is None else dtype + return jax.device_put( + util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 50a4b88ff..838f65e8d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -34,7 +34,6 @@ import math import operator import os import string -import types from typing import (Any, IO, Literal, NamedTuple, Protocol, TypeVar, Union, overload) import warnings @@ -58,14 +57,15 @@ from jax._src.lax.lax import (PrecisionLike,_array_copy, _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version +from jax._src.numpy.array_creation import (empty, empty_like, full, full_like, + ones, ones_like, zeros, zeros_like) from jax._src.numpy import reductions from jax._src.numpy import ufuncs from jax._src.numpy import util from jax._src.numpy.sorting import argsort, sort from jax._src.numpy.vectorize import vectorize from jax._src.typing import ( - Array, ArrayLike, - DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape, StaticScalar, + Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape, StaticScalar, ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, @@ -94,16 +94,6 @@ newaxis = None T = TypeVar('T') -# Like core.canonicalize_shape, but also accept int-like (non-sequence) -# arguments for `shape`. -def canonicalize_shape(shape: Any, context: str="") -> core.Shape: - if (not isinstance(shape, (tuple, list)) and - (getattr(shape, 'ndim', None) == 0 or ndim(shape) == 0)): - return core.canonicalize_shape((shape,), context) - else: - return core.canonicalize_shape(shape, context) - - # NumPy constants pi = np.pi @@ -182,27 +172,6 @@ array_repr = np.array_repr save = np.save savez = np.savez - -def _jnp_dtype(obj: DTypeLike | None, *, align: bool = False, - copy: bool = False) -> DType: - """Similar to np.dtype, but respects JAX dtype defaults.""" - if dtypes.issubdtype(obj, dtypes.extended): - return obj # type: ignore[return-value] - if obj is None: - obj = dtypes.float_ - elif isinstance(obj, type) and obj in dtypes.python_scalar_dtypes: - obj = _DEFAULT_TYPEMAP[obj] - return np.dtype(obj, align=align, copy=copy) - -### utility functions - -_DEFAULT_TYPEMAP: dict[type, np.dtype] = { - bool: np.dtype(bool), - int: np.dtype(dtypes.int_), - float: np.dtype(dtypes.float_), - complex: np.dtype(dtypes.complex_), -} - _lax_const = lax_internal._const @@ -5679,7 +5648,7 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, # We offer a more specific warning than the usual ComplexWarning so we prefer # to issue our warning. result = lax_internal._convert_element_type( - x_arr, dtype, sharding=_normalize_to_sharding(device), + x_arr, dtype, sharding=util.normalize_device_to_sharding(device), warn_on_complex_to_real_cast=False) return _array_copy(result) if copy else result @@ -5815,366 +5784,6 @@ def copy(a: ArrayLike, order: str | None = None) -> Array: return array(a, copy=True, order=order) -@export -def zeros_like(a: ArrayLike | DuckTypedArray, - dtype: DTypeLike | None = None, - shape: Any = None, *, - device: xc.Device | Sharding | None = None) -> Array: - """Create an array full of zeros with the same shape and dtype as an array. - - JAX implementation of :func:`numpy.zeros_like`. - - Args: - a: Array-like object with ``shape`` and ``dtype`` attributes. - shape: optionally override the shape of the created array. - dtype: optionally override the dtype of the created array. - device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - Array of the specified shape and dtype, on the specified device if specified. - - See also: - - :func:`jax.numpy.zeros` - - :func:`jax.numpy.empty_like` - - :func:`jax.numpy.ones_like` - - :func:`jax.numpy.full_like` - - Examples: - >>> x = jnp.arange(4) - >>> jnp.zeros_like(x) - Array([0, 0, 0, 0], dtype=int32) - >>> jnp.zeros_like(x, dtype=bool) - Array([False, False, False, False], dtype=bool) - >>> jnp.zeros_like(x, shape=(2, 3)) - Array([[0, 0, 0], - [0, 0, 0]], dtype=int32) - """ - if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing - util.check_arraylike("zeros_like", a) - dtypes.check_user_dtype_supported(dtype, "zeros_like") - if shape is not None: - shape = canonicalize_shape(shape) - return lax.full_like(a, 0, dtype, shape, sharding=_normalize_to_sharding(device)) - - -@export -def ones_like(a: ArrayLike | DuckTypedArray, - dtype: DTypeLike | None = None, - shape: Any = None, *, - device: xc.Device | Sharding | None = None) -> Array: - """Create an array of ones with the same shape and dtype as an array. - - JAX implementation of :func:`numpy.ones_like`. - - Args: - a: Array-like object with ``shape`` and ``dtype`` attributes. - shape: optionally override the shape of the created array. - dtype: optionally override the dtype of the created array. - device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - Array of the specified shape and dtype, on the specified device if specified. - - See also: - - :func:`jax.numpy.empty` - - :func:`jax.numpy.zeros_like` - - :func:`jax.numpy.ones_like` - - :func:`jax.numpy.full_like` - - Examples: - >>> x = jnp.arange(4) - >>> jnp.ones_like(x) - Array([1, 1, 1, 1], dtype=int32) - >>> jnp.ones_like(x, dtype=bool) - Array([ True, True, True, True], dtype=bool) - >>> jnp.ones_like(x, shape=(2, 3)) - Array([[1, 1, 1], - [1, 1, 1]], dtype=int32) - """ - if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing - util.check_arraylike("ones_like", a) - dtypes.check_user_dtype_supported(dtype, "ones_like") - if shape is not None: - shape = canonicalize_shape(shape) - return lax.full_like(a, 1, dtype, shape, sharding=_normalize_to_sharding(device)) - - -@export -def empty_like(prototype: ArrayLike | DuckTypedArray, - dtype: DTypeLike | None = None, - shape: Any = None, *, - device: xc.Device | Sharding | None = None) -> Array: - """Create an empty array with the same shape and dtype as an array. - - JAX implementation of :func:`numpy.empty_like`. Because XLA cannot create - an un-initialized array, :func:`jax.numpy.empty` will always return an - array full of zeros. - - Args: - a: Array-like object with ``shape`` and ``dtype`` attributes. - shape: optionally override the shape of the created array. - dtype: optionally override the dtype of the created array. - device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - Array of the specified shape and dtype, on the specified device if specified. - - See also: - - :func:`jax.numpy.empty` - - :func:`jax.numpy.zeros_like` - - :func:`jax.numpy.ones_like` - - :func:`jax.numpy.full_like` - - Examples: - >>> x = jnp.arange(4) - >>> jnp.empty_like(x) - Array([0, 0, 0, 0], dtype=int32) - >>> jnp.empty_like(x, dtype=bool) - Array([False, False, False, False], dtype=bool) - >>> jnp.empty_like(x, shape=(2, 3)) - Array([[0, 0, 0], - [0, 0, 0]], dtype=int32) - """ - if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing - util.check_arraylike("empty_like", prototype) - dtypes.check_user_dtype_supported(dtype, "empty_like") - return zeros_like(prototype, dtype=dtype, shape=shape, device=device) - - -def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | None: - if isinstance(device, xc.Device): - return SingleDeviceSharding(device) - else: - return device - - -@export -def full(shape: Any, fill_value: ArrayLike, - dtype: DTypeLike | None = None, *, - device: xc.Device | Sharding | None = None) -> Array: - """Create an array full of a specified value. - - JAX implementation of :func:`numpy.full`. - - Args: - shape: int or sequence of ints specifying the shape of the created array. - fill_value: scalar or array with which to fill the created array. - dtype: optional dtype for the created array; defaults to the dtype of the - fill value. - device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - Array of the specified shape and dtype, on the specified device if specified. - - See also: - - :func:`jax.numpy.full_like` - - :func:`jax.numpy.empty` - - :func:`jax.numpy.zeros` - - :func:`jax.numpy.ones` - - Examples: - >>> jnp.full(4, 2, dtype=float) - Array([2., 2., 2., 2.], dtype=float32) - >>> jnp.full((2, 3), 0, dtype=bool) - Array([[False, False, False], - [False, False, False]], dtype=bool) - - `fill_value` may also be an array that is broadcast to the specified shape: - - >>> jnp.full((2, 3), fill_value=jnp.arange(3)) - Array([[0, 1, 2], - [0, 1, 2]], dtype=int32) - """ - dtypes.check_user_dtype_supported(dtype, "full") - util.check_arraylike("full", fill_value) - - if ndim(fill_value) == 0: - shape = canonicalize_shape(shape) - return lax.full(shape, fill_value, dtype, sharding=_normalize_to_sharding(device)) - else: - return jax.device_put( - broadcast_to(asarray(fill_value, dtype=dtype), shape), device) - - -@export -def full_like(a: ArrayLike | DuckTypedArray, - fill_value: ArrayLike, dtype: DTypeLike | None = None, - shape: Any = None, *, - device: xc.Device | Sharding | None = None) -> Array: - """Create an array full of a specified value with the same shape and dtype as an array. - - JAX implementation of :func:`numpy.full_like`. - - Args: - a: Array-like object with ``shape`` and ``dtype`` attributes. - fill_value: scalar or array with which to fill the created array. - shape: optionally override the shape of the created array. - dtype: optionally override the dtype of the created array. - device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - Array of the specified shape and dtype, on the specified device if specified. - - See also: - - :func:`jax.numpy.full` - - :func:`jax.numpy.empty_like` - - :func:`jax.numpy.zeros_like` - - :func:`jax.numpy.ones_like` - - Examples: - >>> x = jnp.arange(4.0) - >>> jnp.full_like(x, 2) - Array([2., 2., 2., 2.], dtype=float32) - >>> jnp.full_like(x, 0, shape=(2, 3)) - Array([[0., 0., 0.], - [0., 0., 0.]], dtype=float32) - - `fill_value` may also be an array that is broadcast to the specified shape: - - >>> x = jnp.arange(6).reshape(2, 3) - >>> jnp.full_like(x, fill_value=jnp.array([[1], [2]])) - Array([[1, 1, 1], - [2, 2, 2]], dtype=int32) - """ - if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing - util.check_arraylike("full_like", 0, fill_value) - else: - util.check_arraylike("full_like", a, fill_value) - dtypes.check_user_dtype_supported(dtype, "full_like") - if shape is not None: - shape = canonicalize_shape(shape) - if ndim(fill_value) == 0: - return lax.full_like(a, fill_value, dtype, shape, sharding=_normalize_to_sharding(device)) - else: - shape = np.shape(a) if shape is None else shape # type: ignore[arg-type] - dtype = result_type(a) if dtype is None else dtype - return jax.device_put( - broadcast_to(asarray(fill_value, dtype=dtype), shape), device) - - -@export -def zeros(shape: Any, dtype: DTypeLike | None = None, *, - device: xc.Device | Sharding | None = None) -> Array: - """Create an array full of zeros. - - JAX implementation of :func:`numpy.zeros`. - - Args: - shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. - device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - Array of the specified shape and dtype, on the specified device if specified. - - See also: - - :func:`jax.numpy.zeros_like` - - :func:`jax.numpy.empty` - - :func:`jax.numpy.ones` - - :func:`jax.numpy.full` - - Examples: - >>> jnp.zeros(4) - Array([0., 0., 0., 0.], dtype=float32) - >>> jnp.zeros((2, 3), dtype=bool) - Array([[False, False, False], - [False, False, False]], dtype=bool) - """ - if isinstance(shape, types.GeneratorType): - raise TypeError("expected sequence object with len >= 0 or a single integer") - if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m) - dtypes.check_user_dtype_supported(dtype, "zeros") - shape = canonicalize_shape(shape) - return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) - - -@export -def ones(shape: Any, dtype: DTypeLike | None = None, *, - device: xc.Device | Sharding | None = None) -> Array: - """Create an array full of ones. - - JAX implementation of :func:`numpy.ones`. - - Args: - shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. - device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - Array of the specified shape and dtype, on the specified device if specified. - - See also: - - :func:`jax.numpy.ones_like` - - :func:`jax.numpy.empty` - - :func:`jax.numpy.zeros` - - :func:`jax.numpy.full` - - Examples: - >>> jnp.ones(4) - Array([1., 1., 1., 1.], dtype=float32) - >>> jnp.ones((2, 3), dtype=bool) - Array([[ True, True, True], - [ True, True, True]], dtype=bool) - """ - if isinstance(shape, types.GeneratorType): - raise TypeError("expected sequence object with len >= 0 or a single integer") - if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m) - shape = canonicalize_shape(shape) - dtypes.check_user_dtype_supported(dtype, "ones") - return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) - - -@export -def empty(shape: Any, dtype: DTypeLike | None = None, *, - device: xc.Device | Sharding | None = None) -> Array: - """Create an empty array. - - JAX implementation of :func:`numpy.empty`. Because XLA cannot create an - un-initialized array, :func:`jax.numpy.empty` will always return an array - full of zeros. - - Args: - shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. - device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - Array of the specified shape and dtype, on the specified device if specified. - - See also: - - :func:`jax.numpy.empty_like` - - :func:`jax.numpy.zeros` - - :func:`jax.numpy.ones` - - :func:`jax.numpy.full` - - Examples: - >>> jnp.empty(4) - Array([0., 0., 0., 0.], dtype=float32) - >>> jnp.empty((2, 3), dtype=bool) - Array([[False, False, False], - [False, False, False]], dtype=bool) - """ - if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m) - dtypes.check_user_dtype_supported(dtype, "empty") - return zeros(shape, dtype, device=device) - -def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore - if isinstance(dtype, int) and isinstance(shape, int): - return (f"Cannot interpret '{dtype}' as a data type." - f"\n\nDid you accidentally write " - f"`jax.numpy.{name}({shape}, {dtype})` " - f"when you meant `jax.numpy.{name}(({shape}, {dtype}))`, i.e. " - "with a single tuple argument for the shape?") - - @export def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: """Check if two arrays are element-wise equal. @@ -6741,7 +6350,7 @@ def _arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, return _arange_dynamic(start, stop, step, dtype or dtypes.canonicalize_dtype(np.int64)) if dtype is None: dtype = result_type(start, *(x for x in [stop, step] if x is not None)) - dtype = _jnp_dtype(dtype) + dtype = dtypes.jax_dtype(dtype) if stop is None and step is None: start_dtype = _dtype(start) if (not dtypes.issubdtype(start_dtype, np.integer) and @@ -6884,7 +6493,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, if dtype is None: dtype = dtypes.to_inexact_dtype(result_type(start, stop)) - dtype = _jnp_dtype(dtype) + dtype = dtypes.jax_dtype(dtype) computation_dtype = dtypes.to_inexact_dtype(dtype) start = start.astype(computation_dtype) stop = stop.astype(computation_dtype) @@ -7004,7 +6613,7 @@ def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, dtypes.check_user_dtype_supported(dtype, "logspace") if dtype is None: dtype = dtypes.to_inexact_dtype(result_type(start, stop)) - dtype = _jnp_dtype(dtype) + dtype = dtypes.jax_dtype(dtype) computation_dtype = dtypes.to_inexact_dtype(dtype) start, stop = util.ensure_arraylike("logspace", start, stop) start = start.astype(computation_dtype) @@ -7074,7 +6683,7 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool dtypes.check_user_dtype_supported(dtype, "geomspace") if dtype is None: dtype = dtypes.to_inexact_dtype(result_type(start, stop)) - dtype = _jnp_dtype(dtype) + dtype = dtypes.jax_dtype(dtype) computation_dtype = dtypes.to_inexact_dtype(dtype) start, stop = util.ensure_arraylike("geomspace", start, stop) start = start.astype(computation_dtype) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index a6edadf92..dd9e38a0a 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -24,8 +24,11 @@ from jax._src import config from jax._src import core from jax._src import dtypes from jax._src.lax import lax +from jax._src.lib import xla_client as xc +from jax._src.sharding_impls import SingleDeviceSharding from jax._src.util import safe_zip, safe_map from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape +from jax.sharding import Sharding import numpy as np @@ -299,3 +302,10 @@ def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array: except: is_always_empty = False # can fail with dynamic shapes return lax.select(condition, x_arr, y_arr) if not is_always_empty else x_arr + + +def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Sharding | None: + if isinstance(device, xc.Device): + return SingleDeviceSharding(device) + else: + return device diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 363f1c3bc..41badcacc 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -83,8 +83,6 @@ from jax._src.numpy.lax_numpy import ( ediff1d as ediff1d, einsum as einsum, einsum_path as einsum_path, - empty as empty, - empty_like as empty_like, euler_gamma as euler_gamma, expand_dims as expand_dims, extract as extract, @@ -104,8 +102,6 @@ from jax._src.numpy.lax_numpy import ( fromiter as fromiter, fromstring as fromstring, from_dlpack as from_dlpack, - full as full, - full_like as full_like, gcd as gcd, geomspace as geomspace, get_printoptions as get_printoptions, @@ -154,8 +150,6 @@ from jax._src.numpy.lax_numpy import ( ndim as ndim, newaxis as newaxis, nonzero as nonzero, - ones as ones, - ones_like as ones_like, outer as outer, packbits as packbits, pad as pad, @@ -215,6 +209,15 @@ from jax._src.numpy.lax_numpy import ( vsplit as vsplit, vstack as vstack, where as where, +) + +from jax._src.numpy.array_creation import ( + empty as empty, + empty_like as empty_like, + full as full, + full_like as full_like, + ones as ones, + ones_like as ones_like, zeros as zeros, zeros_like as zeros_like, )