Run pyupgrade --py310-plus.

Also apply manual fixes to import sorting and unused imports.
This commit is contained in:
Peter Hawkins 2024-06-26 14:44:52 -04:00
parent cdfe2df384
commit 7f4ef63cd8
140 changed files with 387 additions and 379 deletions

View File

@ -32,7 +32,7 @@ def extract_filename(path):
def generate_final_report(shell=False, env_vars={}):
env = os.environ
env = {**env, **env_vars}
cmd = ["pytest_html_merger", "-i", '{}'.format(base_dir), "-o", '{}/final_compiled_report.html'.format(base_dir)]
cmd = ["pytest_html_merger", "-i", f'{base_dir}', "-o", f'{base_dir}/final_compiled_report.html']
result = subprocess.run(cmd,
shell=shell,
capture_output=True,
@ -90,7 +90,7 @@ def run_test(testmodule, gpu_tokens):
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
}
testfile = extract_filename(testmodule)
cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-x", testmodule]
cmd = ["python3", "-m", "pytest", f'--html={base_dir}/{testfile}_log.html', "--reruns", "3", "-x", testmodule]
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
with GPU_LOCK:
gpu_tokens.append(target_gpu)

View File

@ -14,7 +14,6 @@
from functools import partial, reduce
import math
from typing import Tuple
import jax
import jax.numpy as jnp
@ -325,9 +324,9 @@ class RmsNormFwdClass:
return RmsNormFwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims
@staticmethod
def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.core.ShapedArray]):
def infer_sharding_from_operands(eps: float, mesh : jax.sharding.Mesh,
arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
result_infos: tuple[jax._src.core.ShapedArray, ...]):
del eps, result_infos # Not needed for this example.
x_info, weight_info = arg_infos
assert len(x_info.shape) == 3
@ -340,9 +339,9 @@ class RmsNormFwdClass:
return (output_sharding, invvar_sharding)
@staticmethod
def partition(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
def partition(eps: float, mesh : jax.sharding.Mesh,
arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
result_infos: tuple[jax._src.api.ShapeDtypeStruct, ...]):
del result_infos # Not needed for this example.
x_info, weight_info = arg_infos
assert len(x_info.shape) == 3
@ -395,9 +394,9 @@ class RmsNormBwdClass:
return RmsNormBwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims
@staticmethod
def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.core.ShapedArray]):
def infer_sharding_from_operands(eps: float, mesh : jax.sharding.Mesh,
arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
result_infos: tuple[jax._src.core.ShapedArray, ...]):
del eps, result_infos # Not needed for this example.
g_info, invvar_info, x_info, weight_info = arg_infos
assert len(g_info.shape) == 3
@ -411,9 +410,9 @@ class RmsNormBwdClass:
return (output_sharding, invvar_sharding, output_sharding, )
@staticmethod
def partition(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
def partition(eps: float, mesh : jax.sharding.Mesh,
arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
result_infos: tuple[jax._src.api.ShapeDtypeStruct, ...]):
del result_infos # Not needed for this example.
g_info, invvar_info, x_info, weight_info = arg_infos
assert len(g_info.shape) == 3

View File

@ -167,15 +167,15 @@
"source": [
"from collections.abc import Sequence\n",
"from contextlib import contextmanager\n",
"from typing import Optional, Any\n",
"from typing import Any\n",
"\n",
"class MainTrace(NamedTuple):\n",
" level: int\n",
" trace_type: type['Trace']\n",
" global_data: Optional[Any]\n",
" global_data: Any | None\n",
"\n",
"trace_stack: list[MainTrace] = []\n",
"dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3\n",
"dynamic_trace: MainTrace | None = None # to be employed in Part 3\n",
"\n",
"@contextmanager\n",
"def new_main(trace_type: type['Trace'], global_data=None):\n",
@ -912,7 +912,7 @@
"source": [
"from collections.abc import Hashable, Iterable, Iterator\n",
"import itertools as it\n",
"from typing import Callable\n",
"from collections.abc import Callable\n",
"\n",
"class NodeType(NamedTuple):\n",
" name: str\n",
@ -1651,7 +1651,7 @@
"source": [
"from functools import lru_cache\n",
"\n",
"@lru_cache() # ShapedArrays are hashable\n",
"@lru_cache # ShapedArrays are hashable\n",
"def make_jaxpr_v1(f, *avals_in):\n",
" avals_in, in_tree = tree_flatten(avals_in)\n",
" f, out_tree = flatten_fun(f, in_tree)\n",
@ -1803,7 +1803,7 @@
" finally:\n",
" dynamic_trace = prev_dynamic_trace\n",
"\n",
"@lru_cache()\n",
"@lru_cache\n",
"def make_jaxpr(f: Callable, *avals_in: ShapedArray,\n",
" ) -> tuple[Jaxpr, list[Any], PyTreeDef]:\n",
" avals_in, in_tree = tree_flatten(avals_in)\n",
@ -1994,7 +1994,7 @@
" return execute(*args)\n",
"impl_rules[xla_call_p] = xla_call_impl\n",
"\n",
"@lru_cache()\n",
"@lru_cache\n",
"def xla_callable(hashable_jaxpr: IDHashable,\n",
" hashable_consts: tuple[IDHashable, ...]):\n",
" jaxpr: Jaxpr = hashable_jaxpr.val\n",
@ -2227,7 +2227,7 @@
" return primals_out, tangents_out\n",
"jvp_rules[xla_call_p] = xla_call_jvp_rule\n",
"\n",
"@lru_cache()\n",
"@lru_cache\n",
"def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:\n",
" def jvp_traceable(*primals_and_tangents):\n",
" n = len(primals_and_tangents) // 2\n",
@ -2253,7 +2253,7 @@
" return outs, [0] * len(outs)\n",
"vmap_rules[xla_call_p] = xla_call_vmap_rule\n",
"\n",
"@lru_cache()\n",
"@lru_cache\n",
"def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]\n",
" ) -> tuple[Jaxpr, list[Any]]:\n",
" vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))\n",
@ -2638,7 +2638,7 @@
"source": [
"class PartialVal(NamedTuple):\n",
" aval: ShapedArray\n",
" const: Optional[Any]\n",
" const: Any | None\n",
"\n",
" @classmethod\n",
" def known(cls, val: Any):\n",
@ -2727,7 +2727,7 @@
"source": [
"class PartialEvalTracer(Tracer):\n",
" pval: PartialVal\n",
" recipe: Optional[JaxprRecipe]\n",
" recipe: JaxprRecipe | None\n",
"\n",
" def __init__(self, trace, pval, recipe):\n",
" self._trace = trace\n",
@ -2974,7 +2974,7 @@
"partial_eval_rules[xla_call_p] = xla_call_partial_eval\n",
"\n",
"def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],\n",
" instantiate: Optional[list[bool]] = None,\n",
" instantiate: list[bool] | None = None,\n",
" ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:\n",
" env: dict[Var, bool] = {}\n",
" residuals: set[Var] = set()\n",
@ -3271,7 +3271,7 @@
" return [next(outs) if undef else None for undef in undef_primals]\n",
"transpose_rules[xla_call_p] = xla_call_transpose_rule\n",
"\n",
"@lru_cache()\n",
"@lru_cache\n",
"def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]\n",
" ) -> tuple[Jaxpr, list[Any]]:\n",
" avals_in, avals_out = typecheck_jaxpr(jaxpr)\n",

View File

@ -148,15 +148,15 @@ more descriptive.
```{code-cell}
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Optional, Any
from typing import Any
class MainTrace(NamedTuple):
level: int
trace_type: type['Trace']
global_data: Optional[Any]
global_data: Any | None
trace_stack: list[MainTrace] = []
dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3
dynamic_trace: MainTrace | None = None # to be employed in Part 3
@contextmanager
def new_main(trace_type: type['Trace'], global_data=None):
@ -705,7 +705,7 @@ class Store:
from collections.abc import Hashable, Iterable, Iterator
import itertools as it
from typing import Callable
from collections.abc import Callable
class NodeType(NamedTuple):
name: str
@ -1295,7 +1295,7 @@ transformation and a pretty-printer:
```{code-cell}
from functools import lru_cache
@lru_cache() # ShapedArrays are hashable
@lru_cache # ShapedArrays are hashable
def make_jaxpr_v1(f, *avals_in):
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
@ -1415,7 +1415,7 @@ def new_dynamic(main: MainTrace):
finally:
dynamic_trace = prev_dynamic_trace
@lru_cache()
@lru_cache
def make_jaxpr(f: Callable, *avals_in: ShapedArray,
) -> tuple[Jaxpr, list[Any], PyTreeDef]:
avals_in, in_tree = tree_flatten(avals_in)
@ -1564,7 +1564,7 @@ def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
return execute(*args)
impl_rules[xla_call_p] = xla_call_impl
@lru_cache()
@lru_cache
def xla_callable(hashable_jaxpr: IDHashable,
hashable_consts: tuple[IDHashable, ...]):
jaxpr: Jaxpr = hashable_jaxpr.val
@ -1734,7 +1734,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule
@lru_cache()
@lru_cache
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
def jvp_traceable(*primals_and_tangents):
n = len(primals_and_tangents) // 2
@ -1755,7 +1755,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule
@lru_cache()
@lru_cache
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
) -> tuple[Jaxpr, list[Any]]:
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
@ -2065,7 +2065,7 @@ be either known or unknown:
```{code-cell}
class PartialVal(NamedTuple):
aval: ShapedArray
const: Optional[Any]
const: Any | None
@classmethod
def known(cls, val: Any):
@ -2129,7 +2129,7 @@ JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
```{code-cell}
class PartialEvalTracer(Tracer):
pval: PartialVal
recipe: Optional[JaxprRecipe]
recipe: JaxprRecipe | None
def __init__(self, trace, pval, recipe):
self._trace = trace
@ -2329,7 +2329,7 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
partial_eval_rules[xla_call_p] = xla_call_partial_eval
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
instantiate: Optional[list[bool]] = None,
instantiate: list[bool] | None = None,
) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
env: dict[Var, bool] = {}
residuals: set[Var] = set()
@ -2586,7 +2586,7 @@ def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule
@lru_cache()
@lru_cache
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
) -> tuple[Jaxpr, list[Any]]:
avals_in, avals_out = typecheck_jaxpr(jaxpr)

View File

@ -138,15 +138,15 @@ def bind1(prim, *args, **params):
# +
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Optional, Any
from typing import Any
class MainTrace(NamedTuple):
level: int
trace_type: type['Trace']
global_data: Optional[Any]
global_data: Any | None
trace_stack: list[MainTrace] = []
dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3
dynamic_trace: MainTrace | None = None # to be employed in Part 3
@contextmanager
def new_main(trace_type: type['Trace'], global_data=None):
@ -697,7 +697,7 @@ class Store:
# + tags=["hide-input"]
from collections.abc import Hashable, Iterable, Iterator
import itertools as it
from typing import Callable
from collections.abc import Callable
class NodeType(NamedTuple):
name: str
@ -1297,7 +1297,7 @@ abstract_eval_rules[broadcast_p] = broadcast_abstract_eval
# +
from functools import lru_cache
@lru_cache() # ShapedArrays are hashable
@lru_cache # ShapedArrays are hashable
def make_jaxpr_v1(f, *avals_in):
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
@ -1412,7 +1412,7 @@ def new_dynamic(main: MainTrace):
finally:
dynamic_trace = prev_dynamic_trace
@lru_cache()
@lru_cache
def make_jaxpr(f: Callable, *avals_in: ShapedArray,
) -> tuple[Jaxpr, list[Any], PyTreeDef]:
avals_in, in_tree = tree_flatten(avals_in)
@ -1556,7 +1556,7 @@ def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
return execute(*args)
impl_rules[xla_call_p] = xla_call_impl
@lru_cache()
@lru_cache
def xla_callable(hashable_jaxpr: IDHashable,
hashable_consts: tuple[IDHashable, ...]):
jaxpr: Jaxpr = hashable_jaxpr.val
@ -1728,7 +1728,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule
@lru_cache()
@lru_cache
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
def jvp_traceable(*primals_and_tangents):
n = len(primals_and_tangents) // 2
@ -1749,7 +1749,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule
@lru_cache()
@lru_cache
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
) -> tuple[Jaxpr, list[Any]]:
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
@ -2057,7 +2057,7 @@ def vspace(aval: ShapedArray) -> ShapedArray:
class PartialVal(NamedTuple):
aval: ShapedArray
const: Optional[Any]
const: Any | None
@classmethod
def known(cls, val: Any):
@ -2121,7 +2121,7 @@ JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
class PartialEvalTracer(Tracer):
pval: PartialVal
recipe: Optional[JaxprRecipe]
recipe: JaxprRecipe | None
def __init__(self, trace, pval, recipe):
self._trace = trace
@ -2322,7 +2322,7 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
partial_eval_rules[xla_call_p] = xla_call_partial_eval
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
instantiate: Optional[list[bool]] = None,
instantiate: list[bool] | None = None,
) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
env: dict[Var, bool] = {}
residuals: set[Var] = set()
@ -2585,7 +2585,7 @@ def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule
@lru_cache()
@lru_cache
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
) -> tuple[Jaxpr, list[Any]]:
avals_in, avals_out = typecheck_jaxpr(jaxpr)

