rocm_jax/jax/_src/basearray.py
Peter Hawkins 474dcd409d Remove code to support jaxlib < v0.6.
New minimum jaxlib_extension_version is 330.

PiperOrigin-RevId: 748853497
2025-04-17 16:44:41 -07:00

205 lines
6.7 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note that type annotations for this file are defined in basearray.pyi
from __future__ import annotations
from collections.abc import Sequence
import sys
from typing import Any, Union
from jax._src.lib import xla_client as xc
from jax._src.util import use_cpp_class
import numpy as np
# TODO(jakevdp): fix import cycles and define these.
Device = Any
Shard = Any
Sharding = Any
# Array is a type annotation for standard JAX arrays and tracers produced by
# core functions in jax.lax and jax.numpy; it is not meant to include
# future non-standard array types like KeyArray and BInt.
class Array:
"""Array base class for JAX
``jax.Array`` is the public interface for instance checks and type annotation
of JAX arrays and tracers. Its main applications are in instance checks and
type annotations; for example::
x = jnp.arange(5)
isinstance(x, jax.Array) # returns True both inside and outside traced functions.
def f(x: Array) -> Array: # type annotations are valid for traced and non-traced types.
return x
``jax.Array`` should not be used directly for creation of arrays; instead you
should use array creation routines offered in :mod:`jax.numpy`, such as
:func:`jax.numpy.array`, :func:`jax.numpy.zeros`, :func:`jax.numpy.ones`,
:func:`jax.numpy.full`, :func:`jax.numpy.arange`, etc.
"""
# For the sake of static type analysis, these definitions are mirrored in the
# associated basearray.pyi file.
__slots__ = ['__weakref__']
__hash__ = None
@property
def dtype(self) -> np.dtype:
"""The data type (:class:`numpy.dtype`) of the array."""
raise NotImplementedError
@property
def ndim(self) -> int:
"""The number of dimensions in the array."""
raise NotImplementedError
@property
def size(self) -> int:
"""The total number of elements in the array."""
raise NotImplementedError
@property
def shape(self) -> tuple[int, ...]:
"""The shape of the array."""
raise NotImplementedError
# Documentation for sharding-related methods and properties defined on ArrayImpl:
def addressable_data(self, index: int) -> Array:
"""Return an array of the addressable data at a particular index."""
raise NotImplementedError
@property
def addressable_shards(self) -> Sequence[Shard]:
"""List of addressable shards."""
raise NotImplementedError
@property
def global_shards(self) -> Sequence[Shard]:
"""List of global shards."""
raise NotImplementedError
@property
def is_fully_addressable(self) -> bool:
"""Is this Array fully addressable?
A jax.Array is fully addressable if the current process can address all of
the devices named in the :class:`Sharding`. ``is_fully_addressable`` is
equivalent to "is_local" in multi-process JAX.
Note that fully replicated is not equal to fully addressable i.e.
a jax.Array which is fully replicated can span across multiple hosts and is
not fully addressable.
"""
raise NotImplementedError
@property
def is_fully_replicated(self) -> bool:
"""Is this Array fully replicated?"""
raise NotImplementedError
@property
def sharding(self) -> Sharding:
"""The sharding for the array."""
raise NotImplementedError
@property
def committed(self) -> bool:
"""Whether the array is committed or not.
An array is committed when it is explicitly placed on device(s) via JAX
APIs. For example, `jax.device_put(np.arange(8), jax.devices()[0])` is
committed to device 0. While `jax.device_put(np.arange(8))` is uncommitted
and will be placed on the default device.
Computations involving some committed inputs will happen on the committed
device(s) and the result will be committed on the same device(s).
Invoking an operation on arguments that are committed to different device(s)
will raise an error.
For example:
```
a = jax.device_put(np.arange(8), jax.devices()[0])
b = jax.device_put(np.arange(8), jax.devices()[1])
a + b # Raises an error
```
See https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
for more information.
"""
raise NotImplementedError
@property
def device(self) -> Device | Sharding:
"""Array API-compatible device attribute.
For single-device arrays, this returns a Device. For sharded arrays, this
returns a Sharding.
"""
raise NotImplementedError
def copy_to_host_async(self):
"""Copies an ``Array`` to the host asynchronously.
For arrays that live an an accelerator, such as a GPU or a TPU, JAX may
cache the value of the array on the host. Normally this happens
behind the scenes when the value of an on-device array is requested by the
user, but waiting to initiate a device-to-host copy until the value is
requested requires that JAX block the caller while waiting for the copy to
complete.
``copy_to_host_async`` requests that JAX populate its on-host cache of an
array, but does not wait for the copy to complete. This may speed up a
future on-host access to the array's contents.
"""
raise NotImplementedError
Array = use_cpp_class(xc.Array)(Array)
Array.__module__ = "jax"
# StaticScalar is the Union of all scalar types that can be converted to
# JAX arrays, and are possible to mark as static arguments.
StaticScalar = Union[
np.bool_, np.number, # NumPy scalar types
bool, int, float, complex, # Python scalar types
]
if sys.version_info[:2] < (3, 14):
# Python 3.14 raises
# AttributeError: 'typing.Union' object attribute '__doc__' is read-only
StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars."
# ArrayLike is a Union of all objects that can be implicitly converted to a
# standard JAX array (i.e. not including future non-standard array types like
# KeyArray and BInt). It's different than np.typing.ArrayLike in that it doesn't
# accept arbitrary sequences, nor does it accept string data.
ArrayLike = Union[
Array, # JAX array type
np.ndarray, # NumPy array type
StaticScalar, # valid scalars
]
if sys.version_info[:2] < (3, 14):
# Python 3.14 raises
# AttributeError: 'typing.Union' object attribute '__doc__' is read-only
ArrayLike.__doc__ = "Type annotation for JAX array-like objects."