Merge pull request #8381 from LenaMartens:changelist/405399581

PiperOrigin-RevId: 410371788
This commit is contained in:
jax authors 2021-11-16 15:55:44 -08:00
commit 9e09b511f9
3 changed files with 73 additions and 5 deletions

View File

@ -30,7 +30,7 @@ import collections
from functools import partial
import operator
import types
from typing import Sequence, FrozenSet, Optional, Tuple, Union
from typing import Sequence, FrozenSet, Optional, Tuple, Union, Set, Type, Callable
from textwrap import dedent as _dedent
import warnings
@ -553,6 +553,12 @@ def _arraylike(x):
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
hasattr(x, '__jax_array__') or isscalar(x))
def _stackable(*args):
return _all(type(arg) in stackables for arg in args)
stackables: Set[Type] = set()
_register_stackable: Callable[[Type], None] = stackables.add
def _check_arraylike(fun_name, *args):
"""Check if all args fit JAX's definition of arraylike."""
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
@ -1718,7 +1724,7 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
@_wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC)
def reshape(a, newshape, order="C"):
_check_arraylike("reshape", a)
_stackable(a) or _check_arraylike("reshape", a)
try:
return a.reshape(newshape, order=order) # forward to method for ndarrays
except AttributeError:
@ -1761,7 +1767,7 @@ def _transpose(a, *args):
@_wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('order',), inline=True)
def ravel(a, order="C"):
_check_arraylike("ravel", a)
_stackable(a) or _check_arraylike("ravel", a)
if order == "K":
raise NotImplementedError("Ravel not implemented for order='K'.")
return reshape(a, (size(a),), order)
@ -2224,6 +2230,8 @@ def broadcast_arrays(*args):
The JAX version does not necessarily return a view of the input.
""")
def broadcast_to(arr, shape):
if hasattr(arr, "broadcast_to"):
return arr.broadcast_to(shape)
arr = arr if isinstance(arr, ndarray) else array(arr)
shape = (shape,) if ndim(shape) == 0 else shape
shape = canonicalize_shape(shape) # check that shape is concrete
@ -3361,7 +3369,7 @@ def stack(arrays, axis: int = 0, out=None):
@_wraps(np.tile)
def tile(A, reps):
_check_arraylike("tile", A)
_stackable(A) or _check_arraylike("tile", A)
try:
iter(reps)
except TypeError:
@ -3392,13 +3400,15 @@ def _concatenate_array(arr, axis: int):
def concatenate(arrays, axis: int = 0):
if isinstance(arrays, (np.ndarray, ndarray)):
return _concatenate_array(arrays, axis)
_check_arraylike("concatenate", *arrays)
_stackable(*arrays) or _check_arraylike("concatenate", *arrays)
if not len(arrays):
raise ValueError("Need at least one array to concatenate.")
if ndim(arrays[0]) == 0:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
if axis is None:
return concatenate([ravel(a) for a in arrays], axis=0)
if hasattr(arrays[0], "concatenate"):
return arrays[0].concatenate(arrays[1:], axis)
axis = _canonicalize_axis(axis, ndim(arrays[0]))
arrays = _promote_dtypes(*arrays)
# lax.concatenate can be slow to compile for wide concatenations, so form a

View File

@ -30,6 +30,7 @@ from jax.interpreters import xla
from jax._src.api import jit, vmap
from jax._src.lib import xla_client
from jax._src.lib import cuda_prng
from jax._src.numpy.lax_numpy import _register_stackable
import jax._src.pretty_printer as pp
from jax._src.util import prod
@ -189,6 +190,19 @@ class PRNGKeyArray:
def _split(self, num: int) -> 'PRNGKeyArray':
return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
def reshape(self, newshape, order=None):
reshaped_keys = jnp.reshape(self._keys, (*newshape, -1), order=order)
return PRNGKeyArray(self.impl, reshaped_keys)
def concatenate(self, key_arrs, axis):
axis = axis % len(self.shape)
arrs = [self._keys, *[k._keys for k in key_arrs]]
return PRNGKeyArray(self.impl, jnp.stack(arrs, axis))
def broadcast_to(self, shape):
new_shape = tuple(shape)+(self._keys.shape[-1],)
return PRNGKeyArray(self.impl, jnp.broadcast_to(self._keys, new_shape))
def __repr__(self):
arr_shape = self._shape
pp_keys = pp.text('shape = ') + pp.text(str(arr_shape))
@ -201,6 +215,7 @@ class PRNGKeyArray:
def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray:
return PRNGKeyArray(impl, impl.seed(seed))
_register_stackable(PRNGKeyArray)
# -- threefry2x32 PRNG implementation --

View File

@ -1237,6 +1237,49 @@ class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
def seed_prng(self, seed):
return prng.seed_with_impl(prng.unsafe_rbg_prng_impl, seed)
@skipIf(not config.jax_enable_custom_prng,
'custom PRNG tests require config.jax_enable_custom_prng')
class JnpWithPRNGKeyArrayTest(jtu.JaxTestCase):
def test_reshape(self):
key = random.PRNGKey(123)
keys = random.split(key, 4)
keys = jnp.reshape(keys, (2, 2))
self.assertEqual(keys.shape, (2, 2))
def test_tile(self):
key = random.PRNGKey(123)
keys = jnp.tile(key, 3)
self.assertEqual(keys.shape, (3,))
def test_concatenate(self):
key = random.PRNGKey(123)
keys = random.split(key, 2)
keys = jnp.concatenate([keys, keys, keys], axis=0)
self.assertEqual(keys.shape, (3, 2))
def test_broadcast_to(self):
key = random.PRNGKey(123)
keys = jnp.broadcast_to(key, (3,))
self.assertEqual(keys.shape, (3,))
def test_broadcast_arrays(self):
key = random.PRNGKey(123)
keys = jax.random.split(key, 3)
key, _ = jnp.broadcast_arrays(key, keys)
self.assertEqual(key.shape, (3,))
def test_append(self):
key = random.PRNGKey(123)
keys = jnp.append(key, key)
self.assertEqual(keys.shape, (2, 1))
def test_ravel(self):
key = random.PRNGKey(123)
keys = jax.random.split(key, 4)
keys = jnp.reshape(keys, (2, 2))
keys = jnp.ravel(keys)
self.assertEqual(keys.shape, (4,))
def _sampler_unimplemented_with_rbg(*args, **kwargs):
# TODO(mattjj): enable these tests if/when RngBitGenerator supports them
raise SkipTest('8- and 16-bit types not supported with RBG PRNG')