diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index c960846fa..fffc4196c 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1650,7 +1650,12 @@ def _pjit_batcher_for_sharding( tad = list(new_op.tile_assignment_dimensions) tad.insert(dim, 1) new_op.tile_assignment_dimensions = tad - new_gs = GSPMDSharding(s._device_assignment, new_op) # type: ignore + if xla_extension_version >= 234: + new_gs = GSPMDSharding( + s._device_assignment, new_op, # type: ignore + _device_list=getattr(s, '_internal_device_list', None)) + else: + new_gs = GSPMDSharding(s._device_assignment, new_op) # type: ignore if hasattr(s, '_original_sharding'): vmapped_s, _ = pxla._get_out_sharding_from_orig_sharding( [new_gs], [None], s._original_sharding, None, [False])[0] # type: ignore @@ -1673,7 +1678,11 @@ def _pjit_batcher_for_sharding( parsed_pspec = parse_flatten_op_sharding(s._hlo_sharding, mesh)[0] # type: ignore parsed_pspec = parsed_pspec.insert_axis_partitions(dim, val) mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec) - return GSPMDSharding(mps._device_assignment, mps._to_xla_hlo_sharding(ndim)) + if xla_extension_version >= 234: + return GSPMDSharding(mps._device_assignment, mps._to_xla_hlo_sharding(ndim), + _device_list=getattr(mps, '_internal_device_list', None)) + else: + return GSPMDSharding(mps._device_assignment, mps._to_xla_hlo_sharding(ndim)) def _pjit_jvp(primals_in, tangents_in, @@ -2166,8 +2175,13 @@ def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int, device_or_backend_set: bool = False) -> GSPMDSharding: if isinstance(s, GSPMDSharding): return s - gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim), - memory_kind=s.memory_kind) + if xla_extension_version >= 234: + gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim), + memory_kind=s.memory_kind, + _device_list=getattr(s, '_internal_device_list', None)) + else: + gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim), + memory_kind=s.memory_kind) gs._original_sharding = s if device_or_backend_set: gs._original_sharding._device_backend = device_or_backend_set diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index d648a0d2b..ead6c6991 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -831,11 +831,13 @@ class GSPMDSharding(XLACompatibleSharding): _devices: tuple[Device, ...] _hlo_sharding: xc.HloSharding _memory_kind: str | None + _device_list: xc.DeviceList | None @use_cpp_method() def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding | xc.HloSharding, - *, memory_kind: str | None = None): + *, memory_kind: str | None = None, + _device_list: xc.DeviceList | None = None): self._devices = tuple(devices) if isinstance(op_sharding, xc.OpSharding): self._hlo_sharding = xc.HloSharding.from_proto(op_sharding)