View File

@ -14,11 +14,11 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import functools
from functools import partial
import logging
from typing import Any, Callable
from typing import Any
import types
import numpy as np

View File

@ -13,8 +13,9 @@
# limitations under the License.
from __future__ import annotations
from collections.abc import Callable
import types
from typing import Any, Callable, TypeVar
from typing import Any, TypeVar
from jax._src import core
from jax._src import traceback_util

View File

@ -23,12 +23,12 @@ arrays.
from __future__ import annotations
import collections
from collections.abc import Generator, Hashable, Iterable, Sequence
from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
from functools import partial, lru_cache
import inspect
import math
import typing
from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload,
from typing import (Any, Literal, NamedTuple, TypeVar, overload,
cast)
import weakref

View File

@ -14,11 +14,11 @@
from __future__ import annotations
from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
import inspect
import operator
from functools import partial, lru_cache
from typing import Any, Callable, Type
from typing import Any
import numpy as np
@ -713,6 +713,6 @@ class _HashableByObjectId:
def __eq__(self, other):
return self.val is other.val
def register_class_with_attrs(t: Type) -> None:
def register_class_with_attrs(t: type) -> None:
_class_with_attrs.add(t)
_class_with_attrs: set[Type] = set()
_class_with_attrs: set[type] = set()

