Run pyupgrade --py310-plus.

Also apply manual fixes to import sorting and unused imports.
This commit is contained in:
Peter Hawkins 2024-06-26 14:44:52 -04:00
parent cdfe2df384
commit 7f4ef63cd8
140 changed files with 387 additions and 379 deletions

View File

@ -32,7 +32,7 @@ def extract_filename(path):
def generate_final_report(shell=False, env_vars={}): def generate_final_report(shell=False, env_vars={}):
env = os.environ env = os.environ
env = {**env, **env_vars} env = {**env, **env_vars}
cmd = ["pytest_html_merger", "-i", '{}'.format(base_dir), "-o", '{}/final_compiled_report.html'.format(base_dir)] cmd = ["pytest_html_merger", "-i", f'{base_dir}', "-o", f'{base_dir}/final_compiled_report.html']
result = subprocess.run(cmd, result = subprocess.run(cmd,
shell=shell, shell=shell,
capture_output=True, capture_output=True,
@ -90,7 +90,7 @@ def run_test(testmodule, gpu_tokens):
"XLA_PYTHON_CLIENT_ALLOCATOR": "default", "XLA_PYTHON_CLIENT_ALLOCATOR": "default",
} }
testfile = extract_filename(testmodule) testfile = extract_filename(testmodule)
cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-x", testmodule] cmd = ["python3", "-m", "pytest", f'--html={base_dir}/{testfile}_log.html', "--reruns", "3", "-x", testmodule]
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars) return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
with GPU_LOCK: with GPU_LOCK:
gpu_tokens.append(target_gpu) gpu_tokens.append(target_gpu)

View File

@ -14,7 +14,6 @@
from functools import partial, reduce from functools import partial, reduce
import math import math
from typing import Tuple
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
@ -325,9 +324,9 @@ class RmsNormFwdClass:
return RmsNormFwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims return RmsNormFwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims
@staticmethod @staticmethod
def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh, def infer_sharding_from_operands(eps: float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
result_infos : Tuple[jax._src.core.ShapedArray]): result_infos: tuple[jax._src.core.ShapedArray, ...]):
del eps, result_infos # Not needed for this example. del eps, result_infos # Not needed for this example.
x_info, weight_info = arg_infos x_info, weight_info = arg_infos
assert len(x_info.shape) == 3 assert len(x_info.shape) == 3
@ -340,9 +339,9 @@ class RmsNormFwdClass:
return (output_sharding, invvar_sharding) return (output_sharding, invvar_sharding)
@staticmethod @staticmethod
def partition(eps : float, mesh : jax.sharding.Mesh, def partition(eps: float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
result_infos : Tuple[jax._src.api.ShapeDtypeStruct]): result_infos: tuple[jax._src.api.ShapeDtypeStruct, ...]):
del result_infos # Not needed for this example. del result_infos # Not needed for this example.
x_info, weight_info = arg_infos x_info, weight_info = arg_infos
assert len(x_info.shape) == 3 assert len(x_info.shape) == 3
@ -395,9 +394,9 @@ class RmsNormBwdClass:
return RmsNormBwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims return RmsNormBwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims
@staticmethod @staticmethod
def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh, def infer_sharding_from_operands(eps: float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
result_infos : Tuple[jax._src.core.ShapedArray]): result_infos: tuple[jax._src.core.ShapedArray, ...]):
del eps, result_infos # Not needed for this example. del eps, result_infos # Not needed for this example.
g_info, invvar_info, x_info, weight_info = arg_infos g_info, invvar_info, x_info, weight_info = arg_infos
assert len(g_info.shape) == 3 assert len(g_info.shape) == 3
@ -411,9 +410,9 @@ class RmsNormBwdClass:
return (output_sharding, invvar_sharding, output_sharding, ) return (output_sharding, invvar_sharding, output_sharding, )
@staticmethod @staticmethod
def partition(eps : float, mesh : jax.sharding.Mesh, def partition(eps: float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
result_infos : Tuple[jax._src.api.ShapeDtypeStruct]): result_infos: tuple[jax._src.api.ShapeDtypeStruct, ...]):
del result_infos # Not needed for this example. del result_infos # Not needed for this example.
g_info, invvar_info, x_info, weight_info = arg_infos g_info, invvar_info, x_info, weight_info = arg_infos
assert len(g_info.shape) == 3 assert len(g_info.shape) == 3

View File

