mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
cbadfd41ce
commit
eae47b2330
33
jax/api.py
33
jax/api.py
@ -647,12 +647,43 @@ def vmap(fun, in_axes=0, out_axes=0):
|
||||
"or a (nested) tuple of those types, got {} and {} respectively.")
|
||||
raise TypeError(msg.format(type(in_axes), type(out_axes)))
|
||||
|
||||
def _check_axis_sizes(tree, vals, dims):
|
||||
try:
|
||||
sizes, = {x.shape[d] for x, d in zip(vals, dims) if d is not None}
|
||||
except ValueError:
|
||||
msg = "vmap got inconsistent sizes for array axes to be mapped:\n{}"
|
||||
# we switch the error message based on whether args is a tuple of arrays,
|
||||
# in which case we can produce an error message based on argument indices,
|
||||
# or if it has nested containers.
|
||||
# TODO(mattjj,phawkins): add a way to inspect pytree kind more directly
|
||||
if tree == tree_flatten((core.unit,) * tree.num_leaves)[1]:
|
||||
lines1 = ["arg {} has shape {} and axis {} is to be mapped"
|
||||
.format(i, x.shape, d) for i, (x, d) in enumerate(zip(vals, dims))]
|
||||
sizes = collections.defaultdict(list)
|
||||
for i, (x, d) in enumerate(zip(vals, dims)):
|
||||
if d is not None:
|
||||
sizes[x.shape[d]].append(i)
|
||||
lines2 = ["{} {} {} {} to be mapped of size {}".format(
|
||||
"args" if len(idxs) > 1 else "arg",
|
||||
", ".join(map(str, idxs)),
|
||||
"have" if len(idxs) > 1 else "has",
|
||||
"axes" if len(idxs) > 1 else "an axis",
|
||||
size)
|
||||
for size, idxs in sizes.items()]
|
||||
raise ValueError(msg.format("\n".join(lines1 + ["so"] + lines2)))
|
||||
else:
|
||||
sizes = [x.shape[d] if d is not None else None for x, d in zip(vals, dims)]
|
||||
sizes = tree_unflatten(tree, sizes)
|
||||
raise ValueError(msg.format("the tree of axis sizes is:\n{}".format(sizes)))
|
||||
|
||||
@wraps(fun, docstr=docstr)
|
||||
def batched_fun(*args):
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
f = lu.wrap_init(fun)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
|
||||
out_flat = batching.batch(flat_fun, args_flat, _flatten_axes(in_tree, in_axes),
|
||||
in_axes_flat = _flatten_axes(in_tree, in_axes)
|
||||
_check_axis_sizes(in_tree, args_flat, in_axes_flat)
|
||||
out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
|
||||
lambda: _flatten_axes(out_tree(), out_axes))
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
|
@ -38,8 +38,8 @@ map = safe_map
|
||||
|
||||
|
||||
def batch(fun, in_vals, in_dims, out_dim_dests):
|
||||
out_vals, out_dims = batch_fun(fun, in_vals, in_dims)
|
||||
size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
||||
out_vals, out_dims = batch_fun(fun, in_vals, in_dims)
|
||||
return map(partial(matchaxis, size), out_dims, out_dim_dests(), out_vals)
|
||||
|
||||
def batch_fun(fun, in_vals, in_dims):
|
||||
@ -163,8 +163,8 @@ def get_primitive_batcher(p):
|
||||
try:
|
||||
return primitive_batchers[p]
|
||||
except KeyError:
|
||||
raise NotImplementedError(
|
||||
"Batching rule for '{}' not implemented".format(p))
|
||||
msg = "Batching rule for '{}' not implemented"
|
||||
raise NotImplementedError(msg.format(p))
|
||||
|
||||
def defvectorized(prim):
|
||||
primitive_batchers[prim] = partial(vectorized_batcher, prim)
|
||||
|
@ -1106,6 +1106,29 @@ class APITest(jtu.JaxTestCase):
|
||||
ans = vfun(lambda x: x + 1, np.arange(3))
|
||||
self.assertAllClose(ans, onp.arange(1, 4), check_dtypes=False)
|
||||
|
||||
def test_vmap_error_message_issue_705(self):
|
||||
# https://github.com/google/jax/issues/705
|
||||
def h(a, b):
|
||||
return np.sum(a) + np.sum(b)
|
||||
|
||||
X = onp.random.randn(10, 4)
|
||||
U = onp.random.randn(10, 2)
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
||||
"arg 0 has shape \(10, 4\) and axis 0 is to be mapped\n"
|
||||
"arg 1 has shape \(10, 2\) and axis 1 is to be mapped\n"
|
||||
"so\n"
|
||||
"arg 0 has an axis to be mapped of size 10\n"
|
||||
"arg 1 has an axis to be mapped of size 2",
|
||||
lambda: api.vmap(h, in_axes=(0, 1))(X, U))
|
||||
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
||||
"the tree of axis sizes is:\n",
|
||||
lambda: api.vmap(h, in_axes=(0, 1))(X, [U, U]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user