Create same Sharding objects wherever possible to get maximum cache hits

PiperOrigin-RevId: 524116574
This commit is contained in:
Yash Katariya 2023-04-13 15:18:56 -07:00 committed by jax authors
parent 0fd5b2ca61
commit c235f214d0
4 changed files with 51 additions and 23 deletions

View File

@ -2034,8 +2034,8 @@ def lower_sharding_computation(
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not is_unspecified(o) for o in out_shardings))
in_shardings = tuple(sharding_impls.GSPMDSharding.get_replicated(device_assignment)
if is_unspecified(i) else i for i in in_shardings)
gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
da_object = _create_da_object(tuple(device_assignment))
@ -2401,10 +2401,8 @@ def get_gspmd_shardings_from_executable(
# just return SingleDeviceShardings since we know the computation is running
# only on 1 device.
if len(device_assignment) == 1:
return ([sharding_impls.SingleDeviceSharding(device_assignment[0])
for _ in range(num_in_avals)],
[sharding_impls.SingleDeviceSharding(device_assignment[0])
for _ in range(num_out_avals)])
ss = sharding_impls.SingleDeviceSharding(device_assignment[0])
return [ss] * num_in_avals, [ss] * num_out_avals
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
@ -2445,9 +2443,10 @@ orig_out_sharding_handlers: OrigHandlerType = {}
def _gspmd_to_named_sharding(
op_sharding: xc.OpSharding,
self: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
return sharding_impls.NamedSharding._from_parsed_pspec(
self.mesh,
sharding_impls.parse_flatten_op_sharding(op_sharding, self.mesh)[0])
parsed_pspec = sharding_impls.parse_flatten_op_sharding(
op_sharding, self.mesh)[0]
return create_mesh_pspec_sharding(
self.mesh, parsed_pspec.get_partition_spec(), parsed_pspec)
orig_out_sharding_handlers[sharding_impls.NamedSharding] = _gspmd_to_named_sharding
@ -2844,12 +2843,9 @@ def check_arg_avals_for_call(ref_avals, arg_avals):
def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
# Create replicated shardings for jit(pmap) path with local devices
# because multihost jit(pmap) is not allowed.
in_shardings = [
sharding_impls.GSPMDSharding.get_replicated(local_devices)
] * num_in_shardings
out_shardings = [
sharding_impls.GSPMDSharding.get_replicated(local_devices)
] * num_out_shardings
gs = sharding_impls.GSPMDSharding.get_replicated(local_devices)
in_shardings = [gs] * num_in_shardings
out_shardings = [gs] * num_out_shardings
# jit(pmap) will generate Arrays with multi-device sharding.
# It is unsupported for these shardings to be uncommited, so force
# the outputs to be committed.

View File

@ -1858,13 +1858,16 @@ pxla.custom_resource_typing_rules[sharding_constraint_p] = \
# -------------------- helpers --------------------
@lru_cache(maxsize=2048)
def _cached_gspmd_sharding(s, ndim):
gs = GSPMDSharding(s._device_assignment, s._to_xla_op_sharding(ndim))
gs._original_sharding = s
return gs
def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int) -> GSPMDSharding:
if isinstance(s, GSPMDSharding):
return s
gspmd_sharding = GSPMDSharding(
s._device_assignment, s._to_xla_op_sharding(ndim))
gspmd_sharding._original_sharding = s
return gspmd_sharding
return _cached_gspmd_sharding(s, ndim)
def get_unconstrained_dims(sharding: NamedSharding):

View File

@ -26,6 +26,13 @@ Index = Tuple[slice, ...]
XLADeviceAssignment = Sequence[Device]
@functools.lru_cache(maxsize=4096)
def _addressable_devices_indices_map(
sharding: Sharding, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
return {d: ind for d, ind in sharding.devices_indices_map(global_shape).items()
if d.process_index == d.client.process_index()}
@util.use_cpp_class(xc.Sharding)
class Sharding:
"""Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out
@ -86,7 +93,6 @@ class Sharding:
# The pytype disable is because pytype can't recognize a cached property.
return len(self.device_set) == len(self.addressable_devices) # type: ignore
@functools.lru_cache(maxsize=4096)
def addressable_devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
"""A mapping from addressable device to the slice of global data it contains.
@ -94,5 +100,4 @@ class Sharding:
``addressable_devices_indices_map`` contains that part of
``device_indices_map`` that applies to the addressable devices.
"""
return {d: ind for d, ind in self.devices_indices_map(global_shape).items()
if d.process_index == d.client.process_index()}
return _addressable_devices_indices_map(self, global_shape)

View File

@ -42,7 +42,7 @@ from jax.experimental.maps import xmap
from jax.experimental import multihost_utils
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src import array
from jax._src.sharding import Sharding
from jax._src.sharding import Sharding, _addressable_devices_indices_map
from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src.sharding_impls import (
@ -3177,6 +3177,30 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertIsInstance(out4.sharding, SingleDeviceSharding)
self.assertEqual(out4.device(), jax.devices()[1])
def test_get_indices_cache(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))
ns2 = NamedSharding(mesh, P('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
arr1 = jax.device_put(np_inp, ns)
arr2 = jax.device_put(np_inp, ns2)
arr3 = jax.device_put(np_inp, ns)
cache_info1 = _addressable_devices_indices_map.cache_info()
out = pjit(lambda x, y, z: x + y + z)(arr1, arr2, arr3)
cache_info2 = _addressable_devices_indices_map.cache_info()
self.assertArraysEqual(out, np_inp * 3)
# arr3 and arr1 should have the same GSPMDSharding objects internally.
# So there will be 2 hits in _addressable_devices_indices_map,
# One in `pxla._get_input_indices` and second in `_array_shard_arg`.
self.assertEqual(cache_info2.hits, cache_info1.hits + 2)
# There will double the amount of misses as hits because arr1 and arr2's
# sharding are not the same. So 2 misses in _addressable_devices_indices_map
# and 2 in _array_shard_arg.
self.assertEqual(cache_info2.misses, cache_info1.misses + 4)
class TempSharding(Sharding):