mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add an internal _device_list
parameter to GSPMDSharding so that we can save on the initialization cost of PyDeviceList when creating GSPMDSharding
from other shardings
PiperOrigin-RevId: 601055733
This commit is contained in:
parent
a74b04a43f
commit
a63197fed8
@ -1650,7 +1650,12 @@ def _pjit_batcher_for_sharding(
|
||||
tad = list(new_op.tile_assignment_dimensions)
|
||||
tad.insert(dim, 1)
|
||||
new_op.tile_assignment_dimensions = tad
|
||||
new_gs = GSPMDSharding(s._device_assignment, new_op) # type: ignore
|
||||
if xla_extension_version >= 234:
|
||||
new_gs = GSPMDSharding(
|
||||
s._device_assignment, new_op, # type: ignore
|
||||
_device_list=getattr(s, '_internal_device_list', None))
|
||||
else:
|
||||
new_gs = GSPMDSharding(s._device_assignment, new_op) # type: ignore
|
||||
if hasattr(s, '_original_sharding'):
|
||||
vmapped_s, _ = pxla._get_out_sharding_from_orig_sharding(
|
||||
[new_gs], [None], s._original_sharding, None, [False])[0] # type: ignore
|
||||
@ -1673,7 +1678,11 @@ def _pjit_batcher_for_sharding(
|
||||
parsed_pspec = parse_flatten_op_sharding(s._hlo_sharding, mesh)[0] # type: ignore
|
||||
parsed_pspec = parsed_pspec.insert_axis_partitions(dim, val)
|
||||
mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
|
||||
return GSPMDSharding(mps._device_assignment, mps._to_xla_hlo_sharding(ndim))
|
||||
if xla_extension_version >= 234:
|
||||
return GSPMDSharding(mps._device_assignment, mps._to_xla_hlo_sharding(ndim),
|
||||
_device_list=getattr(mps, '_internal_device_list', None))
|
||||
else:
|
||||
return GSPMDSharding(mps._device_assignment, mps._to_xla_hlo_sharding(ndim))
|
||||
|
||||
|
||||
def _pjit_jvp(primals_in, tangents_in,
|
||||
@ -2166,8 +2175,13 @@ def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int,
|
||||
device_or_backend_set: bool = False) -> GSPMDSharding:
|
||||
if isinstance(s, GSPMDSharding):
|
||||
return s
|
||||
gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
|
||||
memory_kind=s.memory_kind)
|
||||
if xla_extension_version >= 234:
|
||||
gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
|
||||
memory_kind=s.memory_kind,
|
||||
_device_list=getattr(s, '_internal_device_list', None))
|
||||
else:
|
||||
gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
|
||||
memory_kind=s.memory_kind)
|
||||
gs._original_sharding = s
|
||||
if device_or_backend_set:
|
||||
gs._original_sharding._device_backend = device_or_backend_set
|
||||
|
@ -831,11 +831,13 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
_devices: tuple[Device, ...]
|
||||
_hlo_sharding: xc.HloSharding
|
||||
_memory_kind: str | None
|
||||
_device_list: xc.DeviceList | None
|
||||
|
||||
@use_cpp_method()
|
||||
def __init__(self, devices: Sequence[Device],
|
||||
op_sharding: xc.OpSharding | xc.HloSharding,
|
||||
*, memory_kind: str | None = None):
|
||||
*, memory_kind: str | None = None,
|
||||
_device_list: xc.DeviceList | None = None):
|
||||
self._devices = tuple(devices)
|
||||
if isinstance(op_sharding, xc.OpSharding):
|
||||
self._hlo_sharding = xc.HloSharding.from_proto(op_sharding)
|
||||
|
Loading…
x
Reference in New Issue
Block a user