custom prng: remove stackable override for jnp.concatenate

This commit is contained in:
Jake VanderPlas 2023-04-24 12:26:58 -07:00
parent 075bbe3203
commit a5737f82af
5 changed files with 7 additions and 16 deletions

View File

@ -632,6 +632,10 @@ def _lattice_result_type(*args: Any) -> Tuple[DType, bool]:
if len(dtypes) == 1:
out_dtype = dtypes[0]
out_weak_type = weak_types[0]
elif len(set(dtypes)) == 1 and not all(weak_types):
# Trivial promotion case. This allows opaque dtypes through.
out_dtype = dtypes[0]
out_weak_type = False
elif all(weak_types) and config.jax_numpy_dtype_promotion != 'strict':
# If all inputs are weakly typed, we compute the bound of the strongly-typed
# counterparts and apply the weak type at the end. This avoids returning the

View File

@ -1790,15 +1790,13 @@ def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
axis: Optional[int] = 0, dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(arrays, (np.ndarray, Array)):
return _concatenate_array(arrays, axis, dtype=dtype)
util._stackable(*arrays) or util.check_arraylike("concatenate", *arrays)
util.check_arraylike("concatenate", *arrays)
if not len(arrays):
raise ValueError("Need at least one array to concatenate.")
if ndim(arrays[0]) == 0:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
if axis is None:
return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype)
if hasattr(arrays[0], "concatenate"):
return arrays[0].concatenate(arrays[1:], axis, dtype=dtype) # type: ignore[union-attr]
axis = _canonicalize_axis(axis, ndim(arrays[0]))
if dtype is None:
arrays_out = util.promote_dtypes(*arrays)

View File

@ -271,7 +271,7 @@ def promote_dtypes(*args: ArrayLike) -> List[Array]:
return [lax.asarray(arg) for arg in args]
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_opaque_dtype=True)
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
@ -280,7 +280,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
Promotes arguments to an inexact type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_opaque_dtype=True)
to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype)
return [lax._convert_element_type(x, to_dtype_inexact, weak_type)
for x in args]

View File

@ -161,8 +161,6 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
@abc.abstractmethod
def reshape(self, newshape, order=None) -> PRNGKeyArray: ...
@abc.abstractmethod
def concatenate(self, key_arrs, axis, dtype=None) -> PRNGKeyArray: ...
@abc.abstractmethod
def broadcast_to(self, shape) -> PRNGKeyArray: ...
@abc.abstractmethod
def expand_dims(self, dimensions: Sequence[int]) -> PRNGKeyArray: ...
@ -279,14 +277,6 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
reshaped_base = jnp.reshape(self._base_array, (*newshape, -1), order=order)
return PRNGKeyArrayImpl(self.impl, reshaped_base)
def concatenate(self, key_arrs, axis, dtype=None) -> PRNGKeyArrayImpl:
if dtype is not None:
raise ValueError(
'dtype argument not supported for concatenating PRNGKeyArray')
axis = canonicalize_axis(axis, self.ndim)
arrs = [self._base_array, *[k._base_array for k in key_arrs]]
return PRNGKeyArrayImpl(self.impl, jnp.concatenate(arrs, axis))
def broadcast_to(self, shape) -> PRNGKeyArrayImpl:
if jnp.ndim(shape) == 0:
shape = (shape,)

View File

@ -1999,7 +1999,6 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
self.assertEqual(out.shape, (3,))
def test_concatenate(self):
self.skipTest('jnp.concatenate on key arrays') # TODO(frostig)
key = random.PRNGKey(123)
keys = random.split(key, 2)
ref = jnp.concatenate([like(keys)] * 3, axis=0)