mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #16213 from jakevdp:keyarray-shards
PiperOrigin-RevId: 537262590
This commit is contained in:
commit
6a89abcc76
@ -18,8 +18,8 @@ import math
|
||||
import operator as op
|
||||
import numpy as np
|
||||
import functools
|
||||
from typing import (Sequence, Tuple, Callable, Optional, List, cast, Set,
|
||||
TYPE_CHECKING)
|
||||
from typing import (Any, Callable, List, Optional, Sequence, Set, Tuple,
|
||||
Union, cast, TYPE_CHECKING)
|
||||
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import api
|
||||
@ -45,6 +45,7 @@ from jax._src.util import use_cpp_class, use_cpp_method
|
||||
Shape = Tuple[int, ...]
|
||||
Device = xc.Device
|
||||
Index = Tuple[slice, ...]
|
||||
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.
|
||||
|
||||
|
||||
class Shard:
|
||||
@ -60,7 +61,7 @@ class Shard:
|
||||
"""
|
||||
|
||||
def __init__(self, device: Device, sharding: Sharding, global_shape: Shape,
|
||||
data: Optional[ArrayImpl] = None):
|
||||
data: Union[None, ArrayImpl, PRNGKeyArrayImpl] = None):
|
||||
self._device = device
|
||||
self._sharding = sharding
|
||||
self._global_shape = global_shape
|
||||
|
@ -17,7 +17,8 @@ import abc
|
||||
from functools import partial, reduce
|
||||
import math
|
||||
import operator as op
|
||||
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Set, Sequence, Tuple, Union
|
||||
from typing import (Any, Callable, Hashable, Iterator, List, NamedTuple,
|
||||
Set, Sequence, Tuple, Union)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -62,6 +63,7 @@ map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
Device = xc.Device
|
||||
Shard = Any # TODO(jakevdp): fix circular imports and import Shard
|
||||
|
||||
UINT_DTYPES = {
|
||||
8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type]
|
||||
@ -216,9 +218,17 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
|
||||
def is_deleted(self) -> bool: ...
|
||||
@abc.abstractmethod
|
||||
def on_device_size_in_bytes(self) -> int: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def addressable_shards(self) -> List[Shard]: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def global_shards(self) -> List[Shard]: ...
|
||||
@abc.abstractmethod
|
||||
def addressable_data(self, index: int) -> PRNGKeyArray: ...
|
||||
|
||||
# TODO(jakevdp): potentially add tolist(), tobytes(), device_buffer, device_buffers,
|
||||
# addressable_data(), addressable_shards(), global_shards(), __cuda_interface__()
|
||||
# TODO(jakevdp): potentially add tolist(), tobytes(),
|
||||
# device_buffer, device_buffers, __cuda_interface__()
|
||||
|
||||
|
||||
class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
@ -291,6 +301,33 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment]
|
||||
unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment]
|
||||
|
||||
def addressable_data(self, index: int) -> PRNGKeyArrayImpl:
|
||||
return PRNGKeyArrayImpl(self.impl, self._base_array.addressable_data(index))
|
||||
|
||||
@property
|
||||
def addressable_shards(self) -> List[Shard]:
|
||||
return [
|
||||
type(s)(
|
||||
device=s._device,
|
||||
sharding=s._sharding,
|
||||
global_shape=s._global_shape,
|
||||
data=PRNGKeyArrayImpl(self.impl, s._data),
|
||||
)
|
||||
for s in self._base_array.addressable_shards
|
||||
]
|
||||
|
||||
@property
|
||||
def global_shards(self) -> List[Shard]:
|
||||
return [
|
||||
type(s)(
|
||||
device=s._device,
|
||||
sharding=s._sharding,
|
||||
global_shape=s._global_shape,
|
||||
data=PRNGKeyArrayImpl(self.impl, s._data),
|
||||
)
|
||||
for s in self._base_array.global_shards
|
||||
]
|
||||
|
||||
@property
|
||||
def sharding(self):
|
||||
phys_sharding = self._base_array.sharding
|
||||
|
@ -1950,6 +1950,10 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(key.devices(), key._base_array.devices())
|
||||
self.assertEqual(key.on_device_size_in_bytes, key._base_array.on_device_size_in_bytes)
|
||||
self.assertEqual(key.unsafe_buffer_pointer, key._base_array.unsafe_buffer_pointer)
|
||||
self.assertArraysEqual(key.addressable_data(0)._base_array,
|
||||
key._base_array.addressable_data(0))
|
||||
self.assertLen(key.addressable_shards, len(key._base_array.addressable_shards))
|
||||
self.assertLen(key.global_shards, len(key._base_array.global_shards))
|
||||
|
||||
def test_delete(self):
|
||||
key = self.make_keys(10)
|
||||
|
Loading…
x
Reference in New Issue
Block a user