@ -167,15 +167,15 @@
"source": [ "source": [
"from collections.abc import Sequence\n", "from collections.abc import Sequence\n",
"from contextlib import contextmanager\n", "from contextlib import contextmanager\n",
"from typing import Optional, Any\n", "from typing import Any\n",
"\n", "\n",
"class MainTrace(NamedTuple):\n", "class MainTrace(NamedTuple):\n",
" level: int\n", " level: int\n",
" trace_type: type['Trace']\n", " trace_type: type['Trace']\n",
" global_data: Optional[Any]\n", " global_data: Any | None\n",
"\n", "\n",
"trace_stack: list[MainTrace] = []\n", "trace_stack: list[MainTrace] = []\n",
"dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3\n", "dynamic_trace: MainTrace | None = None # to be employed in Part 3\n",
"\n", "\n",
"@contextmanager\n", "@contextmanager\n",
"def new_main(trace_type: type['Trace'], global_data=None):\n", "def new_main(trace_type: type['Trace'], global_data=None):\n",
@ -912,7 +912,7 @@
"source": [ "source": [
"from collections.abc import Hashable, Iterable, Iterator\n", "from collections.abc import Hashable, Iterable, Iterator\n",
"import itertools as it\n", "import itertools as it\n",
"from typing import Callable\n", "from collections.abc import Callable\n",
"\n", "\n",
"class NodeType(NamedTuple):\n", "class NodeType(NamedTuple):\n",
" name: str\n", " name: str\n",
@ -1651,7 +1651,7 @@
"source": [ "source": [
"from functools import lru_cache\n", "from functools import lru_cache\n",
"\n", "\n",
"@lru_cache() # ShapedArrays are hashable\n", "@lru_cache # ShapedArrays are hashable\n",
"def make_jaxpr_v1(f, *avals_in):\n", "def make_jaxpr_v1(f, *avals_in):\n",
" avals_in, in_tree = tree_flatten(avals_in)\n", " avals_in, in_tree = tree_flatten(avals_in)\n",
" f, out_tree = flatten_fun(f, in_tree)\n", " f, out_tree = flatten_fun(f, in_tree)\n",
@ -1803,7 +1803,7 @@
" finally:\n", " finally:\n",
" dynamic_trace = prev_dynamic_trace\n", " dynamic_trace = prev_dynamic_trace\n",
"\n", "\n",
"@lru_cache()\n", "@lru_cache\n",
"def make_jaxpr(f: Callable, *avals_in: ShapedArray,\n", "def make_jaxpr(f: Callable, *avals_in: ShapedArray,\n",
" ) -> tuple[Jaxpr, list[Any], PyTreeDef]:\n", " ) -> tuple[Jaxpr, list[Any], PyTreeDef]:\n",
" avals_in, in_tree = tree_flatten(avals_in)\n", " avals_in, in_tree = tree_flatten(avals_in)\n",
@ -1994,7 +1994,7 @@
" return execute(*args)\n", " return execute(*args)\n",
"impl_rules[xla_call_p] = xla_call_impl\n", "impl_rules[xla_call_p] = xla_call_impl\n",
"\n", "\n",
"@lru_cache()\n", "@lru_cache\n",
"def xla_callable(hashable_jaxpr: IDHashable,\n", "def xla_callable(hashable_jaxpr: IDHashable,\n",
" hashable_consts: tuple[IDHashable, ...]):\n", " hashable_consts: tuple[IDHashable, ...]):\n",
" jaxpr: Jaxpr = hashable_jaxpr.val\n", " jaxpr: Jaxpr = hashable_jaxpr.val\n",
@ -2227,7 +2227,7 @@
" return primals_out, tangents_out\n", " return primals_out, tangents_out\n",
"jvp_rules[xla_call_p] = xla_call_jvp_rule\n", "jvp_rules[xla_call_p] = xla_call_jvp_rule\n",
"\n", "\n",
"@lru_cache()\n", "@lru_cache\n",
"def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:\n", "def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:\n",
" def jvp_traceable(*primals_and_tangents):\n", " def jvp_traceable(*primals_and_tangents):\n",
" n = len(primals_and_tangents) // 2\n", " n = len(primals_and_tangents) // 2\n",
@ -2253,7 +2253,7 @@
" return outs, [0] * len(outs)\n", " return outs, [0] * len(outs)\n",
"vmap_rules[xla_call_p] = xla_call_vmap_rule\n", "vmap_rules[xla_call_p] = xla_call_vmap_rule\n",
"\n", "\n",
"@lru_cache()\n", "@lru_cache\n",
"def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]\n", "def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]\n",
" ) -> tuple[Jaxpr, list[Any]]:\n", " ) -> tuple[Jaxpr, list[Any]]:\n",
" vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))\n", " vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))\n",
@ -2638,7 +2638,7 @@
"source": [ "source": [
"class PartialVal(NamedTuple):\n", "class PartialVal(NamedTuple):\n",
" aval: ShapedArray\n", " aval: ShapedArray\n",
" const: Optional[Any]\n", " const: Any | None\n",
"\n", "\n",
" @classmethod\n", " @classmethod\n",
" def known(cls, val: Any):\n", " def known(cls, val: Any):\n",
@ -2727,7 +2727,7 @@
"source": [ "source": [
"class PartialEvalTracer(Tracer):\n", "class PartialEvalTracer(Tracer):\n",
" pval: PartialVal\n", " pval: PartialVal\n",
" recipe: Optional[JaxprRecipe]\n", " recipe: JaxprRecipe | None\n",
"\n", "\n",
" def __init__(self, trace, pval, recipe):\n", " def __init__(self, trace, pval, recipe):\n",
" self._trace = trace\n", " self._trace = trace\n",
@ -2974,7 +2974,7 @@
"partial_eval_rules[xla_call_p] = xla_call_partial_eval\n", "partial_eval_rules[xla_call_p] = xla_call_partial_eval\n",
"\n", "\n",
"def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],\n", "def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],\n",
" instantiate: Optional[list[bool]] = None,\n", " instantiate: list[bool] | None = None,\n",
" ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:\n", " ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:\n",
" env: dict[Var, bool] = {}\n", " env: dict[Var, bool] = {}\n",
" residuals: set[Var] = set()\n", " residuals: set[Var] = set()\n",
@ -3271,7 +3271,7 @@
" return [next(outs) if undef else None for undef in undef_primals]\n", " return [next(outs) if undef else None for undef in undef_primals]\n",
"transpose_rules[xla_call_p] = xla_call_transpose_rule\n", "transpose_rules[xla_call_p] = xla_call_transpose_rule\n",
"\n", "\n",
"@lru_cache()\n", "@lru_cache\n",
"def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]\n", "def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]\n",
" ) -> tuple[Jaxpr, list[Any]]:\n", " ) -> tuple[Jaxpr, list[Any]]:\n",
" avals_in, avals_out = typecheck_jaxpr(jaxpr)\n", " avals_in, avals_out = typecheck_jaxpr(jaxpr)\n",

View File

@ -148,15 +148,15 @@ more descriptive.
```{code-cell} ```{code-cell}
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Any from typing import Any
class MainTrace(NamedTuple): class MainTrace(NamedTuple):
level: int level: int
trace_type: type['Trace'] trace_type: type['Trace']
global_data: Optional[Any] global_data: Any | None
trace_stack: list[MainTrace] = [] trace_stack: list[MainTrace] = []
dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3 dynamic_trace: MainTrace | None = None # to be employed in Part 3
@contextmanager @contextmanager
def new_main(trace_type: type['Trace'], global_data=None): def new_main(trace_type: type['Trace'], global_data=None):
@ -705,7 +705,7 @@ class Store:
from collections.abc import Hashable, Iterable, Iterator from collections.abc import Hashable, Iterable, Iterator
import itertools as it import itertools as it
from typing import Callable from collections.abc import Callable
class NodeType(NamedTuple): class NodeType(NamedTuple):
name: str name: str
@ -1295,7 +1295,7 @@ transformation and a pretty-printer:
```{code-cell} ```{code-cell}
from functools import lru_cache from functools import lru_cache
@lru_cache() # ShapedArrays are hashable @lru_cache # ShapedArrays are hashable
def make_jaxpr_v1(f, *avals_in): def make_jaxpr_v1(f, *avals_in):
avals_in, in_tree = tree_flatten(avals_in) avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree) f, out_tree = flatten_fun(f, in_tree)
@ -1415,7 +1415,7 @@ def new_dynamic(main: MainTrace):
finally: finally:
dynamic_trace = prev_dynamic_trace dynamic_trace = prev_dynamic_trace
@lru_cache() @lru_cache
def make_jaxpr(f: Callable, *avals_in: ShapedArray, def make_jaxpr(f: Callable, *avals_in: ShapedArray,
) -> tuple[Jaxpr, list[Any], PyTreeDef]: ) -> tuple[Jaxpr, list[Any], PyTreeDef]:
avals_in, in_tree = tree_flatten(avals_in) avals_in, in_tree = tree_flatten(avals_in)
@ -1564,7 +1564,7 @@ def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
return execute(*args) return execute(*args)
impl_rules[xla_call_p] = xla_call_impl impl_rules[xla_call_p] = xla_call_impl
@lru_cache() @lru_cache
def xla_callable(hashable_jaxpr: IDHashable, def xla_callable(hashable_jaxpr: IDHashable,
hashable_consts: tuple[IDHashable, ...]): hashable_consts: tuple[IDHashable, ...]):
jaxpr: Jaxpr = hashable_jaxpr.val jaxpr: Jaxpr = hashable_jaxpr.val
@ -1734,7 +1734,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
return primals_out, tangents_out return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule jvp_rules[xla_call_p] = xla_call_jvp_rule
@lru_cache() @lru_cache
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]: def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
def jvp_traceable(*primals_and_tangents): def jvp_traceable(*primals_and_tangents):
n = len(primals_and_tangents) // 2 n = len(primals_and_tangents) // 2
@ -1755,7 +1755,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
return outs, [0] * len(outs) return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule vmap_rules[xla_call_p] = xla_call_vmap_rule
@lru_cache() @lru_cache
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...] def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
) -> tuple[Jaxpr, list[Any]]: ) -> tuple[Jaxpr, list[Any]]:
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in)) vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
@ -2065,7 +2065,7 @@ be either known or unknown:
```{code-cell} ```{code-cell}
class PartialVal(NamedTuple): class PartialVal(NamedTuple):
aval: ShapedArray aval: ShapedArray
const: Optional[Any] const: Any | None
@classmethod @classmethod
def known(cls, val: Any): def known(cls, val: Any):
@ -2129,7 +2129,7 @@ JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
```{code-cell} ```{code-cell}
class PartialEvalTracer(Tracer): class PartialEvalTracer(Tracer):
pval: PartialVal pval: PartialVal
recipe: Optional[JaxprRecipe] recipe: JaxprRecipe | None
def __init__(self, trace, pval, recipe): def __init__(self, trace, pval, recipe):
self._trace = trace self._trace = trace
@ -2329,7 +2329,7 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
partial_eval_rules[xla_call_p] = xla_call_partial_eval partial_eval_rules[xla_call_p] = xla_call_partial_eval
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool], def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
instantiate: Optional[list[bool]] = None, instantiate: list[bool] | None = None,
) -> tuple[Jaxpr, Jaxpr, list[bool], int]: ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
env: dict[Var, bool] = {} env: dict[Var, bool] = {}
residuals: set[Var] = set() residuals: set[Var] = set()
@ -2586,7 +2586,7 @@ def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
return [next(outs) if undef else None for undef in undef_primals] return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule transpose_rules[xla_call_p] = xla_call_transpose_rule
@lru_cache() @lru_cache
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...] def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
) -> tuple[Jaxpr, list[Any]]: ) -> tuple[Jaxpr, list[Any]]:
avals_in, avals_out = typecheck_jaxpr(jaxpr) avals_in, avals_out = typecheck_jaxpr(jaxpr)

View File

