Use lower-case PEP 585 names for types.

Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
This commit is contained in:
Peter Hawkins 2023-06-23 15:11:37 -07:00 committed by jax authors
parent f67acee129
commit 816ba91263
148 changed files with 1492 additions and 1526 deletions

View File

@ -15,7 +15,6 @@
import argparse
import os
from typing import List
ISSUE_FORMAT = """\
<details><summary>Failure summary {name}</summary>
@ -27,7 +26,7 @@ ISSUE_FORMAT = """\
</details>
"""
def main(logfiles: List[str], outfile: str):
def main(logfiles: list[str], outfile: str):
print(f"extracting content of {logfiles}")
print(f"and writing to {outfile}")
with open(outfile, 'w') as f:

View File

@ -536,9 +536,9 @@
}
],
"source": [
"from typing import Tuple, Iterable\n",
"from typing import Iterable\n",
"\n",
"def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:\n",
"def flatten_MyContainer(container) -> tuple[Iterable[int], str]:\n",
" \"\"\"Returns an iterable over container contents, and aux data.\"\"\"\n",
" flat_contents = [container.a, container.b, container.c]\n",
"\n",
@ -593,7 +593,7 @@
"class MyKeyPathContainer(MyContainer):\n",
" pass\n",
"\n",
"def flatten_with_keys_MyKeyPathContainer(container) -> Tuple[Iterable[int], str]:\n",
"def flatten_with_keys_MyKeyPathContainer(container) -> tuple[Iterable[int], str]:\n",
" \"\"\"Returns an iterable over container contents, and aux data.\"\"\"\n",
" \n",
" # GetAttrKey is a common way to express an attribute key. Users are free\n",

View File

@ -277,9 +277,9 @@ except TypeError as e:
To solve this, we need to register our container with JAX by telling it how to flatten and unflatten it:
```{code-cell} ipython3
from typing import Tuple, Iterable
from typing import Iterable
def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:
def flatten_MyContainer(container) -> tuple[Iterable[int], str]:
"""Returns an iterable over container contents, and aux data."""
flat_contents = [container.a, container.b, container.c]
@ -312,7 +312,7 @@ Alternatively, using the key path API mentioned above, you can register this con
class MyKeyPathContainer(MyContainer):
pass
def flatten_with_keys_MyKeyPathContainer(container) -> Tuple[Iterable[int], str]:
def flatten_with_keys_MyKeyPathContainer(container) -> tuple[Iterable[int], str]:
"""Returns an iterable over container contents, and aux data."""
# GetAttrKey is a common way to express an attribute key. Users are free

View File

@ -541,7 +541,7 @@
},
"outputs": [],
"source": [
"from typing import NamedTuple, Tuple\n",
"from typing import NamedTuple\n",
"import functools\n",
"\n",
"class Params(NamedTuple):\n",
@ -571,7 +571,7 @@
"# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it\n",
"# 'num_devices', but could have used anything, so long as `pmean` used the same.\n",
"@functools.partial(jax.pmap, axis_name='num_devices')\n",
"def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:\n",
"def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> tuple[Params, jnp.ndarray]:\n",
" \"\"\"Performs one SGD update step on params using the given data.\"\"\"\n",
"\n",
" # Compute the gradients on the given minibatch (individually on each device).\n",

View File

@ -219,7 +219,7 @@ If this example is too confusing, you can find the same example, but without par
```{code-cell} ipython3
:id: cI8xQqzRrc-4
from typing import NamedTuple, Tuple
from typing import NamedTuple
import functools
class Params(NamedTuple):
@ -249,7 +249,7 @@ LEARNING_RATE = 0.005
# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it
# 'num_devices', but could have used anything, so long as `pmean` used the same.
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> tuple[Params, jnp.ndarray]:
"""Performs one SGD update step on params using the given data."""
# Compute the gradients on the given minibatch (individually on each device).

View File

@ -161,13 +161,11 @@
}
],
"source": [
"from typing import Tuple\n",
"\n",
"CounterState = int\n",
"\n",
"class CounterV2:\n",
"\n",
" def count(self, n: CounterState) -> Tuple[int, CounterState]:\n",
" def count(self, n: CounterState) -> tuple[int, CounterState]:\n",
" # You could just return n+1, but here we separate its role as \n",
" # the output and as the counter state for didactic purposes.\n",
" return n+1, n+1\n",

View File

@ -102,13 +102,11 @@ Part of the problem with our counter was that the returned value didn't depend o
:id: 53pSdK4KoOEZ
:outputId: 5ac72b9c-7029-4bf2-de8d-1d412bd74c79
from typing import Tuple
CounterState = int
class CounterV2:
def count(self, n: CounterState) -> Tuple[int, CounterState]:
def count(self, n: CounterState) -> tuple[int, CounterState]:
# You could just return n+1, but here we separate its role as
# the output and as the counter state for didactic purposes.
return n+1, n+1

View File

@ -1141,10 +1141,8 @@
},
"source": [
"```python\n",
"from typing import Tuple, List\n",
"\n",
"LayerParam = Tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer\n",
"ParamsList = List[LayerParam]\n",
"LayerParam = tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer\n",
"ParamsList = list[LayerParam]\n",
"\n",
"def net(params: ParamsList, x: jnp.ndarray):\n",
" for W, b in params:\n",

View File

@ -497,10 +497,8 @@ For example, one common pattern in large [Transformer models](https://en.wikiped
+++ {"id": "BUeqKFRS5yPU"}
```python
from typing import Tuple, List
LayerParam = Tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer
ParamsList = List[LayerParam]
LayerParam = tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer
ParamsList = list[LayerParam]
def net(params: ParamsList, x: jnp.ndarray):
for W, b in params:

View File

@ -15,7 +15,6 @@
from __future__ import annotations
from functools import partial
from typing import Set
import numpy as np
@ -40,7 +39,7 @@ def zeros_like_array(x):
aval = ShapedArray(np.shape(x), dtype, weak_type=weak_type)
return ad_util.zeros_like_aval(aval)
numpy_scalar_types: Set[type] = { # pylint: disable=g-bare-generic
numpy_scalar_types: set[type] = { # pylint: disable=g-bare-generic
np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
np.complex64, np.complex128,
@ -52,7 +51,7 @@ if dtypes.int4 is not None:
if dtypes.uint4 is not None:
numpy_scalar_types.add(dtypes.uint4)
array_types: Set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic
array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic
def canonical_concrete_aval(val, weak_type=None):
return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val,

View File

@ -15,8 +15,7 @@
import functools
from functools import partial
import logging
from typing import (Any, Callable, FrozenSet, List, Optional, Sequence, Tuple,
Union)
from typing import Any, Callable, Optional, Sequence, Union
import types
import numpy as np
@ -134,7 +133,7 @@ checkpoint_policies = types.SimpleNamespace(
@api_boundary
def checkpoint(fun: Callable, *, prevent_cse: bool = True,
policy: Optional[Callable[..., bool]] = None,
static_argnums: Union[int, Tuple[int, ...]] = (),
static_argnums: Union[int, tuple[int, ...]] = (),
) -> Callable:
"""Make ``fun`` recompute internal linearization points when differentiated.
@ -348,14 +347,14 @@ class WrapHashably:
# See api_benchmark.py:bench_remat_eager_retracing_overheads_static_argnums.
# On that benchmark, including this caching makes a ~10x difference (which can
# be made arbitrary large by involving larger functions to be traced).
def _dyn_args_fun(fun: Callable, static_argnums: FrozenSet[int],
static_args: Tuple[WrapHashably, ...], nargs: int):
def _dyn_args_fun(fun: Callable, static_argnums: frozenset[int],
static_args: tuple[WrapHashably, ...], nargs: int):
if any(isinstance(x.val, core.Tracer) for x in static_args):
return _dyn_args_fun_uncached(fun, static_argnums, static_args, nargs)
return _dyn_args_fun_cached(fun, static_argnums, static_args, nargs)
def _dyn_args_fun_uncached(fun: Callable, static_argnums: FrozenSet[int],
static_args: Tuple[WrapHashably, ...], nargs: int):
def _dyn_args_fun_uncached(fun: Callable, static_argnums: frozenset[int],
static_args: tuple[WrapHashably, ...], nargs: int):
def new_fun(*dyn_args, **kwargs):
static_args_, dyn_args_ = iter(static_args), iter(dyn_args)
full_args = [next(static_args_).val if i in static_argnums
@ -391,7 +390,7 @@ def _trace_to_jaxpr(fun, in_tree, in_avals):
### Utilities
def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
def saved_residuals(f, *args, **kwargs) -> list[tuple[core.AbstractValue, str]]:
in_leaves, in_tree = tree_flatten((args, kwargs))
def f_(*args):
@ -409,7 +408,7 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
arg_info = pe.arg_info_all(dbg)
return _saved_residuals(jaxpr, arg_info)
def _saved_residuals(jaxpr, arg_info) -> List[Tuple[core.AbstractValue, str]]:
def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]:
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}
@ -579,7 +578,7 @@ ad.reducing_transposes[remat_p] = remat_transpose
def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: Union[bool, Sequence[bool]],
out_zeros: Union[bool, Sequence[bool]],
reduce_axes: Sequence[core.AxisName],
) -> Tuple[core.ClosedJaxpr, List[bool]]:
) -> tuple[core.ClosedJaxpr, list[bool]]:
if type(in_linear) is bool:
in_linear = (in_linear,) * len(jaxpr.in_avals)
if type(out_zeros) is bool:
@ -640,8 +639,8 @@ batching.axis_primitive_batchers[remat_p] = partial(remat_vmap, None)
batching.spmd_axis_primitive_batchers[remat_p] = remat_vmap
# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], Optional[core.JaxprEqn]]:
def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], Optional[core.JaxprEqn]]:
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
new_params = dict(eqn.params, jaxpr=new_jaxpr)
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
@ -781,7 +780,7 @@ def checkpoint_wrapper(
*,
concrete: bool = False,
prevent_cse: bool = True,
static_argnums: Union[int, Tuple[int, ...]] = (),
static_argnums: Union[int, tuple[int, ...]] = (),
policy: Optional[Callable[..., bool]] = None,
) -> Callable:
if concrete:

View File

@ -14,7 +14,7 @@
from __future__ import annotations
import types
from typing import Any, Callable, Dict, TypeVar, Union, cast
from typing import Any, Callable, TypeVar, Union, cast
from jax._src import core
from jax._src import traceback_util
@ -30,7 +30,7 @@ T = TypeVar('T')
map = safe_map
jaxval_adders: Dict[type, Callable[[ArrayLike, ArrayLike], Array]] = {}
jaxval_adders: dict[type, Callable[[ArrayLike, ArrayLike], Array]] = {}
def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array:
return add_jaxvals_p.bind(x, y)
@ -46,7 +46,7 @@ def add_impl(xs, ys):
def add_abstract(xs, ys):
return lattice_join(xs, ys)
jaxval_zeros_likers: Dict[type, Callable[[Any], Array]] = {}
jaxval_zeros_likers: dict[type, Callable[[Any], Array]] = {}
def instantiate(z: Union[Zero, Array]) -> Array:
if type(z) is Zero:
@ -56,7 +56,7 @@ def instantiate(z: Union[Zero, Array]) -> Array:
def zeros_like_aval(aval: core.AbstractValue) -> Array:
return aval_zeros_likers[type(aval)](aval)
aval_zeros_likers: Dict[type, Callable[[Any], Array]] = {}
aval_zeros_likers: dict[type, Callable[[Any], Array]] = {}
def zeros_like_jaxval(val: ArrayLike) -> Array:
return zeros_like_p.bind(val)

View File

@ -27,8 +27,8 @@ from functools import partial
import inspect
import math
import typing
from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union,
from typing import (Any, Callable, Generator, Hashable, Iterable, Literal,
NamedTuple, Optional, Sequence, TypeVar, Union,
overload, cast)
import weakref
@ -363,7 +363,7 @@ def disable_jit(disable: bool = True):
def xla_computation(fun: Callable,
static_argnums: Union[int, Iterable[int]] = (),
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
axis_env: Optional[Sequence[tuple[AxisName, int]]] = None,
in_parts=None, out_parts=None,
backend: Optional[str] = None,
tuple_args: bool = False,
@ -658,7 +658,7 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()
) -> Callable[..., Tuple[Any, Any]]:
) -> Callable[..., tuple[Any, Any]]:
"""Create a function that evaluates both ``fun`` and the gradient of ``fun``.
Args:
@ -1069,7 +1069,7 @@ def vmap(fun: F,
out_axes: Any = 0,
axis_name: Optional[AxisName] = None,
axis_size: Optional[int] = None,
spmd_axis_name: Optional[Union[AxisName, Tuple[AxisName, ...]]] = None
spmd_axis_name: Optional[Union[AxisName, tuple[AxisName, ...]]] = None
) -> F:
"""Vectorizing map. Creates a function which maps ``fun`` over argument axes.
@ -1254,7 +1254,7 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
f"containing an array, got empty *args={args} and **kwargs={kwargs}"
)
def _get_axis_size(name: str, shape: Tuple[core.AxisSize, ...], axis: int
def _get_axis_size(name: str, shape: tuple[core.AxisSize, ...], axis: int
) -> core.AxisSize:
try:
return shape[axis]
@ -1328,7 +1328,7 @@ def pmap(
backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
global_arg_shapes: Optional[tuple[tuple[int, ...], ...]] = None,
) -> Any:
"""Parallel map with support for collective operations.
@ -1888,7 +1888,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
def jvp(
fun: Callable, primals, tangents, has_aux: bool = False
) -> Tuple[Any, ...]:
) -> tuple[Any, ...]:
"""Computes a (forward-mode) Jacobian-vector product of ``fun``.
Args:
@ -1971,16 +1971,16 @@ def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
@overload
def linearize(fun: Callable, *primals, has_aux: Literal[False] = False
) -> Tuple[Any, Callable]:
) -> tuple[Any, Callable]:
...
@overload
def linearize(fun: Callable, *primals, has_aux: Literal[True]
) -> Tuple[Any, Callable, Any]:
) -> tuple[Any, Callable, Any]:
...
def linearize(fun: Callable, *primals, has_aux: bool = False
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
) -> Union[tuple[Any, Callable], tuple[Any, Callable, Any]]:
"""Produces a linear approximation to ``fun`` using :py:func:`jvp` and partial eval.
Args:
@ -2137,17 +2137,17 @@ def _vjp_pullback_wrapper(name, cotangent_dtypes, cotangent_shapes, io_tree,
def vjp(fun: Callable[..., T],
*primals: Any,
has_aux: Literal[False] = False,
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable]:
reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable]:
...
@overload
def vjp(fun: Callable[..., Tuple[T, U]], *primals: Any,
def vjp(fun: Callable[..., tuple[T, U]], *primals: Any,
has_aux: Literal[True],
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable, U]:
reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable, U]:
...
def vjp( # type: ignore
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
) -> Union[tuple[Any, Callable], tuple[Any, Callable, Any]]:
"""Compute a (reverse-mode) vector-Jacobian product of ``fun``.
:py:func:`grad` is implemented as a special case of :py:func:`vjp`.
@ -2310,7 +2310,7 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
def _flat_axes_specs(abstracted_axes, *args, **kwargs
) -> List[pe.AbstractedAxesSpec]:
) -> list[pe.AbstractedAxesSpec]:
if kwargs: raise NotImplementedError
def ax_leaf(l):
return (isinstance(l, dict) and all_leaves(l.values()) or
@ -2323,7 +2323,7 @@ def _flat_axes_specs(abstracted_axes, *args, **kwargs
@overload
def make_jaxpr(fun: Callable, # type: ignore
static_argnums: Union[int, Iterable[int]] = (),
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
axis_env: Optional[Sequence[tuple[AxisName, int]]] = None,
return_shape: Literal[False] = ...,
abstracted_axes: Optional[Any] = None,
) -> Callable[..., core.ClosedJaxpr]:
@ -2332,19 +2332,19 @@ def make_jaxpr(fun: Callable, # type: ignore
@overload
def make_jaxpr(fun: Callable, # type: ignore
static_argnums: Union[int, Iterable[int]] = (),
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
axis_env: Optional[Sequence[tuple[AxisName, int]]] = None,
return_shape: Literal[True] = ...,
abstracted_axes: Optional[Any] = None,
) -> Callable[..., Tuple[core.ClosedJaxpr, Any]]:
) -> Callable[..., tuple[core.ClosedJaxpr, Any]]:
...
def make_jaxpr(fun: Callable,
static_argnums: Union[int, Iterable[int]] = (),
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
axis_env: Optional[Sequence[tuple[AxisName, int]]] = None,
return_shape: bool = False,
abstracted_axes: Optional[Any] = None,
) -> Callable[..., Union[core.ClosedJaxpr,
Tuple[core.ClosedJaxpr, Any]]]:
tuple[core.ClosedJaxpr, Any]]]:
"""Creates a function that produces its jaxpr given example args.
Args:

View File

@ -15,8 +15,7 @@
import inspect
import operator
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Optional, Sequence,
Set, Tuple, Union)
from typing import Any, Callable, Iterable, Optional, Sequence, Union
import warnings
import numpy as np
@ -39,7 +38,7 @@ traceback_util.register_exclusion(__file__)
map = safe_map
def _ensure_index(x: Any) -> Union[int, Tuple[int, ...]]:
def _ensure_index(x: Any) -> Union[int, tuple[int, ...]]:
"""Ensure x is either an index or a tuple of indices."""
x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
try:
@ -47,7 +46,7 @@ def _ensure_index(x: Any) -> Union[int, Tuple[int, ...]]:
except TypeError:
return tuple(map(operator.index, x))
def _ensure_index_tuple(x: Any) -> Tuple[int, ...]:
def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
"""Convert x to a tuple of indices."""
x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
try:
@ -60,7 +59,7 @@ def _ensure_str(x: str) -> str:
raise TypeError(f"argument is not a string: {x}")
return x
def _ensure_str_tuple(x: Union[str, Iterable[str]]) -> Tuple[str, ...]:
def _ensure_str_tuple(x: Union[str, Iterable[str]]) -> tuple[str, ...]:
"""Convert x to a tuple of strings."""
if isinstance(x, str):
return (x,)
@ -97,7 +96,7 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):
def flattened_fun_in_tree(
fn: lu.WrappedFun
) -> Optional[Tuple[PyTreeDef, Callable[[], PyTreeDef], bool]]:
) -> Optional[tuple[PyTreeDef, Callable[[], PyTreeDef], bool]]:
# This implementation relies on internal details of linear_util.py's
# WrappedFun, but it's for the worthy cause of better user error messages.
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
@ -149,7 +148,7 @@ _POSITIONAL_ARGUMENTS = (
inspect.Parameter.POSITIONAL_OR_KEYWORD
)
def validate_argnums(sig: inspect.Signature, argnums: Tuple[int, ...], argnums_name: str) -> None:
def validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_name: str) -> None:
"""
Validate that the argnums are sensible for a given function.
@ -183,7 +182,7 @@ _KEYWORD_ARGUMENTS = (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
def validate_argnames(sig: inspect.Signature, argnames: Tuple[str, ...], argnames_name: str) -> None:
def validate_argnames(sig: inspect.Signature, argnames: tuple[str, ...], argnames_name: str) -> None:
"""
Validate that the argnames are sensible for a given function.
@ -192,8 +191,8 @@ def validate_argnames(sig: inspect.Signature, argnames: Tuple[str, ...], argname
marked as position-only (`f(pos_only, /, ...)`).
"""
var_kwargs = False
valid_kwargs: Set[str] = set()
invalid_kwargs: Set[str] = set()
valid_kwargs: set[str] = set()
invalid_kwargs: set[str] = set()
for param_name, param in sig.parameters.items():
if param.kind in _KEYWORD_ARGUMENTS:
valid_kwargs.add(param_name)
@ -253,7 +252,7 @@ def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True):
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
def _ensure_inbounds(allow_invalid: bool, num_args: int, argnums: Sequence[int]
) -> Tuple[int, ...]:
) -> tuple[int, ...]:
"""Ensure argnum is within bounds. Also resolves negative argnums."""
result = []
for i in argnums:
@ -267,8 +266,8 @@ def _ensure_inbounds(allow_invalid: bool, num_args: int, argnums: Sequence[int]
return tuple(result)
def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
args: Tuple[Any, ...], *, allow_invalid: bool):
def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
args: tuple[Any, ...], *, allow_invalid: bool):
"Version of ``argnums_partial`` that checks hashability of static_argnums."
if not static_argnums:
return f, args
@ -305,13 +304,13 @@ def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
yield ans
def argnames_partial_except(f: lu.WrappedFun, static_argnames: Tuple[str, ...],
kwargs: Dict[str, Any]):
def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
kwargs: dict[str, Any]):
if not static_argnames:
return f, kwargs
dyn_kwargs = {k: v for k, v in kwargs.items() if k not in static_argnames}
fixed_kwargs: Dict[str, Any] = {}
fixed_kwargs: dict[str, Any] = {}
for k, arg in kwargs.items():
if k not in dyn_kwargs:
try:
@ -333,16 +332,16 @@ def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
yield ans
def donation_vector(donate_argnums, args, kwargs) -> Tuple[bool, ...]:
def donation_vector(donate_argnums, args, kwargs) -> tuple[bool, ...]:
"""Returns a tuple with a boolean value for each leaf in args."""
res: List[bool] = []
res: list[bool] = []
for i, arg in enumerate(args):
donate = bool(i in donate_argnums)
res.extend((donate,) * tree_structure(arg).num_leaves)
res.extend((False,) * tree_structure(kwargs).num_leaves)
return tuple(res)
def rebase_donate_argnums(donate_argnums, static_argnums) -> Tuple[int, ...]:
def rebase_donate_argnums(donate_argnums, static_argnums) -> tuple[int, ...]:
"""Shifts donate to account for static.
>>> rebase_donate_argnums((3, 4), (0, 1))
@ -426,7 +425,7 @@ def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
def flat_out_axes(
f: lu.WrappedFun, out_spec: Any
) -> Tuple[lu.WrappedFun, Callable]:
) -> tuple[lu.WrappedFun, Callable]:
leaves, treedef = tree_flatten(out_spec)
f, out_axes = _flat_out_axes(f, tuple(leaves), treedef)
return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef))
@ -473,7 +472,7 @@ def infer_argnums_and_argnames(
sig: inspect.Signature,
argnums: Union[int, Iterable[int], None],
argnames: Union[str, Iterable[str], None],
) -> Tuple[Tuple[int, ...], Tuple[str, ...]]:
) -> tuple[tuple[int, ...], tuple[str, ...]]:
"""Infer missing argnums and argnames for a function with inspect."""
if argnums is None and argnames is None:
return (), ()
@ -504,7 +503,7 @@ def infer_argnums_and_argnames(
def resolve_argnums(
fun, donate_argnums, static_argnums, static_argnames
) -> Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[str, ...]]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[str, ...]]:
# Coerce input
donate_argnums = _ensure_index_tuple(donate_argnums)
@ -563,7 +562,7 @@ def shaped_abstractify(x):
return _shaped_abstractify_handlers[type(x)](x)
except KeyError:
return _shaped_abstractify_slow(x)
_shaped_abstractify_handlers: Dict[Any, Callable[[Any], core.ShapedArray]] = {}
_shaped_abstractify_handlers: dict[Any, Callable[[Any], core.ShapedArray]] = {}
def _str_abstractify(x):
@ -591,9 +590,9 @@ def api_hook(fun, tag: str):
return fun
def debug_info(traced_for: str, fun: Callable, args: Tuple[Any],
kwargs: Dict[str, Any], static_argnums: Tuple[int, ...],
static_argnames: Tuple[str, ...]) -> Optional[TracingDebugInfo]:
def debug_info(traced_for: str, fun: Callable, args: tuple[Any],
kwargs: dict[str, Any], static_argnums: tuple[int, ...],
static_argnames: tuple[str, ...]) -> Optional[TracingDebugInfo]:
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
src = fun_sourceinfo(fun)
arg_names = _arg_names(fun, args, kwargs, static_argnums, static_argnames)
@ -613,7 +612,7 @@ def fun_sourceinfo(fun: Callable) -> Optional[str]:
return None
def _arg_names(fn, args, kwargs, static_argnums, static_argnames,
) -> Optional[Tuple[str, ...]]:
) -> Optional[tuple[str, ...]]:
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
static_argnames_ = set(static_argnames)
@ -633,7 +632,7 @@ def result_paths(*args, **kwargs):
yield ans, [keystr(path) for path, _ in generate_key_paths(ans)]
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo],
result_paths: Optional[Tuple[Optional[str], ...]] = None,
result_paths: Optional[tuple[Optional[str], ...]] = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if trace_debug is None:
@ -647,7 +646,7 @@ def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo],
return jaxpr.replace(debug_info=debug_info)
def debug_info_final(f: lu.WrappedFun, dbg: Optional[TracingDebugInfo],
res_paths: Callable[[], Tuple[str, ...]]) -> lu.WrappedFun:
res_paths: Callable[[], tuple[str, ...]]) -> lu.WrappedFun:
"Attach trace-time debug info and result paths lazy thunk to an lu.WrappedFun"
if dbg is None: return f
assert dbg.result_paths is None

View File

@ -18,8 +18,8 @@ import math
import operator as op
import numpy as np
import functools
from typing import (Any, Callable, List, Optional, Sequence, Set, Tuple,
Union, cast, TYPE_CHECKING)
from typing import (Any, Callable, Optional, Sequence, Union, cast,
TYPE_CHECKING)
from jax._src import abstract_arrays
from jax._src import api
@ -42,9 +42,9 @@ from jax._src.sharding_impls import (
from jax._src.typing import ArrayLike
from jax._src.util import use_cpp_class, use_cpp_method
Shape = Tuple[int, ...]
Shape = tuple[int, ...]
Device = xc.Device
Index = Tuple[slice, ...]
Index = tuple[slice, ...]
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.
@ -149,7 +149,7 @@ class ArrayImpl(basearray.Array):
aval: core.ShapedArray
_sharding: Sharding
_arrays: List[ArrayImpl]
_arrays: list[ArrayImpl]
_committed: bool
_skip_checks: bool
_npy_value: Optional[np.ndarray]
@ -402,7 +402,7 @@ class ArrayImpl(basearray.Array):
raise ValueError('Length of devices is greater than 1. '
'Please use `.devices()`.')
def devices(self) -> Set[Device]:
def devices(self) -> set[Device]:
self._check_if_deleted()
return self.sharding.device_set

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, Set
from typing import Any, Callable, Optional, Sequence, Union
import numpy as np
from jax._src.sharding import Sharding
@ -33,7 +33,7 @@ class Array(abc.ABC):
aval: Any
@property
def shape(self) -> Tuple[int, ...]: ...
def shape(self) -> tuple[int, ...]: ...
@property
def sharding(self) -> Sharding: ...
@ -110,9 +110,9 @@ class Array(abc.ABC):
def __index__(self) -> int: ...
# np.ndarray methods:
def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
def all(self, axis: Optional[Union[int, tuple[int, ...]]] = None, out=None,
keepdims=None) -> Array: ...
def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
def any(self, axis: Optional[Union[int, tuple[int, ...]]] = None, out=None,
keepdims=None) -> Array: ...
def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ...
def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ...
@ -125,9 +125,9 @@ class Array(abc.ABC):
def conj(self) -> Array: ...
def conjugate(self) -> Array: ...
def copy(self) -> Array: ...
def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
def cumprod(self, axis: Optional[Union[int, tuple[int, ...]]] = None,
dtype=None, out=None) -> Array: ...
def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
def cumsum(self, axis: Optional[Union[int, tuple[int, ...]]] = None,
dtype=None, out=None) -> Array: ...
def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Array: ...
def dot(self, b, *, precision=None) -> Array: ...
@ -135,18 +135,18 @@ class Array(abc.ABC):
@property
def imag(self) -> Array: ...
def item(self, *args) -> Any: ...
def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
def max(self, axis: Optional[Union[int, tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None) -> Array: ...
def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
def mean(self, axis: Optional[Union[int, tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=False, *, where=None,) -> Array: ...
def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
def min(self, axis: Optional[Union[int, tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None) -> Array: ...
@property
def nbytes(self) -> int: ...
def nonzero(self, *, size=None, fill_value=None) -> Array: ...
def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
def prod(self, axis: Optional[Union[int, tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None) -> Array: ...
def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
def ptp(self, axis: Optional[Union[int, tuple[int, ...]]] = None, out=None,
keepdims=False,) -> Array: ...
def ravel(self, order='C') -> Array: ...
@property
@ -157,16 +157,16 @@ class Array(abc.ABC):
def round(self, decimals=0, out=None) -> Array: ...
def searchsorted(self, v, side='left', sorter=None) -> Array: ...
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: ...
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
def squeeze(self, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array: ...
def std(self, axis: Optional[Union[int, tuple[int, ...]]] = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
def sum(self, axis: Optional[Union[int, tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None) -> Array: ...
def swapaxes(self, axis1: int, axis2: int) -> Array: ...
def take(self, indices, axis: Optional[int] = None, out=None,
mode=None) -> Array: ...
def tobytes(self, order='C') -> bytes: ...
def tolist(self) -> List[Any]: ...
def tolist(self) -> list[Any]: ...
def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None,
out=None) -> Array: ...
def transpose(self, *args) -> Array: ...
@ -174,7 +174,7 @@ class Array(abc.ABC):
def T(self) -> Array: ...
@property
def mT(self) -> Array: ...
def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
def var(self, axis: Optional[Union[int, tuple[int, ...]]] = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
def view(self, dtype=None, type=None) -> Array: ...
@ -196,7 +196,7 @@ class Array(abc.ABC):
def copy_to_host_async(self) -> None: ...
def delete(self) -> None: ...
def device(self) -> Device: ...
def devices(self) -> Set[Device]: ...
def devices(self) -> set[Device]: ...
@property
def global_shards(self) -> Sequence[Shard]: ...
def is_deleted(self) -> bool: ...

View File

@ -16,8 +16,7 @@ from __future__ import annotations
import dataclasses
import functools
import itertools as it
from typing import (Union, Optional, Callable, Dict, Tuple, TypeVar,
FrozenSet, Type, Set, List, Sequence, Any)
from typing import Union, Optional, Callable, TypeVar, Sequence, Any
import numpy as np
@ -57,8 +56,8 @@ zip, unsafe_zip = safe_zip, zip
Bool = Union[bool, Array]
Int = Union[int, Array]
ErrorCategory = Type['JaxException']
Payload = List[Union[np.ndarray, Array]]
ErrorCategory = type['JaxException']
Payload = list[Union[np.ndarray, Array]]
PyTreeDef = jtu.PyTreeDef
Out = TypeVar('Out')
@ -102,8 +101,8 @@ class JaxException(Exception):
@functools.total_ordering
@dataclasses.dataclass(eq=True, frozen=True)
class ErrorEffect(effects.Effect):
error_type: Type[JaxException]
shape_dtypes: Tuple[api.ShapeDtypeStruct, ...]
error_type: type[JaxException]
shape_dtypes: tuple[api.ShapeDtypeStruct, ...]
def __lt__(self, other: 'ErrorEffect'):
shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable
@ -195,7 +194,7 @@ class FailedCheckError(JaxException):
@dataclasses.dataclass
class BatchedError(JaxException):
error_mapping: Dict[Tuple[int, ...], JaxException]
error_mapping: dict[tuple[int, ...], JaxException]
def __post_init__(self):
traceback_info = list(self.error_mapping.values())[0].traceback_info
@ -212,10 +211,10 @@ class BatchedError(JaxException):
@jtu.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class Error:
_pred: Dict[ErrorEffect, Bool]
_code: Dict[ErrorEffect, Int]
_metadata: Dict[Int, PyTreeDef] # mapping of code to JaxException treedef.
_payload: Dict[ErrorEffect, Payload]
_pred: dict[ErrorEffect, Bool]
_code: dict[ErrorEffect, Int]
_metadata: dict[Int, PyTreeDef] # mapping of code to JaxException treedef.
_payload: dict[ErrorEffect, Payload]
def get(self) -> Optional[str]:
"""Returns error message if error happened, None if no error happened."""
@ -278,7 +277,7 @@ class Error:
new_metadata = {**self._metadata, **metadata}
return Error(new_errs, new_codes, new_metadata, new_payload)
def _add_placeholder_effects(self, effects: Set[ErrorEffect]):
def _add_placeholder_effects(self, effects: set[ErrorEffect]):
"""Fill out Error with `effects` and np.ones arrays of their payloads."""
new_err = self._pred.copy()
new_code = self._code.copy()
@ -337,7 +336,7 @@ def _flatten_and_get_error_metadata_thunk(*invals):
def default_checkify_rule(primitive: core.Primitive, error: Error,
enabled_errors, *invals: core.Value,
**params: Any) -> Tuple[Error, Sequence[core.Value]]:
**params: Any) -> tuple[Error, Sequence[core.Value]]:
"""Default rule for primitives in `checkify` interpreter."""
if 'call_jaxpr' not in params:
# Default non-HOP case: just call primitive and don't update error.
@ -389,15 +388,15 @@ def get_shaped_aval(val):
return core.raise_to_shaped(core.get_aval(val))
def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
error: Error, *args) -> Tuple[Error, List[core.Value]]:
error: Error, *args) -> tuple[Error, list[core.Value]]:
err_vals, err_tree = jtu.tree_flatten(error)
return checkify_jaxpr_flat(jaxpr.jaxpr, jaxpr.consts,
enabled_errors, err_tree, *err_vals, *args)
def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
enabled_errors, err_tree: PyTreeDef,
*args: core.Value) -> Tuple[Error, List[Any]]:
env: Dict[core.Var, Any] = {}
*args: core.Value) -> tuple[Error, list[Any]]:
env: dict[core.Var, Any] = {}
err_vals, in_args = split_list(args, [err_tree.num_leaves])
error = jtu.tree_unflatten(err_tree, err_vals)
@ -558,7 +557,7 @@ ad.primitive_jvps[check_p] = check_jvp_rule
## checkify rules
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
error_checks: dict[core.Primitive, ErrorCheckRule] = {}
def summary() -> str:
@ -740,7 +739,7 @@ error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check,
@weakref_lru_cache
def jaxpr_to_checkify_jaxpr(
jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef,
*flat_err_and_in_vals) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]:
*flat_err_and_in_vals) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]:
checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr,
jaxpr.consts, enabled_errors,
err_tree)
@ -823,7 +822,7 @@ error_checks[lax.scan_p] = scan_error_check
def checkify_while_body_jaxpr(
cond_jaxpr: core.ClosedJaxpr, body_jaxpr: core.ClosedJaxpr,
enabled_errors, error: Error,
c_consts_num: int) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]:
c_consts_num: int) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]:
cond_f = core.jaxpr_as_fun(cond_jaxpr)
body_f = core.jaxpr_as_fun(body_jaxpr)
def new_body_f(*c_consts_and_vals):
@ -1062,8 +1061,8 @@ all_checks = automatic_checks | user_checks
def checkify(f: Callable[..., Out],
errors: FrozenSet[ErrorCategory] = user_checks
) -> Callable[..., Tuple[Error, Out]]:
errors: frozenset[ErrorCategory] = user_checks
) -> Callable[..., tuple[Error, Out]]:
"""Functionalize `check` calls in `fun`, and optionally add run-time error checks.
Run-time errors are either user-added :func:`~check` assertions, or

View File

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import List, Optional, Type, Sequence, Tuple
from typing import Optional, Sequence
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
logger = logging.getLogger(__name__)
@ -28,7 +28,7 @@ class ClusterEnv:
:class:`ClusterEnv` subclasses are automatically detected when imported.
"""
_cluster_types: List[Type['ClusterEnv']] = []
_cluster_types: list[type['ClusterEnv']] = []
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
@ -41,7 +41,7 @@ class ClusterEnv:
num_processes: Optional[int],
process_id: Optional[int],
local_device_ids: Optional[Sequence[int]]
) -> Tuple[Optional[str], Optional[int], Optional[int],
) -> tuple[Optional[str], Optional[int], Optional[int],
Optional[Sequence[int]]]:
if all(p is not None for p in (coordinator_address, num_processes,
process_id, local_device_ids)):

View File

@ -18,7 +18,7 @@ import logging
import os
import re
import sys
from typing import Any, List, Optional
from typing import Any, Optional
import zlib
import numpy as np
@ -310,7 +310,7 @@ _xla_flags_to_exclude_from_cache_key = [
"--xla_tpu_sdc_checker_enable_sdc_event_callbacks",
]
extra_flag_prefixes_to_include_in_cache_key: List[str] = []
extra_flag_prefixes_to_include_in_cache_key: list[str] = []
def _hash_xla_flags(hash_obj):

View File

@ -19,7 +19,7 @@ import logging
import os
import sys
import threading
from typing import Any, List, Callable, Hashable, NamedTuple, Iterator, Optional
from typing import Any, Callable, Hashable, NamedTuple, Iterator, Optional
from jax._src import lib
from jax._src.lib import jax_jit
@ -246,7 +246,7 @@ class Config:
default_value=True)
def define_enum_state(
self, name: str, enum_values: List[str], default: Optional[str],
self, name: str, enum_values: list[str], default: Optional[str],
help: str, update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \
= None):

View File

@ -27,10 +27,9 @@ import operator
from operator import attrgetter
import threading
import types
from typing import (Any, Callable, ClassVar, DefaultDict, Dict, FrozenSet,
Generator, Generic, Hashable, Iterable, Iterator, List,
NamedTuple, Optional, Sequence, Set, Tuple, Type, TypeVar,
Union, cast, overload)
from typing import (Any, Callable, ClassVar, DefaultDict, Generator, Generic,
Hashable, Iterable, Iterator, NamedTuple, Optional,
Sequence, TypeVar, Union, cast, overload)
import warnings
from weakref import ref
@ -71,17 +70,17 @@ no_effects: Effects = effects.no_effects
class JaxprDebugInfo(NamedTuple):
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: Tuple[Optional[str], ...] # e.g. ('args[0]', ... )
result_paths: Tuple[Optional[str], ...] # e.g. ('[0]', '[1]', ...)
arg_names: tuple[Optional[str], ...] # e.g. ('args[0]', ... )
result_paths: tuple[Optional[str], ...] # e.g. ('[0]', '[1]', ...)
class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
'_effects', '_debug_info']
_constvars: List[Var]
_invars: List[Var]
_outvars: List[Atom]
_eqns: List[JaxprEqn]
_constvars: list[Var]
_invars: list[Var]
_outvars: list[Atom]
_eqns: list[JaxprEqn]
_effects: Effects
_debug_info: Optional[JaxprDebugInfo]
@ -171,7 +170,7 @@ class ClosedJaxpr:
__slots__ = ['__weakref__', '_jaxpr', '_consts']
_jaxpr: Jaxpr
_consts: List[Any]
_consts: list[Any]
jaxpr = property(lambda self: self._jaxpr)
consts = property(lambda self: self._consts)
@ -230,10 +229,10 @@ def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
class JaxprEqn(NamedTuple):
invars: List[Atom]
outvars: List[Var]
invars: list[Atom]
outvars: list[Var]
primitive: Primitive
params: Dict[str, Any]
params: dict[str, Any]
effects: Effects
source_info: source_info_util.SourceInfo
@ -242,10 +241,10 @@ class JaxprEqn(NamedTuple):
def replace(
self,
invars: Optional[List[Atom]] = None,
outvars: Optional[List[Var]] = None,
invars: Optional[list[Atom]] = None,
outvars: Optional[list[Var]] = None,
primitive: Optional[Primitive] = None,
params: Optional[Dict[str, Any]] = None,
params: Optional[dict[str, Any]] = None,
effects: Optional[Effects] = None,
source_info: Optional[source_info_util.SourceInfo] = None,
):
@ -355,7 +354,7 @@ class Literal:
else:
return f'Literal(val={self.val})'
literalable_types: Set[type] = set()
literalable_types: set[type] = set()
Atom = Union[Var, Literal]
@ -436,7 +435,7 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True):
assert typecheck(v.aval, val), (v.aval, val)
env[v] = val
env: Dict[Var, Any] = {}
env: dict[Var, Any] = {}
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
lu = last_used(jaxpr)
@ -836,8 +835,8 @@ class EvalTrace(Trace):
class MainTrace:
level: int
trace_type: Type[Trace]
payload: Dict[str, Any]
trace_type: type[Trace]
payload: dict[str, Any]
def __init__(self, level, trace_type, **payload) -> None:
self.level = level
@ -861,7 +860,7 @@ class MainTrace:
class TraceStack:
# See comments in https://github.com/google/jax/pull/3370
stack: List[MainTrace]
stack: list[MainTrace]
dynamic: MainTrace
def __init__(self):
@ -912,8 +911,8 @@ no_axis_name = object()
class TraceState:
trace_stack: TraceStack
substack: List[Sublevel]
axis_env: List[AxisEnvFrame]
substack: list[Sublevel]
axis_env: list[AxisEnvFrame]
def __init__(self) -> None:
self.trace_stack = TraceStack()
@ -995,7 +994,7 @@ the following:
"""
def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]
) -> List[Tracer]:
) -> list[Tracer]:
"""Find the leaked tracers holding a reference to the MainTrace or SubLevel.
It's possible there's none! eg. there's some cases where JAX itself holds a
@ -1012,14 +1011,14 @@ def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces)))
return tracers
def leaked_tracer_error(name: str, t, tracers: List[Tracer]) -> Exception:
def leaked_tracer_error(name: str, t, tracers: list[Tracer]) -> Exception:
assert tracers
why = partial(_why_alive, {id(tracers)})
msgs = '\n\n'.join(f'{tracers[i]}{tracers[i]._origin_msg()}{why(tracers[i])}'
for i in range(len(tracers)))
return Exception(f'Leaked {name} {t}. Leaked tracer(s):\n\n{msgs}\n')
def _why_alive(ignore_ids: Set[int], x: Any) -> str:
def _why_alive(ignore_ids: set[int], x: Any) -> str:
parents = lambda x: [r for r in gc.get_referrers(x) if id(r) not in ignore_ids]
child, lines, seen = x, [], set()
while (id(child) not in seen and type(child) is not types.ModuleType
@ -1078,7 +1077,7 @@ def _why_alive_container_info(container, obj_id) -> str:
@contextmanager
def new_main(trace_type: Type[Trace],
def new_main(trace_type: type[Trace],
dynamic: bool = False,
**payload) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
@ -1106,7 +1105,7 @@ def new_main(trace_type: Type[Trace],
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
@contextmanager
def new_base_main(trace_type: Type[Trace],
def new_base_main(trace_type: type[Trace],
**payload) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack
@ -1234,7 +1233,7 @@ def get_referent(x: Any) -> Any:
def same_referent(x: Any, y: Any) -> bool:
return get_referent(x) is get_referent(y)
def dedup_referents(itr: Iterable[Any]) -> List[Any]:
def dedup_referents(itr: Iterable[Any]) -> list[Any]:
return list({HashableWrapper(get_referent(x)):x for x in itr}.values())
def definitely_equal(x, y):
@ -1247,7 +1246,7 @@ def definitely_equal(x, y):
# -------------------- abstract values --------------------
class AbstractValue:
__slots__: List[str] = []
__slots__: list[str] = []
def at_least_vspace(self):
raise NotImplementedError("must override")
@ -1292,13 +1291,13 @@ class OutDBIdx:
# a sequence of pairs where the first element of each pair is an AbstractValue
# (possibly containing DBIdx instances in its shape) and the second is a boolean
# indicating whether that argument is explicit (i.e. passed to the callable).
InputType = Tuple[Tuple[AbstractValue, bool], ...] # DBIdx in shapes
InputType = tuple[tuple[AbstractValue, bool], ...] # DBIdx in shapes
# For annotating jaxpr output types, we use a sequence of pairs where the first
# element of each pair is an AbstractValue (possibly containing InDBIdx and/or
# OutDBIdx instances in its shape) and the second is a boolean indicating
# whether that argument is explicit (i.e. returned by the callable).
OutputType = Tuple[Tuple[AbstractValue, bool], ...] # InDBIdx / OutDBIdx shapes
OutputType = tuple[tuple[AbstractValue, bool], ...] # InDBIdx / OutDBIdx shapes
def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType:
@ -1669,7 +1668,7 @@ def primal_dtype_to_tangent_dtype(primal_dtype):
# it's kind of convenient!
class DShapedArray(UnshapedArray):
__slots__ = ['shape']
shape: Tuple[AxisSize, ...] # noqa: F821
shape: tuple[AxisSize, ...] # noqa: F821
array_abstraction_level: int = 3
def __init__(self, shape, dtype, weak_type=False):
@ -1731,7 +1730,7 @@ class DConcreteArray(DShapedArray):
self.val = val
pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {}
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
class DArray:
@ -1813,7 +1812,7 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None):
if handler: return handler(aval, weak_type)
raise TypeError(type(aval))
raise_to_shaped_mappings : Dict[type, Callable] = {
raise_to_shaped_mappings : dict[type, Callable] = {
AbstractToken: lambda aval, _: aval,
Bot: lambda aval, _: aval,
UnshapedArray: lambda aval, _: aval,
@ -1907,7 +1906,7 @@ class DimensionHandler:
return d
_dimension_handler_int = DimensionHandler()
_SPECIAL_DIMENSION_HANDLERS: Dict[type, DimensionHandler] = {}
_SPECIAL_DIMENSION_HANDLERS: dict[type, DimensionHandler] = {}
DArrayDimHandler = type('DArrayDimHandler', (DimensionHandler,), {})()
def _get_special_dim_handler(dim: DimSize) -> Optional[DimensionHandler]:
@ -1917,7 +1916,7 @@ def _get_special_dim_handler(dim: DimSize) -> Optional[DimensionHandler]:
return DArrayDimHandler
return _SPECIAL_DIMENSION_HANDLERS.get(type(dim))
def _dim_handler_and_canonical(*dlist: DimSize) -> Tuple[DimensionHandler, Tuple[DimSize, ...]]:
def _dim_handler_and_canonical(*dlist: DimSize) -> tuple[DimensionHandler, tuple[DimSize, ...]]:
"""Finds the handler for the given dimensions; also returns the canonical dimensions.
A dimension is canonical if it is a Python integer scalar, or has a type
@ -2081,7 +2080,7 @@ def _canonicalize_dimension(dim: DimSize) -> DimSize:
else:
raise type_error
def canonicalize_shape(shape: Shape, context: str="") -> Tuple[Any, ...]:
def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
"""Canonicalizes and checks for errors in a user-provided shape value.
Args:
@ -2321,7 +2320,7 @@ closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')
closed_call_p.def_impl(call_impl)
outfeed_primitives: Set[Primitive] = set()
outfeed_primitives: set[Primitive] = set()
def jaxpr_uses_outfeed(jaxpr: Jaxpr) -> bool:
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
return any(primitive_uses_outfeed(eqn.primitive, eqn.params)
@ -2336,7 +2335,7 @@ def _param_uses_outfeed(param):
return True
return False
def primitive_uses_outfeed(prim: Primitive, params: Dict) -> bool:
def primitive_uses_outfeed(prim: Primitive, params: dict) -> bool:
if prim in outfeed_primitives:
return True
for param in params.values():
@ -2476,8 +2475,8 @@ def _unmap_dshaped_array(
else:
raise TypeError(axis)
AvalMapHandlerPair = Tuple[Callable, Callable]
aval_mapping_handlers: Dict[Type, AvalMapHandlerPair] = {
AvalMapHandlerPair = tuple[Callable, Callable]
aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
ConcreteArray: (_map_shaped_array, _unmap_shaped_array),
@ -2501,7 +2500,7 @@ def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
if f.name is not no_axis_name))
@contextmanager
def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]], tag: Any = None):
def extend_axis_env_nd(axes: Iterable[tuple[AxisName, int]], tag: Any = None):
frames = [AxisEnvFrame(axis_name, size, tag) for axis_name, size in axes]
ts = thread_local_state.trace_state
ts.axis_env.extend(frames)
@ -2573,8 +2572,8 @@ def axis_frame(axis_name: AxisName, main_trace: Optional[MainTrace] = None
f'by pmap) are available to collective operations: {named_axes}')
ParamDict = Dict[str, Any]
AxisSubst = Callable[[AxisName], Tuple[AxisName, ...]]
ParamDict = dict[str, Any]
AxisSubst = Callable[[AxisName], tuple[AxisName, ...]]
class NameGatheringSubst:
def __init__(self):
@ -2583,7 +2582,7 @@ class NameGatheringSubst:
self.axis_names.add(axis_name)
return (axis_name,)
def used_axis_names(primitive: Primitive, params: ParamDict) -> Set[AxisName]:
def used_axis_names(primitive: Primitive, params: ParamDict) -> set[AxisName]:
subst = NameGatheringSubst()
subst_axis_names(primitive, params, subst)
return subst.axis_names
@ -2612,7 +2611,7 @@ class DuplicateAxisNameError(Exception):
self.var = var
self.eqn = None
def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: Dict[Var, Var]) -> Var:
def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: dict[Var, Var]) -> Var:
# Var identity is load-bearing, so we can't have duplicates!
if isinstance(v, DropVar): return v
assert v not in var_map
@ -2627,8 +2626,8 @@ def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: Dict[Var, Var]) -> V
var_map[v] = new_v
return new_v
def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: Dict[Var, Var]) -> JaxprEqn:
invars: List[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in eqn.invars]
def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: dict[Var, Var]) -> JaxprEqn:
invars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in eqn.invars]
try:
outvars = [subst_axis_names_var(v, subst, var_map) for v in eqn.outvars]
except DuplicateAxisNameError as e:
@ -2642,11 +2641,11 @@ def do_subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst
if isinstance(jaxpr, ClosedJaxpr):
consts = jaxpr.consts
jaxpr = jaxpr.jaxpr
var_map: Dict[Var, Var] = {}
var_map: dict[Var, Var] = {}
invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars] # type: ignore[union-attr]
constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars] # type: ignore[union-attr]
eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns] # type: ignore[union-attr]
outvars: List[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars] # type: ignore[union-attr]
outvars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars] # type: ignore[union-attr]
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, jaxpr.effects)
if consts is not None:
return ClosedJaxpr(new_jaxpr, consts)
@ -2668,11 +2667,11 @@ def replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: Effects):
return _replace_jaxpr_effects(jaxpr, frozenset(effects))
@weakref_lru_cache
def _replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: FrozenSet[Effect]):
def _replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: frozenset[Effect]):
return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(effects=set(effects)))
axis_substitution_rules: Dict[Primitive, Callable[[ParamDict, AxisSubst, bool], ParamDict]] = {}
axis_substitution_rules: dict[Primitive, Callable[[ParamDict, AxisSubst, bool], ParamDict]] = {}
# ------------------- AxisPrimitive -------------------
# Primitives that store axis names in params and want those axis names to
@ -2722,7 +2721,7 @@ def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
class JaxprTypeError(TypeError): pass
custom_typechecks: Dict[Primitive, Callable] = {}
custom_typechecks: dict[Primitive, Callable] = {}
def _check_closed_call(_, *in_atoms, call_jaxpr):
in_avals = [x.aval for x in in_atoms]
@ -2765,11 +2764,11 @@ def check_jaxpr(jaxpr: Jaxpr):
raise JaxprTypeError(msg) from None
def _check_jaxpr(
ctx_factory: Callable[[], Tuple[JaxprPpContext, JaxprPpSettings]],
ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]],
jaxpr: Jaxpr
) -> None:
# Use set of variables to types to check that variables are in scope.
env: Set[Var] = set()
env: set[Var] = set()
def read(x: Atom) -> Atom:
# Check the type annotation is itself well-typed.
@ -2874,8 +2873,8 @@ def _check_jaxpr(
map(read, jaxpr.outvars)
def check_type(
ctx_factory: Callable[[], Tuple[JaxprPpContext, JaxprPpSettings]],
env: Set[Var],
ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]],
env: set[Var],
ty: AbstractValue,
) -> None:
if isinstance(ty, DShapedArray):
@ -2912,7 +2911,7 @@ def substitute_vars_in_output_ty(
out_type: Sequence[AbstractValue], # shapes may contain InDBIdx / OutDBIdx
in_atoms: Sequence[Atom],
out_binders: Sequence[Var],
) -> List[AbstractValue]: # shapes may contain Vars
) -> list[AbstractValue]: # shapes may contain Vars
in_atoms = [x.val if type(x) is Literal else x for x in in_atoms]
result = []
for aval in out_type:
@ -2945,7 +2944,7 @@ def _check_call(ctx_factory, prim, in_atoms, params):
f"{len(call_jaxpr.invars)} inputs")
# Check `call_jaxpr` can be applied to in_atoms.
env: Dict[Var, Atom] = {}
env: dict[Var, Atom] = {}
def substitute(aval: AbstractValue):
if isinstance(aval, DShapedArray):
aval = aval.update(shape=tuple([env.get(d, d) for d in aval.shape])) # type: ignore
@ -2961,8 +2960,8 @@ def _check_call(ctx_factory, prim, in_atoms, params):
_check_jaxpr(ctx_factory, call_jaxpr)
invars, outvars = call_jaxpr.invars, call_jaxpr.outvars
in_map : Dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)}
out_map: Dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars)
in_map : dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)}
out_map: dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars)
if type(x) is Var}
out_avals = [x.aval for x in call_jaxpr.outvars]
out_type = [a.update(shape=tuple(in_map.get(d, out_map.get(d))
@ -3094,7 +3093,7 @@ def _pp_eqn(eqn, context, settings) -> pp.Doc:
pp.text(" ") + pp_vars(eqn.invars, context)]
return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs])
CustomPpEqnRule = Callable[[JaxprEqn, JaxprPpContext, JaxprPpSettings], pp.Doc]
pp_eqn_rules: Dict[Primitive, CustomPpEqnRule] = {}
pp_eqn_rules: dict[Primitive, CustomPpEqnRule] = {}
def pp_eqns(eqns, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
return pp.join(
@ -3109,7 +3108,7 @@ def _compact_eqn_should_include(k: str, v: Any) -> bool:
return False
return True
def str_eqn_compact(primitive_name: str, params: Dict) -> str:
def str_eqn_compact(primitive_name: str, params: dict) -> str:
"Compact equation to string conversion used in HLO metadata."
kvs = " ".join(f"{k}={v}" for k, v in params.items()
if _compact_eqn_should_include(k, v))
@ -3187,9 +3186,9 @@ def pp_effect(effect: Effect, context: JaxprPpContext) -> pp.Doc:
# ------------------- Jaxpr util -------------------
def last_used(jaxpr: Jaxpr) -> Dict[Var, Optional[JaxprEqn]]:
def last_used(jaxpr: Jaxpr) -> dict[Var, Optional[JaxprEqn]]:
"""Returns a mapping from every var in jaxpr to what equation uses it last."""
last_used: Dict[Var, Optional[JaxprEqn]] = {
last_used: dict[Var, Optional[JaxprEqn]] = {
v: None for v in jaxpr.outvars if not isinstance(v, Literal)}
for eqn in reversed(jaxpr.eqns):
for v in eqn.invars:
@ -3197,8 +3196,8 @@ def last_used(jaxpr: Jaxpr) -> Dict[Var, Optional[JaxprEqn]]:
last_used[v] = eqn
return last_used
def clean_up_dead_vars(eqn: JaxprEqn, env: Dict[Var, Any],
last_used: Dict[Var, Optional[JaxprEqn]]):
def clean_up_dead_vars(eqn: JaxprEqn, env: dict[Var, Any],
last_used: dict[Var, Optional[JaxprEqn]]):
"""Remove all eqn.invars from env if eqn is the last time they were used."""
for v in set(v for v in eqn.invars if not isinstance(v, Literal)):
if last_used[v] is eqn:

View File

@ -15,8 +15,7 @@
import dataclasses
from functools import update_wrapper, reduce, partial
import inspect
from typing import (
Any, Callable, Generic, List, Optional, Sequence, Tuple, TypeVar)
from typing import Any, Callable, Generic, Optional, Sequence, TypeVar
from jax._src import core
from jax._src import custom_api_util
@ -137,13 +136,13 @@ class custom_jvp(Generic[ReturnValue]):
.. _tutorial: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
"""
fun: Callable[..., ReturnValue]
nondiff_argnums: Tuple[int, ...]
jvp: Optional[Callable[..., Tuple[ReturnValue, ReturnValue]]] = None
nondiff_argnums: tuple[int, ...]
jvp: Optional[Callable[..., tuple[ReturnValue, ReturnValue]]] = None
symbolic_zeros: bool = False
def __init__(self,
fun: Callable[..., ReturnValue],
nondiff_argnums: Tuple[int, ...] = (),
nondiff_argnums: tuple[int, ...] = (),
):
update_wrapper(self, fun)
self.fun = fun
@ -152,9 +151,9 @@ class custom_jvp(Generic[ReturnValue]):
__getattr__ = custom_api_util.forward_attr
def defjvp(self,
jvp: Callable[..., Tuple[ReturnValue, ReturnValue]],
jvp: Callable[..., tuple[ReturnValue, ReturnValue]],
symbolic_zeros: bool = False,
) -> Callable[..., Tuple[ReturnValue, ReturnValue]]:
) -> Callable[..., tuple[ReturnValue, ReturnValue]]:
"""Define a custom JVP rule for the function represented by this instance.
Args:
@ -491,19 +490,19 @@ class custom_vjp(Generic[ReturnValue]):
def __init__(self,
fun: Callable[..., ReturnValue],
nondiff_argnums: Tuple[int, ...] = ()):
nondiff_argnums: tuple[int, ...] = ()):
update_wrapper(self, fun)
self.fun = fun
self.nondiff_argnums = nondiff_argnums
self.fwd: Optional[Callable[..., Tuple[ReturnValue, Any]]] = None
self.bwd: Optional[Callable[..., Tuple[Any, ...]]] = None
self.fwd: Optional[Callable[..., tuple[ReturnValue, Any]]] = None
self.bwd: Optional[Callable[..., tuple[Any, ...]]] = None
self.symbolic_zeros = False
__getattr__ = custom_api_util.forward_attr
def defvjp(self,
fwd: Callable[..., Tuple[ReturnValue, Any]],
bwd: Callable[..., Tuple[Any, ...]],
fwd: Callable[..., tuple[ReturnValue, Any]],
bwd: Callable[..., tuple[Any, ...]],
symbolic_zeros: bool = False,
) -> None:
"""Define a custom VJP rule for the function represented by this instance.
@ -831,7 +830,7 @@ mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun(
def _custom_vjp_call_jaxpr_jvp(
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[..., Tuple[core.Jaxpr, Sequence[Any]]],
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
_, args = split_list(primals, [num_consts])
consts_dot, args_dot = split_list(tangents, [num_consts])
@ -857,7 +856,7 @@ ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
def _custom_vjp_call_jaxpr_vmap(
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *,
fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[..., Tuple[core.Jaxpr, Sequence[Any]]],
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]
@ -1006,7 +1005,7 @@ class Residuals:
return cls(jaxpr, in_tree, out_tree, consts)
def closure_convert(fun: Callable, *example_args) -> Tuple[Callable, List[Any]]:
def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
"""Closure conversion utility, for use with higher-order custom derivatives.
To define custom derivatives such as with ``jax.custom_vjp(f)``, the target

View File

@ -13,7 +13,7 @@
# limitations under the License.
import functools
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Optional
from jax._src import ad_util
from jax._src import api_util
@ -51,7 +51,7 @@ class StoreEqual(lu.Store):
@util.curry
def transformation_with_aux(
gen, fun: lu.WrappedFun, *gen_static_args) -> Tuple[lu.WrappedFun, Any]:
gen, fun: lu.WrappedFun, *gen_static_args) -> tuple[lu.WrappedFun, Any]:
out_store = StoreEqual()
out_thunk = lambda: out_store.val
return fun.wrap(gen, gen_static_args, out_store), out_thunk

View File

@ -18,7 +18,7 @@ import pprint
import sys
import traceback
from typing import Any, IO, List, Optional
from typing import Any, IO, Optional
from jax._src.debugger import core as debugger_core
@ -28,7 +28,7 @@ class CliDebugger(cmd.Cmd):
"""A text-based debugger."""
prompt = '(jdb) '
def __init__(self, frames: List[DebuggerFrame], thread_id,
def __init__(self, frames: list[DebuggerFrame], thread_id,
stdin: Optional[IO[str]] = None, stdout: Optional[IO[str]] = None,
completekey: str = "tab"):
super().__init__(stdin=stdin, stdout=stdout, completekey=completekey)
@ -163,7 +163,7 @@ class CliDebugger(cmd.Cmd):
except KeyboardInterrupt:
print('--KeyboardInterrupt--', file=sys.stdout)
def run_debugger(frames: List[DebuggerFrame], thread_id: Optional[int],
def run_debugger(frames: list[DebuggerFrame], thread_id: Optional[int],
**kwargs: Any):
CliDebugger(frames, thread_id, **kwargs).run()
debugger_core.register_debugger("cli", run_debugger, -1)

View File

@ -18,8 +18,6 @@ import html
import inspect
import traceback
from typing import List
import uuid
from jax._src.debugger import colab_lib
@ -42,7 +40,7 @@ except ImportError:
class CodeViewer(colab_lib.DynamicDOMElement):
"""A mutable DOM element that displays code as HTML."""
def __init__(self, code_: str, highlights: List[int], linenostart: int = 1):
def __init__(self, code_: str, highlights: list[int], linenostart: int = 1):
self._code = code_
self._highlights = highlights
self._view = colab_lib.dynamic(colab_lib.div())
@ -226,7 +224,7 @@ class ColabDebugger(cli_debugger.CliDebugger):
"""A JAX debugger for a Colab environment."""
def __init__(self,
frames: List[debugger_core.DebuggerFrame],
frames: list[debugger_core.DebuggerFrame],
thread_id: int):
super().__init__(frames, thread_id)
self._debugger_view = DebuggerView(self.current_frame())

View File

@ -20,7 +20,7 @@ import functools
import sys
import uuid
from typing import Any, Dict, List, Union
from typing import Any, Union
IS_COLAB_ENABLED = "google.colab" in sys.modules
if IS_COLAB_ENABLED:
@ -106,8 +106,8 @@ class StaticDOMElement(DOMElement):
"""An immutable DOM element."""
_uuid: str = dataclasses.field(init=False)
name: str
children: List[Union[str, DOMElement]]
attrs: Dict[str, str]
children: list[Union[str, DOMElement]]
attrs: dict[str, str]
def html(self):
attr_str = ""
@ -137,7 +137,7 @@ class StaticDOMElement(DOMElement):
return dataclasses.replace(self, **kwargs)
def _style_dict_to_str(style_dict: Dict[str, Any]) -> str:
def _style_dict_to_str(style_dict: dict[str, Any]) -> str:
return " ".join([f"{k}: {v};" for k, v in style_dict.items()])

View File

@ -16,7 +16,7 @@ from __future__ import annotations
import dataclasses
import inspect
import threading
from typing import Any, Dict, Hashable, List, Optional, Protocol, Tuple
from typing import Any, Hashable, Optional, Protocol
import numpy as np
@ -72,10 +72,10 @@ def _safe_flatten_dict(dct: dict[Any, Any]
class DebuggerFrame:
"""Encapsulates Python frame information."""
filename: str
locals: Dict[str, Any]
globals: Dict[str, Any]
locals: dict[str, Any]
globals: dict[str, Any]
code_context: str
source: List[str]
source: list[str]
lineno: int
offset: Optional[int]
@ -131,10 +131,10 @@ class DebuggerFrame:
class Debugger(Protocol):
def __call__(self, frames: List[DebuggerFrame], thread_id: Optional[int],
def __call__(self, frames: list[DebuggerFrame], thread_id: Optional[int],
**kwargs: Any) -> None:
...
_debugger_registry: Dict[str, Tuple[int, Debugger]] = {}
_debugger_registry: dict[str, tuple[int, Debugger]] = {}
def get_debugger(backend: Optional[str] = None) -> Debugger:

View File

@ -15,12 +15,12 @@ from __future__ import annotations
import os
import weakref
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
from jax._src.debugger import cli_debugger
from jax._src.debugger import core as debugger_core
web_pdb_version: Optional[Tuple[int, ...]] = None
web_pdb_version: Optional[tuple[int, ...]] = None
try:
import web_pdb # pytype: disable=import-error
web_pdb_version = tuple(map(int, web_pdb.__version__.split(".")))
@ -29,14 +29,14 @@ except:
WEB_PDB_ENABLED = False
_web_consoles: Dict[Tuple[str, int], web_pdb.WebConsole] = {}
_web_consoles: dict[tuple[str, int], web_pdb.WebConsole] = {}
class WebDebugger(cli_debugger.CliDebugger):
"""A web-based debugger."""
prompt = '(jdb) '
use_rawinput: bool = False
def __init__(self, frames: List[debugger_core.DebuggerFrame], thread_id,
def __init__(self, frames: list[debugger_core.DebuggerFrame], thread_id,
completekey: str = "tab", host: str = "", port: int = 5555):
if (host, port) not in _web_consoles:
_web_consoles[host, port] = web_pdb.WebConsole(host, port, self)
@ -87,7 +87,7 @@ class WebDebugger(cli_debugger.CliDebugger):
def run(self):
return self.cmdloop()
def run_debugger(frames: List[debugger_core.DebuggerFrame],
def run_debugger(frames: list[debugger_core.DebuggerFrame],
thread_id: Optional[int], **kwargs: Any):
WebDebugger(frames, thread_id, **kwargs).run()

View File

@ -16,7 +16,7 @@
import functools
import string
import sys
from typing import Any, Dict, Callable, Optional, Sequence, Set, Tuple, Union
from typing import Any, Callable, Optional, Sequence, Union
import weakref
import numpy as np
@ -413,8 +413,8 @@ def _raise_to_slice(slc: Union[slice, int]):
return slice(slc, slc + 1)
return slc
Color = Union[Tuple[float, float, float], str]
ColorMap = Callable[[float], Tuple[float, float, float, float]]
Color = Union[tuple[float, float, float], str]
ColorMap = Callable[[float], tuple[float, float, float, float]]
def _canonicalize_color(color: Color) -> str:
if isinstance(color, str):
@ -464,9 +464,9 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *,
device_kind = next(iter(sharding.device_set)).platform.upper()
device_indices_map = sharding.devices_indices_map(tuple(shape))
slices: Dict[Tuple[int, ...], Set[int]] = {}
heights: Dict[Tuple[int, ...], Optional[float]] = {}
widths: Dict[Tuple[int, ...], float] = {}
slices: dict[tuple[int, ...], set[int]] = {}
heights: dict[tuple[int, ...], Optional[float]] = {}
widths: dict[tuple[int, ...], float] = {}
for i, (dev, slcs) in enumerate(device_indices_map.items()):
assert slcs is not None

View File

@ -21,8 +21,8 @@ import dataclasses
from functools import partial
import itertools
import time
from typing import (Any, Callable, Dict, Iterator, Optional,
Set, Tuple, List, Union, NamedTuple, Sequence)
from typing import (Any, Callable, Iterator, Optional, Union, NamedTuple,
Sequence)
import logging
import os
import re
@ -94,7 +94,7 @@ _on_exit = False
### op-by-op execution
ArgSpec = Tuple[core.AbstractValue, Optional[Device]]
ArgSpec = tuple[core.AbstractValue, Optional[Device]]
def arg_spec(x: Any) -> ArgSpec:
from jax._src import pjit
@ -150,9 +150,9 @@ def simple_impl(prim):
RuntimeToken = Any
class RuntimeTokenSet(threading.local):
tokens: Dict[core.Effect, Tuple[RuntimeToken, Device]]
output_tokens: Dict[Device, RuntimeToken]
output_runtime_tokens: Dict[Device, RuntimeToken]
tokens: dict[core.Effect, tuple[RuntimeToken, Device]]
output_tokens: dict[Device, RuntimeToken]
output_runtime_tokens: dict[Device, RuntimeToken]
def __init__(self):
self.tokens = {}
@ -324,7 +324,7 @@ class SourceInfo(NamedTuple):
def jaxpr_shardings(
jaxpr) -> Iterator[Tuple[XLACompatibleSharding, SourceInfo]]:
jaxpr) -> Iterator[tuple[XLACompatibleSharding, SourceInfo]]:
from jax._src import pjit
from jax.experimental import shard_map
@ -368,7 +368,7 @@ def _is_bint_axis_size(d: core.AxisSize) -> bool:
return False
def _prune_unused_inputs(
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
jaxpr: core.Jaxpr) -> tuple[core.Jaxpr, set[int], set[int]]:
used_outputs = [True] * len(jaxpr.outvars)
new_jaxpr, used_consts, used_inputs = pe.dce_jaxpr_consts(jaxpr, used_outputs)
kept_const_idx = {i for i, b in enumerate(used_consts) if b}
@ -534,7 +534,7 @@ def _cache_write(cache_key: str,
compile_time_secs: float,
module_name: str,
backend: Backend, executable: xc.LoadedExecutable,
host_callbacks: List[Any]):
host_callbacks: list[Any]):
"""Writes `serialized_computation` to the persistent compilation cache."""
if host_callbacks:
logger.info(

View File

@ -22,8 +22,7 @@
import builtins
import functools
from typing import (cast, overload, Any, Dict, List, Literal, Optional, Set,
Tuple, Type, Union)
from typing import cast, overload, Any, Literal, Optional, Union
import warnings
import ml_dtypes
@ -38,23 +37,23 @@ traceback_util.register_exclusion(__file__)
FLAGS = flags.FLAGS
# TODO(frostig,mattjj): achieve this w/ a protocol instead of registry?
opaque_dtypes: Set[OpaqueDType] = set()
opaque_dtypes: set[OpaqueDType] = set()
def is_opaque_dtype(dtype: Any) -> bool:
return type(dtype) in opaque_dtypes
# fp8 support
# TODO(jakevdp): remove this if statement when minimum ml_dtypes version > 0.1
float8_e4m3b11fnuz: Optional[Type[np.generic]] = None
float8_e4m3fn: Type[np.generic] = ml_dtypes.float8_e4m3fn
float8_e5m2: Type[np.generic] = ml_dtypes.float8_e5m2
float8_e4m3b11fnuz: Optional[type[np.generic]] = None
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2
_float8_e4m3b11fnuz_dtype: Optional[np.dtype] = None
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
_float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2)
# bfloat16 support
bfloat16: Type[np.generic] = ml_dtypes.bfloat16
bfloat16: type[np.generic] = ml_dtypes.bfloat16
_bfloat16_dtype: np.dtype = np.dtype(bfloat16)
_custom_float_scalar_types = [
@ -74,9 +73,9 @@ if hasattr(ml_dtypes, "float8_e4m3b11fnuz"):
_custom_float_scalar_types.insert(0, float8_e4m3b11fnuz) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e4m3b11fnuz_dtype) # type: ignore[arg-type]
int4: Optional[Type[np.generic]] = None
int4: Optional[type[np.generic]] = None
_int4_dtype: Optional[np.dtype] = None
uint4: Optional[Type[np.generic]] = None
uint4: Optional[type[np.generic]] = None
_uint4_dtype: Optional[np.dtype] = None
if hasattr(ml_dtypes, "int4"):
@ -91,12 +90,12 @@ int_: type = np.int32 if config.jax_default_dtype_bits == '32' else np.int64
uint: type = np.uint32 if config.jax_default_dtype_bits == '32' else np.uint64
float_: type = np.float32 if config.jax_default_dtype_bits == '32' else np.float64
complex_: type = np.complex64 if config.jax_default_dtype_bits == '32' else np.complex128
_default_types: Dict[str, type] = {'b': bool_, 'i': int_, 'u': uint, 'f': float_, 'c': complex_}
_default_types: dict[str, type] = {'b': bool_, 'i': int_, 'u': uint, 'f': float_, 'c': complex_}
# Trivial vectorspace datatype needed for tangent values of int/bool primals
float0: np.dtype = np.dtype([('float0', np.void, 0)])
_dtype_to_32bit_dtype: Dict[DType, DType] = {
_dtype_to_32bit_dtype: dict[DType, DType] = {
np.dtype('int64'): np.dtype('int32'),
np.dtype('uint64'): np.dtype('uint32'),
np.dtype('float64'): np.dtype('float32'),
@ -106,7 +105,7 @@ _dtype_to_32bit_dtype: Dict[DType, DType] = {
# Note: we promote narrow types to float32 here for backward compatibility
# with earlier approaches. We might consider revisiting this, or perhaps
# tying the logic more closely to the type promotion lattice.
_dtype_to_inexact: Dict[DType, DType] = {
_dtype_to_inexact: dict[DType, DType] = {
np.dtype(k): np.dtype(v) for k, v in [
('bool', 'float32'),
('uint8', 'float32'), ('int8', 'float32'),
@ -163,7 +162,7 @@ def canonicalize_dtype(dtype: Any, allow_opaque_dtype: bool = False) -> Union[DT
return _canonicalize_dtype(config.x64_enabled, allow_opaque_dtype, dtype)
# Default dtypes corresponding to Python scalars.
python_scalar_dtypes : Dict[type, DType] = {
python_scalar_dtypes : dict[type, DType] = {
bool: np.dtype('bool'),
int: np.dtype('int64'),
float: np.dtype('float64'),
@ -304,9 +303,9 @@ issubsctype = np.issubsctype
JAXType = Union[type, DType]
# Enumeration of all valid JAX types in order.
_weak_types: List[JAXType] = [int, float, complex]
_bool_types: List[JAXType] = [np.dtype(bool)]
_int_types: List[JAXType]
_weak_types: list[JAXType] = [int, float, complex]
_bool_types: list[JAXType] = [np.dtype(bool)]
_int_types: list[JAXType]
if int4 is not None:
_int_types = [
np.dtype(uint4),
@ -332,13 +331,13 @@ else:
np.dtype('int64'),
]
_float_types: List[JAXType] = [
_float_types: list[JAXType] = [
*_custom_float_dtypes,
np.dtype('float16'),
np.dtype('float32'),
np.dtype('float64'),
]
_complex_types: List[JAXType] = [
_complex_types: list[JAXType] = [
np.dtype('complex64'),
np.dtype('complex128'),
]
@ -355,11 +354,11 @@ def _jax_type(dtype: DType, weak_type: bool) -> JAXType:
return type(dtype.type(0).item())
return dtype
def _dtype_and_weaktype(value: Any) -> Tuple[DType, bool]:
def _dtype_and_weaktype(value: Any) -> tuple[DType, bool]:
"""Return a (dtype, weak_type) tuple for the given input."""
return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> Dict[JAXType, List[JAXType]]:
def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, list[JAXType]]:
"""
Return the type promotion lattice in the form of a DAG.
This DAG maps each type to its immediately higher type on the lattice.
@ -373,7 +372,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> Dict[JAXType, Lis
c4, c8 = _complex_types
i_, f_, c_ = _weak_types
if jax_numpy_dtype_promotion == 'standard':
out: Dict[JAXType, List[JAXType]]
out: dict[JAXType, list[JAXType]]
out = {
b1: [i_],
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
@ -400,7 +399,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> Dict[JAXType, Lis
raise ValueError(
f"Unexpected value of jax_numpy_dtype_promotion={jax_numpy_dtype_promotion!r}")
def _make_lattice_upper_bounds(jax_numpy_dtype_promotion: str) -> Dict[JAXType, Set[JAXType]]:
def _make_lattice_upper_bounds(jax_numpy_dtype_promotion: str) -> dict[JAXType, set[JAXType]]:
lattice = _type_promotion_lattice(jax_numpy_dtype_promotion)
upper_bounds = {node: {node} for node in lattice}
for n in lattice:
@ -413,7 +412,7 @@ def _make_lattice_upper_bounds(jax_numpy_dtype_promotion: str) -> Dict[JAXType,
upper_bounds[n] |= new_upper_bounds
return upper_bounds
_lattice_upper_bounds: Dict[str, Dict[JAXType, Set[JAXType]]] = {
_lattice_upper_bounds: dict[str, dict[JAXType, set[JAXType]]] = {
'standard': _make_lattice_upper_bounds('standard'),
'strict': _make_lattice_upper_bounds('strict'),
}
@ -533,7 +532,7 @@ def dtype(x: Any, *, canonicalize: bool = False) -> DType:
"type. Only arrays of numeric types are supported by JAX.")
return canonicalize_dtype(dt, allow_opaque_dtype=True) if canonicalize else dt
def _lattice_result_type(*args: Any) -> Tuple[DType, bool]:
def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
if len(dtypes) == 1:
out_dtype = dtypes[0]
@ -559,15 +558,15 @@ def _lattice_result_type(*args: Any) -> Tuple[DType, bool]:
return out_dtype, (out_dtype != bool_) and out_weak_type
@overload
def result_type(*args: Any, return_weak_type_flag: Literal[True]) -> Tuple[DType, bool]: ...
def result_type(*args: Any, return_weak_type_flag: Literal[True]) -> tuple[DType, bool]: ...
@overload
def result_type(*args: Any, return_weak_type_flag: Literal[False] = False) -> DType: ...
@overload
def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, Tuple[DType, bool]]: ...
def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, tuple[DType, bool]]: ...
def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, Tuple[DType, bool]]:
def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, tuple[DType, bool]]:
"""Convenience function to apply JAX argument dtype promotion.
Args:

View File

@ -13,12 +13,12 @@
# limitations under the License.
from __future__ import annotations
from typing import Any, Iterable, Optional, Type, Set
from typing import Any, Iterable, Optional
class Effect:
"""A generic side-effect."""
Effects = Set[Effect]
Effects = set[Effect]
class JaxprInputEffect(Effect):
"""A side-effect associated with the input of a jaxpr.
@ -48,9 +48,9 @@ class JaxprInputEffect(Effect):
class EffectTypeSet:
def __init__(self):
self._effect_types: Set[Type[Effect]] = set()
self._effect_types: set[type[Effect]] = set()
def add_type(self, effect_type: Type[Effect]):
def add_type(self, effect_type: type[Effect]):
self._effect_types.add(effect_type)
def contains(self, eff: Effect) -> bool:

View File

@ -16,7 +16,7 @@ import contextlib
import functools
import itertools as it
from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union
from typing import Any, Callable, Sequence, Optional, Union
import jax
from jax._src import linear_util as lu
@ -45,8 +45,8 @@ def identity(x): return x
def _update_annotation(
f: lu.WrappedFun,
orig_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]],
explicit_nonzeros: List[bool]
orig_type: Optional[tuple[tuple[core.AbstractValue, bool], ...]],
explicit_nonzeros: list[bool]
) -> lu.WrappedFun:
if orig_type is None:
return f
@ -221,13 +221,13 @@ def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
if not is_undefined_primal(val):
primal_env[v] = val
primal_env: Dict[Any, Any] = {}
primal_env: dict[Any, Any] = {}
map(write_primal, jaxpr.constvars, consts)
# FIXME: invars can contain both primal and tangent values, and this line
# forces primal_in to contain UndefinedPrimals for tangent values!
map(write_primal, jaxpr.invars, primals_in)
ct_env: Dict[Any, Any] = {}
ct_env: dict[Any, Any] = {}
ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
else contextlib.nullcontext())
with ctx:
@ -481,17 +481,17 @@ def _primal_tangent_shapes_match(primal, tangent):
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
call_param_updaters: Dict[core.Primitive, Callable] = {}
call_transpose_param_updaters: Dict[core.Primitive, Callable] = {}
call_param_updaters: dict[core.Primitive, Callable] = {}
call_transpose_param_updaters: dict[core.Primitive, Callable] = {}
# -------------------- Primitives --------------------
primitive_jvps : Dict[core.Primitive, Callable] = {}
primitive_jvps : dict[core.Primitive, Callable] = {}
primitive_transposes: Dict[core.Primitive, Callable] = {}
primitive_transposes: dict[core.Primitive, Callable] = {}
# transpose rules that internally perform reductions over the given named axes
reducing_transposes: Dict[core.Primitive, Callable] = {}
reducing_transposes: dict[core.Primitive, Callable] = {}
def deflinear(primitive, transpose_rule):
@ -693,7 +693,7 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
instantiate: Union[bool, Sequence[bool]]
) -> Tuple[core.ClosedJaxpr, List[bool]]:
) -> tuple[core.ClosedJaxpr, list[bool]]:
if type(instantiate) is bool:
instantiate = (instantiate,) * len(jaxpr.out_avals)
return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate))

View File

@ -15,8 +15,7 @@ from __future__ import annotations
import dataclasses
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Optional,
Sequence, Set, Tuple, Type, Union)
from typing import Any, Callable, Iterable, Optional, Sequence, Union
import numpy as np
@ -115,7 +114,7 @@ class RaggedAxis:
# For each axis, we store its index and the corresponding segment lengths.
# For example, the pile i:(Fin 3) => f32[lens1.i, 7, lens2.i]
# would be represented with ragged_axes = [(1, lens1), (3, lens2)]
ragged_axes: List[Tuple[int, Array]]
ragged_axes: list[tuple[int, Array]]
@property
def size(self):
@ -139,7 +138,7 @@ class RaggedAxis:
return RaggedAxis(self.stacked_axis, new_ragged_axes)
def make_batch_axis(
ndim: int, stacked_axis: int, ragged_axes: List[Tuple[int, Array]]
ndim: int, stacked_axis: int, ragged_axes: list[tuple[int, Array]]
) -> Union[int, RaggedAxis]:
if ragged_axes:
canonical = [(canonicalize_axis(ax, ndim), sz) for ax, sz in ragged_axes]
@ -241,7 +240,7 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
if spec is not None else x)
else:
assert False
to_elt_handlers: Dict[Type, ToEltHandler] = {}
to_elt_handlers: dict[type, ToEltHandler] = {}
def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
) -> Vmappable:
@ -257,7 +256,7 @@ def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
return _pile_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val)
else:
return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
from_elt_handlers: Dict[Type, FromEltHandler] = {}
from_elt_handlers: dict[type, FromEltHandler] = {}
def make_iota(axis_size: AxisSize) -> Array:
handler = make_iota_handlers.get(type(axis_size))
@ -265,9 +264,9 @@ def make_iota(axis_size: AxisSize) -> Array:
return handler(axis_size)
else:
return jax.lax.iota('int32', int(axis_size))
make_iota_handlers: Dict[Type, MakeIotaHandler] = {}
make_iota_handlers: dict[type, MakeIotaHandler] = {}
def register_vmappable(data_type: Type, spec_type: Type, axis_size_type: Type,
def register_vmappable(data_type: type, spec_type: type, axis_size_type: type,
to_elt: Callable, from_elt: Callable,
make_iota: Optional[Callable]):
vmappables[data_type] = (spec_type, axis_size_type)
@ -275,10 +274,10 @@ def register_vmappable(data_type: Type, spec_type: Type, axis_size_type: Type,
to_elt_handlers[data_type] = to_elt
from_elt_handlers[data_type] = from_elt
if make_iota: make_iota_handlers[axis_size_type] = make_iota
vmappables: Dict[Type, Tuple[Type, Type]] = {}
spec_types: Set[Type] = {PileAxis}
vmappables: dict[type, tuple[type, type]] = {}
spec_types: set[type] = {PileAxis}
def unregister_vmappable(data_type: Type) -> None:
def unregister_vmappable(data_type: type) -> None:
spec_type, axis_size_type = vmappables.pop(data_type)
spec_types.remove(spec_type)
del to_elt_handlers[data_type]
@ -591,8 +590,8 @@ def _main_trace_for_axis_names(main_trace: core.MainTrace,
### API for batching callables with vmappable inputs and outputs
def batch(fun: lu.WrappedFun, axis_name: AxisName, axis_size,
in_dims, out_dim_dests, main_type: Type[BatchTrace] = BatchTrace,
spmd_axis_name: Optional[Tuple[AxisName, ...]] = None
in_dims, out_dim_dests, main_type: type[BatchTrace] = BatchTrace,
spmd_axis_name: Optional[tuple[AxisName, ...]] = None
) -> lu.WrappedFun:
# we split up _batch_inner and _batch_outer for the leak checker
f = _batch_inner(fun, axis_size, out_dim_dests)
@ -624,11 +623,11 @@ def _batch_inner(axis_size, out_dim_dests, main, in_dims, *in_vals):
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
def vtile(f_flat: lu.WrappedFun,
in_axes_flat: Tuple[Optional[int], ...],
out_axes_flat: Tuple[Optional[int], ...],
in_axes_flat: tuple[Optional[int], ...],
out_axes_flat: tuple[Optional[int], ...],
tile_size: Optional[int],
axis_name: AxisName,
main_type: Type[BatchTrace] = BatchTrace):
main_type: type[BatchTrace] = BatchTrace):
@curry
def tile_axis(arg, axis: Optional[int], tile_size):
if axis is None:
@ -673,22 +672,22 @@ def batch_subtrace(main, in_dims, *in_vals):
def batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
axis_size: core.AxisSize,
in_axes: Tuple[Union[int, NotMapped], ...],
in_axes: tuple[Union[int, NotMapped], ...],
axis_name: AxisName,
spmd_axis_name: AxisName,
main_type: Type[BatchTrace],
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
main_type: type[BatchTrace],
) -> tuple[core.ClosedJaxpr, tuple[Union[int, NotMapped], ...]]:
return _batch_jaxpr2(closed_jaxpr, axis_size, tuple(in_axes), axis_name,
spmd_axis_name, main_type)
@weakref_lru_cache
def _batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
axis_size: core.AxisSize,
in_axes: Tuple[Union[int, NotMapped], ...],
in_axes: tuple[Union[int, NotMapped], ...],
axis_name: AxisName,
spmd_axis_name: AxisName,
main_type: Type[BatchTrace],
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
main_type: type[BatchTrace],
) -> tuple[core.ClosedJaxpr, tuple[Union[int, NotMapped], ...]]:
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f, out_axes = _batch_jaxpr_inner(f, axis_size)
f = _batch_jaxpr_outer(f, axis_name, spmd_axis_name, axis_size, in_axes,
@ -862,10 +861,10 @@ def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False)
### utilities for defining primitives' batching rules
BatchingRule = Callable[..., Tuple[Any, Union[None, int, Tuple[Union[None, int], ...]]]]
primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
axis_primitive_batchers: Dict[core.Primitive, Callable] = {}
spmd_axis_primitive_batchers: Dict[core.Primitive, Callable] = {}
BatchingRule = Callable[..., tuple[Any, Union[None, int, tuple[Union[None, int], ...]]]]
primitive_batchers : dict[core.Primitive, BatchingRule] = {}
axis_primitive_batchers: dict[core.Primitive, Callable] = {}
spmd_axis_primitive_batchers: dict[core.Primitive, Callable] = {}
def defvectorized(prim):
primitive_batchers[prim] = partial(vectorized_batcher, prim)

View File

@ -24,8 +24,8 @@ import itertools
import operator
import re
import typing
from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional,
Protocol, Sequence, Set, Tuple, Type, Union)
from typing import (Any, Callable, Iterator, NamedTuple, Optional,
Protocol, Sequence, Union)
import warnings
import numpy as np
@ -114,7 +114,7 @@ def delegate_lowering(ctx, lowering_fun, *args, **ctx_override_kwargs):
# IR Types
# Non-canonicalized dtype to IR type mapping.
_dtype_to_ir_type : Dict[np.dtype, Callable[[], ir.Type]] = {
_dtype_to_ir_type : dict[np.dtype, Callable[[], ir.Type]] = {
np.dtype(dtypes.float0): partial(ir.IntegerType.get_signless, 1),
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
@ -165,7 +165,7 @@ def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
shape = [d if type(d) is int else dyn_size for d in aval.shape]
return (ir.RankedTensorType.get(shape, dtype_to_ir_type(aval.dtype)),)
ir_type_handlers: Dict[Type[core.AbstractValue],
ir_type_handlers: dict[type[core.AbstractValue],
Callable[[Any], Sequence[ir.Type]]] = {}
def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]:
@ -203,7 +203,7 @@ class ConstantHandler(Protocol):
A JAX value is represented by zero or more IR values."""
_constant_handlers : Dict[type, ConstantHandler] = {}
_constant_handlers : dict[type, ConstantHandler] = {}
def register_constant_handler(type_: type, handler_fun: ConstantHandler):
_constant_handlers[type_] = handler_fun
@ -338,7 +338,7 @@ def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
return ir.Location.callsite(frame_locs[-1], frame_locs[-2::-1])
def _source_info_to_location(
primitive: core.Primitive, params: Dict,
primitive: core.Primitive, params: dict,
source_info: source_info_util.SourceInfo,
name_stack: source_info_util.NameStack) -> ir.Location:
eqn_str = (f'{str(source_info.name_stack)}/'
@ -410,15 +410,15 @@ class ModuleContext:
platform: str
axis_context: AxisContext
name_stack: source_info_util.NameStack
keepalives: List[Any]
keepalives: list[Any]
channel_iterator: Iterator[int]
host_callbacks: List[Any]
host_callbacks: list[Any]
# Keep state for the lowering of shape polymorphism
shape_poly_state: ShapePolyLoweringState
# Cached primitive lowerings.
cached_primitive_lowerings: Dict[Any, func_dialect.FuncOp]
cached_call_jaxpr_lowerings: Dict[Any, func_dialect.FuncOp]
cached_primitive_lowerings: dict[Any, func_dialect.FuncOp]
cached_call_jaxpr_lowerings: dict[Any, func_dialect.FuncOp]
@property
@ -431,16 +431,16 @@ class ModuleContext:
platform: str,
axis_context: AxisContext,
name_stack: source_info_util.NameStack,
keepalives: List[Any],
keepalives: list[Any],
channel_iterator: Iterator[int],
host_callbacks: List[Any],
host_callbacks: list[Any],
context: Optional[ir.Context] = None,
module: Optional[ir.Module] = None,
ip: Optional[ir.InsertionPoint] = None,
symbol_table: Optional[ir.SymbolTable] = None,
cached_primitive_lowerings: Optional[Dict[Any,
cached_primitive_lowerings: Optional[dict[Any,
func_dialect.FuncOp]] = None,
cached_call_jaxpr_lowerings: Optional[Dict[Any,
cached_call_jaxpr_lowerings: Optional[dict[Any,
func_dialect.FuncOp]] = None,
shape_poly_state = None):
assert platform is not None
@ -489,7 +489,7 @@ class LoweringRuleContext:
avals_out: Any # Usually Sequence[core.AbstractValue], but sometimes None.
tokens_in: TokenSet
tokens_out: Optional[TokenSet] # Mutable store for output containers
axis_size_env: Optional[Dict[core.Var, ir.Value]] = None # Dynamic axis sizes
axis_size_env: Optional[dict[core.Var, ir.Value]] = None # Dynamic axis sizes
dim_var_values: Sequence[ir.Value] = () # The values for the dimension variables
# in same order as module_context.shape_poly_state.dim_vars
@ -509,8 +509,8 @@ if not MYPY:
else:
LoweringRule = Any
_lowerings: Dict[core.Primitive, LoweringRule] = {}
_platform_specific_lowerings: Dict[str, Dict[core.Primitive, LoweringRule]]
_lowerings: dict[core.Primitive, LoweringRule] = {}
_platform_specific_lowerings: dict[str, dict[core.Primitive, LoweringRule]]
_platform_specific_lowerings = collections.defaultdict(dict)
def register_lowering(prim: core.Primitive, rule: LoweringRule,
@ -553,7 +553,7 @@ def sharded_aval(aval: core.AbstractValue,
def eval_dynamic_shape(ctx: LoweringRuleContext,
shape: core.Shape) -> Tuple[Union[int, Value], ...]:
shape: core.Shape) -> tuple[Union[int, Value], ...]:
if config.jax_dynamic_shapes:
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
else:
@ -569,7 +569,7 @@ def eval_dynamic_shape(ctx: LoweringRuleContext,
# TODO: replace usage of eval_dynamic_shape_as_vals with eval_dynamic_shape_as_ivals
def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
shape: core.Shape) -> Tuple[Value, ...]:
shape: core.Shape) -> tuple[Value, ...]:
"""Evaluates the dynamic shapes as int32 values."""
def convert_dim(d: Union[int, Value]):
if type(d) is int:
@ -585,7 +585,7 @@ def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
def eval_dynamic_shape_as_ivals(
ctx: LoweringRuleContext, shape: core.Shape
) -> Tuple[Union[int, Value], ...]:
) -> tuple[Union[int, Value], ...]:
"""Evaluates the dynamic shapes as int or ir.int32 values."""
def convert_dim(d: Union[int, Value]) -> Union[int, ir.Value]:
if type(d) is int:
@ -602,7 +602,7 @@ def eval_dynamic_shape_as_ivals(
class LoweringResult(NamedTuple):
module: ir.Module
keepalive: Optional[Any]
host_callbacks: List[Any]
host_callbacks: list[Any]
shape_poly_state: ShapePolyLoweringState
@ -622,7 +622,7 @@ def _to_logical_op_sharding(
def lower_jaxpr_to_module(
module_name: str,
jaxpr: core.ClosedJaxpr,
ordered_effects: List[core.Effect],
ordered_effects: list[core.Effect],
backend_or_name: Optional[Union[str, xb.XlaBackend]],
platform: str,
axis_context: AxisContext,
@ -665,8 +665,8 @@ def lower_jaxpr_to_module(
# HLO channels need to start at 1
channel_iter = itertools.count(1)
# Create a keepalives list that will be mutated during the lowering.
keepalives: List[Any] = []
host_callbacks: List[Any] = []
keepalives: list[Any] = []
host_callbacks: list[Any] = []
dim_vars: Sequence[str]
if not config.jax_dynamic_shapes:
@ -786,7 +786,7 @@ class TokenSet:
tokens = [create_token() for _ in effects]
return TokenSet(zip(effects, tokens))
def items(self) -> Sequence[Tuple[core.Effect, Token]]:
def items(self) -> Sequence[tuple[core.Effect, Token]]:
return tuple(self._tokens.items())
def effects(self) -> set[core.Effect]:
@ -934,7 +934,7 @@ def lower_jaxpr_to_fun(
or arg_names is not None
or num_tokens > 0
):
arg_attrs: List[Dict[str, ir.Attribute]] = [
arg_attrs: list[dict[str, ir.Attribute]] = [
{} for _ in range(len(flat_input_types))]
if replicated_args is not None:
@ -952,7 +952,7 @@ def lower_jaxpr_to_fun(
if input_output_aliases is not None:
output_ids = util.unflatten(list(range(len(flat_output_types))),
map(len, output_types))
aliases: List[Optional[int]] = []
aliases: list[Optional[int]] = []
for types, alias in zip(input_types, input_output_aliases):
if alias is None:
aliases.extend([None] * len(types))
@ -977,7 +977,7 @@ def lower_jaxpr_to_fun(
func_op.arg_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in arg_attrs])
result_attrs: List[Dict[str, ir.Attribute]] = [
result_attrs: list[dict[str, ir.Attribute]] = [
{} for _ in range(len(flat_output_types))]
if num_tokens > 0:
@ -1020,7 +1020,7 @@ def lower_jaxpr_to_fun(
tokens_in = TokenSet.create(effects)
else:
tokens_in = TokenSet(zip(effects, token_args))
args: List[List[ir.Value]] = []
args: list[list[ir.Value]] = []
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
if replace_tokens_with_dummy and aval is core.abstract_token:
args.append(hlo.CreateTokenOp().results)
@ -1104,7 +1104,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
consts: Sequence[Sequence[ir.Value]],
*args: Sequence[ir.Value],
dim_var_values: Sequence[ir.Value]
) -> Tuple[Sequence[Sequence[ir.Value]], TokenSet]:
) -> tuple[Sequence[Sequence[ir.Value]], TokenSet]:
"""Lowers a jaxpr into MLIR, inlined into an existing function.
Assumes that an MLIR context, location, and insertion point are set.
@ -1131,7 +1131,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
env[v] = tuple(node)
env: Dict[core.Var, Tuple[ir.Value, ...]] = {}
env: dict[core.Var, tuple[ir.Value, ...]] = {}
assert len(args) == len(jaxpr.invars), (jaxpr, args)
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
@ -1547,7 +1547,7 @@ def _wrap_with_spmd_op(name: str,
x: ir.Value,
aval_out: core.AbstractValue,
sharding_proto: xc.OpSharding,
unspecified_dims: Optional[Set[int]] = None):
unspecified_dims: Optional[set[int]] = None):
# unspecified_dims indicate dimensions whose shardings are not specified and
# XLA sharding propagation can change them.
if unspecified_dims:
@ -1810,13 +1810,13 @@ def _emit_tpu_python_callback(
callback,
token: Optional[Any],
operands: Sequence[ir.Value],
operand_avals: List[core.ShapedArray],
operand_shapes: List[xc.Shape],
result_avals: List[core.ShapedArray],
result_shapes: List[xc.Shape],
operand_avals: list[core.ShapedArray],
operand_shapes: list[xc.Shape],
result_avals: list[core.ShapedArray],
result_shapes: list[xc.Shape],
*,
sharding: Optional[xc.OpSharding] = None
) -> Tuple[List[ir.Value], Any, Any]:
) -> tuple[list[ir.Value], Any, Any]:
token = token or hlo.CreateTokenOp().result
_wrapped_callback = callback
@ -1886,12 +1886,12 @@ def _aval_to_default_layouts(aval):
def emit_python_callback(
ctx: LoweringRuleContext, callback, token: Optional[Any],
operands: Sequence[ir.Value], operand_avals: List[core.ShapedArray],
result_avals: List[core.ShapedArray],
operands: Sequence[ir.Value], operand_avals: list[core.ShapedArray],
result_avals: list[core.ShapedArray],
has_side_effect: bool, *, sharding: Optional[xc.OpSharding] = None,
operand_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None,
result_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None,
) -> Tuple[List[ir.Value], Any, Any]:
) -> tuple[list[ir.Value], Any, Any]:
"""Emits MLIR that calls back to a provided Python function."""
platform = ctx.module_context.platform
if platform not in {"cpu", "cuda", "rocm", "tpu"}:
@ -2012,7 +2012,7 @@ def custom_call(
result_shapes: Optional[Sequence[ir.Value]] = None,
called_computations: Sequence[str] = (),
api_version: int = 2,
extra_attributes: Dict[str, ir.Attribute] = {},
extra_attributes: dict[str, ir.Attribute] = {},
) -> ir.Operation:
"""Wraps a hlo.CustomCall.
@ -2063,7 +2063,7 @@ def reduce_window(
# d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each
# spatial dimension.
int2d = aval_to_ir_type(core.ShapedArray((1, 2), np.int32))
def prep_one_pad(pad_lo_hi: Tuple[core.DimSize, core.DimSize]):
def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):
pads = shape_tensor(eval_dynamic_shape(ctx, pad_lo_hi)) # i32[2]
return hlo.ReshapeOp(int2d, pads)
d_padding = hlo.ConcatenateOp(list(map(prep_one_pad, padding)),

View File

@ -20,8 +20,8 @@ from functools import partial
import inspect
import itertools as it
import operator as op
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
List, Union, Hashable, Set)
from typing import (Any, Callable, NamedTuple, Optional, Sequence,
Union, Hashable)
from weakref import ref
import numpy as np
@ -62,7 +62,7 @@ ConstId = int
def _update_annotation_known(
f: lu.WrappedFun,
orig_type: Optional[InputType],
in_knowns: List[bool]
in_knowns: list[bool]
) -> lu.WrappedFun:
if orig_type is None: return f
# orig_type might contain DBIdx, but we're tossing out some args so we have to
@ -104,7 +104,7 @@ class PartialVal(tuple):
* `(<AbstractValue>, None)` indicates an unknown value characterized by an
abstract value.
"""
def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
def __new__(cls, xs: tuple[Optional[AbstractValue], core.Value]):
pv, const = xs
if config.jax_enable_checks:
# type checks
@ -549,8 +549,8 @@ class JaxprTrace(Trace['JaxprTracer']):
raise NotImplementedError # TODO(mattjj)
def partition_pvals(
pvals: List[PartialVal]
) -> Tuple[List[bool], List[AbstractValue], List[Any]]:
pvals: list[PartialVal]
) -> tuple[list[bool], list[AbstractValue], list[Any]]:
knowns = [pval.is_known() for pval in pvals ]
avals = [pval.get_aval() for pval in pvals if not pval.is_known()]
consts = [pval.get_known() for pval in pvals if pval.is_known()]
@ -581,7 +581,7 @@ def trace_to_subjaxpr_nounits_dyn(
# for all axis sizes, so that we can then use those tracers in the shapes of
# avals for unknown inputs' tracers. We use ConstVar recipes for on-the-fly
# type agreement checking via get_referent.
in_consts_full: List[Optional[JaxprTracer]] = [None] * len(in_type)
in_consts_full: list[Optional[JaxprTracer]] = [None] * len(in_type)
in_consts_iter, in_knowns_iter = iter(in_consts), iter(in_knowns)
for idx, (aval, explicit) in enumerate(in_type):
if explicit and next(in_knowns_iter):
@ -631,8 +631,8 @@ def trace_to_subjaxpr_nounits_dyn(
for inst, t in zip(instantiate, out_tracers)]
# Collect known outputs.
out_knowns: List[bool] = [t.is_known() for t in out_tracers]
out_consts: List[Any] = [t.pval.get_known() for t in out_tracers
out_knowns: list[bool] = [t.is_known() for t in out_tracers]
out_consts: list[Any] = [t.pval.get_known() for t in out_tracers
if t.is_known()]
# Build the jaxpr.
@ -647,7 +647,7 @@ def trace_to_subjaxpr_nounits_dyn(
# Which residuals are just forwarded inputs? Check obj id, then prune.
id_map = {id(c.recipe.val): i for i, c in enumerate(in_consts_full) # type: ignore
if c is not None}
fwds: List[Optional[int]] = [id_map.get(id(c)) for c in res]
fwds: list[Optional[int]] = [id_map.get(id(c)) for c in res]
res = tuple([c for c, fwd in zip(res, fwds) if fwd is None])
del main, in_consts, trace, in_consts_iter, in_knowns_iter, in_consts_full, \
@ -655,9 +655,9 @@ def trace_to_subjaxpr_nounits_dyn(
yield (*out_consts, *res), (fwds, out_knowns, tuple(out_type), jaxpr, env)
custom_partial_eval_rules: Dict[Primitive, Callable] = {}
call_partial_eval_rules: Dict[Primitive, Callable] = {}
call_param_updaters: Dict[Primitive, Callable] = {}
custom_partial_eval_rules: dict[Primitive, Callable] = {}
call_partial_eval_rules: dict[Primitive, Callable] = {}
call_param_updaters: dict[Primitive, Callable] = {}
def _closed_call_param_updater(params, _, __):
jaxpr = params.get('call_jaxpr')
@ -733,7 +733,7 @@ class JaxprTracer(Tracer):
def trace_to_jaxpr(
fun: lu.WrappedFun, pvals: Sequence[PartialVal],
instantiate: Union[bool, Sequence[bool]] = False,
) -> Tuple[Jaxpr, List[PartialVal], List[core.Value]]:
) -> tuple[Jaxpr, list[PartialVal], list[core.Value]]:
"""
Partially evaluate a function, building a jaxpr for un-evaluated computation.
@ -770,7 +770,7 @@ def trace_to_jaxpr(
def trace_to_jaxpr_nounits(
fun: lu.WrappedFun, pvals: Sequence[PartialVal],
instantiate: Union[bool, Sequence[bool]] = False,
) -> Tuple[Jaxpr, List[PartialVal], List[core.Value]]:
) -> tuple[Jaxpr, list[PartialVal], list[core.Value]]:
current_name_stack = source_info_util.current_name_stack()
with core.new_main(JaxprTrace, name_stack=current_name_stack) as main:
fun = trace_to_subjaxpr_nounits(fun, main, instantiate)
@ -828,7 +828,7 @@ def trace_to_subjaxpr_nounits_fwd(
# Which out_consts (aka residuals) are just forwarded inputs? Check obj id.
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
id_map = {id(c): i for i, c in enumerate(in_consts)}
fwds: List[Optional[int]] = [id_map.get(id(c)) for c in out_consts]
fwds: list[Optional[int]] = [id_map.get(id(c)) for c in out_consts]
pruned_consts = [c for c, fwd in zip(out_consts, fwds) if fwd is None]
del out_tracers
@ -844,14 +844,14 @@ class JaxprEqnRecipe(NamedTuple):
out_tracer_refs: Sequence[ref[JaxprTracer]]
out_avals: Sequence[core.AbstractValue]
primitive: Primitive
params: Dict[str, Any]
params: dict[str, Any]
effects: core.Effects
source_info: source_info_util.SourceInfo
def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
out_tracers: Sequence[JaxprTracer],
primitive: Primitive,
params: Dict[str, Any],
params: dict[str, Any],
effects: core.Effects,
source_info: source_info_util.SourceInfo
) -> JaxprEqnRecipe:
@ -882,7 +882,7 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
def tracers_to_jaxpr(
in_tracers: Sequence[JaxprTracer],
out_tracers: Sequence[JaxprTracer]
) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]:
) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]:
"""Constructs Jaxpr given tracers for inputs and outputs.
Params:
@ -896,10 +896,10 @@ def tracers_to_jaxpr(
"""
gensym = core.gensym()
t_to_var: Dict[TracerId, Var] = {}
consts: Dict[Var, Any] = {}
env: Dict[Var, JaxprTracer] = {}
constid_to_var: Dict[ConstId, Var] = {} # for deduplication
t_to_var: dict[TracerId, Var] = {}
consts: dict[Var, Any] = {}
env: dict[Var, JaxprTracer] = {}
constid_to_var: dict[ConstId, Var] = {} # for deduplication
def get_atom(t: JaxprTracer) -> Atom:
return t.recipe if type(t.recipe) is Literal else t_to_var[id(t)]
@ -920,7 +920,7 @@ def tracers_to_jaxpr(
return aval
processed_eqn_ids = set()
eqns: List[core.JaxprEqn] = []
eqns: list[core.JaxprEqn] = []
for t in toposort([*in_tracers, *out_tracers]):
r = t.recipe
if isinstance(r, JaxprEqnRecipe):
@ -1004,7 +1004,7 @@ def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
def partial_eval_jaxpr_nounits(
jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
instantiate: Union[bool, Sequence[bool]],
) -> Tuple[ClosedJaxpr, ClosedJaxpr, List[bool], List[AbstractValue]]:
) -> tuple[ClosedJaxpr, ClosedJaxpr, list[bool], list[AbstractValue]]:
"""Unzip a jaxpr in two by data dependence into 'known' and 'unknown' parts.
That is, given a jaxpr and a sequence of booleans indicating which jaxpr
@ -1122,7 +1122,7 @@ def partial_eval_jaxpr_custom(
ensure_out_unknowns: Union[bool, Sequence[bool]],
ensure_out_inst: Union[bool, Sequence[bool]],
saveable: Callable[..., bool],
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]:
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int]:
if type(in_inst) is bool:
in_inst = (in_inst,) * len(jaxpr.invars)
if type(ensure_out_unknowns) is bool:
@ -1146,7 +1146,7 @@ def partial_eval_jaxpr_stateful(
ensure_out_unknowns: Union[bool, Sequence[bool]],
ensure_out_inst: Union[bool, Sequence[bool]],
saveable: Callable[..., bool],
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int, int]:
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int, int]:
if type(in_inst) is bool:
in_inst = (in_inst,) * len(jaxpr.invars)
if type(ensure_out_unknowns) is bool:
@ -1163,17 +1163,17 @@ def partial_eval_jaxpr_stateful(
@weakref_lru_cache
def _partial_eval_jaxpr_custom_cached(
jaxpr: Jaxpr,
in_unknowns: Tuple[bool, ...],
in_inst: Tuple[bool, ...],
ensure_out_unknowns: Tuple[bool, ...],
ensure_out_inst: Tuple[bool, ...],
in_unknowns: tuple[bool, ...],
in_inst: tuple[bool, ...],
ensure_out_unknowns: tuple[bool, ...],
ensure_out_inst: tuple[bool, ...],
saveable: Callable[..., bool],
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int, int]:
env: Dict[Var, Tuple[bool, bool]] = {}
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int, int]:
env: dict[Var, tuple[bool, bool]] = {}
residuals: OrderedSet[Var] = OrderedSet()
residual_refs: OrderedSet[Var] = OrderedSet()
def read(x: Atom) -> Tuple[bool, bool]:
def read(x: Atom) -> tuple[bool, bool]:
if type(x) is Var:
return env[x]
return (False, True)
@ -1259,12 +1259,12 @@ def _partial_eval_jaxpr_custom_cached(
# * a list of Var instances representing residuals to be added (i.e. to be
# plumbed as outputs of the 'known' side jaxpr and added as input binders to
# the 'unknown' jaxpr).
PartialEvalCustomResult = Tuple[Optional[JaxprEqn], Optional[JaxprEqn],
Sequence[bool], Sequence[bool], List[Var]]
PartialEvalCustomResult = tuple[Optional[JaxprEqn], Optional[JaxprEqn],
Sequence[bool], Sequence[bool], list[Var]]
PartialEvalCustomRule = Callable[
[Callable[..., bool], Sequence[bool], Sequence[bool], JaxprEqn],
PartialEvalCustomResult]
partial_eval_jaxpr_custom_rules: Dict[Primitive, PartialEvalCustomRule] = {}
partial_eval_jaxpr_custom_rules: dict[Primitive, PartialEvalCustomRule] = {}
def partial_eval_jaxpr_custom_rule_not_implemented(
name: str, saveable: Callable[..., bool], unks_in: Sequence[bool],
@ -1276,10 +1276,10 @@ def partial_eval_jaxpr_custom_rule_not_implemented(
ParamsUpdater = Callable[[Sequence[bool], Sequence[bool], Sequence[bool],
Sequence[bool], int, dict, dict],
Tuple[dict, dict]]
ResAvalUpdater = Callable[[Dict[str, Any], AbstractValue], AbstractValue]
tuple[dict, dict]]
ResAvalUpdater = Callable[[dict[str, Any], AbstractValue], AbstractValue]
def _default_res_aval_updater(
params: Dict[str, Any], aval: AbstractValue) -> AbstractValue:
params: dict[str, Any], aval: AbstractValue) -> AbstractValue:
return aval
@contextmanager
@ -1287,10 +1287,10 @@ def trivial_ctx(_): yield
def call_partial_eval_custom_rule(
jaxpr_param_name: str, params_updater: ParamsUpdater,
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
saveable: Callable[..., bool], unks_in: list[bool], inst_in: list[bool],
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
ctx: Callable[[core.ParamDict], AbstractContextManager[None]] = trivial_ctx,
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
jaxpr = eqn.params[jaxpr_param_name]
with ctx(eqn.params):
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
@ -1319,9 +1319,9 @@ def call_partial_eval_custom_rule(
def closed_call_partial_eval_custom_rule(
jaxpr_param_name: str, params_updater: ParamsUpdater,
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
saveable: Callable[..., bool], unks_in: list[bool], inst_in: list[bool],
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.
closed_jaxpr = eqn.params[jaxpr_param_name]
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_out, num_res_ref = \
@ -1369,9 +1369,9 @@ partial_eval_jaxpr_custom_rules[core.closed_call_p] = \
lambda _, __, ___, ____, _____, x, y: (x, y))
def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
def _jaxpr_forwarding(jaxpr: Jaxpr) -> list[Optional[int]]:
# Compute which inputs are just forwarded to outputs.
fwds: Dict[Var, Var] = dict(zip(jaxpr.invars, jaxpr.invars))
fwds: dict[Var, Var] = dict(zip(jaxpr.invars, jaxpr.invars))
for eqn in jaxpr.eqns:
if eqn.primitive in forwarding_rules:
eqn = eqn.replace(invars=[a if type(a) is Literal else fwds.get(a, a) # type: ignore
@ -1380,14 +1380,14 @@ def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
for v_orig, v_new in zip(eqn.outvars, fwd_vars):
if v_new is not None:
fwds[v_orig] = v_new
idxs: Dict[Var, int] = {v: i for i, v in enumerate(jaxpr.invars)}
idxs: dict[Var, int] = {v: i for i, v in enumerate(jaxpr.invars)}
return [None if type(v) is Literal else idxs.get(fwds.get(v)) # type: ignore
for v in jaxpr.outvars]
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool],
instantiate: Union[bool, Sequence[bool]] = False,
) -> Tuple[Jaxpr, List[bool]]:
) -> tuple[Jaxpr, list[bool]]:
if type(instantiate) is bool:
instantiate = (instantiate,) * len(jaxpr.invars)
return _dce_jaxpr(jaxpr, tuple(used_outputs), tuple(instantiate))
@ -1395,7 +1395,7 @@ def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool],
def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
instantiate: Union[bool, Sequence[bool]] = False,
) -> Tuple[Jaxpr, List[bool], List[bool]]:
) -> tuple[Jaxpr, list[bool], list[bool]]:
jaxpr_ = convert_constvars_jaxpr(jaxpr)
new_jaxpr_, used_inputs_ = dce_jaxpr(jaxpr_, used_outputs)
used_consts, used_inputs = split_list(used_inputs_, [len(jaxpr.constvars)])
@ -1404,10 +1404,10 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
@weakref_lru_cache
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...],
instantiate: Tuple[bool, ...]
) -> Tuple[Jaxpr, List[bool]]:
env: Dict[Var, bool] = {}
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
instantiate: tuple[bool, ...]
) -> tuple[Jaxpr, list[bool]]:
env: dict[Var, bool] = {}
def read(v: Var) -> bool:
return env.get(v, False)
@ -1448,18 +1448,18 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...],
return new_jaxpr, used_inputs
DCERule = Callable[[List[bool], JaxprEqn], Tuple[List[bool], Optional[JaxprEqn]]]
DCERule = Callable[[list[bool], JaxprEqn], tuple[list[bool], Optional[JaxprEqn]]]
def _default_dce_rule(
used_outs: List[bool], eqn: JaxprEqn
) -> Tuple[List[bool], JaxprEqn]:
used_outs: list[bool], eqn: JaxprEqn
) -> tuple[list[bool], JaxprEqn]:
return [True] * len(eqn.invars), eqn
dce_rules: Dict[Primitive, DCERule] = {}
dce_rules: dict[Primitive, DCERule] = {}
def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
) -> Tuple[List[bool], Optional[JaxprEqn]]:
def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn
) -> tuple[list[bool], Optional[JaxprEqn]]:
new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
new_params = dict(eqn.params, call_jaxpr=new_jaxpr)
update_params = call_param_updaters.get(eqn.primitive)
@ -1476,8 +1476,8 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
dce_rules[core.call_p] = dce_jaxpr_call_rule
def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn
) -> Tuple[List[bool], JaxprEqn]:
def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn
) -> tuple[list[bool], JaxprEqn]:
# TODO(mattjj): de-duplicate with above rule?
jaxpr_ = eqn.params['call_jaxpr']
jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts
@ -1500,7 +1500,7 @@ def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]
return _move_binders_to_front(closed_jaxpr, tuple(to_move))
@weakref_lru_cache
def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Tuple[bool, ...]
def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: tuple[bool, ...]
) -> ClosedJaxpr:
assert len(closed_jaxpr.in_avals) == len(to_move)
new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move)
@ -1611,12 +1611,12 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
class JaxprStackFrame:
gensym: Callable[[AbstractValue], Var]
tracer_to_var: Dict[TracerId, Var]
constid_to_tracer: Dict[ConstId, Tracer]
constvar_to_val: Dict[Var, Any]
tracers: List[DynamicJaxprTracer] # hold onto strong refs for all tracers
eqns: List[JaxprEqn]
invars: List[Var]
tracer_to_var: dict[TracerId, Var]
constid_to_tracer: dict[ConstId, Tracer]
constvar_to_val: dict[Var, Any]
tracers: list[DynamicJaxprTracer] # hold onto strong refs for all tracers
eqns: list[JaxprEqn]
invars: list[Var]
effects: core.Effects
debug_info: Optional[DebugInfo]
@ -1634,7 +1634,7 @@ class JaxprStackFrame:
def add_eqn(self, eqn: core.JaxprEqn):
self.eqns.append(eqn)
def to_jaxpr(self, out_tracers: Sequence[Tracer]) -> Tuple[Jaxpr, List[Any]]:
def to_jaxpr(self, out_tracers: Sequence[Tracer]) -> tuple[Jaxpr, list[Any]]:
# It's not necessary, but we keep the tracer-to-var mapping injective:
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
@ -1687,9 +1687,9 @@ class JaxprStackFrame:
return invar_positions, const_eqns
def _const_folding_and_forwarding(
jaxpr: Jaxpr, constvals: Sequence[Any]) -> Tuple[Jaxpr, Tuple[Any, ...]]:
consts: Dict[Var, Any] = dict(zip(jaxpr.constvars, constvals))
var_subs: Dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined
jaxpr: Jaxpr, constvals: Sequence[Any]) -> tuple[Jaxpr, tuple[Any, ...]]:
consts: dict[Var, Any] = dict(zip(jaxpr.constvars, constvals))
var_subs: dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined
new_eqns = []
for eqn in jaxpr.eqns:
# always apply invar substitutions
@ -1723,18 +1723,18 @@ def _const_folding_and_forwarding(
jaxpr_effects, jaxpr.debug_info)
return new_jaxpr, new_constvals
ConstFoldRule = Callable[[List[Optional[Any]], JaxprEqn],
Tuple[List[Optional[Any]], Optional[JaxprEqn]]]
const_fold_rules: Dict[Primitive, ConstFoldRule] = {}
ConstFoldRule = Callable[[list[Optional[Any]], JaxprEqn],
tuple[list[Optional[Any]], Optional[JaxprEqn]]]
const_fold_rules: dict[Primitive, ConstFoldRule] = {}
ForwardingRule = Callable[[JaxprEqn],
Tuple[List[Optional[Var]], Optional[JaxprEqn]]]
forwarding_rules: Dict[Primitive, ForwardingRule] = {}
tuple[list[Optional[Var]], Optional[JaxprEqn]]]
forwarding_rules: dict[Primitive, ForwardingRule] = {}
def _inline_literals(
jaxpr: Jaxpr, constvals: Sequence[Any]
) -> Tuple[Jaxpr, List[Any]]:
) -> tuple[Jaxpr, list[Any]]:
# This function also prunes unused constants and inserts `dropvar` symbols.
input_effects = {eff for eff in jaxpr.effects
if isinstance(eff, effects.JaxprInputEffect)}
@ -1746,7 +1746,7 @@ def _inline_literals(
if type(c) in core.literalable_types and not np.shape(c) and not e}
lit: Callable[[Var], Optional[Literal]] = lits.get
newname: Callable[[AbstractValue], Var] = core.gensym()
newvars: Dict[Var, Var] = {}
newvars: dict[Var, Var] = {}
newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval))
var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval))
dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval))
@ -2052,7 +2052,7 @@ class DynamicJaxprTrace(core.Trace):
return out_tracers
custom_staging_rules: Dict[Primitive, Callable] = {}
custom_staging_rules: dict[Primitive, Callable] = {}
@lu.transformation
def _interleave_fun(every_others, *args, **kwargs):
@ -2113,7 +2113,7 @@ def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
def arg_info_all(dbg: DebugInfo) -> Optional[List[Tuple[str, KeyPath]]]:
def arg_info_all(dbg: DebugInfo) -> Optional[list[tuple[str, KeyPath]]]:
ba = None if dbg.in_tree is None else sig_info(dbg)
if ba is None: return None
return [(name, key_path) for name, dummy_arg in ba.arguments.items()
@ -2131,7 +2131,7 @@ def sig_info(dbg: DebugInfo) -> Optional[inspect.BoundArguments]:
except (TypeError, ValueError):
return None
def result_info(dbg: DebugInfo) -> Optional[List[KeyPath]]:
def result_info(dbg: DebugInfo) -> Optional[list[KeyPath]]:
if dbg.out_tree is None: return None
try:
num_leaves = dbg.out_tree().num_leaves
@ -2148,8 +2148,8 @@ def trace_to_jaxpr_dynamic(
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None,
*,
keep_inputs: Optional[List[bool]] = None,
) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]:
keep_inputs: Optional[list[bool]] = None,
) -> tuple[Jaxpr, list[AbstractValue], list[Any]]:
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
main.jaxpr_stack = () # type: ignore
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
@ -2165,7 +2165,7 @@ def trace_to_subjaxpr_dynamic(
*,
keep_inputs: Optional[Sequence[bool]] = None,
debug_info: Optional[DebugInfo] = None,
) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]:
) -> tuple[Jaxpr, list[AbstractValue], list[Any]]:
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
@ -2185,7 +2185,7 @@ def trace_to_subjaxpr_dynamic(
@profiler.annotate_function
def trace_to_jaxpr_dynamic2(
fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None
) -> Tuple[Jaxpr, OutputType, List[Any]]:
) -> tuple[Jaxpr, OutputType, list[Any]]:
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
main.jaxpr_stack = () # type: ignore
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
@ -2195,7 +2195,7 @@ def trace_to_jaxpr_dynamic2(
def trace_to_subjaxpr_dynamic2(
fun: lu.WrappedFun, main: core.MainTrace,
debug_info: Optional[DebugInfo] = None
) -> Tuple[Jaxpr, OutputType, List[Any]]:
) -> tuple[Jaxpr, OutputType, list[Any]]:
in_avals, keep_inputs = unzip2(fun.in_type)
frame = JaxprStackFrame()
frame.debug_info = debug_info
@ -2226,7 +2226,7 @@ def trace_to_jaxpr_final(
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None,
keep_inputs: Optional[Sequence[bool]] = None,
) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]:
) -> tuple[Jaxpr, list[AbstractValue], list[Any]]:
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
@ -2239,7 +2239,7 @@ def trace_to_jaxpr_final(
@profiler.annotate_function
def trace_to_jaxpr_final2(
fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None
) -> Tuple[Jaxpr, OutputType, List[Any]]:
) -> tuple[Jaxpr, OutputType, list[Any]]:
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
@ -2249,8 +2249,8 @@ def trace_to_jaxpr_final2(
AbstractedAxisName = Hashable
AbstractedAxesSpec = Union[Dict[int, AbstractedAxisName],
Tuple[AbstractedAxisName, ...]]
AbstractedAxesSpec = Union[dict[int, AbstractedAxisName],
tuple[AbstractedAxisName, ...]]
def infer_lambda_input_type(
axes_specs: Optional[Sequence[AbstractedAxesSpec]],
args: Sequence[Any]
@ -2265,7 +2265,7 @@ def infer_lambda_input_type(
lu._check_input_type(input_type)
return input_type
def _spec_to_dict(spec: AbstractedAxesSpec) -> Dict[int, AbstractedAxisName]:
def _spec_to_dict(spec: AbstractedAxesSpec) -> dict[int, AbstractedAxisName]:
if isinstance(spec, tuple):
return {i: d for i, d in enumerate(spec) if d is not None}
else:
@ -2273,15 +2273,15 @@ def _spec_to_dict(spec: AbstractedAxesSpec) -> Dict[int, AbstractedAxisName]:
def _canonicalize_specs(
ndims: Sequence[int], specs: Optional[Sequence[AbstractedAxesSpec]]
) -> List[Dict[int, AbstractedAxisName]]:
) -> list[dict[int, AbstractedAxisName]]:
if specs is None:
return [{}] * len(ndims)
else:
return [_spec_to_dict(s) for n, s in zip(ndims, specs)]
def _complete_specs(
args: Sequence[Any], partial_specs: List[Dict[int, AbstractedAxisName]]
) -> List[Dict[int, AbstractedAxisName]]:
args: Sequence[Any], partial_specs: list[dict[int, AbstractedAxisName]]
) -> list[dict[int, AbstractedAxisName]]:
# The abstracted axes specification in `partial_specs` is partial in the sense
# that there could be additional axis abstraction represented in `args` due to
# Tracers existing in the shapes of elements of `args`. The purpose of this
@ -2291,16 +2291,16 @@ def _complete_specs(
# names (with one new name per unique Tracer object id).
# Identify each user-supplied name in partial_specs with a size.
sizes: Dict[AbstractedAxisName, Union[int, DynamicJaxprTracer]] = {}
sizes: dict[AbstractedAxisName, Union[int, DynamicJaxprTracer]] = {}
for x, spec in zip(args, partial_specs):
for i, name in spec.items():
d = sizes.setdefault(name, x.shape[i])
if d is not x.shape[i] and d != x.shape[i]: raise TypeError
# Introduce new names as needed for Tracers in shapes.
named_tracers: Dict[TracerId, AbstractedAxisName] = {
named_tracers: dict[TracerId, AbstractedAxisName] = {
id(d): name for name, d in sizes.items() if isinstance(d, Tracer)}
specs: List[Dict[int, AbstractedAxisName]] = []
specs: list[dict[int, AbstractedAxisName]] = []
for x, spec in zip(args, partial_specs):
if isinstance(get_aval(x), DShapedArray):
spec = dict(spec)
@ -2318,15 +2318,15 @@ def _complete_specs(
def _collect_implicit(
args: Sequence[Any], specs: List[Dict[int, AbstractedAxisName]]
) -> Tuple[Dict[AbstractedAxisName, DBIdx], List[AbstractValue]]:
args: Sequence[Any], specs: list[dict[int, AbstractedAxisName]]
) -> tuple[dict[AbstractedAxisName, DBIdx], list[AbstractValue]]:
# Given an explicit argument list and a specification of abstracted axes, we
# want to produce an InputType by identifying AbstractedAxisNames with DBIdxs
# and figuring out which AbstractedAxisNames correspond to implicit arguments.
idxs: Dict[AbstractedAxisName, DBIdx] = {}
implicit_types: List[AbstractValue] = []
explicit_tracers: Dict[TracerId, int] = {}
idxs: dict[AbstractedAxisName, DBIdx] = {}
implicit_types: list[AbstractValue] = []
explicit_tracers: dict[TracerId, int] = {}
counter = it.count()
# Add implicit arguments to idxs.
@ -2348,32 +2348,32 @@ def _collect_implicit(
return idxs, implicit_types
def _arg_type(
idxs: Dict[AbstractedAxisName, DBIdx], x: Any,
spec: Dict[int, AbstractedAxisName]
idxs: dict[AbstractedAxisName, DBIdx], x: Any,
spec: dict[int, AbstractedAxisName]
) -> AbstractValue:
# Produce an AbstractValue by substituting DBIdxs for AbstractedAxisNames.
aval = get_aval(x) # aval.shape could contain Tracers
if not spec: return core.raise_to_shaped(aval)
shape: List[Union[int, DBIdx]] = [idxs[spec[i]] if i in spec else d
shape: list[Union[int, DBIdx]] = [idxs[spec[i]] if i in spec else d
for i, d in enumerate(aval.shape)]
assert not any(isinstance(d, Tracer) for d in shape)
return DShapedArray(tuple(shape), aval.dtype, False)
def _add_implicit_outputs(jaxpr: Jaxpr) -> Tuple[Jaxpr, OutputType]:
def _add_implicit_outputs(jaxpr: Jaxpr) -> tuple[Jaxpr, OutputType]:
invars = [*jaxpr.constvars, *jaxpr.invars]
expl_outvars = jaxpr.outvars
# First do a pass to collect implicit outputs, meaning variables which occurr
# in explicit_outvars types but not in invars or to the left in outvars.
seen: Set[Var] = set(invars)
seen: set[Var] = set(invars)
impl_outvars = [seen.add(d) or d for x in expl_outvars if type(x) is Var and # type: ignore
(seen.add(x) or type(x.aval) is DShapedArray) # type: ignore
for d in x.aval.shape if type(d) is Var and d not in seen]
outvars = [*impl_outvars, *expl_outvars]
# Now assemble an OutputType by mapping vars in shapes to InDBIdx/OutDBIdx.
in_map : Dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)}
out_map: Dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars)
in_map : dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)}
out_map: dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars)
if type(x) is Var}
out_avals_ = (x.aval for x in outvars)
out_avals = [a.update(shape=tuple(in_map.get(d, out_map.get(d))
@ -2398,7 +2398,7 @@ class TracerAsName:
return id(self.ref)
def _extract_implicit_args(
trace: DynamicJaxprTrace, in_type: Sequence[Tuple[AbstractValue, bool]],
trace: DynamicJaxprTrace, in_type: Sequence[tuple[AbstractValue, bool]],
explicit_tracers: Sequence[DynamicJaxprTracer]
) -> Sequence[DynamicJaxprTracer]:
# First, construct a list to represent the full argument list, leaving the
@ -2430,7 +2430,7 @@ def _input_type_to_tracers(
# DeBruijn indices which refer to positions in the input argument list. That
# is, each element `a` of `in_avals` can have DBIdx instances in its shape,
# which must refer to positions left of `a`'s.
in_tracers: List[Tracer] = []
in_tracers: list[Tracer] = []
def _substitute_tracers_in_aval(a: AbstractValue) -> AbstractValue:
if isinstance(a, DShapedArray) and any(type(d) is DBIdx for d in a.shape):
@ -2443,7 +2443,7 @@ def _input_type_to_tracers(
return in_tracers
def _substitute_vars_in_type(
consts: Dict[Var, Literal], env: Dict[Var, Var], a: AbstractValue
consts: dict[Var, Literal], env: dict[Var, Var], a: AbstractValue
) -> AbstractValue:
if isinstance(a, DShapedArray) and any(isinstance(d, Var) for d in a.shape):
shape = [consts[d].val if d in consts else env[d] # type: ignore
@ -2494,7 +2494,7 @@ Const = Any
Val = Any
def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const]
) -> Tuple[Jaxpr, List[Const]]:
) -> tuple[Jaxpr, list[Const]]:
bounds = {v: v.aval.dtype.bound for v in jaxpr.invars
if isinstance(v.aval, core.UnshapedArray) and
type(v.aval.dtype) is core.bint and not v.aval.shape}
@ -2521,9 +2521,9 @@ class BoundedAxisSize(NamedTuple):
bound: int
def _eval_jaxpr_padded(
jaxpr: Jaxpr, consts: List[Const], *args: DynamicJaxprTracer
) -> List[Union[Const, DynamicJaxprTracer]]:
env: Dict[Var, Val] = {}
jaxpr: Jaxpr, consts: list[Const], *args: DynamicJaxprTracer
) -> list[Union[Const, DynamicJaxprTracer]]:
env: dict[Var, Val] = {}
def read(x):
return x.val if type(x) is Literal else env[x]
@ -2543,7 +2543,7 @@ def _eval_jaxpr_padded(
core.clean_up_dead_vars(eqn, env, last_used)
return map(read, jaxpr.outvars)
def _substitute_axis_sizes(env: Dict, aval: AbstractValue) -> AbstractValue:
def _substitute_axis_sizes(env: dict, aval: AbstractValue) -> AbstractValue:
if isinstance(aval, DShapedArray):
shp = []
for d in aval.shape:
@ -2570,7 +2570,7 @@ def _is_bint_axis_size(d: Union[int, core.DArray, core.Var]) -> bool:
return False
padding_rules: Dict[Primitive, Callable] = {}
padding_rules: dict[Primitive, Callable] = {}
def def_trivial_padding(prim: Primitive) -> None:
if prim.multiple_results:

View File

@ -23,9 +23,8 @@ from functools import partial, lru_cache, cached_property
import itertools as it
import logging
import math
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable,
TYPE_CHECKING, cast, TypeVar)
from typing import (Any, Callable, NamedTuple, Optional, Sequence, Union,
Iterable, TYPE_CHECKING, cast, TypeVar)
import numpy as np
@ -81,7 +80,7 @@ unsafe_map, map = map, safe_map # type: ignore
logger = logging.getLogger(__name__)
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
Index = Union[int, slice, tuple[Union[int, slice], ...]]
NoSharding = sharding_specs.NoSharding
Chunked = sharding_specs.Chunked
@ -140,7 +139,7 @@ def shard_args(
return [shard_arg(arg, devices, indices[i], shardings[i])
for i, arg in enumerate(args)]
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any, Any], Any]] = {}
shard_arg_handlers: dict[Any, Callable[[Any, Any, Any, Any], Any]] = {}
def _shard_token(x, devices, indices, sharding):
zeros = np.zeros((), dtype=np.dtype(np.bool_))
@ -190,8 +189,8 @@ def batched_device_put(aval: core.ShapedArray,
# from the input ShardingSpec, rather than the indices. However, this would
# require duplicating the ordering logic of spec_to_indices, which is more
# subtle and more likely to change than the index logic we have to support here.
def as_slice_indices(arr: Any, idx: Index) -> Tuple[
Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]:
def as_slice_indices(arr: Any, idx: Index) -> tuple[
tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
"""Returns start_indices, limit_indices, removed_dims"""
start_indices = [0] * arr.ndim
limit_indices = list(arr.shape)
@ -220,7 +219,7 @@ def shard_aval(size, axis: int, aval):
return shard_aval_handlers[type(aval)](size, axis, aval)
except KeyError as err:
raise TypeError(f"No shard_aval handler for type: {type(aval)}") from err
shard_aval_handlers: Dict[Type[core.AbstractValue], Callable[[int, int, Any], Any]] = {}
shard_aval_handlers: dict[type[core.AbstractValue], Callable[[int, int, Any], Any]] = {}
def _shard_abstract_array(size, axis: int, x):
try:
if x.shape[axis] != size:
@ -235,8 +234,8 @@ shard_aval_handlers[ShapedArray] = _shard_abstract_array
def local_aval_to_result_handler(
aval: core.AbstractValue,
sharding: sharding_impls.XLACompatibleSharding,
indices: Optional[Tuple[Index, ...]],
) -> Callable[[List[xc.ArrayImpl]], Any]:
indices: Optional[tuple[Index, ...]],
) -> Callable[[list[xc.ArrayImpl]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:
@ -258,7 +257,7 @@ def local_aval_to_result_handler(
f"No pxla_result_handler for type: {type(aval)}") from err
PxlaResultHandler = Callable[..., Callable[[Any], Any]]
local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
local_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
def global_aval_to_result_handler(
@ -288,7 +287,7 @@ def global_aval_to_result_handler(
raise TypeError(
f"No pxla_result_handler for type: {type(aval)}") from err
global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
global_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
### lazy device-memory persistence and result handling
@ -297,8 +296,8 @@ def make_sharded_device_array(
aval: ShapedArray,
sharding_spec: Optional[ShardingSpec],
# Any is for JAX extensions implementing their own buffer.
device_buffers: List[Any],
indices: Optional[Tuple[Index, ...]] = None,
device_buffers: list[Any],
indices: Optional[tuple[Index, ...]] = None,
):
"""Returns a ShardedDeviceArray implementation based on arguments.
@ -482,7 +481,7 @@ def _emap_impl(fun: lu.WrappedFun, *args,
new_outvals.append(out)
return new_outvals
def _map_schedule(idx: Tuple[Optional[int], ...]) -> Tuple[Optional[int], ...]:
def _map_schedule(idx: tuple[Optional[int], ...]) -> tuple[Optional[int], ...]:
# In order to do a multi-map (a simultaneous map over several axes), we will
# nest several maps. Each time we do a map, we "remove" an input axis so we
# need to update the remaining map axes. For example, if we are to map over
@ -498,9 +497,9 @@ def _map_schedule(idx: Tuple[Optional[int], ...]) -> Tuple[Optional[int], ...]:
# _function object_. Adding this annotation here lets us reuse the same pmap
# callable for all equivalent primitive pmaps.
@lru_cache()
def _multi_pmap(f: Callable, info: EmapInfo, names: List[core.AxisName],
all_axes: List[Tuple[Optional[int], ...]]
) -> Tuple[Callable, Dict[core.AxisName, int]]:
def _multi_pmap(f: Callable, info: EmapInfo, names: list[core.AxisName],
all_axes: list[tuple[Optional[int], ...]]
) -> tuple[Callable, dict[core.AxisName, int]]:
used_names = []
for i, name in reversed(list(enumerate(names))):
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
@ -609,9 +608,9 @@ def _annot_to_flat(ndim: int, mapped_axes: Iterable[int],
return [i for i in range(ndim) if i not in mapped_axes_][annotation]
def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any,
shard_axis_src: Dict[core.AxisName, int],
shard_axis_src: dict[core.AxisName, int],
dst_annotation: Optional[int]
) -> Tuple[Any, Dict[core.AxisName, int]]:
) -> tuple[Any, dict[core.AxisName, int]]:
shard_axis_out = dict(shard_axis_src)
src = shard_axis_out.pop(axis_name, None)
dst = _annot_to_flat(np.ndim(val) + (src is None), shard_axis_out.values(),
@ -629,9 +628,9 @@ def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any,
raise NotImplementedError
return outval, shard_axis_out
def _moveaxis(ndim: int, shard_axes: Dict[core.AxisName, int],
src: int, dst: int) -> Dict[core.AxisName, int]:
lst: List[Optional[core.AxisName]] = [None] * ndim
def _moveaxis(ndim: int, shard_axes: dict[core.AxisName, int],
src: int, dst: int) -> dict[core.AxisName, int]:
lst: list[Optional[core.AxisName]] = [None] * ndim
for k, v in shard_axes.items():
lst[v] = k
name = lst.pop(src)
@ -641,7 +640,7 @@ def _moveaxis(ndim: int, shard_axes: Dict[core.AxisName, int],
class MapTracer(core.Tracer):
__slots__ = ["val", "shard_axes"]
def __init__(self, trace: MapTrace, val, shard_axes: Dict[core.AxisName, int]):
def __init__(self, trace: MapTrace, val, shard_axes: dict[core.AxisName, int]):
self._trace = trace
self.val = val
self.shard_axes = shard_axes
@ -736,7 +735,7 @@ def find_replicas(
def stage_parallel_callable(
pci: ParallelCallableInfo, fun: lu.WrappedFun
) -> Tuple[core.Jaxpr, List[Any], ReplicaInfo, ShardInfo]:
) -> tuple[core.Jaxpr, list[Any], ReplicaInfo, ShardInfo]:
sharded_avals = tuple(
shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
for axis, aval in safe_zip(pci.in_axes, pci.avals))
@ -940,8 +939,8 @@ class UnloadedPmapExecutable:
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
local_output_avals: Sequence[ShapedArray]
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
unordered_effects: List[core.Effect]
ordered_effects: List[core.Effect]
unordered_effects: list[core.Effect]
ordered_effects: list[core.Effect]
keepalive: Sequence[Any]
host_callbacks: Sequence[Any]
jaxpr_debug_info: core.JaxprDebugInfo
@ -979,9 +978,9 @@ class UnloadedPmapExecutable:
replicas: ReplicaInfo,
shards: ShardInfo,
tuple_args: bool,
unordered_effects: List[core.Effect],
ordered_effects: List[core.Effect],
host_callbacks: List[Any],
unordered_effects: list[core.Effect],
ordered_effects: list[core.Effect],
host_callbacks: list[Any],
keepalive: Any,
jaxpr_debug_info: core.JaxprDebugInfo,
compiler_options=None):
@ -1156,7 +1155,7 @@ def _get_pmap_sharding(devices, specs):
return [sharding_impls.PmapSharding(devices, spec) for spec in specs]
multi_host_supported_collectives: Set[core.Primitive] = set()
multi_host_supported_collectives: set[core.Primitive] = set()
def check_multihost_collective_allowlist(jaxpr):
@ -1287,9 +1286,9 @@ class ExecuteReplicated:
def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
out_handler: ResultsHandler,
unordered_effects: List[core.Effect],
ordered_effects: List[core.Effect], keepalive: Any,
has_host_callbacks: bool, kept_var_idx: Set[int]):
unordered_effects: list[core.Effect],
ordered_effects: list[core.Effect], keepalive: Any,
has_host_callbacks: bool, kept_var_idx: set[int]):
self.xla_executable = xla_executable
self.name = name
self.backend = backend
@ -1582,7 +1581,7 @@ class SPMDBatchTrace(batching.BatchTrace):
return super().get_axis_primitive_batcher(primitive, frame)
spmd_primitive_batchers: Dict[core.Primitive, Callable] = {}
spmd_primitive_batchers: dict[core.Primitive, Callable] = {}
def vtile_by_mesh(fun: lu.WrappedFun,
@ -1611,7 +1610,7 @@ def _full_to_shard_abstract_eval(x, axes, mesh, **_):
def manual_proto(
aval: core.ShapedArray,
manual_axes_set: FrozenSet[sharding_impls.MeshAxisName], mesh: Mesh):
manual_axes_set: frozenset[sharding_impls.MeshAxisName], mesh: Mesh):
"""Create an OpSharding proto that declares all mesh axes from `axes` as manual
and all others as replicated.
"""
@ -1638,7 +1637,7 @@ def manual_proto(
@partial(mlir.register_lowering, full_to_shard_p)
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
manual_axes: FrozenSet[sharding_impls.MeshAxisName]):
manual_axes: frozenset[sharding_impls.MeshAxisName]):
# TODO: Can we short-circuit for replicated values? Probably not.
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
@ -1658,7 +1657,7 @@ def _shard_to_full_abstract_eval(x, axes, mesh, **_):
@partial(mlir.register_lowering, shard_to_full_p)
def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapping, mesh: Mesh,
manual_axes: FrozenSet[sharding_impls.MeshAxisName]):
manual_axes: frozenset[sharding_impls.MeshAxisName]):
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
proto = manual_proto(aval_in, manual_axes, mesh) # type: ignore
@ -1669,7 +1668,7 @@ def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapp
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto, unspecified_dims),
@lu.transformation
def vtile_manual(manual_axes: FrozenSet[sharding_impls.MeshAxisName],
def vtile_manual(manual_axes: frozenset[sharding_impls.MeshAxisName],
mesh: Mesh,
in_axes: Sequence[ArrayMapping],
out_axes: Sequence[ArrayMapping],
@ -1688,7 +1687,7 @@ class TileVectorize:
@dataclasses.dataclass(frozen=True)
class TileManual:
manual_axes: FrozenSet[sharding_impls.MeshAxisName]
manual_axes: frozenset[sharding_impls.MeshAxisName]
TilingMethod = Union[TileVectorize, TileManual]
@ -1756,7 +1755,7 @@ class DeviceAssignmentMismatchError(Exception):
pass
ShardingInfo = Tuple[
ShardingInfo = tuple[
Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO],
MismatchType, Optional[Any]] # Any is dispatch.SourceInfo to avoid circular imports
@ -1768,7 +1767,7 @@ def _get_default_device() -> xc.Device:
def _get_and_check_device_assignment(
shardings: Iterable[ShardingInfo],
devices: Optional[Sequence[xc.Device]],
) -> Tuple[xc.Client, Tuple[xc.Device, ...]]:
) -> tuple[xc.Client, tuple[xc.Device, ...]]:
first_sharding_info = None
if devices is None:
devices = ()
@ -1851,7 +1850,7 @@ def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
@dataclasses.dataclass(frozen=True)
class SemanticallyEqualShardings:
shardings: Tuple[Union[sharding_impls.GSPMDSharding, UnspecifiedValue], ...]
shardings: tuple[Union[sharding_impls.GSPMDSharding, UnspecifiedValue], ...]
def __hash__(self):
return hash(tuple(
@ -1895,8 +1894,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
dispatch.raise_warnings_or_errors_for_jit_of_pmap(
nreps, backend, fun_name, jaxpr)
in_mlir_shardings: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
out_mlir_shardings: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
in_mlir_shardings: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]]
out_mlir_shardings: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]]
axis_ctx: mlir.AxisContext
if nreps == 1:
@ -1952,7 +1951,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
@dataclasses.dataclass(frozen=True)
class _DeviceAssignment:
device_assignment: Tuple[xc.Device, ...]
device_assignment: tuple[xc.Device, ...]
@cached_property
def _hash(self):
@ -1980,7 +1979,7 @@ class _DeviceAssignment:
@lru_cache(maxsize=2048)
def _create_da_object(
device_assignment: Tuple[xc.Device, ...]) -> _DeviceAssignment:
device_assignment: tuple[xc.Device, ...]) -> _DeviceAssignment:
return _DeviceAssignment(device_assignment)
@ -2161,7 +2160,7 @@ def lower_mesh_computation(
# 1. Trace to jaxpr and preprocess/verify it
if spmd_lowering:
manual_axes: FrozenSet[MeshAxisName] = frozenset()
manual_axes: frozenset[MeshAxisName] = frozenset()
# TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
if tiling_method is not None:
if isinstance(tiling_method, TileVectorize):
@ -2217,8 +2216,8 @@ def lower_mesh_computation(
# 2. Build up the HLO
tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform)
in_partitions: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
out_partitions: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
in_partitions: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]]
out_partitions: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]]
axis_ctx: mlir.AxisContext
if spmd_lowering:
in_partitions = map(_to_logical_sharding, global_in_avals, in_shardings)
@ -2330,7 +2329,7 @@ class MeshComputation(stages.XlaLowering):
return executable
return self._executable
def cost_analysis(self) -> Dict[str, float]:
def cost_analysis(self) -> dict[str, float]:
backend = self.compile_args["backend"]
if xb.using_pjrt_c_api(backend):
raise NotImplementedError(
@ -2352,7 +2351,7 @@ def _get_input_indices(
avals: Sequence[ShapedArray],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
da_object: Union[_DeviceAssignment, Sequence[xc.Device]],
) -> Sequence[Tuple[Optional[Index], ...]]:
) -> Sequence[tuple[Optional[Index], ...]]:
input_indices = []
if isinstance(da_object, _DeviceAssignment):
@ -2378,7 +2377,7 @@ def _get_input_indices(
def get_gspmd_shardings_from_executable(
xla_executable, device_assignment: Sequence[xc.Device],
num_in_avals: int, num_out_avals: int
) -> Tuple[Sequence[sharding_impls.XLACompatibleSharding],
) -> tuple[Sequence[sharding_impls.XLACompatibleSharding],
Sequence[sharding_impls.XLACompatibleSharding]]:
from jax.experimental import pjit
@ -2411,7 +2410,7 @@ def get_gspmd_shardings_from_executable(
# without mesh.
def _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh: Mesh
) -> Tuple[Sequence[sharding_impls.NamedSharding],
) -> tuple[Sequence[sharding_impls.NamedSharding],
Sequence[sharding_impls.NamedSharding]]:
from jax.experimental import pjit
@ -2421,7 +2420,7 @@ def _get_mesh_pspec_shardings_from_executable(
SubClassT = TypeVar("SubClassT", bound=sharding_impls.XLACompatibleSharding)
OrigHandlerType = Dict[Type[SubClassT],
OrigHandlerType = dict[type[SubClassT],
Callable[[xc.OpSharding, SubClassT], SubClassT]]
orig_out_sharding_handlers: OrigHandlerType = {}
@ -2569,11 +2568,11 @@ class UnloadedMeshExecutable:
committed: bool
are_out_shardings_from_xla: Sequence[bool]
name: str
unordered_effects: List[core.Effect]
ordered_effects: List[core.Effect]
unordered_effects: list[core.Effect]
ordered_effects: list[core.Effect]
keepalive: Sequence[Any]
host_callbacks: Sequence[Any]
kept_var_idx: Set[int]
kept_var_idx: set[int]
auto_spmd_lowering: bool
jaxpr_debug_info: Optional[core.JaxprDebugInfo]
@ -2611,11 +2610,11 @@ class UnloadedMeshExecutable:
spmd_lowering: bool,
tuple_args: bool,
auto_spmd_lowering: bool,
unordered_effects: List[core.Effect],
ordered_effects: List[core.Effect],
host_callbacks: List[Any],
unordered_effects: list[core.Effect],
ordered_effects: list[core.Effect],
host_callbacks: list[Any],
keepalive: Any,
kept_var_idx: Set[int],
kept_var_idx: set[int],
backend: xb.XlaBackend,
device_assignment: Union[_DeviceAssignment, Sequence[xc.Device]],
committed: bool,
@ -2875,7 +2874,7 @@ def _out_shardings_for_trivial(
jaxpr: core.Jaxpr, consts: Sequence[Any],
in_shardings: Sequence[sharding_impls.XLACompatibleSharding],
device_assignment: Sequence[xc.Device],
) -> List[sharding_impls.XLACompatibleSharding]:
) -> list[sharding_impls.XLACompatibleSharding]:
# For each jaxpr output, compute a Sharding by:
# * if the output is a forwarded input, get the corresponding in_sharding;
# * if the output is a constant Array, get its .sharding attribute;
@ -2893,7 +2892,7 @@ def _out_shardings_for_trivial(
rep = sharding_impls.SingleDeviceSharding(dev)
in_shardings = (sharding_impls.SingleDeviceSharding(dev),) * len(in_shardings)
shardings: Dict[core.Var, sharding_impls.XLACompatibleSharding] = {}
shardings: dict[core.Var, sharding_impls.XLACompatibleSharding] = {}
for constvar, constval in zip(jaxpr.constvars, consts):
if isinstance(constval, array.ArrayImpl):
shardings[constvar] = constval.sharding
@ -2903,7 +2902,7 @@ def _out_shardings_for_trivial(
def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args):
env: Dict[core.Var, Any] = {}
env: dict[core.Var, Any] = {}
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
map(env.setdefault, jaxpr.invars, pruned_args)
map(env.setdefault, jaxpr.constvars, consts)
@ -3037,7 +3036,7 @@ def _sanitize_mesh_jaxpr(jaxpr):
core.traverse_jaxpr_params(_sanitize_mesh_jaxpr, eqn.params)
custom_resource_typing_rules: Dict[core.Primitive, Callable] = {}
custom_resource_typing_rules: dict[core.Primitive, Callable] = {}
def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
if isinstance(jaxpr, core.ClosedJaxpr):
@ -3113,7 +3112,7 @@ def maybe_extend_axis_env(*args, **kwargs):
def device_put(x, devices: Sequence[xc.ArrayImpl],
replicate: bool=False) -> List[xc.ArrayImpl]:
replicate: bool=False) -> list[xc.ArrayImpl]:
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
if replicate:
return [jax.device_put(x, device) for device in devices]

View File

@ -22,8 +22,7 @@ import itertools as it
import math
import operator
import re
from typing import (Any, Callable, Dict, Optional, Protocol,
Sequence, Set, Type, Tuple, Union)
from typing import Any, Callable, Optional, Protocol, Sequence, Union
import numpy as np
@ -91,7 +90,7 @@ def parameter(builder, num, shape, name=None, replicated=None):
# nesting in this type definition.
SpatialSharding = Union[Shape,
None,
Tuple[Optional[Shape], ...]]
tuple[Optional[Shape], ...]]
def sharding_to_proto(sharding: SpatialSharding):
"""Converts a SpatialSharding to an OpSharding.
@ -145,7 +144,7 @@ def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]:
except KeyError as err:
raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err
xla_shape_handlers: Dict[Type[core.AbstractValue],
xla_shape_handlers: dict[type[core.AbstractValue],
Callable[[Any], Sequence[xc.Shape]]] = {
ShapedArray: _make_array_shape,
ConcreteArray: _make_array_shape,
@ -179,7 +178,7 @@ def _canonicalize_python_scalar_dtype(typ, x):
return np.asarray(
x, dtypes.canonicalize_dtype(dtypes._scalar_type_to_dtype(typ, x)))
canonicalize_dtype_handlers: Dict[Any, Callable] = {}
canonicalize_dtype_handlers: dict[Any, Callable] = {}
canonicalize_dtype_handlers.update(
(t, _canonicalize_ndarray_dtype) for t in numpy_scalar_types)
canonicalize_dtype_handlers[np.ndarray] = _canonicalize_ndarray_dtype
@ -217,7 +216,7 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = {}
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = {}
pytype_aval_mappings[core.DArray] = operator.attrgetter('_aval')
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar)
@ -291,7 +290,7 @@ def axis_read(axis_env, axis_name):
except ValueError:
raise NameError(f"unbound axis name: {axis_name}") from None
def axis_groups(axis_env: AxisEnv, name) -> Tuple[Tuple[int, ...]]:
def axis_groups(axis_env: AxisEnv, name) -> tuple[tuple[int, ...]]:
if not isinstance(name, (list, tuple)):
name = (name,)
mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
@ -374,12 +373,12 @@ if not MYPY:
else:
TranslationRule = Any
_translations: Dict[core.Primitive, TranslationRule] = {}
_backend_specific_translations: Dict[str, Dict[core.Primitive, TranslationRule]]
_translations: dict[core.Primitive, TranslationRule] = {}
_backend_specific_translations: dict[str, dict[core.Primitive, TranslationRule]]
_backend_specific_translations = defaultdict(dict)
_collective_primitives: Set[core.Primitive] = set()
initial_style_primitives: Set[core.Primitive] = set()
_collective_primitives: set[core.Primitive] = set()
initial_style_primitives: set[core.Primitive] = set()
def register_initial_style_primitive(prim: core.Primitive):
initial_style_primitives.add(prim)
@ -441,5 +440,5 @@ class _BackendSpecificTranslationsAdapter(defaultdict):
translation_tables, _wrap_old_translation)
return ret
backend_specific_translations: Dict[str, _TranslationRuleAdapter]
backend_specific_translations: dict[str, _TranslationRuleAdapter]
backend_specific_translations = _BackendSpecificTranslationsAdapter()

View File

@ -24,7 +24,7 @@ from __future__ import annotations
import os
import platform
from typing import Any, List, Sequence, Optional
from typing import Any, Sequence, Optional
import iree.compiler
import iree.runtime
@ -73,7 +73,7 @@ class IreeDevice:
def transfer_from_outfeed(self, shape: xla_client.Shape):
raise NotImplementedError("transfer_to_outfeed")
def live_buffers(self) -> List[IreeBuffer]:
def live_buffers(self) -> list[IreeBuffer]:
raise NotImplementedError("live_buffers")
@ -123,10 +123,10 @@ class IreeExecutable:
self.module_object = module_object
self.function_name = function_name
def local_devices(self) -> List[IreeDevice]:
def local_devices(self) -> list[IreeDevice]:
return self._devices
def execute(self, arguments: Sequence[IreeBuffer]) -> List[IreeBuffer]:
def execute(self, arguments: Sequence[IreeBuffer]) -> list[IreeBuffer]:
inputs = [arg.to_iree() for arg in arguments]
outputs = self.module_object[self.function_name](*inputs)
# TODO(phawkins): Have a way to just have it always return the list,
@ -159,10 +159,10 @@ class IreeClient:
def device_count(self) -> int:
return len(self._devices)
def devices(self) -> List[IreeDevice]:
def devices(self) -> list[IreeDevice]:
return self._devices
def local_devices(self) -> List[IreeDevice]:
def local_devices(self) -> list[IreeDevice]:
return self._devices
def local_device_count(self) -> int:
@ -170,7 +170,7 @@ class IreeClient:
def get_default_device_assignment(
self,
num_replicas: int) -> List[IreeDevice]:
num_replicas: int) -> list[IreeDevice]:
if num_replicas != 1:
raise NotImplementedError("Only single-device computations implemented")
return [self._devices[0]]

View File

@ -70,7 +70,7 @@ Todos::
"""
from functools import partial
from typing import (Any, Tuple)
from typing import Any
import numpy as np
@ -98,7 +98,7 @@ def approx_max_k(operand: Array,
reduction_dimension: int = -1,
recall_target: float = 0.95,
reduction_input_size_override: int = -1,
aggregate_to_topk: bool = True) -> Tuple[Array, Array]:
aggregate_to_topk: bool = True) -> tuple[Array, Array]:
"""Returns max ``k`` values and their indices of the ``operand`` in an approximate manner.
See https://arxiv.org/abs/2206.14286 for the algorithm details.
@ -157,7 +157,7 @@ def approx_min_k(operand: Array,
reduction_dimension: int = -1,
recall_target: float = 0.95,
reduction_input_size_override: int = -1,
aggregate_to_topk: bool = True) -> Tuple[Array, Array]:
aggregate_to_topk: bool = True) -> tuple[Array, Array]:
"""Returns min ``k`` values and their indices of the ``operand`` in an approximate manner.
See https://arxiv.org/abs/2206.14286 for the algorithm details.

View File

@ -15,7 +15,7 @@
import os
from functools import partial
from typing import Any, Callable, List, Optional, Sequence
from typing import Any, Callable, Optional, Sequence
from jax._src import core
from jax._src import linear_util as lu
@ -134,7 +134,7 @@ def _initial_style_jaxprs_with_common_consts(
# b[] <- 2.0
# in () }
canonical_ref_indices = []
canonical_refs: List[Any] = []
canonical_refs: list[Any] = []
tracer_id_to_canonical_id = {}
all_nonref_consts = []
canonical_ref_avals = []

View File

@ -19,7 +19,7 @@ import inspect
import itertools
import operator
from typing import Callable, Sequence, List, Tuple
from typing import Callable, Sequence
from jax import config
from jax.tree_util import tree_flatten, tree_unflatten
@ -511,7 +511,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
# First, compute output unknowns (unks_out), where an output of the cond is
# unknown if it would be unknown on any of the branches.
unks_out: List[bool] = [False] * len(eqn.outvars)
unks_out: list[bool] = [False] * len(eqn.outvars)
for jaxpr in branches:
_, _, unks_out_, _, _ = pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=True,
@ -520,9 +520,9 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
# Next, use the computed output unknowns to build a known jaxpr and a staged
# jaxpr for each branch.
branches_known_ : List[core.ClosedJaxpr] = []
branches_staged_: List[core.ClosedJaxpr] = []
branch_res_avals: List[core.AbstractValue] = []
branches_known_ : list[core.ClosedJaxpr] = []
branches_staged_: list[core.ClosedJaxpr] = []
branch_res_avals: list[core.AbstractValue] = []
for jaxpr in branches:
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
@ -653,13 +653,13 @@ def _ordered_unique(xs):
d = collections.OrderedDict((x, None) for x in xs)
return list(d.keys())
def _cond_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn,
) -> Tuple[List[bool], core.JaxprEqn]:
def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
) -> tuple[list[bool], core.JaxprEqn]:
closed_branches = eqn.params['branches']
branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches]
# First, compute which inputs are used in any branch (not including `pred`).
used_inputs: List[bool] = [False] * (len(eqn.invars) - 1) # -1 for pred
used_inputs: list[bool] = [False] * (len(eqn.invars) - 1) # -1 for pred
for jaxpr in branches:
_, used_inputs_ = pe.dce_jaxpr(jaxpr, used_outputs, instantiate=False)
used_inputs = map(operator.or_, used_inputs, used_inputs_)

View File

@ -15,7 +15,7 @@
import functools
import operator
from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple, TypeVar, Union
from typing import Any, Callable, Generic, Optional, Sequence, TypeVar, Union
import jax.numpy as jnp
from jax import lax
@ -92,7 +92,7 @@ def _hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef,
state_avals: Sequence[core.AbstractValue]
) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]:
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
f, out_tree_thunk = flatten_fun_nokwargs(
lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree)))
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
@ -172,12 +172,12 @@ Carry = TypeVar('Carry')
X = TypeVar('X')
Y = TypeVar('Y')
def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
length: Optional[int] = None,
reverse: bool = False,
unroll: int = 1) -> Tuple[Carry, Y]:
unroll: int = 1) -> tuple[Carry, Y]:
if not callable(f):
raise TypeError("scan: f argument should be a callable.")
if unroll < 1:
@ -253,7 +253,7 @@ def _for_abstract_eval(*avals, jaxpr, **__):
def _for_discharge_rule(in_avals, _, *args: Any, jaxpr: core.Jaxpr,
reverse: bool, which_linear: Sequence[bool],
nsteps: int, unroll: int
) -> Tuple[Sequence[Optional[Any]], Sequence[Any]]:
) -> tuple[Sequence[Optional[Any]], Sequence[Any]]:
out_vals = for_p.bind(*args, jaxpr=jaxpr, reverse=reverse,
which_linear=which_linear, nsteps=nsteps,
unroll=unroll)
@ -371,7 +371,7 @@ def _partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy):
_save_everything = lambda *_, **__: True
def _is_read_only(ref_effects: Set[StateEffect]) -> bool:
def _is_read_only(ref_effects: set[StateEffect]) -> bool:
assert len(ref_effects) > 0
if len(ref_effects) > 1:
# Means we must have a write or accum effect so not read-only
@ -379,7 +379,7 @@ def _is_read_only(ref_effects: Set[StateEffect]) -> bool:
eff, = ref_effects
return isinstance(eff, ReadEffect)
def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> List[bool]:
def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> list[bool]:
# Get effects for each of the jaxpr inputs and remove the loop index.
ref_effects = state_types.get_ref_state_effects(
[v.aval for v in jaxpr.invars], jaxpr.effects)[1:]
@ -406,8 +406,8 @@ def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> List[bool]:
def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
jaxpr: core.Jaxpr, nsteps: int, reverse: bool,
which_linear: Tuple[bool, ...],
unroll: int) -> List[pe.JaxprTracer]:
which_linear: tuple[bool, ...],
unroll: int) -> list[pe.JaxprTracer]:
num_inputs = len(tracers)
assert num_inputs == len(jaxpr.invars) - 1
in_unknowns = [not t.pval.is_known() for t in tracers]
@ -636,7 +636,7 @@ pe.partial_eval_jaxpr_custom_rules[for_p] = _for_partial_eval_custom
def _convert_outputs_to_writes(
nsteps: int, jaxpr: core.Jaxpr, loop_invar_res: Sequence[bool]
) -> Tuple[core.Jaxpr, List[core.ShapedArray]]:
) -> tuple[core.Jaxpr, list[core.ShapedArray]]:
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."
in_avals = [v.aval for v in jaxpr.invars] # [i, *orig_ref_avals]
@ -654,7 +654,7 @@ def _convert_outputs_to_writes(
res_ref[i] = res_val
return []
# TODO(mattjj, sharadmv): better handling of tokens, which don't have shape/dtype
res_ref_avals: List[core.AbstractValue] = [
res_ref_avals: list[core.AbstractValue] = [
AbstractRef(v.aval) if loop_invar else # pytype: disable=attribute-error
AbstractRef(core.ShapedArray((nsteps, *v.aval.shape), # pytype: disable=attribute-error
v.aval.dtype)) # pytype: disable=attribute-error
@ -679,7 +679,7 @@ def _convert_inputs_to_reads(
res_val_avals, (i_aval,), orig_ref_avals = \
split_list([v.aval for v in jaxpr.invars], [num_res, 1])
res_ref_avals: List[core.AbstractValue] = [
res_ref_avals: list[core.AbstractValue] = [
AbstractRef(aval) if loop_invar else # pytype: disable=attribute-error
AbstractRef(core.ShapedArray((nsteps, *aval.shape), # pytype: disable=attribute-error
aval.dtype)) # pytype: disable=attribute-error
@ -689,7 +689,7 @@ def _convert_inputs_to_reads(
eval_jaxpr, [i_aval, *res_ref_avals, *orig_ref_avals])
return jaxpr
def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: List[bool]) -> core.Jaxpr:
def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: list[bool]) -> core.Jaxpr:
def trans(i, *args):
# First we want to run the computation to read all the residual refs. We can
# do that by using partial evaluation with all linear inputs unknown.

View File

@ -16,7 +16,7 @@ from functools import partial
import inspect
import itertools
import operator
from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar
from typing import Any, Callable, Optional, Sequence, TypeVar
import jax
import weakref
@ -96,12 +96,12 @@ X = TypeVar('X')
Y = TypeVar('Y')
@api_boundary
def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
length: Optional[int] = None,
reverse: bool = False,
unroll: int = 1) -> Tuple[Carry, Y]:
unroll: int = 1) -> tuple[Carry, Y]:
"""Scan a function over leading array axes while carrying along state.
The `Haskell-like type signature`_ in brief is
@ -820,8 +820,8 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts))
return scan_p.bind(*args, jaxpr=padded_jaxpr, **params)
def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], core.JaxprEqn]:
def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], core.JaxprEqn]:
jaxpr = eqn.params['jaxpr']
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
num_xs = len(jaxpr.in_avals) - num_consts - num_carry

View File

@ -15,7 +15,7 @@
import builtins
from functools import partial
import operator
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
from typing import Any, NamedTuple, Optional, Sequence, Union
import numpy as np
@ -51,11 +51,11 @@ class ConvDimensionNumbers(NamedTuple):
out_spec: Sequence[int]
ConvGeneralDilatedDimensionNumbers = Union[
None, ConvDimensionNumbers, Tuple[str, str, str]]
None, ConvDimensionNumbers, tuple[str, str, str]]
def conv_general_dilated(
lhs: Array, rhs: Array, window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: Union[str, Sequence[tuple[int, int]]],
lhs_dilation: Optional[Sequence[int]] = None,
rhs_dilation: Optional[Sequence[int]] = None,
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
@ -199,7 +199,7 @@ def conv(lhs: Array, rhs: Array, window_strides: Sequence[int],
def conv_with_general_padding(lhs: Array, rhs: Array,
window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: Union[str, Sequence[tuple[int, int]]],
lhs_dilation: Optional[Sequence[int]],
rhs_dilation: Optional[Sequence[int]],
precision: lax.PrecisionLike = None,
@ -271,7 +271,7 @@ def _flip_axes(x, axes):
def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: Union[str, Sequence[tuple[int, int]]],
rhs_dilation: Optional[Sequence[int]] = None,
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
transpose_kernel: bool = False,
@ -330,7 +330,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
k_shape = np.take(rhs.shape, dn.rhs_spec)
k_sdims = k_shape[2:] # type: ignore[index]
# Calculate correct output shape given padding and strides.
pads: Union[str, Sequence[Tuple[int, int]]]
pads: Union[str, Sequence[tuple[int, int]]]
if isinstance(padding, str) and padding in {'SAME', 'VALID'}:
if rhs_dilation is None:
rhs_dilation = (1,) * (rhs.ndim - 2)
@ -351,7 +351,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
def _conv_general_dilated_shape_rule(
lhs: core.ShapedArray, rhs: core.ShapedArray, *, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
batch_group_count, **unused_kwargs) -> Tuple[int, ...]:
batch_group_count, **unused_kwargs) -> tuple[int, ...]:
assert type(dimension_numbers) is ConvDimensionNumbers
if len(lhs.shape) != len(rhs.shape):
msg = ("conv_general_dilated lhs and rhs must have the same number of "
@ -730,7 +730,7 @@ def _conv_general_dilated_lower(
# d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each
# spatial dimension.
int2d = mlir.aval_to_ir_type(core.ShapedArray((1, 2), np.int32))
def prep_one_pad(pad_lo_hi: Tuple[core.DimSize, core.DimSize]):
def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):
pad1 = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, pad_lo_hi)) # i32[2]
return hlo.ReshapeOp(int2d, pad1)
d_padding = hlo.ConcatenateOp(list(map(prep_one_pad, padding)),
@ -910,7 +910,7 @@ def conv_general_permutations(dimension_numbers):
def _conv_general_vjp_lhs_padding(
in_shape, window_dimensions, window_strides, out_shape, padding,
lhs_dilation, rhs_dilation) -> List[Tuple[int, int]]:
lhs_dilation, rhs_dilation) -> list[tuple[int, int]]:
lhs_dilated_shape = lax._dilate_shape(in_shape, lhs_dilation)
rhs_dilated_shape = lax._dilate_shape(window_dimensions, rhs_dilation)
out_dilated_shape = lax._dilate_shape(out_shape, window_strides)

View File

@ -19,8 +19,8 @@ from functools import partial
import itertools
import math
import operator
from typing import (Any, Callable, Optional, Sequence, Tuple, List,
TypeVar, Union, cast as type_cast, overload)
from typing import (Any, Callable, Optional, Sequence, TypeVar, Union,
cast as type_cast, overload)
import warnings
import numpy as np
@ -102,7 +102,7 @@ def _validate_shapes(shapes: Sequence[Shape]):
map(_check_static_shape, shapes)
def _try_broadcast_shapes(
shapes: Sequence[Tuple[int, ...]]) -> Optional[Tuple[int, ...]]:
shapes: Sequence[tuple[int, ...]]) -> Optional[tuple[int, ...]]:
if len(shapes) == 1: return shapes[0]
rank, *others = {len(shape) for shape in shapes}
if others: return None # must have consistent rank
@ -134,11 +134,11 @@ def asarray(x: ArrayLike) -> Array:
raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.")
@overload
def broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]: ...
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: ...
@overload
def broadcast_shapes(*shapes: Tuple[Union[int, core.Tracer], ...]
) -> Tuple[Union[int, core.Tracer], ...]: ...
def broadcast_shapes(*shapes: tuple[Union[int, core.Tracer], ...]
) -> tuple[Union[int, core.Tracer], ...]: ...
def broadcast_shapes(*shapes):
"""Returns the shape that results from NumPy broadcasting of `shapes`."""
@ -149,7 +149,7 @@ def broadcast_shapes(*shapes):
return _broadcast_shapes_uncached(*shapes)
@cache()
def _broadcast_shapes_cached(*shapes: Tuple[int, ...]) -> Tuple[int, ...]:
def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
return _broadcast_shapes_uncached(*shapes)
def _broadcast_shapes_uncached(*shapes):
@ -181,7 +181,7 @@ def _identity(x): return x
def _extract_tracers_dyn_shape(
shape: Sequence[Union[int, core.Tracer]]
) -> Tuple[List[core.Tracer], List[Optional[int]]]:
) -> tuple[list[core.Tracer], list[Optional[int]]]:
# Given a sequence representing a shape, pull out Tracers, replacing with None
if config.jax_dynamic_shapes:
# We must gate this behavior under a flag because otherwise the errors
@ -195,7 +195,7 @@ def _extract_tracers_dyn_shape(
def _merge_dyn_shape(
static_shape: Sequence[Optional[int]],
dyn_shape: Sequence[Any],
) -> Tuple[Union[int, mlir.Value, core.Tracer], ...]:
) -> tuple[Union[int, mlir.Value, core.Tracer], ...]:
# Replace Nones in static_shape with elements of dyn_shape, in order
dyn_shape_it = iter(dyn_shape)
shape = tuple(next(dyn_shape_it) if d is None else d for d in static_shape)
@ -663,8 +663,8 @@ class Precision(xla_client.PrecisionConfig.Precision): # type: ignore
PrecisionType = Precision
PrecisionLike = Union[None, str, PrecisionType, Tuple[str, str],
Tuple[PrecisionType, PrecisionType]]
PrecisionLike = Union[None, str, PrecisionType, tuple[str, str],
tuple[PrecisionType, PrecisionType]]
def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
preferred_element_type: Optional[DTypeLike] = None) -> Array:
@ -699,8 +699,8 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
lhs.shape, rhs.shape))
DotDimensionNumbers = Tuple[Tuple[Sequence[int], Sequence[int]],
Tuple[Sequence[int], Sequence[int]]]
DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
tuple[Sequence[int], Sequence[int]]]
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
precision: PrecisionLike = None,
@ -860,7 +860,7 @@ def reshape(operand: ArrayLike, new_sizes: Shape,
dimensions=None if dims is None or same_dims else dims)
def pad(operand: ArrayLike, padding_value: ArrayLike,
padding_config: Sequence[Tuple[int, int, int]]) -> Array:
padding_config: Sequence[tuple[int, int, int]]) -> Array:
"""Applies low, high, and/or interior padding to an array.
Wraps XLA's `Pad
@ -1111,14 +1111,14 @@ def _reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array:
@overload
def sort(operand: Sequence[Array], dimension: int = -1,
is_stable: bool = True, num_keys: int = 1) -> Tuple[Array, ...]: ...
is_stable: bool = True, num_keys: int = 1) -> tuple[Array, ...]: ...
@overload
def sort(operand: Array, dimension: int = -1,
is_stable: bool = True, num_keys: int = 1) -> Array: ...
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]:
is_stable: bool = True, num_keys: int = 1) -> Union[Array, tuple[Array, ...]]:
"""Wraps XLA's `Sort
<https://www.tensorflow.org/xla/operation_semantics#sort>`_ operator.
@ -1154,13 +1154,13 @@ def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
return sort_p.bind(operand, dimension=dimension, is_stable=is_stable, num_keys=1)[0]
def sort_key_val(keys: Array, values: ArrayLike, dimension: int = -1,
is_stable: bool = True) -> Tuple[Array, Array]:
is_stable: bool = True) -> tuple[Array, Array]:
"""Sorts ``keys`` along ``dimension`` and applies the same permutation to ``values``."""
dimension = canonicalize_axis(dimension, len(keys.shape))
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, num_keys=1)
return k, v
def top_k(operand: ArrayLike, k: int) -> Tuple[Array, Array]:
def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]:
"""Returns top ``k`` values and their indices along the last axis of ``operand``.
Args:
@ -4833,7 +4833,7 @@ def remaining(original, *removed_lists):
return [i for i in original if i not in removed]
def canonicalize_precision(precision: PrecisionLike) -> Optional[Tuple[PrecisionType, PrecisionType]]:
def canonicalize_precision(precision: PrecisionLike) -> Optional[tuple[PrecisionType, PrecisionType]]:
"""Turns an API precision specification, into a pair of enumeration values.
The API can take the precision as a string, or int, and either as a single
@ -4844,7 +4844,7 @@ def canonicalize_precision(precision: PrecisionLike) -> Optional[Tuple[Precision
return None
try:
return type_cast(
Tuple[PrecisionType, PrecisionType],
tuple[PrecisionType, PrecisionType],
(Precision(config.jax_default_matmul_precision),
Precision(config.jax_default_matmul_precision)))
except TypeError:
@ -4853,18 +4853,18 @@ def canonicalize_precision(precision: PrecisionLike) -> Optional[Tuple[Precision
f"{list(Precision._strings)}, but got {config.jax_default_matmul_precision}"
) from None
elif isinstance(precision, str) and precision in Precision._strings:
return type_cast(Tuple[PrecisionType, PrecisionType],
return type_cast(tuple[PrecisionType, PrecisionType],
(Precision(precision), Precision(precision)))
elif isinstance(precision, xla_client.PrecisionConfig.Precision):
return type_cast(Tuple[PrecisionType, PrecisionType], (precision, precision))
return type_cast(tuple[PrecisionType, PrecisionType], (precision, precision))
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
all(isinstance(p, xla_client.PrecisionConfig.Precision) for p in precision)):
return type_cast(Tuple[PrecisionType, PrecisionType], precision)
return type_cast(tuple[PrecisionType, PrecisionType], precision)
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
all(isinstance(s, str) for s in precision)):
s1, s2 = precision
p1 = type_cast(Tuple[PrecisionType, PrecisionType], canonicalize_precision(s1))[0]
p2 = type_cast(Tuple[PrecisionType, PrecisionType], canonicalize_precision(s2))[0]
p1 = type_cast(tuple[PrecisionType, PrecisionType], canonicalize_precision(s1))[0]
p2 = type_cast(tuple[PrecisionType, PrecisionType], canonicalize_precision(s2))[0]
return (p1, p2)
else:
raise ValueError(

View File

@ -16,7 +16,7 @@ import inspect
import functools
from functools import partial
import math
from typing import cast, Any, Callable, List, Literal, Optional, Tuple, TypeVar, Union, overload
from typing import cast, Any, Callable, Literal, Optional, TypeVar, Union, overload
import warnings
import numpy as np
@ -137,7 +137,7 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
@_warn_on_positional_kwargs
def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
compute_right_eigenvectors: bool = True) -> List[Array]:
compute_right_eigenvectors: bool = True) -> list[Array]:
"""Eigendecomposition of a general matrix.
Nonsymmetric eigendecomposition is at present only implemented on CPU.
@ -147,7 +147,7 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
@_warn_on_positional_kwargs
def eigh(x: Array, *, lower: bool = True, symmetrize_input: bool = True,
sort_eigenvalues: bool = True) -> Tuple[Array, Array]:
sort_eigenvalues: bool = True) -> tuple[Array, Array]:
r"""Eigendecomposition of a Hermitian matrix.
Computes the eigenvectors and eigenvalues of a complex Hermitian or real
@ -200,7 +200,7 @@ def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array:
return permutation
def lu(x: ArrayLike) -> Tuple[Array, Array, Array]:
def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
"""LU decomposition with partial pivoting.
Computes the matrix decomposition:
@ -234,7 +234,7 @@ def lu(x: ArrayLike) -> Tuple[Array, Array, Array]:
return lu, pivots, permutation
@_warn_on_positional_kwargs
def qr(x: ArrayLike, *, full_matrices: bool = True) -> Tuple[Array, Array]:
def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
"""QR decomposition.
Computes the QR decomposition
@ -265,17 +265,17 @@ def qr(x: ArrayLike, *, full_matrices: bool = True) -> Tuple[Array, Array]:
return q, r
@overload
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[True]) -> Tuple[Array, Array, Array]: ...
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[True]) -> tuple[Array, Array, Array]: ...
@overload
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[False]) -> Array: ...
@overload
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Union[Array, Tuple[Array, Array, Array]]: ...
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Union[Array, tuple[Array, Array, Array]]: ...
# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
@_warn_on_positional_kwargs
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Union[Array, Tuple[Array, Array, Array]]:
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Union[Array, tuple[Array, Array, Array]]:
"""Singular value decomposition.
Returns the singular values if compute_uv is False, otherwise returns a triple
@ -586,7 +586,7 @@ ad.primitive_jvps[eig_p] = eig_jvp_rule
def eigh_jacobi(x: ArrayLike, *, lower: bool = True,
sort_eigenvalues: bool = True) -> Tuple[Array, Array]:
sort_eigenvalues: bool = True) -> tuple[Array, Array]:
"""Helper Jacobi eigendecomposition implemented by XLA.
Used as a subroutine of QDWH-eig on TPU."""
@ -1389,7 +1389,7 @@ def lu_solve(lu: ArrayLike, permutation: ArrayLike, b: ArrayLike,
# geqrf and orgqr. The names, while cryptic Fortran alphabet soup, are LAPACK's
# names for the primitives, and we stick with them for consistency.
def geqrf(a: ArrayLike) -> Tuple[Array, Array]:
def geqrf(a: ArrayLike) -> tuple[Array, Array]:
"""Computes the QR decomposition of a matrix.
Args:
@ -2008,7 +2008,7 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
def schur(x: ArrayLike, *,
compute_schur_vectors: bool = True,
sort_eig_vals: bool = False,
select_callable: Optional[Callable[..., Any]] = None) -> Tuple[Array, Array]:
select_callable: Optional[Callable[..., Any]] = None) -> tuple[Array, Array]:
return schur_p.bind(
x,
compute_schur_vectors=compute_schur_vectors,
@ -2115,7 +2115,7 @@ ad.primitive_jvps[schur_p] = _schur_jvp_rule
# hessenberg: Upper Hessenberg reduction
def hessenberg(a: ArrayLike) -> Tuple[Array, Array]:
def hessenberg(a: ArrayLike) -> tuple[Array, Array]:
"""Reduces a square matrix to upper Hessenberg form.
Currently implemented on CPU only.
@ -2187,7 +2187,7 @@ mlir.register_lowering(hessenberg_p, _hessenberg_cpu_hlo, platform='cpu')
# tridiagonal: Upper Hessenberg reduction
def tridiagonal(a: ArrayLike, *, lower=True
) -> Tuple[Array, Array, Array, Array]:
) -> tuple[Array, Array, Array, Array]:
"""Reduces a symmetric/Hermitian matrix to tridiagonal form.
Currently implemented on CPU and GPU only.

View File

@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import Any, Optional, Sequence, Tuple, Union, cast as type_cast
from typing import Any, Optional, Sequence, Union, cast as type_cast
import jax
from jax._src.numpy import lax_numpy as jnp
@ -26,7 +26,7 @@ def conv_general_dilated_patches(
lhs: jax.typing.ArrayLike,
filter_shape: Sequence[int],
window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: Union[str, Sequence[tuple[int, int]]],
lhs_dilation: Optional[Sequence[int]] = None,
rhs_dilation: Optional[Sequence[int]] = None,
dimension_numbers: Optional[convolution.ConvGeneralDilatedDimensionNumbers] = None,
@ -122,7 +122,7 @@ def conv_general_dilated_local(
lhs: jax.typing.ArrayLike,
rhs: jax.typing.ArrayLike,
window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: Union[str, Sequence[tuple[int, int]]],
filter_shape: Sequence[int],
lhs_dilation: Optional[Sequence[int]] = None,
rhs_dilation: Optional[Sequence[int]] = None,

View File

@ -25,7 +25,7 @@ https://epubs.siam.org/doi/abs/10.1137/090774999
"""
import functools
from typing import Optional, Tuple
from typing import Optional
import jax
import jax.numpy as jnp
@ -196,7 +196,7 @@ def _qdwh(x, m, n, is_hermitian, max_iterations, eps):
# TODO: Add pivoting.
@functools.partial(jax.jit, static_argnames=('is_hermitian',))
def qdwh(x, *, is_hermitian=False, max_iterations=None, eps=None,
dynamic_shape: Optional[Tuple[int, int]] = None):
dynamic_shape: Optional[tuple[int, int]] = None):
"""QR-based dynamically weighted Halley iteration for polar decomposition.
Args:

View File

@ -15,7 +15,7 @@
import enum
from functools import partial
import math
from typing import Callable, List, NamedTuple, Optional, Sequence, Tuple, Union
from typing import Callable, NamedTuple, Optional, Sequence, Union
import weakref
import numpy as np
@ -178,9 +178,9 @@ class GatherDimensionNumbers(NamedTuple):
implicit; there is always an index vector dimension and it must always be the
last dimension. To gather scalar indices, add a trailing dimension of size 1.
"""
offset_dims: Tuple[int, ...]
collapsed_slice_dims: Tuple[int, ...]
start_index_map: Tuple[int, ...]
offset_dims: tuple[int, ...]
collapsed_slice_dims: tuple[int, ...]
start_index_map: tuple[int, ...]
class GatherScatterMode(enum.Enum):
@ -692,7 +692,7 @@ def dynamic_slice_in_dim(operand: Union[Array, np.ndarray],
start_index: ArrayLike,
slice_size: int, axis: int = 0) -> Array:
"""Convenience wrapper around dynamic_slice applying to one dimension."""
start_indices: List[ArrayLike] = [lax._const(start_index, 0)] * operand.ndim
start_indices: list[ArrayLike] = [lax._const(start_index, 0)] * operand.ndim
slice_sizes = list(operand.shape)
axis = int(axis)
@ -719,7 +719,7 @@ def dynamic_update_slice_in_dim(operand: Union[Array, np.ndarray],
in a single ``axis``.
"""
axis = int(axis)
start_indices: List[ArrayLike] = [lax._const(start_index, 0)] * lax._ndim(operand)
start_indices: list[ArrayLike] = [lax._const(start_index, 0)] * lax._ndim(operand)
start_indices[axis] = start_index
return dynamic_update_slice(operand, update, start_indices)
@ -2149,7 +2149,7 @@ mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu")
def _dynamic_slice_indices(
operand: Union[Array, np.ndarray],
start_indices: Union[Union[Array, np.ndarray], Sequence[ArrayLike]]
) -> List[ArrayLike]:
) -> list[ArrayLike]:
# Normalize the start_indices w.r.t. operand.shape
if len(start_indices) != operand.ndim:
msg = ("Length of slice indices must match number of operand dimensions ({} "
@ -2160,7 +2160,7 @@ def _dynamic_slice_indices(
raise ValueError("Slice indices must be a 1D sequence, got {}"
.format(start_indices.shape)) # type: ignore[union-attr]
start_indices = list(start_indices)
result: List[ArrayLike] = []
result: list[ArrayLike] = []
for i, d in zip(start_indices, operand.shape):
# We test whether i and d are static to avoid unnecessary staging.
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d):

View File

@ -20,7 +20,7 @@ Eigendecomposition on TPU.
from __future__ import annotations
from typing import Any, Tuple
from typing import Any
import jax
from jax import lax
@ -60,7 +60,7 @@ class Stack:
lambda x, y: lax.dynamic_update_index_in_dim(x, y, self._size, 0),
self._data, elem))
def pop(self) -> Tuple[Any, Stack]:
def pop(self) -> tuple[Any, Stack]:
"""Pops from the stack, returning an (elem, updated stack) pair."""
elem = jax.tree_util.tree_map(
lambda x: lax.dynamic_index_in_dim(x, self._size - 1, 0, keepdims=False),

View File

@ -13,7 +13,7 @@
# limitations under the License.
from functools import partial
from typing import (Any, Callable, Optional, Sequence, Union, Tuple)
from typing import Any, Callable, Optional, Sequence, Union
import warnings
import numpy as np
@ -44,7 +44,7 @@ Array = Any
def reduce_window(operand, init_value, computation: Callable,
window_dimensions: core.Shape, window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: Union[str, Sequence[tuple[int, int]]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
"""Wraps XLA's `ReduceWindowWithGeneralPadding
@ -112,7 +112,7 @@ def _get_monoid_window_reducer(monoid_op: Callable,
def _reduce_window_sum(operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]],
padding: Sequence[tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
if base_dilation is None:
@ -127,7 +127,7 @@ def _reduce_window_sum(operand: Array, window_dimensions: core.Shape,
def _reduce_window_prod(operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]],
padding: Sequence[tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
init_value = lax._const(operand, 1)
@ -146,7 +146,7 @@ def _reduce_window_prod(operand: Array, window_dimensions: core.Shape,
def _reduce_window_max(operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]],
padding: Sequence[tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
if base_dilation is None:
@ -161,7 +161,7 @@ def _reduce_window_max(operand: Array, window_dimensions: core.Shape,
def _reduce_window_min(operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]],
padding: Sequence[tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
if base_dilation is None:
@ -177,7 +177,7 @@ def _reduce_window_min(operand: Array, window_dimensions: core.Shape,
def _reduce_window_logaddexp(
operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]],
padding: Sequence[tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
init_value = lax._const(operand, -np.inf)
@ -197,7 +197,7 @@ def _reduce_window_logaddexp(
def _select_and_scatter(operand: Array, select: Callable,
window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]], source: Array,
padding: Sequence[tuple[int, int]], source: Array,
init_value: Array, scatter: Callable) -> Array:
select_jaxpr, select_consts = lax._reduction_jaxpr(
select, lax._abstractify(init_value))
@ -213,7 +213,7 @@ def _select_and_scatter_add(source: Array, operand: Array,
select_prim: core.Primitive,
window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]]) -> Array:
padding: Sequence[tuple[int, int]]) -> Array:
return select_and_scatter_add_p.bind(
source, operand, select_prim=select_prim,
window_dimensions=tuple(window_dimensions),
@ -223,7 +223,7 @@ def _select_and_gather_add(tangents: Array, operand: Array,
select_prim: core.Primitive,
window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]],
padding: Sequence[tuple[int, int]],
base_dilation: Sequence[int],
window_dilation: Sequence[int]) -> Array:
"""Extracts the tangent corresponding to the minimum or maximum element in

View File

@ -15,13 +15,13 @@
"""A LazyLoader class."""
import importlib
from typing import Any, Callable, List, Sequence, Tuple
from typing import Any, Callable, Sequence
def attach(package_name: str, submodules: Sequence[str]) -> Tuple[
def attach(package_name: str, submodules: Sequence[str]) -> tuple[
Callable[[str], Any],
Callable[[], List[str]],
List[str],
Callable[[], list[str]],
list[str],
]:
"""Lazily loads submodules of a package.
@ -31,14 +31,14 @@ def attach(package_name: str, submodules: Sequence[str]) -> Tuple[
```
"""
__all__: List[str] = list(submodules)
__all__: list[str] = list(submodules)
def __getattr__(name: str) -> Any:
if name in submodules:
return importlib.import_module(f"{package_name}.{name}")
raise AttributeError(f"module '{package_name}' has no attribute '{name}")
def __dir__() -> List[str]:
def __dir__() -> list[str]:
return __all__
return __getattr__, __dir__, __all__

View File

@ -18,7 +18,7 @@
import gc
import pathlib
import re
from typing import Optional, Tuple
from typing import Optional
try:
import jaxlib as jaxlib
@ -41,12 +41,12 @@ except Exception as err:
# Checks the jaxlib version before importing anything else from jaxlib.
# Returns the jaxlib version string.
def check_jaxlib_version(jax_version: str, jaxlib_version: str,
minimum_jaxlib_version: str) -> Tuple[int, ...]:
minimum_jaxlib_version: str) -> tuple[int, ...]:
# Regex to match a dotted version prefix 0.1.23.456.789 of a PEP440 version.
# PEP440 allows a number of non-numeric suffixes, which we allow also.
# We currently do not allow an epoch.
version_regex = re.compile(r"[0-9]+(?:\.[0-9]+)*")
def _parse_version(v: str) -> Tuple[int, ...]:
def _parse_version(v: str) -> tuple[int, ...]:
m = version_regex.match(v)
if m is None:
raise ValueError(f"Unable to parse jaxlib version '{v}'")

View File

@ -65,7 +65,7 @@ from __future__ import annotations
from functools import partial
import operator
from typing import Any, Tuple, Callable, Optional, NamedTuple
from typing import Any, Callable, Optional, NamedTuple
import weakref
from jax._src.tree_util import tree_map
@ -242,7 +242,7 @@ def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
@curry
def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args,
use_eq_store=False) -> Tuple[WrappedFun, Any]:
use_eq_store=False) -> tuple[WrappedFun, Any]:
"""Adds one more transformation with auxiliary output to a WrappedFun."""
out_store = Store() if not use_eq_store else EqualStore()
out_thunk = lambda: out_store.val
@ -303,8 +303,8 @@ class TracingDebugInfo(NamedTuple):
# TODO(mattjj): delete partial_eval.DebugInfo, replace all uses with this cls
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: Tuple[str, ...] # e.g. ('args[0]', ... )
result_paths: Optional[Callable[[], Tuple[str, ...]]]
arg_names: tuple[str, ...] # e.g. ('args[0]', ... )
result_paths: Optional[Callable[[], tuple[str, ...]]]
def add_debug_info(f: WrappedFun, debug_info: Optional[TracingDebugInfo]
) -> WrappedFun:

View File

@ -16,7 +16,7 @@ import contextlib
import numpy as np
import itertools as it
from collections import OrderedDict, abc
from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set,
from typing import (Callable, Iterable, Optional, Any,
NamedTuple, Union, Sequence, Mapping)
from functools import wraps, partial, partialmethod, lru_cache
import math
@ -268,7 +268,7 @@ def _prepare_axes(axes, arg_name):
return tree_unflatten(treedef, entries), entries, treedef
Resource = Union[ResourceAxisName, SerialLoop]
ResourceSet = Union[Resource, Tuple[Resource, ...]]
ResourceSet = Union[Resource, tuple[Resource, ...]]
# TODO: Some syntactic sugar to make the API more usable in a single-axis case?
# TODO: Are the resource axes scoped lexically or dynamically? Dynamically for now!
@ -491,7 +491,7 @@ def xmap(fun: Callable,
f"in_axes or axis_sizes, but the following are missing: "
f"{out_axes_names - defined_names}")
normalized_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]] = {}
normalized_axis_resources: dict[AxisName, tuple[ResourceAxisName, ...]] = {}
for axis in defined_names:
resources = axis_resources.get(axis, ())
if not isinstance(resources, tuple):
@ -708,10 +708,10 @@ def make_xmap_callable(fun: lu.WrappedFun,
class EvaluationPlan(NamedTuple):
"""Encapsulates preprocessing common to top-level xmap invocations and its translation rule."""
resource_env: ResourceEnv
physical_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]]
loop_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]]
axis_subst_dict: Dict[AxisName, Tuple[ResourceAxisName, ...]]
axis_vmap_size: Dict[AxisName, Optional[int]]
physical_axis_resources: dict[AxisName, tuple[ResourceAxisName, ...]]
loop_axis_resources: dict[AxisName, tuple[ResourceAxisName, ...]]
axis_subst_dict: dict[AxisName, tuple[ResourceAxisName, ...]]
axis_vmap_size: dict[AxisName, Optional[int]]
@property
def axis_subst(self) -> core.AxisSubst:
@ -729,15 +729,15 @@ class EvaluationPlan(NamedTuple):
@classmethod
def from_axis_resources(cls,
axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]],
axis_resources: dict[AxisName, tuple[ResourceAxisName, ...]],
resource_env: ResourceEnv,
global_axis_sizes: Dict[AxisName, int]):
global_axis_sizes: dict[AxisName, int]):
physical_axis_resources, loop_axis_resources = _unzip_axis_resources(
axis_resources, resource_env)
axis_resource_count = _get_axis_resource_count(
axis_resources, resource_env)
axis_subst_dict = dict(axis_resources)
axis_vmap_size: Dict[AxisName, Optional[int]] = {}
axis_vmap_size: dict[AxisName, Optional[int]] = {}
for naxis, raxes in sorted(axis_resources.items(), key=lambda x: str(x[0])):
num_resources = axis_resource_count[naxis]
assert global_axis_sizes[naxis] % num_resources.nglobal == 0
@ -1038,7 +1038,7 @@ def _xmap_partial_eval_custom_params_updater(
unks_in: Sequence[bool], inst_in: Sequence[bool],
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
num_res: int, params_known: dict, params_staged: dict
) -> Tuple[dict, dict]:
) -> tuple[dict, dict]:
assert params_known['spmd_in_axes'] is None is params_known['spmd_out_axes']
assert params_staged['spmd_in_axes'] is None is params_staged['spmd_out_axes']
@ -1573,7 +1573,7 @@ class ResourceCount(NamedTuple):
def _get_axis_resource_count(
axis_resources, resource_env) -> Dict[ResourceAxisName, ResourceCount]:
axis_resources, resource_env) -> dict[ResourceAxisName, ResourceCount]:
global_res_shape = resource_env.shape
local_res_shape = None
@ -1593,8 +1593,8 @@ def _get_axis_resource_count(
def _get_axis_sizes(args_flat: Iterable[Any],
in_axes_flat: Iterable[AxisNamePos],
global_axis_sizes: Dict[AxisName, int],
axis_resource_count: Dict[AxisName, ResourceCount]):
global_axis_sizes: dict[AxisName, int],
axis_resource_count: dict[AxisName, ResourceCount]):
global_axis_sizes = dict(global_axis_sizes)
for arg, in_axes in zip(args_flat, in_axes_flat):
for name, dim in in_axes.items():
@ -1643,7 +1643,7 @@ def hide_mapped_axes(flat_in_axes, flat_out_axes, *flat_args):
yield map(_unsqueeze_mapped_axes, flat_outputs, flat_out_axes)
def _jaxpr_resources(jaxpr, resource_env) -> Set[ResourceAxisName]:
def _jaxpr_resources(jaxpr, resource_env) -> set[ResourceAxisName]:
if isinstance(jaxpr, core.ClosedJaxpr):
jaxpr = jaxpr.jaxpr
assert isinstance(jaxpr, core.Jaxpr)
@ -1661,7 +1661,7 @@ def _jaxpr_resources(jaxpr, resource_env) -> Set[ResourceAxisName]:
def _to_resource_axes(axes_specs: Sequence[AxisNamePos],
axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]]):
axis_resources: dict[AxisName, tuple[ResourceAxisName, ...]]):
"""
Convert in/out_axes parameters ranging over logical dimensions to
ones that range over resource dimensions.
@ -1695,7 +1695,7 @@ def _slice_tile(x, dim: Optional[int], i, n: int):
return lax.dynamic_slice_in_dim(x, i * tile_size, slice_size=tile_size, axis=dim)
def _unzip_axis_resources(axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]],
def _unzip_axis_resources(axis_resources: dict[AxisName, tuple[ResourceAxisName, ...]],
resource_env: ResourceEnv):
"""Splits axis_resources into separate dicts for physical and loop resources."""
physical_axis_resources = {}
@ -1718,7 +1718,7 @@ def _unzip_axis_resources(axis_resources: Dict[AxisName, Tuple[ResourceAxisName,
def _check_out_avals_vs_out_axes(out_avals: Sequence[core.AbstractValue],
out_axes: Sequence[AxisNamePos],
global_axis_sizes: Dict[AxisName, int]):
global_axis_sizes: dict[AxisName, int]):
defined_axes = set(global_axis_sizes)
for aval, axes in zip(out_avals, out_axes):
if not isinstance(aval, core.ShapedArray):

View File

@ -20,7 +20,7 @@ import contextlib
import functools
import math
import threading
from typing import Any, Hashable, NamedTuple, Set, Sequence, Tuple, Union
from typing import Any, Hashable, NamedTuple, Sequence, Union
import numpy as np
@ -43,7 +43,7 @@ def show_axes(axes):
class ResourceEnv(NamedTuple):
physical_mesh: Mesh
loops: Tuple[Loop, ...]
loops: tuple[Loop, ...]
def with_mesh(self, mesh: Mesh):
overlap = set(mesh.axis_names) & (self.resource_axes - set(self.physical_mesh.axis_names))
@ -60,15 +60,15 @@ class ResourceEnv(NamedTuple):
return self._replace(loops=self.loops + (loop,))
@property
def physical_resource_axes(self) -> Set[ResourceAxisName]:
def physical_resource_axes(self) -> set[ResourceAxisName]:
return set(self.physical_mesh.axis_names)
@property
def loop_resource_axes(self) -> Set[ResourceAxisName]:
def loop_resource_axes(self) -> set[ResourceAxisName]:
return {loop.name for loop in self.loops}
@property
def resource_axes(self) -> Set[ResourceAxisName]:
def resource_axes(self) -> set[ResourceAxisName]:
return self.physical_resource_axes | self.loop_resource_axes
@property
@ -138,7 +138,7 @@ class Mesh(contextlib.ContextDecorator):
"""
devices: np.ndarray
axis_names: Tuple[MeshAxisName, ...]
axis_names: tuple[MeshAxisName, ...]
def __init__(self, devices: Union[np.ndarray, Sequence[xc.Device]],
axis_names: Union[str, Sequence[MeshAxisName]]):

View File

@ -20,10 +20,10 @@ during program execution, the registered listeners will be invoked.
A typical listener callback is to send an event to a metrics collector for
aggregation/exporting.
"""
from typing import Callable, List
from typing import Callable
_event_listeners: List[Callable[[str], None]] = []
_event_duration_secs_listeners: List[Callable[[str, float], None]] = []
_event_listeners: list[Callable[[str], None]] = []
_event_duration_secs_listeners: list[Callable[[str, float], None]] = []
def record_event(event: str) -> None:
"""Record an event."""

View File

@ -18,7 +18,7 @@ from functools import partial
import operator
import warnings
import numpy as np
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional, Union
import jax
import jax.numpy as jnp
@ -293,7 +293,7 @@ logsumexp = _logsumexp
@partial(jax.jit, static_argnames=("axis",))
def log_softmax(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
r"""Log-Softmax function.
@ -326,7 +326,7 @@ def log_softmax(x: Array,
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
#@partial(jax.jit, static_argnames=("axis",))
def softmax(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
r"""Softmax function.
@ -355,7 +355,7 @@ def softmax(x: Array,
@partial(jax.custom_jvp, nondiff_argnums=(1,))
def _softmax(
x,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
@ -382,7 +382,7 @@ def _softmax_deprecated(x, axis, where, initial):
@partial(jax.jit, static_argnames=("axis",))
def standardize(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
mean: Optional[Array] = None,
variance: Optional[Array] = None,
epsilon: Array = 1e-5,
@ -400,7 +400,7 @@ def standardize(x: Array,
return (x - mean) * lax.rsqrt(variance + epsilon)
def normalize(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
mean: Optional[Array] = None,
variance: Optional[Array] = None,
epsilon: Array = 1e-5,

View File

@ -18,7 +18,7 @@ used in Keras and Sonnet.
"""
import math
from typing import Any, Literal, Protocol, Sequence, Tuple, Union
from typing import Any, Literal, Protocol, Sequence, Union
import numpy as np
@ -159,7 +159,7 @@ def _compute_fans(shape: core.NamedShape,
in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Union[int, Sequence[int]] = ()
) -> Tuple[Array, Array]:
) -> tuple[Array, Array]:
"""
Compute effective input and output sizes for a linear or convolutional layer.

View File

@ -24,7 +24,7 @@ __all__ = ['register_jax_array_methods']
import abc
from functools import partial, wraps
import inspect
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import warnings
import numpy as np
@ -335,15 +335,15 @@ def _deprecated_split(*args, **kwargs):
@core.stash_axis_env()
@partial(jax.jit, static_argnums=(1,2,3))
def _multi_slice(arr: ArrayLike,
start_indices: Tuple[Tuple[int, ...]],
limit_indices: Tuple[Tuple[int, ...]],
removed_dims: Tuple[Tuple[int, ...]]) -> List[Array]:
start_indices: tuple[tuple[int, ...]],
limit_indices: tuple[tuple[int, ...]],
removed_dims: tuple[tuple[int, ...]]) -> list[Array]:
"""Extracts multiple slices from `arr`.
This is used to shard DeviceArray arguments to pmap. It's implemented as a
DeviceArray method here to avoid circular imports.
"""
results: List[Array] = []
results: list[Array] = []
for starts, limits, removed in zip(start_indices, limit_indices, removed_dims):
sliced = lax.slice(arr, starts, limits)
if removed:
@ -354,7 +354,7 @@ def _multi_slice(arr: ArrayLike,
# The next two functions are related to iter(device_array), implemented here to
# avoid circular imports.
@jax.jit
def _unstack(x: Array) -> List[Array]:
def _unstack(x: Array) -> list[Array]:
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
def _chunk_iter(x, size):

View File

@ -13,7 +13,7 @@
# limitations under the License.
import abc
from typing import Any, Iterable, List, Tuple, Union
from typing import Any, Iterable, Union
import jax
from jax._src import core
@ -73,7 +73,7 @@ class _Mgrid:
[0, 1, 2]]], dtype=int32)
"""
def __getitem__(self, key: Union[slice, Tuple[slice, ...]]) -> Array:
def __getitem__(self, key: Union[slice, tuple[slice, ...]]) -> Array:
if isinstance(key, slice):
return _make_1d_grid_from_slice(key, op_name="mgrid")
output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="mgrid") for k in key)
@ -117,8 +117,8 @@ class _Ogrid:
"""
def __getitem__(
self, key: Union[slice, Tuple[slice, ...]]
) -> Union[Array, List[Array]]:
self, key: Union[slice, tuple[slice, ...]]
) -> Union[Array, list[Array]]:
if isinstance(key, slice):
return _make_1d_grid_from_slice(key, op_name="ogrid")
output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="ogrid") for k in key)
@ -140,8 +140,8 @@ class _AxisConcat(abc.ABC):
trans1d: int
op_name: str
def __getitem__(self, key: Union[_IndexType, Tuple[_IndexType, ...]]) -> Array:
key_tup: Tuple[_IndexType, ...] = key if isinstance(key, tuple) else (key,)
def __getitem__(self, key: Union[_IndexType, tuple[_IndexType, ...]]) -> Array:
key_tup: tuple[_IndexType, ...] = key if isinstance(key, tuple) else (key,)
params = [self.axis, self.ndmin, self.trans1d, -1]
@ -154,7 +154,7 @@ class _AxisConcat(abc.ABC):
elif directive == "c":
params[-1] = 1
else:
vec: List[Any] = directive.split(",")
vec: list[Any] = directive.split(",")
k = len(vec)
if k < 4:
vec += params[k:]

View File

@ -31,8 +31,8 @@ import math
import operator
import types
from typing import (
overload, Any, Callable, Dict, FrozenSet, List, Literal,
NamedTuple, Optional, Protocol, Sequence, Tuple, TypeVar, Union)
overload, Any, Callable, Literal,
NamedTuple, Optional, Protocol, Sequence, TypeVar, Union)
from textwrap import dedent as _dedent
import warnings
@ -230,7 +230,7 @@ def _jnp_dtype(obj: Optional[DTypeLike], *, align: bool = False,
### utility functions
_DEFAULT_TYPEMAP: Dict[type, _ScalarMeta] = {
_DEFAULT_TYPEMAP: dict[type, _ScalarMeta] = {
np.bool_: bool_,
np.int_: int_,
np.float_: float_,
@ -448,7 +448,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
def histogram(a: ArrayLike, bins: ArrayLike = 10,
range: Optional[Sequence[ArrayLike]] = None,
weights: Optional[ArrayLike] = None,
density: Optional[bool] = None) -> Tuple[Array, Array]:
density: Optional[bool] = None) -> tuple[Array, Array]:
if weights is None:
util.check_arraylike("histogram", a, bins)
a = ravel(*util.promote_dtypes_inexact(a))
@ -469,10 +469,10 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10,
return counts, bin_edges
@util._wraps(np.histogram2d)
def histogram2d(x: ArrayLike, y: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10,
def histogram2d(x: ArrayLike, y: ArrayLike, bins: Union[ArrayLike, list[ArrayLike]] = 10,
range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]]=None,
weights: Optional[ArrayLike] = None,
density: Optional[bool] = None) -> Tuple[Array, Array, Array]:
density: Optional[bool] = None) -> tuple[Array, Array, Array]:
util.check_arraylike("histogram2d", x, y)
try:
N = len(bins) # type: ignore[arg-type]
@ -488,10 +488,10 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: Union[ArrayLike, List[ArrayLik
return hist, edges[0], edges[1]
@util._wraps(np.histogramdd)
def histogramdd(sample: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10,
def histogramdd(sample: ArrayLike, bins: Union[ArrayLike, list[ArrayLike]] = 10,
range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]] = None,
weights: Optional[ArrayLike] = None,
density: Optional[bool] = None) -> Tuple[Array, List[Array]]:
density: Optional[bool] = None) -> tuple[Array, list[Array]]:
if weights is None:
util.check_arraylike("histogramdd", sample)
sample, = util.promote_dtypes_inexact(sample)
@ -511,14 +511,14 @@ def histogramdd(sample: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10,
num_bins = len(bins) # type: ignore[arg-type]
except TypeError:
# when bin_size is integer, the same bin is used for each dimension
bins_per_dimension: List[ArrayLike] = D * [bins] # type: ignore[assignment]
bins_per_dimension: list[ArrayLike] = D * [bins] # type: ignore[assignment]
else:
if num_bins != D:
raise ValueError("should be a bin for each dimension.")
bins_per_dimension = list(bins) # type: ignore[arg-type]
bin_idx_by_dim: List[Array] = []
bin_edges_by_dim: List[Array] = []
bin_idx_by_dim: list[Array] = []
bin_edges_by_dim: list[Array] = []
for i in builtins.range(D):
range_i = None if range is None else range[i]
@ -583,7 +583,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:
@util._wraps(np.rot90, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('k', 'axes'))
def rot90(m: ArrayLike, k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Array:
def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array:
util.check_arraylike("rot90", m)
ax1, ax2 = axes
ax1 = _canonicalize_axis(ax1, ndim(m))
@ -605,12 +605,12 @@ def rot90(m: ArrayLike, k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Array:
@util._wraps(np.flip, lax_description=_ARRAY_VIEW_DOC)
def flip(m: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
def flip(m: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array:
util.check_arraylike("flip", m)
return _flip(asarray(m), reductions._ensure_optional_axes(axis))
@partial(jit, static_argnames=('axis',))
def _flip(m: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
def _flip(m: Array, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array:
if axis is None:
return lax.rev(m, list(range(len(shape(m)))))
axis = _ensure_index_tuple(axis)
@ -674,7 +674,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
nd = arr.ndim
axis = _canonicalize_axis(axis, nd)
combined: List[Array] = []
combined: list[Array] = []
if prepend is not None:
util.check_arraylike("diff", prepend)
if isscalar(prepend):
@ -734,8 +734,8 @@ def ediff1d(ary: ArrayLike, to_end: Optional[ArrayLike] = None,
@util._wraps(np.gradient, skip_params=['edge_order'])
@partial(jit, static_argnames=('axis', 'edge_order'))
def gradient(f: ArrayLike, *varargs: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
edge_order: Optional[int] = None) -> Union[Array, List[Array]]:
axis: Optional[Union[int, tuple[int, ...]]] = None,
edge_order: Optional[int] = None) -> Union[Array, list[Array]]:
if edge_order is not None:
raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.")
a, *spacing = util.promote_args_inexact("gradient", f, *varargs)
@ -805,7 +805,7 @@ def ravel(a: ArrayLike, order: str = "C") -> Array:
@util._wraps(np.ravel_multi_index)
def ravel_multi_index(multi_index: Tuple[ArrayLike, ...], dims: Tuple[int, ...],
def ravel_multi_index(multi_index: tuple[ArrayLike, ...], dims: tuple[int, ...],
mode: str = 'raise', order: str = 'C') -> Array:
assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}"
dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims)
@ -848,7 +848,7 @@ and out-of-bounds indices are clipped into the valid range.
"""
@util._wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC)
def unravel_index(indices: ArrayLike, shape: Shape) -> Tuple[Array, ...]:
def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
util.check_arraylike("unravel_index", indices)
indices_arr = asarray(indices)
# Note: we do not convert shape to an array, because it may be passed as a
@ -888,12 +888,12 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array:
return reshape(arr, new_shape)
@util._wraps(np.squeeze, lax_description=_ARRAY_VIEW_DOC)
def squeeze(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
def squeeze(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array:
util.check_arraylike("squeeze", a)
return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None)
@partial(jit, static_argnames=('axis',), inline=True)
def _squeeze(a: Array, axis: Tuple[int]) -> Array:
def _squeeze(a: Array, axis: tuple[int]) -> Array:
if axis is None:
a_shape = shape(a)
if not core.is_constant_shape(a_shape):
@ -927,7 +927,7 @@ def moveaxis(a: ArrayLike, source: Union[int, Sequence[int]],
_ensure_index_tuple(destination))
@partial(jit, static_argnames=('source', 'destination'), inline=True)
def _moveaxis(a: Array, source: Tuple[int, ...], destination: Tuple[int, ...]) -> Array:
def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) -> Array:
source = tuple(_canonicalize_axis(i, ndim(a)) for i in source)
destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination)
if len(source) != len(destination):
@ -1064,20 +1064,20 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
@overload
def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, *,
size: Optional[int] = None,
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
) -> Tuple[Array, ...]: ...
fill_value: Union[None, Array, tuple[ArrayLike]] = None
) -> tuple[Array, ...]: ...
@overload
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *,
size: Optional[int] = None,
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
fill_value: Union[None, Array, tuple[ArrayLike]] = None
) -> Array: ...
@overload
def where(condition: ArrayLike, x: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None, *, size: Optional[int] = None,
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
) -> Union[Array, Tuple[Array, ...]]: ...
fill_value: Union[None, Array, tuple[ArrayLike]] = None
) -> Union[Array, tuple[Array, ...]]: ...
@util._wraps(np.where,
lax_description=_dedent("""
@ -1105,8 +1105,8 @@ def where(condition: ArrayLike, x: Optional[ArrayLike] = None,
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def where(condition: ArrayLike, x: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None, *, size: Optional[int] = None,
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
) -> Union[Array, Tuple[Array, ...]]:
fill_value: Union[None, Array, tuple[ArrayLike]] = None
) -> Union[Array, tuple[Array, ...]]:
if x is None and y is None:
util.check_arraylike("where", condition)
return nonzero(condition, size=size, fill_value=fill_value)
@ -1165,11 +1165,11 @@ def bincount(x: ArrayLike, weights: Optional[ArrayLike] = None,
return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights)
@overload
def broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]: ...
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: ...
@overload
def broadcast_shapes(*shapes: Tuple[Union[int, core.Tracer], ...]
) -> Tuple[Union[int, core.Tracer], ...]: ...
def broadcast_shapes(*shapes: tuple[Union[int, core.Tracer], ...]
) -> tuple[Union[int, core.Tracer], ...]: ...
@util._wraps(getattr(np, "broadcast_shapes", None))
def broadcast_shapes(*shapes):
@ -1182,7 +1182,7 @@ def broadcast_shapes(*shapes):
@util._wraps(np.broadcast_arrays, lax_description="""\
The JAX version does not necessarily return a view of the input.
""")
def broadcast_arrays(*args: ArrayLike) -> List[Array]:
def broadcast_arrays(*args: ArrayLike) -> list[Array]:
return util._broadcast_arrays(*args)
@ -1195,7 +1195,7 @@ def broadcast_to(array: ArrayLike, shape: Shape) -> Array:
def _split(op: str, ary: ArrayLike,
indices_or_sections: Union[int, Sequence[int], ArrayLike],
axis: int = 0) -> List[Array]:
axis: int = 0) -> list[Array]:
util.check_arraylike(op, ary)
ary = asarray(ary)
axis = core.concrete_or_error(operator.index, axis, f"in jax.numpy.{op} argument `axis`")
@ -1234,12 +1234,12 @@ def _split(op: str, ary: ArrayLike,
@util._wraps(np.split, lax_description=_ARRAY_VIEW_DOC)
def split(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike],
axis: int = 0) -> List[Array]:
axis: int = 0) -> list[Array]:
return _split("split", ary, indices_or_sections, axis=axis)
def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, Union[int, ArrayLike]], List[Array]]:
def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, Union[int, ArrayLike]], list[Array]]:
@util._wraps(getattr(np, op), update_doc=False)
def f(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike]) -> List[Array]:
def f(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike]) -> list[Array]:
# for 1-D array, hsplit becomes vsplit
nonlocal axis
util.check_arraylike(op, ary)
@ -1255,7 +1255,7 @@ dsplit = _split_on_axis("dsplit", axis=2)
@util._wraps(np.array_split)
def array_split(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike],
axis: int = 0) -> List[Array]:
axis: int = 0) -> list[Array]:
return _split("array_split", ary, indices_or_sections, axis=axis)
@util._wraps(np.clip, skip_params=['out'])
@ -1367,8 +1367,8 @@ fill_value : array_like, optional
@util._wraps(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)
def nonzero(a: ArrayLike, *, size: Optional[int] = None,
fill_value: Union[None, ArrayLike, Tuple[ArrayLike]] = None
) -> Tuple[Array, ...]:
fill_value: Union[None, ArrayLike, tuple[ArrayLike]] = None
) -> tuple[Array, ...]:
util.check_arraylike("nonzero", a)
arr = atleast_1d(a)
del a
@ -1394,7 +1394,7 @@ def nonzero(a: ArrayLike, *, size: Optional[int] = None,
@util._wraps(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)
def flatnonzero(a: ArrayLike, *, size: Optional[int] = None,
fill_value: Union[None, ArrayLike, Tuple[ArrayLike]] = None) -> Array:
fill_value: Union[None, ArrayLike, tuple[ArrayLike]] = None) -> Array:
return nonzero(ravel(a), size=size, fill_value=fill_value)[0]
@ -1428,7 +1428,7 @@ def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = None,
### Padding
PadValueLike = Union[T, Sequence[T], Sequence[Sequence[T]]]
PadValue = Tuple[Tuple[T, T], ...]
PadValue = tuple[tuple[T, T], ...]
class PadStatFunc(Protocol):
def __call__(self, array: ArrayLike, /, *,
@ -1477,7 +1477,7 @@ def _broadcast_to_pairs(nvals: PadValueLike, nd: int, name: str) -> PadValue:
f"Valid shapes are ({nd}, 2), (1, 2), (2,), (1,), or ().")
def _check_no_padding(axis_padding: Tuple[Any, Any], mode: str):
def _check_no_padding(axis_padding: tuple[Any, Any], mode: str):
if (axis_padding[0] > 0 or axis_padding[1] > 0):
msg = "Cannot apply '{}' padding to empty axis"
raise ValueError(msg.format(mode))
@ -1689,7 +1689,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str,
if nd == 0:
return array
stat_funcs: Dict[str, PadStatFunc] = {
stat_funcs: dict[str, PadStatFunc] = {
"maximum": reductions.amax,
"minimum": reductions.amin,
"mean": reductions.mean,
@ -1806,7 +1806,7 @@ def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array:
try:
iter(reps) # type: ignore[arg-type]
except TypeError:
reps_tup: Tuple[DimSize, ...] = (reps,)
reps_tup: tuple[DimSize, ...] = (reps,)
else:
reps_tup = tuple(reps) # type: ignore[assignment,arg-type]
reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep
@ -1932,7 +1932,7 @@ def _atleast_nd(x: ArrayLike, n: int) -> Array:
m = ndim(x)
return lax.broadcast(x, (1,) * (n - m)) if m < n else asarray(x)
def _block(xs: Union[ArrayLike, List[ArrayLike]]) -> Tuple[Array, int]:
def _block(xs: Union[ArrayLike, list[ArrayLike]]) -> tuple[Array, int]:
if isinstance(xs, tuple):
raise ValueError("jax.numpy.block does not allow tuples, got {}"
.format(xs))
@ -1950,13 +1950,13 @@ def _block(xs: Union[ArrayLike, List[ArrayLike]]) -> Tuple[Array, int]:
@util._wraps(np.block)
@jit
def block(arrays: Union[ArrayLike, List[ArrayLike]]) -> Array:
def block(arrays: Union[ArrayLike, list[ArrayLike]]) -> Array:
out, _ = _block(arrays)
return out
@util._wraps(np.atleast_1d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_1d(*arys: ArrayLike) -> Union[Array, List[Array]]:
def atleast_1d(*arys: ArrayLike) -> Union[Array, list[Array]]:
if len(arys) == 1:
arr = asarray(arys[0])
return arr if ndim(arr) >= 1 else reshape(arr, -1)
@ -1966,7 +1966,7 @@ def atleast_1d(*arys: ArrayLike) -> Union[Array, List[Array]]:
@util._wraps(np.atleast_2d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_2d(*arys: ArrayLike) -> Union[Array, List[Array]]:
def atleast_2d(*arys: ArrayLike) -> Union[Array, list[Array]]:
if len(arys) == 1:
arr = asarray(arys[0])
if ndim(arr) >= 2:
@ -1981,7 +1981,7 @@ def atleast_2d(*arys: ArrayLike) -> Union[Array, List[Array]]:
@util._wraps(np.atleast_3d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_3d(*arys: ArrayLike) -> Union[Array, List[Array]]:
def atleast_3d(*arys: ArrayLike) -> Union[Array, list[Array]]:
if len(arys) == 1:
arr = asarray(arys[0])
if ndim(arr) == 0:
@ -2389,22 +2389,22 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
def linspace(start: ArrayLike, stop: ArrayLike, num: int,
endpoint: bool, retstep: Literal[True],
dtype: Optional[DTypeLike] = None,
axis: int = 0) -> Tuple[Array, Array]: ...
axis: int = 0) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, *, retstep: Literal[True],
dtype: Optional[DTypeLike] = None,
axis: int = 0) -> Tuple[Array, Array]: ...
axis: int = 0) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: Optional[DTypeLike] = None,
axis: int = 0) -> Union[Array, Tuple[Array, Array]]: ...
axis: int = 0) -> Union[Array, tuple[Array, Array]]: ...
@util._wraps(np.linspace)
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: Optional[DTypeLike] = None,
axis: int = 0) -> Union[Array, Tuple[Array, Array]]:
axis: int = 0) -> Union[Array, tuple[Array, Array]]:
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
return _linspace(start, stop, num, endpoint, retstep, dtype, axis)
@ -2413,7 +2413,7 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: Optional[DTypeLike] = None,
axis: int = 0) -> Union[Array, Tuple[Array, Array]]:
axis: int = 0) -> Union[Array, tuple[Array, Array]]:
"""Implementation of linspace differentiable in start and stop args."""
dtypes.check_user_dtype_supported(dtype, "linspace")
if num < 0:
@ -2526,7 +2526,7 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool
@util._wraps(np.meshgrid, lax_description=_ARRAY_VIEW_DOC)
def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
indexing: str = 'xy') -> List[Array]:
indexing: str = 'xy') -> list[Array]:
util.check_arraylike("meshgrid", *xi)
args = [asarray(x) for x in xi]
if not copy:
@ -2557,7 +2557,7 @@ def i0(x: ArrayLike) -> Array:
@util._wraps(np.ix_)
def ix_(*args: ArrayLike) -> Tuple[Array, ...]:
def ix_(*args: ArrayLike) -> tuple[Array, ...]:
util.check_arraylike("ix", *args)
n = len(args)
output = []
@ -2584,13 +2584,13 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
sparse: Literal[False] = False) -> Array: ...
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
*, sparse: Literal[True]) -> Tuple[Array, ...]: ...
*, sparse: Literal[True]) -> tuple[Array, ...]: ...
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
sparse: bool = False) -> Union[Array, Tuple[Array, ...]]: ...
sparse: bool = False) -> Union[Array, tuple[Array, ...]]: ...
@util._wraps(np.indices)
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
sparse: bool = False) -> Union[Array, Tuple[Array, ...]]:
sparse: bool = False) -> Union[Array, tuple[Array, ...]]:
dimensions = tuple(
core.concrete_or_error(operator.index, d, "dimensions argument of jnp.indices")
for d in dimensions)
@ -2649,10 +2649,10 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = None, *,
input_shape = shape(a)
aux_axis = axis if axis < 0 else axis + 1
a = expand_dims(a, aux_axis)
reps: List[DimSize] = [1] * len(shape(a))
reps: list[DimSize] = [1] * len(shape(a))
reps[aux_axis] = repeats
a = tile(a, reps)
result_shape: List[DimSize] = list(input_shape)
result_shape: list[DimSize] = list(input_shape)
result_shape[axis] *= repeats
return reshape(a, result_shape)
@ -2781,7 +2781,7 @@ def _triu_size(n, m, k):
@util._wraps(np.triu_indices)
def triu_indices(n: int, k: int = 0, m: Optional[int] = None) -> Tuple[Array, Array]:
def triu_indices(n: int, k: int = 0, m: Optional[int] = None) -> tuple[Array, Array]:
n = core.concrete_or_error(operator.index, n, "n argument of jnp.triu_indices")
k = core.concrete_or_error(operator.index, k, "k argument of jnp.triu_indices")
m = n if m is None else core.concrete_or_error(operator.index, m, "m argument of jnp.triu_indices")
@ -2790,7 +2790,7 @@ def triu_indices(n: int, k: int = 0, m: Optional[int] = None) -> Tuple[Array, Ar
@util._wraps(np.tril_indices)
def tril_indices(n: int, k: int = 0, m: Optional[int] = None) -> Tuple[Array, Array]:
def tril_indices(n: int, k: int = 0, m: Optional[int] = None) -> tuple[Array, Array]:
n = core.concrete_or_error(operator.index, n, "n argument of jnp.triu_indices")
k = core.concrete_or_error(operator.index, k, "k argument of jnp.triu_indices")
m = n if m is None else core.concrete_or_error(operator.index, m, "m argument of jnp.triu_indices")
@ -2799,13 +2799,13 @@ def tril_indices(n: int, k: int = 0, m: Optional[int] = None) -> Tuple[Array, Ar
@util._wraps(np.triu_indices_from)
def triu_indices_from(arr: ArrayLike, k: int = 0) -> Tuple[Array, Array]:
def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
arr_shape = shape(arr)
return triu_indices(arr_shape[-2], k=k, m=arr_shape[-1])
@util._wraps(np.tril_indices_from)
def tril_indices_from(arr: ArrayLike, k: int = 0) -> Tuple[Array, Array]:
def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
arr_shape = shape(arr)
return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1])
@ -3294,7 +3294,7 @@ def _removechars(s, chars):
@partial(jit, static_argnums=(1, 2, 3, 4), inline=True)
def _einsum(
operands: Sequence,
contractions: Sequence[Tuple[Tuple[int, ...], FrozenSet[str], str]],
contractions: Sequence[tuple[tuple[int, ...], frozenset[str], str]],
precision,
preferred_element_type,
_dot_general=lax.dot_general,
@ -4315,8 +4315,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
# Pairs of (array, start_dim) values. These will be broadcast into
# gather_indices_shape, with the array dimensions aligned to start_dim, and
# then concatenated.
gather_indices: List[Tuple[Array, int]] = []
gather_indices_shape: List[int] = []
gather_indices: list[tuple[Array, int]] = []
gather_indices_shape: list[int] = []
# We perform three transformations to y before the scatter op, in order:
# First, y is broadcast to slice_shape. In general `y` only need broadcast to
@ -4687,11 +4687,11 @@ def kaiser(M: int, beta: ArrayLike) -> Array:
return i0(beta * ufuncs.sqrt(1 - ((n - alpha) / alpha) ** 2)) / i0(beta)
def _gcd_cond_fn(xs: Tuple[Array, Array]) -> Array:
def _gcd_cond_fn(xs: tuple[Array, Array]) -> Array:
x1, x2 = xs
return reductions.any(x2 != 0)
def _gcd_body_fn(xs: Tuple[Array, Array]) -> Tuple[Array, Array]:
def _gcd_body_fn(xs: tuple[Array, Array]) -> tuple[Array, Array]:
x1, x2 = xs
x1, x2 = (where(x2 != 0, x2, x1),
where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0)))
@ -4927,7 +4927,7 @@ See the :func:`jax.lax.switch` documentation for more information.
@util._wraps(np.piecewise, lax_description=_PIECEWISE_DOC)
def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]],
funclist: List[Union[ArrayLike, Callable[..., Array]]],
funclist: list[Union[ArrayLike, Callable[..., Array]]],
*args, **kw) -> Array:
util.check_arraylike("piecewise", x)
nc, nf = len(condlist), len(funclist)
@ -4944,8 +4944,8 @@ def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]],
*args, **kw)
@partial(jit, static_argnames=['funcs'])
def _piecewise(x: Array, condlist: Array, consts: Dict[int, ArrayLike],
funcs: FrozenSet[Tuple[int, Callable[..., Array]]],
def _piecewise(x: Array, condlist: Array, consts: dict[int, ArrayLike],
funcs: frozenset[tuple[int, Callable[..., Array]]],
*args, **kw) -> Array:
funcdict = dict(funcs)
funclist = [consts.get(i, funcdict.get(i)) for i in range(len(condlist) + 1)]

View File

@ -18,7 +18,7 @@ from functools import partial
import numpy as np
import textwrap
import operator
from typing import Literal, Optional, Tuple, Union, cast, overload
from typing import Literal, Optional, Union, cast, overload
import jax
from jax import jit, custom_jvp
@ -49,10 +49,10 @@ def cholesky(a: ArrayLike) -> Array:
@overload
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[True],
hermitian: bool = False) -> Tuple[Array, Array, Array]: ...
hermitian: bool = False) -> tuple[Array, Array, Array]: ...
@overload
def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[True],
hermitian: bool = False) -> Tuple[Array, Array, Array]: ...
hermitian: bool = False) -> tuple[Array, Array, Array]: ...
@overload
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False],
hermitian: bool = False) -> Array: ...
@ -61,12 +61,12 @@ def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False],
hermitian: bool = False) -> Array: ...
@overload
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False) -> Union[Array, Tuple[Array, Array, Array]]: ...
hermitian: bool = False) -> Union[Array, tuple[Array, Array, Array]]: ...
@_wraps(np.linalg.svd)
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False) -> Union[Array, Tuple[Array, Array, Array]]:
hermitian: bool = False) -> Union[Array, tuple[Array, Array, Array]]:
check_arraylike("jnp.linalg.svd", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
if hermitian:
@ -142,7 +142,7 @@ def matrix_rank(M: ArrayLike, tol: Optional[ArrayLike] = None) -> Array:
@custom_jvp
def _slogdet_lu(a: Array) -> Tuple[Array, Array]:
def _slogdet_lu(a: Array) -> tuple[Array, Array]:
dtype = lax.dtype(a)
lu, pivot, _ = lax_linalg.lu(a)
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
@ -164,7 +164,7 @@ def _slogdet_lu(a: Array) -> Tuple[Array, Array]:
return sign, ufuncs.real(logdet)
@custom_jvp
def _slogdet_qr(a: Array) -> Tuple[Array, Array]:
def _slogdet_qr(a: Array) -> tuple[Array, Array]:
# Implementation of slogdet using QR decomposition. One reason we might prefer
# QR decomposition is that it is more amenable to a fast batched
# implementation on TPU because of the lack of row pivoting.
@ -192,7 +192,7 @@ def _slogdet_qr(a: Array) -> Tuple[Array, Array]:
LU decomposition if ``None``.
"""))
@partial(jit, static_argnames=('method',))
def slogdet(a: ArrayLike, *, method: Optional[str] = None) -> Tuple[Array, Array]:
def slogdet(a: ArrayLike, *, method: Optional[str] = None) -> tuple[Array, Array]:
check_arraylike("jnp.linalg.slogdet", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
a_shape = jnp.shape(a)
@ -223,7 +223,7 @@ def _slogdet_jvp(primals, tangents):
_slogdet_lu.defjvp(_slogdet_jvp)
_slogdet_qr.defjvp(_slogdet_jvp)
def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> Tuple[Array, Array]:
def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> tuple[Array, Array]:
"""Equivalent to det(a)*solve(a, b) for nonsingular mat.
Intermediate function used for jvp and vjp of det.
@ -364,7 +364,7 @@ At present, non-symmetric eigendecomposition is only implemented on the CPU
backend. However eigendecomposition for symmetric/Hermitian matrices is
implemented more widely (see :func:`jax.numpy.linalg.eigh`).
""")
def eig(a: ArrayLike) -> Tuple[Array, Array]:
def eig(a: ArrayLike) -> tuple[Array, Array]:
check_arraylike("jnp.linalg.eig", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
w, v = lax_linalg.eig(a, compute_left_eigenvectors=False)
@ -382,7 +382,7 @@ def eigvals(a: ArrayLike) -> Array:
@_wraps(np.linalg.eigh)
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
def eigh(a: ArrayLike, UPLO: Optional[str] = None,
symmetrize_input: bool = True) -> Tuple[Array, Array]:
symmetrize_input: bool = True) -> tuple[Array, Array]:
check_arraylike("jnp.linalg.eigh", a)
if UPLO is None or UPLO == "L":
lower = True
@ -481,7 +481,7 @@ def inv(a: ArrayLike) -> Array:
@_wraps(np.linalg.norm)
@partial(jit, static_argnames=('ord', 'axis', 'keepdims'))
def norm(x: ArrayLike, ord: Union[int, str, None] = None,
axis: Union[None, Tuple[int, ...], int] = None,
axis: Union[None, tuple[int, ...], int] = None,
keepdims: bool = False) -> Array:
check_arraylike("jnp.linalg.norm", x)
x, = promote_dtypes_inexact(jnp.asarray(x))
@ -532,7 +532,7 @@ def norm(x: ArrayLike, ord: Union[int, str, None] = None,
return ufuncs.power(out, ord_inv)
elif num_axes == 2:
row_axis, col_axis = cast(Tuple[int, ...], axis)
row_axis, col_axis = cast(tuple[int, ...], axis)
if ord is None or ord in ('f', 'fro'):
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
keepdims=keepdims))
@ -578,11 +578,11 @@ def norm(x: ArrayLike, ord: Union[int, str, None] = None,
@overload
def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ...
@overload
def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]: ...
def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, tuple[Array, Array]]: ...
@_wraps(np.linalg.qr)
@partial(jit, static_argnames=('mode',))
def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]:
def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, tuple[Array, Array]]:
check_arraylike("jnp.linalg.qr", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
if mode == "raw":
@ -609,7 +609,7 @@ def solve(a: ArrayLike, b: ArrayLike) -> Array:
def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *,
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]:
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
# TODO: add custom jvp rule for more robust lstsq differentiation
a, b = promote_dtypes_inexact(a, b)
@ -669,7 +669,7 @@ _jit_lstsq = jit(partial(_lstsq, numpy_resid=False))
poorly behaved for some inputs, particularly for low-rank `a`.
"""))
def lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float] = None, *,
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]:
check_arraylike("jnp.linalg.lstsq", a, b)
if numpy_resid:
return _lstsq(a, b, rcond, numpy_resid=True)

View File

@ -15,7 +15,7 @@
from functools import partial
import operator
from typing import Optional, Tuple, Union
from typing import Optional, Union
import numpy as np
@ -110,7 +110,7 @@ Also, it works best on rcond <= 10e-3 values.
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
def polyfit(x: Array, y: Array, deg: int, rcond: Optional[float] = None,
full: bool = False, w: Optional[Array] = None, cov: bool = False
) -> Union[Array, Tuple[Array, ...]]:
) -> Union[Array, tuple[Array, ...]]:
check_arraylike("polyfit", x, y)
deg = core.concrete_or_error(int, deg, "deg must be int")
order = deg + 1
@ -301,7 +301,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
return convolve(a1_arr, a2_arr, mode='full')
@_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> Tuple[Array, Array]:
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]:
check_arraylike("polydiv", u, v)
u_arr, v_arr = promote_dtypes_inexact(u, v)
m = len(u_arr) - 1

View File

@ -17,7 +17,7 @@ from functools import partial
import math
import operator
from typing import (
overload, Any, Callable, Literal, Optional, Protocol, Sequence, Tuple, Union)
overload, Any, Callable, Literal, Optional, Protocol, Sequence, Union)
import warnings
import numpy as np
@ -349,15 +349,15 @@ def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None
returned: Literal[True], keepdims: bool = False) -> Array: ...
@overload
def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]: ...
returned: bool = False, keepdims: bool = False) -> Union[Array, tuple[Array, Array]]: ...
@_wraps(np.average)
def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]:
returned: bool = False, keepdims: bool = False) -> Union[Array, tuple[Array, Array]]:
return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)
@partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True)
def _average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]:
returned: bool = False, keepdims: bool = False) -> Union[Array, tuple[Array, Array]]:
if weights is None: # Treat all weights as 1
check_arraylike("average", a)
a, = promote_dtypes_inexact(a)
@ -450,7 +450,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
return lax.div(result, normalizer).astype(dtype)
def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike) -> Tuple[DType, DType]:
def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike) -> tuple[DType, DType]:
if dtype:
if (not dtypes.issubdtype(dtype, np.complexfloating) and
dtypes.issubdtype(a_dtype, np.complexfloating)):
@ -689,7 +689,7 @@ nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
@_wraps(np.quantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
check_arraylike("quantile", a, q)
@ -705,7 +705,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ..
@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
check_arraylike("nanquantile", a, q)
@ -718,7 +718,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int,
"Use 'method=' instead.", DeprecationWarning)
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, True)
def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
def _quantile(a: Array, q: Array, axis: Optional[Union[int, tuple[int, ...]]],
interpolation: str, keepdims: bool, squash_nans: bool) -> Array:
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
@ -843,7 +843,7 @@ def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def percentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
axis: Optional[Union[int, tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
check_arraylike("percentile", a, q)
@ -855,7 +855,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
axis: Optional[Union[int, tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
check_arraylike("nanpercentile", a, q)
@ -866,7 +866,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
@_wraps(np.median, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
def median(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
check_arraylike("median", a)
@ -875,7 +875,7 @@ def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
def nanmedian(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
check_arraylike("nanmedian", a)

View File

@ -16,7 +16,7 @@ from functools import partial
import math
import operator
from textwrap import dedent as _dedent
from typing import Optional, Tuple, Union, cast
from typing import Optional, Union, cast
import numpy as np
@ -155,7 +155,7 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr
@partial(jit, static_argnames=['return_indices'])
def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: bool = False) -> Tuple[Array, ...]:
def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: bool = False) -> tuple[Array, ...]:
"""
Helper function for intersect1d which is jit-able
"""
@ -175,7 +175,7 @@ def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: boo
@_wraps(np.intersect1d)
def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
return_indices: bool = False) -> Union[Array, Tuple[Array, Array, Array]]:
return_indices: bool = False) -> Union[Array, tuple[Array, Array, Array]]:
check_arraylike("intersect1d", ar1, ar2)
ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()")
ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()")
@ -226,7 +226,7 @@ UNIQUE_SIZE_HINT = (
"a concrete value for the size argument, which will determine the output size.")
@partial(jit, static_argnums=1)
def _unique_sorted_mask(ar: Array, axis: int) -> Tuple[Array, Array, Array]:
def _unique_sorted_mask(ar: Array, axis: int) -> tuple[Array, Array, Array]:
aux = moveaxis(ar, axis, 0)
if np.issubdtype(aux.dtype, np.complexfloating):
# Work around issue in sorting of complex numbers with Nan only in the
@ -255,7 +255,7 @@ def _unique_sorted_mask(ar: Array, axis: int) -> Tuple[Array, Array, Array]:
def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bool = False,
return_counts: bool = False, size: Optional[int] = None,
fill_value: Optional[ArrayLike] = None, return_true_size: bool = False
) -> Union[Array, Tuple[Array, ...]]:
) -> Union[Array, tuple[Array, ...]]:
"""
Find the unique elements of an array along a particular axis.
"""
@ -280,7 +280,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
result = full_like(result, fill_value, shape=(size, *result.shape[1:]))
result = moveaxis(result, 0, axis)
ret: Tuple[Array, ...] = (result,)
ret: tuple[Array, ...] = (result,)
if return_index:
if aux.size:
ret += (perm[ind],)

View File

@ -19,7 +19,7 @@ Implements ufuncs for jax.numpy.
from functools import partial
import operator
from textwrap import dedent
from typing import Any, Callable, Tuple, Union, overload
from typing import Any, Callable, Union, overload
import numpy as np
@ -285,7 +285,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@_wraps(np.divmod, module='numpy')
@jit
def divmod(x1: ArrayLike, x2: ArrayLike, /) -> Tuple[Array, Array]:
def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]:
x1, x2 = promote_args_numeric("divmod", x1, x2)
if dtypes.issubdtype(dtypes.dtype(x1), np.integer):
return floor_divide(x1, x2), remainder(x1, x2)
@ -293,7 +293,7 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> Tuple[Array, Array]:
return _float_divmod(x1, x2)
def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> Tuple[Array, Array]:
def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]:
# see float_divmod in floatobject.c of CPython
mod = lax.rem(x1, x2)
div = lax.div(lax.sub(x1, mod), x2)
@ -521,7 +521,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@_wraps(np.frexp, module='numpy')
@jit
def frexp(x: ArrayLike, /) -> Tuple[Array, Array]:
def frexp(x: ArrayLike, /) -> tuple[Array, Array]:
check_arraylike("frexp", x)
x, = promote_dtypes_inexact(x)
if dtypes.issubdtype(x.dtype, np.complexfloating):
@ -616,7 +616,7 @@ def real(val: ArrayLike, /) -> Array:
@_wraps(np.modf, module='numpy', skip_params=['out'])
@jit
def modf(x: ArrayLike, /, out=None) -> Tuple[Array, Array]:
def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]:
check_arraylike("modf", x)
x, = promote_dtypes_inexact(x)
if out is not None:

View File

@ -15,9 +15,8 @@
from functools import partial
import re
import textwrap
from typing import (
Any, Callable, Dict, List, NamedTuple, Optional, Sequence, TypeVar
)
from typing import Any, Callable, NamedTuple, Optional, Sequence, TypeVar
import warnings
from jax._src import dtypes
@ -53,7 +52,7 @@ class ParsedDoc(NamedTuple):
signature: str = ""
summary: str = ""
front_matter: str = ""
sections: Dict[str, str] = {}
sections: dict[str, str] = {}
def _parse_numpydoc(docstr: Optional[str]) -> ParsedDoc:
@ -101,7 +100,7 @@ def _parse_numpydoc(docstr: Optional[str]) -> ParsedDoc:
front_matter=front_matter, sections=sections)
def _parse_parameters(body: str) -> Dict[str, str]:
def _parse_parameters(body: str) -> dict[str, str]:
"""Parse the Parameters section of a docstring."""
title, underline, content = body.split('\n', 2)
assert title == 'Parameters'
@ -110,7 +109,7 @@ def _parse_parameters(body: str) -> Dict[str, str]:
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}
def _parse_extra_params(extra_params: str) -> Dict[str, str]:
def _parse_extra_params(extra_params: str) -> dict[str, str]:
"""Parse the extra parameters passed to _wraps()"""
parameters = _parameter_break.split(extra_params.strip('\n'))
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}
@ -223,7 +222,7 @@ def _wraps(
_dtype = partial(dtypes.dtype, canonicalize=True)
def promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
if len(args) < 2:
return [lax.asarray(arg) for arg in args]
@ -264,7 +263,7 @@ def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
def promote_dtypes(*args: ArrayLike) -> List[Array]:
def promote_dtypes(*args: ArrayLike) -> list[Array]:
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
if len(args) < 2:
@ -275,7 +274,7 @@ def promote_dtypes(*args: ArrayLike) -> List[Array]:
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
def promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to an inexact type."""
@ -286,7 +285,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
for x in args]
def promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
def promote_dtypes_numeric(*args: ArrayLike) -> list[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to a numeric (non-bool) type."""
@ -297,7 +296,7 @@ def promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
for x in args]
def promote_dtypes_complex(*args: ArrayLike) -> List[Array]:
def promote_dtypes_complex(*args: ArrayLike) -> list[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to a complex type."""
@ -350,20 +349,20 @@ def _check_no_float0s(fun_name: str, *args: Any):
"taken a gradient with respect to an integer argument.")
def promote_args(fun_name: str, *args: ArrayLike) -> List[Array]:
def promote_args(fun_name: str, *args: ArrayLike) -> list[Array]:
"""Convenience function to apply Numpy argument shape and dtype promotion."""
check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return promote_shapes(fun_name, *promote_dtypes(*args))
def promote_args_numeric(fun_name: str, *args: ArrayLike) -> List[Array]:
def promote_args_numeric(fun_name: str, *args: ArrayLike) -> list[Array]:
check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return promote_shapes(fun_name, *promote_dtypes_numeric(*args))
def promote_args_inexact(fun_name: str, *args: ArrayLike) -> List[Array]:
def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]:
"""Convenience function to apply Numpy argument shape and dtype promotion.
Promotes non-inexact types to an inexact type."""
@ -373,7 +372,7 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> List[Array]:
@partial(api.jit, inline=True)
def _broadcast_arrays(*args: ArrayLike) -> List[Array]:
def _broadcast_arrays(*args: ArrayLike) -> list[Array]:
"""Like Numpy's broadcast_arrays but doesn't return views."""
shapes = [np.shape(arg) for arg in args]
if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):

View File

@ -14,7 +14,7 @@
import functools
import re
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable
from jax._src import api
from jax import lax
@ -30,13 +30,13 @@ _ARGUMENT_LIST = '{0:}(?:,{0:})*'.format(_ARGUMENT)
_SIGNATURE = '^{0:}->{0:}$'.format(_ARGUMENT_LIST)
CoreDims = Tuple[str, ...]
CoreDims = tuple[str, ...]
NDArray = Any
def _parse_gufunc_signature(
signature: str,
) -> Tuple[List[CoreDims], List[CoreDims]]:
) -> tuple[list[CoreDims], list[CoreDims]]:
"""Parse string signatures for a generalized universal function.
Args:
@ -56,8 +56,8 @@ def _parse_gufunc_signature(
def _update_dim_sizes(
dim_sizes: Dict[str, int],
shape: Tuple[int, ...],
dim_sizes: dict[str, int],
shape: tuple[int, ...],
core_dims: CoreDims,
error_context: str = "",
*,
@ -94,10 +94,10 @@ def _update_dim_sizes(
def _parse_input_dimensions(
args: Tuple[NDArray, ...],
input_core_dims: List[CoreDims],
args: tuple[NDArray, ...],
input_core_dims: list[CoreDims],
error_context: str = "",
) -> Tuple[Tuple[int, ...], Dict[str, int]]:
) -> tuple[tuple[int, ...], dict[str, int]]:
"""Parse broadcast and core dimensions for vectorize with a signature.
Args:
@ -114,7 +114,7 @@ def _parse_input_dimensions(
'wrong number of positional arguments: expected %r, got %r %s'
% (len(input_core_dims), len(args), error_context))
shapes = []
dim_sizes: Dict[str, int] = {}
dim_sizes: dict[str, int] = {}
for arg, core_dims in zip(args, input_core_dims):
_update_dim_sizes(dim_sizes, arg.shape, core_dims, error_context,
is_input=True)
@ -127,8 +127,8 @@ def _parse_input_dimensions(
def _check_output_dims(
func: Callable,
dim_sizes: Dict[str, int],
expected_output_core_dims: List[CoreDims],
dim_sizes: dict[str, int],
expected_output_core_dims: list[CoreDims],
error_context: str = "",
) -> Callable:
"""Check that output core dimensions match the signature."""

View File

@ -14,7 +14,7 @@
"""Sharding utilities"""
import itertools
from typing import List, Sequence, Tuple, Union
from typing import Sequence, Union
import numpy as np
@ -22,7 +22,7 @@ from jax._src.lib import xla_client as xc
def get_num_ways_dim_sharded(
hlo_sharding: xc.HloSharding) -> Tuple[Sequence[int], int]:
hlo_sharding: xc.HloSharding) -> tuple[Sequence[int], int]:
if hlo_sharding.is_replicated(): # type: ignore
return [], 1
partitions = hlo_sharding.tile_assignment_dimensions()
@ -61,7 +61,7 @@ def are_op_shardings_equal(op1: Union[xc.OpSharding, xc.HloSharding],
return hc1 == hc2
_Index = Union[int, slice, Tuple[Union[int, slice], ...]]
_Index = Union[int, slice, tuple[Union[int, slice], ...]]
def op_sharding_to_numpy_indices(
@ -81,7 +81,7 @@ def op_sharding_to_numpy_indices(
partitions, num_replicas = get_num_ways_dim_sharded(hlo_sharding)
assert len(partitions) == len(shape), (len(partitions), len(shape))
axis_indices: List[Sequence[_Index]] = []
axis_indices: list[Sequence[_Index]] = []
for dim, n_shards in zip(shape, partitions):
if n_shards == 1:
axis_indices.append([slice(None)])
@ -103,6 +103,6 @@ def op_sharding_to_numpy_indices(
def op_sharding_to_indices(
op_sharding: xc.HloSharding, shape: Sequence[int],
num_devices: int) -> Tuple[Tuple[slice, ...], ...]:
num_devices: int) -> tuple[tuple[slice, ...], ...]:
indices = op_sharding_to_numpy_indices(op_sharding, shape, num_devices)
return tuple(indices.flat)

View File

@ -15,7 +15,7 @@
# Helpers for indexed updates.
import sys
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Optional, Sequence, Union
import warnings
import numpy as np
@ -37,7 +37,7 @@ if sys.version_info >= (3, 10):
SingleIndex = Union[None, int, slice, Sequence[int], Array, EllipsisType]
else:
SingleIndex = Union[None, int, slice, Sequence[int], Array]
Index = Union[SingleIndex, Tuple[SingleIndex, ...]]
Index = Union[SingleIndex, tuple[SingleIndex, ...]]
Scalar = Union[complex, float, int, np.number]

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import overload, Literal, Optional, Tuple, Union
from typing import overload, Literal, Optional, Union
import jax
from jax import lax
@ -32,14 +32,14 @@ def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] =
@overload
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
keepdims: bool = False, *, return_sign: Literal[True]) -> Tuple[Array, Array]: ...
keepdims: bool = False, *, return_sign: Literal[True]) -> tuple[Array, Array]: ...
@overload
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
keepdims: bool = False, return_sign: bool = False) -> Union[Array, Tuple[Array, Array]]: ...
keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]: ...
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
keepdims: bool = False, return_sign: bool = False) -> Union[Array, Tuple[Array, Array]]:
keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]:
r"""Log-sum-exp reduction.
Computes

View File

@ -17,8 +17,8 @@ import inspect
import logging
import weakref
import numpy as np
from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional,
Iterable, NamedTuple, Any)
from typing import (Callable, Sequence, Union, cast, Optional, Iterable,
NamedTuple, Any)
import itertools as it
from functools import partial, lru_cache
import threading
@ -399,9 +399,9 @@ class PjitInfo(NamedTuple):
fun: Callable
in_shardings: Any
out_shardings: Any
static_argnums: Tuple[int, ...]
static_argnames: Tuple[str, ...]
donate_argnums: Tuple[int, ...]
static_argnums: tuple[int, ...]
static_argnames: tuple[str, ...]
donate_argnums: tuple[int, ...]
device: Optional[xc.Device]
backend: Optional[str]
keep_unused: bool
@ -537,7 +537,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
donate_argnums)
def _extract_implicit_args(
in_type: Sequence[Tuple[core.AbstractValue, bool]],
in_type: Sequence[tuple[core.AbstractValue, bool]],
explicit_args: Sequence[Any]
) -> Sequence[core.Tracer]:
"""
@ -567,7 +567,7 @@ def _extract_implicit_args(
return [x for x, (_, e) in zip(args, in_type) if not e] # type: ignore
def _flat_axes_specs(abstracted_axes, *args, **kwargs
) -> Optional[List[pe.AbstractedAxesSpec]]:
) -> Optional[list[pe.AbstractedAxesSpec]]:
if abstracted_axes is None: return None
if kwargs: raise NotImplementedError
def ax_leaf(l):
@ -978,7 +978,7 @@ def _pjit_jaxpr(fun, out_shardings_thunk, in_type, debug_info,
def pjit_check_aval_sharding(
shardings, flat_avals, names: Optional[Tuple[str, ...]],
shardings, flat_avals, names: Optional[tuple[str, ...]],
what_aval: str, allow_uneven_sharding: bool):
new_names = [''] * len(shardings) if names is None else names
for aval, s, name in zip(flat_avals, shardings, new_names):
@ -1208,7 +1208,7 @@ pjit_p.def_impl(_pjit_call_impl)
@dataclasses.dataclass(frozen=True)
class SameDeviceAssignmentTuple:
shardings: Tuple[PjitSharding, ...]
shardings: tuple[PjitSharding, ...]
# device_assignment is Optional because shardings can contain `AUTO` and in
# that case `mesh` is compulsory to be used. So in that case
# `_pjit_lower_cached` cache, resource_env will check against the devices.
@ -1262,9 +1262,9 @@ def _pjit_lower_cached(
always_lower: bool,
*,
lowering_platform: Optional[str]):
in_shardings: Tuple[PjitShardingMinusUnspecified, ...] = cast(
Tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
out_shardings: Tuple[PjitSharding, ...] = sdat_out_shardings.shardings
in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast(
tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings
if resource_env is not None:
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
@ -1323,7 +1323,7 @@ pe.custom_staging_rules[pjit_p] = pjit_staging_rule
# TODO(mattjj): remove/trivialize this when jaxprs have type annotation on them,
# since it's actually not possible in general to infer the type from the term
def _out_type(jaxpr: core.ClosedJaxpr) -> List[core.AbstractValue]:
def _out_type(jaxpr: core.ClosedJaxpr) -> list[core.AbstractValue]:
out = []
in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)}
out_idx = {x: i for i, x in enumerate(jaxpr.jaxpr.invars)
@ -1430,7 +1430,7 @@ pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None)
def _pjit_batcher_for_sharding(
s: Union[GSPMDSharding, UnspecifiedValue],
dim: int, val: Tuple[str, ...], mesh, ndim: int):
dim: int, val: tuple[str, ...], mesh, ndim: int):
if is_unspecified(s):
return s
if not val:
@ -1490,7 +1490,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
@weakref_lru_cache
def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr,
fwds_known: Tuple[Optional[int]]) -> core.ClosedJaxpr:
fwds_known: tuple[Optional[int]]) -> core.ClosedJaxpr:
updated_jaxpr = known_jaxpr.jaxpr.replace(
outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, fwds_known)
if i is None])
@ -1603,7 +1603,7 @@ def _pjit_partial_eval_custom_params_updater(
unks_in: Sequence[bool], inst_in: Sequence[bool],
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
num_res: int, params_known: dict, params_staged: dict
) -> Tuple[dict, dict]:
) -> tuple[dict, dict]:
# prune inputs to jaxpr_known according to unks_in
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings'])
@ -1688,14 +1688,14 @@ ad.reducing_transposes[pjit_p] = _pjit_transpose
@weakref_lru_cache
def _dce_jaxpr_pjit(
jaxpr: core.ClosedJaxpr, used_outputs: Tuple[bool]
) -> Tuple[core.ClosedJaxpr, List[bool]]:
jaxpr: core.ClosedJaxpr, used_outputs: tuple[bool]
) -> tuple[core.ClosedJaxpr, list[bool]]:
new_jaxpr, used_inputs = pe.dce_jaxpr(jaxpr.jaxpr, used_outputs)
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts), used_inputs
def dce_jaxpr_pjit_rule(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], Optional[core.JaxprEqn]]:
def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], Optional[core.JaxprEqn]]:
dced_jaxpr, used_inputs = _dce_jaxpr_pjit(
eqn.params['jaxpr'], tuple(used_outputs))
@ -1965,13 +1965,13 @@ def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[Parti
def _get_op_sharding_from_executable(
executable) -> Tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]:
in_op_shardings: List[xc.OpSharding] = []
executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]:
in_op_shardings: list[xc.OpSharding] = []
parameter_shardings_from_xla = executable.get_parameter_shardings()
if parameter_shardings_from_xla is not None:
in_op_shardings = parameter_shardings_from_xla
out_op_shardings: List[xc.OpSharding] = []
out_op_shardings: list[xc.OpSharding] = []
output_shardings_from_xla = executable.get_output_shardings()
if output_shardings_from_xla is not None:
out_op_shardings = output_shardings_from_xla
@ -1979,10 +1979,10 @@ def _get_op_sharding_from_executable(
return in_op_shardings, out_op_shardings
def _get_ppspec_from_executable(executable, mesh) -> Tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
def _get_ppspec_from_executable(executable, mesh) -> tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
input_op_shardings: Sequence[xc.OpSharding] = executable.hlo_modules()[0].spmd_parameters_shardings
output_op_sharding: xc.OpSharding = executable.hlo_modules()[0].spmd_output_sharding
in_ppspec: List[ParsedPartitionSpec] = []
in_ppspec: list[ParsedPartitionSpec] = []
for s in input_op_shardings:
in_ppspec.extend(parse_flatten_op_sharding(s, mesh))
out_ppspec = parse_flatten_op_sharding(output_op_sharding, mesh)
@ -1991,7 +1991,7 @@ def _get_ppspec_from_executable(executable, mesh) -> Tuple[Sequence[ParsedPartit
def _get_pspec_from_executable(
executable, mesh: pxla.Mesh
) -> Tuple[Tuple[PartitionSpec, ...], Tuple[PartitionSpec, ...]]:
) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]:
in_ppspec, out_ppspec = _get_ppspec_from_executable(executable, mesh)
out_partition_spec = _get_partition_spec(out_ppspec)
in_partition_spec = _get_partition_spec(in_ppspec)

View File

@ -29,7 +29,7 @@ import abc
import enum
import sys
from functools import partial
from typing import List, NamedTuple, Optional, Sequence, Tuple, Union
from typing import NamedTuple, Optional, Sequence, Union
from jax._src.config import config
try:
@ -95,7 +95,7 @@ class _TextDoc(Doc):
class _ConcatDoc(Doc):
__slots__ = ("children",)
children: List[Doc]
children: list[Doc]
def __init__(self, children: Sequence[Doc]):
self.children = list(children)
@ -164,7 +164,7 @@ _BreakMode = enum.Enum("_BreakMode", ["FLAT", "BREAK"])
# non-recursive formulation using an explicit stack, necessary because Python
# doesn't have a tail recursion optimization.
def _fits(doc: Doc, width: int, agenda: List[Tuple[int, _BreakMode, Doc]]
def _fits(doc: Doc, width: int, agenda: list[tuple[int, _BreakMode, Doc]]
) -> bool:
while width >= 0 and len(agenda) > 0:
i, m, doc = agenda.pop()
@ -234,11 +234,11 @@ class _State(NamedTuple):
class _Line(NamedTuple):
text: str
width: int
annotations: Union[Optional[str], List[str]]
annotations: Union[Optional[str], list[str]]
def _update_color(use_color: bool, state: _ColorState, update: _ColorState
) -> Tuple[_ColorState, str]:
) -> tuple[_ColorState, str]:
if not use_color or colorama is None:
return update, ""
color_str = ""

View File

@ -17,8 +17,8 @@ import abc
from functools import partial, reduce
import math
import operator as op
from typing import (Any, Callable, Hashable, Iterator, List, NamedTuple,
Set, Sequence, Tuple, Union)
from typing import (Any, Callable, Hashable, Iterator, NamedTuple,
Sequence, Union)
import numpy as np
@ -154,7 +154,7 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
@property
@abc.abstractmethod
def shape(self) -> Tuple[int, ...]: ...
def shape(self) -> tuple[int, ...]: ...
@property
@abc.abstractmethod
@ -211,7 +211,7 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
@abc.abstractmethod
def device(self) -> Device: ...
@abc.abstractmethod
def devices(self) -> Set[Device]: ...
def devices(self) -> set[Device]: ...
@abc.abstractmethod
def delete(self) -> None: ...
@abc.abstractmethod
@ -220,10 +220,10 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
def on_device_size_in_bytes(self) -> int: ...
@property
@abc.abstractmethod
def addressable_shards(self) -> List[Shard]: ...
def addressable_shards(self) -> list[Shard]: ...
@property
@abc.abstractmethod
def global_shards(self) -> List[Shard]: ...
def global_shards(self) -> list[Shard]: ...
@abc.abstractmethod
def addressable_data(self, index: int) -> PRNGKeyArray: ...
@ -305,7 +305,7 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
return PRNGKeyArrayImpl(self.impl, self._base_array.addressable_data(index))
@property
def addressable_shards(self) -> List[Shard]:
def addressable_shards(self) -> list[Shard]:
return [
type(s)(
device=s._device,
@ -317,7 +317,7 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
]
@property
def global_shards(self) -> List[Shard]:
def global_shards(self) -> list[Shard]:
return [
type(s)(
device=s._device,

View File

@ -16,7 +16,7 @@
from functools import partial
import math
from operator import index
from typing import Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Union
import warnings
import numpy as np
@ -67,7 +67,7 @@ def _isnan(x: ArrayLike) -> Array:
return lax.ne(x, x)
def _check_prng_key(key) -> Tuple[prng.PRNGKeyArray, bool]:
def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]:
# TODO(frostig): remove once we always enable_custom_prng
if isinstance(key, prng.PRNGKeyArray):
return key, False

View File

@ -19,7 +19,7 @@ import numpy as np
import scipy.linalg
import textwrap
import warnings
from typing import cast, overload, Any, Literal, Optional, Tuple, Union
from typing import cast, overload, Any, Literal, Optional, Union
import jax
import jax.numpy as jnp
@ -56,7 +56,7 @@ def cholesky(a: ArrayLike, lower: bool = False, overwrite_a: bool = False,
@_wraps(scipy.linalg.cho_factor,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
def cho_factor(a: ArrayLike, lower: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> Tuple[Array, bool]:
check_finite: bool = True) -> tuple[Array, bool]:
del overwrite_a, check_finite # Unused
return (cholesky(a, lower=lower), lower)
@ -72,30 +72,30 @@ def _cho_solve(c: ArrayLike, b: ArrayLike, lower: bool) -> Array:
@_wraps(scipy.linalg.cho_solve, update_doc=False,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'check_finite'))
def cho_solve(c_and_lower: Tuple[ArrayLike, bool], b: ArrayLike,
def cho_solve(c_and_lower: tuple[ArrayLike, bool], b: ArrayLike,
overwrite_b: bool = False, check_finite: bool = True) -> Array:
del overwrite_b, check_finite # Unused
c, lower = c_and_lower
return _cho_solve(c, b, lower)
@overload
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[True]) -> Tuple[Array, Array, Array]: ...
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[True]) -> tuple[Array, Array, Array]: ...
@overload
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[False]) -> Array: ...
@overload
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Union[Array, Tuple[Array, Array, Array]]: ...
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Union[Array, tuple[Array, Array, Array]]: ...
@partial(jit, static_argnames=('full_matrices', 'compute_uv'))
def _svd(a: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Union[Array, Tuple[Array, Array, Array]]:
def _svd(a: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Union[Array, tuple[Array, Array, Array]]:
a, = promote_dtypes_inexact(jnp.asarray(a))
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
@overload
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: Literal[True] = True,
overwrite_a: bool = False, check_finite: bool = True,
lapack_driver: str = 'gesdd') -> Tuple[Array, Array, Array]: ...
lapack_driver: str = 'gesdd') -> tuple[Array, Array, Array]: ...
@overload
def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False],
@ -110,13 +110,13 @@ def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False],
@overload
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
overwrite_a: bool = False, check_finite: bool = True,
lapack_driver: str = 'gesdd') -> Union[Array, Tuple[Array, Array, Array]]: ...
lapack_driver: str = 'gesdd') -> Union[Array, tuple[Array, Array, Array]]: ...
@_wraps(scipy.linalg.svd,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lapack_driver'))
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
overwrite_a: bool = False, check_finite: bool = True,
lapack_driver: str = 'gesdd') -> Union[Array, Tuple[Array, Array, Array]]:
lapack_driver: str = 'gesdd') -> Union[Array, tuple[Array, Array, Array]]:
del overwrite_a, check_finite, lapack_driver # unused
return _svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
@ -133,15 +133,15 @@ def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: Liter
@overload
def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: Literal[False],
eigvals: None, type: int) -> Tuple[Array, Array]: ...
eigvals: None, type: int) -> tuple[Array, Array]: ...
@overload
def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: bool,
eigvals: None, type: int) -> Union[Array, Tuple[Array, Array]]: ...
eigvals: None, type: int) -> Union[Array, tuple[Array, Array]]: ...
@partial(jit, static_argnames=('lower', 'eigvals_only', 'eigvals', 'type'))
def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: bool,
eigvals: None, type: int) -> Union[Array, Tuple[Array, Array]]:
eigvals: None, type: int) -> Union[Array, tuple[Array, Array]]:
if b is not None:
raise NotImplementedError("Only the b=None case of eigh is implemented")
if type != 1:
@ -162,7 +162,7 @@ def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: bool,
def eigh(a: ArrayLike, b: Optional[ArrayLike] = None, lower: bool = True,
eigvals_only: Literal[False] = False, overwrite_a: bool = False,
overwrite_b: bool = False, turbo: bool = True, eigvals: None = None,
type: int = 1, check_finite: bool = True) -> Tuple[Array, Array]: ...
type: int = 1, check_finite: bool = True) -> tuple[Array, Array]: ...
@overload
def eigh(a: ArrayLike, b: Optional[ArrayLike] = None, lower: bool = True, *,
@ -180,7 +180,7 @@ def eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool,
def eigh(a: ArrayLike, b: Optional[ArrayLike] = None, lower: bool = True,
eigvals_only: bool = False, overwrite_a: bool = False,
overwrite_b: bool = False, turbo: bool = True, eigvals: None = None,
type: int = 1, check_finite: bool = True) -> Union[Array, Tuple[Array, Array]]: ...
type: int = 1, check_finite: bool = True) -> Union[Array, tuple[Array, Array]]: ...
@_wraps(scipy.linalg.eigh,
lax_description=_no_overwrite_and_chkfinite_doc,
@ -188,18 +188,18 @@ def eigh(a: ArrayLike, b: Optional[ArrayLike] = None, lower: bool = True,
def eigh(a: ArrayLike, b: Optional[ArrayLike] = None, lower: bool = True,
eigvals_only: bool = False, overwrite_a: bool = False,
overwrite_b: bool = False, turbo: bool = True, eigvals: None = None,
type: int = 1, check_finite: bool = True) -> Union[Array, Tuple[Array, Array]]:
type: int = 1, check_finite: bool = True) -> Union[Array, tuple[Array, Array]]:
del overwrite_a, overwrite_b, turbo, check_finite # unused
return _eigh(a, b, lower, eigvals_only, eigvals, type)
@partial(jit, static_argnames=('output',))
def _schur(a: Array, output: str) -> Tuple[Array, Array]:
def _schur(a: Array, output: str) -> tuple[Array, Array]:
if output == "complex":
a = a.astype(dtypes.to_complex_dtype(a.dtype))
return lax_linalg.schur(a)
@_wraps(scipy.linalg.schur)
def schur(a: ArrayLike, output: str = 'real') -> Tuple[Array, Array]:
def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]:
if output not in ('real', 'complex'):
raise ValueError(
f"Expected 'output' to be either 'real' or 'complex', got {output=}.")
@ -215,7 +215,7 @@ def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> A
@_wraps(scipy.linalg.lu_factor,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
@partial(jit, static_argnames=('overwrite_a', 'check_finite'))
def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Tuple[Array, Array]:
def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]:
del overwrite_a, check_finite # unused
a, = promote_dtypes_inexact(jnp.asarray(a))
lu, pivots, _ = lax_linalg.lu(a)
@ -225,7 +225,7 @@ def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True
@_wraps(scipy.linalg.lu_solve,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'check_finite'))
@partial(jit, static_argnames=('trans', 'overwrite_b', 'check_finite'))
def lu_solve(lu_and_piv: Tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0,
def lu_solve(lu_and_piv: tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0,
overwrite_b: bool = False, check_finite: bool = True) -> Array:
del overwrite_b, check_finite # unused
lu, pivots = lu_and_piv
@ -234,16 +234,16 @@ def lu_solve(lu_and_piv: Tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0,
return lax_linalg.lu_solve(lu, perm, b, trans)
@overload
def _lu(a: ArrayLike, permute_l: Literal[True]) -> Tuple[Array, Array]: ...
def _lu(a: ArrayLike, permute_l: Literal[True]) -> tuple[Array, Array]: ...
@overload
def _lu(a: ArrayLike, permute_l: Literal[False]) -> Tuple[Array, Array, Array]: ...
def _lu(a: ArrayLike, permute_l: Literal[False]) -> tuple[Array, Array, Array]: ...
@overload
def _lu(a: ArrayLike, permute_l: bool) -> Union[Tuple[Array, Array], Tuple[Array, Array, Array]]: ...
def _lu(a: ArrayLike, permute_l: bool) -> Union[tuple[Array, Array], tuple[Array, Array, Array]]: ...
@partial(jit, static_argnums=(1,))
def _lu(a: ArrayLike, permute_l: bool) -> Union[Tuple[Array, Array], Tuple[Array, Array, Array]]:
def _lu(a: ArrayLike, permute_l: bool) -> Union[tuple[Array, Array], tuple[Array, Array, Array]]:
a, = promote_dtypes_inexact(jnp.asarray(a))
lu, _, permutation = lax_linalg.lu(a)
dtype = lax.dtype(a)
@ -259,35 +259,35 @@ def _lu(a: ArrayLike, permute_l: bool) -> Union[Tuple[Array, Array], Tuple[Array
@overload
def lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False,
check_finite: bool = True) -> Tuple[Array, Array, Array]: ...
check_finite: bool = True) -> tuple[Array, Array, Array]: ...
@overload
def lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False,
check_finite: bool = True) -> Tuple[Array, Array]: ...
check_finite: bool = True) -> tuple[Array, Array]: ...
@overload
def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> Union[Tuple[Array, Array], Tuple[Array, Array, Array]]: ...
check_finite: bool = True) -> Union[tuple[Array, Array], tuple[Array, Array, Array]]: ...
@_wraps(scipy.linalg.lu, update_doc=False,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
@partial(jit, static_argnames=('permute_l', 'overwrite_a', 'check_finite'))
def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> Union[Tuple[Array, Array], Tuple[Array, Array, Array]]:
check_finite: bool = True) -> Union[tuple[Array, Array], tuple[Array, Array, Array]]:
del overwrite_a, check_finite # unused
return _lu(a, permute_l)
@overload
def _qr(a: ArrayLike, mode: Literal["r"], pivoting: bool) -> Tuple[Array]: ...
def _qr(a: ArrayLike, mode: Literal["r"], pivoting: bool) -> tuple[Array]: ...
@overload
def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: bool) -> Tuple[Array, Array]: ...
def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: bool) -> tuple[Array, Array]: ...
@overload
def _qr(a: ArrayLike, mode: str, pivoting: bool) -> Union[Tuple[Array], Tuple[Array, Array]]: ...
def _qr(a: ArrayLike, mode: str, pivoting: bool) -> Union[tuple[Array], tuple[Array, Array]]: ...
@partial(jit, static_argnames=('mode', 'pivoting'))
def _qr(a: ArrayLike, mode: str, pivoting: bool) -> Union[Tuple[Array], Tuple[Array, Array]]:
def _qr(a: ArrayLike, mode: str, pivoting: bool) -> Union[tuple[Array], tuple[Array, Array]]:
if pivoting:
raise NotImplementedError(
"The pivoting=True case of qr is not implemented.")
@ -306,24 +306,24 @@ def _qr(a: ArrayLike, mode: str, pivoting: bool) -> Union[Tuple[Array], Tuple[Ar
@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: Literal["full", "economic"] = "full",
pivoting: bool = False, check_finite: bool = True) -> Tuple[Array, Array]: ...
pivoting: bool = False, check_finite: bool = True) -> tuple[Array, Array]: ...
@overload
def qr(a: ArrayLike, overwrite_a: bool, lwork: Any, mode: Literal["r"],
pivoting: bool = False, check_finite: bool = True) -> Tuple[Array]: ...
pivoting: bool = False, check_finite: bool = True) -> tuple[Array]: ...
@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal["r"],
pivoting: bool = False, check_finite: bool = True) -> Tuple[Array]: ...
pivoting: bool = False, check_finite: bool = True) -> tuple[Array]: ...
@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full",
pivoting: bool = False, check_finite: bool = True) -> Union[Tuple[Array], Tuple[Array, Array]]: ...
pivoting: bool = False, check_finite: bool = True) -> Union[tuple[Array], tuple[Array, Array]]: ...
@_wraps(scipy.linalg.qr,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lwork'))
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full",
pivoting: bool = False, check_finite: bool = True) -> Union[Tuple[Array], Tuple[Array, Array]]:
pivoting: bool = False, check_finite: bool = True) -> Union[tuple[Array], tuple[Array, Array]]:
del overwrite_a, lwork, check_finite # unused
return _qr(a, mode, pivoting)
@ -458,7 +458,7 @@ def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 1
return R
@jit
def _calc_P_Q(A: ArrayLike) -> Tuple[Array, Array, Array]:
def _calc_P_Q(A: ArrayLike) -> tuple[Array, Array, Array]:
A = jnp.asarray(A)
if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError('expected A to be a square matrix')
@ -513,7 +513,7 @@ def _squaring(R: Array, n_squarings: Array, max_squarings: int) -> Array:
return res
def _pade3(A: Array) -> Tuple[Array, Array]:
def _pade3(A: Array) -> tuple[Array, Array]:
b = (120., 60., 12., 1.)
M, N = A.shape
ident = jnp.eye(M, N, dtype=A.dtype)
@ -522,7 +522,7 @@ def _pade3(A: Array) -> Tuple[Array, Array]:
V: Array = b[2]*A2 + b[0]*ident
return U, V
def _pade5(A: Array) -> Tuple[Array, Array]:
def _pade5(A: Array) -> tuple[Array, Array]:
b = (30240., 15120., 3360., 420., 30., 1.)
M, N = A.shape
ident = jnp.eye(M, N, dtype=A.dtype)
@ -532,7 +532,7 @@ def _pade5(A: Array) -> Tuple[Array, Array]:
V: Array = b[4]*A4 + b[2]*A2 + b[0]*ident
return U, V
def _pade7(A: Array) -> Tuple[Array, Array]:
def _pade7(A: Array) -> tuple[Array, Array]:
b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
M, N = A.shape
ident = jnp.eye(M, N, dtype=A.dtype)
@ -543,7 +543,7 @@ def _pade7(A: Array) -> Tuple[Array, Array]:
V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
def _pade9(A: Array) -> Tuple[Array, Array]:
def _pade9(A: Array) -> tuple[Array, Array]:
b = (17643225600., 8821612800., 2075673600., 302702400., 30270240.,
2162160., 110880., 3960., 90., 1.)
M, N = A.shape
@ -556,7 +556,7 @@ def _pade9(A: Array) -> Tuple[Array, Array]:
V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
def _pade13(A: Array) -> Tuple[Array, Array]:
def _pade13(A: Array) -> tuple[Array, Array]:
b = (64764752532480000., 32382376266240000., 7771770303897600.,
1187353796428800., 129060195264000., 10559470521600., 670442572800.,
33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.)
@ -578,7 +578,7 @@ support the ``method='blockEnlarge'`` argument.
@overload
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None,
compute_expm: Literal[True] = True) -> Tuple[Array, Array]: ...
compute_expm: Literal[True] = True) -> tuple[Array, Array]: ...
@overload
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None,
@ -586,12 +586,12 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None,
@overload
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None,
compute_expm: bool = True) -> Union[Array, Tuple[Array, Array]]: ...
compute_expm: bool = True) -> Union[Array, tuple[Array, Array]]: ...
@_wraps(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description)
@partial(jit, static_argnames=('method', 'compute_expm'))
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None,
compute_expm: bool = True) -> Union[Array, Tuple[Array, Array]]:
compute_expm: bool = True) -> Union[Array, tuple[Array, Array]]:
A = jnp.asarray(A)
E = jnp.asarray(E)
if A.ndim != 2 or A.shape[0] != A.shape[1]:
@ -617,8 +617,8 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None,
@jit
def block_diag(*arrs: ArrayLike) -> Array:
if len(arrs) == 0:
arrs = cast(Tuple[ArrayLike], (jnp.zeros((1, 0)),))
arrs = cast(Tuple[ArrayLike], promote_dtypes(*arrs))
arrs = cast(tuple[ArrayLike], (jnp.zeros((1, 0)),))
arrs = cast(tuple[ArrayLike], promote_dtypes(*arrs))
bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2]
if bad_shapes:
raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at "
@ -638,7 +638,7 @@ def block_diag(*arrs: ArrayLike) -> Array:
@_wraps(scipy.linalg.eigh_tridiagonal)
@partial(jit, static_argnames=("eigvals_only", "select", "select_range"))
def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False,
select: str = 'a', select_range: Optional[Tuple[float, float]] = None,
select: str = 'a', select_range: Optional[tuple[float, float]] = None,
tol: Optional[float] = None) -> Array:
if not eigvals_only:
raise NotImplementedError("Calculation of eigenvectors is not implemented")
@ -794,7 +794,7 @@ def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False,
@partial(jit, static_argnames=('side', 'method'))
@jax.default_matmul_precision("float32")
def polar(a: ArrayLike, side: str = 'right', *, method: str = 'qdwh', eps: Optional[float] = None,
max_iterations: Optional[int] = None) -> Tuple[Array, Array]:
max_iterations: Optional[int] = None) -> tuple[Array, Array]:
r"""Computes the polar decomposition.
Given the :math:`m \times n` matrix :math:`a`, returns the factors of the polar
@ -936,7 +936,7 @@ def sqrtm(A: ArrayLike, blocksize: int = 1) -> Array:
@_wraps(scipy.linalg.rsf2csf, lax_description=_no_chkfinite_doc)
@partial(jit, static_argnames=('check_finite',))
def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> Tuple[Array, Array]:
def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]:
del check_finite # unused
T = jnp.asarray(T)
@ -1001,12 +1001,12 @@ def hessenberg(a: ArrayLike, *, calc_q: Literal[False], overwrite_a: bool = Fals
@overload
def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False,
check_finite: bool = True) -> Tuple[Array, Array]: ...
check_finite: bool = True) -> tuple[Array, Array]: ...
@_wraps(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc)
@partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a'))
def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> Union[Array, Tuple[Array, Array]]:
check_finite: bool = True) -> Union[Array, tuple[Array, Array]]:
del overwrite_a, check_finite
n = jnp.shape(a)[-1]
if n == 0:

View File

@ -17,7 +17,7 @@ import functools
import itertools
import operator
import textwrap
from typing import Callable, Dict, List, Sequence, Tuple
from typing import Callable, Sequence
import scipy.ndimage
@ -44,7 +44,7 @@ def _mirror_index_fixer(index: Array, size: int) -> Array:
def _reflect_index_fixer(index: Array, size: int) -> Array:
return jnp.floor_divide(_mirror_index_fixer(2*index+1, 2*size+1) - 1, 2)
_INDEX_FIXERS: Dict[str, Callable[[Array, int], Array]] = {
_INDEX_FIXERS: dict[str, Callable[[Array, int], Array]] = {
'constant': lambda index, size: index,
'nearest': lambda index, size: jnp.clip(index, 0, size - 1),
'wrap': lambda index, size: index % size,
@ -57,13 +57,13 @@ def _round_half_away_from_zero(a: Array) -> Array:
return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a)
def _nearest_indices_and_weights(coordinate: Array) -> List[Tuple[Array, ArrayLike]]:
def _nearest_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
index = _round_half_away_from_zero(coordinate).astype(jnp.int32)
weight = coordinate.dtype.type(1)
return [(index, weight)]
def _linear_indices_and_weights(coordinate: Array) -> List[Tuple[Array, ArrayLike]]:
def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
lower = jnp.floor(coordinate)
upper_weight = coordinate - lower
lower_weight = 1 - upper_weight

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Mapping, Optional, Union
import jax
from jax._src.scipy.optimize.bfgs import minimize_bfgs
@ -50,7 +50,7 @@ class OptimizeResults(NamedTuple):
def minimize(
fun: Callable,
x0: jax.Array,
args: Tuple = (),
args: tuple = (),
*,
method: str,
tol: Optional[float] = None,

View File

@ -15,7 +15,7 @@
from functools import partial
import math
import operator
from typing import Callable, Optional, Tuple, Union, Sequence
from typing import Callable, Optional, Union, Sequence
import warnings
import numpy as np
@ -289,7 +289,7 @@ def _spectral_helper(x: Array, y: Optional[ArrayLike], fs: ArrayLike = 1.0,
detrend_type: Union[bool, str, Callable[[Array], Array]] = 'constant',
return_onesided: bool = True, scaling: str = 'density',
axis: int = -1, mode: str = 'psd', boundary: Optional[str] = None,
padded: bool = False) -> Tuple[Array, Array, Array]:
padded: bool = False) -> tuple[Array, Array, Array]:
"""LAX-backend implementation of `scipy.signal._spectral_helper`.
Unlike the original helper function, `y` can be None for explicitly
@ -500,7 +500,7 @@ def _spectral_helper(x: Array, y: Optional[ArrayLike], fs: ArrayLike = 1.0,
def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256,
noverlap: Optional[int] = None, nfft: Optional[int] = None,
detrend: bool = False, return_onesided: bool = True, boundary: Optional[str] = 'zeros',
padded: bool = True, axis: int = -1) -> Tuple[Array, Array, Array]:
padded: bool = True, axis: int = -1) -> tuple[Array, Array, Array]:
return _spectral_helper(x, None, fs, window, nperseg, noverlap,
nfft, detrend, return_onesided,
scaling='spectrum', axis=axis,
@ -520,7 +520,7 @@ def csd(x: Array, y: Optional[ArrayLike], fs: ArrayLike = 1.0, window: str = 'ha
nperseg: Optional[int] = None, noverlap: Optional[int] = None,
nfft: Optional[int] = None, detrend: str = 'constant',
return_onesided: bool = True, scaling: str = 'density',
axis: int = -1, average: str = 'mean') -> Tuple[Array, Array]:
axis: int = -1, average: str = 'mean') -> tuple[Array, Array]:
freqs, _, Pxy = _spectral_helper(x, y, fs, window, nperseg, noverlap, nfft,
detrend, return_onesided, scaling, axis,
mode='psd')
@ -553,7 +553,7 @@ def welch(x: Array, fs: ArrayLike = 1.0, window: str = 'hann',
nperseg: Optional[int] = None, noverlap: Optional[int] = None,
nfft: Optional[int] = None, detrend: str = 'constant',
return_onesided: bool = True, scaling: str = 'density',
axis: int = -1, average: str = 'mean') -> Tuple[Array, Array]:
axis: int = -1, average: str = 'mean') -> tuple[Array, Array]:
freqs, Pxx = csd(x, None, fs=fs, window=window, nperseg=nperseg,
noverlap=noverlap, nfft=nfft, detrend=detrend,
return_onesided=return_onesided, scaling=scaling,
@ -615,7 +615,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
nperseg: Optional[int] = None, noverlap: Optional[int] = None,
nfft: Optional[int] = None, input_onesided: bool = True,
boundary: bool = True, time_axis: int = -1,
freq_axis: int = -2) -> Tuple[Array, Array]:
freq_axis: int = -2) -> tuple[Array, Array]:
# Input validation
check_arraylike("istft", Zxx)
if Zxx.ndim < 2:

View File

@ -14,7 +14,7 @@
from functools import partial
import operator
from typing import cast, Any, List, Optional, Tuple
from typing import cast, Any, Optional
import numpy as np
import scipy.special as osp_special
@ -765,7 +765,7 @@ def bessel_jn(z: ArrayLike, *, v: int, n_iter: int=50) -> Array:
def _gen_recurrence_mask(
l_max: int, is_normalized: bool, dtype: Any
) -> Tuple[Array, Array]:
) -> tuple[Array, Array]:
"""Generates mask for recurrence relation on the remaining entries.
The remaining entries are with respect to the diagonal and offdiagonal
@ -1012,7 +1012,7 @@ def _gen_associated_legendre(l_max: int,
return p
def lpmn(m: int, n: int, z: Array) -> Tuple[Array, Array]:
def lpmn(m: int, n: int, z: Array) -> tuple[Array, Array]:
"""The associated Legendre functions (ALFs) of the first kind.
Args:
@ -1215,7 +1215,7 @@ def _expint1(x: Array) -> Array:
return x * f + jnp.euler_gamma + jnp.log(x)
def _eval_expint_k(A: List[float], B: List[float], x: Array) -> Array:
def _eval_expint_k(A: list[float], B: list[float], x: Array) -> Array:
# helper function for all subsequent intervals
A_arr = jnp.array(A, dtype=x.dtype)
B_arr = jnp.array(B, dtype=x.dtype)

View File

@ -15,7 +15,7 @@
from collections import namedtuple
from functools import partial
import math
from typing import Optional, Tuple
from typing import Optional
import jax
import jax.numpy as jnp
@ -70,7 +70,7 @@ def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", k
axis = 0
x = x.ravel()
def _mode_helper(x: jax.Array) -> Tuple[jax.Array, jax.Array]:
def _mode_helper(x: jax.Array) -> tuple[jax.Array, jax.Array]:
"""Helper function to return mode and count of a given array."""
if x.size == 0:
return jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)), jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_))

View File

@ -15,15 +15,15 @@
from __future__ import annotations
import functools
from typing import (Mapping, Optional, Sequence, Set, Tuple)
from typing import Mapping, Optional, Sequence
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
Shape = Tuple[int, ...]
Shape = tuple[int, ...]
Device = xc.Device
Index = Tuple[slice, ...]
Index = tuple[slice, ...]
XLADeviceAssignment = Sequence[Device]
@ -44,7 +44,7 @@ class Sharding:
# Abstract methods below that subclasses should implement.
@property
def device_set(self) -> Set[Device]:
def device_set(self) -> set[Device]:
"""A ``set`` of global devices that this ``Sharding`` spans.
In multi-controller JAX, the set of devices is global, i.e., includes
@ -89,7 +89,7 @@ class Sharding:
# Default implementations below that all subclasses will inherit.
@functools.cached_property
def addressable_devices(self) -> Set[Device]:
def addressable_devices(self) -> set[Device]:
"""A set of devices that are addressable by the current process."""
# Add a fast path for single controller runtimes.
if xb.process_count() == 1:

View File

@ -22,8 +22,8 @@ import itertools
import math
import operator as op
import sys
from typing import (Any, Dict, FrozenSet, List, Mapping, Optional, OrderedDict,
NamedTuple, Sequence, Set, Tuple, Union, cast)
from typing import (Any, Mapping, Optional, OrderedDict, NamedTuple, Sequence,
Union, cast)
from jax._src import mesh as mesh_lib
from jax._src.op_shardings import (
@ -44,13 +44,13 @@ import numpy as np
if sys.version_info >= (3, 9):
OrderedDictType = OrderedDict
else:
OrderedDictType = Dict
OrderedDictType = dict
Shape = Tuple[int, ...]
Shape = tuple[int, ...]
Device = xc.Device
Index = Tuple[slice, ...]
XLADeviceAssignment = Tuple[Device, ...]
Index = tuple[slice, ...]
XLADeviceAssignment = tuple[Device, ...]
# Shardings that inherit from XLACompatibleSharding should implement the
@ -146,7 +146,7 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
'create a device to index mapping for your sharding from which replica '
'ids will be calculated.') from None
index_to_replica: Dict[int, int] = collections.Counter()
index_to_replica: dict[int, int] = collections.Counter()
out = {}
for device, index in device_indices_map_fn(global_shape).items():
h_index = hashed_index(index)
@ -255,7 +255,7 @@ class NamedSharding(XLACompatibleSharding):
return cls(mesh, parsed_pspec.get_partition_spec(), parsed_pspec)
@property
def device_set(self) -> Set[Device]:
def device_set(self) -> set[Device]:
return self.mesh._flat_devices_set
@property
@ -269,7 +269,7 @@ class NamedSharding(XLACompatibleSharding):
return not self.mesh.is_multi_process
@property
def addressable_devices(self) -> Set[Device]:
def addressable_devices(self) -> set[Device]:
# Override addressable devices because there is a high chance that the mesh
# across multiple NamedSharding objects will be the same.
return self.mesh._local_devices_set
@ -355,7 +355,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
return self._device == other._device
@property
def device_set(self) -> Set[Device]:
def device_set(self) -> set[Device]:
return {self._device}
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: # type: ignore
@ -453,7 +453,7 @@ class PmapSharding(XLACompatibleSharding):
return cls(pmap_devices, sharding_spec)
@functools.cached_property
def device_set(self) -> Set[Device]:
def device_set(self) -> set[Device]:
return set(self.devices.flat)
@functools.lru_cache(maxsize=4096)
@ -524,7 +524,7 @@ def _op_sharding_to_pos_sharding(
class PositionalSharding(XLACompatibleSharding):
_devices: Tuple[xc.Device, ...]
_devices: tuple[xc.Device, ...]
_ids: np.ndarray # dtype DeviceIdSet
def __init__(self, devices: Union[Sequence[xc.Device], np.ndarray]):
@ -564,7 +564,7 @@ class PositionalSharding(XLACompatibleSharding):
@classmethod
def remake(
cls, devices: Tuple[xc.Device, ...], ids: np.ndarray) -> PositionalSharding:
cls, devices: tuple[xc.Device, ...], ids: np.ndarray) -> PositionalSharding:
self = cls.__new__(cls)
self._devices = devices
self._ids = ids
@ -626,7 +626,7 @@ class PositionalSharding(XLACompatibleSharding):
class DeviceIdSet:
_name: str
_ids: FrozenSet[int]
_ids: frozenset[int]
def __init__(self, name, *ids):
self._name = name
self._ids = frozenset(ids)
@ -655,7 +655,7 @@ class DeviceIdSet:
@use_cpp_class(xc.GSPMDSharding)
class GSPMDSharding(XLACompatibleSharding):
_devices: Tuple[Device, ...]
_devices: tuple[Device, ...]
_hlo_sharding: xc.HloSharding
@use_cpp_method()
@ -706,7 +706,7 @@ class GSPMDSharding(XLACompatibleSharding):
f"{len(aval_shape)}")
@functools.cached_property
def device_set(self) -> Set[Device]:
def device_set(self) -> set[Device]:
return set(self._devices)
@functools.lru_cache(maxsize=4096)
@ -997,8 +997,8 @@ def _check_unique_resources(axis_resources, arg_name):
class AxisEnv(NamedTuple):
"""Represents a pmap mesh (only along the replica axes)."""
nreps: int
names: Tuple[Any, ...]
sizes: Tuple[int, ...]
names: tuple[Any, ...]
sizes: tuple[int, ...]
@dataclasses.dataclass(frozen=True)
@ -1010,7 +1010,7 @@ class SPMDAxisContext:
is invoked inside an xmap) lowered in the MANUAL sharding mode.
"""
mesh: mesh_lib.Mesh
manual_axes: FrozenSet[MeshAxisName] = frozenset()
manual_axes: frozenset[MeshAxisName] = frozenset()
@property
def axis_env(self):
@ -1030,7 +1030,7 @@ class SPMDAxisContext:
names=self.mesh.axis_names,
sizes=tuple(self.mesh.shape.values()))
def extend_manual(self, axes: FrozenSet[MeshAxisName]) -> SPMDAxisContext:
def extend_manual(self, axes: frozenset[MeshAxisName]) -> SPMDAxisContext:
return SPMDAxisContext(self.mesh, self.manual_axes | axes)
@ -1162,7 +1162,7 @@ def parse_flatten_op_sharding(op_sharding: Union[xc.OpSharding, xc.HloSharding],
if isinstance(op_sharding, xc.HloSharding):
op_sharding = op_sharding.to_proto() # type: ignore
if op_sharding.type == xc.OpSharding.Type.TUPLE:
out: List[ParsedPartitionSpec] = []
out: list[ParsedPartitionSpec] = []
for s in op_sharding.tuple_shardings:
out.extend(parse_flatten_op_sharding(s, mesh))
return out

View File

@ -31,7 +31,7 @@ import collections
import functools
import itertools
import math
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast
from typing import Any, Mapping, Optional, Sequence, Union, cast
import numpy as np
@ -91,7 +91,7 @@ def sharding_spec_sharding_proto(
the code here might help:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/compiler/xla/experimental/xla_sharding/xla_sharding.py
"""
mesh_shape = cast(Tuple[int, ...], self.mesh_shape)
mesh_shape = cast(tuple[int, ...], self.mesh_shape)
sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped
replicated_maxes = [] # lists mesh axis identifiers to replicate over
@ -128,8 +128,8 @@ def sharding_spec_sharding_proto(
# specially over some mesh axes.
last_tile_dims = []
if replicated_maxes:
axes_by_type: Dict[OpShardingType, List[_MeshAxisName]] = {}
size_by_type: Dict[OpShardingType, int] = collections.defaultdict(lambda: 1)
axes_by_type: dict[OpShardingType, list[_MeshAxisName]] = {}
size_by_type: dict[OpShardingType, int] = collections.defaultdict(lambda: 1)
assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes.keys()))
for axis, size in replicated_maxes:
ty = special_axes.get(axis, xc.OpSharding.Type.REPLICATED)
@ -145,7 +145,7 @@ def sharding_spec_sharding_proto(
transpose_perm=mesh_permutation, subgroup_types=last_tile_dims)
def _sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray:
"""Returns NumPy-style indices corresponding to a sharding spec.
Args:
@ -167,7 +167,7 @@ def _sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
hlo_sharding, shape, math.prod(self.mesh_shape)
).reshape(self.mesh_shape)
axis_indices: List[Sequence[Index]] = []
axis_indices: list[Sequence[Index]] = []
shard_indices_shape = []
for dim, sharding in enumerate(self.sharding):
axis_size = shape[dim]
@ -223,10 +223,10 @@ ShardingSpec.indices = _sharding_spec_indices
ShardingSpec.__repr__ = _sharding_spec_repr # type: ignore
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
Index = Union[int, slice, tuple[Union[int, slice], ...]]
def spec_to_indices(shape: Sequence[int],
spec: ShardingSpec) -> Tuple[Index, ...]:
spec: ShardingSpec) -> tuple[Index, ...]:
"""Returns numpy-style indices corresponding to a sharding spec.
Each index describes a shard of the array. The order of the indices is the
@ -306,7 +306,7 @@ def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int],
mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping)
def create_pmap_sharding_spec(shape: Tuple[int, ...], sharded_dim: int = 0,
def create_pmap_sharding_spec(shape: tuple[int, ...], sharded_dim: int = 0,
sharded_dim_size: Optional[int] = None):
if sharded_dim is not None:
sharded_shape = shape[:sharded_dim] + shape[sharded_dim+1:]

View File

@ -20,7 +20,7 @@ import os.path
import sysconfig
import threading
import types
from typing import List, Optional, Iterator, NamedTuple, Union, Tuple
from typing import Optional, Iterator, NamedTuple, Union
import jax.version
from jax._src.lib import xla_client
@ -40,7 +40,7 @@ class Frame(NamedTuple):
end_column: int
_exclude_paths: List[str] = [
_exclude_paths: list[str] = [
os.path.dirname(jax.version.__file__),
# Also exclude stdlib as user frames. In a non-standard Python runtime,
# the following two may be different.
@ -54,13 +54,13 @@ def register_exclusion(path: str):
class Scope(NamedTuple):
name: str
def wrap(self, stack: Tuple[str, ...]) -> Tuple[str, ...]:
def wrap(self, stack: tuple[str, ...]) -> tuple[str, ...]:
return (self.name, *stack)
class Transform(NamedTuple):
name: str
def wrap(self, stack: Tuple[str, ...]) -> Tuple[str, ...]:
def wrap(self, stack: tuple[str, ...]) -> tuple[str, ...]:
if stack:
return (f'{self.name}({stack[0]})', *stack[1:])
else:
@ -68,9 +68,9 @@ class Transform(NamedTuple):
@dataclasses.dataclass(frozen=True)
class NameStack:
stack: Tuple[Union[Scope, Transform], ...] = ()
stack: tuple[Union[Scope, Transform], ...] = ()
def extend(self, name: Union[Tuple[str, ...], str]) -> 'NameStack':
def extend(self, name: Union[tuple[str, ...], str]) -> 'NameStack':
if not isinstance(name, tuple):
name = (name,)
scopes = tuple(map(Scope, name))
@ -97,7 +97,7 @@ class NameStack:
return NameStack(other.stack + self.stack)
def __str__(self) -> str:
scope: Tuple[str, ...] = ()
scope: tuple[str, ...] = ()
for elem in self.stack[::-1]:
scope = elem.wrap(scope)
return '/'.join(scope)

View File

@ -33,8 +33,7 @@ from __future__ import annotations
import warnings
from dataclasses import dataclass
from typing import (
Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tuple, Union)
from typing import Any, NamedTuple, Optional, Protocol, Sequence, Union
import jax
@ -56,7 +55,7 @@ xla_extension = xc._xla
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
CompilerOptions = Dict[str, Union[str, bool]]
CompilerOptions = dict[str, Union[str, bool]]
# -- Internal protocols
@ -232,9 +231,9 @@ class XlaExecutable(Executable):
else:
raise
# TODO(skyewm): this should return a single Dict (I think returning a list
# TODO(skyewm): this should return a single dict (I think returning a list
# was to support MPMD executables, which never fully landed)
def cost_analysis(self) -> List[Dict[str, float]]:
def cost_analysis(self) -> list[dict[str, float]]:
xla_ext_exe = self.xla_extension_executable()
# TODO(b/259255524): Unify/merge the two cost_analysis calls below.
@ -284,7 +283,7 @@ class XlaExecutable(Executable):
class XlaLowering(Lowering):
"""Adapts our various internal XLA-backed computations into a ``Lowering``."""
compile_args: Dict[str, Any]
compile_args: dict[str, Any]
def hlo(self) -> xc.XlaComputation:
"""Return an HLO representation of this computation."""
@ -331,7 +330,7 @@ class XlaLowering(Lowering):
else:
raise ValueError(f"unknown dialect: {dialect}")
def cost_analysis(self) -> Dict[str, float]:
def cost_analysis(self) -> dict[str, float]:
raise NotImplementedError("must override")
@ -578,7 +577,7 @@ class Lowered(Stage):
lowering: XlaLowering,
in_tree: tree_util.PyTreeDef,
in_avals,
donate_argnums: Tuple[int, ...],
donate_argnums: tuple[int, ...],
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False):
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.
@ -599,7 +598,7 @@ class Lowered(Stage):
def compile(
self, compiler_options: Optional[CompilerOptions] = None) -> Compiled:
"""Compile, returning a corresponding ``Compiled`` instance."""
kw: Dict[str, Any] = {"compiler_options": compiler_options}
kw: dict[str, Any] = {"compiler_options": compiler_options}
return Compiled(
self._lowering.compile(**kw), # pytype: disable=wrong-keyword-args
self.args_info,

View File

@ -17,7 +17,7 @@ import dataclasses
from functools import partial
import operator
from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Tuple, Union
from typing import Any, Callable, Optional, Protocol, Sequence, Union
import numpy as np
@ -55,7 +55,7 @@ PyTreeDef = tree_util.PyTreeDef
def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * ,
should_discharge: Union[bool, Sequence[bool]] = True
) -> Tuple[core.Jaxpr, List[Any]]:
) -> tuple[core.Jaxpr, list[Any]]:
"""Converts a jaxpr that takes in `Ref`s into one that doesn't."""
if isinstance(should_discharge, bool):
should_discharge = [should_discharge] * len(jaxpr.invars)
@ -69,7 +69,7 @@ def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * ,
@dataclasses.dataclass
class Environment:
env: Dict[core.Var, Any]
env: dict[core.Var, Any]
def read(self, v: core.Atom) -> Any:
if type(v) is core.Literal:
@ -84,7 +84,7 @@ class DischargeRule(Protocol):
def __call__(self, in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue], *args: Any,
**params: Any) -> Tuple[Sequence[Optional[Any]], Sequence[Any]]:
**params: Any) -> tuple[Sequence[Optional[Any]], Sequence[Any]]:
...
_discharge_rules: dict[core.Primitive, DischargeRule] = {}
@ -327,7 +327,7 @@ ad.primitive_jvps[run_state_p] = _run_state_jvp
_save_everything = lambda *_, **__: True
def _convert_outputs_to_writes(
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, List[core.ShapedArray]]:
jaxpr: core.Jaxpr) -> tuple[core.Jaxpr, list[core.ShapedArray]]:
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."
in_avals = [v.aval for v in jaxpr.invars]
@ -677,7 +677,7 @@ def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue],
def initial_style_jaxpr(
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue]
) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]:
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals))
@weakref_lru_cache

View File

@ -14,7 +14,7 @@
"""Module for state primitives."""
from functools import partial
from typing import Any, List, Tuple, Union
from typing import Any, Union
import numpy as np
@ -55,7 +55,7 @@ def _get_impl(ref: AbstractRef, *idx: int, **_):
raise ValueError("Cannot run stateful primitive.")
get_p.def_impl(_get_impl)
Indexer = Tuple[Union[int, slice, Array], ...]
Indexer = tuple[Union[int, slice, Array], ...]
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
def _is_trivial_indexer(idx: Indexer) -> bool:
@ -68,7 +68,7 @@ def _is_trivial_indexer(idx: Indexer) -> bool:
return False
def _unpack_idx(idx: Indexer, ndim: int
) -> Tuple[Tuple[Array, ...], Tuple[bool, ...]]:
) -> tuple[tuple[Array, ...], tuple[bool, ...]]:
if _is_trivial_indexer(idx):
idx = tuple(slice(None) for _ in range(ndim))
indexed_dims_ = []
@ -85,9 +85,9 @@ def _unpack_idx(idx: Indexer, ndim: int
import jax.numpy as jnp
return (tuple(map(jnp.int32, non_slice_idx)), tuple(indexed_dims))
def _get_slice_output_shape(in_shape: Tuple[int, ...],
idx_shapes: Tuple[Tuple[int, ...], ...],
indexed_dims: Tuple[bool, ...]) -> Tuple[int, ...]:
def _get_slice_output_shape(in_shape: tuple[int, ...],
idx_shapes: tuple[tuple[int, ...], ...],
indexed_dims: tuple[bool, ...]) -> tuple[int, ...]:
shape_suffix = [d for i, d in zip(indexed_dims, in_shape) if not i]
shape_prefix, = set(idx_shapes) or [()] # tie fighter
# Move shape prefix dimensions to the front
@ -95,7 +95,7 @@ def _get_slice_output_shape(in_shape: Tuple[int, ...],
return shape
def _get_indexer(ref: AbstractRef, idx: Indexer
) -> Tuple[Indexer, Tuple[bool, ...]]:
) -> tuple[Indexer, tuple[bool, ...]]:
if isinstance(ref.inner_aval, core.ShapedArray):
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
else:
@ -201,7 +201,7 @@ get_p.def_effectful_abstract_eval(_get_abstract_eval)
def _swap_abstract_eval(ref_aval: AbstractRef,
val_aval: core.AbstractValue,
*idx: core.ShapedArray, indexed_dims: Tuple[bool]):
*idx: core.ShapedArray, indexed_dims: tuple[bool]):
out_aval: core.AbstractValue
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
@ -236,7 +236,7 @@ swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
def _addupdate_abstract_eval(ref_aval: AbstractRef,
val_aval: core.AbstractValue,
*idx: core.ShapedArray, indexed_dims: Tuple[bool]):
*idx: core.ShapedArray, indexed_dims: tuple[bool]):
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
if idx and not isinstance(ref_aval.inner_aval, core.ShapedArray):
@ -327,7 +327,7 @@ core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
## get/swap/addupdate JVP rules
def _get_jvp(primals: List[Any], tangents: List[Any], **params: Any):
def _get_jvp(primals: list[Any], tangents: list[Any], **params: Any):
ref_primal, *idx = primals
assert isinstance(ref_primal.aval, AbstractRef)
ref_tangent, *_ = tangents
@ -336,7 +336,7 @@ def _get_jvp(primals: List[Any], tangents: List[Any], **params: Any):
get_p.bind(ref_tangent, *idx, **params)) # type: ignore[arg-type]
ad.primitive_jvps[get_p] = _get_jvp
def _swap_jvp(primals: List[Any], tangents: List[Any], **params: Any):
def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any):
ref_primal, x_primal, *idx = primals
assert isinstance(ref_primal.aval, AbstractRef)
ref_tangent, x_tangent, *_ = tangents
@ -346,7 +346,7 @@ def _swap_jvp(primals: List[Any], tangents: List[Any], **params: Any):
swap_p.bind(ref_tangent, x_tangent, *idx, **params)) # type: ignore[arg-type]
ad.primitive_jvps[swap_p] = _swap_jvp
def addupdate_jvp_rule(primals: List[Any], tangents: List[Any], **params: Any):
def addupdate_jvp_rule(primals: list[Any], tangents: list[Any], **params: Any):
ref_primal, x_primal, *idx = primals
ref_tangent, x_tangent, *_ = tangents
x_tangent = ad_util.instantiate(x_tangent)
@ -397,8 +397,8 @@ pe.partial_eval_jaxpr_custom_rules[addupdate_p] = partial(
## get/swap/addupdate batching rules
def _output_bdim(indexed_dims: Tuple[bool, ...], ref_dim: int,
idxs_shape: Tuple[int, ...]):
def _output_bdim(indexed_dims: tuple[bool, ...], ref_dim: int,
idxs_shape: tuple[int, ...]):
num_idxs_to_left = sum(indexed_dims[:ref_dim])
return ref_dim - num_idxs_to_left + len(idxs_shape)

View File

@ -15,7 +15,7 @@
from __future__ import annotations
import math
from typing import Any, Generic, List, Sequence, Set, Tuple, TypeVar, Union
from typing import Any, Generic, Sequence, TypeVar, Union
from jax._src import core
from jax._src import effects
@ -150,12 +150,12 @@ core.aval_mapping_handlers[AbstractRef] = (_map_ref, _unmap_ref)
def get_ref_state_effects(
avals: Sequence[core.AbstractValue],
effects: core.Effects) -> List[Set[StateEffect]]:
effects: core.Effects) -> list[set[StateEffect]]:
return [{eff for eff in effects
if isinstance(eff, (ReadEffect, WriteEffect, AccumEffect))
and eff.input_index == i} for i, _ in enumerate(avals)]
def shaped_array_ref(shape: Tuple[int, ...], dtype,
def shaped_array_ref(shape: tuple[int, ...], dtype,
weak_type: bool = False,
named_shape = None) -> AbstractRef[core.AbstractValue]:
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type,

View File

@ -22,7 +22,7 @@ import re
import os
import tempfile
import textwrap
from typing import Any, Callable, Dict, List, Generator, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Generator, Optional, Sequence, Union
import unittest
import warnings
import zlib
@ -1057,7 +1057,7 @@ def ignore_warning(**kw):
# -------------------- Mesh parametrization helpers --------------------
MeshSpec = List[Tuple[str, int]]
MeshSpec = list[tuple[str, int]]
@contextmanager
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
@ -1199,7 +1199,7 @@ def strict_promotion_if_dtypes_match(dtypes):
return jax.numpy_dtype_promotion('standard')
_version_regex = re.compile(r"([0-9]+(?:\.[0-9]+)*)(?:(rc|dev).*)?")
def _parse_version(v: str) -> Tuple[int, ...]:
def _parse_version(v: str) -> tuple[int, ...]:
m = _version_regex.match(v)
if m is None:
raise ValueError(f"Unable to parse version '{v}'")
@ -1209,8 +1209,8 @@ def numpy_version():
return _parse_version(np.__version__)
def parameterized_filterable(*,
kwargs: Sequence[Dict[str, Any]],
testcase_name: Optional[Callable[[Dict[str, Any]], str]] = None,
kwargs: Sequence[dict[str, Any]],
testcase_name: Optional[Callable[[dict[str, Any]], str]] = None,
one_containing: Optional[str] = None,
):
"""
@ -1229,7 +1229,7 @@ def parameterized_filterable(*,
only one `kwargs` whose `testcase_name` includes `one_containing`.
"""
# Ensure that all kwargs contain a testcase_name
kwargs_with_testcase_name: Sequence[Dict[str, Any]]
kwargs_with_testcase_name: Sequence[dict[str, Any]]
if testcase_name is not None:
kwargs_with_testcase_name = [dict(testcase_name=testcase_name(kw), **kw)
for kw in kwargs]

View File

@ -1,4 +1,4 @@
from typing import Callable, Tuple
from typing import Callable
import scipy.linalg
@ -11,7 +11,7 @@ from jax._src.typing import ArrayLike, Array
@jit
def _algorithm_11_1_1(F: Array, T: Array) -> Tuple[Array, Array]:
def _algorithm_11_1_1(F: Array, T: Array) -> tuple[Array, Array]:
# Algorithm 11.1.1 from Golub and Van Loan "Matrix Computations"
N = T.shape[0]
minden = jnp.abs(T[0, 0])
@ -50,7 +50,7 @@ will be printed if the error in the array output is estimated to be large.
"""
@_wraps(scipy.linalg.funm, lax_description=_FUNM_LAX_DESCRIPTION)
def funm(A: ArrayLike, func: Callable[[Array], Array], disp: bool = True) -> Tuple[Array, Array]:
def funm(A: ArrayLike, func: Callable[[Array], Array], disp: bool = True) -> tuple[Array, Array]:
A = jnp.asarray(A)
if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError('expected square array_like input')

View File

@ -1,15 +1,15 @@
"""Utility functions adopted from scipy.signal."""
import scipy.signal as osp_signal
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional, Union
import warnings
import jax.numpy as jnp
from jax._src.typing import Array, ArrayLike, DTypeLike
def _triage_segments(window: Union[ArrayLike, str, Tuple[Any, ...]], nperseg: Optional[int],
input_length: int, dtype: DTypeLike) -> Tuple[Array, int]:
def _triage_segments(window: Union[ArrayLike, str, tuple[Any, ...]], nperseg: Optional[int],
input_length: int, dtype: DTypeLike) -> tuple[Array, int]:
"""
Parses window and nperseg arguments for spectrogram and _spectral_helper.
This is a helper function, not meant to be called externally.

View File

@ -16,7 +16,7 @@ import functools
import os
import traceback
import types
from typing import Any, Callable, List, Optional, TypeVar, cast
from typing import Any, Callable, Optional, TypeVar, cast
from jax._src.config import config
from jax._src.lib import xla_extension
@ -25,7 +25,7 @@ from jax._src import util
C = TypeVar("C", bound=Callable[..., Any])
_exclude_paths: List[str] = [__file__, util.__file__]
_exclude_paths: list[str] = [__file__, util.__file__]
def register_exclusion(path: str):
_exclude_paths.append(path)

View File

@ -20,8 +20,8 @@ import functools
from functools import partial
import operator as op
import textwrap
from typing import (Any, Callable, Hashable, Iterable, List, NamedTuple,
Optional, Tuple, Type, TypeVar, Union, overload)
from typing import (Any, Callable, Hashable, Iterable, NamedTuple,
Optional, TypeVar, Union, overload)
import warnings
from jax._src import traceback_util
@ -33,7 +33,7 @@ from jax._src.util import unzip2
traceback_util.register_exclusion(__file__)
T = TypeVar("T")
U = TypeVar("U", bound=Type[Any])
U = TypeVar("U", bound=type[Any])
Leaf = Any
PyTreeDef = pytree.PyTreeDef
@ -41,7 +41,7 @@ PyTreeDef = pytree.PyTreeDef
def tree_flatten(tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None
) -> Tuple[List[Leaf], PyTreeDef]:
) -> tuple[list[Leaf], PyTreeDef]:
"""Flattens a pytree.
The flattening order (i.e. the order of elements in the output list)
@ -79,7 +79,7 @@ def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
def tree_leaves(tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None
) -> List[Leaf]:
) -> list[Leaf]:
"""Gets the leaves of a pytree."""
return pytree.flatten(tree, is_leaf)[0]
@ -92,7 +92,7 @@ def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef:
"""Makes a tuple treedef from an iterable of child treedefs."""
return pytree.tuple(list(treedefs))
def treedef_children(treedef: PyTreeDef) -> List[PyTreeDef]:
def treedef_children(treedef: PyTreeDef) -> list[PyTreeDef]:
return treedef.children()
def treedef_is_leaf(treedef: PyTreeDef) -> bool:
@ -129,8 +129,8 @@ def all_leaves(iterable: Iterable[Any],
_Children = TypeVar("_Children", bound=Iterable[Any])
_AuxData = TypeVar("_AuxData", bound=Hashable)
def register_pytree_node(nodetype: Type[T],
flatten_func: Callable[[T], Tuple[_Children, _AuxData]],
def register_pytree_node(nodetype: type[T],
flatten_func: Callable[[T], tuple[_Children, _AuxData]],
unflatten_func: Callable[[_AuxData, _Children], T]):
"""Extends the set of types that are considered internal nodes in pytrees.
@ -393,7 +393,7 @@ register_pytree_node(
def broadcast_prefix(prefix_tree: Any, full_tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None
) -> List[Any]:
) -> list[Any]:
# If prefix_tree is not a tree prefix of full_tree, this code can raise a
# ValueError; use prefix_errors to find disagreements and raise more precise
# error messages.
@ -403,7 +403,7 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,
tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
return result
def flatten_one_level(pytree: Any) -> Tuple[List[Any], Hashable]:
def flatten_one_level(pytree: Any) -> tuple[list[Any], Hashable]:
"""Flatten the given pytree node by one level.
Args:
@ -429,12 +429,12 @@ def flatten_one_level(pytree: Any) -> Tuple[List[Any], Hashable]:
def prefix_errors(prefix_tree: Any, full_tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> List[Callable[[str], ValueError]]:
) -> list[Callable[[str], ValueError]]:
return list(_prefix_error((), prefix_tree, full_tree, is_leaf))
def equality_errors(
tree1: Any, tree2: Any, is_leaf: Optional[Callable[[Any], bool]] = None,
) -> Iterable[Tuple[KeyPath, str, str, str]]:
) -> Iterable[tuple[KeyPath, str, str, str]]:
"""Helper to describe structural differences between two pytrees.
Args:
@ -562,7 +562,7 @@ class FlattenedIndexKey():
BuiltInKeyEntry = Union[SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey]
KeyEntry = TypeVar("KeyEntry", bound=Hashable)
KeyPath = Tuple[KeyEntry, ...]
KeyPath = tuple[KeyEntry, ...]
def keystr(keys: KeyPath):
"""Helper to pretty-print a tuple of keys.
@ -582,7 +582,7 @@ class _RegistryWithKeypathsEntry(NamedTuple):
def register_keypaths(
ty: Type[T], handler: Callable[[T], Tuple[KeyEntry, ...]]
ty: type[T], handler: Callable[[T], tuple[KeyEntry, ...]]
) -> None:
"""[Deprecated] Register the method to get keypaths for type.
@ -603,7 +603,7 @@ def register_keypaths(
def _register_keypaths(
ty: Type[T], handler: Callable[[T], Tuple[KeyEntry, ...]]
ty: type[T], handler: Callable[[T], tuple[KeyEntry, ...]]
) -> None:
def flatten_with_keys(xs):
children, treedef = _registry[ty].to_iter(xs)
@ -634,13 +634,13 @@ _register_keypaths(
def register_pytree_with_keys(
nodetype: Type[T],
nodetype: type[T],
flatten_with_keys: Callable[
[T], Tuple[Iterable[Tuple[KeyEntry, Any]], _AuxData]
[T], tuple[Iterable[tuple[KeyEntry, Any]], _AuxData]
],
unflatten_func: Callable[[_AuxData, Iterable[Any]], T],
flatten_func: Optional[
Callable[[T], Tuple[Iterable[Any], _AuxData]]
Callable[[T], tuple[Iterable[Any], _AuxData]]
] = None,
):
"""Extends the set of types that are considered internal nodes in pytrees.
@ -708,7 +708,7 @@ def register_pytree_with_keys_class(cls: U) -> U:
def tree_flatten_with_path(
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
) -> Tuple[List[Tuple[KeyPath, Any]], PyTreeDef]:
) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]:
"""Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path.
Args:
@ -725,7 +725,7 @@ def tree_flatten_with_path(
def tree_leaves_with_path(
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
) -> List[Tuple[KeyPath, Any]]:
) -> list[tuple[KeyPath, Any]]:
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
Args:
@ -739,7 +739,7 @@ def tree_leaves_with_path(
def generate_key_paths(
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
) -> List[Tuple[KeyPath, Any]]:
) -> list[tuple[KeyPath, Any]]:
return list(_generate_key_paths_((), tree, is_leaf))
_generate_key_paths = generate_key_paths # alias for backward compat
@ -749,7 +749,7 @@ def _generate_key_paths_(
key_path: KeyPath,
tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> Iterable[Tuple[KeyPath, Any]]:
) -> Iterable[tuple[KeyPath, Any]]:
if is_leaf and is_leaf(tree):
yield key_path, tree
return

View File

@ -17,9 +17,8 @@ from functools import partial
import itertools as it
import logging
import operator
from typing import (Any, Callable, Generic, Iterable, Iterator, List,
Optional, Sequence, Set, Tuple, TypeVar, overload,
TYPE_CHECKING, cast)
from typing import (Any, Callable, Generic, Iterable, Iterator, Optional,
Sequence, TypeVar, overload, TYPE_CHECKING, cast)
import weakref
import numpy as np
@ -43,13 +42,13 @@ if TYPE_CHECKING:
# to that used for builtins.zip in python/typeshed. This supports
# return types matching input types for up to three arguments.
@overload
def safe_zip(__arg1: Iterable[T1]) -> List[Tuple[T1]]: ...
def safe_zip(__arg1: Iterable[T1]) -> list[tuple[T1]]: ...
@overload
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2]) -> List[Tuple[T1, T2]]: ...
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2]) -> list[tuple[T1, T2]]: ...
@overload
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> List[Tuple[T1, T2, T3]]: ...
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> list[tuple[T1, T2, T3]]: ...
@overload
def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> List[Tuple[Any, ...]]: ...
def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> list[tuple[Any, ...]]: ...
def safe_zip(*args):
args = list(map(list, args))
@ -77,16 +76,16 @@ if TYPE_CHECKING:
# to that used for builtins.map in python/typeshed. This supports
# checking input types for the callable with up to three arguments.
@overload
def safe_map(f: Callable[[T1], T], __arg1: Iterable[T1]) -> List[T]: ...
def safe_map(f: Callable[[T1], T], __arg1: Iterable[T1]) -> list[T]: ...
@overload
def safe_map(f: Callable[[T1, T2], T], __arg1: Iterable[T1], __arg2: Iterable[T2]) -> List[T]: ...
def safe_map(f: Callable[[T1, T2], T], __arg1: Iterable[T1], __arg2: Iterable[T2]) -> list[T]: ...
@overload
def safe_map(f: Callable[[T1, T2, T3], T], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> List[T]: ...
def safe_map(f: Callable[[T1, T2, T3], T], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> list[T]: ...
@overload
def safe_map(f: Callable[..., T], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> List[T]: ...
def safe_map(f: Callable[..., T], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> list[T]: ...
def safe_map(f, *args):
args = list(map(list, args))
@ -108,26 +107,26 @@ else:
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
return list(map(f, *args))
def unzip2(xys: Iterable[Tuple[T1, T2]]
) -> Tuple[Tuple[T1, ...], Tuple[T2, ...]]:
def unzip2(xys: Iterable[tuple[T1, T2]]
) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
"""Unzip sequence of length-2 tuples into two tuples."""
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
# is too permissive about inputs, and does not guarantee a length-2 output.
xs: List[T1] = []
ys: List[T2] = []
xs: list[T1] = []
ys: list[T2] = []
for x, y in xys:
xs.append(x)
ys.append(y)
return tuple(xs), tuple(ys)
def unzip3(xyzs: Iterable[Tuple[T1, T2, T3]]
) -> Tuple[Tuple[T1, ...], Tuple[T2, ...], Tuple[T3, ...]]:
def unzip3(xyzs: Iterable[tuple[T1, T2, T3]]
) -> tuple[tuple[T1, ...], tuple[T2, ...], tuple[T3, ...]]:
"""Unzip sequence of length-3 tuples into three tuples."""
# Note: we deliberately don't use zip(*xyzs) because it is lazily evaluated,
# is too permissive about inputs, and does not guarantee a length-3 output.
xs: List[T1] = []
ys: List[T2] = []
zs: List[T3] = []
xs: list[T1] = []
ys: list[T2] = []
zs: list[T3] = []
for x, y, z in xyzs:
xs.append(x)
ys.append(y)
@ -140,7 +139,7 @@ def subvals(lst, replace):
lst[i] = v
return tuple(lst)
def split_list(args: Sequence[T], ns: Sequence[int]) -> List[List[T]]:
def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]:
args = list(args)
lists = []
for n in ns:
@ -149,14 +148,14 @@ def split_list(args: Sequence[T], ns: Sequence[int]) -> List[List[T]]:
lists.append(args)
return lists
def partition_list(bs: Sequence[bool], l: Sequence[T]) -> Tuple[List[T], List[T]]:
def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T]]:
assert len(bs) == len(l)
lists = [], [] # type: ignore
for b, x in zip(bs, l):
lists[b].append(x)
return lists
def merge_lists(bs: Sequence[bool], l0: Sequence[T], l1: Sequence[T]) -> List[T]:
def merge_lists(bs: Sequence[bool], l0: Sequence[T], l1: Sequence[T]) -> list[T]:
assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0)
i0, i1 = iter(l0), iter(l1)
out = [next(i1) if b else next(i0) for b in bs]
@ -170,7 +169,7 @@ def split_dict(dct, names):
assert not dct
return lst
def concatenate(xs: Iterable[Sequence[T]]) -> List[T]:
def concatenate(xs: Iterable[Sequence[T]]) -> list[T]:
"""Concatenates/flattens a list of lists."""
return list(it.chain.from_iterable(xs))
@ -178,7 +177,7 @@ flatten = concatenate
_unflatten_done = object()
def unflatten(xs: Iterable[T], ns: Sequence[int]) -> List[List[T]]:
def unflatten(xs: Iterable[T], ns: Sequence[int]) -> list[list[T]]:
"""Splits `xs` into subsequences of lengths `ns`.
Unlike `split_list`, the `sum(ns)` must be equal to `len(xs)`."""
@ -499,8 +498,8 @@ def distributed_debug_log(*pairs):
class OrderedSet(Generic[T]):
elts_set: Set[T]
elts_list: List[T]
elts_set: set[T]
elts_list: list[T]
def __init__(self):
self.elts_set = set()

View File

@ -30,7 +30,7 @@ import platform as py_platform
import pkgutil
import sys
import threading
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Tuple, Union
from typing import Any, Callable, Mapping, Optional, Union
import warnings
import numpy as np
@ -109,7 +109,7 @@ def get_compile_options(
use_auto_spmd_partitioning: bool = False,
auto_spmd_partitioning_mesh_shape=[],
auto_spmd_partitioning_mesh_ids=[],
env_options_overrides: Optional[Dict[str, str]] = None,
env_options_overrides: Optional[dict[str, str]] = None,
) -> xla_client.CompileOptions:
"""Returns the compile options to use, as derived from flag values.
@ -231,10 +231,10 @@ class BackendRegistration:
# a buggy plugin.
experimental: bool = False
_backend_factories: Dict[str, BackendRegistration] = {}
_backend_factories: dict[str, BackendRegistration] = {}
_default_backend: Optional[xla_client.Client] = None
_backends : Dict[str, xla_client.Client] = {}
_backends_errors : Dict[str, str] = {}
_backends : dict[str, xla_client.Client] = {}
_backends_errors : dict[str, str] = {}
_backend_lock = threading.Lock()
# The set of known non-experimental plugins.
@ -245,7 +245,7 @@ _backend_lock = threading.Lock()
# It is fine for a plugin not to implement every feature that JAX uses, provided
# that a reasonable feature set is implemented and the plugin fails gracefully
# for unimplemented features. Wrong outputs are not acceptable.
_nonexperimental_plugins: Set[str] = set()
_nonexperimental_plugins: set[str] = set()
def register_backend_factory(name: str, factory: BackendFactory, *,
priority: int = 0,
@ -314,7 +314,7 @@ if hasattr(xla_client, "make_tpu_client"):
def _get_pjrt_plugin_names_and_library_paths(
plugins_from_env: str,
) -> Dict[str, str]:
) -> dict[str, str]:
"""Gets the names and library paths of PJRT plugins to load from env var.
Args:
@ -343,7 +343,7 @@ def _get_pjrt_plugin_names_and_library_paths(
def _get_pjrt_plugin_config(
json_path: str,
) -> Tuple[str, Optional[Mapping[str, Union[str, int, List[int], float]]]]:
) -> tuple[str, Optional[Mapping[str, Union[str, int, list[int], float]]]]:
"""Gets PJRT plugin configuration from a json file.
The json file needs to have a "library_path" field for the plugin library
@ -440,7 +440,7 @@ def register_plugin(
*,
priority: int = 400,
library_path: Optional[str] = None,
options: Optional[Mapping[str, Union[str, int, List[int], float]]] = None,
options: Optional[Mapping[str, Union[str, int, list[int], float]]] = None,
) -> None:
"""Registers a backend factory for the PJRT plugin.
@ -516,7 +516,7 @@ _platform_aliases = {
"rocm": "gpu",
}
_alias_to_platforms: Dict[str, List[str]] = {}
_alias_to_platforms: dict[str, list[str]] = {}
for _platform, _alias in _platform_aliases.items():
_alias_to_platforms.setdefault(_alias, []).append(_platform)
@ -550,7 +550,7 @@ def canonicalize_platform(platform: str) -> str:
"Platforms present are: " + ",".join(b.keys()))
def expand_platform_alias(platform: str) -> List[str]:
def expand_platform_alias(platform: str) -> list[str]:
"""Expands, e.g., "gpu" to ["cuda", "rocm"].
This is used for convenience reasons: we expect cuda and rocm to act similarly
@ -561,7 +561,7 @@ def expand_platform_alias(platform: str) -> List[str]:
def is_gpu(platform):
return platform in ("cuda", "rocm")
def backends() -> Dict[str, xla_client.Client]:
def backends() -> dict[str, xla_client.Client]:
global _backends
global _backends_errors
global _default_backend
@ -732,7 +732,7 @@ def local_device_count(
def devices(
backend: Optional[Union[str, xla_client.Client]] = None
) -> List[xla_client.Device]:
) -> list[xla_client.Device]:
"""Returns a list of all devices for a given backend.
.. currentmodule:: jaxlib.xla_extension
@ -765,7 +765,7 @@ def default_backend() -> str:
@lru_cache
def local_devices(process_index: Optional[int] = None,
backend: Optional[Union[str, xla_client.Client]] = None,
host_id: Optional[int] = None) -> List[xla_client.Device]:
host_id: Optional[int] = None) -> list[xla_client.Device]:
"""Like :py:func:`jax.devices`, but only returns devices local to a given process.
If ``process_index`` is ``None``, returns devices local to this process.
@ -839,7 +839,7 @@ def host_count(backend: Optional[Union[str, xla_client.Client]] = None) -> int:
# TODO: remove this sometime after jax 0.2.13 is released
def host_ids(
backend: Optional[Union[str, xla_client.Client]] = None
) -> List[int]:
) -> list[int]:
warnings.warn(
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
"instead. jax.host_ids will eventually be removed; please update your "

View File

@ -89,7 +89,7 @@ Example Usage:
.. _Optax: https://github.com/deepmind/optax
"""
from typing import Any, Callable, NamedTuple, Tuple, Union
from typing import Any, Callable, NamedTuple, Union
from collections import namedtuple
import functools
@ -138,7 +138,7 @@ class Optimizer(NamedTuple):
Schedule = Callable[[Step], float]
def optimizer(opt_maker: Callable[...,
Tuple[Callable[[Params], State],
tuple[Callable[[Params], State],
Callable[[Step, Updates, Params], Params],
Callable[[State], Params]]]) -> Callable[..., Optimizer]:
"""Decorator to make an optimizer defined for arrays generalize to containers.

View File

@ -22,7 +22,7 @@ import os
import re
import time
import threading
from typing import Awaitable, Any, Callable, Dict, Optional, Sequence, Union
from typing import Awaitable, Any, Callable, Optional, Sequence, Union
import jax
from jax._src import distributed
@ -236,7 +236,7 @@ def estimate_read_memory_footprint(t: ts.TensorStore,
async def async_deserialize(
in_sharding: sharding_impls.XLACompatibleSharding,
tensorstore_spec: Union[ts.Spec, Dict[str, Any]],
tensorstore_spec: Union[ts.Spec, dict[str, Any]],
global_shape: Optional[Sequence[int]] = None,
dtype=None,
byte_limiter: Optional[_LimitInFlightBytes] = None,
@ -288,7 +288,7 @@ async def async_deserialize(
def run_deserialization(shardings: Sequence[sharding.Sharding],
tensorstore_specs: Sequence[Dict[str, Any]],
tensorstore_specs: Sequence[dict[str, Any]],
global_shapes: Optional[Sequence[array.Shape]] = None,
dtypes: Optional[Sequence[typing.DTypeLike]] = None,
concurrent_gb: int = 32):
@ -370,7 +370,7 @@ class GlobalAsyncCheckpointManagerBase(metaclass=abc.ABCMeta):
@abc.abstractmethod
def deserialize(self, shardings: Sequence[sharding.Sharding],
tensorstore_specs: Sequence[Dict[str, Any]],
tensorstore_specs: Sequence[dict[str, Any]],
global_shapes: Optional[Sequence[array.Shape]] = None,
dtypes: Optional[Sequence[typing.DTypeLike]] = None):
"""Deserializes GDAs from TensorStore."""
@ -519,7 +519,7 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
self._start_async_commit(on_commit_callback)
def deserialize(self, shardings: Sequence[sharding.Sharding],
tensorstore_specs: Sequence[Dict[str, Any]],
tensorstore_specs: Sequence[dict[str, Any]],
global_shapes: Optional[Sequence[array.Shape]] = None,
dtypes: Optional[Sequence[typing.DTypeLike]] = None,
concurrent_gb: int = 32):

Some files were not shown because too many files have changed in this diff Show More