mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
custom prng: remove stackable override for jnp.concatenate
This commit is contained in:
parent
075bbe3203
commit
a5737f82af
@ -632,6 +632,10 @@ def _lattice_result_type(*args: Any) -> Tuple[DType, bool]:
|
||||
if len(dtypes) == 1:
|
||||
out_dtype = dtypes[0]
|
||||
out_weak_type = weak_types[0]
|
||||
elif len(set(dtypes)) == 1 and not all(weak_types):
|
||||
# Trivial promotion case. This allows opaque dtypes through.
|
||||
out_dtype = dtypes[0]
|
||||
out_weak_type = False
|
||||
elif all(weak_types) and config.jax_numpy_dtype_promotion != 'strict':
|
||||
# If all inputs are weakly typed, we compute the bound of the strongly-typed
|
||||
# counterparts and apply the weak type at the end. This avoids returning the
|
||||
|
@ -1790,15 +1790,13 @@ def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
|
||||
axis: Optional[int] = 0, dtype: Optional[DTypeLike] = None) -> Array:
|
||||
if isinstance(arrays, (np.ndarray, Array)):
|
||||
return _concatenate_array(arrays, axis, dtype=dtype)
|
||||
util._stackable(*arrays) or util.check_arraylike("concatenate", *arrays)
|
||||
util.check_arraylike("concatenate", *arrays)
|
||||
if not len(arrays):
|
||||
raise ValueError("Need at least one array to concatenate.")
|
||||
if ndim(arrays[0]) == 0:
|
||||
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
|
||||
if axis is None:
|
||||
return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype)
|
||||
if hasattr(arrays[0], "concatenate"):
|
||||
return arrays[0].concatenate(arrays[1:], axis, dtype=dtype) # type: ignore[union-attr]
|
||||
axis = _canonicalize_axis(axis, ndim(arrays[0]))
|
||||
if dtype is None:
|
||||
arrays_out = util.promote_dtypes(*arrays)
|
||||
|
@ -271,7 +271,7 @@ def promote_dtypes(*args: ArrayLike) -> List[Array]:
|
||||
return [lax.asarray(arg) for arg in args]
|
||||
else:
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_opaque_dtype=True)
|
||||
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
|
||||
|
||||
|
||||
@ -280,7 +280,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
|
||||
|
||||
Promotes arguments to an inexact type."""
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_opaque_dtype=True)
|
||||
to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype)
|
||||
return [lax._convert_element_type(x, to_dtype_inexact, weak_type)
|
||||
for x in args]
|
||||
|
@ -161,8 +161,6 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
|
||||
@abc.abstractmethod
|
||||
def reshape(self, newshape, order=None) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def concatenate(self, key_arrs, axis, dtype=None) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def broadcast_to(self, shape) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def expand_dims(self, dimensions: Sequence[int]) -> PRNGKeyArray: ...
|
||||
@ -279,14 +277,6 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
reshaped_base = jnp.reshape(self._base_array, (*newshape, -1), order=order)
|
||||
return PRNGKeyArrayImpl(self.impl, reshaped_base)
|
||||
|
||||
def concatenate(self, key_arrs, axis, dtype=None) -> PRNGKeyArrayImpl:
|
||||
if dtype is not None:
|
||||
raise ValueError(
|
||||
'dtype argument not supported for concatenating PRNGKeyArray')
|
||||
axis = canonicalize_axis(axis, self.ndim)
|
||||
arrs = [self._base_array, *[k._base_array for k in key_arrs]]
|
||||
return PRNGKeyArrayImpl(self.impl, jnp.concatenate(arrs, axis))
|
||||
|
||||
def broadcast_to(self, shape) -> PRNGKeyArrayImpl:
|
||||
if jnp.ndim(shape) == 0:
|
||||
shape = (shape,)
|
||||
|
@ -1999,7 +1999,6 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.shape, (3,))
|
||||
|
||||
def test_concatenate(self):
|
||||
self.skipTest('jnp.concatenate on key arrays') # TODO(frostig)
|
||||
key = random.PRNGKey(123)
|
||||
keys = random.split(key, 2)
|
||||
ref = jnp.concatenate([like(keys)] * 3, axis=0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user