mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make _remake internal and add return type hints
PiperOrigin-RevId: 550721261
This commit is contained in:
parent
727af17cfd
commit
7821516105
@ -577,7 +577,7 @@ def _op_sharding_to_pos_sharding(
|
||||
name = device_assignment[0].platform.upper()
|
||||
ids = np.array([DeviceIdSet(name, i)
|
||||
for i in op_sharding.tile_assignment_devices])
|
||||
p = PositionalSharding.remake(tuple(device_assignment), ids)
|
||||
p = PositionalSharding._remake(tuple(device_assignment), ids)
|
||||
p = p.reshape(op_sharding.tile_assignment_dimensions)
|
||||
if replicate_on_last_tile_dim:
|
||||
p = p.replicate(-1, keepdims=False)
|
||||
@ -619,19 +619,19 @@ class PositionalSharding(XLACompatibleSharding):
|
||||
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
|
||||
return f'{cls_name}({body}{mem})'
|
||||
|
||||
def reshape(self, *shape):
|
||||
return self.remake(self._devices, self._ids.reshape(*shape))
|
||||
def reshape(self, *shape) -> PositionalSharding:
|
||||
return self._remake(self._devices, self._ids.reshape(*shape))
|
||||
|
||||
def transpose(self, *axes):
|
||||
return self.remake(self._devices, self._ids.transpose(*axes))
|
||||
def transpose(self, *axes) -> PositionalSharding:
|
||||
return self._remake(self._devices, self._ids.transpose(*axes))
|
||||
T = property(transpose)
|
||||
|
||||
def replicate(self, axis=None, keepdims=True):
|
||||
def replicate(self, axis=None, keepdims=True) -> PositionalSharding:
|
||||
new_ids = self._ids.sum(axis=axis, keepdims=keepdims) # union
|
||||
return self.remake(self._devices, new_ids)
|
||||
return self._remake(self._devices, new_ids)
|
||||
|
||||
@classmethod
|
||||
def remake(
|
||||
def _remake(
|
||||
cls, devices: tuple[xc.Device, ...], ids: np.ndarray,
|
||||
*, memory_kind: str | None = None) -> PositionalSharding:
|
||||
self = cls.__new__(cls)
|
||||
|
Loading…
x
Reference in New Issue
Block a user