Remove typing_extensions dependency

This commit is contained in:
Jake VanderPlas 2022-12-05 15:42:26 -08:00
parent 23261d78da
commit 4389216d0c
19 changed files with 21 additions and 42 deletions

View File

@ -26,10 +26,8 @@ import functools
from functools import partial
import inspect
import itertools as it
from typing import (Any, Callable, Generator, Iterable, NamedTuple, Mapping,
Optional, Sequence, Tuple, TypeVar, Union, overload, Dict,
Hashable, List)
from typing_extensions import Literal
from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union, overload)
from warnings import warn
import numpy as np

View File

@ -17,8 +17,7 @@ import dataclasses
import inspect
import threading
from typing import Any, Dict, Hashable, List, Optional, Tuple
from typing_extensions import Protocol
from typing import Any, Dict, Hashable, List, Optional, Protocol, Tuple
import jax.numpy as jnp
from jax import core

View File

@ -21,9 +21,8 @@ from functools import partial
import itertools
import time
from typing import (
Any, Callable, Dict, Iterable, Iterator, Optional, Sequence,
Set, Tuple, List, Type, Union)
from typing_extensions import Protocol
Any, Callable, Dict, Iterable, Iterator, Optional, Protocol,
Sequence, Set, Tuple, List, Type, Union)
import logging
import os
import re

View File

@ -21,8 +21,7 @@
import functools
from typing import cast, overload, Any, Dict, List, Optional, Set, Tuple, Union
from typing_extensions import Literal
from typing import cast, overload, Any, Dict, List, Literal, Optional, Set, Tuple, Union
import numpy as np

View File

@ -15,11 +15,10 @@
import inspect
import functools
from functools import partial
from typing import cast, Any, Callable, List, Optional, Tuple, TypeVar, Union, overload
from typing import cast, Any, Callable, List, Literal, Optional, Tuple, TypeVar, Union, overload
import warnings
import numpy as np
from typing_extensions import Literal
import jax
from jax._src.numpy import lax_numpy as jnp

View File

@ -18,10 +18,9 @@ used in Keras and Sonnet.
"""
from typing import Any, Sequence, Tuple, Union
from typing import Any, Literal, Protocol, Sequence, Tuple, Union
import numpy as np
from typing_extensions import Literal, Protocol
import jax.numpy as jnp
from jax import lax

View File

@ -30,14 +30,13 @@ from functools import partial
import operator
import types
from typing import (
overload, Any, Callable, Dict, FrozenSet, List, Optional,
Sequence, Tuple, TypeVar, Union)
overload, Any, Callable, Dict, FrozenSet, List, Literal,
Optional, Sequence, Tuple, TypeVar, Union)
from textwrap import dedent as _dedent
import warnings
import numpy as np
import opt_einsum
from typing_extensions import Literal
import jax
from jax import jit

View File

@ -18,8 +18,7 @@ from functools import partial
import numpy as np
import textwrap
import operator
from typing import Optional, Tuple, Union, cast, overload
from typing_extensions import Literal
from typing import Literal, Optional, Tuple, Union, cast, overload
import jax
from jax import jit, custom_jvp

View File

@ -15,11 +15,10 @@
import builtins
from functools import partial
import operator
from typing import overload, Any, Callable, Optional, Sequence, Tuple, Union
from typing import overload, Any, Callable, Literal, Optional, Sequence, Tuple, Union
import warnings
import numpy as np
from typing_extensions import Literal
from jax import core
from jax import lax

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import overload, Optional, Tuple, Union
from typing_extensions import Literal
from typing import overload, Literal, Optional, Tuple, Union
import jax
from jax import lax

View File

@ -19,8 +19,7 @@ import numpy as np
import scipy.linalg
import textwrap
import warnings
from typing import cast, overload, Any, Optional, Tuple, Union
from typing_extensions import Literal
from typing import cast, overload, Any, Literal, Optional, Tuple, Union
import jax
from jax import jit, vmap, jvp

View File

@ -33,8 +33,7 @@ from __future__ import annotations
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple
from typing_extensions import Protocol
from typing import Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tuple
import jax
from jax import core

View File

@ -16,8 +16,7 @@ from __future__ import annotations
import dataclasses
from functools import partial
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing_extensions import Protocol
from typing import Any, Dict, List, Optional, Protocol, Sequence, Tuple, Union
import numpy as np

View File

@ -14,8 +14,7 @@
"""Module for state primitives."""
from functools import partial
from typing import Any, List, Tuple, TypeVar, Union
from typing_extensions import Protocol
from typing import Any, List, Protocol, Tuple, TypeVar, Union
from jax import core
from jax import lax

View File

@ -26,8 +26,7 @@ https://github.com/google/jax/pull/11859/.
from __future__ import annotations
from typing import Any, Sequence, Union
from typing_extensions import Protocol
from typing import Any, Protocol, Sequence, Union
import numpy as np
from jax._src.basearray import Array

View File

@ -18,13 +18,11 @@ https://github.com/google/flax/tree/main/examples/sst2
import functools
from typing import Any, Callable, Optional
from typing_extensions import TypeAlias
from flax import linen as nn
import jax
from jax import numpy as jnp
Array: TypeAlias = jnp.ndarray
from jax._src.typing import Array
def sequence_mask(lengths: Array, max_length: int) -> Array:

View File

@ -25,8 +25,7 @@ import itertools
import re
import typing
from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional,
Sequence, Set, Tuple, Type, Union, FrozenSet)
from typing_extensions import Protocol
Protocol, Sequence, Set, Tuple, Type, Union, FrozenSet)
import warnings
from jax import core

View File

@ -22,8 +22,7 @@ import itertools as it
import operator
import re
from typing import (Any, Callable, Dict, List, NamedTuple, Optional,
Sequence, Set, Type, Tuple, Union)
from typing_extensions import Protocol
Protocol, Sequence, Set, Type, Tuple, Union)
import numpy as np

View File

@ -67,7 +67,6 @@ setup(
'numpy>=1.20',
'opt_einsum',
'scipy>=1.5',
'typing_extensions',
],
extras_require={
# Minimum jaxlib version; used in testing.