@ -138,15 +138,15 @@ def bind1(prim, *args, **params):
# + # +
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Any from typing import Any
class MainTrace(NamedTuple): class MainTrace(NamedTuple):
level: int level: int
trace_type: type['Trace'] trace_type: type['Trace']
global_data: Optional[Any] global_data: Any | None
trace_stack: list[MainTrace] = [] trace_stack: list[MainTrace] = []
dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3 dynamic_trace: MainTrace | None = None # to be employed in Part 3
@contextmanager @contextmanager
def new_main(trace_type: type['Trace'], global_data=None): def new_main(trace_type: type['Trace'], global_data=None):
@ -697,7 +697,7 @@ class Store:
# + tags=["hide-input"] # + tags=["hide-input"]
from collections.abc import Hashable, Iterable, Iterator from collections.abc import Hashable, Iterable, Iterator
import itertools as it import itertools as it
from typing import Callable from collections.abc import Callable
class NodeType(NamedTuple): class NodeType(NamedTuple):
name: str name: str
@ -1297,7 +1297,7 @@ abstract_eval_rules[broadcast_p] = broadcast_abstract_eval
# + # +
from functools import lru_cache from functools import lru_cache
@lru_cache() # ShapedArrays are hashable @lru_cache # ShapedArrays are hashable
def make_jaxpr_v1(f, *avals_in): def make_jaxpr_v1(f, *avals_in):
avals_in, in_tree = tree_flatten(avals_in) avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree) f, out_tree = flatten_fun(f, in_tree)
@ -1412,7 +1412,7 @@ def new_dynamic(main: MainTrace):
finally: finally:
dynamic_trace = prev_dynamic_trace dynamic_trace = prev_dynamic_trace
@lru_cache() @lru_cache
def make_jaxpr(f: Callable, *avals_in: ShapedArray, def make_jaxpr(f: Callable, *avals_in: ShapedArray,
) -> tuple[Jaxpr, list[Any], PyTreeDef]: ) -> tuple[Jaxpr, list[Any], PyTreeDef]:
avals_in, in_tree = tree_flatten(avals_in) avals_in, in_tree = tree_flatten(avals_in)
@ -1556,7 +1556,7 @@ def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
return execute(*args) return execute(*args)
impl_rules[xla_call_p] = xla_call_impl impl_rules[xla_call_p] = xla_call_impl
@lru_cache() @lru_cache
def xla_callable(hashable_jaxpr: IDHashable, def xla_callable(hashable_jaxpr: IDHashable,
hashable_consts: tuple[IDHashable, ...]): hashable_consts: tuple[IDHashable, ...]):
jaxpr: Jaxpr = hashable_jaxpr.val jaxpr: Jaxpr = hashable_jaxpr.val
@ -1728,7 +1728,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
return primals_out, tangents_out return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule jvp_rules[xla_call_p] = xla_call_jvp_rule
@lru_cache() @lru_cache
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]: def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
def jvp_traceable(*primals_and_tangents): def jvp_traceable(*primals_and_tangents):
n = len(primals_and_tangents) // 2 n = len(primals_and_tangents) // 2
@ -1749,7 +1749,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
return outs, [0] * len(outs) return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule vmap_rules[xla_call_p] = xla_call_vmap_rule
@lru_cache() @lru_cache
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...] def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
) -> tuple[Jaxpr, list[Any]]: ) -> tuple[Jaxpr, list[Any]]:
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in)) vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
@ -2057,7 +2057,7 @@ def vspace(aval: ShapedArray) -> ShapedArray:
class PartialVal(NamedTuple): class PartialVal(NamedTuple):
aval: ShapedArray aval: ShapedArray
const: Optional[Any] const: Any | None
@classmethod @classmethod
def known(cls, val: Any): def known(cls, val: Any):
@ -2121,7 +2121,7 @@ JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
class PartialEvalTracer(Tracer): class PartialEvalTracer(Tracer):
pval: PartialVal pval: PartialVal
recipe: Optional[JaxprRecipe] recipe: JaxprRecipe | None
def __init__(self, trace, pval, recipe): def __init__(self, trace, pval, recipe):
self._trace = trace self._trace = trace
@ -2322,7 +2322,7 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
partial_eval_rules[xla_call_p] = xla_call_partial_eval partial_eval_rules[xla_call_p] = xla_call_partial_eval
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool], def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
instantiate: Optional[list[bool]] = None, instantiate: list[bool] | None = None,
) -> tuple[Jaxpr, Jaxpr, list[bool], int]: ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
env: dict[Var, bool] = {} env: dict[Var, bool] = {}
residuals: set[Var] = set() residuals: set[Var] = set()
@ -2585,7 +2585,7 @@ def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
return [next(outs) if undef else None for undef in undef_primals] return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule transpose_rules[xla_call_p] = xla_call_transpose_rule
@lru_cache() @lru_cache
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...] def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
) -> tuple[Jaxpr, list[Any]]: ) -> tuple[Jaxpr, list[Any]]:
avals_in, avals_out = typecheck_jaxpr(jaxpr) avals_in, avals_out = typecheck_jaxpr(jaxpr)

View File

@ -14,11 +14,11 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import functools import functools
from functools import partial from functools import partial
import logging import logging
from typing import Any, Callable from typing import Any
import types import types
import numpy as np import numpy as np

View File

@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import types import types
from typing import Any, Callable, TypeVar from typing import Any, TypeVar
from jax._src import core from jax._src import core
from jax._src import traceback_util from jax._src import traceback_util

View File

@ -23,12 +23,12 @@ arrays.
from __future__ import annotations from __future__ import annotations
import collections import collections
from collections.abc import Generator, Hashable, Iterable, Sequence from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
from functools import partial, lru_cache from functools import partial, lru_cache
import inspect import inspect
import math import math
import typing import typing
from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload, from typing import (Any, Literal, NamedTuple, TypeVar, overload,
cast) cast)
import weakref import weakref

View File

@ -14,11 +14,11 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
import inspect import inspect
import operator import operator
from functools import partial, lru_cache from functools import partial, lru_cache
from typing import Any, Callable, Type from typing import Any
import numpy as np import numpy as np
@ -713,6 +713,6 @@ class _HashableByObjectId:
def __eq__(self, other): def __eq__(self, other):
return self.val is other.val return self.val is other.val
def register_class_with_attrs(t: Type) -> None: def register_class_with_attrs(t: type) -> None:
_class_with_attrs.add(t) _class_with_attrs.add(t)
_class_with_attrs: set[Type] = set() _class_with_attrs: set[type] = set()

View File

@ -15,12 +15,12 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence from collections.abc import Callable, Sequence
import enum import enum
import functools import functools
import math import math
import operator as op import operator as op
from typing import Any, Callable, TYPE_CHECKING, cast from typing import Any, TYPE_CHECKING, cast
from jax._src import abstract_arrays from jax._src import abstract_arrays
from jax._src import api from jax._src import api

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Protocol, Sequence from collections.abc import Sequence
from typing import Any, Protocol
import jax import jax
from jax._src import random from jax._src import random
from jax._src.typing import Array, ArrayLike from jax._src.typing import Array, ArrayLike

View File

