Fix mypy issue in jax/experimental/jet.py

This commit is contained in:
Jake VanderPlas 2023-03-23 13:56:11 -07:00
parent a9e48af260
commit 1286446b52

View File

@ -52,18 +52,16 @@ r"""Jet is an experimental module for higher-order automatic differentiation
`outstanding primitive rules <https://github.com/google/jax/issues/2431>`__.
"""
from typing import Callable, Any, Tuple
from typing import Any, Callable, Dict, Tuple
from functools import partial
import numpy as np
import jax
from jax import lax
from jax.interpreters import xla
import jax.numpy as jnp
from jax.experimental import pjit
from jax.interpreters import partial_eval as pe, pxla
from jax.interpreters import partial_eval as pe
from jax.tree_util import (register_pytree_node, tree_structure,
treedef_is_leaf, tree_flatten, tree_unflatten,)
@ -265,7 +263,7 @@ zero_series = ZeroSeries()
register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series)
call_param_updaters = {}
call_param_updaters: Dict[core.Primitive, Callable[..., Any]] = {}
### rule definitions