diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index f1dfc2b2b..4159e8d29 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -48,7 +48,7 @@ DeviceArray([-1.2655463 , -0.52060574, -0.14522289, -0.10817424, import functools from typing import ( - Any, Callable, Dict, NamedTuple, NewType, List, Optional, Sequence, Tuple, Union) + Any, Callable, Dict, NamedTuple, List, Optional, Sequence, Tuple, Union) import numpy as np