@ -14,11 +14,11 @@
"""Module for JAX callbacks.""" """Module for JAX callbacks."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import dataclasses import dataclasses
import functools import functools
import logging import logging
from typing import Any, Callable from typing import Any
import jax import jax
from jax._src import core from jax._src import core

View File

@ -13,11 +13,11 @@
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import dataclasses import dataclasses
import functools import functools
import itertools as it import itertools as it
from typing import Callable, TypeVar, Any, Union from typing import TypeVar, Any, Union
import numpy as np import numpy as np

View File

@ -16,7 +16,6 @@ import os
from jax import version from jax import version
from jax._src import config from jax._src import config
from jax._src import hardware_utils from jax._src import hardware_utils
from typing import Optional
running_in_cloud_tpu_vm: bool = False running_in_cloud_tpu_vm: bool = False
@ -35,7 +34,7 @@ def maybe_import_libtpu():
return libtpu return libtpu
def get_tpu_library_path() -> Optional[str]: def get_tpu_library_path() -> str | None:
path_from_env = os.getenv("TPU_LIBRARY_PATH") path_from_env = os.getenv("TPU_LIBRARY_PATH")
if path_from_env is not None and os.path.isfile(path_from_env): if path_from_env is not None and os.path.isfile(path_from_env):
return path_from_env return path_from_env

View File

@ -21,7 +21,7 @@ import logging
import os import os
import tempfile import tempfile
import time import time
from typing import Any, Optional from typing import Any
import warnings import warnings
from jax._src import compilation_cache from jax._src import compilation_cache
@ -393,7 +393,7 @@ def _share_fdo_profiles(
backend: xc.Client, backend: xc.Client,
global_client: lib.xla_extension.DistributedRuntimeClient, global_client: lib.xla_extension.DistributedRuntimeClient,
min_process_id min_process_id
) -> Optional[bytes]: ) -> bytes | None:
sym_name = computation.operation.attributes['sym_name'] sym_name = computation.operation.attributes['sym_name']
module_name = ir.StringAttr(sym_name).value module_name = ir.StringAttr(sym_name).value
fdo_profile = compile_options.executable_build_options.fdo_profile fdo_profile = compile_options.executable_build_options.fdo_profile

View File

@ -14,7 +14,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Hashable, Iterator, Sequence from collections.abc import Callable, Hashable, Iterator, Sequence
import contextlib import contextlib
import functools import functools
import itertools import itertools
@ -22,9 +22,7 @@ import logging
import os import os
import sys import sys
import threading import threading
from typing import ( from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast
Any, Callable, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast,
)
from jax._src import lib from jax._src import lib
from jax._src.lib import jax_jit from jax._src.lib import jax_jit

View File

@ -14,8 +14,8 @@
from __future__ import annotations from __future__ import annotations
from collections import Counter, defaultdict, deque, namedtuple from collections import Counter, defaultdict, deque, namedtuple
from collections.abc import (Collection, Generator, Hashable, Iterable, from collections.abc import (Callable, Collection, Generator, Hashable,
Iterator, Set, Sequence, MutableSet, Iterable, Iterator, Set, Sequence, MutableSet,
MutableMapping) MutableMapping)
from contextlib import contextmanager, ExitStack from contextlib import contextmanager, ExitStack
from dataclasses import dataclass from dataclasses import dataclass
@ -28,7 +28,7 @@ import math
import operator import operator
import threading import threading
import types import types
from typing import (Any, Callable, ClassVar, Generic, NamedTuple, TypeVar, from typing import (Any, ClassVar, Generic, NamedTuple, TypeVar,
cast, overload, Union) cast, overload, Union)
import warnings import warnings
from weakref import ref from weakref import ref

View File

@ -15,7 +15,6 @@
from enum import Enum from enum import Enum
from functools import partial, reduce from functools import partial, reduce
import operator import operator
from typing import Optional
import json import json
import jax import jax
@ -927,10 +926,10 @@ _dot_product_attention.defvjp(_dot_product_attention_fwd_rule, _dot_product_atte
def dot_product_attention(query: Array, def dot_product_attention(query: Array,
key: Array, key: Array,
value: Array, value: Array,
bias: Optional[Array] = None, bias: Array | None = None,
mask: Optional[Array] = None, mask: Array | None = None,
q_seqlen: Optional[Array] = None, q_seqlen: Array | None = None,
kv_seqlen: Optional[Array] = None, kv_seqlen: Array | None = None,
*, *,
scale: float = 1.0, scale: float = 1.0,
mask_type: MaskType = MaskType.NO_MASK, mask_type: MaskType = MaskType.NO_MASK,

View File

@ -14,9 +14,9 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import functools import functools
import operator import operator
from typing import Callable
from jax import lax from jax import lax
from jax._src import api from jax._src import api

View File

@ -14,11 +14,11 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import dataclasses import dataclasses
from functools import update_wrapper, reduce, partial from functools import update_wrapper, reduce, partial
import inspect import inspect
from typing import Any, Callable, Generic, TypeVar from typing import Any, Generic, TypeVar
from jax._src import config from jax._src import config
from jax._src import core from jax._src import core

View File

@ -14,8 +14,9 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import functools import functools
from typing import Any, Callable from typing import Any
from jax._src import ad_util from jax._src import ad_util
from jax._src import api_util from jax._src import api_util

View File

@ -16,12 +16,12 @@
from __future__ import annotations from __future__ import annotations
import importlib.util import importlib.util
from collections.abc import Sequence from collections.abc import Callable, Sequence
import functools import functools
import logging import logging
import string import string
import sys import sys
from typing import Any, Callable, Union from typing import Any, Union
import weakref import weakref
import numpy as np import numpy as np

View File

@ -16,13 +16,13 @@
from __future__ import annotations from __future__ import annotations
import atexit import atexit
from collections.abc import Iterator, Sequence from collections.abc import Callable, Iterator, Sequence
import contextlib import contextlib
import dataclasses import dataclasses
from functools import partial from functools import partial
import itertools import itertools
import time import time
from typing import Any, Callable, NamedTuple from typing import Any, NamedTuple
import logging import logging
import threading import threading

View File

@ -17,13 +17,13 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import copy import copy
import dataclasses import dataclasses
import functools import functools
import itertools import itertools
import re import re
from typing import Any, Callable, Union from typing import Any, Union
import warnings import warnings
from absl import logging from absl import logging

View File

@ -16,9 +16,9 @@
from __future__ import annotations from __future__ import annotations
from typing import Callable, TypeVar from collections.abc import Callable, Sequence
from collections.abc import Sequence
from functools import partial from functools import partial
from typing import TypeVar
try: try:
import flatbuffers import flatbuffers

View File

@ -21,7 +21,7 @@ import flatbuffers
from flatbuffers.compat import import_numpy from flatbuffers.compat import import_numpy
np = import_numpy() np = import_numpy()
class PyTreeDefKind(object): class PyTreeDefKind:
leaf = 0 leaf = 0
none = 1 none = 1
tuple = 2 tuple = 2
@ -29,12 +29,12 @@ class PyTreeDefKind(object):
dict = 4 dict = 4
class AbstractValueKind(object): class AbstractValueKind:
shapedArray = 0 shapedArray = 0
abstractToken = 1 abstractToken = 1
class DType(object): class DType:
bool = 0 bool = 0
i8 = 1 i8 = 1
i16 = 2 i16 = 2
@ -60,18 +60,18 @@ class DType(object):
f0 = 22 f0 = 22
class ShardingKind(object): class ShardingKind:
unspecified = 0 unspecified = 0
hlo_sharding = 1 hlo_sharding = 1
class DisabledSafetyCheckKind(object): class DisabledSafetyCheckKind:
platform = 0 platform = 0
custom_call = 1 custom_call = 1
shape_assertions = 2 shape_assertions = 2
class PyTreeDef(object): class PyTreeDef:
__slots__ = ['_tab'] __slots__ = ['_tab']
@classmethod @classmethod
@ -163,7 +163,7 @@ def PyTreeDefEnd(builder):
class AbstractValue(object): class AbstractValue:
__slots__ = ['_tab'] __slots__ = ['_tab']
@classmethod @classmethod
@ -235,7 +235,7 @@ def AbstractValueEnd(builder):
class Sharding(object): class Sharding:
__slots__ = ['_tab'] __slots__ = ['_tab']
@classmethod @classmethod
@ -304,7 +304,7 @@ def ShardingEnd(builder):
class Effect(object): class Effect:
__slots__ = ['_tab'] __slots__ = ['_tab']
@classmethod @classmethod
@ -340,7 +340,7 @@ def EffectEnd(builder):
class DisabledSafetyCheck(object): class DisabledSafetyCheck:
__slots__ = ['_tab'] __slots__ = ['_tab']
@classmethod @classmethod
@ -386,7 +386,7 @@ def DisabledSafetyCheckEnd(builder):
class Exported(object): class Exported:
__slots__ = ['_tab'] __slots__ = ['_tab']
@classmethod @classmethod

View File

@ -19,7 +19,7 @@ See documentation at https://jax.readthedocs.io/en/latest/export/shape_poly.html
from __future__ import annotations from __future__ import annotations
import enum import enum
from collections.abc import Sequence from collections.abc import Callable, Sequence
import dataclasses import dataclasses
from enum import Enum from enum import Enum
import functools import functools
@ -28,7 +28,7 @@ import io
import copy import copy
import operator as op import operator as op
import tokenize import tokenize
from typing import Any, Callable, Union, overload from typing import Any, Union, overload
import warnings import warnings
import numpy as np import numpy as np

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable from collections.abc import Callable, Hashable
from collections.abc import Hashable
from jax import Array from jax import Array

View File

@ -14,10 +14,9 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
from functools import partial from functools import partial
import enum import enum
from typing import Callable
import numpy as np import numpy as np

View File

@ -70,13 +70,13 @@ then update `test_custom_call_coverage`, and then update your `test_foo_call`:
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
import dataclasses import dataclasses
import datetime import datetime
import os import os
import re import re
import sys import sys
from typing import Any, Callable from typing import Any
from absl import logging from absl import logging

View File

@ -38,11 +38,11 @@ to fail. A Limitation is specific to a harness.
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
import operator import operator
import os import os
from functools import partial from functools import partial
from typing import Any, Callable, NamedTuple, Union from typing import Any, NamedTuple, Union
from absl import testing from absl import testing
import numpy as np import numpy as np

View File

@ -14,12 +14,12 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import contextlib import contextlib
import functools import functools
import itertools as it import itertools as it
from functools import partial from functools import partial
from typing import Any, Callable from typing import Any
import jax import jax
from jax._src import config from jax._src import config

View File

@ -14,10 +14,10 @@
from __future__ import annotations from __future__ import annotations
import collections import collections
from collections.abc import Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
import dataclasses import dataclasses
from functools import partial from functools import partial
from typing import Any, Callable, Union from typing import Any, Union
import numpy as np import numpy as np

View File

@ -16,7 +16,7 @@
from __future__ import annotations from __future__ import annotations
import collections import collections
from collections.abc import Iterator, Sequence from collections.abc import Callable, Iterator, Sequence
import dataclasses import dataclasses
import functools import functools
from functools import partial from functools import partial
@ -27,7 +27,7 @@ import os
import re import re
import types import types
import typing import typing
from typing import Any, Callable, NamedTuple, Protocol, Union, cast as type_cast from typing import Any, NamedTuple, Protocol, Union, cast as type_cast
import warnings import warnings
import numpy as np import numpy as np

View File

@ -14,13 +14,13 @@
from __future__ import annotations from __future__ import annotations
from collections import namedtuple from collections import namedtuple
from collections.abc import Sequence, Hashable from collections.abc import Callable, Sequence, Hashable
from contextlib import contextmanager, AbstractContextManager from contextlib import contextmanager, AbstractContextManager
from functools import partial from functools import partial
import inspect import inspect
import itertools as it import itertools as it
import operator as op import operator as op
from typing import Any, Callable, NamedTuple, Union from typing import Any, NamedTuple, Union
from weakref import ref from weakref import ref
import numpy as np import numpy as np

View File

@ -19,15 +19,14 @@ import enum
from contextlib import contextmanager from contextlib import contextmanager
import collections import collections
from collections import namedtuple from collections import namedtuple
from collections.abc import Sequence, Iterable from collections.abc import Callable, Sequence, Iterable, Iterator
import dataclasses import dataclasses
from functools import partial, lru_cache, cached_property from functools import partial, lru_cache, cached_property
import itertools as it import itertools as it
import logging import logging
import math import math
import threading import threading
from typing import Any, Callable, NamedTuple, TypeVar, Union, cast from typing import Any, NamedTuple, TypeVar, Union, cast
from collections.abc import Iterator
import warnings import warnings
import numpy as np import numpy as np

View File

@ -17,12 +17,12 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence from collections.abc import Callable, Sequence
import dataclasses import dataclasses
import functools import functools
from functools import partial from functools import partial
import itertools as it import itertools as it
from typing import Any, Callable, Protocol, Union from typing import Any, Protocol, Union
import numpy as np import numpy as np

View File

@ -17,11 +17,12 @@
from __future__ import annotations from __future__ import annotations
from collections import Counter, defaultdict from collections import Counter, defaultdict
from collections.abc import Callable
import gzip import gzip
import itertools import itertools
import json import json
import types import types
from typing import Any, Callable, Union from typing import Any, Union
from jax._src import core from jax._src import core
from jax._src import util from jax._src import util

View File

@ -15,10 +15,10 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import os import os
from functools import partial from functools import partial
from typing import Any, Callable from typing import Any
from jax._src import core from jax._src import core
from jax._src import linear_util as lu from jax._src import linear_util as lu

View File

@ -15,13 +15,13 @@
from __future__ import annotations from __future__ import annotations
import collections import collections
from collections.abc import Sequence from collections.abc import Callable, Sequence
import functools import functools
from functools import partial from functools import partial
import inspect import inspect
import itertools import itertools
import operator import operator
from typing import Any, Callable, TypeVar from typing import Any, TypeVar
import jax import jax
from jax.tree_util import tree_flatten, tree_unflatten from jax.tree_util import tree_flatten, tree_unflatten

View File

@ -15,10 +15,10 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import functools import functools
import operator import operator
from typing import Any, Callable, Generic, TypeVar from typing import Any, Generic, TypeVar
import jax.numpy as jnp import jax.numpy as jnp
from jax import lax from jax import lax

View File

@ -14,12 +14,12 @@
"""Module for the loop primitives.""" """Module for the loop primitives."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
from functools import partial from functools import partial
import inspect import inspect
import itertools import itertools
import operator import operator
from typing import Any, Callable, TypeVar from typing import Any, TypeVar
import weakref import weakref
import jax import jax

