mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add a "last" symbol for vmap axis specs, use it in api.jacfwd
. tests and fixes #1372
Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
db694bed2c
commit
d1c66614e8
@ -448,7 +448,7 @@ def jacfwd(fun, argnums=0, holomorphic=False):
|
||||
f_partial, dyn_args = _argnums_partial(f, argnums, args)
|
||||
holomorphic or tree_map(_check_real_input_jacfwd, dyn_args)
|
||||
pushfwd = partial(jvp, f_partial, dyn_args)
|
||||
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
|
||||
y, jac = vmap(pushfwd, out_axes=(None, batching.last))(_std_basis(dyn_args))
|
||||
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
|
||||
return tree_map(partial(_unravel_array_into_pytree, example_args, -1), jac)
|
||||
|
||||
@ -617,12 +617,10 @@ def _flatten_axes(treedef, axis_tree):
|
||||
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
|
||||
axes = []
|
||||
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
|
||||
# TODO(mattjj): remove _replace_nones / list comp after jaxlib 0.1.25
|
||||
tree_multimap(add_leaves, _replace_nones(axis_tree), dummy)
|
||||
axes = [None if a is _none_proxy else a for a in axes]
|
||||
return axes
|
||||
|
||||
# TODO(mattjj): remove this when jaxlib is updated past 0.1.25
|
||||
def _replace_nones(tuptree):
|
||||
if type(tuptree) in (list, tuple):
|
||||
return tuple(map(_replace_nones, tuptree))
|
||||
|
@ -241,9 +241,14 @@ defvectorized(xla.device_put_p)
|
||||
# almost works, except for broadcast, for which raw numpy.ndarrays don't have a
|
||||
# method. To handle that case, the `broadcast` function uses a try/except.
|
||||
|
||||
class _Last(object): pass
|
||||
last = _Last()
|
||||
|
||||
def broadcast(x, sz, axis):
|
||||
if core.get_aval(x) is core.abstract_unit:
|
||||
return core.unit
|
||||
if axis is last:
|
||||
axis = onp.ndim(x)
|
||||
shape = list(onp.shape(x))
|
||||
shape.insert(axis, sz)
|
||||
if isinstance(x, onp.ndarray) or onp.isscalar(x):
|
||||
@ -267,6 +272,8 @@ def matchaxis(sz, src, dst, x):
|
||||
return x
|
||||
elif type(src) == type(dst) == int:
|
||||
return moveaxis(x, src, dst)
|
||||
elif type(src) == int and dst is last:
|
||||
return moveaxis(x, src, -1)
|
||||
elif src is not_mapped and dst is not not_mapped:
|
||||
return broadcast(x, sz, dst)
|
||||
else:
|
||||
|
@ -357,6 +357,41 @@ class APITest(jtu.JaxTestCase):
|
||||
(onp.array([0., 0.]), onp.array([0., 2.])))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_issue1372(self):
|
||||
def quad(x):
|
||||
return np.dot(x, x)
|
||||
|
||||
def f(x, u):
|
||||
return quad(x) + quad(u)
|
||||
|
||||
x, u = np.ones(5), np.ones(2)
|
||||
|
||||
rev = jacrev
|
||||
fwd = jacfwd
|
||||
|
||||
# Diagonal entries
|
||||
self.assertEqual(rev(rev(f, 0), 0)(x, u).shape, (5, 5))
|
||||
self.assertEqual(rev(fwd(f, 0), 0)(x, u).shape, (5, 5))
|
||||
self.assertEqual(fwd(rev(f, 0), 0)(x, u).shape, (5, 5))
|
||||
self.assertEqual(fwd(fwd(f, 0), 0)(x, u).shape, (5, 5))
|
||||
self.assertEqual(rev(rev(f, 1), 1)(x, u).shape, (2, 2))
|
||||
self.assertEqual(rev(fwd(f, 1), 1)(x, u).shape, (2, 2))
|
||||
self.assertEqual(fwd(rev(f, 1), 1)(x, u).shape, (2, 2))
|
||||
self.assertEqual(fwd(fwd(f, 1), 1)(x, u).shape, (2, 2))
|
||||
|
||||
# Off-diagonal entries by reverse-mode on the outside
|
||||
self.assertEqual(rev(rev(f, 1), 0)(x, u).shape, (2, 5))
|
||||
self.assertEqual(rev(fwd(f, 1), 0)(x, u).shape, (2, 5))
|
||||
self.assertEqual(rev(rev(f, 0), 1)(x, u).shape, (5, 2))
|
||||
self.assertEqual(rev(fwd(f, 0), 1)(x, u).shape, (5, 2))
|
||||
|
||||
# Off-diagonal entries by forward-mode on the outside
|
||||
self.assertEqual(fwd(rev(f, 1), 0)(x, u).shape, (2, 5))
|
||||
self.assertEqual(fwd(fwd(f, 1), 0)(x, u).shape, (2, 5))
|
||||
self.assertEqual(fwd(rev(f, 0), 1)(x, u).shape, (5, 2))
|
||||
self.assertEqual(fwd(fwd(f, 0), 1)(x, u).shape, (5, 2))
|
||||
|
||||
def test_disable_jit(self):
|
||||
effects = []
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user