Merge pull request #11906 from alonfnt:dtype-arg

PiperOrigin-RevId: 467970800
This commit is contained in:
jax authors 2022-08-16 10:59:39 -07:00
commit 332d7d0168
2 changed files with 42 additions and 27 deletions

View File

@ -1616,14 +1616,14 @@ def pad(array, pad_width, mode="constant", **kwargs):
@_wraps(np.stack, skip_params=['out'])
def stack(arrays, axis: int = 0, out=None):
def stack(arrays, axis: int = 0, out=None, dtype=None):
if not len(arrays):
raise ValueError("Need at least one array to stack.")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.stack is not supported.")
if isinstance(arrays, (np.ndarray, ndarray)):
axis = _canonicalize_axis(axis, arrays.ndim)
return concatenate(expand_dims(arrays, axis + 1), axis=axis)
return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype)
else:
_stackable(*arrays) or _check_arraylike("stack", *arrays)
shape0 = shape(arrays[0])
@ -1633,7 +1633,7 @@ def stack(arrays, axis: int = 0, out=None):
if shape(a) != shape0:
raise ValueError("All input arrays must have the same shape.")
new_arrays.append(expand_dims(a, axis))
return concatenate(new_arrays, axis=axis)
return concatenate(new_arrays, axis=axis, dtype=dtype)
@_wraps(np.tile)
def tile(A, reps):
@ -1696,33 +1696,33 @@ def concatenate(arrays, axis: int = 0, dtype=None):
@_wraps(np.vstack)
def vstack(tup):
def vstack(tup, dtype=None):
if isinstance(tup, (np.ndarray, ndarray)):
arrs = jax.vmap(atleast_2d)(tup)
else:
arrs = [atleast_2d(m) for m in tup]
return concatenate(arrs, axis=0)
return concatenate(arrs, axis=0, dtype=dtype)
row_stack = vstack
@_wraps(np.hstack)
def hstack(tup):
def hstack(tup, dtype=None):
if isinstance(tup, (np.ndarray, ndarray)):
arrs = jax.vmap(atleast_1d)(tup)
arr0_ndim = arrs.ndim - 1
else:
arrs = [atleast_1d(m) for m in tup]
arr0_ndim = arrs[0].ndim
return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1)
return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype)
@_wraps(np.dstack)
def dstack(tup):
def dstack(tup, dtype=None):
if isinstance(tup, (np.ndarray, ndarray)):
arrs = jax.vmap(atleast_3d)(tup)
else:
arrs = [atleast_3d(m) for m in tup]
return concatenate(arrs, axis=2)
return concatenate(arrs, axis=2, dtype=dtype)
@_wraps(np.column_stack)

View File

@ -3311,9 +3311,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "{}_axis={}_array={}".format(
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis, array_input),
"shape": shape, "axis": axis, "dtypes": dtypes, "array_input": array_input}
{"testcase_name": "{}_axis={}_array={}_out={}".format(
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis, array_input,
np.dtype(out_dtype).name),
"shape": shape, "axis": axis, "dtypes": dtypes, "array_input": array_input,
"out_dtype": out_dtype}
for dtypes in [
[np.float32],
[np.float32, np.float32],
@ -3323,23 +3325,30 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
]
for shape in [(), (2,), (3, 4), (1, 100)]
for axis in range(-len(shape), len(shape) + 1)
for array_input in [True, False]))
def testStack(self, shape, axis, dtypes, array_input):
for array_input in [True, False]
for out_dtype in [np.float32, np.int32]))
def testStack(self, shape, axis, dtypes, array_input, out_dtype):
rng = jtu.rand_default(self.rng())
if array_input:
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
np_fun = _promote_like_jnp(partial(np.stack, axis=axis))
jnp_fun = partial(jnp.stack, axis=axis)
if numpy_version < (1, 24):
np_fun = _promote_like_jnp(lambda *args: np.stack(*args, axis=axis).astype(out_dtype))
else:
np_fun = _promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype))
jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype)
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_{}_array={}".format(
op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input),
"shape": shape, "op": op, "dtypes": dtypes, "array_input": array_input}
{"testcase_name": "_op={}_{}_array={}_out={}".format(
op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input,
np.dtype(out_dtype).name),
"shape": shape, "op": op, "dtypes": dtypes, "array_input": array_input, "out_dtype": out_dtype}
for op in ["hstack", "vstack", "dstack"]
for dtypes in [
[np.float32],
@ -3349,15 +3358,21 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
[np.float32, np.int32, np.float64],
]
for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)]
for array_input in [True, False]))
def testHVDStack(self, shape, op, dtypes, array_input):
for array_input in [True, False]
for out_dtype in [np.float32, np.int32]))
def testHVDStack(self, shape, op, dtypes, array_input, out_dtype):
rng = jtu.rand_default(self.rng())
if array_input:
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
np_fun = _promote_like_jnp(getattr(np, op))
jnp_fun = getattr(jnp, op)
if numpy_version < (1, 24) or op == "dstack":
np_fun = _promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype))
else:
np_fun = partial(_promote_like_jnp(getattr(np, op)), dtype=out_dtype)
jnp_fun = partial(getattr(jnp, op), dtype=out_dtype)
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@ -6388,7 +6403,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'einsum': ['kwargs'],
'einsum_path': ['einsum_call'],
'eye': ['order', 'like'],
'hstack': ['dtype', 'casting'],
'hstack': ['casting'],
'identity': ['like'],
'in1d': ['kind'],
'isin': ['kind'],
@ -6400,11 +6415,11 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'histogramdd': ['normed'],
'ones': ['order', 'like'],
'ones_like': ['subok', 'order'],
'row_stack': ['dtype', 'casting'],
'stack': ['dtype', 'casting'],
'row_stack': ['casting'],
'stack': ['casting'],
'tri': ['like'],
'unique': ['equal_nan'],
'vstack': ['dtype', 'casting'],
'vstack': ['casting'],
'zeros_like': ['subok', 'order']
}