From a5737f82afb430e050726175fac961b0496427b4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 24 Apr 2023 12:26:58 -0700 Subject: [PATCH] custom prng: remove stackable override for jnp.concatenate --- jax/_src/dtypes.py | 4 ++++ jax/_src/numpy/lax_numpy.py | 4 +--- jax/_src/numpy/util.py | 4 ++-- jax/_src/prng.py | 10 ---------- tests/random_test.py | 1 - 5 files changed, 7 insertions(+), 16 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 2450804a2..0c472a13b 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -632,6 +632,10 @@ def _lattice_result_type(*args: Any) -> Tuple[DType, bool]: if len(dtypes) == 1: out_dtype = dtypes[0] out_weak_type = weak_types[0] + elif len(set(dtypes)) == 1 and not all(weak_types): + # Trivial promotion case. This allows opaque dtypes through. + out_dtype = dtypes[0] + out_weak_type = False elif all(weak_types) and config.jax_numpy_dtype_promotion != 'strict': # If all inputs are weakly typed, we compute the bound of the strongly-typed # counterparts and apply the weak type at the end. This avoids returning the diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index acebeb895..73bd555a3 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1790,15 +1790,13 @@ def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]], axis: Optional[int] = 0, dtype: Optional[DTypeLike] = None) -> Array: if isinstance(arrays, (np.ndarray, Array)): return _concatenate_array(arrays, axis, dtype=dtype) - util._stackable(*arrays) or util.check_arraylike("concatenate", *arrays) + util.check_arraylike("concatenate", *arrays) if not len(arrays): raise ValueError("Need at least one array to concatenate.") if ndim(arrays[0]) == 0: raise ValueError("Zero-dimensional arrays cannot be concatenated.") if axis is None: return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype) - if hasattr(arrays[0], "concatenate"): - return arrays[0].concatenate(arrays[1:], axis, dtype=dtype) # type: ignore[union-attr] axis = _canonicalize_axis(axis, ndim(arrays[0])) if dtype is None: arrays_out = util.promote_dtypes(*arrays) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 0cda54957..1e611f229 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -271,7 +271,7 @@ def promote_dtypes(*args: ArrayLike) -> List[Array]: return [lax.asarray(arg) for arg in args] else: to_dtype, weak_type = dtypes._lattice_result_type(*args) - to_dtype = dtypes.canonicalize_dtype(to_dtype) + to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_opaque_dtype=True) return [lax._convert_element_type(x, to_dtype, weak_type) for x in args] @@ -280,7 +280,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> List[Array]: Promotes arguments to an inexact type.""" to_dtype, weak_type = dtypes._lattice_result_type(*args) - to_dtype = dtypes.canonicalize_dtype(to_dtype) + to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_opaque_dtype=True) to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) return [lax._convert_element_type(x, to_dtype_inexact, weak_type) for x in args] diff --git a/jax/_src/prng.py b/jax/_src/prng.py index a10d463ac..54ed0a142 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -161,8 +161,6 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta): @abc.abstractmethod def reshape(self, newshape, order=None) -> PRNGKeyArray: ... @abc.abstractmethod - def concatenate(self, key_arrs, axis, dtype=None) -> PRNGKeyArray: ... - @abc.abstractmethod def broadcast_to(self, shape) -> PRNGKeyArray: ... @abc.abstractmethod def expand_dims(self, dimensions: Sequence[int]) -> PRNGKeyArray: ... @@ -279,14 +277,6 @@ class PRNGKeyArrayImpl(PRNGKeyArray): reshaped_base = jnp.reshape(self._base_array, (*newshape, -1), order=order) return PRNGKeyArrayImpl(self.impl, reshaped_base) - def concatenate(self, key_arrs, axis, dtype=None) -> PRNGKeyArrayImpl: - if dtype is not None: - raise ValueError( - 'dtype argument not supported for concatenating PRNGKeyArray') - axis = canonicalize_axis(axis, self.ndim) - arrs = [self._base_array, *[k._base_array for k in key_arrs]] - return PRNGKeyArrayImpl(self.impl, jnp.concatenate(arrs, axis)) - def broadcast_to(self, shape) -> PRNGKeyArrayImpl: if jnp.ndim(shape) == 0: shape = (shape,) diff --git a/tests/random_test.py b/tests/random_test.py index 561b336a6..fbe84e993 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1999,7 +1999,6 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase): self.assertEqual(out.shape, (3,)) def test_concatenate(self): - self.skipTest('jnp.concatenate on key arrays') # TODO(frostig) key = random.PRNGKey(123) keys = random.split(key, 2) ref = jnp.concatenate([like(keys)] * 3, axis=0)