Use traced identity in jacobian std_basis

This commit is contained in:
Jake VanderPlas 2021-09-22 16:08:18 -07:00
parent b0541802fa
commit 0957e81655

View File

@ -1118,9 +1118,8 @@ def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
def _std_basis(pytree):
leaves, _ = tree_flatten(pytree)
ndim = sum(map(np.size, leaves))
# TODO(mattjj): use a symbolic identity matrix here
dtype = dtypes.result_type(*leaves)
flat_basis = np.eye(ndim, dtype=dtype)
flat_basis = jax.numpy.eye(ndim, dtype=dtype)
return _unravel_array_into_pytree(pytree, 1, flat_basis)
def _unravel_array_into_pytree(pytree, axis, arr):