View File

@ -15,14 +15,14 @@
from __future__ import annotations from __future__ import annotations
import builtins import builtins
from collections.abc import Sequence from collections.abc import Callable, Sequence
import enum import enum
import functools import functools
from functools import partial from functools import partial
import itertools import itertools
import math import math
import operator import operator
from typing import Any, Callable, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING from typing import Any, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING
import warnings import warnings
import numpy as np import numpy as np
@ -2986,10 +2986,10 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
m, k = lhs.shape m, k = lhs.shape
group_count, rk, n = rhs.shape group_count, rk, n = rhs.shape
if k != rk: if k != rk:
raise TypeError("ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {} and {}.".format(k, rk)) raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.")
num_groups = group_sizes.shape[0] num_groups = group_sizes.shape[0]
if group_count != num_groups: if group_count != num_groups:
raise TypeError("ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {} and {}.".format(group_count, num_groups)) raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.")
return (m, n) return (m, n)
# DotDimensionNumbers used in the dot_general call for ragged_dot(). # DotDimensionNumbers used in the dot_general call for ragged_dot().

View File

@ -14,10 +14,11 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import functools import functools
from functools import partial from functools import partial
import math import math
from typing import Any, Callable, Literal, TypeVar, overload from typing import Any, Literal, TypeVar, overload
import numpy as np import numpy as np

View File

@ -14,12 +14,12 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import enum import enum
import operator import operator
from functools import partial from functools import partial
import math import math
from typing import Callable, NamedTuple from typing import NamedTuple
import weakref import weakref
import numpy as np import numpy as np

View File

@ -14,9 +14,8 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
from functools import partial from functools import partial
from typing import Callable
import warnings import warnings
from jax import tree_util from jax import tree_util

View File

