mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix mypy issue in jax/experimental/jet.py
This commit is contained in:
parent
a9e48af260
commit
1286446b52
@ -52,18 +52,16 @@ r"""Jet is an experimental module for higher-order automatic differentiation
|
||||
`outstanding primitive rules <https://github.com/google/jax/issues/2431>`__.
|
||||
"""
|
||||
|
||||
from typing import Callable, Any, Tuple
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.interpreters import xla
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import partial_eval as pe, pxla
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import (register_pytree_node, tree_structure,
|
||||
treedef_is_leaf, tree_flatten, tree_unflatten,)
|
||||
|
||||
@ -265,7 +263,7 @@ zero_series = ZeroSeries()
|
||||
register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series)
|
||||
|
||||
|
||||
call_param_updaters = {}
|
||||
call_param_updaters: Dict[core.Primitive, Callable[..., Any]] = {}
|
||||
|
||||
|
||||
### rule definitions
|
||||
|
Loading…
x
Reference in New Issue
Block a user