mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Pass axis name to _match_axes and add to error message.
This commit is contained in:
parent
454f5e67b1
commit
eb9d6e4d21
@ -57,7 +57,8 @@ def batch_subtrace(main, in_dims, *in_vals):
|
||||
yield out_vals, out_dims
|
||||
|
||||
@lu.transformation
|
||||
def _match_axes(axis_size, in_dims, out_dims_thunk, out_dim_dests, *in_vals):
|
||||
def _match_axes(axis_size, axis_name, in_dims, out_dims_thunk, out_dim_dests,
|
||||
*in_vals):
|
||||
if axis_size is None:
|
||||
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
||||
out_vals = yield in_vals, {}
|
||||
@ -65,7 +66,10 @@ def _match_axes(axis_size, in_dims, out_dims_thunk, out_dim_dests, *in_vals):
|
||||
out_dims = out_dims_thunk()
|
||||
for od, od_dest in zip(out_dims, out_dim_dests):
|
||||
if od is not None and not isinstance(od_dest, int):
|
||||
msg = f"vmap has mapped output but out_axes is {od_dest}"
|
||||
if not isinstance(axis_name, core._TempAxisName):
|
||||
msg = f"vmap has mapped output (axis_name={axis_name}) but out_axes is {od_dest}"
|
||||
else:
|
||||
msg = f"vmap has mapped output but out_axes is {od_dest}"
|
||||
raise ValueError(msg)
|
||||
yield map(partial(matchaxis, axis_size), out_dims, out_dim_dests, out_vals)
|
||||
|
||||
@ -280,8 +284,9 @@ def batch(fun: lu.WrappedFun,
|
||||
# anlogue of `jvp` in ad.py
|
||||
# TODO(mattjj,apaszke): change type of axis_size to be int, not Optional[int]
|
||||
fun, out_dims_thunk = batch_subtrace(fun)
|
||||
return _match_axes(batchfun(fun, axis_name, axis_size, in_dims, main_type),
|
||||
axis_size, in_dims, out_dims_thunk, out_dim_dests)
|
||||
return _match_axes(
|
||||
batchfun(fun, axis_name, axis_size, in_dims, main_type), axis_size,
|
||||
axis_name, in_dims, out_dims_thunk, out_dim_dests)
|
||||
|
||||
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
|
||||
def vtile(f_flat: lu.WrappedFun,
|
||||
|
@ -1908,8 +1908,17 @@ class APITest(jtu.JaxTestCase):
|
||||
api.vmap(lambda x: x, in_axes=0, out_axes=(2, 3))(jnp.array([1., 2.]))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "vmap has mapped output but out_axes is None"):
|
||||
# If the output is mapped, then there must be some out_axes specified
|
||||
ValueError,
|
||||
r"vmap has mapped output \(axis_name=foo\) but out_axes is None"):
|
||||
# If the output is mapped (user-named axis), then there must be some
|
||||
# out_axes specified.
|
||||
api.vmap(lambda x: x, out_axes=None, axis_name="foo")(jnp.array([1., 2.]))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"vmap has mapped output but out_axes is None"):
|
||||
# If the output is mapped (unnamed axis), then there must be some out_axes
|
||||
# specified.
|
||||
api.vmap(lambda x: x, out_axes=None)(jnp.array([1., 2.]))
|
||||
|
||||
def test_vmap_structured_in_axes(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user