@ -14,9 +14,9 @@
"""A LazyLoader class.""" """A LazyLoader class."""
from collections.abc import Sequence from collections.abc import Callable, Sequence
import importlib import importlib
from typing import Any, Callable from typing import Any
def attach(package_name: str, submodules: Sequence[str]) -> tuple[ def attach(package_name: str, submodules: Sequence[str]) -> tuple[

View File

@ -63,8 +63,9 @@ data must be immutable, because it will be stored in function memoization tables
""" """
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
from functools import partial from functools import partial
from typing import Any, Callable, NamedTuple from typing import Any, NamedTuple
import weakref import weakref
from jax._src import config from jax._src import config

View File

@ -15,12 +15,12 @@
from __future__ import annotations from __future__ import annotations
from collections import OrderedDict, abc from collections import OrderedDict, abc
from collections.abc import Iterable, Sequence, Mapping from collections.abc import Callable, Iterable, Sequence, Mapping
import contextlib import contextlib
from functools import wraps, partial, partialmethod, lru_cache from functools import wraps, partial, partialmethod, lru_cache
import itertools as it import itertools as it
import math import math
from typing import Callable, Any, NamedTuple, Union, cast as type_cast from typing import Any, NamedTuple, Union, cast as type_cast
import numpy as np import numpy as np

View File

@ -27,13 +27,13 @@ from __future__ import annotations
import builtins import builtins
import collections import collections
from collections.abc import Sequence from collections.abc import Callable, Sequence
from functools import partial from functools import partial
import importlib import importlib
import math import math
import operator import operator
import types import types
from typing import (cast, overload, Any, Callable, Literal, NamedTuple, from typing import (cast, overload, Any, Literal, NamedTuple,
Protocol, TypeVar, Union) Protocol, TypeVar, Union)
from textwrap import dedent as _dedent from textwrap import dedent as _dedent
import warnings import warnings

View File

@ -15,11 +15,11 @@
from __future__ import annotations from __future__ import annotations
import builtins import builtins
from collections.abc import Sequence from collections.abc import Callable, Sequence
from functools import partial from functools import partial
import math import math
import operator import operator
from typing import overload, Any, Callable, Literal, Protocol, Union from typing import overload, Any, Literal, Protocol, Union
import warnings import warnings
import numpy as np import numpy as np

View File

@ -16,10 +16,11 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
from functools import partial from functools import partial
import math import math
import operator import operator
from typing import Any, Callable from typing import Any
import jax import jax
from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.typing import Array, ArrayLike, DTypeLike

View File

@ -18,9 +18,9 @@ Implements ufuncs for jax.numpy.
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
from functools import partial from functools import partial
import operator import operator
from typing import Callable
import warnings import warnings

View File

@ -13,11 +13,11 @@
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
from functools import partial from functools import partial
import re import re
import textwrap import textwrap
from typing import Any, Callable, NamedTuple, TypeVar from typing import Any, NamedTuple, TypeVar
import warnings import warnings

View File

@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
from collections.abc import Collection, Sequence from collections.abc import Callable, Collection, Sequence
import functools import functools
import re import re
from typing import Any, Callable from typing import Any
from jax._src import api from jax._src import api
from jax import lax from jax import lax

View File

@ -16,8 +16,8 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
from typing import Callable, Union from typing import Union
import warnings import warnings
import numpy as np import numpy as np

View File

@ -15,13 +15,13 @@
"""Module for pallas-core functionality.""" """Module for pallas-core functionality."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterator, Sequence from collections.abc import Callable, Iterator, Sequence
import copy import copy
import contextlib import contextlib
import dataclasses import dataclasses
import functools import functools
import threading import threading
from typing import Any, Callable, Union from typing import Any, Union
import jax import jax
from jax._src import api_util from jax._src import api_util

View File

@ -78,7 +78,7 @@ class AbstractSemaphoreTy(dtypes.ExtendedDType):
return self.__class__ == other.__class__ return self.__class__ == other.__class__
def __hash__(self) -> int: def __hash__(self) -> int:
return hash((self.__class__)) return hash(self.__class__)
# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy # TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy
@ -109,7 +109,7 @@ class SemaphoreType(enum.Enum):
dtype = SemaphoreTy() dtype = SemaphoreTy()
return MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) return MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE)
def get_aval(self) -> "AbstractMemoryRef": def get_aval(self) -> AbstractMemoryRef:
return self(()).get_aval() return self(()).get_aval()
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)

View File

@ -15,11 +15,11 @@
"""Module for lowering JAX to Mosaic-compatible MLIR dialects.""" """Module for lowering JAX to Mosaic-compatible MLIR dialects."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import dataclasses import dataclasses
import functools import functools
import string import string
from typing import Any, Callable from typing import Any
import jax import jax
from jax import core as jax_core from jax import core as jax_core

View File

@ -15,12 +15,13 @@
"""Module for emitting custom TPU pipelines within a Pallas call.""" """Module for emitting custom TPU pipelines within a Pallas call."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence
import dataclasses import dataclasses
import enum import enum
import functools import functools
import itertools import itertools
import operator import operator
from typing import Optional, Union, Any, Sequence from typing import Union, Any
import jax import jax
from jax import lax from jax import lax
@ -201,12 +202,12 @@ class BufferedRef:
spec: pl.BlockSpec # static metadata spec: pl.BlockSpec # static metadata
dtype: Any # static metadata dtype: Any # static metadata
buffer_type: BufferType # static metadata buffer_type: BufferType # static metadata
vmem_ref: Optional[REF] vmem_ref: REF | None
accum_ref: Optional[REF] accum_ref: REF | None
current_slot: Optional[ArrayRef] current_slot: ArrayRef | None
next_slot: Optional[ArrayRef] next_slot: ArrayRef | None
sem_recv: Optional[SemaphoreType] sem_recv: SemaphoreType | None
sem_send: Optional[SemaphoreType] sem_send: SemaphoreType | None
def tree_flatten(self): def tree_flatten(self):
return ((self.vmem_ref, self.accum_ref, self.current_slot, return ((self.vmem_ref, self.accum_ref, self.current_slot,
@ -218,7 +219,7 @@ class BufferedRef:
return cls(*meta, *data) return cls(*meta, *data)
@classmethod @classmethod
def create(cls, spec, dtype, buffer_type) -> 'BufferedRef': def create(cls, spec, dtype, buffer_type) -> BufferedRef:
"""Create a BufferedRef. """Create a BufferedRef.
Args: Args:
@ -810,9 +811,9 @@ def _partition_grid(
if isinstance(grid[i], int) and grid[i] % num_cores == 0 if isinstance(grid[i], int) and grid[i] % num_cores == 0
} }
if divisible_dimensions: if divisible_dimensions:
first_divisible_dimension, *_ = [ first_divisible_dimension, *_ = (
i for i in range(len(dimension_semantics)) if i in divisible_dimensions i for i in range(len(dimension_semantics)) if i in divisible_dimensions
] )
partitioned_dim_size = grid[first_divisible_dimension] // num_cores partitioned_dim_size = grid[first_divisible_dimension] // num_cores
partitioned_dim_offset = pl.program_id(core_axis) * partitioned_dim_size partitioned_dim_offset = pl.program_id(core_axis) * partitioned_dim_size
new_grid = jax_util.tuple_update( new_grid = jax_util.tuple_update(
@ -828,11 +829,11 @@ def _partition_grid(
# potentially divide it more evenly # potentially divide it more evenly
largest_parallel_dimension = max(grid[i] for i in parallel_dimensions largest_parallel_dimension = max(grid[i] for i in parallel_dimensions
if isinstance(grid[i], int)) # type: ignore if isinstance(grid[i], int)) # type: ignore
partition_dimension, *_ = [ partition_dimension, *_ = (
i i
for i, d in enumerate(grid) for i, d in enumerate(grid)
if isinstance(d, int) and d == largest_parallel_dimension if isinstance(d, int) and d == largest_parallel_dimension
] )
base_num_iters, rem = divmod(grid[partition_dimension], num_cores) base_num_iters, rem = divmod(grid[partition_dimension], num_cores)
assert rem > 0, rem assert rem > 0, rem
# We have some remainder iterations that we need to assign somewhere. We # We have some remainder iterations that we need to assign somewhere. We

View File

@ -15,9 +15,10 @@
"""Module for Pallas:TPU-specific JAX primitives and functions.""" """Module for Pallas:TPU-specific JAX primitives and functions."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import dataclasses import dataclasses
import enum import enum
from typing import Any, Callable from typing import Any
import jax import jax
from jax._src import api_util from jax._src import api_util

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable, Optional
from collections.abc import Callable
import jax import jax
import numpy as np import numpy as np
@ -172,7 +173,7 @@ def sample_block(sampler_fn: SampleFnType,
block_size: Shape, block_size: Shape,
tile_size: Shape, tile_size: Shape,
total_size: Shape, total_size: Shape,
block_index: Optional[tuple[typing.ArrayLike, ...]] = None, block_index: tuple[typing.ArrayLike, ...] | None = None,
**kwargs) -> jax.Array: **kwargs) -> jax.Array:
"""Samples a block of random values with invariance guarantees. """Samples a block of random values with invariance guarantees.

View File

@ -15,10 +15,10 @@
"""Module for calling pallas functions from JAX.""" """Module for calling pallas functions from JAX."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
from functools import partial, reduce from functools import partial, reduce
import itertools import itertools
from typing import Any, Callable from typing import Any
import jax import jax
from jax import api_util from jax import api_util

View File

@ -16,12 +16,12 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import dataclasses import dataclasses
import functools import functools
import math import math
import operator import operator
from typing import Any, Callable, TypeVar from typing import Any, TypeVar
import jax import jax
from jax import lax from jax import lax

View File

@ -15,7 +15,7 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence, Iterable from collections.abc import Callable, Sequence, Iterable
import dataclasses import dataclasses
from functools import partial from functools import partial
import inspect import inspect
@ -23,7 +23,7 @@ import itertools as it
import logging import logging
import operator as op import operator as op
import weakref import weakref
from typing import Callable, NamedTuple, Any, Union, Optional, cast from typing import NamedTuple, Any, Union, cast
import threading import threading
import warnings import warnings
@ -245,7 +245,7 @@ def _need_to_rebuild_with_fdo(pgle_profiler):
def _get_fastpath_data( def _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked, effects, executable, out_tree, args_flat, out_flat, attrs_tracked, effects,
consts, abstracted_axes, pgle_profiler consts, abstracted_axes, pgle_profiler
) -> Optional[pxla.MeshExecutableFastpathData]: ) -> pxla.MeshExecutableFastpathData | None:
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat) out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
use_fastpath = ( use_fastpath = (
@ -608,7 +608,7 @@ def _infer_params_impl(
assert None not in in_shardings_leaves assert None not in in_shardings_leaves
assert None not in out_shardings_leaves assert None not in out_shardings_leaves
in_type: Union[core.InputType, tuple[core.AbstractValue, ...]] in_type: core.InputType | tuple[core.AbstractValue, ...]
if config.dynamic_shapes.value: if config.dynamic_shapes.value:
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args) in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
in_avals = tuple(a for a, e in in_type if e) in_avals = tuple(a for a, e in in_type if e)

View File

@ -13,11 +13,11 @@
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterator, Sequence from collections.abc import Callable, Iterator, Sequence
from functools import partial, reduce from functools import partial, reduce
import math import math
import operator as op import operator as op
from typing import Any, Callable, NamedTuple from typing import Any, NamedTuple
import numpy as np import numpy as np

View File

@ -14,6 +14,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
from contextlib import contextmanager from contextlib import contextmanager
from functools import wraps from functools import wraps
import glob import glob
@ -24,7 +25,7 @@ import logging
import os import os
import socketserver import socketserver
import threading import threading
from typing import Callable, List, Optional, Union, Any from typing import Any
from jax._src import traceback_util from jax._src import traceback_util
traceback_util.register_exclusion(__file__) traceback_util.register_exclusion(__file__)
@ -210,7 +211,7 @@ def stop_trace():
_profile_state.reset() _profile_state.reset()
def stop_and_get_fdo_profile() -> Union[bytes, str]: def stop_and_get_fdo_profile() -> bytes | str:
"""Stops the currently-running profiler trace and export fdo_profile. """Stops the currently-running profiler trace and export fdo_profile.
Currently, this is only supported for GPU. Currently, this is only supported for GPU.
@ -391,10 +392,10 @@ class PGLEProfiler:
self.percentile: int = percentile self.percentile: int = percentile
self.collected_fdo: str | None = None self.collected_fdo: str | None = None
self.called_times: int = 0 self.called_times: int = 0
self.fdo_profiles: List[Any] = [] self.fdo_profiles: list[Any] = []
self.current_session: xla_client.profiler.ProfilerSession | None = None self.current_session: xla_client.profiler.ProfilerSession | None = None
def consume_fdo_profile(self) -> Optional[str]: def consume_fdo_profile(self) -> str | None:
if self.collected_fdo is not None: if self.collected_fdo is not None:
return self.collected_fdo return self.collected_fdo

View File

@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Sequence from collections.abc import Callable, Sequence
import functools import functools
import itertools import itertools
import operator import operator
from typing import Callable
from jax._src import api from jax._src import api
from jax._src import util from jax._src import util

View File

@ -15,8 +15,9 @@
from __future__ import annotations from __future__ import annotations
from typing import Callable, NamedTuple from collections.abc import Callable
from functools import partial from functools import partial
from typing import NamedTuple
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp

View File

@ -15,8 +15,9 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
from functools import partial from functools import partial
from typing import Callable, NamedTuple from typing import NamedTuple
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp

View File

@ -14,8 +14,8 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping from collections.abc import Callable, Mapping
from typing import Any, Callable from typing import Any
import jax import jax
from jax._src.scipy.optimize.bfgs import minimize_bfgs from jax._src.scipy.optimize.bfgs import minimize_bfgs

View File

@ -14,11 +14,10 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
from functools import partial from functools import partial
import math import math
import operator import operator
from typing import Callable
import warnings import warnings
import numpy as np import numpy as np

View File

@ -18,9 +18,10 @@ An implementation of sourcemaps following `TC39 <https://tc39.es/source-map>`_.
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable, Sequence
from dataclasses import dataclass from dataclasses import dataclass
import json import json
from typing import Iterable, Sequence, Union from typing import Union
# A Segment encodes how parts in the generated source relate to the original source. # A Segment encodes how parts in the generated source relate to the original source.
# Each segment is made up of 1, 4 or 5 variable-length fields. For their semantics see # Each segment is made up of 1, 4 or 5 variable-length fields. For their semantics see

View File

@ -315,7 +315,7 @@ class XlaLowering(Lowering):
def hlo(self) -> xc.XlaComputation: def hlo(self) -> xc.XlaComputation:
"""Return an HLO representation of this computation.""" """Return an HLO representation of this computation."""
hlo = self.stablehlo() hlo = self.stablehlo()
m: Union[str, bytes] m: str | bytes
m = mlir.module_to_bytecode(hlo) m = mlir.module_to_bytecode(hlo)
return xla_extension.mlir.mlir_module_to_xla_computation( return xla_extension.mlir.mlir_module_to_xla_computation(
m, use_tuple_args=self.compile_args["tuple_args"]) m, use_tuple_args=self.compile_args["tuple_args"])

View File

@ -14,11 +14,11 @@
"""Module for discharging state primitives.""" """Module for discharging state primitives."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import dataclasses import dataclasses
from functools import partial from functools import partial
import operator import operator
from typing import Any, Callable, Protocol from typing import Any, Protocol
import numpy as np import numpy as np

View File

@ -65,9 +65,9 @@ class Slice:
@classmethod @classmethod
def tree_unflatten(cls, aux_data, children) -> Slice: def tree_unflatten(cls, aux_data, children) -> Slice:
start, size = [ start, size = (
a if a is not None else b for a, b in zip(children, aux_data[:2]) a if a is not None else b for a, b in zip(children, aux_data[:2])
] )
return cls(start, size, aux_data[2]) return cls(start, size, aux_data[2])
@classmethod @classmethod

View File

@ -16,7 +16,7 @@
from __future__ import annotations from __future__ import annotations
import collections import collections
from collections.abc import Generator, Iterable, Sequence from collections.abc import Callable, Generator, Iterable, Sequence
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
import datetime import datetime
import functools import functools
@ -28,7 +28,7 @@ import re
import sys import sys
import tempfile import tempfile
import textwrap import textwrap
from typing import Any, Callable from typing import Any
import unittest import unittest
import warnings import warnings
import zlib import zlib

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Callable from collections.abc import Callable
from jax import jit, lax from jax import jit, lax
import jax.numpy as jnp import jax.numpy as jnp

View File

@ -19,13 +19,13 @@ from __future__ import annotations
import base64 import base64
import collections.abc import collections.abc
from collections.abc import Sequence from collections.abc import Callable, Sequence
import dataclasses import dataclasses
import functools import functools
import io import io
import os import os
import time import time
from typing import Any, Callable from typing import Any
import jax import jax
from jax import core from jax import core

View File

@ -14,12 +14,13 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import functools import functools
import os import os
import sys import sys
import traceback import traceback
import types import types
from typing import Any, Callable, TypeVar, cast from typing import Any, TypeVar, cast
from jax._src import config from jax._src import config
from jax._src import util from jax._src import util

View File

@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, Iterable, TypeVar, overload from collections.abc import Callable, Iterable
from typing import Any, TypeVar, overload
from jax._src import tree_util from jax._src import tree_util

View File

@ -14,14 +14,14 @@
from __future__ import annotations from __future__ import annotations
import collections import collections
from collections.abc import Hashable, Iterable from collections.abc import Callable, Hashable, Iterable, Sequence
from dataclasses import dataclass from dataclasses import dataclass
import difflib import difflib
import functools import functools
from functools import partial from functools import partial
import operator as op import operator as op
import textwrap import textwrap
from typing import Any, Callable, NamedTuple, Sequence, TypeVar, Union, overload from typing import Any, NamedTuple, TypeVar, Union, overload
from jax._src import traceback_util from jax._src import traceback_util
from jax._src.lib import pytree from jax._src.lib import pytree

View File

@ -15,14 +15,14 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
from collections.abc import Iterable, Iterator, Sequence from collections.abc import Callable, Iterable, Iterator, Sequence
import dataclasses import dataclasses
import functools import functools
from functools import partial from functools import partial
import itertools as it import itertools as it
import logging import logging
import operator import operator
from typing import (Any, Callable, Generic, TypeVar, overload, TYPE_CHECKING, cast) from typing import (Any, Generic, TypeVar, overload, TYPE_CHECKING, cast)
import weakref import weakref
import numpy as np import numpy as np

View File

@ -21,7 +21,7 @@ XLA. There are also a handful of related casting utilities.
from __future__ import annotations from __future__ import annotations
import atexit import atexit
from collections.abc import Mapping from collections.abc import Callable, Mapping
import dataclasses import dataclasses
from functools import lru_cache, partial from functools import lru_cache, partial
import importlib import importlib
@ -32,7 +32,7 @@ import pkgutil
import platform as py_platform import platform as py_platform
import threading import threading
import traceback import traceback
from typing import Any, Callable, Union from typing import Any, Union
import warnings import warnings
from jax._src import config from jax._src import config

View File

@ -91,7 +91,8 @@ Example Usage:
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, NamedTuple from collections.abc import Callable
from typing import Any, NamedTuple
from collections import namedtuple from collections import namedtuple
import functools import functools

View File

@ -15,7 +15,6 @@
from __future__ import annotations from __future__ import annotations
import jax import jax
from typing import Tuple
from jax._src.sharding import Sharding from jax._src.sharding import Sharding
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src import dtypes as _dtypes, config from jax._src import dtypes as _dtypes, config
@ -71,7 +70,7 @@ class __array_namespace_info__:
def dtypes( def dtypes(
self, *, self, *,
device: xc.Device | Sharding | None = None, device: xc.Device | Sharding | None = None,
kind: str | Tuple[str, ...] | None = None): kind: str | tuple[str, ...] | None = None):
# Array API supported dtypes are device-independent in JAX # Array API supported dtypes are device-independent in JAX
del device del device
data_types = self._build_dtype_dict() data_types = self._build_dtype_dict()

View File

@ -17,16 +17,15 @@ from __future__ import annotations
import abc import abc
import asyncio import asyncio
from collections.abc import Awaitable, Sequence from collections.abc import Awaitable, Callable, Sequence
from functools import partial from functools import partial
import itertools import itertools
import logging import logging
import os import os
import re import re
import sys
import threading import threading
import time import time
from typing import Any, Callable, Optional, Union from typing import Any
import jax import jax
from jax._src import array from jax._src import array
@ -130,7 +129,7 @@ def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
return spec return spec
def is_remote_storage(tspec: Union[dict[str, Any], str]) -> bool: def is_remote_storage(tspec: dict[str, Any] | str) -> bool:
"""Detect if user is using cloud storages. """Detect if user is using cloud storages.
This can detect common defines and unable to detect some corner cases such as This can detect common defines and unable to detect some corner cases such as
@ -190,7 +189,7 @@ async def async_serialize(
tensorstore_spec, tensorstore_spec,
commit_future=None, commit_future=None,
context=TS_CONTEXT, context=TS_CONTEXT,
primary_host: Optional[int] = 0, primary_host: int | None = 0,
replica_id: int = 0, replica_id: int = 0,
): ):
"""Serialize an array using TensorStore. """Serialize an array using TensorStore.

