mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
address reviewer comments, fix test error
This commit is contained in:
parent
eae47b2330
commit
14acca7b51
29
jax/api.py
29
jax/api.py
@ -572,12 +572,11 @@ def vmap(fun, in_axes=0, out_axes=0):
|
||||
length equal to the number of positional arguments to ``fun``. An integer
|
||||
or None indicates which array axis to map over for all arguments (with
|
||||
None indicating not to map any axis), and a tuple indicates which axis to
|
||||
map for each corresponding positional argument. More generally, if the
|
||||
positinal arguments to ``fun`` are container types, the corresponding
|
||||
element of ``in_axes`` can itself be a matching container, so that
|
||||
distinct array axes can be mapped for different container elements. The
|
||||
constraint is that ``in_axes`` must be a container tree prefix of the
|
||||
positional argument tuple passed to ``fun``.
|
||||
map for each corresponding positional argument. If the positinal arguments
|
||||
to ``fun`` are container types, the corresponding element of ``in_axes``
|
||||
can itself be a matching container, so that distinct array axes can be
|
||||
mapped for different container elements. ``in_axes`` must be a container
|
||||
tree prefix of the positional argument tuple passed to ``fun``.
|
||||
out_axes: A nonnegative integer, None, or (nested) standard Python container
|
||||
(tuple/list/dict) thereof indicating where the mapped axis should appear
|
||||
in the output.
|
||||
@ -623,19 +622,8 @@ def vmap(fun, in_axes=0, out_axes=0):
|
||||
>>> z = np.ones((C, D, K))
|
||||
>>> tree = (x, (y, z))
|
||||
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
|
||||
>>> print(vfoo(tree))
|
||||
[[[12. 12. 12. 12. 12.]
|
||||
[12. 12. 12. 12. 12.]]
|
||||
[[12. 12. 12. 12. 12.]
|
||||
[12. 12. 12. 12. 12.]]
|
||||
[[12. 12. 12. 12. 12.]
|
||||
[12. 12. 12. 12. 12.]]
|
||||
[[12. 12. 12. 12. 12.]
|
||||
[12. 12. 12. 12. 12.]]
|
||||
[[12. 12. 12. 12. 12.]
|
||||
[12. 12. 12. 12. 12.]]
|
||||
[[12. 12. 12. 12. 12.]
|
||||
[12. 12. 12. 12. 12.]]]
|
||||
>>> print(vfoo(tree)).shape
|
||||
(6, 2, 5)
|
||||
"""
|
||||
docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "
|
||||
"but with additional array axes over which {fun} is mapped.")
|
||||
@ -648,8 +636,9 @@ def vmap(fun, in_axes=0, out_axes=0):
|
||||
raise TypeError(msg.format(type(in_axes), type(out_axes)))
|
||||
|
||||
def _check_axis_sizes(tree, vals, dims):
|
||||
mapped_axis_sizes = {x.shape[d] for x, d in zip(vals, dims) if d is not None}
|
||||
try:
|
||||
sizes, = {x.shape[d] for x, d in zip(vals, dims) if d is not None}
|
||||
sizes, = mapped_axis_sizes
|
||||
except ValueError:
|
||||
msg = "vmap got inconsistent sizes for array axes to be mapped:\n{}"
|
||||
# we switch the error message based on whether args is a tuple of arrays,
|
||||
|
@ -1099,14 +1099,14 @@ class APITest(jtu.JaxTestCase):
|
||||
lambda: api.vmap(lambda x: x, in_axes=(0, 0))(np.ones(3)),
|
||||
ValueError, "axes specification must be a tree prefix")
|
||||
|
||||
def test_vmap_objects_issue_183(self):
|
||||
def test_vmap_unbatched_object_passthrough_issue_183(self):
|
||||
# https://github.com/google/jax/issues/183
|
||||
fun = lambda f, x: f(x)
|
||||
vfun = api.vmap(fun, (None, 0))
|
||||
ans = vfun(lambda x: x + 1, np.arange(3))
|
||||
self.assertAllClose(ans, onp.arange(1, 4), check_dtypes=False)
|
||||
|
||||
def test_vmap_error_message_issue_705(self):
|
||||
def test_vmap_mismatched_axis_sizes_error_message_issue_705(self):
|
||||
# https://github.com/google/jax/issues/705
|
||||
def h(a, b):
|
||||
return np.sum(a) + np.sum(b)
|
||||
@ -1126,7 +1126,8 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
||||
"the tree of axis sizes is:\n",
|
||||
"the tree of axis sizes is:\n"
|
||||
"\(10, \[2, 2\]\)",
|
||||
lambda: api.vmap(h, in_axes=(0, 1))(X, [U, U]))
|
||||
|
||||
|
||||
|
@ -369,20 +369,22 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
|
||||
def testDynamicSlice(self):
|
||||
# test dynamic_slice via numpy indexing syntax
|
||||
x = onp.arange(30).reshape((10, 3))
|
||||
# see https://github.com/google/jax/issues/1613 for an explanation of why we
|
||||
# need to use np rather than onp to create x and idx
|
||||
x = np.arange(30).reshape((10, 3))
|
||||
|
||||
ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1)
|
||||
expected = x[:, 1]
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
|
||||
idx = onp.array([0, 1, 2, 1, 0] * 2)
|
||||
idx = np.array([0, 1, 2, 1, 0] * 2)
|
||||
ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx)
|
||||
expected = x[onp.arange(10), idx]
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
x = onp.arange(3)
|
||||
idx = onp.array([0, 1, 2, 1, 0] * 2)
|
||||
x = np.arange(3)
|
||||
idx = np.array([0, 1, 2, 1, 0] * 2)
|
||||
ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx)
|
||||
expected = x[idx]
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user