mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add efficient path for array input to jnp.stack, jnp.[hvd]stack, jnp.column_stack
This commit is contained in:
parent
3550732a74
commit
17710c0711
@ -2801,15 +2801,19 @@ def stack(arrays, axis: int =0, out=None):
|
||||
raise ValueError("Need at least one array to stack.")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.stack is not supported.")
|
||||
_check_arraylike("stack", *arrays)
|
||||
shape0 = shape(arrays[0])
|
||||
axis = _canonicalize_axis(axis, len(shape0) + 1)
|
||||
new_arrays = []
|
||||
for a in arrays:
|
||||
if shape(a) != shape0:
|
||||
raise ValueError("All input arrays must have the same shape.")
|
||||
new_arrays.append(expand_dims(a, axis))
|
||||
return concatenate(new_arrays, axis=axis)
|
||||
if isinstance(arrays, ndarray):
|
||||
axis = _canonicalize_axis(axis, arrays.ndim)
|
||||
return concatenate(expand_dims(arrays, axis + 1), axis=axis)
|
||||
else:
|
||||
_check_arraylike("stack", *arrays)
|
||||
shape0 = shape(arrays[0])
|
||||
axis = _canonicalize_axis(axis, len(shape0) + 1)
|
||||
new_arrays = []
|
||||
for a in arrays:
|
||||
if shape(a) != shape0:
|
||||
raise ValueError("All input arrays must have the same shape.")
|
||||
new_arrays.append(expand_dims(a, axis))
|
||||
return concatenate(new_arrays, axis=axis)
|
||||
|
||||
@_wraps(np.tile)
|
||||
def tile(A, reps):
|
||||
@ -2868,32 +2872,41 @@ def concatenate(arrays, axis: int = 0):
|
||||
|
||||
@_wraps(np.vstack)
|
||||
def vstack(tup):
|
||||
return concatenate([atleast_2d(m) for m in tup], axis=0)
|
||||
if isinstance(tup, ndarray):
|
||||
arrs = jax.vmap(atleast_2d)(tup)
|
||||
else:
|
||||
arrs = [atleast_2d(m) for m in tup]
|
||||
return concatenate(arrs, axis=0)
|
||||
row_stack = vstack
|
||||
|
||||
|
||||
@_wraps(np.hstack)
|
||||
def hstack(tup):
|
||||
arrs = [atleast_1d(m) for m in tup]
|
||||
if arrs[0].ndim == 1:
|
||||
return concatenate(arrs, 0)
|
||||
return concatenate(arrs, 1)
|
||||
if isinstance(tup, ndarray):
|
||||
arrs = jax.vmap(atleast_1d)(tup)
|
||||
arr0_ndim = arrs.ndim - 1
|
||||
else:
|
||||
arrs = [atleast_1d(m) for m in tup]
|
||||
arr0_ndim = arrs[0].ndim
|
||||
return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1)
|
||||
|
||||
|
||||
@_wraps(np.dstack)
|
||||
def dstack(tup):
|
||||
return concatenate([atleast_3d(m) for m in tup], axis=2)
|
||||
if isinstance(tup, ndarray):
|
||||
arrs = jax.vmap(atleast_3d)(tup)
|
||||
else:
|
||||
arrs = [atleast_3d(m) for m in tup]
|
||||
return concatenate(arrs, axis=2)
|
||||
|
||||
|
||||
@_wraps(np.column_stack)
|
||||
def column_stack(tup):
|
||||
arrays = []
|
||||
for v in tup:
|
||||
arr = asarray(v)
|
||||
if arr.ndim < 2:
|
||||
arr = atleast_2d(arr).T
|
||||
arrays.append(arr)
|
||||
return concatenate(arrays, 1)
|
||||
if isinstance(tup, ndarray):
|
||||
arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup
|
||||
else:
|
||||
arrs = [atleast_2d(arr).T if arr.ndim < 2 else arr for arr in map(asarray, tup)]
|
||||
return concatenate(arrs, 1)
|
||||
|
||||
|
||||
@_wraps(np.choose, skip_params=['out'])
|
||||
|
@ -2762,9 +2762,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)),
|
||||
"shape": shape, "dtypes": dtypes}
|
||||
{"testcase_name": "_{}_array={}".format(
|
||||
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input),
|
||||
"shape": shape, "dtypes": dtypes, "array_input": array_input}
|
||||
for dtypes in [
|
||||
[np.float32],
|
||||
[np.float32, np.float32],
|
||||
@ -2772,19 +2772,23 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
[np.float32, np.int64, np.float32],
|
||||
[np.float32, np.int32, np.float64],
|
||||
]
|
||||
for shape in [(), (2,), (3, 4), (1, 5)]))
|
||||
def testColumnStack(self, shape, dtypes):
|
||||
for shape in [(), (2,), (3, 4), (1, 5)]
|
||||
for array_input in [True, False]))
|
||||
def testColumnStack(self, shape, dtypes, array_input):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
if array_input:
|
||||
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
|
||||
else:
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
np_fun = _promote_like_jnp(np.column_stack)
|
||||
jnp_fun = jnp.column_stack
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}".format(
|
||||
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis),
|
||||
"shape": shape, "axis": axis, "dtypes": dtypes}
|
||||
{"testcase_name": "_{}_axis={}_array={}".format(
|
||||
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis, array_input),
|
||||
"shape": shape, "axis": axis, "dtypes": dtypes, "array_input": array_input}
|
||||
for dtypes in [
|
||||
[np.float32],
|
||||
[np.float32, np.float32],
|
||||
@ -2793,19 +2797,23 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
[np.float32, np.int32, np.float64],
|
||||
]
|
||||
for shape in [(), (2,), (3, 4), (1, 100)]
|
||||
for axis in range(-len(shape), len(shape) + 1)))
|
||||
def testStack(self, shape, axis, dtypes):
|
||||
for axis in range(-len(shape), len(shape) + 1)
|
||||
for array_input in [True, False]))
|
||||
def testStack(self, shape, axis, dtypes, array_input):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
if array_input:
|
||||
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
|
||||
else:
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
np_fun = _promote_like_jnp(partial(np.stack, axis=axis))
|
||||
jnp_fun = partial(jnp.stack, axis=axis)
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_{}".format(
|
||||
op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)),
|
||||
"shape": shape, "op": op, "dtypes": dtypes}
|
||||
{"testcase_name": "_op={}_{}_array={}".format(
|
||||
op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input),
|
||||
"shape": shape, "op": op, "dtypes": dtypes, "array_input": array_input}
|
||||
for op in ["hstack", "vstack", "dstack"]
|
||||
for dtypes in [
|
||||
[np.float32],
|
||||
@ -2814,10 +2822,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
[np.float32, np.int64, np.float32],
|
||||
[np.float32, np.int32, np.float64],
|
||||
]
|
||||
for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)]))
|
||||
def testHVDStack(self, shape, op, dtypes):
|
||||
for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)]
|
||||
for array_input in [True, False]))
|
||||
def testHVDStack(self, shape, op, dtypes, array_input):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
if array_input:
|
||||
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
|
||||
else:
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
np_fun = _promote_like_jnp(getattr(np, op))
|
||||
jnp_fun = getattr(jnp, op)
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
|
Loading…
x
Reference in New Issue
Block a user