mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove GDA from JAX since jax.Array is the default type and cannot be disabled anymore as per https://jax.readthedocs.io/en/latest/jax_array_migration.html#how-can-i-disable-jax-array-for-now
PiperOrigin-RevId: 516905931
This commit is contained in:
parent
6f52388ecc
commit
634035abd7
@ -1,11 +0,0 @@
|
||||
``jax.experimental.global_device_array`` module
|
||||
===============================================
|
||||
|
||||
.. automodule:: jax.experimental.global_device_array
|
||||
|
||||
API
|
||||
---
|
||||
|
||||
.. autoclass:: GlobalDeviceArray
|
||||
:members:
|
||||
.. autoclass:: Shard
|
@ -15,7 +15,6 @@ Experimental Modules
|
||||
:maxdepth: 1
|
||||
|
||||
jax.experimental.checkify
|
||||
jax.experimental.global_device_array
|
||||
jax.experimental.host_callback
|
||||
jax.experimental.maps
|
||||
jax.experimental.pjit
|
||||
|
11
jax/BUILD
11
jax/BUILD
@ -107,7 +107,6 @@ py_library_providing_imports_info(
|
||||
"_src/dtypes.py",
|
||||
"_src/errors.py",
|
||||
"_src/flatten_util.py",
|
||||
"_src/global_device_array.py",
|
||||
"_src/__init__.py",
|
||||
"_src/lax_reference.py",
|
||||
"_src/linear_util.py",
|
||||
@ -381,9 +380,6 @@ py_library_providing_imports_info(
|
||||
"experimental/*.py",
|
||||
"example_libraries/*.py",
|
||||
],
|
||||
exclude = [
|
||||
"experimental/global_device_array.py",
|
||||
],
|
||||
),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
@ -506,10 +502,3 @@ pytype_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "global_device_array",
|
||||
srcs = ["experimental/global_device_array.py"],
|
||||
visibility = [":internal"],
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
@ -1,650 +0,0 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import Counter
|
||||
import dataclasses
|
||||
import functools
|
||||
import math
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import api_util
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import xla, mlir
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.interpreters.pxla import PartitionSpec
|
||||
|
||||
Shape = Tuple[int, ...]
|
||||
MeshAxes = PartitionSpec
|
||||
DeviceArray = xc.Buffer
|
||||
Device = xc.Device
|
||||
ArrayLike = Union[np.ndarray, DeviceArray]
|
||||
Index = Tuple[slice, ...]
|
||||
|
||||
|
||||
_hashed_index = lambda x: hash(tuple((v.start, v.stop) for v in x))
|
||||
|
||||
|
||||
def _get_sharding_spec(global_shape, global_mesh, mesh_axes):
|
||||
array_mapping = pxla.get_array_mapping(mesh_axes)
|
||||
# The dtype doesn't matter for creating sharding specs.
|
||||
aval = core.ShapedArray(global_shape, np.float32)
|
||||
return pxla.mesh_sharding_specs(global_mesh.shape,
|
||||
global_mesh.axis_names)(aval, array_mapping)
|
||||
|
||||
|
||||
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes) -> Tuple[Index, ...]:
|
||||
sharding_spec = _get_sharding_spec(global_shape, global_mesh, mesh_axes)
|
||||
indices = pxla.spec_to_indices(global_shape, sharding_spec)
|
||||
return indices # type: ignore
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes) -> Mapping[Device, Index]:
|
||||
indices = _get_indices(global_shape, global_mesh, mesh_axes)
|
||||
# The type: ignore is to ignore the type returned by `spec_to_indices`.
|
||||
return {
|
||||
d: i
|
||||
for d, i in safe_zip(global_mesh.devices.flat, indices)} # type: ignore
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def get_shard_indices_replica_ids(
|
||||
global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes) -> Mapping[Device, Tuple[Index, int]]:
|
||||
return _get_shard_indices_replica_ids_uncached(global_shape, global_mesh, mesh_axes)
|
||||
|
||||
def _get_shard_indices_replica_ids_uncached(
|
||||
global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes) -> Mapping[Device, Tuple[Index, int]]:
|
||||
indices = _get_indices(global_shape, global_mesh, mesh_axes)
|
||||
index_to_replica: Dict[int, int] = Counter()
|
||||
out = {}
|
||||
unique_shards = 0
|
||||
for device, index in safe_zip(global_mesh.devices.flat, indices):
|
||||
h_index = _hashed_index(index)
|
||||
replica_id = index_to_replica[h_index]
|
||||
if replica_id == 0:
|
||||
unique_shards += 1
|
||||
index_to_replica[h_index] += 1
|
||||
out[device] = (index, replica_id)
|
||||
|
||||
shard_shape = get_shard_shape(global_shape, global_mesh, mesh_axes)
|
||||
expected_unique_shards = math.prod(
|
||||
[g // s for g, s in safe_zip(global_shape, shard_shape) if g != 0 or s != 0])
|
||||
if expected_unique_shards != unique_shards:
|
||||
raise RuntimeError(
|
||||
f'Number of expected unique shards are: {expected_unique_shards} but '
|
||||
f'got {unique_shards}. Please file a bug at '
|
||||
'https://github.com/google/jax/issues.')
|
||||
return out
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
|
||||
chunk_size = []
|
||||
for mesh_axis, size in zip(mesh_axes, global_shape):
|
||||
if not mesh_axis:
|
||||
chunk_size.append(size)
|
||||
elif isinstance(mesh_axis, tuple):
|
||||
m = math.prod([global_mesh.shape[ma] for ma in mesh_axis])
|
||||
chunk_size.append(size // m)
|
||||
else:
|
||||
chunk_size.append(size // global_mesh.shape[mesh_axis])
|
||||
if len(chunk_size) != len(global_shape):
|
||||
chunk_size.extend(global_shape[len(chunk_size):])
|
||||
return tuple(chunk_size)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Shard:
|
||||
"""A single data shard of a GlobalDeviceArray.
|
||||
|
||||
Args:
|
||||
device : Which device this shard resides on.
|
||||
index : The index into the global array of this shard.
|
||||
replica_id : Integer id indicating which replica of the global array this
|
||||
shard is part of. Always 0 for fully sharded data
|
||||
(i.e. when there’s only 1 replica).
|
||||
data : The data of this shard. None if ``device`` is non-local.
|
||||
"""
|
||||
device: Device
|
||||
index: Index
|
||||
replica_id: int
|
||||
# None if this `Shard` lives on a non-local device.
|
||||
data: Optional[DeviceArray] = None
|
||||
|
||||
|
||||
class _GdaFastPathArgs(NamedTuple):
|
||||
global_indices_replica_ids: Mapping[Device, Tuple[Index, int]]
|
||||
local_devices: Sequence[Device]
|
||||
|
||||
|
||||
class GlobalDeviceArray:
|
||||
"""A logical array with data sharded across multiple devices and processes.
|
||||
|
||||
If you’re not already familiar with JAX’s multi-process programming model,
|
||||
please read https://jax.readthedocs.io/en/latest/multi_process.html.
|
||||
You can also read about pjit (https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html)
|
||||
to learn about ``Mesh``, ``PartitionSpec`` and how arrays can be
|
||||
partitioned or replicated.
|
||||
|
||||
A GlobalDeviceArray (GDA) can be thought of as a view into a single logical
|
||||
array sharded across processes. The logical array is the “global” array, and
|
||||
each process has a GlobalDeviceArray object referring to the same global array
|
||||
(similarly to how each process runs a multi-process pmap or pjit). Each process
|
||||
can access the shape, dtype, etc. of the global array via the GDA, pass the
|
||||
GDA into multi-process pjits, and get GDAs as pjit outputs (coming soon: xmap
|
||||
and pmap). However, each process can only directly access the shards of the
|
||||
global array data stored on its local devices.
|
||||
|
||||
GDAs can help manage the inputs and outputs of multi-process computations.
|
||||
A GDA keeps track of which shard of the global array belongs to which device,
|
||||
and provides callback-based APIs to materialize the correct shard of the data
|
||||
needed for each local device of each process.
|
||||
|
||||
A GDA consists of data shards. Each shard is stored on a different device.
|
||||
There are local shards and global shards. Local shards are those on local
|
||||
devices, and the data is visible to the current process. Global shards are
|
||||
those across all devices (including local devices), and the data isn’t visible
|
||||
if the shard is on a non-local device with respect to the current process.
|
||||
Please see the ``Shard`` class to see what information is stored inside that
|
||||
data structure.
|
||||
|
||||
Note: to make pjit output GlobalDeviceArrays, set the environment variable
|
||||
``JAX_PARALLEL_FUNCTIONS_OUTPUT_GDA=true`` or add the following to your code:
|
||||
``jax.config.update('jax_parallel_functions_output_gda', True)``
|
||||
|
||||
Args:
|
||||
global_shape : The global shape of the array.
|
||||
global_mesh : The global mesh representing devices across multiple
|
||||
processes.
|
||||
mesh_axes : A sequence with length less than or equal to the rank of the
|
||||
global array (i.e. the length of the global shape). Each element can be:
|
||||
|
||||
* An axis name of ``global_mesh``, indicating that the corresponding
|
||||
global array axis is partitioned across the given device axis of
|
||||
``global_mesh``.
|
||||
* A tuple of axis names of ``global_mesh``. This is like the above option
|
||||
except the global array axis is partitioned across the product of axes
|
||||
named in the tuple.
|
||||
* None indicating that the corresponding global array axis is not
|
||||
partitioned.
|
||||
|
||||
For more information, please see:
|
||||
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
|
||||
device_buffers: DeviceArrays that are on the local devices of ``global_mesh``.
|
||||
|
||||
Attributes:
|
||||
shape : Global shape of the array.
|
||||
dtype : Dtype of the global array.
|
||||
ndim : Number of array dimensions in the global shape.
|
||||
size: Number of elements in the global array.
|
||||
local_shards : List of :class:`Shard` on the local devices of the current process.
|
||||
Data is materialized for all local shards.
|
||||
global_shards : List of all :class:`Shard` of the global array. Data isn’t
|
||||
available if a shard is on a non-local device with respect to the current
|
||||
process.
|
||||
is_fully_replicated : True if the full array value is present on all devices
|
||||
of the global mesh.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from jax.sharding import Mesh
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> assert jax.device_count() == 8
|
||||
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
|
||||
>>> # Logical mesh is (hosts, devices)
|
||||
>>> assert global_mesh.shape == {'x': 4, 'y': 2}
|
||||
>>> global_input_shape = (8, 2)
|
||||
>>> mesh_axes = P('x', 'y')
|
||||
...
|
||||
>>> # Dummy example data; in practice we wouldn't necessarily materialize global data
|
||||
>>> # in a single process.
|
||||
>>> global_input_data = np.arange(
|
||||
... np.prod(global_input_shape)).reshape(global_input_shape)
|
||||
...
|
||||
>>> def get_local_data_slice(index):
|
||||
... # index will be a tuple of slice objects, e.g. (slice(0, 16), slice(0, 4))
|
||||
... # This method will be called per-local device from the GDA constructor.
|
||||
... return global_input_data[index]
|
||||
...
|
||||
>>> gda = GlobalDeviceArray.from_callback(
|
||||
... global_input_shape, global_mesh, mesh_axes, get_local_data_slice)
|
||||
>>> print(gda.shape)
|
||||
(8, 2)
|
||||
>>> print(gda.addressable_shards[0].data) # Access the data on a single local device
|
||||
[[0]
|
||||
[2]]
|
||||
>>> print(gda.addressable_shards[0].data.shape)
|
||||
(2, 1)
|
||||
>>> # Numpy-style index into the global array that this data shard corresponds to
|
||||
>>> print(gda.addressable_shards[0].index)
|
||||
(slice(0, 2, None), slice(0, 1, None))
|
||||
|
||||
GDAs can also be given as an input to pjit and you can get GDAs as output from pjit::
|
||||
|
||||
# Allow pjit to output GDAs
|
||||
jax.config.update('jax_parallel_functions_output_gda', True)
|
||||
|
||||
f = pjit(lambda x: x @ x.T, in_shardings=P('x', 'y'), out_axis_resources = P('x', 'y'))
|
||||
with global_mesh:
|
||||
out = f(gda)
|
||||
|
||||
# `out` can be passed to another pjit call, out.addressable_shards can be used to
|
||||
# export the data to non-jax systems (e.g. for checkpointing or logging), etc.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
global_shape: Shape,
|
||||
global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes,
|
||||
device_buffers: Sequence[DeviceArray],
|
||||
_gda_fast_path_args: Optional[_GdaFastPathArgs] = None,
|
||||
_enable_checks: bool = True):
|
||||
warnings.warn(
|
||||
"GlobalDeviceArray has been deprecated. Please migrate to jax.Array. "
|
||||
"See https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration "
|
||||
"on how to migrate to jax.Array.", DeprecationWarning)
|
||||
self._global_shape = global_shape
|
||||
self._global_mesh = global_mesh
|
||||
self._mesh_axes = mesh_axes
|
||||
self._device_buffers = device_buffers
|
||||
|
||||
# Optionally precomputed for performance.
|
||||
self._gda_fast_path_args = _gda_fast_path_args
|
||||
|
||||
if self._gda_fast_path_args is None:
|
||||
self._local_devices = self._global_mesh.local_devices
|
||||
else:
|
||||
self._local_devices = self._gda_fast_path_args.local_devices
|
||||
|
||||
if _enable_checks or config.jax_enable_checks:
|
||||
for db, ld in safe_zip(self._device_buffers, self._local_devices):
|
||||
if db.device() != ld:
|
||||
raise ValueError(
|
||||
"The `global_mesh.local_devices` and `device_buffers` device "
|
||||
"order doesn't match. Please use `global_mesh.local_devices` to "
|
||||
"put arrays on devices instead of `jax.local_devices()`")
|
||||
|
||||
if _enable_checks or config.jax_enable_checks:
|
||||
ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes)
|
||||
assert all(db.shape == ss for db in self._device_buffers), (
|
||||
f"Expected shard shape {ss} doesn't match the device buffer "
|
||||
f"shape, got: {[db.shape for db in self._device_buffers]}")
|
||||
|
||||
dtype = device_buffers[0].dtype # type: ignore
|
||||
if _enable_checks or config.jax_enable_checks:
|
||||
assert all(db.dtype == dtype for db in self._device_buffers), (
|
||||
"Input arrays to GlobalDeviceArray must have matching dtypes, "
|
||||
f"got: {[db.dtype for db in self._device_buffers]}")
|
||||
self.dtype = dtype
|
||||
|
||||
def __eq__(self, other: object):
|
||||
raise NotImplementedError(
|
||||
"GlobalDeviceArray equality is intentionally unimplemented. "
|
||||
"Implement desired functionality explicitly, e.g. to check if all "
|
||||
"values are equal: "
|
||||
"pjit(lambda x, y: x == y, "
|
||||
"in_shardings=FROM_GDA, out_shardings=None)"
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return f'GlobalDeviceArray(shape={self.shape}, dtype={self.dtype})'
|
||||
|
||||
def __repr__(self):
|
||||
return (f'GlobalDeviceArray(shape={self.shape}, dtype={self.dtype}, '
|
||||
f'global_mesh_shape={dict(self.mesh.shape)}, '
|
||||
f'mesh_axes={self.mesh_axes})')
|
||||
|
||||
@property
|
||||
def shape(self) -> Shape:
|
||||
return self._global_shape
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return len(self.shape)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return math.prod(self.shape)
|
||||
|
||||
@property
|
||||
def mesh(self):
|
||||
return self._global_mesh
|
||||
|
||||
@property
|
||||
def mesh_axes(self) -> MeshAxes:
|
||||
return self._mesh_axes
|
||||
|
||||
@property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
return self.shape == self.addressable_data(0).shape
|
||||
|
||||
def _create_local_shards(self) -> Sequence[Shard]:
|
||||
if self._gda_fast_path_args is not None:
|
||||
global_indices_rid = self._gda_fast_path_args.global_indices_replica_ids
|
||||
else:
|
||||
global_indices_rid = get_shard_indices_replica_ids(
|
||||
self._global_shape, self._global_mesh, self.mesh_axes)
|
||||
|
||||
out = []
|
||||
for db in self._device_buffers:
|
||||
db = dispatch._set_aval(db)
|
||||
device = db.device()
|
||||
index, rid = global_indices_rid[device]
|
||||
out.append(Shard(device, index, rid, db))
|
||||
return out
|
||||
|
||||
@functools.cached_property
|
||||
def local_shards(self) -> Sequence[Shard]:
|
||||
self._check_if_deleted()
|
||||
return self._create_local_shards()
|
||||
|
||||
@functools.cached_property
|
||||
def addressable_shards(self) -> Sequence[Shard]:
|
||||
self._check_if_deleted()
|
||||
return self.local_shards
|
||||
|
||||
@property
|
||||
def global_shards(self) -> Sequence[Shard]:
|
||||
self._check_if_deleted()
|
||||
if self.mesh.size == len(self._local_devices):
|
||||
return self.addressable_shards
|
||||
|
||||
# Populating global_shards lazily (i.e. when requested) because populating
|
||||
# sthem eagerly leads to a performance regression when training on large
|
||||
# models.
|
||||
# Also as this a cached property, once calculated, it should be cached. So
|
||||
# multiple accesses should be cheap.
|
||||
global_indices_rid = get_shard_indices_replica_ids(
|
||||
self._global_shape, self._global_mesh, self.mesh_axes)
|
||||
device_to_buffer = {db.device(): db for db in self._device_buffers}
|
||||
global_shards = []
|
||||
for device, (index, rid) in global_indices_rid.items():
|
||||
local_shard = device.process_index == device.client.process_index()
|
||||
buf = device_to_buffer[device] if local_shard else None
|
||||
if buf is not None and buf.aval is None:
|
||||
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
|
||||
sh = Shard(device, index, rid, buf)
|
||||
global_shards.append(sh)
|
||||
return global_shards
|
||||
|
||||
@property
|
||||
def _value(self):
|
||||
self._check_if_deleted()
|
||||
if self.is_fully_replicated:
|
||||
return np.asarray(self._device_buffers[0])
|
||||
|
||||
if self.mesh.is_multi_process:
|
||||
raise RuntimeError("Fetching value for GDA that spans non-addressable "
|
||||
"devices is not possible. You can use "
|
||||
"`jax.experimental.multihost_utils.process_allgather` "
|
||||
"for this use case.")
|
||||
unique_shards = [s.data.copy_to_host_async() or s
|
||||
for s in self.addressable_shards if s.replica_id == 0]
|
||||
npy_value = np.empty(self.shape, self.dtype)
|
||||
for s in unique_shards:
|
||||
npy_value[s.index] = np.asarray(s.data)
|
||||
return npy_value
|
||||
|
||||
def __array__(self, dtype=None, context=None):
|
||||
self._check_if_deleted()
|
||||
return self._value if dtype is None else self._value.astype(dtype)
|
||||
|
||||
def local_data(self, index) -> DeviceArray:
|
||||
self._check_if_deleted()
|
||||
return dispatch._set_aval(self._device_buffers[index])
|
||||
|
||||
def addressable_data(self, index) -> DeviceArray:
|
||||
self._check_if_deleted()
|
||||
return self.local_data(index)
|
||||
|
||||
def block_until_ready(self):
|
||||
self._check_if_deleted()
|
||||
for db in self._device_buffers:
|
||||
db.block_until_ready()
|
||||
return self
|
||||
|
||||
def _check_if_deleted(self):
|
||||
if self.is_deleted():
|
||||
raise RuntimeError("GlobalDeviceArray has been deleted.")
|
||||
|
||||
def is_deleted(self):
|
||||
return self._device_buffers is None
|
||||
|
||||
def delete(self):
|
||||
for b in self._device_buffers:
|
||||
b.delete()
|
||||
self._device_buffers = None
|
||||
|
||||
@property
|
||||
def sharding(self):
|
||||
return jax.sharding.NamedSharding(self._global_mesh, self.mesh_axes)
|
||||
|
||||
@classmethod
|
||||
def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes, data_callback: Callable[[Index],
|
||||
ArrayLike]):
|
||||
"""Constructs a GlobalDeviceArray via data fetched from ``data_callback``.
|
||||
|
||||
``data_callback`` is used to fetch the data for each local slice of the returned GlobalDeviceArray.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from jax.sharding import Mesh
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> global_input_shape = (8, 8)
|
||||
>>> mesh_axes = P('x', 'y')
|
||||
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
|
||||
>>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
...
|
||||
>>> def cb(index):
|
||||
... return global_input_data[index]
|
||||
...
|
||||
>>> gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb)
|
||||
>>> gda.addressable_data(0).shape
|
||||
(4, 2)
|
||||
|
||||
Args:
|
||||
global_shape : The global shape of the array
|
||||
global_mesh : The global mesh representing devices across multiple
|
||||
processes.
|
||||
mesh_axes : See the ``mesh_axes`` parameter of GlobalDeviceArray.
|
||||
data_callback : Callback that takes indices into the global array value as input and
|
||||
returns the corresponding data of the global array value. The data can be returned
|
||||
as any array-like object, e.g. a ``numpy.ndarray``.
|
||||
"""
|
||||
global_indices_rid = get_shard_indices_replica_ids(
|
||||
global_shape, global_mesh, mesh_axes)
|
||||
local_devices = global_mesh.local_devices
|
||||
dbs = [
|
||||
jax.device_put(data_callback(global_indices_rid[device][0]), device)
|
||||
for device in local_devices
|
||||
]
|
||||
if config.jax_array:
|
||||
return jax.make_array_from_single_device_arrays(
|
||||
global_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), dbs)
|
||||
return cls(global_shape, global_mesh, mesh_axes, dbs,
|
||||
_gda_fast_path_args=_GdaFastPathArgs(global_indices_rid, local_devices))
|
||||
|
||||
@classmethod
|
||||
def from_batched_callback(cls, global_shape: Shape,
|
||||
global_mesh: pxla.Mesh, mesh_axes: MeshAxes,
|
||||
data_callback: Callable[[Sequence[Index]],
|
||||
Sequence[ArrayLike]]):
|
||||
"""Constructs a GlobalDeviceArray via batched data fetched from ``data_callback``.
|
||||
|
||||
Like ``from_callback``, except the callback function is called only once to fetch all data
|
||||
local to this process.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from jax.sharding import Mesh
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> global_input_shape = (8, 2)
|
||||
>>> mesh_axes = P('x')
|
||||
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
|
||||
>>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
...
|
||||
>>> def batched_cb(indices):
|
||||
... assert len(indices) == len(global_mesh.local_devices)
|
||||
... return [global_input_data[index] for index in indices]
|
||||
...
|
||||
>>> gda = GlobalDeviceArray.from_batched_callback(global_input_shape, global_mesh, mesh_axes, batched_cb)
|
||||
>>> gda.addressable_data(0).shape
|
||||
(2, 2)
|
||||
|
||||
Args:
|
||||
global_shape : The global shape of the array
|
||||
global_mesh : The global mesh representing devices across multiple
|
||||
processes.
|
||||
mesh_axes : See the ``mesh_axes`` parameter of GlobalDeviceArray.
|
||||
data_callback : Callback that takes a batch of indices into the global array value with
|
||||
length equal to the number of local devices as input and returns the corresponding data for each index.
|
||||
The data can be returned as any array-like objects, e.g. ``numpy.ndarray``
|
||||
"""
|
||||
global_indices_rid = get_shard_indices_replica_ids(
|
||||
global_shape, global_mesh, mesh_axes)
|
||||
local_devices = global_mesh.local_devices
|
||||
local_indices = [global_indices_rid[d][0] for d in local_devices]
|
||||
local_arrays = data_callback(local_indices)
|
||||
dbs = pxla.device_put(local_arrays, local_devices)
|
||||
if config.jax_array:
|
||||
return jax.make_array_from_single_device_arrays(
|
||||
global_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), dbs) # type: ignore
|
||||
return cls(global_shape, global_mesh, mesh_axes, dbs,
|
||||
_gda_fast_path_args=_GdaFastPathArgs(global_indices_rid, local_devices))
|
||||
|
||||
@classmethod
|
||||
def from_batched_callback_with_devices(
|
||||
cls, global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes,
|
||||
data_callback: Callable[[Sequence[Tuple[Index, Tuple[Device, ...]]]],
|
||||
Sequence[DeviceArray]]):
|
||||
"""Constructs a GlobalDeviceArray via batched DeviceArrays fetched from ``data_callback``.
|
||||
|
||||
Like ``from_batched_callback``, except the callback function is responsible for returning on-device data (e.g. by calling ``jax.device_put``).
|
||||
|
||||
Example:
|
||||
|
||||
>>> from jax.sharding import Mesh
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> global_input_shape = (8, 2)
|
||||
>>> mesh_axes = P(('x', 'y'))
|
||||
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
|
||||
>>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
...
|
||||
>>> def cb(cb_inp):
|
||||
... dbs = []
|
||||
... for inp in cb_inp:
|
||||
... index, devices = inp
|
||||
... array = global_input_data[index]
|
||||
... dbs.extend([jax.device_put(array, device) for device in devices])
|
||||
... return dbs
|
||||
...
|
||||
>>> gda = GlobalDeviceArray.from_batched_callback_with_devices(
|
||||
... global_input_shape, global_mesh, mesh_axes, cb)
|
||||
>>> gda.addressable_data(0).shape
|
||||
(1, 2)
|
||||
|
||||
Args:
|
||||
global_shape : The global shape of the array
|
||||
global_mesh : The global mesh representing devices across multiple
|
||||
processes.
|
||||
mesh_axes : See the ``mesh_axes`` parameter of GlobalDeviceArray.
|
||||
data_callback : Callback that takes agets batch of indices into the global array value with
|
||||
length equal to the number of local devices as input and returns the corresponding data for
|
||||
each index. The data must be returned as jax DeviceArrays.
|
||||
"""
|
||||
global_indices_rid = get_shard_indices_replica_ids(
|
||||
global_shape, global_mesh, mesh_axes)
|
||||
local_devices = global_mesh.local_devices
|
||||
|
||||
index_to_device: Dict[int, Tuple[Index, List[Device]]] = {}
|
||||
for device in local_devices:
|
||||
index = global_indices_rid[device][0]
|
||||
h_index = _hashed_index(index)
|
||||
if h_index not in index_to_device:
|
||||
index_to_device[h_index] = (index, [device])
|
||||
else:
|
||||
index_to_device[h_index][1].append(device)
|
||||
|
||||
cb_inp = [
|
||||
(index, tuple(devices)) for index, devices in index_to_device.values()
|
||||
]
|
||||
dbs = data_callback(cb_inp)
|
||||
if config.jax_array:
|
||||
return jax.make_array_from_single_device_arrays(
|
||||
global_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), dbs) # type: ignore
|
||||
return cls(global_shape, global_mesh, mesh_axes, dbs,
|
||||
_gda_fast_path_args=_GdaFastPathArgs(global_indices_rid, local_devices))
|
||||
|
||||
|
||||
core.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray(
|
||||
x.shape, x.dtype)
|
||||
xla.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray(
|
||||
x.shape, x.dtype)
|
||||
xla.canonicalize_dtype_handlers[GlobalDeviceArray] = pxla.identity
|
||||
api_util._shaped_abstractify_handlers[GlobalDeviceArray] = \
|
||||
lambda x: core.ShapedArray(x.shape, x.dtype)
|
||||
|
||||
# This will only work when GDA is fully addressable i.e. on a single host or
|
||||
# fully replicated.
|
||||
def _gda_mlir_constant_handler(val, canonicalize_types=True):
|
||||
return mlir.ir_constants(val._value,
|
||||
canonicalize_types=canonicalize_types)
|
||||
mlir.register_constant_handler(GlobalDeviceArray, _gda_mlir_constant_handler)
|
||||
|
||||
|
||||
def _gda_shard_arg(x, devices, indices, sharding):
|
||||
x._check_if_deleted()
|
||||
return x._device_buffers
|
||||
pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg
|
||||
|
||||
|
||||
def _gda_array_result_handler(global_aval, out_sharding, committed,
|
||||
is_out_sharding_from_xla):
|
||||
if core.is_opaque_dtype(global_aval.dtype):
|
||||
return global_aval.dtype._rules.global_sharded_result_handler(
|
||||
global_aval, out_sharding, committed, is_out_sharding_from_xla)
|
||||
global_mesh, out_axis_resources = out_sharding.mesh, out_sharding.spec
|
||||
global_idx_rid = get_shard_indices_replica_ids(global_aval.shape, global_mesh,
|
||||
out_axis_resources)
|
||||
local_devices = global_mesh.local_devices
|
||||
fast_path_args = _GdaFastPathArgs(global_idx_rid, local_devices)
|
||||
return lambda bufs: GlobalDeviceArray(
|
||||
global_aval.shape, global_mesh, out_axis_resources, bufs, fast_path_args,
|
||||
_enable_checks=False)
|
||||
pxla.global_result_handlers[
|
||||
(core.ShapedArray, pxla.OutputType.GlobalDeviceArray)] = _gda_array_result_handler
|
||||
pxla.global_result_handlers[
|
||||
(core.ConcreteArray, pxla.OutputType.GlobalDeviceArray)] = _gda_array_result_handler
|
@ -3706,30 +3706,21 @@ def check_device_backend_on_shardings(shardings) -> bool:
|
||||
|
||||
def check_gda_or_array_xla_sharding_match(
|
||||
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> None:
|
||||
from jax._src.global_device_array import GlobalDeviceArray
|
||||
from jax._src.array import ArrayImpl
|
||||
|
||||
for arg, xs in safe_zip(args, in_xla_shardings):
|
||||
if not isinstance(arg, (GlobalDeviceArray, ArrayImpl)):
|
||||
if not isinstance(arg, ArrayImpl):
|
||||
continue
|
||||
if isinstance(arg, GlobalDeviceArray):
|
||||
arg_sharding = create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
|
||||
arg_type = 'GDA'
|
||||
committed = True
|
||||
else:
|
||||
arg_sharding = arg.sharding
|
||||
arg_type = 'Array'
|
||||
committed = arg._committed
|
||||
|
||||
# No need to cache this check since MeshExecutable has a C++ fast path
|
||||
# for AOT compiled call.
|
||||
if (not check_device_backend_on_shardings([xs]) and
|
||||
committed and
|
||||
not are_op_shardings_equal(arg_sharding._to_xla_op_sharding(arg.ndim),
|
||||
arg._committed and
|
||||
not are_op_shardings_equal(arg.sharding._to_xla_op_sharding(arg.ndim),
|
||||
xs._to_xla_op_sharding(arg.ndim))):
|
||||
raise ValueError(
|
||||
f"{arg_type} sharding does not match the input sharding. "
|
||||
f"Got {arg_type} sharding: {arg_sharding} and xla sharding: {xs} for "
|
||||
f"Array sharding does not match the input sharding. "
|
||||
f"Got Array sharding: {arg.sharding} and xla sharding: {xs} for "
|
||||
f"arg shape: {arg.shape}, arg value: {arg}")
|
||||
|
||||
|
||||
|
@ -37,7 +37,6 @@ from jax._src import traceback_util
|
||||
from jax._src.config import config
|
||||
from jax.errors import JAXTypeError
|
||||
from jax._src.array import ArrayImpl
|
||||
from jax._src.global_device_array import GlobalDeviceArray
|
||||
from jax._src.sharding_impls import NamedSharding
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
@ -521,13 +520,7 @@ def xmap(fun: Callable,
|
||||
lambda: tuple(_flatten_axes("xmap out_axes", out_tree(), out_axes, tupled_args=False)),
|
||||
closure=(out_axes_entries, out_axes_treedef))
|
||||
|
||||
if config.jax_array:
|
||||
in_positional_semantics = (_PositionalSemantics.GLOBAL,) * len(args_flat)
|
||||
else:
|
||||
in_positional_semantics = tuple(
|
||||
_PositionalSemantics.GLOBAL
|
||||
if isinstance(a, GlobalDeviceArray) else _positional_semantics.val
|
||||
for a in args_flat)
|
||||
in_positional_semantics = (_PositionalSemantics.GLOBAL,) * len(args_flat)
|
||||
out_positional_semantics = (
|
||||
_PositionalSemantics.GLOBAL
|
||||
if config.jax_array or config.jax_parallel_functions_output_gda
|
||||
@ -1638,7 +1631,7 @@ def _get_axis_sizes(args_flat: Iterable[Any],
|
||||
in_axes_flat: Iterable[AxisNamePos],
|
||||
global_axis_sizes: Dict[AxisName, int],
|
||||
axis_resource_count: Dict[AxisName, ResourceCount],
|
||||
in_positional_semantics: Sequence[bool]):
|
||||
in_positional_semantics: Sequence[_PositionalSemantics]):
|
||||
global_axis_sizes = dict(global_axis_sizes)
|
||||
for arg, in_axes, ips in zip(args_flat, in_axes_flat, in_positional_semantics):
|
||||
for name, dim in in_axes.items():
|
||||
@ -1824,26 +1817,22 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
|
||||
axis_resources, resource_env, global_axis_sizes,
|
||||
in_positional_semantics).to_mesh_axes(in_axes_flat)
|
||||
for arg, xmap_array_mapping in safe_zip(args_flat, mesh_in_axes):
|
||||
if isinstance(arg, (GlobalDeviceArray, ArrayImpl)):
|
||||
arr_flavor = 'GDA' if isinstance(arg, GlobalDeviceArray) else 'Array'
|
||||
if arr_flavor == 'Array' and not isinstance(arg.sharding, NamedSharding):
|
||||
if isinstance(arg, ArrayImpl):
|
||||
if not isinstance(arg.sharding, NamedSharding):
|
||||
continue
|
||||
mesh = arg.mesh if arr_flavor == 'GDA' else arg.sharding.mesh
|
||||
mesh = arg.sharding.mesh
|
||||
if mesh != resource_env.physical_mesh:
|
||||
raise ValueError(f"xmap's mesh and {arr_flavor}'s mesh should be equal. "
|
||||
raise ValueError("xmap's mesh and Array's mesh should be equal. "
|
||||
f"Got xmap mesh: {resource_env.physical_mesh},\n"
|
||||
f"{arr_flavor} mesh: {mesh}")
|
||||
f"Array mesh: {mesh}")
|
||||
|
||||
if arr_flavor == 'GDA':
|
||||
s = pxla.create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
|
||||
else:
|
||||
s = arg.sharding
|
||||
s = arg.sharding
|
||||
xmap_sharding = pxla.create_mesh_pspec_sharding(
|
||||
mesh, pxla.array_mapping_to_axis_resources(xmap_array_mapping))
|
||||
# This check is cached because comparing OpSharding is expensive during
|
||||
# dispatch and if the shardings are the same, then there is no need to
|
||||
# compare twice.
|
||||
_check_sharding(s, xmap_sharding, arg.ndim, arr_flavor)
|
||||
_check_sharding(s, xmap_sharding, arg.ndim, 'Array')
|
||||
|
||||
|
||||
# TODO: We should relax this at least for "constructor primitives"
|
||||
|
110
jax/_src/pjit.py
110
jax/_src/pjit.py
@ -28,7 +28,6 @@ import jax
|
||||
from jax._src import core
|
||||
from jax import stages
|
||||
from jax.errors import JAXTypeError
|
||||
from jax._src.global_device_array import GlobalDeviceArray as GDA
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters.pxla import PartitionSpec
|
||||
@ -491,28 +490,17 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
out_shardings = tree_map(
|
||||
lambda x: x if _is_unspecified(x) else
|
||||
_create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), user_out_shardings)
|
||||
# This check fails extremely rarely and has a huge cost in the dispatch
|
||||
# path. So hide it behind the jax_enable_checks flag.
|
||||
if jax.config.jax_enable_checks:
|
||||
_maybe_check_pjit_gda_mesh(args_flat, pjit_mesh)
|
||||
|
||||
del user_in_shardings, user_out_shardings
|
||||
|
||||
local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
|
||||
# TODO(yashkatariya): This is a hack. This should go away when avals have
|
||||
# is_global attribute.
|
||||
if jax.config.jax_array:
|
||||
in_positional_semantics = (pxla._PositionalSemantics.GLOBAL,) * len(args_flat)
|
||||
else:
|
||||
in_positional_semantics = tuple(tree_map(_get_in_positional_semantics, args_flat))
|
||||
out_positional_semantics = (
|
||||
pxla._PositionalSemantics.GLOBAL
|
||||
if jax.config.jax_parallel_functions_output_gda or jax.config.jax_array else
|
||||
pxla.positional_semantics.val)
|
||||
|
||||
in_positional_semantics = (pxla._PositionalSemantics.GLOBAL,) * len(args_flat)
|
||||
out_positional_semantics = pxla._PositionalSemantics.GLOBAL
|
||||
|
||||
global_in_avals, canonicalized_in_shardings_flat = _process_in_axis_resources(
|
||||
hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics,
|
||||
tuple(isinstance(a, GDA) for a in args_flat), resource_env)
|
||||
resource_env)
|
||||
|
||||
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
|
||||
flat_fun, hashable_pytree(out_shardings), global_in_avals,
|
||||
@ -842,7 +830,7 @@ class PytreeLeaf:
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def _process_in_axis_resources(in_shardings_thunk, local_in_avals,
|
||||
in_tree, in_positional_semantics, is_gda,
|
||||
in_tree, in_positional_semantics,
|
||||
resource_env):
|
||||
orig_in_shardings = in_shardings_thunk()
|
||||
# Only do this if original in_shardings are unspecified. If they are
|
||||
@ -854,66 +842,15 @@ def _process_in_axis_resources(in_shardings_thunk, local_in_avals,
|
||||
"pjit in_shardings", in_tree, orig_in_shardings,
|
||||
tupled_args=True)
|
||||
|
||||
# Fork here because the `Array` path is very simple and doesn't need all the
|
||||
# complexity below.
|
||||
if config.jax_array:
|
||||
pjit_check_aval_sharding(in_shardings_flat, local_in_avals, "pjit arguments",
|
||||
allow_uneven_sharding=False)
|
||||
global_in_avals = local_in_avals
|
||||
# TODO(yashkatariya): Only check for is_auto or _is_unspecified when
|
||||
# FROM_GDA is removed.
|
||||
canonicalized_shardings = tuple(
|
||||
i if _is_unspecified_or_from_gda_or_auto(i) else to_gspmd_sharding(i, aval.ndim)
|
||||
for i, aval in safe_zip(in_shardings_flat, global_in_avals))
|
||||
return tuple(global_in_avals), canonicalized_shardings
|
||||
|
||||
if not local_in_avals:
|
||||
assert not in_shardings_flat
|
||||
return (), ()
|
||||
|
||||
in_axis_resources_flat = tuple(
|
||||
i if _is_from_gda(i) or is_auto(i) else i._parsed_pspec
|
||||
for i in in_shardings_flat)
|
||||
|
||||
# This check should be above local_to_global call below otherwise if
|
||||
# `FROM_GDA` is passed to any input other than GDA, a ugly error message
|
||||
# will be raised because get_array_mapping (in local_to_global) of a
|
||||
# FROM_GDA cannot happen.
|
||||
tree_map(_check_resources_mismatch, in_axis_resources_flat, is_gda)
|
||||
# If all inputs have global semantics or fully replicated, then the avals are
|
||||
# global and the mesh should also be global. This split is because
|
||||
# non-contiguous mesh can only be used if all inputs have global semantics or
|
||||
# fully replicated.
|
||||
# Use canonicalized in_axis_resources here because we want to treat P(None)
|
||||
# and None (for example) as equivalent.
|
||||
if all(
|
||||
(not _is_from_gda(p) and not is_auto(p) and
|
||||
CanonicalizedParsedPartitionSpec(p).partitions == ()) or
|
||||
ips == pxla._PositionalSemantics.GLOBAL
|
||||
for p, ips in safe_zip(in_axis_resources_flat, in_positional_semantics)):
|
||||
# Shapes should be checked against non canonicalized in_axis_resources.
|
||||
# For example, partitions of () and ((),) are not equivalent, since the
|
||||
# first one is a valid spec for a scalar value, while the second is not!
|
||||
pjit_check_aval_sharding(in_shardings_flat, local_in_avals, "pjit arguments",
|
||||
allow_uneven_sharding=False)
|
||||
else:
|
||||
pjit_check_aval_sharding(
|
||||
[i if _is_from_gda(i) or is_auto(i) else
|
||||
NamedSharding(i.mesh.local_mesh, i.spec)
|
||||
for i in in_shardings_flat],
|
||||
local_in_avals, "pjit arguments", allow_uneven_sharding=False)
|
||||
|
||||
# Local or global avals doesn't matter for converting to op sharding because
|
||||
# the `ndim` does not change.
|
||||
canonicalized_in_shardings_flat = tuple(
|
||||
i if _is_from_gda(i) or is_auto(i) else to_gspmd_sharding(i, aval.ndim)
|
||||
for i, aval in safe_zip(in_shardings_flat, local_in_avals))
|
||||
|
||||
global_in_avals = local_to_global(
|
||||
in_positional_semantics, local_in_avals, canonicalized_in_shardings_flat,
|
||||
resource_env.physical_mesh)
|
||||
|
||||
return tuple(global_in_avals), canonicalized_in_shardings_flat
|
||||
pjit_check_aval_sharding(in_shardings_flat, local_in_avals, "pjit arguments",
|
||||
allow_uneven_sharding=False)
|
||||
global_in_avals = local_in_avals
|
||||
# TODO(yashkatariya): Only check for is_auto or _is_unspecified when
|
||||
# FROM_GDA is removed.
|
||||
canonicalized_shardings = tuple(
|
||||
i if _is_unspecified_or_from_gda_or_auto(i) else to_gspmd_sharding(i, aval.ndim)
|
||||
for i, aval in safe_zip(in_shardings_flat, global_in_avals))
|
||||
return tuple(global_in_avals), canonicalized_shardings
|
||||
|
||||
|
||||
@lu.cache
|
||||
@ -1156,11 +1093,6 @@ def _prepare_axis_resources(axis_resources,
|
||||
return tree_unflatten(treedef, new_entries), new_entries, treedef
|
||||
|
||||
|
||||
def _check_resources_mismatch(in_axis_resources_flat, is_gda):
|
||||
if not is_gda and _is_from_gda(in_axis_resources_flat):
|
||||
raise ValueError('For a non-GDA input, the corresponding resource in '
|
||||
'in_axis_resources cannot be `pjit.FROM_GDA`.')
|
||||
|
||||
def _check_unique_resources(axis_resources, arg_name):
|
||||
for arg_axis_resources in axis_resources:
|
||||
if not arg_axis_resources: continue
|
||||
@ -2204,11 +2136,6 @@ def _calc_is_global_sequence(in_positional_semantics, in_shardings):
|
||||
pxla.is_op_sharding_replicated(i._op_sharding))
|
||||
for ips, i in safe_zip(in_positional_semantics, in_shardings))
|
||||
|
||||
def _get_in_positional_semantics(arg) -> pxla._PositionalSemantics:
|
||||
if isinstance(arg, GDA):
|
||||
return pxla._PositionalSemantics.GLOBAL
|
||||
return pxla.positional_semantics.val
|
||||
|
||||
|
||||
def _fast_path_get_device_assignment(
|
||||
shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]:
|
||||
@ -2244,20 +2171,11 @@ def _maybe_replace_from_gda_with_pspec(
|
||||
out.append(in_sharding_flat)
|
||||
elif isinstance(arg, array.ArrayImpl):
|
||||
out.append(to_gspmd_sharding(arg.sharding, arg.ndim))
|
||||
elif isinstance(arg, GDA):
|
||||
gda_sharding = pxla.create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
|
||||
out.append(_gda_check_and_get_sharding(gda_sharding, in_sharding_flat, arg.ndim))
|
||||
else:
|
||||
out.append(in_sharding_flat)
|
||||
return tuple(out)
|
||||
|
||||
|
||||
def _maybe_check_pjit_gda_mesh(args, mesh):
|
||||
for x in args:
|
||||
if isinstance(x, GDA) and x.mesh != mesh:
|
||||
raise ValueError("Pjit's mesh and GDA's mesh should be equal. Got Pjit "
|
||||
f"mesh: {mesh},\n GDA mesh: {x.mesh}")
|
||||
|
||||
# -------------------- XLA OpSharding to PartitionSpec --------------------
|
||||
# Note that OpSharding is more expressive than PartitionSpecs, so it's not
|
||||
# always possible to convert them, but the code below should at least
|
||||
|
@ -26,7 +26,6 @@ from typing import Callable, Sequence, Optional, Dict, Any
|
||||
import jax
|
||||
from jax._src import distributed
|
||||
from jax._src.config import config
|
||||
from jax._src import global_device_array as gda
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
from jax._src import sharding_impls
|
||||
@ -64,27 +63,6 @@ async def create_async_array_from_callback(
|
||||
global_shape, inp_sharding, dbs)
|
||||
|
||||
|
||||
async def create_async_gda_from_callback(
|
||||
global_shape: gda.Shape,
|
||||
global_mesh: Mesh,
|
||||
mesh_axes: gda.MeshAxes,
|
||||
data_callback: Callable[[gda.Index], asyncio.Future],
|
||||
):
|
||||
global_idx_rid = gda.get_shard_indices_replica_ids(
|
||||
global_shape, global_mesh, mesh_axes)
|
||||
local_devices = global_mesh.local_devices
|
||||
future_arrays = [data_callback(global_idx_rid[d][0])
|
||||
for d in local_devices]
|
||||
# Pause here and come back to `from_async_callback()` when future_arrays are
|
||||
# ready. device_put cannot happen with future_arrays.
|
||||
local_arrays = await asyncio.gather(*future_arrays)
|
||||
|
||||
dbs = [jax.device_put(array, device)
|
||||
for array, device in zip(local_arrays, local_devices)]
|
||||
return gda.GlobalDeviceArray(global_shape, global_mesh, mesh_axes, dbs,
|
||||
gda._GdaFastPathArgs(global_idx_rid, local_devices))
|
||||
|
||||
|
||||
def _get_metadata(arr):
|
||||
if arr.dtype == jnp.bfloat16:
|
||||
# Tensorstore uses 'bfloat16', not '<V2'.
|
||||
@ -280,15 +258,7 @@ async def async_deserialize(in_sharding, tensorstore_spec,
|
||||
await byte_limiter.release_bytes(requested_bytes)
|
||||
return out
|
||||
|
||||
if config.jax_array:
|
||||
return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
|
||||
else:
|
||||
if not isinstance(in_sharding, sharding_impls.NamedSharding):
|
||||
raise ValueError('Deserializing a GlobalDeviceArray is only possible with '
|
||||
'a `NamedSharding` which consists of a `mesh` and '
|
||||
f'`pspec`, but got {in_sharding}')
|
||||
return await create_async_gda_from_callback(
|
||||
tuple(shape), in_sharding.mesh, in_sharding.spec, cb)
|
||||
return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
|
||||
|
||||
|
||||
def run_deserialization(shardings: Sequence[sharding.Sharding],
|
||||
|
@ -1,26 +0,0 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.global_device_array import (
|
||||
Device as Device,
|
||||
GlobalDeviceArray as GlobalDeviceArray,
|
||||
MeshAxes as MeshAxes,
|
||||
PartitionSpec as PartitionSpec,
|
||||
Shard as Shard,
|
||||
Shape as Shape,
|
||||
get_shard_indices as get_shard_indices,
|
||||
get_shard_shape as get_shard_shape,
|
||||
_get_sharding_spec as _get_sharding_spec,
|
||||
_hashed_index as _hashed_index,
|
||||
)
|
@ -54,7 +54,6 @@ from jax._src import random as random_internal
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.global_device_array import GlobalDeviceArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
@ -1093,16 +1092,6 @@ def _to_jax_dtype(tf_dtype):
|
||||
return dt
|
||||
|
||||
|
||||
def _maybe_decode_gda(gda_or_py_object: Any):
|
||||
"""Convert GlobalDeviceArray into numpy object."""
|
||||
if isinstance(gda_or_py_object, GlobalDeviceArray):
|
||||
if jax.process_count() != 1:
|
||||
raise RuntimeError("GlobalDeviceArray does not support multi-process"
|
||||
f" currently. Process num = {jax.process_count()}")
|
||||
return gda_or_py_object._value
|
||||
return gda_or_py_object
|
||||
|
||||
|
||||
def _tfval_to_tensor_jax_dtype(val: TfVal,
|
||||
jax_dtype: Optional[DType] = None,
|
||||
memoize_constants=False) -> Tuple[TfVal, DType]:
|
||||
@ -1153,8 +1142,7 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
|
||||
# The float0 type is not known to TF.
|
||||
if jax_dtype == dtypes.float0:
|
||||
val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype)
|
||||
tf_val = tf.convert_to_tensor(
|
||||
_maybe_decode_gda(val), dtype=conversion_dtype)
|
||||
tf_val = tf.convert_to_tensor(val, dtype=conversion_dtype)
|
||||
if do_memoize:
|
||||
_thread_local_state.constant_cache[const_key] = (val, tf_val)
|
||||
return tf_val, jax_dtype
|
||||
|
@ -30,7 +30,6 @@ from jax.interpreters import xla
|
||||
from jax._src import pjit as pjit_lib
|
||||
from jax.experimental.pjit import pjit, FROM_GDA
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src.global_device_array import GlobalDeviceArray
|
||||
from jax._src import distributed
|
||||
from jax._src import config as config_internal
|
||||
import numpy as np
|
||||
@ -60,9 +59,6 @@ def broadcast_one_to_all(in_tree: Any, is_source: Optional[bool] = None) -> Any:
|
||||
is_source = jax.process_index() == 0
|
||||
|
||||
def pre_pmap(x):
|
||||
if isinstance(x, GlobalDeviceArray):
|
||||
raise ValueError('GDAs cannot be broadcasted from source host to other '
|
||||
'hosts.')
|
||||
if is_source:
|
||||
return np.concatenate([
|
||||
x[None, ...],
|
||||
@ -148,35 +144,8 @@ def process_allgather(in_tree: Any, tiled: bool = False) -> Any:
|
||||
"""
|
||||
|
||||
def _pjit(inp):
|
||||
if jax.config.jax_array:
|
||||
return _handle_array_process_allgather(inp, tiled)
|
||||
else:
|
||||
if isinstance(inp, GlobalDeviceArray):
|
||||
if inp.is_fully_replicated:
|
||||
return np.asarray(inp.addressable_data(0))
|
||||
global_mesh = inp.mesh
|
||||
in_axis_resources = FROM_GDA
|
||||
else:
|
||||
# DA/SDA/np.array will be sharded based on global_mesh.local_mesh.
|
||||
# Shape of local_mesh will always be (1, local_device_count())
|
||||
devices = np.array(jax.devices()).reshape(jax.process_count(),
|
||||
jax.local_device_count())
|
||||
global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices'))
|
||||
in_axis_resources = P('processes')
|
||||
if inp.ndim == 0 or not tiled:
|
||||
inp = np.expand_dims(inp, axis=0)
|
||||
|
||||
with global_mesh:
|
||||
out = pjit(
|
||||
_identity_fn, in_shardings=in_axis_resources, out_shardings=None
|
||||
)(inp)
|
||||
return np.asarray(out.addressable_data(0))
|
||||
|
||||
if jax.config.jax_array:
|
||||
return jax.tree_map(_pjit, in_tree) # array route
|
||||
else:
|
||||
with config_internal.parallel_functions_output_gda(True):
|
||||
return jax.tree_map(_pjit, in_tree) # gda route
|
||||
return _handle_array_process_allgather(inp, tiled)
|
||||
return jax.tree_map(_pjit, in_tree)
|
||||
|
||||
|
||||
def assert_equal(in_tree, fail_message: str = ''):
|
||||
|
Loading…
x
Reference in New Issue
Block a user