mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Speed up GDA initialization by making local_shards lazy and hiding checks behind config.jax_enable_checks
flag.
PiperOrigin-RevId: 438115859
This commit is contained in:
parent
085d3901fd
commit
e08bc27bf0
@ -87,6 +87,26 @@ def indices_replica_id_calc_cached(mesh_shape, mesh_axes, state):
|
||||
gda.get_shard_indices_replica_ids(global_input_shape, global_mesh, mesh_axes)
|
||||
|
||||
|
||||
def gda_local_shards(mesh_shape, mesh_axes, state):
|
||||
# `device_put` time is not measured in this benchmark. All the devices here
|
||||
# are local.
|
||||
global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y"))
|
||||
global_input_shape = (2048, 2048)
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
global_indices = gda.get_shard_indices(global_input_shape, global_mesh,
|
||||
mesh_axes)
|
||||
dbs = [
|
||||
jax.device_put(global_input_data[global_indices[device]], device)
|
||||
for device in global_mesh.local_devices
|
||||
]
|
||||
gda_inp = gda.GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes,
|
||||
dbs)
|
||||
|
||||
while state:
|
||||
gda_inp._create_local_shards()
|
||||
|
||||
|
||||
benchmarks = []
|
||||
for mesh_shape, axes in mesh_shapes_axes:
|
||||
benchmarks.extend([
|
||||
@ -102,6 +122,9 @@ for mesh_shape, axes in mesh_shapes_axes:
|
||||
google_benchmark.register(
|
||||
partial(indices_replica_id_calc_cached, mesh_shape, axes),
|
||||
name=f"indices_replica_id_calc_cached_{mesh_shape}_{axes}"),
|
||||
google_benchmark.register(
|
||||
partial(gda_local_shards, mesh_shape, axes),
|
||||
name=f"gda_local_shards_{mesh_shape}_{axes}"),
|
||||
])
|
||||
|
||||
|
||||
|
@ -21,6 +21,7 @@ from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Di
|
||||
from jax import core
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.config import config
|
||||
from jax.interpreters import pxla, xla
|
||||
from jax._src.util import prod, safe_zip, cache
|
||||
from jax._src.api import device_put
|
||||
@ -130,6 +131,12 @@ def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
|
||||
return tuple(chunk_size)
|
||||
|
||||
|
||||
def _set_aval(val):
|
||||
if val.aval is None:
|
||||
val.aval = core.ShapedArray(val.shape, val.dtype)
|
||||
return val
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Shard:
|
||||
"""A single data shard of a GlobalDeviceArray.
|
||||
@ -274,7 +281,8 @@ class GlobalDeviceArray:
|
||||
|
||||
def __init__(self, global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes, device_buffers: Sequence[DeviceArray],
|
||||
_gda_fast_path_args: Optional[_GdaFastPathArgs] = None):
|
||||
_gda_fast_path_args: Optional[_GdaFastPathArgs] = None,
|
||||
_enable_checks: bool = True):
|
||||
self._global_shape = global_shape
|
||||
self._global_mesh = global_mesh
|
||||
self._mesh_axes = mesh_axes
|
||||
@ -288,24 +296,25 @@ class GlobalDeviceArray:
|
||||
else:
|
||||
self._local_devices = self._gda_fast_path_args.local_devices
|
||||
|
||||
for db, ld in safe_zip(device_buffers, self._local_devices):
|
||||
if db.device() != ld:
|
||||
raise ValueError(
|
||||
"The `global_mesh.local_devices` and `device_buffers` device order "
|
||||
"doesn't match. Please use `global_mesh.local_devices` to put "
|
||||
"arrays on devices instead of `jax.local_devices()`")
|
||||
if _enable_checks or config.jax_enable_checks:
|
||||
for db, ld in safe_zip(device_buffers, self._local_devices):
|
||||
if db.device() != ld:
|
||||
raise ValueError(
|
||||
"The `global_mesh.local_devices` and `device_buffers` device "
|
||||
"order doesn't match. Please use `global_mesh.local_devices` to "
|
||||
"put arrays on devices instead of `jax.local_devices()`")
|
||||
|
||||
self._local_shards = self._create_local_shards()
|
||||
|
||||
ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes)
|
||||
assert all(db.shape == ss for db in device_buffers), (
|
||||
f"Expected shard shape {ss} doesn't match the device buffer "
|
||||
f"shape, got: {[db.shape for db in device_buffers]}")
|
||||
if _enable_checks or config.jax_enable_checks:
|
||||
ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes)
|
||||
assert all(db.shape == ss for db in device_buffers), (
|
||||
f"Expected shard shape {ss} doesn't match the device buffer "
|
||||
f"shape, got: {[db.shape for db in device_buffers]}")
|
||||
|
||||
dtype = device_buffers[0].dtype
|
||||
assert all(db.dtype == dtype for db in device_buffers), (
|
||||
"Input arrays to GlobalDeviceArray must have matching dtypes, "
|
||||
f"got: {[db.dtype for db in device_buffers]}")
|
||||
if _enable_checks or config.jax_enable_checks:
|
||||
assert all(db.dtype == dtype for db in device_buffers), (
|
||||
"Input arrays to GlobalDeviceArray must have matching dtypes, "
|
||||
f"got: {[db.dtype for db in device_buffers]}")
|
||||
self.dtype = dtype
|
||||
|
||||
def __eq__(self, other: object):
|
||||
@ -358,16 +367,15 @@ class GlobalDeviceArray:
|
||||
|
||||
out = []
|
||||
for db in self._device_buffers:
|
||||
if db.aval is None:
|
||||
db.aval = core.ShapedArray(db.shape, db.dtype)
|
||||
db = _set_aval(db)
|
||||
device = db.device()
|
||||
index, rid = global_indices_rid[device]
|
||||
out.append(Shard(device, index, rid, db))
|
||||
return out
|
||||
|
||||
@property
|
||||
@pxla.maybe_cached_property
|
||||
def local_shards(self) -> Sequence[Shard]:
|
||||
return self._local_shards
|
||||
return self._create_local_shards()
|
||||
|
||||
@property
|
||||
def global_shards(self) -> Sequence[Shard]:
|
||||
@ -393,11 +401,11 @@ class GlobalDeviceArray:
|
||||
return global_shards
|
||||
|
||||
def local_data(self, index) -> DeviceArray:
|
||||
return self.local_shards[index].data
|
||||
return _set_aval(self._device_buffers[index])
|
||||
|
||||
def block_until_ready(self):
|
||||
for s in self.local_shards:
|
||||
s.data.block_until_ready()
|
||||
for db in self._device_buffers:
|
||||
db.block_until_ready()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@ -563,7 +571,7 @@ xla.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray(
|
||||
xla.canonicalize_dtype_handlers[GlobalDeviceArray] = pxla.identity
|
||||
|
||||
def _gda_shard_arg(x, devices, indices):
|
||||
return [s.data for s in x.local_shards]
|
||||
return [d for d in x._device_buffers]
|
||||
pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg
|
||||
|
||||
|
||||
@ -573,6 +581,7 @@ def _gda_array_result_handler(global_aval, out_axis_resources, global_mesh):
|
||||
local_devices = global_mesh.local_devices
|
||||
fast_path_args = _GdaFastPathArgs(global_idx_rid, local_devices)
|
||||
return lambda bufs: GlobalDeviceArray(
|
||||
global_aval.shape, global_mesh, out_axis_resources, bufs, fast_path_args)
|
||||
global_aval.shape, global_mesh, out_axis_resources, bufs, fast_path_args,
|
||||
_enable_checks=False)
|
||||
pxla.global_result_handlers[core.ShapedArray] = _gda_array_result_handler
|
||||
pxla.global_result_handlers[core.ConcreteArray] = _gda_array_result_handler
|
||||
|
Loading…
x
Reference in New Issue
Block a user