add _device attribute to PRNGKeyArray so that computation follows key placement

unrelated: remove some redundant hasattr + try / except AttributeError
This commit is contained in:
Matthew Johnson 2022-09-08 13:45:06 -07:00
parent e9204e312a
commit 47b2dfe92f
3 changed files with 13 additions and 4 deletions

View File

@ -31,8 +31,7 @@ class ArrayMeta(abc.ABCMeta):
# that isinstance(x, ndarray) might return true but
# issubclass(type(x), ndarray) might return false for an array tracer.
try:
return (hasattr(instance, "aval") and
isinstance(instance.aval, core.UnshapedArray))
return isinstance(instance.aval, core.UnshapedArray)
except AttributeError:
super().__instancecheck__(instance)

View File

@ -15,6 +15,7 @@
import abc
from functools import partial
import operator as op
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence
import numpy as np
@ -114,8 +115,7 @@ class PRNGKeyArrayMeta(abc.ABCMeta):
def __instancecheck__(self, instance):
try:
return (hasattr(instance, 'aval') and
isinstance(instance.aval, core.ShapedArray) and
return (isinstance(instance.aval, core.ShapedArray) and
type(instance.aval.dtype) is KeyTy)
except AttributeError:
super().__instancecheck__(instance)
@ -169,6 +169,10 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
def dtype(self):
return KeyTy(self.impl)
_device = property(op.attrgetter('_base_array._device'))
_committed = property(op.attrgetter('_base_array._committed'))
sharding = property(op.attrgetter('_base_array.sharding'))
def _is_scalar(self):
base_ndim = len(self.impl.key_shape)
return self._base_array.ndim == base_ndim

View File

@ -150,6 +150,12 @@ class MultiDeviceTest(jtu.JaxTestCase):
jax.device_put(x_uncommitted, devices[3])),
devices[4])
def test_computation_follows_data_prng(self):
_, device, *_ = self.get_devices()
rng = jax.device_put(jax.random.PRNGKey(0), device)
val = jax.random.normal(rng, ())
self.assert_committed_to_device(val, device)
def test_primitive_compilation_cache(self):
devices = self.get_devices()