Merge pull request #16213 from jakevdp:keyarray-shards

PiperOrigin-RevId: 537262590
This commit is contained in:
jax authors 2023-06-02 03:12:55 -07:00
commit 6a89abcc76
3 changed files with 48 additions and 6 deletions

View File

@ -18,8 +18,8 @@ import math
import operator as op
import numpy as np
import functools
from typing import (Sequence, Tuple, Callable, Optional, List, cast, Set,
TYPE_CHECKING)
from typing import (Any, Callable, List, Optional, Sequence, Set, Tuple,
Union, cast, TYPE_CHECKING)
from jax._src import abstract_arrays
from jax._src import api
@ -45,6 +45,7 @@ from jax._src.util import use_cpp_class, use_cpp_method
Shape = Tuple[int, ...]
Device = xc.Device
Index = Tuple[slice, ...]
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.
class Shard:
@ -60,7 +61,7 @@ class Shard:
"""
def __init__(self, device: Device, sharding: Sharding, global_shape: Shape,
data: Optional[ArrayImpl] = None):
data: Union[None, ArrayImpl, PRNGKeyArrayImpl] = None):
self._device = device
self._sharding = sharding
self._global_shape = global_shape

View File

@ -17,7 +17,8 @@ import abc
from functools import partial, reduce
import math
import operator as op
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Set, Sequence, Tuple, Union
from typing import (Any, Callable, Hashable, Iterator, List, NamedTuple,
Set, Sequence, Tuple, Union)
import numpy as np
@ -62,6 +63,7 @@ map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Device = xc.Device
Shard = Any # TODO(jakevdp): fix circular imports and import Shard
UINT_DTYPES = {
8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type]
@ -216,9 +218,17 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
def is_deleted(self) -> bool: ...
@abc.abstractmethod
def on_device_size_in_bytes(self) -> int: ...
@property
@abc.abstractmethod
def addressable_shards(self) -> List[Shard]: ...
@property
@abc.abstractmethod
def global_shards(self) -> List[Shard]: ...
@abc.abstractmethod
def addressable_data(self, index: int) -> PRNGKeyArray: ...
# TODO(jakevdp): potentially add tolist(), tobytes(), device_buffer, device_buffers,
# addressable_data(), addressable_shards(), global_shards(), __cuda_interface__()
# TODO(jakevdp): potentially add tolist(), tobytes(),
# device_buffer, device_buffers, __cuda_interface__()
class PRNGKeyArrayImpl(PRNGKeyArray):
@ -291,6 +301,33 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment]
unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment]
def addressable_data(self, index: int) -> PRNGKeyArrayImpl:
return PRNGKeyArrayImpl(self.impl, self._base_array.addressable_data(index))
@property
def addressable_shards(self) -> List[Shard]:
return [
type(s)(
device=s._device,
sharding=s._sharding,
global_shape=s._global_shape,
data=PRNGKeyArrayImpl(self.impl, s._data),
)
for s in self._base_array.addressable_shards
]
@property
def global_shards(self) -> List[Shard]:
return [
type(s)(
device=s._device,
sharding=s._sharding,
global_shape=s._global_shape,
data=PRNGKeyArrayImpl(self.impl, s._data),
)
for s in self._base_array.global_shards
]
@property
def sharding(self):
phys_sharding = self._base_array.sharding

View File

@ -1950,6 +1950,10 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertEqual(key.devices(), key._base_array.devices())
self.assertEqual(key.on_device_size_in_bytes, key._base_array.on_device_size_in_bytes)
self.assertEqual(key.unsafe_buffer_pointer, key._base_array.unsafe_buffer_pointer)
self.assertArraysEqual(key.addressable_data(0)._base_array,
key._base_array.addressable_data(0))
self.assertLen(key.addressable_shards, len(key._base_array.addressable_shards))
self.assertLen(key.global_shards, len(key._base_array.global_shards))
def test_delete(self):
key = self.make_keys(10)