mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #11906 from alonfnt:dtype-arg
PiperOrigin-RevId: 467970800
This commit is contained in:
commit
332d7d0168
@ -1616,14 +1616,14 @@ def pad(array, pad_width, mode="constant", **kwargs):
|
||||
|
||||
|
||||
@_wraps(np.stack, skip_params=['out'])
|
||||
def stack(arrays, axis: int = 0, out=None):
|
||||
def stack(arrays, axis: int = 0, out=None, dtype=None):
|
||||
if not len(arrays):
|
||||
raise ValueError("Need at least one array to stack.")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.stack is not supported.")
|
||||
if isinstance(arrays, (np.ndarray, ndarray)):
|
||||
axis = _canonicalize_axis(axis, arrays.ndim)
|
||||
return concatenate(expand_dims(arrays, axis + 1), axis=axis)
|
||||
return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype)
|
||||
else:
|
||||
_stackable(*arrays) or _check_arraylike("stack", *arrays)
|
||||
shape0 = shape(arrays[0])
|
||||
@ -1633,7 +1633,7 @@ def stack(arrays, axis: int = 0, out=None):
|
||||
if shape(a) != shape0:
|
||||
raise ValueError("All input arrays must have the same shape.")
|
||||
new_arrays.append(expand_dims(a, axis))
|
||||
return concatenate(new_arrays, axis=axis)
|
||||
return concatenate(new_arrays, axis=axis, dtype=dtype)
|
||||
|
||||
@_wraps(np.tile)
|
||||
def tile(A, reps):
|
||||
@ -1696,33 +1696,33 @@ def concatenate(arrays, axis: int = 0, dtype=None):
|
||||
|
||||
|
||||
@_wraps(np.vstack)
|
||||
def vstack(tup):
|
||||
def vstack(tup, dtype=None):
|
||||
if isinstance(tup, (np.ndarray, ndarray)):
|
||||
arrs = jax.vmap(atleast_2d)(tup)
|
||||
else:
|
||||
arrs = [atleast_2d(m) for m in tup]
|
||||
return concatenate(arrs, axis=0)
|
||||
return concatenate(arrs, axis=0, dtype=dtype)
|
||||
row_stack = vstack
|
||||
|
||||
|
||||
@_wraps(np.hstack)
|
||||
def hstack(tup):
|
||||
def hstack(tup, dtype=None):
|
||||
if isinstance(tup, (np.ndarray, ndarray)):
|
||||
arrs = jax.vmap(atleast_1d)(tup)
|
||||
arr0_ndim = arrs.ndim - 1
|
||||
else:
|
||||
arrs = [atleast_1d(m) for m in tup]
|
||||
arr0_ndim = arrs[0].ndim
|
||||
return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1)
|
||||
return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype)
|
||||
|
||||
|
||||
@_wraps(np.dstack)
|
||||
def dstack(tup):
|
||||
def dstack(tup, dtype=None):
|
||||
if isinstance(tup, (np.ndarray, ndarray)):
|
||||
arrs = jax.vmap(atleast_3d)(tup)
|
||||
else:
|
||||
arrs = [atleast_3d(m) for m in tup]
|
||||
return concatenate(arrs, axis=2)
|
||||
return concatenate(arrs, axis=2, dtype=dtype)
|
||||
|
||||
|
||||
@_wraps(np.column_stack)
|
||||
|
@ -3311,9 +3311,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "{}_axis={}_array={}".format(
|
||||
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis, array_input),
|
||||
"shape": shape, "axis": axis, "dtypes": dtypes, "array_input": array_input}
|
||||
{"testcase_name": "{}_axis={}_array={}_out={}".format(
|
||||
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis, array_input,
|
||||
np.dtype(out_dtype).name),
|
||||
"shape": shape, "axis": axis, "dtypes": dtypes, "array_input": array_input,
|
||||
"out_dtype": out_dtype}
|
||||
for dtypes in [
|
||||
[np.float32],
|
||||
[np.float32, np.float32],
|
||||
@ -3323,23 +3325,30 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
]
|
||||
for shape in [(), (2,), (3, 4), (1, 100)]
|
||||
for axis in range(-len(shape), len(shape) + 1)
|
||||
for array_input in [True, False]))
|
||||
def testStack(self, shape, axis, dtypes, array_input):
|
||||
for array_input in [True, False]
|
||||
for out_dtype in [np.float32, np.int32]))
|
||||
def testStack(self, shape, axis, dtypes, array_input, out_dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
if array_input:
|
||||
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
|
||||
else:
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
np_fun = _promote_like_jnp(partial(np.stack, axis=axis))
|
||||
jnp_fun = partial(jnp.stack, axis=axis)
|
||||
|
||||
if numpy_version < (1, 24):
|
||||
np_fun = _promote_like_jnp(lambda *args: np.stack(*args, axis=axis).astype(out_dtype))
|
||||
else:
|
||||
np_fun = _promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype))
|
||||
|
||||
jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype)
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_{}_array={}".format(
|
||||
op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input),
|
||||
"shape": shape, "op": op, "dtypes": dtypes, "array_input": array_input}
|
||||
{"testcase_name": "_op={}_{}_array={}_out={}".format(
|
||||
op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input,
|
||||
np.dtype(out_dtype).name),
|
||||
"shape": shape, "op": op, "dtypes": dtypes, "array_input": array_input, "out_dtype": out_dtype}
|
||||
for op in ["hstack", "vstack", "dstack"]
|
||||
for dtypes in [
|
||||
[np.float32],
|
||||
@ -3349,15 +3358,21 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
[np.float32, np.int32, np.float64],
|
||||
]
|
||||
for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)]
|
||||
for array_input in [True, False]))
|
||||
def testHVDStack(self, shape, op, dtypes, array_input):
|
||||
for array_input in [True, False]
|
||||
for out_dtype in [np.float32, np.int32]))
|
||||
def testHVDStack(self, shape, op, dtypes, array_input, out_dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
if array_input:
|
||||
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
|
||||
else:
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
np_fun = _promote_like_jnp(getattr(np, op))
|
||||
jnp_fun = getattr(jnp, op)
|
||||
|
||||
if numpy_version < (1, 24) or op == "dstack":
|
||||
np_fun = _promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype))
|
||||
else:
|
||||
np_fun = partial(_promote_like_jnp(getattr(np, op)), dtype=out_dtype)
|
||||
|
||||
jnp_fun = partial(getattr(jnp, op), dtype=out_dtype)
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
@ -6388,7 +6403,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'einsum': ['kwargs'],
|
||||
'einsum_path': ['einsum_call'],
|
||||
'eye': ['order', 'like'],
|
||||
'hstack': ['dtype', 'casting'],
|
||||
'hstack': ['casting'],
|
||||
'identity': ['like'],
|
||||
'in1d': ['kind'],
|
||||
'isin': ['kind'],
|
||||
@ -6400,11 +6415,11 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'histogramdd': ['normed'],
|
||||
'ones': ['order', 'like'],
|
||||
'ones_like': ['subok', 'order'],
|
||||
'row_stack': ['dtype', 'casting'],
|
||||
'stack': ['dtype', 'casting'],
|
||||
'row_stack': ['casting'],
|
||||
'stack': ['casting'],
|
||||
'tri': ['like'],
|
||||
'unique': ['equal_nan'],
|
||||
'vstack': ['dtype', 'casting'],
|
||||
'vstack': ['casting'],
|
||||
'zeros_like': ['subok', 'order']
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user