View File

@ -503,14 +503,14 @@ from __future__ import annotations
import atexit import atexit
import enum import enum
from collections.abc import Sequence from collections.abc import Callable, Sequence
import functools import functools
import itertools import itertools
import logging import logging
import math import math
import threading import threading
import traceback import traceback
from typing import Any, Callable, cast from typing import Any, cast
import jax import jax
from jax._src import api from jax._src import api

View File

@ -25,10 +25,10 @@ https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#callin
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import dataclasses import dataclasses
import functools import functools
from typing import Any, Callable, Optional from typing import Any
from absl import logging from absl import logging
import jax import jax

View File

@ -21,12 +21,12 @@ See README.md for how these are used.
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import functools import functools
import logging import logging
import re import re
import time import time
from typing import Any, Callable, Optional from typing import Any
from absl import flags from absl import flags
import flax import flax

View File

@ -26,8 +26,8 @@ customize this function as needed.
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
from typing import Any, Callable, Optional, Union from typing import Any
from jax.experimental import jax2tf from jax.experimental import jax2tf
import tensorflow as tf import tensorflow as tf

View File

@ -16,12 +16,12 @@
from __future__ import annotations from __future__ import annotations
import builtins import builtins
from collections.abc import Sequence from collections.abc import Callable, Sequence
import dataclasses import dataclasses
from functools import partial, wraps from functools import partial, wraps
import math import math
import string import string
from typing import Any, Callable, Optional from typing import Any
from jax._src import core from jax._src import core
from jax import lax from jax import lax

