diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 78c3a3a18..a4217f3d7 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -210,6 +210,8 @@ class PRNGKeyArray: return PRNGKeyArray(self.impl, jnp.concatenate(arrs, axis)) def broadcast_to(self, shape): + if jnp.ndim(shape) == 0: + shape = (shape,) new_shape = (*shape, *self.impl.key_shape) return PRNGKeyArray(self.impl, jnp.broadcast_to(self._keys, new_shape)) diff --git a/tests/random_test.py b/tests/random_test.py index 861f47458..55a580130 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1553,6 +1553,9 @@ class JnpWithPRNGKeyArrayTest(jtu.JaxTestCase): ref = jnp.broadcast_to(like(key), (3,)) self.assertEqual(out.shape, ref.shape) self.assertEqual(out.shape, (3,)) + out = jnp.broadcast_to(key, 3) + self.assertEqual(out.shape, ref.shape) + self.assertEqual(out.shape, (3,)) def test_expand_dims(self): key = random.PRNGKey(123)