Make _remake internal and add return type hints

PiperOrigin-RevId: 550721261
This commit is contained in:
Yash Katariya 2023-07-24 17:35:46 -07:00 committed by jax authors
parent 727af17cfd
commit 7821516105

View File

@ -577,7 +577,7 @@ def _op_sharding_to_pos_sharding(
name = device_assignment[0].platform.upper()
ids = np.array([DeviceIdSet(name, i)
for i in op_sharding.tile_assignment_devices])
p = PositionalSharding.remake(tuple(device_assignment), ids)
p = PositionalSharding._remake(tuple(device_assignment), ids)
p = p.reshape(op_sharding.tile_assignment_dimensions)
if replicate_on_last_tile_dim:
p = p.replicate(-1, keepdims=False)
@ -619,19 +619,19 @@ class PositionalSharding(XLACompatibleSharding):
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
return f'{cls_name}({body}{mem})'
def reshape(self, *shape):
return self.remake(self._devices, self._ids.reshape(*shape))
def reshape(self, *shape) -> PositionalSharding:
return self._remake(self._devices, self._ids.reshape(*shape))
def transpose(self, *axes):
return self.remake(self._devices, self._ids.transpose(*axes))
def transpose(self, *axes) -> PositionalSharding:
return self._remake(self._devices, self._ids.transpose(*axes))
T = property(transpose)
def replicate(self, axis=None, keepdims=True):
def replicate(self, axis=None, keepdims=True) -> PositionalSharding:
new_ids = self._ids.sum(axis=axis, keepdims=keepdims) # union
return self.remake(self._devices, new_ids)
return self._remake(self._devices, new_ids)
@classmethod
def remake(
def _remake(
cls, devices: tuple[xc.Device, ...], ids: np.ndarray,
*, memory_kind: str | None = None) -> PositionalSharding:
self = cls.__new__(cls)