add a "last" symbol for vmap axis specs, use it in api.jacfwd. tests and fixes #1372

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Roy Frostig 2019-09-23 13:35:52 -07:00
parent db694bed2c
commit d1c66614e8
3 changed files with 43 additions and 3 deletions

View File

@ -448,7 +448,7 @@ def jacfwd(fun, argnums=0, holomorphic=False):
f_partial, dyn_args = _argnums_partial(f, argnums, args)
holomorphic or tree_map(_check_real_input_jacfwd, dyn_args)
pushfwd = partial(jvp, f_partial, dyn_args)
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
y, jac = vmap(pushfwd, out_axes=(None, batching.last))(_std_basis(dyn_args))
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
return tree_map(partial(_unravel_array_into_pytree, example_args, -1), jac)
@ -617,12 +617,10 @@ def _flatten_axes(treedef, axis_tree):
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
axes = []
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
# TODO(mattjj): remove _replace_nones / list comp after jaxlib 0.1.25
tree_multimap(add_leaves, _replace_nones(axis_tree), dummy)
axes = [None if a is _none_proxy else a for a in axes]
return axes
# TODO(mattjj): remove this when jaxlib is updated past 0.1.25
def _replace_nones(tuptree):
if type(tuptree) in (list, tuple):
return tuple(map(_replace_nones, tuptree))

View File

@ -241,9 +241,14 @@ defvectorized(xla.device_put_p)
# almost works, except for broadcast, for which raw numpy.ndarrays don't have a
# method. To handle that case, the `broadcast` function uses a try/except.
class _Last(object): pass
last = _Last()
def broadcast(x, sz, axis):
if core.get_aval(x) is core.abstract_unit:
return core.unit
if axis is last:
axis = onp.ndim(x)
shape = list(onp.shape(x))
shape.insert(axis, sz)
if isinstance(x, onp.ndarray) or onp.isscalar(x):
@ -267,6 +272,8 @@ def matchaxis(sz, src, dst, x):
return x
elif type(src) == type(dst) == int:
return moveaxis(x, src, dst)
elif type(src) == int and dst is last:
return moveaxis(x, src, -1)
elif src is not_mapped and dst is not not_mapped:
return broadcast(x, sz, dst)
else:

View File

@ -357,6 +357,41 @@ class APITest(jtu.JaxTestCase):
(onp.array([0., 0.]), onp.array([0., 2.])))
self.assertAllClose(ans, expected, check_dtypes=False)
@jtu.skip_on_devices("tpu")
def test_issue1372(self):
def quad(x):
return np.dot(x, x)
def f(x, u):
return quad(x) + quad(u)
x, u = np.ones(5), np.ones(2)
rev = jacrev
fwd = jacfwd
# Diagonal entries
self.assertEqual(rev(rev(f, 0), 0)(x, u).shape, (5, 5))
self.assertEqual(rev(fwd(f, 0), 0)(x, u).shape, (5, 5))
self.assertEqual(fwd(rev(f, 0), 0)(x, u).shape, (5, 5))
self.assertEqual(fwd(fwd(f, 0), 0)(x, u).shape, (5, 5))
self.assertEqual(rev(rev(f, 1), 1)(x, u).shape, (2, 2))
self.assertEqual(rev(fwd(f, 1), 1)(x, u).shape, (2, 2))
self.assertEqual(fwd(rev(f, 1), 1)(x, u).shape, (2, 2))
self.assertEqual(fwd(fwd(f, 1), 1)(x, u).shape, (2, 2))
# Off-diagonal entries by reverse-mode on the outside
self.assertEqual(rev(rev(f, 1), 0)(x, u).shape, (2, 5))
self.assertEqual(rev(fwd(f, 1), 0)(x, u).shape, (2, 5))
self.assertEqual(rev(rev(f, 0), 1)(x, u).shape, (5, 2))
self.assertEqual(rev(fwd(f, 0), 1)(x, u).shape, (5, 2))
# Off-diagonal entries by forward-mode on the outside
self.assertEqual(fwd(rev(f, 1), 0)(x, u).shape, (2, 5))
self.assertEqual(fwd(fwd(f, 1), 0)(x, u).shape, (2, 5))
self.assertEqual(fwd(rev(f, 0), 1)(x, u).shape, (5, 2))
self.assertEqual(fwd(fwd(f, 0), 1)(x, u).shape, (5, 2))
def test_disable_jit(self):
effects = []