Merge pull request #15179 from jakevdp:fix-jet-mypy

PiperOrigin-RevId: 518958675
This commit is contained in:
jax authors 2023-03-23 14:07:27 -07:00
commit d857187503

View File

@ -52,18 +52,16 @@ r"""Jet is an experimental module for higher-order automatic differentiation
`outstanding primitive rules <https://github.com/google/jax/issues/2431>`__.
"""
from typing import Callable, Any, Tuple
from typing import Any, Callable, Dict, Tuple
from functools import partial
import numpy as np
import jax
from jax import lax
from jax.interpreters import xla
import jax.numpy as jnp
from jax.experimental import pjit
from jax.interpreters import partial_eval as pe, pxla
from jax.interpreters import partial_eval as pe
from jax.tree_util import (register_pytree_node, tree_structure,
treedef_is_leaf, tree_flatten, tree_unflatten,)
@ -265,7 +263,7 @@ zero_series = ZeroSeries()
register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series)
call_param_updaters = {}
call_param_updaters: Dict[core.Primitive, Callable[..., Any]] = {}
### rule definitions