fix transpose issue in jacfwd and jacrev

This commit is contained in:
Matthew Johnson 2018-12-11 16:24:20 -08:00
parent e788539e0a
commit 89349e5e6d
4 changed files with 27 additions and 4 deletions

View File

@ -64,7 +64,7 @@ def jacfwd(fun, x):
fun = lu.wrap_init(fun)
pushfwd = partial(jvp, fun, (x,))
std_basis = onp.eye(onp.size(x)).reshape((-1,) + onp.shape(x)),
y, jac_flat = vmap(pushfwd, std_basis, out_axes=(None, 0))
y, jac_flat = vmap(pushfwd, out_axes=(None, -1))(std_basis)
return jac_flat.reshape(onp.shape(y) + onp.shape(x))
@curry
@ -72,7 +72,7 @@ def jacrev(fun, x):
fun = lu.wrap_init(fun)
y, pullback = vjp(fun, x)
std_basis = onp.eye(onp.size(y)).reshape((-1,) + onp.shape(y))
jac_flat, = vmap(pullback, std_basis, out_axes=onp.ndim(y))
jac_flat, = vmap(pullback, out_axes=0)(std_basis)
return jac_flat.reshape(onp.shape(y) + onp.shape(x))
def hessian(fun):

View File

@ -281,6 +281,7 @@ def moveaxis(sz, dst, src, x):
else:
return pack(map(partial(moveaxis, sz, dst, src), x))
elif isinstance(aval, ShapedArray):
dst = (dst % aval.ndim) if dst is not None and aval.ndim else dst
if src == dst:
return x
else:

View File

@ -24,7 +24,7 @@ from jax import test_util as jtu
import jax.numpy as np
from jax.config import config
from jax import jit, grad, device_get, device_put
from jax import jit, grad, device_get, device_put, jacfwd, jacrev
from jax.core import Primitive
from jax.interpreters.partial_eval import def_abstract_eval
from jax.interpreters.ad import defjvp
@ -235,6 +235,18 @@ class APITest(jtu.JaxTestCase):
assert isinstance(y2[1][1], onp.ndarray)
assert onp.all(y2[1][1] == 3 * x)
def test_jacobian(self):
R = onp.random.RandomState(0).randn
A = R(4, 3)
x = R(3)
f = lambda x: np.dot(A, x)
assert onp.allclose(jacfwd(f)(x), A)
assert onp.allclose(jacrev(f)(x), A)
f = lambda x: np.tanh(np.dot(A, x))
assert onp.allclose(jacfwd(f)(x), jacrev(f)(x))
if __name__ == '__main__':
config.config_with_absl()

View File

@ -24,7 +24,7 @@ import jax.numpy as np
from jax import test_util as jtu
from jax.abstract_arrays import ShapedArray
from jax import lax
from jax.api import jit, grad, jvp, vjp, trace_to_jaxpr
from jax.api import jit, grad, jvp, vjp, trace_to_jaxpr, jacfwd, jacrev
from jax.api import vmap
from jax.config import config
from jax.core import unit
@ -272,6 +272,16 @@ class BatchingTest(jtu.JaxTestCase):
onp.moveaxis(z, 2, 0)], 2)
self.assertAllClose(ans, expected_ans, check_dtypes=False)
def testJacobianIssue54(self):
# test modeling the code in https://github.com/google/jax/issues/54
def func(xs):
return np.array([x for x in xs])
xs = np.ones((5, 1))
jacrev(func)(xs) # don't crash
jacfwd(func)(xs) # don't crash
if __name__ == '__main__':
config.config_with_absl()