Add an internal _device_list parameter to GSPMDSharding so that we can save on the initialization cost of PyDeviceList when creating GSPMDSharding from other shardings

PiperOrigin-RevId: 601055733
This commit is contained in:
Yash Katariya 2024-01-24 02:27:07 -08:00 committed by jax authors
parent a74b04a43f
commit a63197fed8
2 changed files with 21 additions and 5 deletions

View File

@ -1650,7 +1650,12 @@ def _pjit_batcher_for_sharding(
tad = list(new_op.tile_assignment_dimensions)
tad.insert(dim, 1)
new_op.tile_assignment_dimensions = tad
new_gs = GSPMDSharding(s._device_assignment, new_op) # type: ignore
if xla_extension_version >= 234:
new_gs = GSPMDSharding(
s._device_assignment, new_op, # type: ignore
_device_list=getattr(s, '_internal_device_list', None))
else:
new_gs = GSPMDSharding(s._device_assignment, new_op) # type: ignore
if hasattr(s, '_original_sharding'):
vmapped_s, _ = pxla._get_out_sharding_from_orig_sharding(
[new_gs], [None], s._original_sharding, None, [False])[0] # type: ignore
@ -1673,7 +1678,11 @@ def _pjit_batcher_for_sharding(
parsed_pspec = parse_flatten_op_sharding(s._hlo_sharding, mesh)[0] # type: ignore
parsed_pspec = parsed_pspec.insert_axis_partitions(dim, val)
mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
return GSPMDSharding(mps._device_assignment, mps._to_xla_hlo_sharding(ndim))
if xla_extension_version >= 234:
return GSPMDSharding(mps._device_assignment, mps._to_xla_hlo_sharding(ndim),
_device_list=getattr(mps, '_internal_device_list', None))
else:
return GSPMDSharding(mps._device_assignment, mps._to_xla_hlo_sharding(ndim))
def _pjit_jvp(primals_in, tangents_in,
@ -2166,8 +2175,13 @@ def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int,
device_or_backend_set: bool = False) -> GSPMDSharding:
if isinstance(s, GSPMDSharding):
return s
gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
memory_kind=s.memory_kind)
if xla_extension_version >= 234:
gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
memory_kind=s.memory_kind,
_device_list=getattr(s, '_internal_device_list', None))
else:
gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
memory_kind=s.memory_kind)
gs._original_sharding = s
if device_or_backend_set:
gs._original_sharding._device_backend = device_or_backend_set

View File

@ -831,11 +831,13 @@ class GSPMDSharding(XLACompatibleSharding):
_devices: tuple[Device, ...]
_hlo_sharding: xc.HloSharding
_memory_kind: str | None
_device_list: xc.DeviceList | None
@use_cpp_method()
def __init__(self, devices: Sequence[Device],
op_sharding: xc.OpSharding | xc.HloSharding,
*, memory_kind: str | None = None):
*, memory_kind: str | None = None,
_device_list: xc.DeviceList | None = None):
self._devices = tuple(devices)
if isinstance(op_sharding, xc.OpSharding):
self._hlo_sharding = xc.HloSharding.from_proto(op_sharding)