mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Create same Sharding objects wherever possible to get maximum cache hits
PiperOrigin-RevId: 524116574
This commit is contained in:
parent
0fd5b2ca61
commit
c235f214d0
@ -2034,8 +2034,8 @@ def lower_sharding_computation(
|
||||
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
|
||||
any(not is_unspecified(o) for o in out_shardings))
|
||||
|
||||
in_shardings = tuple(sharding_impls.GSPMDSharding.get_replicated(device_assignment)
|
||||
if is_unspecified(i) else i for i in in_shardings)
|
||||
gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
|
||||
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
|
||||
|
||||
da_object = _create_da_object(tuple(device_assignment))
|
||||
|
||||
@ -2401,10 +2401,8 @@ def get_gspmd_shardings_from_executable(
|
||||
# just return SingleDeviceShardings since we know the computation is running
|
||||
# only on 1 device.
|
||||
if len(device_assignment) == 1:
|
||||
return ([sharding_impls.SingleDeviceSharding(device_assignment[0])
|
||||
for _ in range(num_in_avals)],
|
||||
[sharding_impls.SingleDeviceSharding(device_assignment[0])
|
||||
for _ in range(num_out_avals)])
|
||||
ss = sharding_impls.SingleDeviceSharding(device_assignment[0])
|
||||
return [ss] * num_in_avals, [ss] * num_out_avals
|
||||
|
||||
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
|
||||
|
||||
@ -2445,9 +2443,10 @@ orig_out_sharding_handlers: OrigHandlerType = {}
|
||||
def _gspmd_to_named_sharding(
|
||||
op_sharding: xc.OpSharding,
|
||||
self: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
|
||||
return sharding_impls.NamedSharding._from_parsed_pspec(
|
||||
self.mesh,
|
||||
sharding_impls.parse_flatten_op_sharding(op_sharding, self.mesh)[0])
|
||||
parsed_pspec = sharding_impls.parse_flatten_op_sharding(
|
||||
op_sharding, self.mesh)[0]
|
||||
return create_mesh_pspec_sharding(
|
||||
self.mesh, parsed_pspec.get_partition_spec(), parsed_pspec)
|
||||
orig_out_sharding_handlers[sharding_impls.NamedSharding] = _gspmd_to_named_sharding
|
||||
|
||||
|
||||
@ -2844,12 +2843,9 @@ def check_arg_avals_for_call(ref_avals, arg_avals):
|
||||
def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
|
||||
# Create replicated shardings for jit(pmap) path with local devices
|
||||
# because multihost jit(pmap) is not allowed.
|
||||
in_shardings = [
|
||||
sharding_impls.GSPMDSharding.get_replicated(local_devices)
|
||||
] * num_in_shardings
|
||||
out_shardings = [
|
||||
sharding_impls.GSPMDSharding.get_replicated(local_devices)
|
||||
] * num_out_shardings
|
||||
gs = sharding_impls.GSPMDSharding.get_replicated(local_devices)
|
||||
in_shardings = [gs] * num_in_shardings
|
||||
out_shardings = [gs] * num_out_shardings
|
||||
# jit(pmap) will generate Arrays with multi-device sharding.
|
||||
# It is unsupported for these shardings to be uncommited, so force
|
||||
# the outputs to be committed.
|
||||
|
@ -1858,13 +1858,16 @@ pxla.custom_resource_typing_rules[sharding_constraint_p] = \
|
||||
|
||||
# -------------------- helpers --------------------
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _cached_gspmd_sharding(s, ndim):
|
||||
gs = GSPMDSharding(s._device_assignment, s._to_xla_op_sharding(ndim))
|
||||
gs._original_sharding = s
|
||||
return gs
|
||||
|
||||
def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int) -> GSPMDSharding:
|
||||
if isinstance(s, GSPMDSharding):
|
||||
return s
|
||||
gspmd_sharding = GSPMDSharding(
|
||||
s._device_assignment, s._to_xla_op_sharding(ndim))
|
||||
gspmd_sharding._original_sharding = s
|
||||
return gspmd_sharding
|
||||
return _cached_gspmd_sharding(s, ndim)
|
||||
|
||||
|
||||
def get_unconstrained_dims(sharding: NamedSharding):
|
||||
|
@ -26,6 +26,13 @@ Index = Tuple[slice, ...]
|
||||
XLADeviceAssignment = Sequence[Device]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def _addressable_devices_indices_map(
|
||||
sharding: Sharding, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
|
||||
return {d: ind for d, ind in sharding.devices_indices_map(global_shape).items()
|
||||
if d.process_index == d.client.process_index()}
|
||||
|
||||
|
||||
@util.use_cpp_class(xc.Sharding)
|
||||
class Sharding:
|
||||
"""Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out
|
||||
@ -86,7 +93,6 @@ class Sharding:
|
||||
# The pytype disable is because pytype can't recognize a cached property.
|
||||
return len(self.device_set) == len(self.addressable_devices) # type: ignore
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def addressable_devices_indices_map(
|
||||
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
|
||||
"""A mapping from addressable device to the slice of global data it contains.
|
||||
@ -94,5 +100,4 @@ class Sharding:
|
||||
``addressable_devices_indices_map`` contains that part of
|
||||
``device_indices_map`` that applies to the addressable devices.
|
||||
"""
|
||||
return {d: ind for d, ind in self.devices_indices_map(global_shape).items()
|
||||
if d.process_index == d.client.process_index()}
|
||||
return _addressable_devices_indices_map(self, global_shape)
|
||||
|
@ -42,7 +42,7 @@ from jax.experimental.maps import xmap
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax._src import array
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding import Sharding, _addressable_devices_indices_map
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_impls
|
||||
from jax._src.sharding_impls import (
|
||||
@ -3177,6 +3177,30 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(out4.sharding, SingleDeviceSharding)
|
||||
self.assertEqual(out4.device(), jax.devices()[1])
|
||||
|
||||
def test_get_indices_cache(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
ns = NamedSharding(mesh, P('x'))
|
||||
ns2 = NamedSharding(mesh, P('x', 'y'))
|
||||
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr1 = jax.device_put(np_inp, ns)
|
||||
arr2 = jax.device_put(np_inp, ns2)
|
||||
arr3 = jax.device_put(np_inp, ns)
|
||||
|
||||
cache_info1 = _addressable_devices_indices_map.cache_info()
|
||||
out = pjit(lambda x, y, z: x + y + z)(arr1, arr2, arr3)
|
||||
cache_info2 = _addressable_devices_indices_map.cache_info()
|
||||
self.assertArraysEqual(out, np_inp * 3)
|
||||
|
||||
# arr3 and arr1 should have the same GSPMDSharding objects internally.
|
||||
# So there will be 2 hits in _addressable_devices_indices_map,
|
||||
# One in `pxla._get_input_indices` and second in `_array_shard_arg`.
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits + 2)
|
||||
# There will double the amount of misses as hits because arr1 and arr2's
|
||||
# sharding are not the same. So 2 misses in _addressable_devices_indices_map
|
||||
# and 2 in _array_shard_arg.
|
||||
self.assertEqual(cache_info2.misses, cache_info1.misses + 4)
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user