View File

@ -15,12 +15,12 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import enum
import functools
import math
import operator as op
from typing import Any, Callable, TYPE_CHECKING, cast
from typing import Any, TYPE_CHECKING, cast
from jax._src import abstract_arrays
from jax._src import api

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Protocol, Sequence
from collections.abc import Sequence
from typing import Any, Protocol
import jax
from jax._src import random
from jax._src.typing import Array, ArrayLike

View File

@ -14,11 +14,11 @@
"""Module for JAX callbacks."""
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import dataclasses
import functools
import logging
from typing import Any, Callable
from typing import Any
import jax
from jax._src import core

View File

@ -13,11 +13,11 @@
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import dataclasses
import functools
import itertools as it
from typing import Callable, TypeVar, Any, Union
from typing import TypeVar, Any, Union
import numpy as np

View File

@ -16,7 +16,6 @@ import os
from jax import version
from jax._src import config
from jax._src import hardware_utils
from typing import Optional
running_in_cloud_tpu_vm: bool = False
@ -35,7 +34,7 @@ def maybe_import_libtpu():
return libtpu
def get_tpu_library_path() -> Optional[str]:
def get_tpu_library_path() -> str | None:
path_from_env = os.getenv("TPU_LIBRARY_PATH")
if path_from_env is not None and os.path.isfile(path_from_env):
return path_from_env

View File

@ -21,7 +21,7 @@ import logging
import os
import tempfile
import time
from typing import Any, Optional
from typing import Any
import warnings
from jax._src import compilation_cache
@ -393,7 +393,7 @@ def _share_fdo_profiles(
backend: xc.Client,
global_client: lib.xla_extension.DistributedRuntimeClient,
min_process_id
) -> Optional[bytes]:
) -> bytes | None:
sym_name = computation.operation.attributes['sym_name']
module_name = ir.StringAttr(sym_name).value
fdo_profile = compile_options.executable_build_options.fdo_profile

View File

@ -14,7 +14,7 @@
from __future__ import annotations
from collections.abc import Hashable, Iterator, Sequence
from collections.abc import Callable, Hashable, Iterator, Sequence
import contextlib
import functools
import itertools
@ -22,9 +22,7 @@ import logging
import os
import sys
import threading
from typing import (
Any, Callable, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast,
)
from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast
from jax._src import lib
from jax._src.lib import jax_jit

View File