View File

@ -15,7 +15,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from functools import partial from functools import partial
import contextlib import contextlib
import math import math
@ -23,7 +23,7 @@ import operator
import os import os
import re import re
import threading import threading
from typing import Any, Callable, Union from typing import Any, Union
import warnings import warnings
from absl import logging from absl import logging

View File

@ -20,11 +20,10 @@ these tests.
from __future__ import annotations from __future__ import annotations
import base64 import base64
from collections.abc import Sequence from collections.abc import Callable, Sequence
import io import io
import os import os
import tarfile import tarfile
from typing import Callable, Optional
from absl.testing import absltest from absl.testing import absltest
import jax import jax

View File

@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
"""Tests for call_tf.""" """Tests for call_tf."""
from collections.abc import Callable
import contextlib import contextlib
from functools import partial from functools import partial
import os import os
from typing import Callable
import unittest import unittest
from absl import logging from absl import logging

View File

@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Converters for jax2tf.""" """Converters for jax2tf."""
from collections.abc import Callable
import dataclasses import dataclasses
import functools import functools
import tempfile import tempfile
from typing import Any, Callable from typing import Any
from jax.experimental import jax2tf from jax.experimental import jax2tf
import tensorflow as tf import tensorflow as tf
import tensorflowjs as tfjs import tensorflowjs as tfjs

View File

@ -26,12 +26,11 @@ currently saved file with the saved one.
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Callable, Sequence
import contextlib import contextlib
import dataclasses import dataclasses
import os import os
import re import re
from typing import Callable, Optional
import zlib import zlib
from absl import app from absl import app

View File

@ -18,8 +18,9 @@ https://github.com/google/flax/tree/main/examples/sst2
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import functools import functools
from typing import Any, Callable, Optional from typing import Any
from flax import linen as nn from flax import linen as nn
import jax import jax

View File

@ -16,8 +16,7 @@
https://github.com/google/flax/tree/main/examples/ogbg_molpcba https://github.com/google/flax/tree/main/examples/ogbg_molpcba
""" """
from collections.abc import Sequence from collections.abc import Callable, Sequence
from typing import Callable
from flax import linen as nn from flax import linen as nn

View File

@ -19,9 +19,9 @@ https://github.com/google/flax/tree/main/examples/imagenet
# See issue #620. # See issue #620.
# pytype: disable=wrong-arg-count # pytype: disable=wrong-arg-count
from collections.abc import Sequence from collections.abc import Callable, Sequence
from functools import partial from functools import partial
from typing import Any, Callable from typing import Any
from flax import linen as nn from flax import linen as nn
import jax.numpy as jnp import jax.numpy as jnp

Some files were not shown because too many files have changed in this diff Show More