mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #8381 from LenaMartens:changelist/405399581
PiperOrigin-RevId: 410371788
This commit is contained in:
commit
9e09b511f9
@ -30,7 +30,7 @@ import collections
|
||||
from functools import partial
|
||||
import operator
|
||||
import types
|
||||
from typing import Sequence, FrozenSet, Optional, Tuple, Union
|
||||
from typing import Sequence, FrozenSet, Optional, Tuple, Union, Set, Type, Callable
|
||||
from textwrap import dedent as _dedent
|
||||
import warnings
|
||||
|
||||
@ -553,6 +553,12 @@ def _arraylike(x):
|
||||
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
|
||||
hasattr(x, '__jax_array__') or isscalar(x))
|
||||
|
||||
|
||||
def _stackable(*args):
|
||||
return _all(type(arg) in stackables for arg in args)
|
||||
stackables: Set[Type] = set()
|
||||
_register_stackable: Callable[[Type], None] = stackables.add
|
||||
|
||||
def _check_arraylike(fun_name, *args):
|
||||
"""Check if all args fit JAX's definition of arraylike."""
|
||||
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
|
||||
@ -1718,7 +1724,7 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
|
||||
|
||||
@_wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC)
|
||||
def reshape(a, newshape, order="C"):
|
||||
_check_arraylike("reshape", a)
|
||||
_stackable(a) or _check_arraylike("reshape", a)
|
||||
try:
|
||||
return a.reshape(newshape, order=order) # forward to method for ndarrays
|
||||
except AttributeError:
|
||||
@ -1761,7 +1767,7 @@ def _transpose(a, *args):
|
||||
@_wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC)
|
||||
@partial(jit, static_argnames=('order',), inline=True)
|
||||
def ravel(a, order="C"):
|
||||
_check_arraylike("ravel", a)
|
||||
_stackable(a) or _check_arraylike("ravel", a)
|
||||
if order == "K":
|
||||
raise NotImplementedError("Ravel not implemented for order='K'.")
|
||||
return reshape(a, (size(a),), order)
|
||||
@ -2224,6 +2230,8 @@ def broadcast_arrays(*args):
|
||||
The JAX version does not necessarily return a view of the input.
|
||||
""")
|
||||
def broadcast_to(arr, shape):
|
||||
if hasattr(arr, "broadcast_to"):
|
||||
return arr.broadcast_to(shape)
|
||||
arr = arr if isinstance(arr, ndarray) else array(arr)
|
||||
shape = (shape,) if ndim(shape) == 0 else shape
|
||||
shape = canonicalize_shape(shape) # check that shape is concrete
|
||||
@ -3361,7 +3369,7 @@ def stack(arrays, axis: int = 0, out=None):
|
||||
|
||||
@_wraps(np.tile)
|
||||
def tile(A, reps):
|
||||
_check_arraylike("tile", A)
|
||||
_stackable(A) or _check_arraylike("tile", A)
|
||||
try:
|
||||
iter(reps)
|
||||
except TypeError:
|
||||
@ -3392,13 +3400,15 @@ def _concatenate_array(arr, axis: int):
|
||||
def concatenate(arrays, axis: int = 0):
|
||||
if isinstance(arrays, (np.ndarray, ndarray)):
|
||||
return _concatenate_array(arrays, axis)
|
||||
_check_arraylike("concatenate", *arrays)
|
||||
_stackable(*arrays) or _check_arraylike("concatenate", *arrays)
|
||||
if not len(arrays):
|
||||
raise ValueError("Need at least one array to concatenate.")
|
||||
if ndim(arrays[0]) == 0:
|
||||
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
|
||||
if axis is None:
|
||||
return concatenate([ravel(a) for a in arrays], axis=0)
|
||||
if hasattr(arrays[0], "concatenate"):
|
||||
return arrays[0].concatenate(arrays[1:], axis)
|
||||
axis = _canonicalize_axis(axis, ndim(arrays[0]))
|
||||
arrays = _promote_dtypes(*arrays)
|
||||
# lax.concatenate can be slow to compile for wide concatenations, so form a
|
||||
|
@ -30,6 +30,7 @@ from jax.interpreters import xla
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import cuda_prng
|
||||
from jax._src.numpy.lax_numpy import _register_stackable
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src.util import prod
|
||||
|
||||
@ -189,6 +190,19 @@ class PRNGKeyArray:
|
||||
def _split(self, num: int) -> 'PRNGKeyArray':
|
||||
return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
|
||||
|
||||
def reshape(self, newshape, order=None):
|
||||
reshaped_keys = jnp.reshape(self._keys, (*newshape, -1), order=order)
|
||||
return PRNGKeyArray(self.impl, reshaped_keys)
|
||||
|
||||
def concatenate(self, key_arrs, axis):
|
||||
axis = axis % len(self.shape)
|
||||
arrs = [self._keys, *[k._keys for k in key_arrs]]
|
||||
return PRNGKeyArray(self.impl, jnp.stack(arrs, axis))
|
||||
|
||||
def broadcast_to(self, shape):
|
||||
new_shape = tuple(shape)+(self._keys.shape[-1],)
|
||||
return PRNGKeyArray(self.impl, jnp.broadcast_to(self._keys, new_shape))
|
||||
|
||||
def __repr__(self):
|
||||
arr_shape = self._shape
|
||||
pp_keys = pp.text('shape = ') + pp.text(str(arr_shape))
|
||||
@ -201,6 +215,7 @@ class PRNGKeyArray:
|
||||
def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray:
|
||||
return PRNGKeyArray(impl, impl.seed(seed))
|
||||
|
||||
_register_stackable(PRNGKeyArray)
|
||||
|
||||
# -- threefry2x32 PRNG implementation --
|
||||
|
||||
|
@ -1237,6 +1237,49 @@ class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
|
||||
def seed_prng(self, seed):
|
||||
return prng.seed_with_impl(prng.unsafe_rbg_prng_impl, seed)
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng,
|
||||
'custom PRNG tests require config.jax_enable_custom_prng')
|
||||
class JnpWithPRNGKeyArrayTest(jtu.JaxTestCase):
|
||||
def test_reshape(self):
|
||||
key = random.PRNGKey(123)
|
||||
keys = random.split(key, 4)
|
||||
keys = jnp.reshape(keys, (2, 2))
|
||||
self.assertEqual(keys.shape, (2, 2))
|
||||
|
||||
def test_tile(self):
|
||||
key = random.PRNGKey(123)
|
||||
keys = jnp.tile(key, 3)
|
||||
self.assertEqual(keys.shape, (3,))
|
||||
|
||||
def test_concatenate(self):
|
||||
key = random.PRNGKey(123)
|
||||
keys = random.split(key, 2)
|
||||
keys = jnp.concatenate([keys, keys, keys], axis=0)
|
||||
self.assertEqual(keys.shape, (3, 2))
|
||||
|
||||
def test_broadcast_to(self):
|
||||
key = random.PRNGKey(123)
|
||||
keys = jnp.broadcast_to(key, (3,))
|
||||
self.assertEqual(keys.shape, (3,))
|
||||
|
||||
def test_broadcast_arrays(self):
|
||||
key = random.PRNGKey(123)
|
||||
keys = jax.random.split(key, 3)
|
||||
key, _ = jnp.broadcast_arrays(key, keys)
|
||||
self.assertEqual(key.shape, (3,))
|
||||
|
||||
def test_append(self):
|
||||
key = random.PRNGKey(123)
|
||||
keys = jnp.append(key, key)
|
||||
self.assertEqual(keys.shape, (2, 1))
|
||||
|
||||
def test_ravel(self):
|
||||
key = random.PRNGKey(123)
|
||||
keys = jax.random.split(key, 4)
|
||||
keys = jnp.reshape(keys, (2, 2))
|
||||
keys = jnp.ravel(keys)
|
||||
self.assertEqual(keys.shape, (4,))
|
||||
|
||||
def _sampler_unimplemented_with_rbg(*args, **kwargs):
|
||||
# TODO(mattjj): enable these tests if/when RngBitGenerator supports them
|
||||
raise SkipTest('8- and 16-bit types not supported with RBG PRNG')
|
||||
|
Loading…
x
Reference in New Issue
Block a user