Remove GDA from JAX since jax.Array is the default type and cannot be disabled anymore as per https://jax.readthedocs.io/en/latest/jax_array_migration.html#how-can-i-disable-jax-array-for-now

PiperOrigin-RevId: 516905931
This commit is contained in:
Yash Katariya 2023-03-15 12:59:33 -07:00 committed by jax authors
parent 6f52388ecc
commit 634035abd7
11 changed files with 32 additions and 906 deletions

View File

@ -1,11 +0,0 @@
``jax.experimental.global_device_array`` module
===============================================
.. automodule:: jax.experimental.global_device_array
API
---
.. autoclass:: GlobalDeviceArray
:members:
.. autoclass:: Shard

View File

@ -15,7 +15,6 @@ Experimental Modules
:maxdepth: 1
jax.experimental.checkify
jax.experimental.global_device_array
jax.experimental.host_callback
jax.experimental.maps
jax.experimental.pjit

View File

@ -107,7 +107,6 @@ py_library_providing_imports_info(
"_src/dtypes.py",
"_src/errors.py",
"_src/flatten_util.py",
"_src/global_device_array.py",
"_src/__init__.py",
"_src/lax_reference.py",
"_src/linear_util.py",
@ -381,9 +380,6 @@ py_library_providing_imports_info(
"experimental/*.py",
"example_libraries/*.py",
],
exclude = [
"experimental/global_device_array.py",
],
),
visibility = ["//visibility:public"],
deps = [
@ -506,10 +502,3 @@ pytype_library(
visibility = ["//visibility:public"],
deps = [":jax"],
)
pytype_library(
name = "global_device_array",
srcs = ["experimental/global_device_array.py"],
visibility = [":internal"],
deps = [":jax"],
)

View File

@ -1,650 +0,0 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import Counter
import dataclasses
import functools
import math
import warnings
import numpy as np
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple
import jax
from jax._src import core
from jax._src import dispatch
from jax._src import api_util
from jax._src.lib import xla_client as xc
from jax._src.config import config
from jax._src.interpreters import pxla
from jax.interpreters import xla, mlir
from jax._src.util import safe_zip
from jax._src.interpreters.pxla import PartitionSpec
Shape = Tuple[int, ...]
MeshAxes = PartitionSpec
DeviceArray = xc.Buffer
Device = xc.Device
ArrayLike = Union[np.ndarray, DeviceArray]
Index = Tuple[slice, ...]
_hashed_index = lambda x: hash(tuple((v.start, v.stop) for v in x))
def _get_sharding_spec(global_shape, global_mesh, mesh_axes):
array_mapping = pxla.get_array_mapping(mesh_axes)
# The dtype doesn't matter for creating sharding specs.
aval = core.ShapedArray(global_shape, np.float32)
return pxla.mesh_sharding_specs(global_mesh.shape,
global_mesh.axis_names)(aval, array_mapping)
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes) -> Tuple[Index, ...]:
sharding_spec = _get_sharding_spec(global_shape, global_mesh, mesh_axes)
indices = pxla.spec_to_indices(global_shape, sharding_spec)
return indices # type: ignore
@functools.lru_cache(maxsize=4096)
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes) -> Mapping[Device, Index]:
indices = _get_indices(global_shape, global_mesh, mesh_axes)
# The type: ignore is to ignore the type returned by `spec_to_indices`.
return {
d: i
for d, i in safe_zip(global_mesh.devices.flat, indices)} # type: ignore
@functools.lru_cache(maxsize=4096)
def get_shard_indices_replica_ids(
global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes) -> Mapping[Device, Tuple[Index, int]]:
return _get_shard_indices_replica_ids_uncached(global_shape, global_mesh, mesh_axes)
def _get_shard_indices_replica_ids_uncached(
global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes) -> Mapping[Device, Tuple[Index, int]]:
indices = _get_indices(global_shape, global_mesh, mesh_axes)
index_to_replica: Dict[int, int] = Counter()
out = {}
unique_shards = 0
for device, index in safe_zip(global_mesh.devices.flat, indices):
h_index = _hashed_index(index)
replica_id = index_to_replica[h_index]
if replica_id == 0:
unique_shards += 1
index_to_replica[h_index] += 1
out[device] = (index, replica_id)
shard_shape = get_shard_shape(global_shape, global_mesh, mesh_axes)
expected_unique_shards = math.prod(
[g // s for g, s in safe_zip(global_shape, shard_shape) if g != 0 or s != 0])
if expected_unique_shards != unique_shards:
raise RuntimeError(
f'Number of expected unique shards are: {expected_unique_shards} but '
f'got {unique_shards}. Please file a bug at '
'https://github.com/google/jax/issues.')
return out
@functools.lru_cache(maxsize=4096)
def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
chunk_size = []
for mesh_axis, size in zip(mesh_axes, global_shape):
if not mesh_axis:
chunk_size.append(size)
elif isinstance(mesh_axis, tuple):
m = math.prod([global_mesh.shape[ma] for ma in mesh_axis])
chunk_size.append(size // m)
else:
chunk_size.append(size // global_mesh.shape[mesh_axis])
if len(chunk_size) != len(global_shape):
chunk_size.extend(global_shape[len(chunk_size):])
return tuple(chunk_size)
@dataclasses.dataclass(frozen=True)
class Shard:
"""A single data shard of a GlobalDeviceArray.
Args:
device : Which device this shard resides on.
index : The index into the global array of this shard.
replica_id : Integer id indicating which replica of the global array this
shard is part of. Always 0 for fully sharded data
(i.e. when theres only 1 replica).
data : The data of this shard. None if ``device`` is non-local.
"""
device: Device
index: Index
replica_id: int
# None if this `Shard` lives on a non-local device.
data: Optional[DeviceArray] = None
class _GdaFastPathArgs(NamedTuple):
global_indices_replica_ids: Mapping[Device, Tuple[Index, int]]
local_devices: Sequence[Device]
class GlobalDeviceArray:
"""A logical array with data sharded across multiple devices and processes.
If youre not already familiar with JAXs multi-process programming model,
please read https://jax.readthedocs.io/en/latest/multi_process.html.
You can also read about pjit (https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html)
to learn about ``Mesh``, ``PartitionSpec`` and how arrays can be
partitioned or replicated.
A GlobalDeviceArray (GDA) can be thought of as a view into a single logical
array sharded across processes. The logical array is the global array, and
each process has a GlobalDeviceArray object referring to the same global array
(similarly to how each process runs a multi-process pmap or pjit). Each process
can access the shape, dtype, etc. of the global array via the GDA, pass the
GDA into multi-process pjits, and get GDAs as pjit outputs (coming soon: xmap
and pmap). However, each process can only directly access the shards of the
global array data stored on its local devices.
GDAs can help manage the inputs and outputs of multi-process computations.
A GDA keeps track of which shard of the global array belongs to which device,
and provides callback-based APIs to materialize the correct shard of the data
needed for each local device of each process.
A GDA consists of data shards. Each shard is stored on a different device.
There are local shards and global shards. Local shards are those on local
devices, and the data is visible to the current process. Global shards are
those across all devices (including local devices), and the data isnt visible
if the shard is on a non-local device with respect to the current process.
Please see the ``Shard`` class to see what information is stored inside that
data structure.
Note: to make pjit output GlobalDeviceArrays, set the environment variable
``JAX_PARALLEL_FUNCTIONS_OUTPUT_GDA=true`` or add the following to your code:
``jax.config.update('jax_parallel_functions_output_gda', True)``
Args:
global_shape : The global shape of the array.
global_mesh : The global mesh representing devices across multiple
processes.
mesh_axes : A sequence with length less than or equal to the rank of the
global array (i.e. the length of the global shape). Each element can be:
* An axis name of ``global_mesh``, indicating that the corresponding
global array axis is partitioned across the given device axis of
``global_mesh``.
* A tuple of axis names of ``global_mesh``. This is like the above option
except the global array axis is partitioned across the product of axes
named in the tuple.
* None indicating that the corresponding global array axis is not
partitioned.
For more information, please see:
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
device_buffers: DeviceArrays that are on the local devices of ``global_mesh``.
Attributes:
shape : Global shape of the array.
dtype : Dtype of the global array.
ndim : Number of array dimensions in the global shape.
size: Number of elements in the global array.
local_shards : List of :class:`Shard` on the local devices of the current process.
Data is materialized for all local shards.
global_shards : List of all :class:`Shard` of the global array. Data isnt
available if a shard is on a non-local device with respect to the current
process.
is_fully_replicated : True if the full array value is present on all devices
of the global mesh.
Example:
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> assert jax.device_count() == 8
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> # Logical mesh is (hosts, devices)
>>> assert global_mesh.shape == {'x': 4, 'y': 2}
>>> global_input_shape = (8, 2)
>>> mesh_axes = P('x', 'y')
...
>>> # Dummy example data; in practice we wouldn't necessarily materialize global data
>>> # in a single process.
>>> global_input_data = np.arange(
... np.prod(global_input_shape)).reshape(global_input_shape)
...
>>> def get_local_data_slice(index):
... # index will be a tuple of slice objects, e.g. (slice(0, 16), slice(0, 4))
... # This method will be called per-local device from the GDA constructor.
... return global_input_data[index]
...
>>> gda = GlobalDeviceArray.from_callback(
... global_input_shape, global_mesh, mesh_axes, get_local_data_slice)
>>> print(gda.shape)
(8, 2)
>>> print(gda.addressable_shards[0].data) # Access the data on a single local device
[[0]
[2]]
>>> print(gda.addressable_shards[0].data.shape)
(2, 1)
>>> # Numpy-style index into the global array that this data shard corresponds to
>>> print(gda.addressable_shards[0].index)
(slice(0, 2, None), slice(0, 1, None))
GDAs can also be given as an input to pjit and you can get GDAs as output from pjit::
# Allow pjit to output GDAs
jax.config.update('jax_parallel_functions_output_gda', True)
f = pjit(lambda x: x @ x.T, in_shardings=P('x', 'y'), out_axis_resources = P('x', 'y'))
with global_mesh:
out = f(gda)
# `out` can be passed to another pjit call, out.addressable_shards can be used to
# export the data to non-jax systems (e.g. for checkpointing or logging), etc.
"""
def __init__(self,
global_shape: Shape,
global_mesh: pxla.Mesh,
mesh_axes: MeshAxes,
device_buffers: Sequence[DeviceArray],
_gda_fast_path_args: Optional[_GdaFastPathArgs] = None,
_enable_checks: bool = True):
warnings.warn(
"GlobalDeviceArray has been deprecated. Please migrate to jax.Array. "
"See https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration "
"on how to migrate to jax.Array.", DeprecationWarning)
self._global_shape = global_shape
self._global_mesh = global_mesh
self._mesh_axes = mesh_axes
self._device_buffers = device_buffers
# Optionally precomputed for performance.
self._gda_fast_path_args = _gda_fast_path_args
if self._gda_fast_path_args is None:
self._local_devices = self._global_mesh.local_devices
else:
self._local_devices = self._gda_fast_path_args.local_devices
if _enable_checks or config.jax_enable_checks:
for db, ld in safe_zip(self._device_buffers, self._local_devices):
if db.device() != ld:
raise ValueError(
"The `global_mesh.local_devices` and `device_buffers` device "
"order doesn't match. Please use `global_mesh.local_devices` to "
"put arrays on devices instead of `jax.local_devices()`")
if _enable_checks or config.jax_enable_checks:
ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes)
assert all(db.shape == ss for db in self._device_buffers), (
f"Expected shard shape {ss} doesn't match the device buffer "
f"shape, got: {[db.shape for db in self._device_buffers]}")
dtype = device_buffers[0].dtype # type: ignore
if _enable_checks or config.jax_enable_checks:
assert all(db.dtype == dtype for db in self._device_buffers), (
"Input arrays to GlobalDeviceArray must have matching dtypes, "
f"got: {[db.dtype for db in self._device_buffers]}")
self.dtype = dtype
def __eq__(self, other: object):
raise NotImplementedError(
"GlobalDeviceArray equality is intentionally unimplemented. "
"Implement desired functionality explicitly, e.g. to check if all "
"values are equal: "
"pjit(lambda x, y: x == y, "
"in_shardings=FROM_GDA, out_shardings=None)"
)
def __str__(self):
return f'GlobalDeviceArray(shape={self.shape}, dtype={self.dtype})'
def __repr__(self):
return (f'GlobalDeviceArray(shape={self.shape}, dtype={self.dtype}, '
f'global_mesh_shape={dict(self.mesh.shape)}, '
f'mesh_axes={self.mesh_axes})')
@property
def shape(self) -> Shape:
return self._global_shape
@property
def ndim(self):
return len(self.shape)
@property
def size(self):
return math.prod(self.shape)
@property
def mesh(self):
return self._global_mesh
@property
def mesh_axes(self) -> MeshAxes:
return self._mesh_axes
@property
def is_fully_replicated(self) -> bool:
return self.shape == self.addressable_data(0).shape
def _create_local_shards(self) -> Sequence[Shard]:
if self._gda_fast_path_args is not None:
global_indices_rid = self._gda_fast_path_args.global_indices_replica_ids
else:
global_indices_rid = get_shard_indices_replica_ids(
self._global_shape, self._global_mesh, self.mesh_axes)
out = []
for db in self._device_buffers:
db = dispatch._set_aval(db)
device = db.device()
index, rid = global_indices_rid[device]
out.append(Shard(device, index, rid, db))
return out
@functools.cached_property
def local_shards(self) -> Sequence[Shard]:
self._check_if_deleted()
return self._create_local_shards()
@functools.cached_property
def addressable_shards(self) -> Sequence[Shard]:
self._check_if_deleted()
return self.local_shards
@property
def global_shards(self) -> Sequence[Shard]:
self._check_if_deleted()
if self.mesh.size == len(self._local_devices):
return self.addressable_shards
# Populating global_shards lazily (i.e. when requested) because populating
# sthem eagerly leads to a performance regression when training on large
# models.
# Also as this a cached property, once calculated, it should be cached. So
# multiple accesses should be cheap.
global_indices_rid = get_shard_indices_replica_ids(
self._global_shape, self._global_mesh, self.mesh_axes)
device_to_buffer = {db.device(): db for db in self._device_buffers}
global_shards = []
for device, (index, rid) in global_indices_rid.items():
local_shard = device.process_index == device.client.process_index()
buf = device_to_buffer[device] if local_shard else None
if buf is not None and buf.aval is None:
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
sh = Shard(device, index, rid, buf)
global_shards.append(sh)
return global_shards
@property
def _value(self):
self._check_if_deleted()
if self.is_fully_replicated:
return np.asarray(self._device_buffers[0])
if self.mesh.is_multi_process:
raise RuntimeError("Fetching value for GDA that spans non-addressable "
"devices is not possible. You can use "
"`jax.experimental.multihost_utils.process_allgather` "
"for this use case.")
unique_shards = [s.data.copy_to_host_async() or s
for s in self.addressable_shards if s.replica_id == 0]
npy_value = np.empty(self.shape, self.dtype)
for s in unique_shards:
npy_value[s.index] = np.asarray(s.data)
return npy_value
def __array__(self, dtype=None, context=None):
self._check_if_deleted()
return self._value if dtype is None else self._value.astype(dtype)
def local_data(self, index) -> DeviceArray:
self._check_if_deleted()
return dispatch._set_aval(self._device_buffers[index])
def addressable_data(self, index) -> DeviceArray:
self._check_if_deleted()
return self.local_data(index)
def block_until_ready(self):
self._check_if_deleted()
for db in self._device_buffers:
db.block_until_ready()
return self
def _check_if_deleted(self):
if self.is_deleted():
raise RuntimeError("GlobalDeviceArray has been deleted.")
def is_deleted(self):
return self._device_buffers is None
def delete(self):
for b in self._device_buffers:
b.delete()
self._device_buffers = None
@property
def sharding(self):
return jax.sharding.NamedSharding(self._global_mesh, self.mesh_axes)
@classmethod
def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes, data_callback: Callable[[Index],
ArrayLike]):
"""Constructs a GlobalDeviceArray via data fetched from ``data_callback``.
``data_callback`` is used to fetch the data for each local slice of the returned GlobalDeviceArray.
Example:
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> global_input_shape = (8, 8)
>>> mesh_axes = P('x', 'y')
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape)
...
>>> def cb(index):
... return global_input_data[index]
...
>>> gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb)
>>> gda.addressable_data(0).shape
(4, 2)
Args:
global_shape : The global shape of the array
global_mesh : The global mesh representing devices across multiple
processes.
mesh_axes : See the ``mesh_axes`` parameter of GlobalDeviceArray.
data_callback : Callback that takes indices into the global array value as input and
returns the corresponding data of the global array value. The data can be returned
as any array-like object, e.g. a ``numpy.ndarray``.
"""
global_indices_rid = get_shard_indices_replica_ids(
global_shape, global_mesh, mesh_axes)
local_devices = global_mesh.local_devices
dbs = [
jax.device_put(data_callback(global_indices_rid[device][0]), device)
for device in local_devices
]
if config.jax_array:
return jax.make_array_from_single_device_arrays(
global_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), dbs)
return cls(global_shape, global_mesh, mesh_axes, dbs,
_gda_fast_path_args=_GdaFastPathArgs(global_indices_rid, local_devices))
@classmethod
def from_batched_callback(cls, global_shape: Shape,
global_mesh: pxla.Mesh, mesh_axes: MeshAxes,
data_callback: Callable[[Sequence[Index]],
Sequence[ArrayLike]]):
"""Constructs a GlobalDeviceArray via batched data fetched from ``data_callback``.
Like ``from_callback``, except the callback function is called only once to fetch all data
local to this process.
Example:
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> global_input_shape = (8, 2)
>>> mesh_axes = P('x')
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape)
...
>>> def batched_cb(indices):
... assert len(indices) == len(global_mesh.local_devices)
... return [global_input_data[index] for index in indices]
...
>>> gda = GlobalDeviceArray.from_batched_callback(global_input_shape, global_mesh, mesh_axes, batched_cb)
>>> gda.addressable_data(0).shape
(2, 2)
Args:
global_shape : The global shape of the array
global_mesh : The global mesh representing devices across multiple
processes.
mesh_axes : See the ``mesh_axes`` parameter of GlobalDeviceArray.
data_callback : Callback that takes a batch of indices into the global array value with
length equal to the number of local devices as input and returns the corresponding data for each index.
The data can be returned as any array-like objects, e.g. ``numpy.ndarray``
"""
global_indices_rid = get_shard_indices_replica_ids(
global_shape, global_mesh, mesh_axes)
local_devices = global_mesh.local_devices
local_indices = [global_indices_rid[d][0] for d in local_devices]
local_arrays = data_callback(local_indices)
dbs = pxla.device_put(local_arrays, local_devices)
if config.jax_array:
return jax.make_array_from_single_device_arrays(
global_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), dbs) # type: ignore
return cls(global_shape, global_mesh, mesh_axes, dbs,
_gda_fast_path_args=_GdaFastPathArgs(global_indices_rid, local_devices))
@classmethod
def from_batched_callback_with_devices(
cls, global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes,
data_callback: Callable[[Sequence[Tuple[Index, Tuple[Device, ...]]]],
Sequence[DeviceArray]]):
"""Constructs a GlobalDeviceArray via batched DeviceArrays fetched from ``data_callback``.
Like ``from_batched_callback``, except the callback function is responsible for returning on-device data (e.g. by calling ``jax.device_put``).
Example:
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> global_input_shape = (8, 2)
>>> mesh_axes = P(('x', 'y'))
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape)
...
>>> def cb(cb_inp):
... dbs = []
... for inp in cb_inp:
... index, devices = inp
... array = global_input_data[index]
... dbs.extend([jax.device_put(array, device) for device in devices])
... return dbs
...
>>> gda = GlobalDeviceArray.from_batched_callback_with_devices(
... global_input_shape, global_mesh, mesh_axes, cb)
>>> gda.addressable_data(0).shape
(1, 2)
Args:
global_shape : The global shape of the array
global_mesh : The global mesh representing devices across multiple
processes.
mesh_axes : See the ``mesh_axes`` parameter of GlobalDeviceArray.
data_callback : Callback that takes agets batch of indices into the global array value with
length equal to the number of local devices as input and returns the corresponding data for
each index. The data must be returned as jax DeviceArrays.
"""
global_indices_rid = get_shard_indices_replica_ids(
global_shape, global_mesh, mesh_axes)
local_devices = global_mesh.local_devices
index_to_device: Dict[int, Tuple[Index, List[Device]]] = {}
for device in local_devices:
index = global_indices_rid[device][0]
h_index = _hashed_index(index)
if h_index not in index_to_device:
index_to_device[h_index] = (index, [device])
else:
index_to_device[h_index][1].append(device)
cb_inp = [
(index, tuple(devices)) for index, devices in index_to_device.values()
]
dbs = data_callback(cb_inp)
if config.jax_array:
return jax.make_array_from_single_device_arrays(
global_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), dbs) # type: ignore
return cls(global_shape, global_mesh, mesh_axes, dbs,
_gda_fast_path_args=_GdaFastPathArgs(global_indices_rid, local_devices))
core.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray(
x.shape, x.dtype)
xla.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray(
x.shape, x.dtype)
xla.canonicalize_dtype_handlers[GlobalDeviceArray] = pxla.identity
api_util._shaped_abstractify_handlers[GlobalDeviceArray] = \
lambda x: core.ShapedArray(x.shape, x.dtype)
# This will only work when GDA is fully addressable i.e. on a single host or
# fully replicated.
def _gda_mlir_constant_handler(val, canonicalize_types=True):
return mlir.ir_constants(val._value,
canonicalize_types=canonicalize_types)
mlir.register_constant_handler(GlobalDeviceArray, _gda_mlir_constant_handler)
def _gda_shard_arg(x, devices, indices, sharding):
x._check_if_deleted()
return x._device_buffers
pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg
def _gda_array_result_handler(global_aval, out_sharding, committed,
is_out_sharding_from_xla):
if core.is_opaque_dtype(global_aval.dtype):
return global_aval.dtype._rules.global_sharded_result_handler(
global_aval, out_sharding, committed, is_out_sharding_from_xla)
global_mesh, out_axis_resources = out_sharding.mesh, out_sharding.spec
global_idx_rid = get_shard_indices_replica_ids(global_aval.shape, global_mesh,
out_axis_resources)
local_devices = global_mesh.local_devices
fast_path_args = _GdaFastPathArgs(global_idx_rid, local_devices)
return lambda bufs: GlobalDeviceArray(
global_aval.shape, global_mesh, out_axis_resources, bufs, fast_path_args,
_enable_checks=False)
pxla.global_result_handlers[
(core.ShapedArray, pxla.OutputType.GlobalDeviceArray)] = _gda_array_result_handler
pxla.global_result_handlers[
(core.ConcreteArray, pxla.OutputType.GlobalDeviceArray)] = _gda_array_result_handler

