Add dtype arg to jnp.concatenate and update tests

This commit is contained in:
Penn 2022-08-01 15:48:40 -07:00
parent be8939771c
commit 1987ca7389
3 changed files with 24 additions and 16 deletions

View File

@ -1650,9 +1650,9 @@ def tile(A, reps):
[k for pair in zip(reps, A_shape) for k in pair])
return reshape(result, tuple(np.multiply(A_shape, reps)))
def _concatenate_array(arr, axis: int):
def _concatenate_array(arr, axis: int, dtype=None):
# Fast path for concatenation when the input is an ndarray rather than a list.
arr = asarray(arr)
arr = asarray(arr, dtype=dtype)
if arr.ndim == 0 or arr.shape[0] == 0:
raise ValueError("Need at least one array to concatenate.")
if axis is None:
@ -1665,26 +1665,29 @@ def _concatenate_array(arr, axis: int):
return lax.reshape(arr, shape, dimensions)
@_wraps(np.concatenate)
def concatenate(arrays, axis: int = 0):
def concatenate(arrays, axis: int = 0, dtype=None):
if isinstance(arrays, (np.ndarray, ndarray)):
return _concatenate_array(arrays, axis)
return _concatenate_array(arrays, axis, dtype=dtype)
_stackable(*arrays) or _check_arraylike("concatenate", *arrays)
if not len(arrays):
raise ValueError("Need at least one array to concatenate.")
if ndim(arrays[0]) == 0:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
if axis is None:
return concatenate([ravel(a) for a in arrays], axis=0)
return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype)
if hasattr(arrays[0], "concatenate"):
return arrays[0].concatenate(arrays[1:], axis)
return arrays[0].concatenate(arrays[1:], axis, dtype=dtype)
axis = _canonicalize_axis(axis, ndim(arrays[0]))
arrays = _promote_dtypes(*arrays)
if dtype is None:
arrays = _promote_dtypes(*arrays)
else:
arrays = [asarray(arr, dtype=dtype) for arr in arrays]
# lax.concatenate can be slow to compile for wide concatenations, so form a
# tree of concatenations as a workaround especially for op-by-op mode.
# (https://github.com/google/jax/issues/653).
k = 16
if len(arrays) == 1:
return asarray(arrays[0])
return asarray(arrays[0], dtype=dtype)
else:
while len(arrays) > 1:
arrays = [lax.concatenate(arrays[i:i+k], axis)

View File

@ -203,7 +203,9 @@ class PRNGKeyArray:
reshaped_keys = jnp.reshape(self._keys, (*newshape, -1), order=order)
return PRNGKeyArray(self.impl, reshaped_keys)
def concatenate(self, key_arrs, axis):
def concatenate(self, key_arrs, axis, dtype=None):
if dtype is not None:
raise ValueError('dtype argument not supported for concatenating PRNGKeyArray')
axis = canonicalize_axis(axis, self.ndim)
arrs = [self._keys, *[k._keys for k in key_arrs]]
return PRNGKeyArray(self.impl, jnp.concatenate(arrs, axis))

View File

@ -2299,15 +2299,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
axis, ",".join(str(d) for d in base_shape),
{"testcase_name": "_axis={}_dtype={}_baseshape=[{}]_argdtypes=[{}]".format(
axis, dtype and np.dtype(dtype).name, ",".join(str(d) for d in base_shape),
",".join(np.dtype(dtype).name for dtype in arg_dtypes)),
"axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes}
"axis": axis, "dtype": dtype, "base_shape": base_shape, "arg_dtypes": arg_dtypes}
for num_arrs in [3]
for arg_dtypes in itertools.combinations_with_replacement(default_dtypes, num_arrs)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for dtype in [None] + default_dtypes
for axis in range(-len(base_shape)+1, len(base_shape))))
def testConcatenate(self, axis, base_shape, arg_dtypes):
def testConcatenate(self, axis, dtype, base_shape, arg_dtypes):
rng = jtu.rand_default(self.rng())
wrapped_axis = axis % len(base_shape)
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
@ -2315,9 +2316,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def np_fun(*args):
args = [x if x.dtype != jnp.bfloat16 else x.astype(np.float32)
for x in args]
dtype = functools.reduce(jnp.promote_types, arg_dtypes)
return np.concatenate(args, axis=axis).astype(dtype)
jnp_fun = lambda *args: jnp.concatenate(args, axis=axis)
if numpy_version < (1, 20):
_dtype = dtype or jnp.result_type(*arg_dtypes)
return np.concatenate(args, axis=axis).astype(_dtype)
return np.concatenate(args, axis=axis, dtype=dtype, casting='unsafe')
jnp_fun = lambda *args: jnp.concatenate(args, axis=axis, dtype=dtype)
def args_maker():
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]