mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Run pyupgrade --py310-plus
.
Also apply manual fixes to import sorting and unused imports.
This commit is contained in:
parent
cdfe2df384
commit
7f4ef63cd8
@ -32,7 +32,7 @@ def extract_filename(path):
|
||||
def generate_final_report(shell=False, env_vars={}):
|
||||
env = os.environ
|
||||
env = {**env, **env_vars}
|
||||
cmd = ["pytest_html_merger", "-i", '{}'.format(base_dir), "-o", '{}/final_compiled_report.html'.format(base_dir)]
|
||||
cmd = ["pytest_html_merger", "-i", f'{base_dir}', "-o", f'{base_dir}/final_compiled_report.html']
|
||||
result = subprocess.run(cmd,
|
||||
shell=shell,
|
||||
capture_output=True,
|
||||
@ -90,7 +90,7 @@ def run_test(testmodule, gpu_tokens):
|
||||
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
|
||||
}
|
||||
testfile = extract_filename(testmodule)
|
||||
cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-x", testmodule]
|
||||
cmd = ["python3", "-m", "pytest", f'--html={base_dir}/{testfile}_log.html', "--reruns", "3", "-x", testmodule]
|
||||
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
|
||||
with GPU_LOCK:
|
||||
gpu_tokens.append(target_gpu)
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
from functools import partial, reduce
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -325,9 +324,9 @@ class RmsNormFwdClass:
|
||||
return RmsNormFwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims
|
||||
|
||||
@staticmethod
|
||||
def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
|
||||
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
|
||||
result_infos : Tuple[jax._src.core.ShapedArray]):
|
||||
def infer_sharding_from_operands(eps: float, mesh : jax.sharding.Mesh,
|
||||
arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
|
||||
result_infos: tuple[jax._src.core.ShapedArray, ...]):
|
||||
del eps, result_infos # Not needed for this example.
|
||||
x_info, weight_info = arg_infos
|
||||
assert len(x_info.shape) == 3
|
||||
@ -340,9 +339,9 @@ class RmsNormFwdClass:
|
||||
return (output_sharding, invvar_sharding)
|
||||
|
||||
@staticmethod
|
||||
def partition(eps : float, mesh : jax.sharding.Mesh,
|
||||
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
|
||||
result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
|
||||
def partition(eps: float, mesh : jax.sharding.Mesh,
|
||||
arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
|
||||
result_infos: tuple[jax._src.api.ShapeDtypeStruct, ...]):
|
||||
del result_infos # Not needed for this example.
|
||||
x_info, weight_info = arg_infos
|
||||
assert len(x_info.shape) == 3
|
||||
@ -395,9 +394,9 @@ class RmsNormBwdClass:
|
||||
return RmsNormBwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims
|
||||
|
||||
@staticmethod
|
||||
def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
|
||||
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
|
||||
result_infos : Tuple[jax._src.core.ShapedArray]):
|
||||
def infer_sharding_from_operands(eps: float, mesh : jax.sharding.Mesh,
|
||||
arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
|
||||
result_infos: tuple[jax._src.core.ShapedArray, ...]):
|
||||
del eps, result_infos # Not needed for this example.
|
||||
g_info, invvar_info, x_info, weight_info = arg_infos
|
||||
assert len(g_info.shape) == 3
|
||||
@ -411,9 +410,9 @@ class RmsNormBwdClass:
|
||||
return (output_sharding, invvar_sharding, output_sharding, )
|
||||
|
||||
@staticmethod
|
||||
def partition(eps : float, mesh : jax.sharding.Mesh,
|
||||
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
|
||||
result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
|
||||
def partition(eps: float, mesh : jax.sharding.Mesh,
|
||||
arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
|
||||
result_infos: tuple[jax._src.api.ShapeDtypeStruct, ...]):
|
||||
del result_infos # Not needed for this example.
|
||||
g_info, invvar_info, x_info, weight_info = arg_infos
|
||||
assert len(g_info.shape) == 3
|
||||
|
@ -167,15 +167,15 @@
|
||||
"source": [
|
||||
"from collections.abc import Sequence\n",
|
||||
"from contextlib import contextmanager\n",
|
||||
"from typing import Optional, Any\n",
|
||||
"from typing import Any\n",
|
||||
"\n",
|
||||
"class MainTrace(NamedTuple):\n",
|
||||
" level: int\n",
|
||||
" trace_type: type['Trace']\n",
|
||||
" global_data: Optional[Any]\n",
|
||||
" global_data: Any | None\n",
|
||||
"\n",
|
||||
"trace_stack: list[MainTrace] = []\n",
|
||||
"dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3\n",
|
||||
"dynamic_trace: MainTrace | None = None # to be employed in Part 3\n",
|
||||
"\n",
|
||||
"@contextmanager\n",
|
||||
"def new_main(trace_type: type['Trace'], global_data=None):\n",
|
||||
@ -912,7 +912,7 @@
|
||||
"source": [
|
||||
"from collections.abc import Hashable, Iterable, Iterator\n",
|
||||
"import itertools as it\n",
|
||||
"from typing import Callable\n",
|
||||
"from collections.abc import Callable\n",
|
||||
"\n",
|
||||
"class NodeType(NamedTuple):\n",
|
||||
" name: str\n",
|
||||
@ -1651,7 +1651,7 @@
|
||||
"source": [
|
||||
"from functools import lru_cache\n",
|
||||
"\n",
|
||||
"@lru_cache() # ShapedArrays are hashable\n",
|
||||
"@lru_cache # ShapedArrays are hashable\n",
|
||||
"def make_jaxpr_v1(f, *avals_in):\n",
|
||||
" avals_in, in_tree = tree_flatten(avals_in)\n",
|
||||
" f, out_tree = flatten_fun(f, in_tree)\n",
|
||||
@ -1803,7 +1803,7 @@
|
||||
" finally:\n",
|
||||
" dynamic_trace = prev_dynamic_trace\n",
|
||||
"\n",
|
||||
"@lru_cache()\n",
|
||||
"@lru_cache\n",
|
||||
"def make_jaxpr(f: Callable, *avals_in: ShapedArray,\n",
|
||||
" ) -> tuple[Jaxpr, list[Any], PyTreeDef]:\n",
|
||||
" avals_in, in_tree = tree_flatten(avals_in)\n",
|
||||
@ -1994,7 +1994,7 @@
|
||||
" return execute(*args)\n",
|
||||
"impl_rules[xla_call_p] = xla_call_impl\n",
|
||||
"\n",
|
||||
"@lru_cache()\n",
|
||||
"@lru_cache\n",
|
||||
"def xla_callable(hashable_jaxpr: IDHashable,\n",
|
||||
" hashable_consts: tuple[IDHashable, ...]):\n",
|
||||
" jaxpr: Jaxpr = hashable_jaxpr.val\n",
|
||||
@ -2227,7 +2227,7 @@
|
||||
" return primals_out, tangents_out\n",
|
||||
"jvp_rules[xla_call_p] = xla_call_jvp_rule\n",
|
||||
"\n",
|
||||
"@lru_cache()\n",
|
||||
"@lru_cache\n",
|
||||
"def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:\n",
|
||||
" def jvp_traceable(*primals_and_tangents):\n",
|
||||
" n = len(primals_and_tangents) // 2\n",
|
||||
@ -2253,7 +2253,7 @@
|
||||
" return outs, [0] * len(outs)\n",
|
||||
"vmap_rules[xla_call_p] = xla_call_vmap_rule\n",
|
||||
"\n",
|
||||
"@lru_cache()\n",
|
||||
"@lru_cache\n",
|
||||
"def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]\n",
|
||||
" ) -> tuple[Jaxpr, list[Any]]:\n",
|
||||
" vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))\n",
|
||||
@ -2638,7 +2638,7 @@
|
||||
"source": [
|
||||
"class PartialVal(NamedTuple):\n",
|
||||
" aval: ShapedArray\n",
|
||||
" const: Optional[Any]\n",
|
||||
" const: Any | None\n",
|
||||
"\n",
|
||||
" @classmethod\n",
|
||||
" def known(cls, val: Any):\n",
|
||||
@ -2727,7 +2727,7 @@
|
||||
"source": [
|
||||
"class PartialEvalTracer(Tracer):\n",
|
||||
" pval: PartialVal\n",
|
||||
" recipe: Optional[JaxprRecipe]\n",
|
||||
" recipe: JaxprRecipe | None\n",
|
||||
"\n",
|
||||
" def __init__(self, trace, pval, recipe):\n",
|
||||
" self._trace = trace\n",
|
||||
@ -2974,7 +2974,7 @@
|
||||
"partial_eval_rules[xla_call_p] = xla_call_partial_eval\n",
|
||||
"\n",
|
||||
"def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],\n",
|
||||
" instantiate: Optional[list[bool]] = None,\n",
|
||||
" instantiate: list[bool] | None = None,\n",
|
||||
" ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:\n",
|
||||
" env: dict[Var, bool] = {}\n",
|
||||
" residuals: set[Var] = set()\n",
|
||||
@ -3271,7 +3271,7 @@
|
||||
" return [next(outs) if undef else None for undef in undef_primals]\n",
|
||||
"transpose_rules[xla_call_p] = xla_call_transpose_rule\n",
|
||||
"\n",
|
||||
"@lru_cache()\n",
|
||||
"@lru_cache\n",
|
||||
"def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]\n",
|
||||
" ) -> tuple[Jaxpr, list[Any]]:\n",
|
||||
" avals_in, avals_out = typecheck_jaxpr(jaxpr)\n",
|
||||
|
@ -148,15 +148,15 @@ more descriptive.
|
||||
```{code-cell}
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Any
|
||||
from typing import Any
|
||||
|
||||
class MainTrace(NamedTuple):
|
||||
level: int
|
||||
trace_type: type['Trace']
|
||||
global_data: Optional[Any]
|
||||
global_data: Any | None
|
||||
|
||||
trace_stack: list[MainTrace] = []
|
||||
dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3
|
||||
dynamic_trace: MainTrace | None = None # to be employed in Part 3
|
||||
|
||||
@contextmanager
|
||||
def new_main(trace_type: type['Trace'], global_data=None):
|
||||
@ -705,7 +705,7 @@ class Store:
|
||||
|
||||
from collections.abc import Hashable, Iterable, Iterator
|
||||
import itertools as it
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
class NodeType(NamedTuple):
|
||||
name: str
|
||||
@ -1295,7 +1295,7 @@ transformation and a pretty-printer:
|
||||
```{code-cell}
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache() # ShapedArrays are hashable
|
||||
@lru_cache # ShapedArrays are hashable
|
||||
def make_jaxpr_v1(f, *avals_in):
|
||||
avals_in, in_tree = tree_flatten(avals_in)
|
||||
f, out_tree = flatten_fun(f, in_tree)
|
||||
@ -1415,7 +1415,7 @@ def new_dynamic(main: MainTrace):
|
||||
finally:
|
||||
dynamic_trace = prev_dynamic_trace
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def make_jaxpr(f: Callable, *avals_in: ShapedArray,
|
||||
) -> tuple[Jaxpr, list[Any], PyTreeDef]:
|
||||
avals_in, in_tree = tree_flatten(avals_in)
|
||||
@ -1564,7 +1564,7 @@ def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
|
||||
return execute(*args)
|
||||
impl_rules[xla_call_p] = xla_call_impl
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def xla_callable(hashable_jaxpr: IDHashable,
|
||||
hashable_consts: tuple[IDHashable, ...]):
|
||||
jaxpr: Jaxpr = hashable_jaxpr.val
|
||||
@ -1734,7 +1734,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
|
||||
return primals_out, tangents_out
|
||||
jvp_rules[xla_call_p] = xla_call_jvp_rule
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
|
||||
def jvp_traceable(*primals_and_tangents):
|
||||
n = len(primals_and_tangents) // 2
|
||||
@ -1755,7 +1755,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
|
||||
return outs, [0] * len(outs)
|
||||
vmap_rules[xla_call_p] = xla_call_vmap_rule
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
|
||||
) -> tuple[Jaxpr, list[Any]]:
|
||||
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
|
||||
@ -2065,7 +2065,7 @@ be either known or unknown:
|
||||
```{code-cell}
|
||||
class PartialVal(NamedTuple):
|
||||
aval: ShapedArray
|
||||
const: Optional[Any]
|
||||
const: Any | None
|
||||
|
||||
@classmethod
|
||||
def known(cls, val: Any):
|
||||
@ -2129,7 +2129,7 @@ JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
|
||||
```{code-cell}
|
||||
class PartialEvalTracer(Tracer):
|
||||
pval: PartialVal
|
||||
recipe: Optional[JaxprRecipe]
|
||||
recipe: JaxprRecipe | None
|
||||
|
||||
def __init__(self, trace, pval, recipe):
|
||||
self._trace = trace
|
||||
@ -2329,7 +2329,7 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
|
||||
partial_eval_rules[xla_call_p] = xla_call_partial_eval
|
||||
|
||||
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
|
||||
instantiate: Optional[list[bool]] = None,
|
||||
instantiate: list[bool] | None = None,
|
||||
) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
|
||||
env: dict[Var, bool] = {}
|
||||
residuals: set[Var] = set()
|
||||
@ -2586,7 +2586,7 @@ def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
|
||||
return [next(outs) if undef else None for undef in undef_primals]
|
||||
transpose_rules[xla_call_p] = xla_call_transpose_rule
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
|
||||
) -> tuple[Jaxpr, list[Any]]:
|
||||
avals_in, avals_out = typecheck_jaxpr(jaxpr)
|
||||
|
@ -138,15 +138,15 @@ def bind1(prim, *args, **params):
|
||||
# +
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Any
|
||||
from typing import Any
|
||||
|
||||
class MainTrace(NamedTuple):
|
||||
level: int
|
||||
trace_type: type['Trace']
|
||||
global_data: Optional[Any]
|
||||
global_data: Any | None
|
||||
|
||||
trace_stack: list[MainTrace] = []
|
||||
dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3
|
||||
dynamic_trace: MainTrace | None = None # to be employed in Part 3
|
||||
|
||||
@contextmanager
|
||||
def new_main(trace_type: type['Trace'], global_data=None):
|
||||
@ -697,7 +697,7 @@ class Store:
|
||||
# + tags=["hide-input"]
|
||||
from collections.abc import Hashable, Iterable, Iterator
|
||||
import itertools as it
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
class NodeType(NamedTuple):
|
||||
name: str
|
||||
@ -1297,7 +1297,7 @@ abstract_eval_rules[broadcast_p] = broadcast_abstract_eval
|
||||
# +
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache() # ShapedArrays are hashable
|
||||
@lru_cache # ShapedArrays are hashable
|
||||
def make_jaxpr_v1(f, *avals_in):
|
||||
avals_in, in_tree = tree_flatten(avals_in)
|
||||
f, out_tree = flatten_fun(f, in_tree)
|
||||
@ -1412,7 +1412,7 @@ def new_dynamic(main: MainTrace):
|
||||
finally:
|
||||
dynamic_trace = prev_dynamic_trace
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def make_jaxpr(f: Callable, *avals_in: ShapedArray,
|
||||
) -> tuple[Jaxpr, list[Any], PyTreeDef]:
|
||||
avals_in, in_tree = tree_flatten(avals_in)
|
||||
@ -1556,7 +1556,7 @@ def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
|
||||
return execute(*args)
|
||||
impl_rules[xla_call_p] = xla_call_impl
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def xla_callable(hashable_jaxpr: IDHashable,
|
||||
hashable_consts: tuple[IDHashable, ...]):
|
||||
jaxpr: Jaxpr = hashable_jaxpr.val
|
||||
@ -1728,7 +1728,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
|
||||
return primals_out, tangents_out
|
||||
jvp_rules[xla_call_p] = xla_call_jvp_rule
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
|
||||
def jvp_traceable(*primals_and_tangents):
|
||||
n = len(primals_and_tangents) // 2
|
||||
@ -1749,7 +1749,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
|
||||
return outs, [0] * len(outs)
|
||||
vmap_rules[xla_call_p] = xla_call_vmap_rule
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
|
||||
) -> tuple[Jaxpr, list[Any]]:
|
||||
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
|
||||
@ -2057,7 +2057,7 @@ def vspace(aval: ShapedArray) -> ShapedArray:
|
||||
|
||||
class PartialVal(NamedTuple):
|
||||
aval: ShapedArray
|
||||
const: Optional[Any]
|
||||
const: Any | None
|
||||
|
||||
@classmethod
|
||||
def known(cls, val: Any):
|
||||
@ -2121,7 +2121,7 @@ JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
|
||||
|
||||
class PartialEvalTracer(Tracer):
|
||||
pval: PartialVal
|
||||
recipe: Optional[JaxprRecipe]
|
||||
recipe: JaxprRecipe | None
|
||||
|
||||
def __init__(self, trace, pval, recipe):
|
||||
self._trace = trace
|
||||
@ -2322,7 +2322,7 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
|
||||
partial_eval_rules[xla_call_p] = xla_call_partial_eval
|
||||
|
||||
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
|
||||
instantiate: Optional[list[bool]] = None,
|
||||
instantiate: list[bool] | None = None,
|
||||
) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
|
||||
env: dict[Var, bool] = {}
|
||||
residuals: set[Var] = set()
|
||||
@ -2585,7 +2585,7 @@ def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
|
||||
return [next(outs) if undef else None for undef in undef_primals]
|
||||
transpose_rules[xla_call_p] = xla_call_transpose_rule
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
|
||||
) -> tuple[Jaxpr, list[Any]]:
|
||||
avals_in, avals_out = typecheck_jaxpr(jaxpr)
|
||||
|
@ -14,11 +14,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import functools
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
import types
|
||||
|
||||
import numpy as np
|
||||
|
@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import types
|
||||
from typing import Any, Callable, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import traceback_util
|
||||
|
@ -23,12 +23,12 @@ arrays.
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Generator, Hashable, Iterable, Sequence
|
||||
from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
|
||||
from functools import partial, lru_cache
|
||||
import inspect
|
||||
import math
|
||||
import typing
|
||||
from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload,
|
||||
from typing import (Any, Literal, NamedTuple, TypeVar, overload,
|
||||
cast)
|
||||
import weakref
|
||||
|
||||
|
@ -14,11 +14,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
import inspect
|
||||
import operator
|
||||
from functools import partial, lru_cache
|
||||
from typing import Any, Callable, Type
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -713,6 +713,6 @@ class _HashableByObjectId:
|
||||
def __eq__(self, other):
|
||||
return self.val is other.val
|
||||
|
||||
def register_class_with_attrs(t: Type) -> None:
|
||||
def register_class_with_attrs(t: type) -> None:
|
||||
_class_with_attrs.add(t)
|
||||
_class_with_attrs: set[Type] = set()
|
||||
_class_with_attrs: set[type] = set()
|
||||
|
@ -15,12 +15,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import enum
|
||||
import functools
|
||||
import math
|
||||
import operator as op
|
||||
from typing import Any, Callable, TYPE_CHECKING, cast
|
||||
from typing import Any, TYPE_CHECKING, cast
|
||||
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import api
|
||||
|
@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Protocol, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Protocol
|
||||
import jax
|
||||
from jax._src import random
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
@ -14,11 +14,11 @@
|
||||
"""Module for JAX callbacks."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
|
@ -13,11 +13,11 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools as it
|
||||
from typing import Callable, TypeVar, Any, Union
|
||||
from typing import TypeVar, Any, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -16,7 +16,6 @@ import os
|
||||
from jax import version
|
||||
from jax._src import config
|
||||
from jax._src import hardware_utils
|
||||
from typing import Optional
|
||||
|
||||
running_in_cloud_tpu_vm: bool = False
|
||||
|
||||
@ -35,7 +34,7 @@ def maybe_import_libtpu():
|
||||
return libtpu
|
||||
|
||||
|
||||
def get_tpu_library_path() -> Optional[str]:
|
||||
def get_tpu_library_path() -> str | None:
|
||||
path_from_env = os.getenv("TPU_LIBRARY_PATH")
|
||||
if path_from_env is not None and os.path.isfile(path_from_env):
|
||||
return path_from_env
|
||||
|
@ -21,7 +21,7 @@ import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
from jax._src import compilation_cache
|
||||
@ -393,7 +393,7 @@ def _share_fdo_profiles(
|
||||
backend: xc.Client,
|
||||
global_client: lib.xla_extension.DistributedRuntimeClient,
|
||||
min_process_id
|
||||
) -> Optional[bytes]:
|
||||
) -> bytes | None:
|
||||
sym_name = computation.operation.attributes['sym_name']
|
||||
module_name = ir.StringAttr(sym_name).value
|
||||
fdo_profile = compile_options.executable_build_options.fdo_profile
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Hashable, Iterator, Sequence
|
||||
from collections.abc import Callable, Hashable, Iterator, Sequence
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
@ -22,9 +22,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from typing import (
|
||||
Any, Callable, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast,
|
||||
)
|
||||
from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast
|
||||
|
||||
from jax._src import lib
|
||||
from jax._src.lib import jax_jit
|
||||
|
@ -14,8 +14,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter, defaultdict, deque, namedtuple
|
||||
from collections.abc import (Collection, Generator, Hashable, Iterable,
|
||||
Iterator, Set, Sequence, MutableSet,
|
||||
from collections.abc import (Callable, Collection, Generator, Hashable,
|
||||
Iterable, Iterator, Set, Sequence, MutableSet,
|
||||
MutableMapping)
|
||||
from contextlib import contextmanager, ExitStack
|
||||
from dataclasses import dataclass
|
||||
@ -28,7 +28,7 @@ import math
|
||||
import operator
|
||||
import threading
|
||||
import types
|
||||
from typing import (Any, Callable, ClassVar, Generic, NamedTuple, TypeVar,
|
||||
from typing import (Any, ClassVar, Generic, NamedTuple, TypeVar,
|
||||
cast, overload, Union)
|
||||
import warnings
|
||||
from weakref import ref
|
||||
|
@ -15,7 +15,6 @@
|
||||
from enum import Enum
|
||||
from functools import partial, reduce
|
||||
import operator
|
||||
from typing import Optional
|
||||
import json
|
||||
|
||||
import jax
|
||||
@ -927,10 +926,10 @@ _dot_product_attention.defvjp(_dot_product_attention_fwd_rule, _dot_product_atte
|
||||
def dot_product_attention(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
q_seqlen: Optional[Array] = None,
|
||||
kv_seqlen: Optional[Array] = None,
|
||||
bias: Array | None = None,
|
||||
mask: Array | None = None,
|
||||
q_seqlen: Array | None = None,
|
||||
kv_seqlen: Array | None = None,
|
||||
*,
|
||||
scale: float = 1.0,
|
||||
mask_type: MaskType = MaskType.NO_MASK,
|
||||
|
@ -14,9 +14,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
import operator
|
||||
from typing import Callable
|
||||
|
||||
from jax import lax
|
||||
from jax._src import api
|
||||
|
@ -14,11 +14,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
from functools import update_wrapper, reduce, partial
|
||||
import inspect
|
||||
from typing import Any, Callable, Generic, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
|
@ -14,8 +14,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
|
@ -16,12 +16,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import functools
|
||||
import logging
|
||||
import string
|
||||
import sys
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Union
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
|
@ -16,13 +16,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
from collections.abc import Iterator, Sequence
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import itertools
|
||||
import time
|
||||
from typing import Any, Callable, NamedTuple
|
||||
from typing import Any, NamedTuple
|
||||
import logging
|
||||
import threading
|
||||
|
||||
|
@ -17,13 +17,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import re
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Union
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
|
@ -16,9 +16,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, TypeVar
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import TypeVar
|
||||
|
||||
try:
|
||||
import flatbuffers
|
||||
|
@ -21,7 +21,7 @@ import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class PyTreeDefKind(object):
|
||||
class PyTreeDefKind:
|
||||
leaf = 0
|
||||
none = 1
|
||||
tuple = 2
|
||||
@ -29,12 +29,12 @@ class PyTreeDefKind(object):
|
||||
dict = 4
|
||||
|
||||
|
||||
class AbstractValueKind(object):
|
||||
class AbstractValueKind:
|
||||
shapedArray = 0
|
||||
abstractToken = 1
|
||||
|
||||
|
||||
class DType(object):
|
||||
class DType:
|
||||
bool = 0
|
||||
i8 = 1
|
||||
i16 = 2
|
||||
@ -60,18 +60,18 @@ class DType(object):
|
||||
f0 = 22
|
||||
|
||||
|
||||
class ShardingKind(object):
|
||||
class ShardingKind:
|
||||
unspecified = 0
|
||||
hlo_sharding = 1
|
||||
|
||||
|
||||
class DisabledSafetyCheckKind(object):
|
||||
class DisabledSafetyCheckKind:
|
||||
platform = 0
|
||||
custom_call = 1
|
||||
shape_assertions = 2
|
||||
|
||||
|
||||
class PyTreeDef(object):
|
||||
class PyTreeDef:
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
@ -163,7 +163,7 @@ def PyTreeDefEnd(builder):
|
||||
|
||||
|
||||
|
||||
class AbstractValue(object):
|
||||
class AbstractValue:
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
@ -235,7 +235,7 @@ def AbstractValueEnd(builder):
|
||||
|
||||
|
||||
|
||||
class Sharding(object):
|
||||
class Sharding:
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
@ -304,7 +304,7 @@ def ShardingEnd(builder):
|
||||
|
||||
|
||||
|
||||
class Effect(object):
|
||||
class Effect:
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
@ -340,7 +340,7 @@ def EffectEnd(builder):
|
||||
|
||||
|
||||
|
||||
class DisabledSafetyCheck(object):
|
||||
class DisabledSafetyCheck:
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
@ -386,7 +386,7 @@ def DisabledSafetyCheckEnd(builder):
|
||||
|
||||
|
||||
|
||||
class Exported(object):
|
||||
class Exported:
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
|
@ -19,7 +19,7 @@ See documentation at https://jax.readthedocs.io/en/latest/export/shape_poly.html
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
from enum import Enum
|
||||
import functools
|
||||
@ -28,7 +28,7 @@ import io
|
||||
import copy
|
||||
import operator as op
|
||||
import tokenize
|
||||
from typing import Any, Callable, Union, overload
|
||||
from typing import Any, Union, overload
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
@ -12,8 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable
|
||||
from collections.abc import Hashable
|
||||
from collections.abc import Callable, Hashable
|
||||
|
||||
from jax import Array
|
||||
|
||||
|
@ -14,10 +14,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
import enum
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -70,13 +70,13 @@ then update `test_custom_call_coverage`, and then update your `test_foo_call`:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
import dataclasses
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from absl import logging
|
||||
|
||||
|
@ -38,11 +38,11 @@ to fail. A Limitation is specific to a harness.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
import operator
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Callable, NamedTuple, Union
|
||||
from typing import Any, NamedTuple, Union
|
||||
|
||||
from absl import testing
|
||||
import numpy as np
|
||||
|
@ -14,12 +14,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools as it
|
||||
from functools import partial
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import config
|
||||
|
@ -14,10 +14,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Iterable, Sequence
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Iterator, Sequence
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
from functools import partial
|
||||
@ -27,7 +27,7 @@ import os
|
||||
import re
|
||||
import types
|
||||
import typing
|
||||
from typing import Any, Callable, NamedTuple, Protocol, Union, cast as type_cast
|
||||
from typing import Any, NamedTuple, Protocol, Union, cast as type_cast
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
@ -14,13 +14,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import namedtuple
|
||||
from collections.abc import Sequence, Hashable
|
||||
from collections.abc import Callable, Sequence, Hashable
|
||||
from contextlib import contextmanager, AbstractContextManager
|
||||
from functools import partial
|
||||
import inspect
|
||||
import itertools as it
|
||||
import operator as op
|
||||
from typing import Any, Callable, NamedTuple, Union
|
||||
from typing import Any, NamedTuple, Union
|
||||
from weakref import ref
|
||||
|
||||
import numpy as np
|
||||
|
@ -19,15 +19,14 @@ import enum
|
||||
from contextlib import contextmanager
|
||||
import collections
|
||||
from collections import namedtuple
|
||||
from collections.abc import Sequence, Iterable
|
||||
from collections.abc import Callable, Sequence, Iterable, Iterator
|
||||
import dataclasses
|
||||
from functools import partial, lru_cache, cached_property
|
||||
import itertools as it
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
from typing import Any, Callable, NamedTuple, TypeVar, Union, cast
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, NamedTuple, TypeVar, Union, cast
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
@ -17,12 +17,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
from typing import Any, Callable, Protocol, Union
|
||||
from typing import Any, Protocol, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -17,11 +17,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Callable
|
||||
import gzip
|
||||
import itertools
|
||||
import json
|
||||
import types
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import util
|
||||
|
@ -15,10 +15,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
|
@ -15,13 +15,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, Callable, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import jax
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
|
@ -15,10 +15,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import functools
|
||||
import operator
|
||||
from typing import Any, Callable, Generic, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
|
@ -14,12 +14,12 @@
|
||||
"""Module for the loop primitives."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
import inspect
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, Callable, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
import weakref
|
||||
|
||||
import jax
|
||||
|
@ -15,14 +15,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import enum
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
from typing import Any, Callable, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING
|
||||
from typing import Any, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -2986,10 +2986,10 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
|
||||
m, k = lhs.shape
|
||||
group_count, rk, n = rhs.shape
|
||||
if k != rk:
|
||||
raise TypeError("ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {} and {}.".format(k, rk))
|
||||
raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.")
|
||||
num_groups = group_sizes.shape[0]
|
||||
if group_count != num_groups:
|
||||
raise TypeError("ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {} and {}.".format(group_count, num_groups))
|
||||
raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.")
|
||||
return (m, n)
|
||||
|
||||
# DotDimensionNumbers used in the dot_general call for ragged_dot().
|
||||
|
@ -14,10 +14,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Any, Callable, Literal, TypeVar, overload
|
||||
from typing import Any, Literal, TypeVar, overload
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -14,12 +14,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import enum
|
||||
import operator
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Callable, NamedTuple
|
||||
from typing import NamedTuple
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
|
@ -14,9 +14,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
import warnings
|
||||
|
||||
from jax import tree_util
|
||||
|
@ -14,9 +14,9 @@
|
||||
|
||||
"""A LazyLoader class."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import importlib
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
def attach(package_name: str, submodules: Sequence[str]) -> tuple[
|
||||
|
@ -63,8 +63,9 @@ data must be immutable, because it will be stored in function memoization tables
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Any, Callable, NamedTuple
|
||||
from typing import Any, NamedTuple
|
||||
import weakref
|
||||
|
||||
from jax._src import config
|
||||
|
@ -15,12 +15,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict, abc
|
||||
from collections.abc import Iterable, Sequence, Mapping
|
||||
from collections.abc import Callable, Iterable, Sequence, Mapping
|
||||
import contextlib
|
||||
from functools import wraps, partial, partialmethod, lru_cache
|
||||
import itertools as it
|
||||
import math
|
||||
from typing import Callable, Any, NamedTuple, Union, cast as type_cast
|
||||
from typing import Any, NamedTuple, Union, cast as type_cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -27,13 +27,13 @@ from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import collections
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
import importlib
|
||||
import math
|
||||
import operator
|
||||
import types
|
||||
from typing import (cast, overload, Any, Callable, Literal, NamedTuple,
|
||||
from typing import (cast, overload, Any, Literal, NamedTuple,
|
||||
Protocol, TypeVar, Union)
|
||||
from textwrap import dedent as _dedent
|
||||
import warnings
|
||||
|
@ -15,11 +15,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from typing import overload, Any, Callable, Literal, Protocol, Union
|
||||
from typing import overload, Any, Literal, Protocol, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
@ -16,10 +16,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
|
@ -18,9 +18,9 @@ Implements ufuncs for jax.numpy.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Callable
|
||||
|
||||
import warnings
|
||||
|
||||
|
@ -13,11 +13,11 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
import re
|
||||
import textwrap
|
||||
from typing import Any, Callable, NamedTuple, TypeVar
|
||||
from typing import Any, NamedTuple, TypeVar
|
||||
|
||||
import warnings
|
||||
|
||||
|
@ -13,10 +13,10 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Collection, Sequence
|
||||
from collections.abc import Callable, Collection, Sequence
|
||||
import functools
|
||||
import re
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from jax._src import api
|
||||
from jax import lax
|
||||
|
@ -16,8 +16,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Union
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
@ -15,13 +15,13 @@
|
||||
"""Module for pallas-core functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
import copy
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import threading
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
|
@ -78,7 +78,7 @@ class AbstractSemaphoreTy(dtypes.ExtendedDType):
|
||||
return self.__class__ == other.__class__
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.__class__))
|
||||
return hash(self.__class__)
|
||||
|
||||
# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy
|
||||
|
||||
@ -109,7 +109,7 @@ class SemaphoreType(enum.Enum):
|
||||
dtype = SemaphoreTy()
|
||||
return MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE)
|
||||
|
||||
def get_aval(self) -> "AbstractMemoryRef":
|
||||
def get_aval(self) -> AbstractMemoryRef:
|
||||
return self(()).get_aval()
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
@ -15,11 +15,11 @@
|
||||
"""Module for lowering JAX to Mosaic-compatible MLIR dialects."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
import string
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
|
@ -15,12 +15,13 @@
|
||||
"""Module for emitting custom TPU pipelines within a Pallas call."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Optional, Union, Any, Sequence
|
||||
from typing import Union, Any
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
@ -201,12 +202,12 @@ class BufferedRef:
|
||||
spec: pl.BlockSpec # static metadata
|
||||
dtype: Any # static metadata
|
||||
buffer_type: BufferType # static metadata
|
||||
vmem_ref: Optional[REF]
|
||||
accum_ref: Optional[REF]
|
||||
current_slot: Optional[ArrayRef]
|
||||
next_slot: Optional[ArrayRef]
|
||||
sem_recv: Optional[SemaphoreType]
|
||||
sem_send: Optional[SemaphoreType]
|
||||
vmem_ref: REF | None
|
||||
accum_ref: REF | None
|
||||
current_slot: ArrayRef | None
|
||||
next_slot: ArrayRef | None
|
||||
sem_recv: SemaphoreType | None
|
||||
sem_send: SemaphoreType | None
|
||||
|
||||
def tree_flatten(self):
|
||||
return ((self.vmem_ref, self.accum_ref, self.current_slot,
|
||||
@ -218,7 +219,7 @@ class BufferedRef:
|
||||
return cls(*meta, *data)
|
||||
|
||||
@classmethod
|
||||
def create(cls, spec, dtype, buffer_type) -> 'BufferedRef':
|
||||
def create(cls, spec, dtype, buffer_type) -> BufferedRef:
|
||||
"""Create a BufferedRef.
|
||||
|
||||
Args:
|
||||
@ -810,9 +811,9 @@ def _partition_grid(
|
||||
if isinstance(grid[i], int) and grid[i] % num_cores == 0
|
||||
}
|
||||
if divisible_dimensions:
|
||||
first_divisible_dimension, *_ = [
|
||||
first_divisible_dimension, *_ = (
|
||||
i for i in range(len(dimension_semantics)) if i in divisible_dimensions
|
||||
]
|
||||
)
|
||||
partitioned_dim_size = grid[first_divisible_dimension] // num_cores
|
||||
partitioned_dim_offset = pl.program_id(core_axis) * partitioned_dim_size
|
||||
new_grid = jax_util.tuple_update(
|
||||
@ -828,11 +829,11 @@ def _partition_grid(
|
||||
# potentially divide it more evenly
|
||||
largest_parallel_dimension = max(grid[i] for i in parallel_dimensions
|
||||
if isinstance(grid[i], int)) # type: ignore
|
||||
partition_dimension, *_ = [
|
||||
partition_dimension, *_ = (
|
||||
i
|
||||
for i, d in enumerate(grid)
|
||||
if isinstance(d, int) and d == largest_parallel_dimension
|
||||
]
|
||||
)
|
||||
base_num_iters, rem = divmod(grid[partition_dimension], num_cores)
|
||||
assert rem > 0, rem
|
||||
# We have some remainder iterations that we need to assign somewhere. We
|
||||
|
@ -15,9 +15,10 @@
|
||||
"""Module for Pallas:TPU-specific JAX primitives and functions."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
|
@ -11,7 +11,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable, Optional
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
@ -172,7 +173,7 @@ def sample_block(sampler_fn: SampleFnType,
|
||||
block_size: Shape,
|
||||
tile_size: Shape,
|
||||
total_size: Shape,
|
||||
block_index: Optional[tuple[typing.ArrayLike, ...]] = None,
|
||||
block_index: tuple[typing.ArrayLike, ...] | None = None,
|
||||
**kwargs) -> jax.Array:
|
||||
"""Samples a block of random values with invariance guarantees.
|
||||
|
||||
|
@ -15,10 +15,10 @@
|
||||
"""Module for calling pallas functions from JAX."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial, reduce
|
||||
import itertools
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import api_util
|
||||
|
@ -16,12 +16,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
import math
|
||||
import operator
|
||||
from typing import Any, Callable, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
|
@ -15,7 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence, Iterable
|
||||
from collections.abc import Callable, Sequence, Iterable
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import inspect
|
||||
@ -23,7 +23,7 @@ import itertools as it
|
||||
import logging
|
||||
import operator as op
|
||||
import weakref
|
||||
from typing import Callable, NamedTuple, Any, Union, Optional, cast
|
||||
from typing import NamedTuple, Any, Union, cast
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
@ -245,7 +245,7 @@ def _need_to_rebuild_with_fdo(pgle_profiler):
|
||||
def _get_fastpath_data(
|
||||
executable, out_tree, args_flat, out_flat, attrs_tracked, effects,
|
||||
consts, abstracted_axes, pgle_profiler
|
||||
) -> Optional[pxla.MeshExecutableFastpathData]:
|
||||
) -> pxla.MeshExecutableFastpathData | None:
|
||||
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
|
||||
|
||||
use_fastpath = (
|
||||
@ -608,7 +608,7 @@ def _infer_params_impl(
|
||||
assert None not in in_shardings_leaves
|
||||
assert None not in out_shardings_leaves
|
||||
|
||||
in_type: Union[core.InputType, tuple[core.AbstractValue, ...]]
|
||||
in_type: core.InputType | tuple[core.AbstractValue, ...]
|
||||
if config.dynamic_shapes.value:
|
||||
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
|
||||
in_avals = tuple(a for a, e in in_type if e)
|
||||
|
@ -13,11 +13,11 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
from functools import partial, reduce
|
||||
import math
|
||||
import operator as op
|
||||
from typing import Any, Callable, NamedTuple
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
import glob
|
||||
@ -24,7 +25,7 @@ import logging
|
||||
import os
|
||||
import socketserver
|
||||
import threading
|
||||
from typing import Callable, List, Optional, Union, Any
|
||||
from typing import Any
|
||||
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
@ -210,7 +211,7 @@ def stop_trace():
|
||||
_profile_state.reset()
|
||||
|
||||
|
||||
def stop_and_get_fdo_profile() -> Union[bytes, str]:
|
||||
def stop_and_get_fdo_profile() -> bytes | str:
|
||||
"""Stops the currently-running profiler trace and export fdo_profile.
|
||||
|
||||
Currently, this is only supported for GPU.
|
||||
@ -391,10 +392,10 @@ class PGLEProfiler:
|
||||
self.percentile: int = percentile
|
||||
self.collected_fdo: str | None = None
|
||||
self.called_times: int = 0
|
||||
self.fdo_profiles: List[Any] = []
|
||||
self.fdo_profiles: list[Any] = []
|
||||
self.current_session: xla_client.profiler.ProfilerSession | None = None
|
||||
|
||||
def consume_fdo_profile(self) -> Optional[str]:
|
||||
def consume_fdo_profile(self) -> str | None:
|
||||
if self.collected_fdo is not None:
|
||||
return self.collected_fdo
|
||||
|
||||
|
@ -12,11 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Callable
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import util
|
||||
|
@ -15,8 +15,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, NamedTuple
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import NamedTuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
@ -15,8 +15,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Callable, NamedTuple
|
||||
from typing import NamedTuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
@ -14,8 +14,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src.scipy.optimize.bfgs import minimize_bfgs
|
||||
|
@ -14,11 +14,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from typing import Callable
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
@ -18,9 +18,10 @@ An implementation of sourcemaps following `TC39 <https://tc39.es/source-map>`_.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from typing import Iterable, Sequence, Union
|
||||
from typing import Union
|
||||
|
||||
# A Segment encodes how parts in the generated source relate to the original source.
|
||||
# Each segment is made up of 1, 4 or 5 variable-length fields. For their semantics see
|
||||
|
@ -315,7 +315,7 @@ class XlaLowering(Lowering):
|
||||
def hlo(self) -> xc.XlaComputation:
|
||||
"""Return an HLO representation of this computation."""
|
||||
hlo = self.stablehlo()
|
||||
m: Union[str, bytes]
|
||||
m: str | bytes
|
||||
m = mlir.module_to_bytecode(hlo)
|
||||
return xla_extension.mlir.mlir_module_to_xla_computation(
|
||||
m, use_tuple_args=self.compile_args["tuple_args"])
|
||||
|
@ -14,11 +14,11 @@
|
||||
"""Module for discharging state primitives."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Any, Callable, Protocol
|
||||
from typing import Any, Protocol
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -65,9 +65,9 @@ class Slice:
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children) -> Slice:
|
||||
start, size = [
|
||||
start, size = (
|
||||
a if a is not None else b for a, b in zip(children, aux_data[:2])
|
||||
]
|
||||
)
|
||||
return cls(start, size, aux_data[2])
|
||||
|
||||
@classmethod
|
||||
|
@ -16,7 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from collections.abc import Callable, Generator, Iterable, Sequence
|
||||
from contextlib import ExitStack, contextmanager
|
||||
import datetime
|
||||
import functools
|
||||
@ -28,7 +28,7 @@ import re
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
import unittest
|
||||
import warnings
|
||||
import zlib
|
||||
|
2
jax/_src/third_party/scipy/linalg.py
vendored
2
jax/_src/third_party/scipy/linalg.py
vendored
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
from jax import jit, lax
|
||||
import jax.numpy as jnp
|
||||
|
@ -19,13 +19,13 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import collections.abc
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
|
@ -14,12 +14,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
from typing import Any, Callable, TypeVar, cast
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import util
|
||||
|
@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Iterable, TypeVar, overload
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
from jax._src import tree_util
|
||||
|
||||
|
@ -14,14 +14,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Hashable, Iterable
|
||||
from collections.abc import Callable, Hashable, Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
import difflib
|
||||
import functools
|
||||
from functools import partial
|
||||
import operator as op
|
||||
import textwrap
|
||||
from typing import Any, Callable, NamedTuple, Sequence, TypeVar, Union, overload
|
||||
from typing import Any, NamedTuple, TypeVar, Union, overload
|
||||
|
||||
from jax._src import traceback_util
|
||||
from jax._src.lib import pytree
|
||||
|
@ -15,14 +15,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import logging
|
||||
import operator
|
||||
from typing import (Any, Callable, Generic, TypeVar, overload, TYPE_CHECKING, cast)
|
||||
from typing import (Any, Generic, TypeVar, overload, TYPE_CHECKING, cast)
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
|
@ -21,7 +21,7 @@ XLA. There are also a handful of related casting utilities.
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Callable, Mapping
|
||||
import dataclasses
|
||||
from functools import lru_cache, partial
|
||||
import importlib
|
||||
@ -32,7 +32,7 @@ import pkgutil
|
||||
import platform as py_platform
|
||||
import threading
|
||||
import traceback
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Union
|
||||
import warnings
|
||||
|
||||
from jax._src import config
|
||||
|
@ -91,7 +91,8 @@ Example Usage:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, NamedTuple
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from collections import namedtuple
|
||||
import functools
|
||||
|
@ -15,7 +15,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import jax
|
||||
from typing import Tuple
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src import dtypes as _dtypes, config
|
||||
@ -71,7 +70,7 @@ class __array_namespace_info__:
|
||||
def dtypes(
|
||||
self, *,
|
||||
device: xc.Device | Sharding | None = None,
|
||||
kind: str | Tuple[str, ...] | None = None):
|
||||
kind: str | tuple[str, ...] | None = None):
|
||||
# Array API supported dtypes are device-independent in JAX
|
||||
del device
|
||||
data_types = self._build_dtype_dict()
|
||||
|
@ -17,16 +17,15 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from functools import partial
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import array
|
||||
@ -130,7 +129,7 @@ def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
|
||||
return spec
|
||||
|
||||
|
||||
def is_remote_storage(tspec: Union[dict[str, Any], str]) -> bool:
|
||||
def is_remote_storage(tspec: dict[str, Any] | str) -> bool:
|
||||
"""Detect if user is using cloud storages.
|
||||
|
||||
This can detect common defines and unable to detect some corner cases such as
|
||||
@ -190,7 +189,7 @@ async def async_serialize(
|
||||
tensorstore_spec,
|
||||
commit_future=None,
|
||||
context=TS_CONTEXT,
|
||||
primary_host: Optional[int] = 0,
|
||||
primary_host: int | None = 0,
|
||||
replica_id: int = 0,
|
||||
):
|
||||
"""Serialize an array using TensorStore.
|
||||
|
@ -503,14 +503,14 @@ from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import enum
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
import traceback
|
||||
from typing import Any, Callable, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import jax
|
||||
from jax._src import api
|
||||
|
@ -25,10 +25,10 @@ https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#callin
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any
|
||||
|
||||
from absl import logging
|
||||
import jax
|
||||
|
@ -21,12 +21,12 @@ See README.md for how these are used.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import functools
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any
|
||||
from absl import flags
|
||||
|
||||
import flax
|
||||
|
@ -26,8 +26,8 @@ customize this function as needed.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from jax.experimental import jax2tf
|
||||
import tensorflow as tf
|
||||
|
@ -16,12 +16,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
from functools import partial, wraps
|
||||
import math
|
||||
import string
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any
|
||||
|
||||
from jax._src import core
|
||||
from jax import lax
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from functools import partial
|
||||
import contextlib
|
||||
import math
|
||||
@ -23,7 +23,7 @@ import operator
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Union
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
|
@ -20,11 +20,10 @@ these tests.
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import io
|
||||
import os
|
||||
import tarfile
|
||||
from typing import Callable, Optional
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
|
@ -13,10 +13,10 @@
|
||||
# limitations under the License.
|
||||
"""Tests for call_tf."""
|
||||
|
||||
from collections.abc import Callable
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import os
|
||||
from typing import Callable
|
||||
import unittest
|
||||
|
||||
from absl import logging
|
||||
|
@ -12,10 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Converters for jax2tf."""
|
||||
|
||||
from collections.abc import Callable
|
||||
import dataclasses
|
||||
import functools
|
||||
import tempfile
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from jax.experimental import jax2tf
|
||||
import tensorflow as tf
|
||||
import tensorflowjs as tfjs
|
||||
|
@ -26,12 +26,11 @@ currently saved file with the saved one.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import os
|
||||
import re
|
||||
from typing import Callable, Optional
|
||||
import zlib
|
||||
|
||||
from absl import app
|
||||
|
@ -18,8 +18,9 @@ https://github.com/google/flax/tree/main/examples/sst2
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any
|
||||
|
||||
from flax import linen as nn
|
||||
import jax
|
||||
|
@ -16,8 +16,7 @@
|
||||
https://github.com/google/flax/tree/main/examples/ogbg_molpcba
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from flax import linen as nn
|
||||
|
||||
|
@ -19,9 +19,9 @@ https://github.com/google/flax/tree/main/examples/imagenet
|
||||
# See issue #620.
|
||||
# pytype: disable=wrong-arg-count
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from flax import linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user