mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix transpose issue in jacfwd and jacrev
This commit is contained in:
parent
e788539e0a
commit
89349e5e6d
@ -64,7 +64,7 @@ def jacfwd(fun, x):
|
||||
fun = lu.wrap_init(fun)
|
||||
pushfwd = partial(jvp, fun, (x,))
|
||||
std_basis = onp.eye(onp.size(x)).reshape((-1,) + onp.shape(x)),
|
||||
y, jac_flat = vmap(pushfwd, std_basis, out_axes=(None, 0))
|
||||
y, jac_flat = vmap(pushfwd, out_axes=(None, -1))(std_basis)
|
||||
return jac_flat.reshape(onp.shape(y) + onp.shape(x))
|
||||
|
||||
@curry
|
||||
@ -72,7 +72,7 @@ def jacrev(fun, x):
|
||||
fun = lu.wrap_init(fun)
|
||||
y, pullback = vjp(fun, x)
|
||||
std_basis = onp.eye(onp.size(y)).reshape((-1,) + onp.shape(y))
|
||||
jac_flat, = vmap(pullback, std_basis, out_axes=onp.ndim(y))
|
||||
jac_flat, = vmap(pullback, out_axes=0)(std_basis)
|
||||
return jac_flat.reshape(onp.shape(y) + onp.shape(x))
|
||||
|
||||
def hessian(fun):
|
||||
|
@ -281,6 +281,7 @@ def moveaxis(sz, dst, src, x):
|
||||
else:
|
||||
return pack(map(partial(moveaxis, sz, dst, src), x))
|
||||
elif isinstance(aval, ShapedArray):
|
||||
dst = (dst % aval.ndim) if dst is not None and aval.ndim else dst
|
||||
if src == dst:
|
||||
return x
|
||||
else:
|
||||
|
@ -24,7 +24,7 @@ from jax import test_util as jtu
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.config import config
|
||||
from jax import jit, grad, device_get, device_put
|
||||
from jax import jit, grad, device_get, device_put, jacfwd, jacrev
|
||||
from jax.core import Primitive
|
||||
from jax.interpreters.partial_eval import def_abstract_eval
|
||||
from jax.interpreters.ad import defjvp
|
||||
@ -235,6 +235,18 @@ class APITest(jtu.JaxTestCase):
|
||||
assert isinstance(y2[1][1], onp.ndarray)
|
||||
assert onp.all(y2[1][1] == 3 * x)
|
||||
|
||||
def test_jacobian(self):
|
||||
R = onp.random.RandomState(0).randn
|
||||
A = R(4, 3)
|
||||
x = R(3)
|
||||
|
||||
f = lambda x: np.dot(A, x)
|
||||
assert onp.allclose(jacfwd(f)(x), A)
|
||||
assert onp.allclose(jacrev(f)(x), A)
|
||||
|
||||
f = lambda x: np.tanh(np.dot(A, x))
|
||||
assert onp.allclose(jacfwd(f)(x), jacrev(f)(x))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.config_with_absl()
|
||||
|
@ -24,7 +24,7 @@ import jax.numpy as np
|
||||
from jax import test_util as jtu
|
||||
from jax.abstract_arrays import ShapedArray
|
||||
from jax import lax
|
||||
from jax.api import jit, grad, jvp, vjp, trace_to_jaxpr
|
||||
from jax.api import jit, grad, jvp, vjp, trace_to_jaxpr, jacfwd, jacrev
|
||||
from jax.api import vmap
|
||||
from jax.config import config
|
||||
from jax.core import unit
|
||||
@ -272,6 +272,16 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
onp.moveaxis(z, 2, 0)], 2)
|
||||
self.assertAllClose(ans, expected_ans, check_dtypes=False)
|
||||
|
||||
def testJacobianIssue54(self):
|
||||
# test modeling the code in https://github.com/google/jax/issues/54
|
||||
|
||||
def func(xs):
|
||||
return np.array([x for x in xs])
|
||||
|
||||
xs = np.ones((5, 1))
|
||||
jacrev(func)(xs) # don't crash
|
||||
jacfwd(func)(xs) # don't crash
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.config_with_absl()
|
||||
|
Loading…
x
Reference in New Issue
Block a user