View File

@ -3706,30 +3706,21 @@ def check_device_backend_on_shardings(shardings) -> bool:
def check_gda_or_array_xla_sharding_match(
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> None:
from jax._src.global_device_array import GlobalDeviceArray
from jax._src.array import ArrayImpl
for arg, xs in safe_zip(args, in_xla_shardings):
if not isinstance(arg, (GlobalDeviceArray, ArrayImpl)):
if not isinstance(arg, ArrayImpl):
continue
if isinstance(arg, GlobalDeviceArray):
arg_sharding = create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
arg_type = 'GDA'
committed = True
else:
arg_sharding = arg.sharding
arg_type = 'Array'
committed = arg._committed
# No need to cache this check since MeshExecutable has a C++ fast path
# for AOT compiled call.
if (not check_device_backend_on_shardings([xs]) and
committed and
not are_op_shardings_equal(arg_sharding._to_xla_op_sharding(arg.ndim),
arg._committed and
not are_op_shardings_equal(arg.sharding._to_xla_op_sharding(arg.ndim),
xs._to_xla_op_sharding(arg.ndim))):
raise ValueError(
f"{arg_type} sharding does not match the input sharding. "
f"Got {arg_type} sharding: {arg_sharding} and xla sharding: {xs} for "
f"Array sharding does not match the input sharding. "
f"Got Array sharding: {arg.sharding} and xla sharding: {xs} for "
f"arg shape: {arg.shape}, arg value: {arg}")

View File

@ -37,7 +37,6 @@ from jax._src import traceback_util
from jax._src.config import config
from jax.errors import JAXTypeError
from jax._src.array import ArrayImpl
from jax._src.global_device_array import GlobalDeviceArray
from jax._src.sharding_impls import NamedSharding
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
@ -521,13 +520,7 @@ def xmap(fun: Callable,
lambda: tuple(_flatten_axes("xmap out_axes", out_tree(), out_axes, tupled_args=False)),
closure=(out_axes_entries, out_axes_treedef))
if config.jax_array:
in_positional_semantics = (_PositionalSemantics.GLOBAL,) * len(args_flat)
else:
in_positional_semantics = tuple(
_PositionalSemantics.GLOBAL
if isinstance(a, GlobalDeviceArray) else _positional_semantics.val
for a in args_flat)
in_positional_semantics = (_PositionalSemantics.GLOBAL,) * len(args_flat)
out_positional_semantics = (
_PositionalSemantics.GLOBAL
if config.jax_array or config.jax_parallel_functions_output_gda
@ -1638,7 +1631,7 @@ def _get_axis_sizes(args_flat: Iterable[Any],
in_axes_flat: Iterable[AxisNamePos],
global_axis_sizes: Dict[AxisName, int],
axis_resource_count: Dict[AxisName, ResourceCount],
in_positional_semantics: Sequence[bool]):
in_positional_semantics: Sequence[_PositionalSemantics]):
global_axis_sizes = dict(global_axis_sizes)
for arg, in_axes, ips in zip(args_flat, in_axes_flat, in_positional_semantics):
for name, dim in in_axes.items():
@ -1824,26 +1817,22 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
axis_resources, resource_env, global_axis_sizes,
in_positional_semantics).to_mesh_axes(in_axes_flat)
for arg, xmap_array_mapping in safe_zip(args_flat, mesh_in_axes):
if isinstance(arg, (GlobalDeviceArray, ArrayImpl)):
arr_flavor = 'GDA' if isinstance(arg, GlobalDeviceArray) else 'Array'
if arr_flavor == 'Array' and not isinstance(arg.sharding, NamedSharding):
if isinstance(arg, ArrayImpl):
if not isinstance(arg.sharding, NamedSharding):
continue
mesh = arg.mesh if arr_flavor == 'GDA' else arg.sharding.mesh
mesh = arg.sharding.mesh
if mesh != resource_env.physical_mesh:
raise ValueError(f"xmap's mesh and {arr_flavor}'s mesh should be equal. "
raise ValueError("xmap's mesh and Array's mesh should be equal. "
f"Got xmap mesh: {resource_env.physical_mesh},\n"
f"{arr_flavor} mesh: {mesh}")
f"Array mesh: {mesh}")
if arr_flavor == 'GDA':
s = pxla.create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
else:
s = arg.sharding
s = arg.sharding
xmap_sharding = pxla.create_mesh_pspec_sharding(
mesh, pxla.array_mapping_to_axis_resources(xmap_array_mapping))
# This check is cached because comparing OpSharding is expensive during
# dispatch and if the shardings are the same, then there is no need to
# compare twice.
_check_sharding(s, xmap_sharding, arg.ndim, arr_flavor)
_check_sharding(s, xmap_sharding, arg.ndim, 'Array')
# TODO: We should relax this at least for "constructor primitives"

View File

@ -28,7 +28,6 @@ import jax
from jax._src import core
from jax import stages
from jax.errors import JAXTypeError
from jax._src.global_device_array import GlobalDeviceArray as GDA
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src.interpreters.pxla import PartitionSpec
@ -491,28 +490,17 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
out_shardings = tree_map(
lambda x: x if _is_unspecified(x) else
_create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), user_out_shardings)
# This check fails extremely rarely and has a huge cost in the dispatch
# path. So hide it behind the jax_enable_checks flag.
if jax.config.jax_enable_checks:
_maybe_check_pjit_gda_mesh(args_flat, pjit_mesh)
del user_in_shardings, user_out_shardings
local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
# TODO(yashkatariya): This is a hack. This should go away when avals have
# is_global attribute.
if jax.config.jax_array:
in_positional_semantics = (pxla._PositionalSemantics.GLOBAL,) * len(args_flat)
else:
in_positional_semantics = tuple(tree_map(_get_in_positional_semantics, args_flat))
out_positional_semantics = (
pxla._PositionalSemantics.GLOBAL
if jax.config.jax_parallel_functions_output_gda or jax.config.jax_array else
pxla.positional_semantics.val)
in_positional_semantics = (pxla._PositionalSemantics.GLOBAL,) * len(args_flat)
out_positional_semantics = pxla._PositionalSemantics.GLOBAL
global_in_avals, canonicalized_in_shardings_flat = _process_in_axis_resources(
hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics,
tuple(isinstance(a, GDA) for a in args_flat), resource_env)
resource_env)
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
flat_fun, hashable_pytree(out_shardings), global_in_avals,
@ -842,7 +830,7 @@ class PytreeLeaf:
@lru_cache(maxsize=4096)
def _process_in_axis_resources(in_shardings_thunk, local_in_avals,
in_tree, in_positional_semantics, is_gda,
in_tree, in_positional_semantics,
resource_env):
orig_in_shardings = in_shardings_thunk()
# Only do this if original in_shardings are unspecified. If they are
@ -854,66 +842,15 @@ def _process_in_axis_resources(in_shardings_thunk, local_in_avals,
"pjit in_shardings", in_tree, orig_in_shardings,
tupled_args=True)
# Fork here because the `Array` path is very simple and doesn't need all the
# complexity below.
if config.jax_array:
pjit_check_aval_sharding(in_shardings_flat, local_in_avals, "pjit arguments",
allow_uneven_sharding=False)
global_in_avals = local_in_avals
# TODO(yashkatariya): Only check for is_auto or _is_unspecified when
# FROM_GDA is removed.
canonicalized_shardings = tuple(
i if _is_unspecified_or_from_gda_or_auto(i) else to_gspmd_sharding(i, aval.ndim)
for i, aval in safe_zip(in_shardings_flat, global_in_avals))
return tuple(global_in_avals), canonicalized_shardings
if not local_in_avals:
assert not in_shardings_flat
return (), ()
in_axis_resources_flat = tuple(
i if _is_from_gda(i) or is_auto(i) else i._parsed_pspec
for i in in_shardings_flat)
# This check should be above local_to_global call below otherwise if
# `FROM_GDA` is passed to any input other than GDA, a ugly error message
# will be raised because get_array_mapping (in local_to_global) of a
# FROM_GDA cannot happen.
tree_map(_check_resources_mismatch, in_axis_resources_flat, is_gda)
# If all inputs have global semantics or fully replicated, then the avals are
# global and the mesh should also be global. This split is because
# non-contiguous mesh can only be used if all inputs have global semantics or
# fully replicated.
# Use canonicalized in_axis_resources here because we want to treat P(None)
# and None (for example) as equivalent.
if all(
(not _is_from_gda(p) and not is_auto(p) and
CanonicalizedParsedPartitionSpec(p).partitions == ()) or
ips == pxla._PositionalSemantics.GLOBAL
for p, ips in safe_zip(in_axis_resources_flat, in_positional_semantics)):
# Shapes should be checked against non canonicalized in_axis_resources.
# For example, partitions of () and ((),) are not equivalent, since the
# first one is a valid spec for a scalar value, while the second is not!
pjit_check_aval_sharding(in_shardings_flat, local_in_avals, "pjit arguments",
allow_uneven_sharding=False)
else:
pjit_check_aval_sharding(
[i if _is_from_gda(i) or is_auto(i) else
NamedSharding(i.mesh.local_mesh, i.spec)
for i in in_shardings_flat],
local_in_avals, "pjit arguments", allow_uneven_sharding=False)
# Local or global avals doesn't matter for converting to op sharding because
# the `ndim` does not change.
canonicalized_in_shardings_flat = tuple(
i if _is_from_gda(i) or is_auto(i) else to_gspmd_sharding(i, aval.ndim)
for i, aval in safe_zip(in_shardings_flat, local_in_avals))
global_in_avals = local_to_global(
in_positional_semantics, local_in_avals, canonicalized_in_shardings_flat,
resource_env.physical_mesh)
return tuple(global_in_avals), canonicalized_in_shardings_flat
pjit_check_aval_sharding(in_shardings_flat, local_in_avals, "pjit arguments",
allow_uneven_sharding=False)
global_in_avals = local_in_avals
# TODO(yashkatariya): Only check for is_auto or _is_unspecified when
# FROM_GDA is removed.
canonicalized_shardings = tuple(
i if _is_unspecified_or_from_gda_or_auto(i) else to_gspmd_sharding(i, aval.ndim)
for i, aval in safe_zip(in_shardings_flat, global_in_avals))
return tuple(global_in_avals), canonicalized_shardings
@lu.cache
@ -1156,11 +1093,6 @@ def _prepare_axis_resources(axis_resources,
return tree_unflatten(treedef, new_entries), new_entries, treedef
def _check_resources_mismatch(in_axis_resources_flat, is_gda):
if not is_gda and _is_from_gda(in_axis_resources_flat):
raise ValueError('For a non-GDA input, the corresponding resource in '
'in_axis_resources cannot be `pjit.FROM_GDA`.')
def _check_unique_resources(axis_resources, arg_name):
for arg_axis_resources in axis_resources:
if not arg_axis_resources: continue
@ -2204,11 +2136,6 @@ def _calc_is_global_sequence(in_positional_semantics, in_shardings):
pxla.is_op_sharding_replicated(i._op_sharding))
for ips, i in safe_zip(in_positional_semantics, in_shardings))
def _get_in_positional_semantics(arg) -> pxla._PositionalSemantics:
if isinstance(arg, GDA):
return pxla._PositionalSemantics.GLOBAL
return pxla.positional_semantics.val
def _fast_path_get_device_assignment(
shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]:
@ -2244,20 +2171,11 @@ def _maybe_replace_from_gda_with_pspec(
out.append(in_sharding_flat)
elif isinstance(arg, array.ArrayImpl):
out.append(to_gspmd_sharding(arg.sharding, arg.ndim))
elif isinstance(arg, GDA):
gda_sharding = pxla.create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
out.append(_gda_check_and_get_sharding(gda_sharding, in_sharding_flat, arg.ndim))
else:
out.append(in_sharding_flat)
return tuple(out)
def _maybe_check_pjit_gda_mesh(args, mesh):
for x in args:
if isinstance(x, GDA) and x.mesh != mesh:
raise ValueError("Pjit's mesh and GDA's mesh should be equal. Got Pjit "
f"mesh: {mesh},\n GDA mesh: {x.mesh}")
# -------------------- XLA OpSharding to PartitionSpec --------------------
# Note that OpSharding is more expressive than PartitionSpecs, so it's not
# always possible to convert them, but the code below should at least

View File

@ -26,7 +26,6 @@ from typing import Callable, Sequence, Optional, Dict, Any
import jax
from jax._src import distributed
from jax._src.config import config
from jax._src import global_device_array as gda
from jax._src import array
from jax._src import sharding
from jax._src import sharding_impls
@ -64,27 +63,6 @@ async def create_async_array_from_callback(
global_shape, inp_sharding, dbs)
async def create_async_gda_from_callback(
global_shape: gda.Shape,
global_mesh: Mesh,
mesh_axes: gda.MeshAxes,
data_callback: Callable[[gda.Index], asyncio.Future],
):
global_idx_rid = gda.get_shard_indices_replica_ids(
global_shape, global_mesh, mesh_axes)
local_devices = global_mesh.local_devices
future_arrays = [data_callback(global_idx_rid[d][0])
for d in local_devices]
# Pause here and come back to `from_async_callback()` when future_arrays are
# ready. device_put cannot happen with future_arrays.
local_arrays = await asyncio.gather(*future_arrays)
dbs = [jax.device_put(array, device)
for array, device in zip(local_arrays, local_devices)]
return gda.GlobalDeviceArray(global_shape, global_mesh, mesh_axes, dbs,
gda._GdaFastPathArgs(global_idx_rid, local_devices))
def _get_metadata(arr):
if arr.dtype == jnp.bfloat16:
# Tensorstore uses 'bfloat16', not '<V2'.
@ -280,15 +258,7 @@ async def async_deserialize(in_sharding, tensorstore_spec,
await byte_limiter.release_bytes(requested_bytes)
return out
if config.jax_array:
return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
else:
if not isinstance(in_sharding, sharding_impls.NamedSharding):
raise ValueError('Deserializing a GlobalDeviceArray is only possible with '
'a `NamedSharding` which consists of a `mesh` and '
f'`pspec`, but got {in_sharding}')
return await create_async_gda_from_callback(
tuple(shape), in_sharding.mesh, in_sharding.spec, cb)
return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
def run_deserialization(shardings: Sequence[sharding.Sharding],

View File

@ -1,26 +0,0 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.global_device_array import (
Device as Device,
GlobalDeviceArray as GlobalDeviceArray,
MeshAxes as MeshAxes,
PartitionSpec as PartitionSpec,
Shard as Shard,
Shape as Shape,
get_shard_indices as get_shard_indices,
get_shard_shape as get_shard_shape,
_get_sharding_spec as _get_sharding_spec,
_hashed_index as _hashed_index,
)

View File

@ -54,7 +54,6 @@ from jax._src import random as random_internal
from jax._src import source_info_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.global_device_array import GlobalDeviceArray
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
@ -1093,16 +1092,6 @@ def _to_jax_dtype(tf_dtype):
return dt
def _maybe_decode_gda(gda_or_py_object: Any):
"""Convert GlobalDeviceArray into numpy object."""
if isinstance(gda_or_py_object, GlobalDeviceArray):
if jax.process_count() != 1:
raise RuntimeError("GlobalDeviceArray does not support multi-process"
f" currently. Process num = {jax.process_count()}")
return gda_or_py_object._value
return gda_or_py_object
def _tfval_to_tensor_jax_dtype(val: TfVal,
jax_dtype: Optional[DType] = None,
memoize_constants=False) -> Tuple[TfVal, DType]:
@ -1153,8 +1142,7 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
# The float0 type is not known to TF.
if jax_dtype == dtypes.float0:
val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype)
tf_val = tf.convert_to_tensor(
_maybe_decode_gda(val), dtype=conversion_dtype)
tf_val = tf.convert_to_tensor(val, dtype=conversion_dtype)
if do_memoize:
_thread_local_state.constant_cache[const_key] = (val, tf_val)
return tf_val, jax_dtype

View File

@ -30,7 +30,6 @@ from jax.interpreters import xla
from jax._src import pjit as pjit_lib
from jax.experimental.pjit import pjit, FROM_GDA
from jax.sharding import PartitionSpec as P
from jax._src.global_device_array import GlobalDeviceArray
from jax._src import distributed
from jax._src import config as config_internal
import numpy as np
@ -60,9 +59,6 @@ def broadcast_one_to_all(in_tree: Any, is_source: Optional[bool] = None) -> Any:
is_source = jax.process_index() == 0
def pre_pmap(x):
if isinstance(x, GlobalDeviceArray):
raise ValueError('GDAs cannot be broadcasted from source host to other '
'hosts.')
if is_source:
return np.concatenate([
x[None, ...],
@ -148,35 +144,8 @@ def process_allgather(in_tree: Any, tiled: bool = False) -> Any:
"""
def _pjit(inp):
if jax.config.jax_array:
return _handle_array_process_allgather(inp, tiled)
else:
if isinstance(inp, GlobalDeviceArray):
if inp.is_fully_replicated:
return np.asarray(inp.addressable_data(0))
global_mesh = inp.mesh
in_axis_resources = FROM_GDA
else:
# DA/SDA/np.array will be sharded based on global_mesh.local_mesh.
# Shape of local_mesh will always be (1, local_device_count())
devices = np.array(jax.devices()).reshape(jax.process_count(),
jax.local_device_count())
global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices'))
in_axis_resources = P('processes')
if inp.ndim == 0 or not tiled:
inp = np.expand_dims(inp, axis=0)
with global_mesh:
out = pjit(
_identity_fn, in_shardings=in_axis_resources, out_shardings=None
)(inp)
return np.asarray(out.addressable_data(0))
if jax.config.jax_array:
return jax.tree_map(_pjit, in_tree) # array route
else:
with config_internal.parallel_functions_output_gda(True):
return jax.tree_map(_pjit, in_tree) # gda route
return _handle_array_process_allgather(inp, tiled)
return jax.tree_map(_pjit, in_tree)
def assert_equal(in_tree, fail_message: str = ''):