mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use traced identity in jacobian std_basis
This commit is contained in:
parent
b0541802fa
commit
0957e81655
@ -1118,9 +1118,8 @@ def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
def _std_basis(pytree):
|
||||
leaves, _ = tree_flatten(pytree)
|
||||
ndim = sum(map(np.size, leaves))
|
||||
# TODO(mattjj): use a symbolic identity matrix here
|
||||
dtype = dtypes.result_type(*leaves)
|
||||
flat_basis = np.eye(ndim, dtype=dtype)
|
||||
flat_basis = jax.numpy.eye(ndim, dtype=dtype)
|
||||
return _unravel_array_into_pytree(pytree, 1, flat_basis)
|
||||
|
||||
def _unravel_array_into_pytree(pytree, axis, arr):
|
||||
|
Loading…
x
Reference in New Issue
Block a user