mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Remove typing_extensions dependency
This commit is contained in:
parent
23261d78da
commit
4389216d0c
@ -26,10 +26,8 @@ import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
import itertools as it
|
||||
from typing import (Any, Callable, Generator, Iterable, NamedTuple, Mapping,
|
||||
Optional, Sequence, Tuple, TypeVar, Union, overload, Dict,
|
||||
Hashable, List)
|
||||
from typing_extensions import Literal
|
||||
from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
|
||||
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union, overload)
|
||||
from warnings import warn
|
||||
|
||||
import numpy as np
|
||||
|
@ -17,8 +17,7 @@ import dataclasses
|
||||
import inspect
|
||||
import threading
|
||||
|
||||
from typing import Any, Dict, Hashable, List, Optional, Tuple
|
||||
from typing_extensions import Protocol
|
||||
from typing import Any, Dict, Hashable, List, Optional, Protocol, Tuple
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import core
|
||||
|
@ -21,9 +21,8 @@ from functools import partial
|
||||
import itertools
|
||||
import time
|
||||
from typing import (
|
||||
Any, Callable, Dict, Iterable, Iterator, Optional, Sequence,
|
||||
Set, Tuple, List, Type, Union)
|
||||
from typing_extensions import Protocol
|
||||
Any, Callable, Dict, Iterable, Iterator, Optional, Protocol,
|
||||
Sequence, Set, Tuple, List, Type, Union)
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
@ -21,8 +21,7 @@
|
||||
|
||||
|
||||
import functools
|
||||
from typing import cast, overload, Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing_extensions import Literal
|
||||
from typing import cast, overload, Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -15,11 +15,10 @@
|
||||
import inspect
|
||||
import functools
|
||||
from functools import partial
|
||||
from typing import cast, Any, Callable, List, Optional, Tuple, TypeVar, Union, overload
|
||||
from typing import cast, Any, Callable, List, Literal, Optional, Tuple, TypeVar, Union, overload
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import Literal
|
||||
|
||||
import jax
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
|
@ -18,10 +18,9 @@ used in Keras and Sonnet.
|
||||
"""
|
||||
|
||||
|
||||
from typing import Any, Sequence, Tuple, Union
|
||||
from typing import Any, Literal, Protocol, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import Literal, Protocol
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
|
@ -30,14 +30,13 @@ from functools import partial
|
||||
import operator
|
||||
import types
|
||||
from typing import (
|
||||
overload, Any, Callable, Dict, FrozenSet, List, Optional,
|
||||
Sequence, Tuple, TypeVar, Union)
|
||||
overload, Any, Callable, Dict, FrozenSet, List, Literal,
|
||||
Optional, Sequence, Tuple, TypeVar, Union)
|
||||
from textwrap import dedent as _dedent
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import opt_einsum
|
||||
from typing_extensions import Literal
|
||||
|
||||
import jax
|
||||
from jax import jit
|
||||
|
@ -18,8 +18,7 @@ from functools import partial
|
||||
import numpy as np
|
||||
import textwrap
|
||||
import operator
|
||||
from typing import Optional, Tuple, Union, cast, overload
|
||||
from typing_extensions import Literal
|
||||
from typing import Literal, Optional, Tuple, Union, cast, overload
|
||||
|
||||
import jax
|
||||
from jax import jit, custom_jvp
|
||||
|
@ -15,11 +15,10 @@
|
||||
import builtins
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import overload, Any, Callable, Optional, Sequence, Tuple, Union
|
||||
from typing import overload, Any, Callable, Literal, Optional, Sequence, Tuple, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import Literal
|
||||
|
||||
from jax import core
|
||||
from jax import lax
|
||||
|
@ -12,8 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import overload, Optional, Tuple, Union
|
||||
from typing_extensions import Literal
|
||||
from typing import overload, Literal, Optional, Tuple, Union
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
|
@ -19,8 +19,7 @@ import numpy as np
|
||||
import scipy.linalg
|
||||
import textwrap
|
||||
import warnings
|
||||
from typing import cast, overload, Any, Optional, Tuple, Union
|
||||
from typing_extensions import Literal
|
||||
from typing import cast, overload, Any, Literal, Optional, Tuple, Union
|
||||
|
||||
import jax
|
||||
from jax import jit, vmap, jvp
|
||||
|
@ -33,8 +33,7 @@ from __future__ import annotations
|
||||
import warnings
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple
|
||||
from typing_extensions import Protocol
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tuple
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
|
@ -16,8 +16,7 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing_extensions import Protocol
|
||||
from typing import Any, Dict, List, Optional, Protocol, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -14,8 +14,7 @@
|
||||
"""Module for state primitives."""
|
||||
from functools import partial
|
||||
|
||||
from typing import Any, List, Tuple, TypeVar, Union
|
||||
from typing_extensions import Protocol
|
||||
from typing import Any, List, Protocol, Tuple, TypeVar, Union
|
||||
|
||||
from jax import core
|
||||
from jax import lax
|
||||
|
@ -26,8 +26,7 @@ https://github.com/google/jax/pull/11859/.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Sequence, Union
|
||||
from typing_extensions import Protocol
|
||||
from typing import Any, Protocol, Sequence, Union
|
||||
import numpy as np
|
||||
|
||||
from jax._src.basearray import Array
|
||||
|
@ -18,13 +18,11 @@ https://github.com/google/flax/tree/main/examples/sst2
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Optional
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from flax import linen as nn
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
|
||||
Array: TypeAlias = jnp.ndarray
|
||||
from jax._src.typing import Array
|
||||
|
||||
|
||||
def sequence_mask(lengths: Array, max_length: int) -> Array:
|
||||
|
@ -25,8 +25,7 @@ import itertools
|
||||
import re
|
||||
import typing
|
||||
from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional,
|
||||
Sequence, Set, Tuple, Type, Union, FrozenSet)
|
||||
from typing_extensions import Protocol
|
||||
Protocol, Sequence, Set, Tuple, Type, Union, FrozenSet)
|
||||
import warnings
|
||||
|
||||
from jax import core
|
||||
|
@ -22,8 +22,7 @@ import itertools as it
|
||||
import operator
|
||||
import re
|
||||
from typing import (Any, Callable, Dict, List, NamedTuple, Optional,
|
||||
Sequence, Set, Type, Tuple, Union)
|
||||
from typing_extensions import Protocol
|
||||
Protocol, Sequence, Set, Type, Tuple, Union)
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user