From 634035abd7018d88f9d4db403a2ed257114e0919 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 15 Mar 2023 12:59:33 -0700 Subject: [PATCH] Remove GDA from JAX since jax.Array is the default type and cannot be disabled anymore as per https://jax.readthedocs.io/en/latest/jax_array_migration.html#how-can-i-disable-jax-array-for-now PiperOrigin-RevId: 516905931 --- docs/jax.experimental.global_device_array.rst | 11 - docs/jax.experimental.rst | 1 - jax/BUILD | 11 - jax/_src/global_device_array.py | 650 ------------------ jax/_src/interpreters/pxla.py | 19 +- jax/_src/maps.py | 29 +- jax/_src/pjit.py | 110 +-- .../gda_serialization/serialization.py | 32 +- jax/experimental/global_device_array.py | 26 - jax/experimental/jax2tf/jax2tf.py | 14 +- jax/experimental/multihost_utils.py | 35 +- 11 files changed, 32 insertions(+), 906 deletions(-) delete mode 100644 docs/jax.experimental.global_device_array.rst delete mode 100644 jax/_src/global_device_array.py delete mode 100644 jax/experimental/global_device_array.py diff --git a/docs/jax.experimental.global_device_array.rst b/docs/jax.experimental.global_device_array.rst deleted file mode 100644 index 40fb1c95a..000000000 --- a/docs/jax.experimental.global_device_array.rst +++ /dev/null @@ -1,11 +0,0 @@ -``jax.experimental.global_device_array`` module -=============================================== - -.. automodule:: jax.experimental.global_device_array - -API ---- - -.. autoclass:: GlobalDeviceArray - :members: -.. autoclass:: Shard diff --git a/docs/jax.experimental.rst b/docs/jax.experimental.rst index 2cbb364f6..ffcc18ba6 100644 --- a/docs/jax.experimental.rst +++ b/docs/jax.experimental.rst @@ -15,7 +15,6 @@ Experimental Modules :maxdepth: 1 jax.experimental.checkify - jax.experimental.global_device_array jax.experimental.host_callback jax.experimental.maps jax.experimental.pjit diff --git a/jax/BUILD b/jax/BUILD index b42435876..0cd20ce9d 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -107,7 +107,6 @@ py_library_providing_imports_info( "_src/dtypes.py", "_src/errors.py", "_src/flatten_util.py", - "_src/global_device_array.py", "_src/__init__.py", "_src/lax_reference.py", "_src/linear_util.py", @@ -381,9 +380,6 @@ py_library_providing_imports_info( "experimental/*.py", "example_libraries/*.py", ], - exclude = [ - "experimental/global_device_array.py", - ], ), visibility = ["//visibility:public"], deps = [ @@ -506,10 +502,3 @@ pytype_library( visibility = ["//visibility:public"], deps = [":jax"], ) - -pytype_library( - name = "global_device_array", - srcs = ["experimental/global_device_array.py"], - visibility = [":internal"], - deps = [":jax"], -) diff --git a/jax/_src/global_device_array.py b/jax/_src/global_device_array.py deleted file mode 100644 index a1c97fd45..000000000 --- a/jax/_src/global_device_array.py +++ /dev/null @@ -1,650 +0,0 @@ -# Copyright 2021 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import Counter -import dataclasses -import functools -import math -import warnings -import numpy as np -from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple - -import jax -from jax._src import core -from jax._src import dispatch -from jax._src import api_util -from jax._src.lib import xla_client as xc -from jax._src.config import config -from jax._src.interpreters import pxla -from jax.interpreters import xla, mlir -from jax._src.util import safe_zip -from jax._src.interpreters.pxla import PartitionSpec - -Shape = Tuple[int, ...] -MeshAxes = PartitionSpec -DeviceArray = xc.Buffer -Device = xc.Device -ArrayLike = Union[np.ndarray, DeviceArray] -Index = Tuple[slice, ...] - - -_hashed_index = lambda x: hash(tuple((v.start, v.stop) for v in x)) - - -def _get_sharding_spec(global_shape, global_mesh, mesh_axes): - array_mapping = pxla.get_array_mapping(mesh_axes) - # The dtype doesn't matter for creating sharding specs. - aval = core.ShapedArray(global_shape, np.float32) - return pxla.mesh_sharding_specs(global_mesh.shape, - global_mesh.axis_names)(aval, array_mapping) - - -def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh, - mesh_axes: MeshAxes) -> Tuple[Index, ...]: - sharding_spec = _get_sharding_spec(global_shape, global_mesh, mesh_axes) - indices = pxla.spec_to_indices(global_shape, sharding_spec) - return indices # type: ignore - - -@functools.lru_cache(maxsize=4096) -def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh, - mesh_axes: MeshAxes) -> Mapping[Device, Index]: - indices = _get_indices(global_shape, global_mesh, mesh_axes) - # The type: ignore is to ignore the type returned by `spec_to_indices`. - return { - d: i - for d, i in safe_zip(global_mesh.devices.flat, indices)} # type: ignore - - -@functools.lru_cache(maxsize=4096) -def get_shard_indices_replica_ids( - global_shape: Shape, global_mesh: pxla.Mesh, - mesh_axes: MeshAxes) -> Mapping[Device, Tuple[Index, int]]: - return _get_shard_indices_replica_ids_uncached(global_shape, global_mesh, mesh_axes) - -def _get_shard_indices_replica_ids_uncached( - global_shape: Shape, global_mesh: pxla.Mesh, - mesh_axes: MeshAxes) -> Mapping[Device, Tuple[Index, int]]: - indices = _get_indices(global_shape, global_mesh, mesh_axes) - index_to_replica: Dict[int, int] = Counter() - out = {} - unique_shards = 0 - for device, index in safe_zip(global_mesh.devices.flat, indices): - h_index = _hashed_index(index) - replica_id = index_to_replica[h_index] - if replica_id == 0: - unique_shards += 1 - index_to_replica[h_index] += 1 - out[device] = (index, replica_id) - - shard_shape = get_shard_shape(global_shape, global_mesh, mesh_axes) - expected_unique_shards = math.prod( - [g // s for g, s in safe_zip(global_shape, shard_shape) if g != 0 or s != 0]) - if expected_unique_shards != unique_shards: - raise RuntimeError( - f'Number of expected unique shards are: {expected_unique_shards} but ' - f'got {unique_shards}. Please file a bug at ' - 'https://github.com/google/jax/issues.') - return out - - -@functools.lru_cache(maxsize=4096) -def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape: - chunk_size = [] - for mesh_axis, size in zip(mesh_axes, global_shape): - if not mesh_axis: - chunk_size.append(size) - elif isinstance(mesh_axis, tuple): - m = math.prod([global_mesh.shape[ma] for ma in mesh_axis]) - chunk_size.append(size // m) - else: - chunk_size.append(size // global_mesh.shape[mesh_axis]) - if len(chunk_size) != len(global_shape): - chunk_size.extend(global_shape[len(chunk_size):]) - return tuple(chunk_size) - - -@dataclasses.dataclass(frozen=True) -class Shard: - """A single data shard of a GlobalDeviceArray. - - Args: - device : Which device this shard resides on. - index : The index into the global array of this shard. - replica_id : Integer id indicating which replica of the global array this - shard is part of. Always 0 for fully sharded data - (i.e. when there’s only 1 replica). - data : The data of this shard. None if ``device`` is non-local. - """ - device: Device - index: Index - replica_id: int - # None if this `Shard` lives on a non-local device. - data: Optional[DeviceArray] = None - - -class _GdaFastPathArgs(NamedTuple): - global_indices_replica_ids: Mapping[Device, Tuple[Index, int]] - local_devices: Sequence[Device] - - -class GlobalDeviceArray: - """A logical array with data sharded across multiple devices and processes. - - If you’re not already familiar with JAX’s multi-process programming model, - please read https://jax.readthedocs.io/en/latest/multi_process.html. - You can also read about pjit (https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html) - to learn about ``Mesh``, ``PartitionSpec`` and how arrays can be - partitioned or replicated. - - A GlobalDeviceArray (GDA) can be thought of as a view into a single logical - array sharded across processes. The logical array is the “global” array, and - each process has a GlobalDeviceArray object referring to the same global array - (similarly to how each process runs a multi-process pmap or pjit). Each process - can access the shape, dtype, etc. of the global array via the GDA, pass the - GDA into multi-process pjits, and get GDAs as pjit outputs (coming soon: xmap - and pmap). However, each process can only directly access the shards of the - global array data stored on its local devices. - - GDAs can help manage the inputs and outputs of multi-process computations. - A GDA keeps track of which shard of the global array belongs to which device, - and provides callback-based APIs to materialize the correct shard of the data - needed for each local device of each process. - - A GDA consists of data shards. Each shard is stored on a different device. - There are local shards and global shards. Local shards are those on local - devices, and the data is visible to the current process. Global shards are - those across all devices (including local devices), and the data isn’t visible - if the shard is on a non-local device with respect to the current process. - Please see the ``Shard`` class to see what information is stored inside that - data structure. - - Note: to make pjit output GlobalDeviceArrays, set the environment variable - ``JAX_PARALLEL_FUNCTIONS_OUTPUT_GDA=true`` or add the following to your code: - ``jax.config.update('jax_parallel_functions_output_gda', True)`` - - Args: - global_shape : The global shape of the array. - global_mesh : The global mesh representing devices across multiple - processes. - mesh_axes : A sequence with length less than or equal to the rank of the - global array (i.e. the length of the global shape). Each element can be: - - * An axis name of ``global_mesh``, indicating that the corresponding - global array axis is partitioned across the given device axis of - ``global_mesh``. - * A tuple of axis names of ``global_mesh``. This is like the above option - except the global array axis is partitioned across the product of axes - named in the tuple. - * None indicating that the corresponding global array axis is not - partitioned. - - For more information, please see: - https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html - device_buffers: DeviceArrays that are on the local devices of ``global_mesh``. - - Attributes: - shape : Global shape of the array. - dtype : Dtype of the global array. - ndim : Number of array dimensions in the global shape. - size: Number of elements in the global array. - local_shards : List of :class:`Shard` on the local devices of the current process. - Data is materialized for all local shards. - global_shards : List of all :class:`Shard` of the global array. Data isn’t - available if a shard is on a non-local device with respect to the current - process. - is_fully_replicated : True if the full array value is present on all devices - of the global mesh. - - Example: - - >>> from jax.sharding import Mesh - >>> from jax.sharding import PartitionSpec as P - >>> import numpy as np - ... - >>> assert jax.device_count() == 8 - >>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y')) - >>> # Logical mesh is (hosts, devices) - >>> assert global_mesh.shape == {'x': 4, 'y': 2} - >>> global_input_shape = (8, 2) - >>> mesh_axes = P('x', 'y') - ... - >>> # Dummy example data; in practice we wouldn't necessarily materialize global data - >>> # in a single process. - >>> global_input_data = np.arange( - ... np.prod(global_input_shape)).reshape(global_input_shape) - ... - >>> def get_local_data_slice(index): - ... # index will be a tuple of slice objects, e.g. (slice(0, 16), slice(0, 4)) - ... # This method will be called per-local device from the GDA constructor. - ... return global_input_data[index] - ... - >>> gda = GlobalDeviceArray.from_callback( - ... global_input_shape, global_mesh, mesh_axes, get_local_data_slice) - >>> print(gda.shape) - (8, 2) - >>> print(gda.addressable_shards[0].data) # Access the data on a single local device - [[0] - [2]] - >>> print(gda.addressable_shards[0].data.shape) - (2, 1) - >>> # Numpy-style index into the global array that this data shard corresponds to - >>> print(gda.addressable_shards[0].index) - (slice(0, 2, None), slice(0, 1, None)) - - GDAs can also be given as an input to pjit and you can get GDAs as output from pjit:: - - # Allow pjit to output GDAs - jax.config.update('jax_parallel_functions_output_gda', True) - - f = pjit(lambda x: x @ x.T, in_shardings=P('x', 'y'), out_axis_resources = P('x', 'y')) - with global_mesh: - out = f(gda) - - # `out` can be passed to another pjit call, out.addressable_shards can be used to - # export the data to non-jax systems (e.g. for checkpointing or logging), etc. - - """ - - def __init__(self, - global_shape: Shape, - global_mesh: pxla.Mesh, - mesh_axes: MeshAxes, - device_buffers: Sequence[DeviceArray], - _gda_fast_path_args: Optional[_GdaFastPathArgs] = None, - _enable_checks: bool = True): - warnings.warn( - "GlobalDeviceArray has been deprecated. Please migrate to jax.Array. " - "See https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration " - "on how to migrate to jax.Array.", DeprecationWarning) - self._global_shape = global_shape - self._global_mesh = global_mesh - self._mesh_axes = mesh_axes - self._device_buffers = device_buffers - - # Optionally precomputed for performance. - self._gda_fast_path_args = _gda_fast_path_args - - if self._gda_fast_path_args is None: - self._local_devices = self._global_mesh.local_devices - else: - self._local_devices = self._gda_fast_path_args.local_devices - - if _enable_checks or config.jax_enable_checks: - for db, ld in safe_zip(self._device_buffers, self._local_devices): - if db.device() != ld: - raise ValueError( - "The `global_mesh.local_devices` and `device_buffers` device " - "order doesn't match. Please use `global_mesh.local_devices` to " - "put arrays on devices instead of `jax.local_devices()`") - - if _enable_checks or config.jax_enable_checks: - ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes) - assert all(db.shape == ss for db in self._device_buffers), ( - f"Expected shard shape {ss} doesn't match the device buffer " - f"shape, got: {[db.shape for db in self._device_buffers]}") - - dtype = device_buffers[0].dtype # type: ignore - if _enable_checks or config.jax_enable_checks: - assert all(db.dtype == dtype for db in self._device_buffers), ( - "Input arrays to GlobalDeviceArray must have matching dtypes, " - f"got: {[db.dtype for db in self._device_buffers]}") - self.dtype = dtype - - def __eq__(self, other: object): - raise NotImplementedError( - "GlobalDeviceArray equality is intentionally unimplemented. " - "Implement desired functionality explicitly, e.g. to check if all " - "values are equal: " - "pjit(lambda x, y: x == y, " - "in_shardings=FROM_GDA, out_shardings=None)" - ) - - def __str__(self): - return f'GlobalDeviceArray(shape={self.shape}, dtype={self.dtype})' - - def __repr__(self): - return (f'GlobalDeviceArray(shape={self.shape}, dtype={self.dtype}, ' - f'global_mesh_shape={dict(self.mesh.shape)}, ' - f'mesh_axes={self.mesh_axes})') - - @property - def shape(self) -> Shape: - return self._global_shape - - @property - def ndim(self): - return len(self.shape) - - @property - def size(self): - return math.prod(self.shape) - - @property - def mesh(self): - return self._global_mesh - - @property - def mesh_axes(self) -> MeshAxes: - return self._mesh_axes - - @property - def is_fully_replicated(self) -> bool: - return self.shape == self.addressable_data(0).shape - - def _create_local_shards(self) -> Sequence[Shard]: - if self._gda_fast_path_args is not None: - global_indices_rid = self._gda_fast_path_args.global_indices_replica_ids - else: - global_indices_rid = get_shard_indices_replica_ids( - self._global_shape, self._global_mesh, self.mesh_axes) - - out = [] - for db in self._device_buffers: - db = dispatch._set_aval(db) - device = db.device() - index, rid = global_indices_rid[device] - out.append(Shard(device, index, rid, db)) - return out - - @functools.cached_property - def local_shards(self) -> Sequence[Shard]: - self._check_if_deleted() - return self._create_local_shards() - - @functools.cached_property - def addressable_shards(self) -> Sequence[Shard]: - self._check_if_deleted() - return self.local_shards - - @property - def global_shards(self) -> Sequence[Shard]: - self._check_if_deleted() - if self.mesh.size == len(self._local_devices): - return self.addressable_shards - - # Populating global_shards lazily (i.e. when requested) because populating - # sthem eagerly leads to a performance regression when training on large - # models. - # Also as this a cached property, once calculated, it should be cached. So - # multiple accesses should be cheap. - global_indices_rid = get_shard_indices_replica_ids( - self._global_shape, self._global_mesh, self.mesh_axes) - device_to_buffer = {db.device(): db for db in self._device_buffers} - global_shards = [] - for device, (index, rid) in global_indices_rid.items(): - local_shard = device.process_index == device.client.process_index() - buf = device_to_buffer[device] if local_shard else None - if buf is not None and buf.aval is None: - buf.aval = core.ShapedArray(buf.shape, buf.dtype) - sh = Shard(device, index, rid, buf) - global_shards.append(sh) - return global_shards - - @property - def _value(self): - self._check_if_deleted() - if self.is_fully_replicated: - return np.asarray(self._device_buffers[0]) - - if self.mesh.is_multi_process: - raise RuntimeError("Fetching value for GDA that spans non-addressable " - "devices is not possible. You can use " - "`jax.experimental.multihost_utils.process_allgather` " - "for this use case.") - unique_shards = [s.data.copy_to_host_async() or s - for s in self.addressable_shards if s.replica_id == 0] - npy_value = np.empty(self.shape, self.dtype) - for s in unique_shards: - npy_value[s.index] = np.asarray(s.data) - return npy_value - - def __array__(self, dtype=None, context=None): - self._check_if_deleted() - return self._value if dtype is None else self._value.astype(dtype) - - def local_data(self, index) -> DeviceArray: - self._check_if_deleted() - return dispatch._set_aval(self._device_buffers[index]) - - def addressable_data(self, index) -> DeviceArray: - self._check_if_deleted() - return self.local_data(index) - - def block_until_ready(self): - self._check_if_deleted() - for db in self._device_buffers: - db.block_until_ready() - return self - - def _check_if_deleted(self): - if self.is_deleted(): - raise RuntimeError("GlobalDeviceArray has been deleted.") - - def is_deleted(self): - return self._device_buffers is None - - def delete(self): - for b in self._device_buffers: - b.delete() - self._device_buffers = None - - @property - def sharding(self): - return jax.sharding.NamedSharding(self._global_mesh, self.mesh_axes) - - @classmethod - def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh, - mesh_axes: MeshAxes, data_callback: Callable[[Index], - ArrayLike]): - """Constructs a GlobalDeviceArray via data fetched from ``data_callback``. - - ``data_callback`` is used to fetch the data for each local slice of the returned GlobalDeviceArray. - - Example: - - >>> from jax.sharding import Mesh - >>> from jax.sharding import PartitionSpec as P - >>> import numpy as np - ... - >>> global_input_shape = (8, 8) - >>> mesh_axes = P('x', 'y') - >>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) - >>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape) - ... - >>> def cb(index): - ... return global_input_data[index] - ... - >>> gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) - >>> gda.addressable_data(0).shape - (4, 2) - - Args: - global_shape : The global shape of the array - global_mesh : The global mesh representing devices across multiple - processes. - mesh_axes : See the ``mesh_axes`` parameter of GlobalDeviceArray. - data_callback : Callback that takes indices into the global array value as input and - returns the corresponding data of the global array value. The data can be returned - as any array-like object, e.g. a ``numpy.ndarray``. - """ - global_indices_rid = get_shard_indices_replica_ids( - global_shape, global_mesh, mesh_axes) - local_devices = global_mesh.local_devices - dbs = [ - jax.device_put(data_callback(global_indices_rid[device][0]), device) - for device in local_devices - ] - if config.jax_array: - return jax.make_array_from_single_device_arrays( - global_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), dbs) - return cls(global_shape, global_mesh, mesh_axes, dbs, - _gda_fast_path_args=_GdaFastPathArgs(global_indices_rid, local_devices)) - - @classmethod - def from_batched_callback(cls, global_shape: Shape, - global_mesh: pxla.Mesh, mesh_axes: MeshAxes, - data_callback: Callable[[Sequence[Index]], - Sequence[ArrayLike]]): - """Constructs a GlobalDeviceArray via batched data fetched from ``data_callback``. - - Like ``from_callback``, except the callback function is called only once to fetch all data - local to this process. - - Example: - - >>> from jax.sharding import Mesh - >>> from jax.sharding import PartitionSpec as P - >>> import numpy as np - ... - >>> global_input_shape = (8, 2) - >>> mesh_axes = P('x') - >>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y')) - >>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape) - ... - >>> def batched_cb(indices): - ... assert len(indices) == len(global_mesh.local_devices) - ... return [global_input_data[index] for index in indices] - ... - >>> gda = GlobalDeviceArray.from_batched_callback(global_input_shape, global_mesh, mesh_axes, batched_cb) - >>> gda.addressable_data(0).shape - (2, 2) - - Args: - global_shape : The global shape of the array - global_mesh : The global mesh representing devices across multiple - processes. - mesh_axes : See the ``mesh_axes`` parameter of GlobalDeviceArray. - data_callback : Callback that takes a batch of indices into the global array value with - length equal to the number of local devices as input and returns the corresponding data for each index. - The data can be returned as any array-like objects, e.g. ``numpy.ndarray`` - """ - global_indices_rid = get_shard_indices_replica_ids( - global_shape, global_mesh, mesh_axes) - local_devices = global_mesh.local_devices - local_indices = [global_indices_rid[d][0] for d in local_devices] - local_arrays = data_callback(local_indices) - dbs = pxla.device_put(local_arrays, local_devices) - if config.jax_array: - return jax.make_array_from_single_device_arrays( - global_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), dbs) # type: ignore - return cls(global_shape, global_mesh, mesh_axes, dbs, - _gda_fast_path_args=_GdaFastPathArgs(global_indices_rid, local_devices)) - - @classmethod - def from_batched_callback_with_devices( - cls, global_shape: Shape, global_mesh: pxla.Mesh, - mesh_axes: MeshAxes, - data_callback: Callable[[Sequence[Tuple[Index, Tuple[Device, ...]]]], - Sequence[DeviceArray]]): - """Constructs a GlobalDeviceArray via batched DeviceArrays fetched from ``data_callback``. - - Like ``from_batched_callback``, except the callback function is responsible for returning on-device data (e.g. by calling ``jax.device_put``). - - Example: - - >>> from jax.sharding import Mesh - >>> from jax.sharding import PartitionSpec as P - >>> import numpy as np - ... - >>> global_input_shape = (8, 2) - >>> mesh_axes = P(('x', 'y')) - >>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y')) - >>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape) - ... - >>> def cb(cb_inp): - ... dbs = [] - ... for inp in cb_inp: - ... index, devices = inp - ... array = global_input_data[index] - ... dbs.extend([jax.device_put(array, device) for device in devices]) - ... return dbs - ... - >>> gda = GlobalDeviceArray.from_batched_callback_with_devices( - ... global_input_shape, global_mesh, mesh_axes, cb) - >>> gda.addressable_data(0).shape - (1, 2) - - Args: - global_shape : The global shape of the array - global_mesh : The global mesh representing devices across multiple - processes. - mesh_axes : See the ``mesh_axes`` parameter of GlobalDeviceArray. - data_callback : Callback that takes agets batch of indices into the global array value with - length equal to the number of local devices as input and returns the corresponding data for - each index. The data must be returned as jax DeviceArrays. - """ - global_indices_rid = get_shard_indices_replica_ids( - global_shape, global_mesh, mesh_axes) - local_devices = global_mesh.local_devices - - index_to_device: Dict[int, Tuple[Index, List[Device]]] = {} - for device in local_devices: - index = global_indices_rid[device][0] - h_index = _hashed_index(index) - if h_index not in index_to_device: - index_to_device[h_index] = (index, [device]) - else: - index_to_device[h_index][1].append(device) - - cb_inp = [ - (index, tuple(devices)) for index, devices in index_to_device.values() - ] - dbs = data_callback(cb_inp) - if config.jax_array: - return jax.make_array_from_single_device_arrays( - global_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), dbs) # type: ignore - return cls(global_shape, global_mesh, mesh_axes, dbs, - _gda_fast_path_args=_GdaFastPathArgs(global_indices_rid, local_devices)) - - -core.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray( - x.shape, x.dtype) -xla.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray( - x.shape, x.dtype) -xla.canonicalize_dtype_handlers[GlobalDeviceArray] = pxla.identity -api_util._shaped_abstractify_handlers[GlobalDeviceArray] = \ - lambda x: core.ShapedArray(x.shape, x.dtype) - -# This will only work when GDA is fully addressable i.e. on a single host or -# fully replicated. -def _gda_mlir_constant_handler(val, canonicalize_types=True): - return mlir.ir_constants(val._value, - canonicalize_types=canonicalize_types) -mlir.register_constant_handler(GlobalDeviceArray, _gda_mlir_constant_handler) - - -def _gda_shard_arg(x, devices, indices, sharding): - x._check_if_deleted() - return x._device_buffers -pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg - - -def _gda_array_result_handler(global_aval, out_sharding, committed, - is_out_sharding_from_xla): - if core.is_opaque_dtype(global_aval.dtype): - return global_aval.dtype._rules.global_sharded_result_handler( - global_aval, out_sharding, committed, is_out_sharding_from_xla) - global_mesh, out_axis_resources = out_sharding.mesh, out_sharding.spec - global_idx_rid = get_shard_indices_replica_ids(global_aval.shape, global_mesh, - out_axis_resources) - local_devices = global_mesh.local_devices - fast_path_args = _GdaFastPathArgs(global_idx_rid, local_devices) - return lambda bufs: GlobalDeviceArray( - global_aval.shape, global_mesh, out_axis_resources, bufs, fast_path_args, - _enable_checks=False) -pxla.global_result_handlers[ - (core.ShapedArray, pxla.OutputType.GlobalDeviceArray)] = _gda_array_result_handler -pxla.global_result_handlers[ - (core.ConcreteArray, pxla.OutputType.GlobalDeviceArray)] = _gda_array_result_handler diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 11ee28677..8bb1ffb64 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -3706,30 +3706,21 @@ def check_device_backend_on_shardings(shardings) -> bool: def check_gda_or_array_xla_sharding_match( args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> None: - from jax._src.global_device_array import GlobalDeviceArray from jax._src.array import ArrayImpl for arg, xs in safe_zip(args, in_xla_shardings): - if not isinstance(arg, (GlobalDeviceArray, ArrayImpl)): + if not isinstance(arg, ArrayImpl): continue - if isinstance(arg, GlobalDeviceArray): - arg_sharding = create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes) - arg_type = 'GDA' - committed = True - else: - arg_sharding = arg.sharding - arg_type = 'Array' - committed = arg._committed # No need to cache this check since MeshExecutable has a C++ fast path # for AOT compiled call. if (not check_device_backend_on_shardings([xs]) and - committed and - not are_op_shardings_equal(arg_sharding._to_xla_op_sharding(arg.ndim), + arg._committed and + not are_op_shardings_equal(arg.sharding._to_xla_op_sharding(arg.ndim), xs._to_xla_op_sharding(arg.ndim))): raise ValueError( - f"{arg_type} sharding does not match the input sharding. " - f"Got {arg_type} sharding: {arg_sharding} and xla sharding: {xs} for " + f"Array sharding does not match the input sharding. " + f"Got Array sharding: {arg.sharding} and xla sharding: {xs} for " f"arg shape: {arg.shape}, arg value: {arg}") diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 212a22fab..ff7884fa7 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -37,7 +37,6 @@ from jax._src import traceback_util from jax._src.config import config from jax.errors import JAXTypeError from jax._src.array import ArrayImpl -from jax._src.global_device_array import GlobalDeviceArray from jax._src.sharding_impls import NamedSharding from jax._src.interpreters import mlir from jax.interpreters import partial_eval as pe @@ -521,13 +520,7 @@ def xmap(fun: Callable, lambda: tuple(_flatten_axes("xmap out_axes", out_tree(), out_axes, tupled_args=False)), closure=(out_axes_entries, out_axes_treedef)) - if config.jax_array: - in_positional_semantics = (_PositionalSemantics.GLOBAL,) * len(args_flat) - else: - in_positional_semantics = tuple( - _PositionalSemantics.GLOBAL - if isinstance(a, GlobalDeviceArray) else _positional_semantics.val - for a in args_flat) + in_positional_semantics = (_PositionalSemantics.GLOBAL,) * len(args_flat) out_positional_semantics = ( _PositionalSemantics.GLOBAL if config.jax_array or config.jax_parallel_functions_output_gda @@ -1638,7 +1631,7 @@ def _get_axis_sizes(args_flat: Iterable[Any], in_axes_flat: Iterable[AxisNamePos], global_axis_sizes: Dict[AxisName, int], axis_resource_count: Dict[AxisName, ResourceCount], - in_positional_semantics: Sequence[bool]): + in_positional_semantics: Sequence[_PositionalSemantics]): global_axis_sizes = dict(global_axis_sizes) for arg, in_axes, ips in zip(args_flat, in_axes_flat, in_positional_semantics): for name, dim in in_axes.items(): @@ -1824,26 +1817,22 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env, axis_resources, resource_env, global_axis_sizes, in_positional_semantics).to_mesh_axes(in_axes_flat) for arg, xmap_array_mapping in safe_zip(args_flat, mesh_in_axes): - if isinstance(arg, (GlobalDeviceArray, ArrayImpl)): - arr_flavor = 'GDA' if isinstance(arg, GlobalDeviceArray) else 'Array' - if arr_flavor == 'Array' and not isinstance(arg.sharding, NamedSharding): + if isinstance(arg, ArrayImpl): + if not isinstance(arg.sharding, NamedSharding): continue - mesh = arg.mesh if arr_flavor == 'GDA' else arg.sharding.mesh + mesh = arg.sharding.mesh if mesh != resource_env.physical_mesh: - raise ValueError(f"xmap's mesh and {arr_flavor}'s mesh should be equal. " + raise ValueError("xmap's mesh and Array's mesh should be equal. " f"Got xmap mesh: {resource_env.physical_mesh},\n" - f"{arr_flavor} mesh: {mesh}") + f"Array mesh: {mesh}") - if arr_flavor == 'GDA': - s = pxla.create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes) - else: - s = arg.sharding + s = arg.sharding xmap_sharding = pxla.create_mesh_pspec_sharding( mesh, pxla.array_mapping_to_axis_resources(xmap_array_mapping)) # This check is cached because comparing OpSharding is expensive during # dispatch and if the shardings are the same, then there is no need to # compare twice. - _check_sharding(s, xmap_sharding, arg.ndim, arr_flavor) + _check_sharding(s, xmap_sharding, arg.ndim, 'Array') # TODO: We should relax this at least for "constructor primitives" diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 3528fc95d..29c3911d0 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -28,7 +28,6 @@ import jax from jax._src import core from jax import stages from jax.errors import JAXTypeError -from jax._src.global_device_array import GlobalDeviceArray as GDA from jax.interpreters import partial_eval as pe from jax.interpreters import xla from jax._src.interpreters.pxla import PartitionSpec @@ -491,28 +490,17 @@ def common_infer_params(pjit_info_args, *args, **kwargs): out_shardings = tree_map( lambda x: x if _is_unspecified(x) else _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), user_out_shardings) - # This check fails extremely rarely and has a huge cost in the dispatch - # path. So hide it behind the jax_enable_checks flag. - if jax.config.jax_enable_checks: - _maybe_check_pjit_gda_mesh(args_flat, pjit_mesh) del user_in_shardings, user_out_shardings local_in_avals = tuple(shaped_abstractify(a) for a in args_flat) - # TODO(yashkatariya): This is a hack. This should go away when avals have - # is_global attribute. - if jax.config.jax_array: - in_positional_semantics = (pxla._PositionalSemantics.GLOBAL,) * len(args_flat) - else: - in_positional_semantics = tuple(tree_map(_get_in_positional_semantics, args_flat)) - out_positional_semantics = ( - pxla._PositionalSemantics.GLOBAL - if jax.config.jax_parallel_functions_output_gda or jax.config.jax_array else - pxla.positional_semantics.val) + + in_positional_semantics = (pxla._PositionalSemantics.GLOBAL,) * len(args_flat) + out_positional_semantics = pxla._PositionalSemantics.GLOBAL global_in_avals, canonicalized_in_shardings_flat = _process_in_axis_resources( hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics, - tuple(isinstance(a, GDA) for a in args_flat), resource_env) + resource_env) jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( flat_fun, hashable_pytree(out_shardings), global_in_avals, @@ -842,7 +830,7 @@ class PytreeLeaf: @lru_cache(maxsize=4096) def _process_in_axis_resources(in_shardings_thunk, local_in_avals, - in_tree, in_positional_semantics, is_gda, + in_tree, in_positional_semantics, resource_env): orig_in_shardings = in_shardings_thunk() # Only do this if original in_shardings are unspecified. If they are @@ -854,66 +842,15 @@ def _process_in_axis_resources(in_shardings_thunk, local_in_avals, "pjit in_shardings", in_tree, orig_in_shardings, tupled_args=True) - # Fork here because the `Array` path is very simple and doesn't need all the - # complexity below. - if config.jax_array: - pjit_check_aval_sharding(in_shardings_flat, local_in_avals, "pjit arguments", - allow_uneven_sharding=False) - global_in_avals = local_in_avals - # TODO(yashkatariya): Only check for is_auto or _is_unspecified when - # FROM_GDA is removed. - canonicalized_shardings = tuple( - i if _is_unspecified_or_from_gda_or_auto(i) else to_gspmd_sharding(i, aval.ndim) - for i, aval in safe_zip(in_shardings_flat, global_in_avals)) - return tuple(global_in_avals), canonicalized_shardings - - if not local_in_avals: - assert not in_shardings_flat - return (), () - - in_axis_resources_flat = tuple( - i if _is_from_gda(i) or is_auto(i) else i._parsed_pspec - for i in in_shardings_flat) - - # This check should be above local_to_global call below otherwise if - # `FROM_GDA` is passed to any input other than GDA, a ugly error message - # will be raised because get_array_mapping (in local_to_global) of a - # FROM_GDA cannot happen. - tree_map(_check_resources_mismatch, in_axis_resources_flat, is_gda) - # If all inputs have global semantics or fully replicated, then the avals are - # global and the mesh should also be global. This split is because - # non-contiguous mesh can only be used if all inputs have global semantics or - # fully replicated. - # Use canonicalized in_axis_resources here because we want to treat P(None) - # and None (for example) as equivalent. - if all( - (not _is_from_gda(p) and not is_auto(p) and - CanonicalizedParsedPartitionSpec(p).partitions == ()) or - ips == pxla._PositionalSemantics.GLOBAL - for p, ips in safe_zip(in_axis_resources_flat, in_positional_semantics)): - # Shapes should be checked against non canonicalized in_axis_resources. - # For example, partitions of () and ((),) are not equivalent, since the - # first one is a valid spec for a scalar value, while the second is not! - pjit_check_aval_sharding(in_shardings_flat, local_in_avals, "pjit arguments", - allow_uneven_sharding=False) - else: - pjit_check_aval_sharding( - [i if _is_from_gda(i) or is_auto(i) else - NamedSharding(i.mesh.local_mesh, i.spec) - for i in in_shardings_flat], - local_in_avals, "pjit arguments", allow_uneven_sharding=False) - - # Local or global avals doesn't matter for converting to op sharding because - # the `ndim` does not change. - canonicalized_in_shardings_flat = tuple( - i if _is_from_gda(i) or is_auto(i) else to_gspmd_sharding(i, aval.ndim) - for i, aval in safe_zip(in_shardings_flat, local_in_avals)) - - global_in_avals = local_to_global( - in_positional_semantics, local_in_avals, canonicalized_in_shardings_flat, - resource_env.physical_mesh) - - return tuple(global_in_avals), canonicalized_in_shardings_flat + pjit_check_aval_sharding(in_shardings_flat, local_in_avals, "pjit arguments", + allow_uneven_sharding=False) + global_in_avals = local_in_avals + # TODO(yashkatariya): Only check for is_auto or _is_unspecified when + # FROM_GDA is removed. + canonicalized_shardings = tuple( + i if _is_unspecified_or_from_gda_or_auto(i) else to_gspmd_sharding(i, aval.ndim) + for i, aval in safe_zip(in_shardings_flat, global_in_avals)) + return tuple(global_in_avals), canonicalized_shardings @lu.cache @@ -1156,11 +1093,6 @@ def _prepare_axis_resources(axis_resources, return tree_unflatten(treedef, new_entries), new_entries, treedef -def _check_resources_mismatch(in_axis_resources_flat, is_gda): - if not is_gda and _is_from_gda(in_axis_resources_flat): - raise ValueError('For a non-GDA input, the corresponding resource in ' - 'in_axis_resources cannot be `pjit.FROM_GDA`.') - def _check_unique_resources(axis_resources, arg_name): for arg_axis_resources in axis_resources: if not arg_axis_resources: continue @@ -2204,11 +2136,6 @@ def _calc_is_global_sequence(in_positional_semantics, in_shardings): pxla.is_op_sharding_replicated(i._op_sharding)) for ips, i in safe_zip(in_positional_semantics, in_shardings)) -def _get_in_positional_semantics(arg) -> pxla._PositionalSemantics: - if isinstance(arg, GDA): - return pxla._PositionalSemantics.GLOBAL - return pxla.positional_semantics.val - def _fast_path_get_device_assignment( shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]: @@ -2244,20 +2171,11 @@ def _maybe_replace_from_gda_with_pspec( out.append(in_sharding_flat) elif isinstance(arg, array.ArrayImpl): out.append(to_gspmd_sharding(arg.sharding, arg.ndim)) - elif isinstance(arg, GDA): - gda_sharding = pxla.create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes) - out.append(_gda_check_and_get_sharding(gda_sharding, in_sharding_flat, arg.ndim)) else: out.append(in_sharding_flat) return tuple(out) -def _maybe_check_pjit_gda_mesh(args, mesh): - for x in args: - if isinstance(x, GDA) and x.mesh != mesh: - raise ValueError("Pjit's mesh and GDA's mesh should be equal. Got Pjit " - f"mesh: {mesh},\n GDA mesh: {x.mesh}") - # -------------------- XLA OpSharding to PartitionSpec -------------------- # Note that OpSharding is more expressive than PartitionSpecs, so it's not # always possible to convert them, but the code below should at least diff --git a/jax/experimental/gda_serialization/serialization.py b/jax/experimental/gda_serialization/serialization.py index 1f560cf24..289e93802 100644 --- a/jax/experimental/gda_serialization/serialization.py +++ b/jax/experimental/gda_serialization/serialization.py @@ -26,7 +26,6 @@ from typing import Callable, Sequence, Optional, Dict, Any import jax from jax._src import distributed from jax._src.config import config -from jax._src import global_device_array as gda from jax._src import array from jax._src import sharding from jax._src import sharding_impls @@ -64,27 +63,6 @@ async def create_async_array_from_callback( global_shape, inp_sharding, dbs) -async def create_async_gda_from_callback( - global_shape: gda.Shape, - global_mesh: Mesh, - mesh_axes: gda.MeshAxes, - data_callback: Callable[[gda.Index], asyncio.Future], -): - global_idx_rid = gda.get_shard_indices_replica_ids( - global_shape, global_mesh, mesh_axes) - local_devices = global_mesh.local_devices - future_arrays = [data_callback(global_idx_rid[d][0]) - for d in local_devices] - # Pause here and come back to `from_async_callback()` when future_arrays are - # ready. device_put cannot happen with future_arrays. - local_arrays = await asyncio.gather(*future_arrays) - - dbs = [jax.device_put(array, device) - for array, device in zip(local_arrays, local_devices)] - return gda.GlobalDeviceArray(global_shape, global_mesh, mesh_axes, dbs, - gda._GdaFastPathArgs(global_idx_rid, local_devices)) - - def _get_metadata(arr): if arr.dtype == jnp.bfloat16: # Tensorstore uses 'bfloat16', not ' Tuple[TfVal, DType]: @@ -1153,8 +1142,7 @@ def _tfval_to_tensor_jax_dtype(val: TfVal, # The float0 type is not known to TF. if jax_dtype == dtypes.float0: val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype) - tf_val = tf.convert_to_tensor( - _maybe_decode_gda(val), dtype=conversion_dtype) + tf_val = tf.convert_to_tensor(val, dtype=conversion_dtype) if do_memoize: _thread_local_state.constant_cache[const_key] = (val, tf_val) return tf_val, jax_dtype diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index d3712a9a5..713917648 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -30,7 +30,6 @@ from jax.interpreters import xla from jax._src import pjit as pjit_lib from jax.experimental.pjit import pjit, FROM_GDA from jax.sharding import PartitionSpec as P -from jax._src.global_device_array import GlobalDeviceArray from jax._src import distributed from jax._src import config as config_internal import numpy as np @@ -60,9 +59,6 @@ def broadcast_one_to_all(in_tree: Any, is_source: Optional[bool] = None) -> Any: is_source = jax.process_index() == 0 def pre_pmap(x): - if isinstance(x, GlobalDeviceArray): - raise ValueError('GDAs cannot be broadcasted from source host to other ' - 'hosts.') if is_source: return np.concatenate([ x[None, ...], @@ -148,35 +144,8 @@ def process_allgather(in_tree: Any, tiled: bool = False) -> Any: """ def _pjit(inp): - if jax.config.jax_array: - return _handle_array_process_allgather(inp, tiled) - else: - if isinstance(inp, GlobalDeviceArray): - if inp.is_fully_replicated: - return np.asarray(inp.addressable_data(0)) - global_mesh = inp.mesh - in_axis_resources = FROM_GDA - else: - # DA/SDA/np.array will be sharded based on global_mesh.local_mesh. - # Shape of local_mesh will always be (1, local_device_count()) - devices = np.array(jax.devices()).reshape(jax.process_count(), - jax.local_device_count()) - global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices')) - in_axis_resources = P('processes') - if inp.ndim == 0 or not tiled: - inp = np.expand_dims(inp, axis=0) - - with global_mesh: - out = pjit( - _identity_fn, in_shardings=in_axis_resources, out_shardings=None - )(inp) - return np.asarray(out.addressable_data(0)) - - if jax.config.jax_array: - return jax.tree_map(_pjit, in_tree) # array route - else: - with config_internal.parallel_functions_output_gda(True): - return jax.tree_map(_pjit, in_tree) # gda route + return _handle_array_process_allgather(inp, tiled) + return jax.tree_map(_pjit, in_tree) def assert_equal(in_tree, fail_message: str = ''):