mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add dtype arg to jnp.concatenate and update tests
This commit is contained in:
parent
be8939771c
commit
1987ca7389
@ -1650,9 +1650,9 @@ def tile(A, reps):
|
||||
[k for pair in zip(reps, A_shape) for k in pair])
|
||||
return reshape(result, tuple(np.multiply(A_shape, reps)))
|
||||
|
||||
def _concatenate_array(arr, axis: int):
|
||||
def _concatenate_array(arr, axis: int, dtype=None):
|
||||
# Fast path for concatenation when the input is an ndarray rather than a list.
|
||||
arr = asarray(arr)
|
||||
arr = asarray(arr, dtype=dtype)
|
||||
if arr.ndim == 0 or arr.shape[0] == 0:
|
||||
raise ValueError("Need at least one array to concatenate.")
|
||||
if axis is None:
|
||||
@ -1665,26 +1665,29 @@ def _concatenate_array(arr, axis: int):
|
||||
return lax.reshape(arr, shape, dimensions)
|
||||
|
||||
@_wraps(np.concatenate)
|
||||
def concatenate(arrays, axis: int = 0):
|
||||
def concatenate(arrays, axis: int = 0, dtype=None):
|
||||
if isinstance(arrays, (np.ndarray, ndarray)):
|
||||
return _concatenate_array(arrays, axis)
|
||||
return _concatenate_array(arrays, axis, dtype=dtype)
|
||||
_stackable(*arrays) or _check_arraylike("concatenate", *arrays)
|
||||
if not len(arrays):
|
||||
raise ValueError("Need at least one array to concatenate.")
|
||||
if ndim(arrays[0]) == 0:
|
||||
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
|
||||
if axis is None:
|
||||
return concatenate([ravel(a) for a in arrays], axis=0)
|
||||
return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype)
|
||||
if hasattr(arrays[0], "concatenate"):
|
||||
return arrays[0].concatenate(arrays[1:], axis)
|
||||
return arrays[0].concatenate(arrays[1:], axis, dtype=dtype)
|
||||
axis = _canonicalize_axis(axis, ndim(arrays[0]))
|
||||
arrays = _promote_dtypes(*arrays)
|
||||
if dtype is None:
|
||||
arrays = _promote_dtypes(*arrays)
|
||||
else:
|
||||
arrays = [asarray(arr, dtype=dtype) for arr in arrays]
|
||||
# lax.concatenate can be slow to compile for wide concatenations, so form a
|
||||
# tree of concatenations as a workaround especially for op-by-op mode.
|
||||
# (https://github.com/google/jax/issues/653).
|
||||
k = 16
|
||||
if len(arrays) == 1:
|
||||
return asarray(arrays[0])
|
||||
return asarray(arrays[0], dtype=dtype)
|
||||
else:
|
||||
while len(arrays) > 1:
|
||||
arrays = [lax.concatenate(arrays[i:i+k], axis)
|
||||
|
@ -203,7 +203,9 @@ class PRNGKeyArray:
|
||||
reshaped_keys = jnp.reshape(self._keys, (*newshape, -1), order=order)
|
||||
return PRNGKeyArray(self.impl, reshaped_keys)
|
||||
|
||||
def concatenate(self, key_arrs, axis):
|
||||
def concatenate(self, key_arrs, axis, dtype=None):
|
||||
if dtype is not None:
|
||||
raise ValueError('dtype argument not supported for concatenating PRNGKeyArray')
|
||||
axis = canonicalize_axis(axis, self.ndim)
|
||||
arrs = [self._keys, *[k._keys for k in key_arrs]]
|
||||
return PRNGKeyArray(self.impl, jnp.concatenate(arrs, axis))
|
||||
|
@ -2299,15 +2299,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
|
||||
axis, ",".join(str(d) for d in base_shape),
|
||||
{"testcase_name": "_axis={}_dtype={}_baseshape=[{}]_argdtypes=[{}]".format(
|
||||
axis, dtype and np.dtype(dtype).name, ",".join(str(d) for d in base_shape),
|
||||
",".join(np.dtype(dtype).name for dtype in arg_dtypes)),
|
||||
"axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes}
|
||||
"axis": axis, "dtype": dtype, "base_shape": base_shape, "arg_dtypes": arg_dtypes}
|
||||
for num_arrs in [3]
|
||||
for arg_dtypes in itertools.combinations_with_replacement(default_dtypes, num_arrs)
|
||||
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
||||
for dtype in [None] + default_dtypes
|
||||
for axis in range(-len(base_shape)+1, len(base_shape))))
|
||||
def testConcatenate(self, axis, base_shape, arg_dtypes):
|
||||
def testConcatenate(self, axis, dtype, base_shape, arg_dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
wrapped_axis = axis % len(base_shape)
|
||||
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
|
||||
@ -2315,9 +2316,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def np_fun(*args):
|
||||
args = [x if x.dtype != jnp.bfloat16 else x.astype(np.float32)
|
||||
for x in args]
|
||||
dtype = functools.reduce(jnp.promote_types, arg_dtypes)
|
||||
return np.concatenate(args, axis=axis).astype(dtype)
|
||||
jnp_fun = lambda *args: jnp.concatenate(args, axis=axis)
|
||||
if numpy_version < (1, 20):
|
||||
_dtype = dtype or jnp.result_type(*arg_dtypes)
|
||||
return np.concatenate(args, axis=axis).astype(_dtype)
|
||||
return np.concatenate(args, axis=axis, dtype=dtype, casting='unsafe')
|
||||
jnp_fun = lambda *args: jnp.concatenate(args, axis=axis, dtype=dtype)
|
||||
|
||||
def args_maker():
|
||||
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user