mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
add _device attribute to PRNGKeyArray so that computation follows key placement
unrelated: remove some redundant hasattr + try / except AttributeError
This commit is contained in:
parent
e9204e312a
commit
47b2dfe92f
@ -31,8 +31,7 @@ class ArrayMeta(abc.ABCMeta):
|
||||
# that isinstance(x, ndarray) might return true but
|
||||
# issubclass(type(x), ndarray) might return false for an array tracer.
|
||||
try:
|
||||
return (hasattr(instance, "aval") and
|
||||
isinstance(instance.aval, core.UnshapedArray))
|
||||
return isinstance(instance.aval, core.UnshapedArray)
|
||||
except AttributeError:
|
||||
super().__instancecheck__(instance)
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
import abc
|
||||
from functools import partial
|
||||
import operator as op
|
||||
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence
|
||||
|
||||
import numpy as np
|
||||
@ -114,8 +115,7 @@ class PRNGKeyArrayMeta(abc.ABCMeta):
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
try:
|
||||
return (hasattr(instance, 'aval') and
|
||||
isinstance(instance.aval, core.ShapedArray) and
|
||||
return (isinstance(instance.aval, core.ShapedArray) and
|
||||
type(instance.aval.dtype) is KeyTy)
|
||||
except AttributeError:
|
||||
super().__instancecheck__(instance)
|
||||
@ -169,6 +169,10 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
||||
def dtype(self):
|
||||
return KeyTy(self.impl)
|
||||
|
||||
_device = property(op.attrgetter('_base_array._device'))
|
||||
_committed = property(op.attrgetter('_base_array._committed'))
|
||||
sharding = property(op.attrgetter('_base_array.sharding'))
|
||||
|
||||
def _is_scalar(self):
|
||||
base_ndim = len(self.impl.key_shape)
|
||||
return self._base_array.ndim == base_ndim
|
||||
|
@ -150,6 +150,12 @@ class MultiDeviceTest(jtu.JaxTestCase):
|
||||
jax.device_put(x_uncommitted, devices[3])),
|
||||
devices[4])
|
||||
|
||||
def test_computation_follows_data_prng(self):
|
||||
_, device, *_ = self.get_devices()
|
||||
rng = jax.device_put(jax.random.PRNGKey(0), device)
|
||||
val = jax.random.normal(rng, ())
|
||||
self.assert_committed_to_device(val, device)
|
||||
|
||||
def test_primitive_compilation_cache(self):
|
||||
devices = self.get_devices()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user