Speed up GDA initialization by making local_shards lazy and hiding checks behind config.jax_enable_checks flag.

PiperOrigin-RevId: 438115859
This commit is contained in:
Yash Katariya 2022-03-29 13:50:45 -07:00 committed by jax authors
parent 085d3901fd
commit e08bc27bf0
2 changed files with 57 additions and 25 deletions

View File

@ -87,6 +87,26 @@ def indices_replica_id_calc_cached(mesh_shape, mesh_axes, state):
gda.get_shard_indices_replica_ids(global_input_shape, global_mesh, mesh_axes)
def gda_local_shards(mesh_shape, mesh_axes, state):
# `device_put` time is not measured in this benchmark. All the devices here
# are local.
global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y"))
global_input_shape = (2048, 2048)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
global_indices = gda.get_shard_indices(global_input_shape, global_mesh,
mesh_axes)
dbs = [
jax.device_put(global_input_data[global_indices[device]], device)
for device in global_mesh.local_devices
]
gda_inp = gda.GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes,
dbs)
while state:
gda_inp._create_local_shards()
benchmarks = []
for mesh_shape, axes in mesh_shapes_axes:
benchmarks.extend([
@ -102,6 +122,9 @@ for mesh_shape, axes in mesh_shapes_axes:
google_benchmark.register(
partial(indices_replica_id_calc_cached, mesh_shape, axes),
name=f"indices_replica_id_calc_cached_{mesh_shape}_{axes}"),
google_benchmark.register(
partial(gda_local_shards, mesh_shape, axes),
name=f"gda_local_shards_{mesh_shape}_{axes}"),
])

View File

@ -21,6 +21,7 @@ from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Di
from jax import core
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.config import config
from jax.interpreters import pxla, xla
from jax._src.util import prod, safe_zip, cache
from jax._src.api import device_put
@ -130,6 +131,12 @@ def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
return tuple(chunk_size)
def _set_aval(val):
if val.aval is None:
val.aval = core.ShapedArray(val.shape, val.dtype)
return val
@dataclasses.dataclass(frozen=True)
class Shard:
"""A single data shard of a GlobalDeviceArray.
@ -274,7 +281,8 @@ class GlobalDeviceArray:
def __init__(self, global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes, device_buffers: Sequence[DeviceArray],
_gda_fast_path_args: Optional[_GdaFastPathArgs] = None):
_gda_fast_path_args: Optional[_GdaFastPathArgs] = None,
_enable_checks: bool = True):
self._global_shape = global_shape
self._global_mesh = global_mesh
self._mesh_axes = mesh_axes
@ -288,24 +296,25 @@ class GlobalDeviceArray:
else:
self._local_devices = self._gda_fast_path_args.local_devices
for db, ld in safe_zip(device_buffers, self._local_devices):
if db.device() != ld:
raise ValueError(
"The `global_mesh.local_devices` and `device_buffers` device order "
"doesn't match. Please use `global_mesh.local_devices` to put "
"arrays on devices instead of `jax.local_devices()`")
if _enable_checks or config.jax_enable_checks:
for db, ld in safe_zip(device_buffers, self._local_devices):
if db.device() != ld:
raise ValueError(
"The `global_mesh.local_devices` and `device_buffers` device "
"order doesn't match. Please use `global_mesh.local_devices` to "
"put arrays on devices instead of `jax.local_devices()`")
self._local_shards = self._create_local_shards()
ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes)
assert all(db.shape == ss for db in device_buffers), (
f"Expected shard shape {ss} doesn't match the device buffer "
f"shape, got: {[db.shape for db in device_buffers]}")
if _enable_checks or config.jax_enable_checks:
ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes)
assert all(db.shape == ss for db in device_buffers), (
f"Expected shard shape {ss} doesn't match the device buffer "
f"shape, got: {[db.shape for db in device_buffers]}")
dtype = device_buffers[0].dtype
assert all(db.dtype == dtype for db in device_buffers), (
"Input arrays to GlobalDeviceArray must have matching dtypes, "
f"got: {[db.dtype for db in device_buffers]}")
if _enable_checks or config.jax_enable_checks:
assert all(db.dtype == dtype for db in device_buffers), (
"Input arrays to GlobalDeviceArray must have matching dtypes, "
f"got: {[db.dtype for db in device_buffers]}")
self.dtype = dtype
def __eq__(self, other: object):
@ -358,16 +367,15 @@ class GlobalDeviceArray:
out = []
for db in self._device_buffers:
if db.aval is None:
db.aval = core.ShapedArray(db.shape, db.dtype)
db = _set_aval(db)
device = db.device()
index, rid = global_indices_rid[device]
out.append(Shard(device, index, rid, db))
return out
@property
@pxla.maybe_cached_property
def local_shards(self) -> Sequence[Shard]:
return self._local_shards
return self._create_local_shards()
@property
def global_shards(self) -> Sequence[Shard]:
@ -393,11 +401,11 @@ class GlobalDeviceArray:
return global_shards
def local_data(self, index) -> DeviceArray:
return self.local_shards[index].data
return _set_aval(self._device_buffers[index])
def block_until_ready(self):
for s in self.local_shards:
s.data.block_until_ready()
for db in self._device_buffers:
db.block_until_ready()
return self
@classmethod
@ -563,7 +571,7 @@ xla.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray(
xla.canonicalize_dtype_handlers[GlobalDeviceArray] = pxla.identity
def _gda_shard_arg(x, devices, indices):
return [s.data for s in x.local_shards]
return [d for d in x._device_buffers]
pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg
@ -573,6 +581,7 @@ def _gda_array_result_handler(global_aval, out_axis_resources, global_mesh):
local_devices = global_mesh.local_devices
fast_path_args = _GdaFastPathArgs(global_idx_rid, local_devices)
return lambda bufs: GlobalDeviceArray(
global_aval.shape, global_mesh, out_axis_resources, bufs, fast_path_args)
global_aval.shape, global_mesh, out_axis_resources, bufs, fast_path_args,
_enable_checks=False)
pxla.global_result_handlers[core.ShapedArray] = _gda_array_result_handler
pxla.global_result_handlers[core.ConcreteArray] = _gda_array_result_handler