@ -14,8 +14,8 @@
from __future__ import annotations
from collections import Counter, defaultdict, deque, namedtuple
from collections.abc import (Collection, Generator, Hashable, Iterable,
Iterator, Set, Sequence, MutableSet,
from collections.abc import (Callable, Collection, Generator, Hashable,
Iterable, Iterator, Set, Sequence, MutableSet,
MutableMapping)
from contextlib import contextmanager, ExitStack
from dataclasses import dataclass
@ -28,7 +28,7 @@ import math
import operator
import threading
import types
from typing import (Any, Callable, ClassVar, Generic, NamedTuple, TypeVar,
from typing import (Any, ClassVar, Generic, NamedTuple, TypeVar,
cast, overload, Union)
import warnings
from weakref import ref

View File

@ -15,7 +15,6 @@
from enum import Enum
from functools import partial, reduce
import operator
from typing import Optional
import json
import jax
@ -927,10 +926,10 @@ _dot_product_attention.defvjp(_dot_product_attention_fwd_rule, _dot_product_atte
def dot_product_attention(query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
q_seqlen: Optional[Array] = None,
kv_seqlen: Optional[Array] = None,
bias: Array | None = None,
mask: Array | None = None,
q_seqlen: Array | None = None,
kv_seqlen: Array | None = None,
*,
scale: float = 1.0,
mask_type: MaskType = MaskType.NO_MASK,

View File

@ -14,9 +14,9 @@
from __future__ import annotations
from collections.abc import Callable
import functools
import operator
from typing import Callable
from jax import lax
from jax._src import api

View File

@ -14,11 +14,11 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import dataclasses
from functools import update_wrapper, reduce, partial
import inspect
from typing import Any, Callable, Generic, TypeVar
from typing import Any, Generic, TypeVar
from jax._src import config
from jax._src import core

View File

@ -14,8 +14,9 @@
from __future__ import annotations
from collections.abc import Callable
import functools
from typing import Any, Callable
from typing import Any
from jax._src import ad_util
from jax._src import api_util

View File

@ -16,12 +16,12 @@
from __future__ import annotations
import importlib.util
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import functools
import logging
import string
import sys
from typing import Any, Callable, Union
from typing import Any, Union
import weakref
import numpy as np

View File

@ -16,13 +16,13 @@
from __future__ import annotations
import atexit
from collections.abc import Iterator, Sequence
from collections.abc import Callable, Iterator, Sequence
import contextlib
import dataclasses
from functools import partial
import itertools
import time
from typing import Any, Callable, NamedTuple
from typing import Any, NamedTuple
import logging
import threading

View File

@ -17,13 +17,13 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import copy
import dataclasses
import functools
import itertools
import re
from typing import Any, Callable, Union
from typing import Any, Union
import warnings
from absl import logging

View File

@ -16,9 +16,9 @@
from __future__ import annotations
from typing import Callable, TypeVar
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import partial
from typing import TypeVar
try:
import flatbuffers

View File

@ -21,7 +21,7 @@ import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()
class PyTreeDefKind(object):
class PyTreeDefKind:
leaf = 0
none = 1
tuple = 2
@ -29,12 +29,12 @@ class PyTreeDefKind(object):
dict = 4
class AbstractValueKind(object):
class AbstractValueKind:
shapedArray = 0
abstractToken = 1
class DType(object):
class DType:
bool = 0
i8 = 1
i16 = 2
@ -60,18 +60,18 @@ class DType(object):
f0 = 22
class ShardingKind(object):
class ShardingKind:
unspecified = 0
hlo_sharding = 1
class DisabledSafetyCheckKind(object):
class DisabledSafetyCheckKind:
platform = 0
custom_call = 1
shape_assertions = 2
class PyTreeDef(object):
class PyTreeDef:
__slots__ = ['_tab']
@classmethod
@ -163,7 +163,7 @@ def PyTreeDefEnd(builder):
class AbstractValue(object):
class AbstractValue:
__slots__ = ['_tab']
@classmethod
@ -235,7 +235,7 @@ def AbstractValueEnd(builder):
class Sharding(object):
class Sharding:
__slots__ = ['_tab']
@classmethod
@ -304,7 +304,7 @@ def ShardingEnd(builder):
class Effect(object):
class Effect:
__slots__ = ['_tab']
@classmethod
@ -340,7 +340,7 @@ def EffectEnd(builder):
class DisabledSafetyCheck(object):
class DisabledSafetyCheck:
__slots__ = ['_tab']
@classmethod
@ -386,7 +386,7 @@ def DisabledSafetyCheckEnd(builder):
class Exported(object):
class Exported:
__slots__ = ['_tab']
@classmethod

View File

@ -19,7 +19,7 @@ See documentation at https://jax.readthedocs.io/en/latest/export/shape_poly.html
from __future__ import annotations
import enum
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import dataclasses
from enum import Enum
import functools
@ -28,7 +28,7 @@ import io
import copy
import operator as op
import tokenize
from typing import Any, Callable, Union, overload
from typing import Any, Union, overload
import warnings
import numpy as np

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
from collections.abc import Hashable
from collections.abc import Callable, Hashable
from jax import Array

View File

@ -14,10 +14,9 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import partial
import enum
from typing import Callable
import numpy as np

View File

@ -70,13 +70,13 @@ then update `test_custom_call_coverage`, and then update your `test_foo_call`:
from __future__ import annotations
from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
import dataclasses
import datetime
import os
import re
import sys
from typing import Any, Callable
from typing import Any
from absl import logging

View File

@ -38,11 +38,11 @@ to fail. A Limitation is specific to a harness.
from __future__ import annotations
from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
import operator
import os
from functools import partial
from typing import Any, Callable, NamedTuple, Union
from typing import Any, NamedTuple, Union
from absl import testing
import numpy as np

View File

@ -14,12 +14,12 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import contextlib
import functools
import itertools as it
from functools import partial
from typing import Any, Callable
from typing import Any
import jax
from jax._src import config

View File

@ -14,10 +14,10 @@
from __future__ import annotations
import collections
from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
import dataclasses
from functools import partial
from typing import Any, Callable, Union
from typing import Any, Union
import numpy as np

View File

@ -16,7 +16,7 @@
from __future__ import annotations
import collections
from collections.abc import Iterator, Sequence
from collections.abc import Callable, Iterator, Sequence
import dataclasses
import functools
from functools import partial
@ -27,7 +27,7 @@ import os
import re
import types
import typing
from typing import Any, Callable, NamedTuple, Protocol, Union, cast as type_cast
from typing import Any, NamedTuple, Protocol, Union, cast as type_cast
import warnings
import numpy as np

View File

@ -14,13 +14,13 @@
from __future__ import annotations
from collections import namedtuple
from collections.abc import Sequence, Hashable
from collections.abc import Callable, Sequence, Hashable
from contextlib import contextmanager, AbstractContextManager
from functools import partial
import inspect
import itertools as it
import operator as op
from typing import Any, Callable, NamedTuple, Union
from typing import Any, NamedTuple, Union
from weakref import ref
import numpy as np

View File

@ -19,15 +19,14 @@ import enum
from contextlib import contextmanager
import collections
from collections import namedtuple
from collections.abc import Sequence, Iterable
from collections.abc import Callable, Sequence, Iterable, Iterator
import dataclasses
from functools import partial, lru_cache, cached_property
import itertools as it
import logging
import math
import threading
from typing import Any, Callable, NamedTuple, TypeVar, Union, cast
from collections.abc import Iterator
from typing import Any, NamedTuple, TypeVar, Union, cast
import warnings
import numpy as np

View File

@ -17,12 +17,12 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import dataclasses
import functools
from functools import partial
import itertools as it
from typing import Any, Callable, Protocol, Union
from typing import Any, Protocol, Union
import numpy as np

View File

@ -17,11 +17,12 @@
from __future__ import annotations
from collections import Counter, defaultdict
from collections.abc import Callable
import gzip
import itertools
import json
import types
from typing import Any, Callable, Union
from typing import Any, Union
from jax._src import core
from jax._src import util

View File

@ -15,10 +15,10 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import os
from functools import partial
from typing import Any, Callable
from typing import Any
from jax._src import core
from jax._src import linear_util as lu

View File

@ -15,13 +15,13 @@
from __future__ import annotations
import collections
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import functools
from functools import partial
import inspect
import itertools
import operator
from typing import Any, Callable, TypeVar
from typing import Any, TypeVar
import jax
from jax.tree_util import tree_flatten, tree_unflatten

View File

@ -15,10 +15,10 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import functools
import operator
from typing import Any, Callable, Generic, TypeVar
from typing import Any, Generic, TypeVar
import jax.numpy as jnp
from jax import lax

View File

@ -14,12 +14,12 @@
"""Module for the loop primitives."""
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import partial
import inspect
import itertools
import operator
from typing import Any, Callable, TypeVar
from typing import Any, TypeVar
import weakref
import jax

View File

@ -15,14 +15,14 @@
from __future__ import annotations
import builtins
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import enum
import functools
from functools import partial
import itertools
import math
import operator
from typing import Any, Callable, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING
from typing import Any, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING
import warnings
import numpy as np
@ -2986,10 +2986,10 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
m, k = lhs.shape
group_count, rk, n = rhs.shape
if k != rk:
raise TypeError("ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {} and {}.".format(k, rk))
raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.")
num_groups = group_sizes.shape[0]
if group_count != num_groups:
raise TypeError("ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {} and {}.".format(group_count, num_groups))
raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.")
return (m, n)
# DotDimensionNumbers used in the dot_general call for ragged_dot().

View File

@ -14,10 +14,11 @@
from __future__ import annotations
from collections.abc import Callable
import functools
from functools import partial
import math
from typing import Any, Callable, Literal, TypeVar, overload
from typing import Any, Literal, TypeVar, overload
import numpy as np

View File

@ -14,12 +14,12 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import enum
import operator
from functools import partial
import math
from typing import Callable, NamedTuple
from typing import NamedTuple
import weakref
import numpy as np

View File

@ -14,9 +14,8 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import partial
from typing import Callable
import warnings
from jax import tree_util

View File

@ -14,9 +14,9 @@
"""A LazyLoader class."""
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import importlib
from typing import Any, Callable
from typing import Any
def attach(package_name: str, submodules: Sequence[str]) -> tuple[

View File

@ -63,8 +63,9 @@ data must be immutable, because it will be stored in function memoization tables
"""
from __future__ import annotations
from collections.abc import Callable
from functools import partial
from typing import Any, Callable, NamedTuple
from typing import Any, NamedTuple
import weakref
from jax._src import config

View File

@ -15,12 +15,12 @@
from __future__ import annotations
from collections import OrderedDict, abc
from collections.abc import Iterable, Sequence, Mapping
from collections.abc import Callable, Iterable, Sequence, Mapping
import contextlib
from functools import wraps, partial, partialmethod, lru_cache
import itertools as it
import math
from typing import Callable, Any, NamedTuple, Union, cast as type_cast
from typing import Any, NamedTuple, Union, cast as type_cast
import numpy as np

View File

@ -27,13 +27,13 @@ from __future__ import annotations
import builtins
import collections
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import partial
import importlib
import math
import operator
import types
from typing import (cast, overload, Any, Callable, Literal, NamedTuple,
from typing import (cast, overload, Any, Literal, NamedTuple,
Protocol, TypeVar, Union)
from textwrap import dedent as _dedent
import warnings

View File

@ -15,11 +15,11 @@
from __future__ import annotations
import builtins
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import partial
import math
import operator
from typing import overload, Any, Callable, Literal, Protocol, Union
from typing import overload, Any, Literal, Protocol, Union
import warnings
import numpy as np

View File

@ -16,10 +16,11 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
import math
import operator
from typing import Any, Callable
from typing import Any
import jax
from jax._src.typing import Array, ArrayLike, DTypeLike

View File

@ -18,9 +18,9 @@ Implements ufuncs for jax.numpy.
from __future__ import annotations
from collections.abc import Callable
from functools import partial
import operator
from typing import Callable
import warnings

View File

@ -13,11 +13,11 @@
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import partial
import re
import textwrap
from typing import Any, Callable, NamedTuple, TypeVar
from typing import Any, NamedTuple, TypeVar
import warnings

View File

@ -13,10 +13,10 @@
# limitations under the License.
from __future__ import annotations
from collections.abc import Collection, Sequence
from collections.abc import Callable, Collection, Sequence
import functools
import re
from typing import Any, Callable
from typing import Any
from jax._src import api
from jax import lax

View File

@ -16,8 +16,8 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Callable, Union
from collections.abc import Callable, Sequence
from typing import Union
import warnings
import numpy as np

View File

@ -15,13 +15,13 @@
"""Module for pallas-core functionality."""
from __future__ import annotations
from collections.abc import Iterator, Sequence
from collections.abc import Callable, Iterator, Sequence
import copy
import contextlib
import dataclasses
import functools
import threading
from typing import Any, Callable, Union
from typing import Any, Union
import jax
from jax._src import api_util

View File

@ -78,7 +78,7 @@ class AbstractSemaphoreTy(dtypes.ExtendedDType):
return self.__class__ == other.__class__
def __hash__(self) -> int:
return hash((self.__class__))
return hash(self.__class__)
# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy
@ -109,7 +109,7 @@ class SemaphoreType(enum.Enum):
dtype = SemaphoreTy()
return MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE)
def get_aval(self) -> "AbstractMemoryRef":
def get_aval(self) -> AbstractMemoryRef:
return self(()).get_aval()
@dataclasses.dataclass(frozen=True)

View File

@ -15,11 +15,11 @@
"""Module for lowering JAX to Mosaic-compatible MLIR dialects."""
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import dataclasses
import functools
import string
from typing import Any, Callable
from typing import Any
import jax
from jax import core as jax_core

View File

@ -15,12 +15,13 @@
"""Module for emitting custom TPU pipelines within a Pallas call."""
from __future__ import annotations
from collections.abc import Sequence
import dataclasses
import enum
import functools
import itertools
import operator
from typing import Optional, Union, Any, Sequence
from typing import Union, Any
import jax
from jax import lax
@ -201,12 +202,12 @@ class BufferedRef:
spec: pl.BlockSpec # static metadata
dtype: Any # static metadata
buffer_type: BufferType # static metadata
vmem_ref: Optional[REF]
accum_ref: Optional[REF]
current_slot: Optional[ArrayRef]
next_slot: Optional[ArrayRef]
sem_recv: Optional[SemaphoreType]
sem_send: Optional[SemaphoreType]
vmem_ref: REF | None
accum_ref: REF | None
current_slot: ArrayRef | None
next_slot: ArrayRef | None
sem_recv: SemaphoreType | None
sem_send: SemaphoreType | None
def tree_flatten(self):
return ((self.vmem_ref, self.accum_ref, self.current_slot,
@ -218,7 +219,7 @@ class BufferedRef:
return cls(*meta, *data)
@classmethod
def create(cls, spec, dtype, buffer_type) -> 'BufferedRef':
def create(cls, spec, dtype, buffer_type) -> BufferedRef:
"""Create a BufferedRef.
Args:
@ -810,9 +811,9 @@ def _partition_grid(
if isinstance(grid[i], int) and grid[i] % num_cores == 0
}
if divisible_dimensions:
first_divisible_dimension, *_ = [
first_divisible_dimension, *_ = (
i for i in range(len(dimension_semantics)) if i in divisible_dimensions
]
)
partitioned_dim_size = grid[first_divisible_dimension] // num_cores
partitioned_dim_offset = pl.program_id(core_axis) * partitioned_dim_size
new_grid = jax_util.tuple_update(
@ -828,11 +829,11 @@ def _partition_grid(
# potentially divide it more evenly
largest_parallel_dimension = max(grid[i] for i in parallel_dimensions
if isinstance(grid[i], int)) # type: ignore
partition_dimension, *_ = [
partition_dimension, *_ = (
i
for i, d in enumerate(grid)
if isinstance(d, int) and d == largest_parallel_dimension
]
)
base_num_iters, rem = divmod(grid[partition_dimension], num_cores)
assert rem > 0, rem
# We have some remainder iterations that we need to assign somewhere. We

View File

@ -15,9 +15,10 @@
"""Module for Pallas:TPU-specific JAX primitives and functions."""
from __future__ import annotations
from collections.abc import Callable
import dataclasses
import enum
from typing import Any, Callable
from typing import Any
import jax
from jax._src import api_util

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional
from collections.abc import Callable
import jax
import numpy as np
@ -172,7 +173,7 @@ def sample_block(sampler_fn: SampleFnType,
block_size: Shape,
tile_size: Shape,
total_size: Shape,
block_index: Optional[tuple[typing.ArrayLike, ...]] = None,
block_index: tuple[typing.ArrayLike, ...] | None = None,
**kwargs) -> jax.Array:
"""Samples a block of random values with invariance guarantees.

View File

@ -15,10 +15,10 @@
"""Module for calling pallas functions from JAX."""
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import partial, reduce
import itertools
from typing import Any, Callable
from typing import Any
import jax
from jax import api_util

View File

@ -16,12 +16,12 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import dataclasses
import functools
import math
import operator
from typing import Any, Callable, TypeVar
from typing import Any, TypeVar
import jax
from jax import lax

View File

@ -15,7 +15,7 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Sequence, Iterable
from collections.abc import Callable, Sequence, Iterable
import dataclasses
from functools import partial
import inspect
@ -23,7 +23,7 @@ import itertools as it
import logging
import operator as op
import weakref
from typing import Callable, NamedTuple, Any, Union, Optional, cast
from typing import NamedTuple, Any, Union, cast
import threading
import warnings
@ -245,7 +245,7 @@ def _need_to_rebuild_with_fdo(pgle_profiler):
def _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked, effects,
consts, abstracted_axes, pgle_profiler
) -> Optional[pxla.MeshExecutableFastpathData]:
) -> pxla.MeshExecutableFastpathData | None:
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
use_fastpath = (
@ -608,7 +608,7 @@ def _infer_params_impl(
assert None not in in_shardings_leaves
assert None not in out_shardings_leaves
in_type: Union[core.InputType, tuple[core.AbstractValue, ...]]
in_type: core.InputType | tuple[core.AbstractValue, ...]
if config.dynamic_shapes.value:
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
in_avals = tuple(a for a, e in in_type if e)

View File

@ -13,11 +13,11 @@
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterator, Sequence
from collections.abc import Callable, Iterator, Sequence
from functools import partial, reduce
import math
import operator as op
from typing import Any, Callable, NamedTuple
from typing import Any, NamedTuple
import numpy as np

View File

@ -14,6 +14,7 @@
from __future__ import annotations
from collections.abc import Callable
from contextlib import contextmanager
from functools import wraps
import glob
@ -24,7 +25,7 @@ import logging
import os
import socketserver
import threading
from typing import Callable, List, Optional, Union, Any
from typing import Any
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
@ -210,7 +211,7 @@ def stop_trace():
_profile_state.reset()
def stop_and_get_fdo_profile() -> Union[bytes, str]:
def stop_and_get_fdo_profile() -> bytes | str:
"""Stops the currently-running profiler trace and export fdo_profile.
Currently, this is only supported for GPU.
@ -391,10 +392,10 @@ class PGLEProfiler:
self.percentile: int = percentile
self.collected_fdo: str | None = None
self.called_times: int = 0
self.fdo_profiles: List[Any] = []
self.fdo_profiles: list[Any] = []
self.current_session: xla_client.profiler.ProfilerSession | None = None
def consume_fdo_profile(self) -> Optional[str]:
def consume_fdo_profile(self) -> str | None:
if self.collected_fdo is not None:
return self.collected_fdo

View File

@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import functools
import itertools
import operator
from typing import Callable
from jax._src import api
from jax._src import util

View File

@ -15,8 +15,9 @@
from __future__ import annotations
from typing import Callable, NamedTuple
from collections.abc import Callable
from functools import partial
from typing import NamedTuple
import jax
import jax.numpy as jnp

View File

@ -15,8 +15,9 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
from typing import Callable, NamedTuple
from typing import NamedTuple
import jax
import jax.numpy as jnp

View File

@ -14,8 +14,8 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Callable
from collections.abc import Callable, Mapping
from typing import Any
import jax
from jax._src.scipy.optimize.bfgs import minimize_bfgs

View File

@ -14,11 +14,10 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import partial
import math
import operator
from typing import Callable
import warnings
import numpy as np

View File

@ -18,9 +18,10 @@ An implementation of sourcemaps following `TC39 <https://tc39.es/source-map>`_.
from __future__ import annotations
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
import json
from typing import Iterable, Sequence, Union
from typing import Union
# A Segment encodes how parts in the generated source relate to the original source.
# Each segment is made up of 1, 4 or 5 variable-length fields. For their semantics see

View File

@ -315,7 +315,7 @@ class XlaLowering(Lowering):
def hlo(self) -> xc.XlaComputation:
"""Return an HLO representation of this computation."""
hlo = self.stablehlo()
m: Union[str, bytes]
m: str | bytes
m = mlir.module_to_bytecode(hlo)
return xla_extension.mlir.mlir_module_to_xla_computation(
m, use_tuple_args=self.compile_args["tuple_args"])

View File

@ -14,11 +14,11 @@
"""Module for discharging state primitives."""
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import dataclasses
from functools import partial
import operator
from typing import Any, Callable, Protocol
from typing import Any, Protocol
import numpy as np

View File

@ -65,9 +65,9 @@ class Slice:
@classmethod
def tree_unflatten(cls, aux_data, children) -> Slice:
start, size = [
start, size = (
a if a is not None else b for a, b in zip(children, aux_data[:2])
]
)
return cls(start, size, aux_data[2])
@classmethod

View File

@ -16,7 +16,7 @@
from __future__ import annotations
import collections
from collections.abc import Generator, Iterable, Sequence
from collections.abc import Callable, Generator, Iterable, Sequence
from contextlib import ExitStack, contextmanager
import datetime
import functools
@ -28,7 +28,7 @@ import re
import sys
import tempfile
import textwrap
from typing import Any, Callable
from typing import Any
import unittest
import warnings
import zlib

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Callable
from collections.abc import Callable
from jax import jit, lax
import jax.numpy as jnp

View File

@ -19,13 +19,13 @@ from __future__ import annotations
import base64
import collections.abc
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import dataclasses
import functools
import io
import os
import time
from typing import Any, Callable
from typing import Any
import jax
from jax import core

View File

@ -14,12 +14,13 @@
from __future__ import annotations
from collections.abc import Callable
import functools
import os
import sys
import traceback
import types
from typing import Any, Callable, TypeVar, cast
from typing import Any, TypeVar, cast
from jax._src import config
from jax._src import util

View File

@ -13,7 +13,8 @@
# limitations under the License.
from __future__ import annotations
from typing import Any, Callable, Iterable, TypeVar, overload
from collections.abc import Callable, Iterable
from typing import Any, TypeVar, overload
from jax._src import tree_util

View File

@ -14,14 +14,14 @@
from __future__ import annotations
import collections
from collections.abc import Hashable, Iterable
from collections.abc import Callable, Hashable, Iterable, Sequence
from dataclasses import dataclass
import difflib
import functools
from functools import partial
import operator as op
import textwrap
from typing import Any, Callable, NamedTuple, Sequence, TypeVar, Union, overload
from typing import Any, NamedTuple, TypeVar, Union, overload
from jax._src import traceback_util
from jax._src.lib import pytree

View File

@ -15,14 +15,14 @@
from __future__ import annotations
import abc
from collections.abc import Iterable, Iterator, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
import dataclasses
import functools
from functools import partial
import itertools as it
import logging
import operator
from typing import (Any, Callable, Generic, TypeVar, overload, TYPE_CHECKING, cast)
from typing import (Any, Generic, TypeVar, overload, TYPE_CHECKING, cast)
import weakref
import numpy as np

View File

@ -21,7 +21,7 @@ XLA. There are also a handful of related casting utilities.
from __future__ import annotations
import atexit
from collections.abc import Mapping
from collections.abc import Callable, Mapping
import dataclasses
from functools import lru_cache, partial
import importlib
@ -32,7 +32,7 @@ import pkgutil
import platform as py_platform
import threading
import traceback
from typing import Any, Callable, Union
from typing import Any, Union
import warnings
from jax._src import config

View File

@ -91,7 +91,8 @@ Example Usage:
from __future__ import annotations
from typing import Any, Callable, NamedTuple
from collections.abc import Callable
from typing import Any, NamedTuple
from collections import namedtuple
import functools

View File

@ -15,7 +15,6 @@
from __future__ import annotations
import jax
from typing import Tuple
from jax._src.sharding import Sharding
from jax._src.lib import xla_client as xc
from jax._src import dtypes as _dtypes, config
@ -71,7 +70,7 @@ class __array_namespace_info__:
def dtypes(
self, *,
device: xc.Device | Sharding | None = None,
kind: str | Tuple[str, ...] | None = None):
kind: str | tuple[str, ...] | None = None):
# Array API supported dtypes are device-independent in JAX
del device
data_types = self._build_dtype_dict()

View File

@ -17,16 +17,15 @@ from __future__ import annotations
import abc
import asyncio
from collections.abc import Awaitable, Sequence
from collections.abc import Awaitable, Callable, Sequence
from functools import partial
import itertools
import logging
import os
import re
import sys
import threading
import time
from typing import Any, Callable, Optional, Union
from typing import Any
import jax
from jax._src import array
@ -130,7 +129,7 @@ def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
return spec
def is_remote_storage(tspec: Union[dict[str, Any], str]) -> bool:
def is_remote_storage(tspec: dict[str, Any] | str) -> bool:
"""Detect if user is using cloud storages.
This can detect common defines and unable to detect some corner cases such as
@ -190,7 +189,7 @@ async def async_serialize(
tensorstore_spec,
commit_future=None,
context=TS_CONTEXT,
primary_host: Optional[int] = 0,
primary_host: int | None = 0,
replica_id: int = 0,
):
"""Serialize an array using TensorStore.

View File

@ -503,14 +503,14 @@ from __future__ import annotations
import atexit
import enum
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import functools
import itertools
import logging
import math
import threading
import traceback
from typing import Any, Callable, cast
from typing import Any, cast
import jax
from jax._src import api

View File

@ -25,10 +25,10 @@ https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#callin
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import dataclasses
import functools
from typing import Any, Callable, Optional
from typing import Any
from absl import logging
import jax

View File

@ -21,12 +21,12 @@ See README.md for how these are used.
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import functools
import logging
import re
import time
from typing import Any, Callable, Optional
from typing import Any
from absl import flags
import flax

View File

@ -26,8 +26,8 @@ customize this function as needed.
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
from collections.abc import Callable, Sequence
from typing import Any
from jax.experimental import jax2tf
import tensorflow as tf

View File

@ -16,12 +16,12 @@
from __future__ import annotations
import builtins
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import dataclasses
from functools import partial, wraps
import math
import string
from typing import Any, Callable, Optional
from typing import Any
from jax._src import core
from jax import lax

View File

@ -15,7 +15,7 @@
from __future__ import annotations
from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
from functools import partial
import contextlib
import math
@ -23,7 +23,7 @@ import operator
import os
import re
import threading
from typing import Any, Callable, Union
from typing import Any, Union
import warnings
from absl import logging

View File

@ -20,11 +20,10 @@ these tests.
from __future__ import annotations
import base64
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import io
import os
import tarfile
from typing import Callable, Optional
from absl.testing import absltest
import jax

View File

@ -13,10 +13,10 @@
# limitations under the License.
"""Tests for call_tf."""
from collections.abc import Callable
import contextlib
from functools import partial
import os
from typing import Callable
import unittest
from absl import logging

View File

@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Converters for jax2tf."""
from collections.abc import Callable
import dataclasses
import functools
import tempfile
from typing import Any, Callable
from typing import Any
from jax.experimental import jax2tf
import tensorflow as tf
import tensorflowjs as tfjs

View File

@ -26,12 +26,11 @@ currently saved file with the saved one.
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import contextlib
import dataclasses
import os
import re
from typing import Callable, Optional
import zlib
from absl import app

View File

@ -18,8 +18,9 @@ https://github.com/google/flax/tree/main/examples/sst2
from __future__ import annotations
from collections.abc import Callable
import functools
from typing import Any, Callable, Optional
from typing import Any
from flax import linen as nn
import jax

View File

@ -16,8 +16,7 @@
https://github.com/google/flax/tree/main/examples/ogbg_molpcba
"""
from collections.abc import Sequence
from typing import Callable
from collections.abc import Callable, Sequence
from flax import linen as nn

View File

@ -19,9 +19,9 @@ https://github.com/google/flax/tree/main/examples/imagenet
# See issue #620.
# pytype: disable=wrong-arg-count
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import partial
from typing import Any, Callable
from typing import Any
from flax import linen as nn
import jax.numpy as jnp

Some files were not shown because too many files have changed in this diff Show More