mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
remove inaccurate inline comment in PRNGKeyArray
constructor
PiperOrigin-RevId: 748085747
This commit is contained in:
parent
47bc2f55dc
commit
90af597786
@ -166,7 +166,6 @@ class PRNGKeyArray(jax.Array):
|
||||
_check_prng_key_data(impl, key_data)
|
||||
self._impl = impl
|
||||
self._consumed = False # TODO(jakevdp): default to True here?
|
||||
# If key_data is a numpy array, convert it to an uncommitted CPU jax.Array
|
||||
if isinstance(key_data, np.ndarray):
|
||||
aval = core.get_aval(key_data)
|
||||
device = pxla.get_default_device()
|
||||
|
@ -612,7 +612,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def test_numpy_construction(self):
|
||||
key = random.wrap_key_data(np.array([42, 173], dtype=np.uint32),
|
||||
impl='threefry2x32')
|
||||
impl='threefry2x32')
|
||||
self.assertIsInstance(key, prng_internal.PRNGKeyArray)
|
||||
self.assertIsInstance(key._base_array, jax.Array)
|
||||
self.assertEqual(key._base_array.device, jax.devices()[0])
|
||||
|
Loading…
x
Reference in New Issue
Block a user