Merge pull request #10288 from YouJiacheng:patch-7

PiperOrigin-RevId: 442043193
This commit is contained in:
jax authors 2022-04-15 10:19:08 -07:00
commit a4b8a443be
2 changed files with 5 additions and 0 deletions

View File

@ -210,6 +210,8 @@ class PRNGKeyArray:
return PRNGKeyArray(self.impl, jnp.concatenate(arrs, axis))
def broadcast_to(self, shape):
if jnp.ndim(shape) == 0:
shape = (shape,)
new_shape = (*shape, *self.impl.key_shape)
return PRNGKeyArray(self.impl, jnp.broadcast_to(self._keys, new_shape))

View File

@ -1553,6 +1553,9 @@ class JnpWithPRNGKeyArrayTest(jtu.JaxTestCase):
ref = jnp.broadcast_to(like(key), (3,))
self.assertEqual(out.shape, ref.shape)
self.assertEqual(out.shape, (3,))
out = jnp.broadcast_to(key, 3)
self.assertEqual(out.shape, ref.shape)
self.assertEqual(out.shape, (3,))
def test_expand_dims(self):
key = random.PRNGKey(123)