remove inaccurate inline comment in PRNGKeyArray constructor

PiperOrigin-RevId: 748085747
This commit is contained in:
Roy Frostig 2025-04-15 17:38:52 -07:00 committed by jax authors
parent 47bc2f55dc
commit 90af597786
2 changed files with 1 additions and 2 deletions

View File

@ -166,7 +166,6 @@ class PRNGKeyArray(jax.Array):
_check_prng_key_data(impl, key_data)
self._impl = impl
self._consumed = False # TODO(jakevdp): default to True here?
# If key_data is a numpy array, convert it to an uncommitted CPU jax.Array
if isinstance(key_data, np.ndarray):
aval = core.get_aval(key_data)
device = pxla.get_default_device()

View File

@ -612,7 +612,7 @@ class KeyArrayTest(jtu.JaxTestCase):
def test_numpy_construction(self):
key = random.wrap_key_data(np.array([42, 173], dtype=np.uint32),
impl='threefry2x32')
impl='threefry2x32')
self.assertIsInstance(key, prng_internal.PRNGKeyArray)
self.assertIsInstance(key._base_array, jax.Array)
self.assertEqual(key._base_array.device, jax.devices()[0])