diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 92564b1d9..795b50f20 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2034,8 +2034,8 @@ def lower_sharding_computation( any(not is_unspecified(js) for js, _ in jaxpr_sharding) or any(not is_unspecified(o) for o in out_shardings)) - in_shardings = tuple(sharding_impls.GSPMDSharding.get_replicated(device_assignment) - if is_unspecified(i) else i for i in in_shardings) + gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment) + in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) da_object = _create_da_object(tuple(device_assignment)) @@ -2401,10 +2401,8 @@ def get_gspmd_shardings_from_executable( # just return SingleDeviceShardings since we know the computation is running # only on 1 device. if len(device_assignment) == 1: - return ([sharding_impls.SingleDeviceSharding(device_assignment[0]) - for _ in range(num_in_avals)], - [sharding_impls.SingleDeviceSharding(device_assignment[0]) - for _ in range(num_out_avals)]) + ss = sharding_impls.SingleDeviceSharding(device_assignment[0]) + return [ss] * num_in_avals, [ss] * num_out_avals in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable) @@ -2445,9 +2443,10 @@ orig_out_sharding_handlers: OrigHandlerType = {} def _gspmd_to_named_sharding( op_sharding: xc.OpSharding, self: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding: - return sharding_impls.NamedSharding._from_parsed_pspec( - self.mesh, - sharding_impls.parse_flatten_op_sharding(op_sharding, self.mesh)[0]) + parsed_pspec = sharding_impls.parse_flatten_op_sharding( + op_sharding, self.mesh)[0] + return create_mesh_pspec_sharding( + self.mesh, parsed_pspec.get_partition_spec(), parsed_pspec) orig_out_sharding_handlers[sharding_impls.NamedSharding] = _gspmd_to_named_sharding @@ -2844,12 +2843,9 @@ def check_arg_avals_for_call(ref_avals, arg_avals): def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): # Create replicated shardings for jit(pmap) path with local devices # because multihost jit(pmap) is not allowed. - in_shardings = [ - sharding_impls.GSPMDSharding.get_replicated(local_devices) - ] * num_in_shardings - out_shardings = [ - sharding_impls.GSPMDSharding.get_replicated(local_devices) - ] * num_out_shardings + gs = sharding_impls.GSPMDSharding.get_replicated(local_devices) + in_shardings = [gs] * num_in_shardings + out_shardings = [gs] * num_out_shardings # jit(pmap) will generate Arrays with multi-device sharding. # It is unsupported for these shardings to be uncommited, so force # the outputs to be committed. diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 815317f7c..82abfbe8b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1858,13 +1858,16 @@ pxla.custom_resource_typing_rules[sharding_constraint_p] = \ # -------------------- helpers -------------------- +@lru_cache(maxsize=2048) +def _cached_gspmd_sharding(s, ndim): + gs = GSPMDSharding(s._device_assignment, s._to_xla_op_sharding(ndim)) + gs._original_sharding = s + return gs + def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int) -> GSPMDSharding: if isinstance(s, GSPMDSharding): return s - gspmd_sharding = GSPMDSharding( - s._device_assignment, s._to_xla_op_sharding(ndim)) - gspmd_sharding._original_sharding = s - return gspmd_sharding + return _cached_gspmd_sharding(s, ndim) def get_unconstrained_dims(sharding: NamedSharding): diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index 3f5a96299..bc2526394 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -26,6 +26,13 @@ Index = Tuple[slice, ...] XLADeviceAssignment = Sequence[Device] +@functools.lru_cache(maxsize=4096) +def _addressable_devices_indices_map( + sharding: Sharding, global_shape: Shape) -> Mapping[Device, Optional[Index]]: + return {d: ind for d, ind in sharding.devices_indices_map(global_shape).items() + if d.process_index == d.client.process_index()} + + @util.use_cpp_class(xc.Sharding) class Sharding: """Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out @@ -86,7 +93,6 @@ class Sharding: # The pytype disable is because pytype can't recognize a cached property. return len(self.device_set) == len(self.addressable_devices) # type: ignore - @functools.lru_cache(maxsize=4096) def addressable_devices_indices_map( self, global_shape: Shape) -> Mapping[Device, Optional[Index]]: """A mapping from addressable device to the slice of global data it contains. @@ -94,5 +100,4 @@ class Sharding: ``addressable_devices_indices_map`` contains that part of ``device_indices_map`` that applies to the addressable devices. """ - return {d: ind for d, ind in self.devices_indices_map(global_shape).items() - if d.process_index == d.client.process_index()} + return _addressable_devices_indices_map(self, global_shape) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 7e31f5f70..8568ae689 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -42,7 +42,7 @@ from jax.experimental.maps import xmap from jax.experimental import multihost_utils from jax.experimental.custom_partitioning import custom_partitioning from jax._src import array -from jax._src.sharding import Sharding +from jax._src.sharding import Sharding, _addressable_devices_indices_map from jax._src import op_shardings from jax._src import sharding_impls from jax._src.sharding_impls import ( @@ -3177,6 +3177,30 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertIsInstance(out4.sharding, SingleDeviceSharding) self.assertEqual(out4.device(), jax.devices()[1]) + def test_get_indices_cache(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + ns = NamedSharding(mesh, P('x')) + ns2 = NamedSharding(mesh, P('x', 'y')) + + np_inp = np.arange(16).reshape(8, 2) + arr1 = jax.device_put(np_inp, ns) + arr2 = jax.device_put(np_inp, ns2) + arr3 = jax.device_put(np_inp, ns) + + cache_info1 = _addressable_devices_indices_map.cache_info() + out = pjit(lambda x, y, z: x + y + z)(arr1, arr2, arr3) + cache_info2 = _addressable_devices_indices_map.cache_info() + self.assertArraysEqual(out, np_inp * 3) + + # arr3 and arr1 should have the same GSPMDSharding objects internally. + # So there will be 2 hits in _addressable_devices_indices_map, + # One in `pxla._get_input_indices` and second in `_array_shard_arg`. + self.assertEqual(cache_info2.hits, cache_info1.hits + 2) + # There will double the amount of misses as hits because arr1 and arr2's + # sharding are not the same. So 2 misses in _addressable_devices_indices_map + # and 2 in _array_shard_arg. + self.assertEqual(cache_info2.misses, cache_info1.misses + 4) + class TempSharding(Sharding):