mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537 PiperOrigin-RevId: 542969282
This commit is contained in:
parent
f67acee129
commit
816ba91263
3
.github/workflows/cat_slurm_logs.py
vendored
3
.github/workflows/cat_slurm_logs.py
vendored
@ -15,7 +15,6 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
ISSUE_FORMAT = """\
|
||||
<details><summary>Failure summary {name}</summary>
|
||||
@ -27,7 +26,7 @@ ISSUE_FORMAT = """\
|
||||
</details>
|
||||
"""
|
||||
|
||||
def main(logfiles: List[str], outfile: str):
|
||||
def main(logfiles: list[str], outfile: str):
|
||||
print(f"extracting content of {logfiles}")
|
||||
print(f"and writing to {outfile}")
|
||||
with open(outfile, 'w') as f:
|
||||
|
@ -536,9 +536,9 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from typing import Tuple, Iterable\n",
|
||||
"from typing import Iterable\n",
|
||||
"\n",
|
||||
"def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:\n",
|
||||
"def flatten_MyContainer(container) -> tuple[Iterable[int], str]:\n",
|
||||
" \"\"\"Returns an iterable over container contents, and aux data.\"\"\"\n",
|
||||
" flat_contents = [container.a, container.b, container.c]\n",
|
||||
"\n",
|
||||
@ -593,7 +593,7 @@
|
||||
"class MyKeyPathContainer(MyContainer):\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"def flatten_with_keys_MyKeyPathContainer(container) -> Tuple[Iterable[int], str]:\n",
|
||||
"def flatten_with_keys_MyKeyPathContainer(container) -> tuple[Iterable[int], str]:\n",
|
||||
" \"\"\"Returns an iterable over container contents, and aux data.\"\"\"\n",
|
||||
" \n",
|
||||
" # GetAttrKey is a common way to express an attribute key. Users are free\n",
|
||||
|
@ -277,9 +277,9 @@ except TypeError as e:
|
||||
To solve this, we need to register our container with JAX by telling it how to flatten and unflatten it:
|
||||
|
||||
```{code-cell} ipython3
|
||||
from typing import Tuple, Iterable
|
||||
from typing import Iterable
|
||||
|
||||
def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:
|
||||
def flatten_MyContainer(container) -> tuple[Iterable[int], str]:
|
||||
"""Returns an iterable over container contents, and aux data."""
|
||||
flat_contents = [container.a, container.b, container.c]
|
||||
|
||||
@ -312,7 +312,7 @@ Alternatively, using the key path API mentioned above, you can register this con
|
||||
class MyKeyPathContainer(MyContainer):
|
||||
pass
|
||||
|
||||
def flatten_with_keys_MyKeyPathContainer(container) -> Tuple[Iterable[int], str]:
|
||||
def flatten_with_keys_MyKeyPathContainer(container) -> tuple[Iterable[int], str]:
|
||||
"""Returns an iterable over container contents, and aux data."""
|
||||
|
||||
# GetAttrKey is a common way to express an attribute key. Users are free
|
||||
|
@ -541,7 +541,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import NamedTuple, Tuple\n",
|
||||
"from typing import NamedTuple\n",
|
||||
"import functools\n",
|
||||
"\n",
|
||||
"class Params(NamedTuple):\n",
|
||||
@ -571,7 +571,7 @@
|
||||
"# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it\n",
|
||||
"# 'num_devices', but could have used anything, so long as `pmean` used the same.\n",
|
||||
"@functools.partial(jax.pmap, axis_name='num_devices')\n",
|
||||
"def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:\n",
|
||||
"def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> tuple[Params, jnp.ndarray]:\n",
|
||||
" \"\"\"Performs one SGD update step on params using the given data.\"\"\"\n",
|
||||
"\n",
|
||||
" # Compute the gradients on the given minibatch (individually on each device).\n",
|
||||
|
@ -219,7 +219,7 @@ If this example is too confusing, you can find the same example, but without par
|
||||
```{code-cell} ipython3
|
||||
:id: cI8xQqzRrc-4
|
||||
|
||||
from typing import NamedTuple, Tuple
|
||||
from typing import NamedTuple
|
||||
import functools
|
||||
|
||||
class Params(NamedTuple):
|
||||
@ -249,7 +249,7 @@ LEARNING_RATE = 0.005
|
||||
# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it
|
||||
# 'num_devices', but could have used anything, so long as `pmean` used the same.
|
||||
@functools.partial(jax.pmap, axis_name='num_devices')
|
||||
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
|
||||
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> tuple[Params, jnp.ndarray]:
|
||||
"""Performs one SGD update step on params using the given data."""
|
||||
|
||||
# Compute the gradients on the given minibatch (individually on each device).
|
||||
|
@ -161,13 +161,11 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from typing import Tuple\n",
|
||||
"\n",
|
||||
"CounterState = int\n",
|
||||
"\n",
|
||||
"class CounterV2:\n",
|
||||
"\n",
|
||||
" def count(self, n: CounterState) -> Tuple[int, CounterState]:\n",
|
||||
" def count(self, n: CounterState) -> tuple[int, CounterState]:\n",
|
||||
" # You could just return n+1, but here we separate its role as \n",
|
||||
" # the output and as the counter state for didactic purposes.\n",
|
||||
" return n+1, n+1\n",
|
||||
|
@ -102,13 +102,11 @@ Part of the problem with our counter was that the returned value didn't depend o
|
||||
:id: 53pSdK4KoOEZ
|
||||
:outputId: 5ac72b9c-7029-4bf2-de8d-1d412bd74c79
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
CounterState = int
|
||||
|
||||
class CounterV2:
|
||||
|
||||
def count(self, n: CounterState) -> Tuple[int, CounterState]:
|
||||
def count(self, n: CounterState) -> tuple[int, CounterState]:
|
||||
# You could just return n+1, but here we separate its role as
|
||||
# the output and as the counter state for didactic purposes.
|
||||
return n+1, n+1
|
||||
|
@ -1141,10 +1141,8 @@
|
||||
},
|
||||
"source": [
|
||||
"```python\n",
|
||||
"from typing import Tuple, List\n",
|
||||
"\n",
|
||||
"LayerParam = Tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer\n",
|
||||
"ParamsList = List[LayerParam]\n",
|
||||
"LayerParam = tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer\n",
|
||||
"ParamsList = list[LayerParam]\n",
|
||||
"\n",
|
||||
"def net(params: ParamsList, x: jnp.ndarray):\n",
|
||||
" for W, b in params:\n",
|
||||
|
@ -497,10 +497,8 @@ For example, one common pattern in large [Transformer models](https://en.wikiped
|
||||
+++ {"id": "BUeqKFRS5yPU"}
|
||||
|
||||
```python
|
||||
from typing import Tuple, List
|
||||
|
||||
LayerParam = Tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer
|
||||
ParamsList = List[LayerParam]
|
||||
LayerParam = tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer
|
||||
ParamsList = list[LayerParam]
|
||||
|
||||
def net(params: ParamsList, x: jnp.ndarray):
|
||||
for W, b in params:
|
||||
|
@ -15,7 +15,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Set
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -40,7 +39,7 @@ def zeros_like_array(x):
|
||||
aval = ShapedArray(np.shape(x), dtype, weak_type=weak_type)
|
||||
return ad_util.zeros_like_aval(aval)
|
||||
|
||||
numpy_scalar_types: Set[type] = { # pylint: disable=g-bare-generic
|
||||
numpy_scalar_types: set[type] = { # pylint: disable=g-bare-generic
|
||||
np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.complex64, np.complex128,
|
||||
@ -52,7 +51,7 @@ if dtypes.int4 is not None:
|
||||
if dtypes.uint4 is not None:
|
||||
numpy_scalar_types.add(dtypes.uint4)
|
||||
|
||||
array_types: Set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic
|
||||
array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic
|
||||
|
||||
def canonical_concrete_aval(val, weak_type=None):
|
||||
return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val,
|
||||
|
@ -15,8 +15,7 @@
|
||||
import functools
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import (Any, Callable, FrozenSet, List, Optional, Sequence, Tuple,
|
||||
Union)
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
import types
|
||||
|
||||
import numpy as np
|
||||
@ -134,7 +133,7 @@ checkpoint_policies = types.SimpleNamespace(
|
||||
@api_boundary
|
||||
def checkpoint(fun: Callable, *, prevent_cse: bool = True,
|
||||
policy: Optional[Callable[..., bool]] = None,
|
||||
static_argnums: Union[int, Tuple[int, ...]] = (),
|
||||
static_argnums: Union[int, tuple[int, ...]] = (),
|
||||
) -> Callable:
|
||||
"""Make ``fun`` recompute internal linearization points when differentiated.
|
||||
|
||||
@ -348,14 +347,14 @@ class WrapHashably:
|
||||
# See api_benchmark.py:bench_remat_eager_retracing_overheads_static_argnums.
|
||||
# On that benchmark, including this caching makes a ~10x difference (which can
|
||||
# be made arbitrary large by involving larger functions to be traced).
|
||||
def _dyn_args_fun(fun: Callable, static_argnums: FrozenSet[int],
|
||||
static_args: Tuple[WrapHashably, ...], nargs: int):
|
||||
def _dyn_args_fun(fun: Callable, static_argnums: frozenset[int],
|
||||
static_args: tuple[WrapHashably, ...], nargs: int):
|
||||
if any(isinstance(x.val, core.Tracer) for x in static_args):
|
||||
return _dyn_args_fun_uncached(fun, static_argnums, static_args, nargs)
|
||||
return _dyn_args_fun_cached(fun, static_argnums, static_args, nargs)
|
||||
|
||||
def _dyn_args_fun_uncached(fun: Callable, static_argnums: FrozenSet[int],
|
||||
static_args: Tuple[WrapHashably, ...], nargs: int):
|
||||
def _dyn_args_fun_uncached(fun: Callable, static_argnums: frozenset[int],
|
||||
static_args: tuple[WrapHashably, ...], nargs: int):
|
||||
def new_fun(*dyn_args, **kwargs):
|
||||
static_args_, dyn_args_ = iter(static_args), iter(dyn_args)
|
||||
full_args = [next(static_args_).val if i in static_argnums
|
||||
@ -391,7 +390,7 @@ def _trace_to_jaxpr(fun, in_tree, in_avals):
|
||||
|
||||
### Utilities
|
||||
|
||||
def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
|
||||
def saved_residuals(f, *args, **kwargs) -> list[tuple[core.AbstractValue, str]]:
|
||||
in_leaves, in_tree = tree_flatten((args, kwargs))
|
||||
|
||||
def f_(*args):
|
||||
@ -409,7 +408,7 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
|
||||
arg_info = pe.arg_info_all(dbg)
|
||||
return _saved_residuals(jaxpr, arg_info)
|
||||
|
||||
def _saved_residuals(jaxpr, arg_info) -> List[Tuple[core.AbstractValue, str]]:
|
||||
def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]:
|
||||
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
|
||||
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}
|
||||
|
||||
@ -579,7 +578,7 @@ ad.reducing_transposes[remat_p] = remat_transpose
|
||||
def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: Union[bool, Sequence[bool]],
|
||||
out_zeros: Union[bool, Sequence[bool]],
|
||||
reduce_axes: Sequence[core.AxisName],
|
||||
) -> Tuple[core.ClosedJaxpr, List[bool]]:
|
||||
) -> tuple[core.ClosedJaxpr, list[bool]]:
|
||||
if type(in_linear) is bool:
|
||||
in_linear = (in_linear,) * len(jaxpr.in_avals)
|
||||
if type(out_zeros) is bool:
|
||||
@ -640,8 +639,8 @@ batching.axis_primitive_batchers[remat_p] = partial(remat_vmap, None)
|
||||
batching.spmd_axis_primitive_batchers[remat_p] = remat_vmap
|
||||
|
||||
# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
|
||||
def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn
|
||||
) -> Tuple[List[bool], Optional[core.JaxprEqn]]:
|
||||
def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
) -> tuple[list[bool], Optional[core.JaxprEqn]]:
|
||||
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
|
||||
new_params = dict(eqn.params, jaxpr=new_jaxpr)
|
||||
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
|
||||
@ -781,7 +780,7 @@ def checkpoint_wrapper(
|
||||
*,
|
||||
concrete: bool = False,
|
||||
prevent_cse: bool = True,
|
||||
static_argnums: Union[int, Tuple[int, ...]] = (),
|
||||
static_argnums: Union[int, tuple[int, ...]] = (),
|
||||
policy: Optional[Callable[..., bool]] = None,
|
||||
) -> Callable:
|
||||
if concrete:
|
||||
|
@ -14,7 +14,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from typing import Any, Callable, Dict, TypeVar, Union, cast
|
||||
from typing import Any, Callable, TypeVar, Union, cast
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import traceback_util
|
||||
@ -30,7 +30,7 @@ T = TypeVar('T')
|
||||
|
||||
map = safe_map
|
||||
|
||||
jaxval_adders: Dict[type, Callable[[ArrayLike, ArrayLike], Array]] = {}
|
||||
jaxval_adders: dict[type, Callable[[ArrayLike, ArrayLike], Array]] = {}
|
||||
|
||||
def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
return add_jaxvals_p.bind(x, y)
|
||||
@ -46,7 +46,7 @@ def add_impl(xs, ys):
|
||||
def add_abstract(xs, ys):
|
||||
return lattice_join(xs, ys)
|
||||
|
||||
jaxval_zeros_likers: Dict[type, Callable[[Any], Array]] = {}
|
||||
jaxval_zeros_likers: dict[type, Callable[[Any], Array]] = {}
|
||||
|
||||
def instantiate(z: Union[Zero, Array]) -> Array:
|
||||
if type(z) is Zero:
|
||||
@ -56,7 +56,7 @@ def instantiate(z: Union[Zero, Array]) -> Array:
|
||||
def zeros_like_aval(aval: core.AbstractValue) -> Array:
|
||||
return aval_zeros_likers[type(aval)](aval)
|
||||
|
||||
aval_zeros_likers: Dict[type, Callable[[Any], Array]] = {}
|
||||
aval_zeros_likers: dict[type, Callable[[Any], Array]] = {}
|
||||
|
||||
def zeros_like_jaxval(val: ArrayLike) -> Array:
|
||||
return zeros_like_p.bind(val)
|
||||
|
@ -27,8 +27,8 @@ from functools import partial
|
||||
import inspect
|
||||
import math
|
||||
import typing
|
||||
from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
|
||||
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union,
|
||||
from typing import (Any, Callable, Generator, Hashable, Iterable, Literal,
|
||||
NamedTuple, Optional, Sequence, TypeVar, Union,
|
||||
overload, cast)
|
||||
import weakref
|
||||
|
||||
@ -363,7 +363,7 @@ def disable_jit(disable: bool = True):
|
||||
|
||||
def xla_computation(fun: Callable,
|
||||
static_argnums: Union[int, Iterable[int]] = (),
|
||||
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
|
||||
axis_env: Optional[Sequence[tuple[AxisName, int]]] = None,
|
||||
in_parts=None, out_parts=None,
|
||||
backend: Optional[str] = None,
|
||||
tuple_args: bool = False,
|
||||
@ -658,7 +658,7 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
has_aux: bool = False, holomorphic: bool = False,
|
||||
allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()
|
||||
) -> Callable[..., Tuple[Any, Any]]:
|
||||
) -> Callable[..., tuple[Any, Any]]:
|
||||
"""Create a function that evaluates both ``fun`` and the gradient of ``fun``.
|
||||
|
||||
Args:
|
||||
@ -1069,7 +1069,7 @@ def vmap(fun: F,
|
||||
out_axes: Any = 0,
|
||||
axis_name: Optional[AxisName] = None,
|
||||
axis_size: Optional[int] = None,
|
||||
spmd_axis_name: Optional[Union[AxisName, Tuple[AxisName, ...]]] = None
|
||||
spmd_axis_name: Optional[Union[AxisName, tuple[AxisName, ...]]] = None
|
||||
) -> F:
|
||||
"""Vectorizing map. Creates a function which maps ``fun`` over argument axes.
|
||||
|
||||
@ -1254,7 +1254,7 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
|
||||
f"containing an array, got empty *args={args} and **kwargs={kwargs}"
|
||||
)
|
||||
|
||||
def _get_axis_size(name: str, shape: Tuple[core.AxisSize, ...], axis: int
|
||||
def _get_axis_size(name: str, shape: tuple[core.AxisSize, ...], axis: int
|
||||
) -> core.AxisSize:
|
||||
try:
|
||||
return shape[axis]
|
||||
@ -1328,7 +1328,7 @@ def pmap(
|
||||
backend: Optional[str] = None,
|
||||
axis_size: Optional[int] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
|
||||
global_arg_shapes: Optional[tuple[tuple[int, ...], ...]] = None,
|
||||
) -> Any:
|
||||
"""Parallel map with support for collective operations.
|
||||
|
||||
@ -1888,7 +1888,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
|
||||
def jvp(
|
||||
fun: Callable, primals, tangents, has_aux: bool = False
|
||||
) -> Tuple[Any, ...]:
|
||||
) -> tuple[Any, ...]:
|
||||
"""Computes a (forward-mode) Jacobian-vector product of ``fun``.
|
||||
|
||||
Args:
|
||||
@ -1971,16 +1971,16 @@ def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
|
||||
|
||||
@overload
|
||||
def linearize(fun: Callable, *primals, has_aux: Literal[False] = False
|
||||
) -> Tuple[Any, Callable]:
|
||||
) -> tuple[Any, Callable]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def linearize(fun: Callable, *primals, has_aux: Literal[True]
|
||||
) -> Tuple[Any, Callable, Any]:
|
||||
) -> tuple[Any, Callable, Any]:
|
||||
...
|
||||
|
||||
def linearize(fun: Callable, *primals, has_aux: bool = False
|
||||
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
|
||||
) -> Union[tuple[Any, Callable], tuple[Any, Callable, Any]]:
|
||||
"""Produces a linear approximation to ``fun`` using :py:func:`jvp` and partial eval.
|
||||
|
||||
Args:
|
||||
@ -2137,17 +2137,17 @@ def _vjp_pullback_wrapper(name, cotangent_dtypes, cotangent_shapes, io_tree,
|
||||
def vjp(fun: Callable[..., T],
|
||||
*primals: Any,
|
||||
has_aux: Literal[False] = False,
|
||||
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable]:
|
||||
reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def vjp(fun: Callable[..., Tuple[T, U]], *primals: Any,
|
||||
def vjp(fun: Callable[..., tuple[T, U]], *primals: Any,
|
||||
has_aux: Literal[True],
|
||||
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable, U]:
|
||||
reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable, U]:
|
||||
...
|
||||
def vjp( # type: ignore
|
||||
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
|
||||
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
|
||||
) -> Union[tuple[Any, Callable], tuple[Any, Callable, Any]]:
|
||||
"""Compute a (reverse-mode) vector-Jacobian product of ``fun``.
|
||||
|
||||
:py:func:`grad` is implemented as a special case of :py:func:`vjp`.
|
||||
@ -2310,7 +2310,7 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
|
||||
|
||||
|
||||
def _flat_axes_specs(abstracted_axes, *args, **kwargs
|
||||
) -> List[pe.AbstractedAxesSpec]:
|
||||
) -> list[pe.AbstractedAxesSpec]:
|
||||
if kwargs: raise NotImplementedError
|
||||
def ax_leaf(l):
|
||||
return (isinstance(l, dict) and all_leaves(l.values()) or
|
||||
@ -2323,7 +2323,7 @@ def _flat_axes_specs(abstracted_axes, *args, **kwargs
|
||||
@overload
|
||||
def make_jaxpr(fun: Callable, # type: ignore
|
||||
static_argnums: Union[int, Iterable[int]] = (),
|
||||
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
|
||||
axis_env: Optional[Sequence[tuple[AxisName, int]]] = None,
|
||||
return_shape: Literal[False] = ...,
|
||||
abstracted_axes: Optional[Any] = None,
|
||||
) -> Callable[..., core.ClosedJaxpr]:
|
||||
@ -2332,19 +2332,19 @@ def make_jaxpr(fun: Callable, # type: ignore
|
||||
@overload
|
||||
def make_jaxpr(fun: Callable, # type: ignore
|
||||
static_argnums: Union[int, Iterable[int]] = (),
|
||||
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
|
||||
axis_env: Optional[Sequence[tuple[AxisName, int]]] = None,
|
||||
return_shape: Literal[True] = ...,
|
||||
abstracted_axes: Optional[Any] = None,
|
||||
) -> Callable[..., Tuple[core.ClosedJaxpr, Any]]:
|
||||
) -> Callable[..., tuple[core.ClosedJaxpr, Any]]:
|
||||
...
|
||||
|
||||
def make_jaxpr(fun: Callable,
|
||||
static_argnums: Union[int, Iterable[int]] = (),
|
||||
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
|
||||
axis_env: Optional[Sequence[tuple[AxisName, int]]] = None,
|
||||
return_shape: bool = False,
|
||||
abstracted_axes: Optional[Any] = None,
|
||||
) -> Callable[..., Union[core.ClosedJaxpr,
|
||||
Tuple[core.ClosedJaxpr, Any]]]:
|
||||
tuple[core.ClosedJaxpr, Any]]]:
|
||||
"""Creates a function that produces its jaxpr given example args.
|
||||
|
||||
Args:
|
||||
|
@ -15,8 +15,7 @@
|
||||
import inspect
|
||||
import operator
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Dict, Iterable, List, Optional, Sequence,
|
||||
Set, Tuple, Union)
|
||||
from typing import Any, Callable, Iterable, Optional, Sequence, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -39,7 +38,7 @@ traceback_util.register_exclusion(__file__)
|
||||
|
||||
map = safe_map
|
||||
|
||||
def _ensure_index(x: Any) -> Union[int, Tuple[int, ...]]:
|
||||
def _ensure_index(x: Any) -> Union[int, tuple[int, ...]]:
|
||||
"""Ensure x is either an index or a tuple of indices."""
|
||||
x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
|
||||
try:
|
||||
@ -47,7 +46,7 @@ def _ensure_index(x: Any) -> Union[int, Tuple[int, ...]]:
|
||||
except TypeError:
|
||||
return tuple(map(operator.index, x))
|
||||
|
||||
def _ensure_index_tuple(x: Any) -> Tuple[int, ...]:
|
||||
def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
|
||||
"""Convert x to a tuple of indices."""
|
||||
x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
|
||||
try:
|
||||
@ -60,7 +59,7 @@ def _ensure_str(x: str) -> str:
|
||||
raise TypeError(f"argument is not a string: {x}")
|
||||
return x
|
||||
|
||||
def _ensure_str_tuple(x: Union[str, Iterable[str]]) -> Tuple[str, ...]:
|
||||
def _ensure_str_tuple(x: Union[str, Iterable[str]]) -> tuple[str, ...]:
|
||||
"""Convert x to a tuple of strings."""
|
||||
if isinstance(x, str):
|
||||
return (x,)
|
||||
@ -97,7 +96,7 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):
|
||||
|
||||
def flattened_fun_in_tree(
|
||||
fn: lu.WrappedFun
|
||||
) -> Optional[Tuple[PyTreeDef, Callable[[], PyTreeDef], bool]]:
|
||||
) -> Optional[tuple[PyTreeDef, Callable[[], PyTreeDef], bool]]:
|
||||
# This implementation relies on internal details of linear_util.py's
|
||||
# WrappedFun, but it's for the worthy cause of better user error messages.
|
||||
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
|
||||
@ -149,7 +148,7 @@ _POSITIONAL_ARGUMENTS = (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
|
||||
def validate_argnums(sig: inspect.Signature, argnums: Tuple[int, ...], argnums_name: str) -> None:
|
||||
def validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_name: str) -> None:
|
||||
"""
|
||||
Validate that the argnums are sensible for a given function.
|
||||
|
||||
@ -183,7 +182,7 @@ _KEYWORD_ARGUMENTS = (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
)
|
||||
def validate_argnames(sig: inspect.Signature, argnames: Tuple[str, ...], argnames_name: str) -> None:
|
||||
def validate_argnames(sig: inspect.Signature, argnames: tuple[str, ...], argnames_name: str) -> None:
|
||||
"""
|
||||
Validate that the argnames are sensible for a given function.
|
||||
|
||||
@ -192,8 +191,8 @@ def validate_argnames(sig: inspect.Signature, argnames: Tuple[str, ...], argname
|
||||
marked as position-only (`f(pos_only, /, ...)`).
|
||||
"""
|
||||
var_kwargs = False
|
||||
valid_kwargs: Set[str] = set()
|
||||
invalid_kwargs: Set[str] = set()
|
||||
valid_kwargs: set[str] = set()
|
||||
invalid_kwargs: set[str] = set()
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.kind in _KEYWORD_ARGUMENTS:
|
||||
valid_kwargs.add(param_name)
|
||||
@ -253,7 +252,7 @@ def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True):
|
||||
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
|
||||
|
||||
def _ensure_inbounds(allow_invalid: bool, num_args: int, argnums: Sequence[int]
|
||||
) -> Tuple[int, ...]:
|
||||
) -> tuple[int, ...]:
|
||||
"""Ensure argnum is within bounds. Also resolves negative argnums."""
|
||||
result = []
|
||||
for i in argnums:
|
||||
@ -267,8 +266,8 @@ def _ensure_inbounds(allow_invalid: bool, num_args: int, argnums: Sequence[int]
|
||||
return tuple(result)
|
||||
|
||||
|
||||
def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
|
||||
args: Tuple[Any, ...], *, allow_invalid: bool):
|
||||
def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
|
||||
args: tuple[Any, ...], *, allow_invalid: bool):
|
||||
"Version of ``argnums_partial`` that checks hashability of static_argnums."
|
||||
if not static_argnums:
|
||||
return f, args
|
||||
@ -305,13 +304,13 @@ def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
|
||||
yield ans
|
||||
|
||||
|
||||
def argnames_partial_except(f: lu.WrappedFun, static_argnames: Tuple[str, ...],
|
||||
kwargs: Dict[str, Any]):
|
||||
def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
|
||||
kwargs: dict[str, Any]):
|
||||
if not static_argnames:
|
||||
return f, kwargs
|
||||
dyn_kwargs = {k: v for k, v in kwargs.items() if k not in static_argnames}
|
||||
|
||||
fixed_kwargs: Dict[str, Any] = {}
|
||||
fixed_kwargs: dict[str, Any] = {}
|
||||
for k, arg in kwargs.items():
|
||||
if k not in dyn_kwargs:
|
||||
try:
|
||||
@ -333,16 +332,16 @@ def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
|
||||
yield ans
|
||||
|
||||
|
||||
def donation_vector(donate_argnums, args, kwargs) -> Tuple[bool, ...]:
|
||||
def donation_vector(donate_argnums, args, kwargs) -> tuple[bool, ...]:
|
||||
"""Returns a tuple with a boolean value for each leaf in args."""
|
||||
res: List[bool] = []
|
||||
res: list[bool] = []
|
||||
for i, arg in enumerate(args):
|
||||
donate = bool(i in donate_argnums)
|
||||
res.extend((donate,) * tree_structure(arg).num_leaves)
|
||||
res.extend((False,) * tree_structure(kwargs).num_leaves)
|
||||
return tuple(res)
|
||||
|
||||
def rebase_donate_argnums(donate_argnums, static_argnums) -> Tuple[int, ...]:
|
||||
def rebase_donate_argnums(donate_argnums, static_argnums) -> tuple[int, ...]:
|
||||
"""Shifts donate to account for static.
|
||||
|
||||
>>> rebase_donate_argnums((3, 4), (0, 1))
|
||||
@ -426,7 +425,7 @@ def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
|
||||
|
||||
def flat_out_axes(
|
||||
f: lu.WrappedFun, out_spec: Any
|
||||
) -> Tuple[lu.WrappedFun, Callable]:
|
||||
) -> tuple[lu.WrappedFun, Callable]:
|
||||
leaves, treedef = tree_flatten(out_spec)
|
||||
f, out_axes = _flat_out_axes(f, tuple(leaves), treedef)
|
||||
return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef))
|
||||
@ -473,7 +472,7 @@ def infer_argnums_and_argnames(
|
||||
sig: inspect.Signature,
|
||||
argnums: Union[int, Iterable[int], None],
|
||||
argnames: Union[str, Iterable[str], None],
|
||||
) -> Tuple[Tuple[int, ...], Tuple[str, ...]]:
|
||||
) -> tuple[tuple[int, ...], tuple[str, ...]]:
|
||||
"""Infer missing argnums and argnames for a function with inspect."""
|
||||
if argnums is None and argnames is None:
|
||||
return (), ()
|
||||
@ -504,7 +503,7 @@ def infer_argnums_and_argnames(
|
||||
|
||||
def resolve_argnums(
|
||||
fun, donate_argnums, static_argnums, static_argnames
|
||||
) -> Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[str, ...]]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[str, ...]]:
|
||||
# Coerce input
|
||||
donate_argnums = _ensure_index_tuple(donate_argnums)
|
||||
|
||||
@ -563,7 +562,7 @@ def shaped_abstractify(x):
|
||||
return _shaped_abstractify_handlers[type(x)](x)
|
||||
except KeyError:
|
||||
return _shaped_abstractify_slow(x)
|
||||
_shaped_abstractify_handlers: Dict[Any, Callable[[Any], core.ShapedArray]] = {}
|
||||
_shaped_abstractify_handlers: dict[Any, Callable[[Any], core.ShapedArray]] = {}
|
||||
|
||||
|
||||
def _str_abstractify(x):
|
||||
@ -591,9 +590,9 @@ def api_hook(fun, tag: str):
|
||||
return fun
|
||||
|
||||
|
||||
def debug_info(traced_for: str, fun: Callable, args: Tuple[Any],
|
||||
kwargs: Dict[str, Any], static_argnums: Tuple[int, ...],
|
||||
static_argnames: Tuple[str, ...]) -> Optional[TracingDebugInfo]:
|
||||
def debug_info(traced_for: str, fun: Callable, args: tuple[Any],
|
||||
kwargs: dict[str, Any], static_argnums: tuple[int, ...],
|
||||
static_argnames: tuple[str, ...]) -> Optional[TracingDebugInfo]:
|
||||
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
|
||||
src = fun_sourceinfo(fun)
|
||||
arg_names = _arg_names(fun, args, kwargs, static_argnums, static_argnames)
|
||||
@ -613,7 +612,7 @@ def fun_sourceinfo(fun: Callable) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def _arg_names(fn, args, kwargs, static_argnums, static_argnames,
|
||||
) -> Optional[Tuple[str, ...]]:
|
||||
) -> Optional[tuple[str, ...]]:
|
||||
static = object()
|
||||
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
|
||||
static_argnames_ = set(static_argnames)
|
||||
@ -633,7 +632,7 @@ def result_paths(*args, **kwargs):
|
||||
yield ans, [keystr(path) for path, _ in generate_key_paths(ans)]
|
||||
|
||||
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo],
|
||||
result_paths: Optional[Tuple[Optional[str], ...]] = None,
|
||||
result_paths: Optional[tuple[Optional[str], ...]] = None,
|
||||
) -> core.Jaxpr:
|
||||
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
|
||||
if trace_debug is None:
|
||||
@ -647,7 +646,7 @@ def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo],
|
||||
return jaxpr.replace(debug_info=debug_info)
|
||||
|
||||
def debug_info_final(f: lu.WrappedFun, dbg: Optional[TracingDebugInfo],
|
||||
res_paths: Callable[[], Tuple[str, ...]]) -> lu.WrappedFun:
|
||||
res_paths: Callable[[], tuple[str, ...]]) -> lu.WrappedFun:
|
||||
"Attach trace-time debug info and result paths lazy thunk to an lu.WrappedFun"
|
||||
if dbg is None: return f
|
||||
assert dbg.result_paths is None
|
||||
|
@ -18,8 +18,8 @@ import math
|
||||
import operator as op
|
||||
import numpy as np
|
||||
import functools
|
||||
from typing import (Any, Callable, List, Optional, Sequence, Set, Tuple,
|
||||
Union, cast, TYPE_CHECKING)
|
||||
from typing import (Any, Callable, Optional, Sequence, Union, cast,
|
||||
TYPE_CHECKING)
|
||||
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import api
|
||||
@ -42,9 +42,9 @@ from jax._src.sharding_impls import (
|
||||
from jax._src.typing import ArrayLike
|
||||
from jax._src.util import use_cpp_class, use_cpp_method
|
||||
|
||||
Shape = Tuple[int, ...]
|
||||
Shape = tuple[int, ...]
|
||||
Device = xc.Device
|
||||
Index = Tuple[slice, ...]
|
||||
Index = tuple[slice, ...]
|
||||
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.
|
||||
|
||||
|
||||
@ -149,7 +149,7 @@ class ArrayImpl(basearray.Array):
|
||||
|
||||
aval: core.ShapedArray
|
||||
_sharding: Sharding
|
||||
_arrays: List[ArrayImpl]
|
||||
_arrays: list[ArrayImpl]
|
||||
_committed: bool
|
||||
_skip_checks: bool
|
||||
_npy_value: Optional[np.ndarray]
|
||||
@ -402,7 +402,7 @@ class ArrayImpl(basearray.Array):
|
||||
raise ValueError('Length of devices is greater than 1. '
|
||||
'Please use `.devices()`.')
|
||||
|
||||
def devices(self) -> Set[Device]:
|
||||
def devices(self) -> set[Device]:
|
||||
self._check_if_deleted()
|
||||
return self.sharding.device_set
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, Set
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
import numpy as np
|
||||
|
||||
from jax._src.sharding import Sharding
|
||||
@ -33,7 +33,7 @@ class Array(abc.ABC):
|
||||
aval: Any
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]: ...
|
||||
def shape(self) -> tuple[int, ...]: ...
|
||||
|
||||
@property
|
||||
def sharding(self) -> Sharding: ...
|
||||
@ -110,9 +110,9 @@ class Array(abc.ABC):
|
||||
def __index__(self) -> int: ...
|
||||
|
||||
# np.ndarray methods:
|
||||
def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
def all(self, axis: Optional[Union[int, tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None) -> Array: ...
|
||||
def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
def any(self, axis: Optional[Union[int, tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None) -> Array: ...
|
||||
def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ...
|
||||
def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ...
|
||||
@ -125,9 +125,9 @@ class Array(abc.ABC):
|
||||
def conj(self) -> Array: ...
|
||||
def conjugate(self) -> Array: ...
|
||||
def copy(self) -> Array: ...
|
||||
def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
def cumprod(self, axis: Optional[Union[int, tuple[int, ...]]] = None,
|
||||
dtype=None, out=None) -> Array: ...
|
||||
def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
def cumsum(self, axis: Optional[Union[int, tuple[int, ...]]] = None,
|
||||
dtype=None, out=None) -> Array: ...
|
||||
def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Array: ...
|
||||
def dot(self, b, *, precision=None) -> Array: ...
|
||||
@ -135,18 +135,18 @@ class Array(abc.ABC):
|
||||
@property
|
||||
def imag(self) -> Array: ...
|
||||
def item(self, *args) -> Any: ...
|
||||
def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
def max(self, axis: Optional[Union[int, tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None) -> Array: ...
|
||||
def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
def mean(self, axis: Optional[Union[int, tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=False, *, where=None,) -> Array: ...
|
||||
def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
def min(self, axis: Optional[Union[int, tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None) -> Array: ...
|
||||
@property
|
||||
def nbytes(self) -> int: ...
|
||||
def nonzero(self, *, size=None, fill_value=None) -> Array: ...
|
||||
def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
def prod(self, axis: Optional[Union[int, tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None) -> Array: ...
|
||||
def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
def ptp(self, axis: Optional[Union[int, tuple[int, ...]]] = None, out=None,
|
||||
keepdims=False,) -> Array: ...
|
||||
def ravel(self, order='C') -> Array: ...
|
||||
@property
|
||||
@ -157,16 +157,16 @@ class Array(abc.ABC):
|
||||
def round(self, decimals=0, out=None) -> Array: ...
|
||||
def searchsorted(self, v, side='left', sorter=None) -> Array: ...
|
||||
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
|
||||
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: ...
|
||||
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
def squeeze(self, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array: ...
|
||||
def std(self, axis: Optional[Union[int, tuple[int, ...]]] = None,
|
||||
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
|
||||
def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
def sum(self, axis: Optional[Union[int, tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None) -> Array: ...
|
||||
def swapaxes(self, axis1: int, axis2: int) -> Array: ...
|
||||
def take(self, indices, axis: Optional[int] = None, out=None,
|
||||
mode=None) -> Array: ...
|
||||
def tobytes(self, order='C') -> bytes: ...
|
||||
def tolist(self) -> List[Any]: ...
|
||||
def tolist(self) -> list[Any]: ...
|
||||
def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None,
|
||||
out=None) -> Array: ...
|
||||
def transpose(self, *args) -> Array: ...
|
||||
@ -174,7 +174,7 @@ class Array(abc.ABC):
|
||||
def T(self) -> Array: ...
|
||||
@property
|
||||
def mT(self) -> Array: ...
|
||||
def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
def var(self, axis: Optional[Union[int, tuple[int, ...]]] = None,
|
||||
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
|
||||
def view(self, dtype=None, type=None) -> Array: ...
|
||||
|
||||
@ -196,7 +196,7 @@ class Array(abc.ABC):
|
||||
def copy_to_host_async(self) -> None: ...
|
||||
def delete(self) -> None: ...
|
||||
def device(self) -> Device: ...
|
||||
def devices(self) -> Set[Device]: ...
|
||||
def devices(self) -> set[Device]: ...
|
||||
@property
|
||||
def global_shards(self) -> Sequence[Shard]: ...
|
||||
def is_deleted(self) -> bool: ...
|
||||
|
@ -16,8 +16,7 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools as it
|
||||
from typing import (Union, Optional, Callable, Dict, Tuple, TypeVar,
|
||||
FrozenSet, Type, Set, List, Sequence, Any)
|
||||
from typing import Union, Optional, Callable, TypeVar, Sequence, Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -57,8 +56,8 @@ zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
Bool = Union[bool, Array]
|
||||
Int = Union[int, Array]
|
||||
ErrorCategory = Type['JaxException']
|
||||
Payload = List[Union[np.ndarray, Array]]
|
||||
ErrorCategory = type['JaxException']
|
||||
Payload = list[Union[np.ndarray, Array]]
|
||||
PyTreeDef = jtu.PyTreeDef
|
||||
Out = TypeVar('Out')
|
||||
|
||||
@ -102,8 +101,8 @@ class JaxException(Exception):
|
||||
@functools.total_ordering
|
||||
@dataclasses.dataclass(eq=True, frozen=True)
|
||||
class ErrorEffect(effects.Effect):
|
||||
error_type: Type[JaxException]
|
||||
shape_dtypes: Tuple[api.ShapeDtypeStruct, ...]
|
||||
error_type: type[JaxException]
|
||||
shape_dtypes: tuple[api.ShapeDtypeStruct, ...]
|
||||
|
||||
def __lt__(self, other: 'ErrorEffect'):
|
||||
shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable
|
||||
@ -195,7 +194,7 @@ class FailedCheckError(JaxException):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BatchedError(JaxException):
|
||||
error_mapping: Dict[Tuple[int, ...], JaxException]
|
||||
error_mapping: dict[tuple[int, ...], JaxException]
|
||||
|
||||
def __post_init__(self):
|
||||
traceback_info = list(self.error_mapping.values())[0].traceback_info
|
||||
@ -212,10 +211,10 @@ class BatchedError(JaxException):
|
||||
@jtu.register_pytree_node_class
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Error:
|
||||
_pred: Dict[ErrorEffect, Bool]
|
||||
_code: Dict[ErrorEffect, Int]
|
||||
_metadata: Dict[Int, PyTreeDef] # mapping of code to JaxException treedef.
|
||||
_payload: Dict[ErrorEffect, Payload]
|
||||
_pred: dict[ErrorEffect, Bool]
|
||||
_code: dict[ErrorEffect, Int]
|
||||
_metadata: dict[Int, PyTreeDef] # mapping of code to JaxException treedef.
|
||||
_payload: dict[ErrorEffect, Payload]
|
||||
|
||||
def get(self) -> Optional[str]:
|
||||
"""Returns error message if error happened, None if no error happened."""
|
||||
@ -278,7 +277,7 @@ class Error:
|
||||
new_metadata = {**self._metadata, **metadata}
|
||||
return Error(new_errs, new_codes, new_metadata, new_payload)
|
||||
|
||||
def _add_placeholder_effects(self, effects: Set[ErrorEffect]):
|
||||
def _add_placeholder_effects(self, effects: set[ErrorEffect]):
|
||||
"""Fill out Error with `effects` and np.ones arrays of their payloads."""
|
||||
new_err = self._pred.copy()
|
||||
new_code = self._code.copy()
|
||||
@ -337,7 +336,7 @@ def _flatten_and_get_error_metadata_thunk(*invals):
|
||||
|
||||
def default_checkify_rule(primitive: core.Primitive, error: Error,
|
||||
enabled_errors, *invals: core.Value,
|
||||
**params: Any) -> Tuple[Error, Sequence[core.Value]]:
|
||||
**params: Any) -> tuple[Error, Sequence[core.Value]]:
|
||||
"""Default rule for primitives in `checkify` interpreter."""
|
||||
if 'call_jaxpr' not in params:
|
||||
# Default non-HOP case: just call primitive and don't update error.
|
||||
@ -389,15 +388,15 @@ def get_shaped_aval(val):
|
||||
return core.raise_to_shaped(core.get_aval(val))
|
||||
|
||||
def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
|
||||
error: Error, *args) -> Tuple[Error, List[core.Value]]:
|
||||
error: Error, *args) -> tuple[Error, list[core.Value]]:
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
return checkify_jaxpr_flat(jaxpr.jaxpr, jaxpr.consts,
|
||||
enabled_errors, err_tree, *err_vals, *args)
|
||||
|
||||
def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
|
||||
enabled_errors, err_tree: PyTreeDef,
|
||||
*args: core.Value) -> Tuple[Error, List[Any]]:
|
||||
env: Dict[core.Var, Any] = {}
|
||||
*args: core.Value) -> tuple[Error, list[Any]]:
|
||||
env: dict[core.Var, Any] = {}
|
||||
err_vals, in_args = split_list(args, [err_tree.num_leaves])
|
||||
error = jtu.tree_unflatten(err_tree, err_vals)
|
||||
|
||||
@ -558,7 +557,7 @@ ad.primitive_jvps[check_p] = check_jvp_rule
|
||||
## checkify rules
|
||||
|
||||
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
|
||||
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
|
||||
error_checks: dict[core.Primitive, ErrorCheckRule] = {}
|
||||
|
||||
|
||||
def summary() -> str:
|
||||
@ -740,7 +739,7 @@ error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check,
|
||||
@weakref_lru_cache
|
||||
def jaxpr_to_checkify_jaxpr(
|
||||
jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef,
|
||||
*flat_err_and_in_vals) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]:
|
||||
*flat_err_and_in_vals) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]:
|
||||
checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr,
|
||||
jaxpr.consts, enabled_errors,
|
||||
err_tree)
|
||||
@ -823,7 +822,7 @@ error_checks[lax.scan_p] = scan_error_check
|
||||
def checkify_while_body_jaxpr(
|
||||
cond_jaxpr: core.ClosedJaxpr, body_jaxpr: core.ClosedJaxpr,
|
||||
enabled_errors, error: Error,
|
||||
c_consts_num: int) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]:
|
||||
c_consts_num: int) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]:
|
||||
cond_f = core.jaxpr_as_fun(cond_jaxpr)
|
||||
body_f = core.jaxpr_as_fun(body_jaxpr)
|
||||
def new_body_f(*c_consts_and_vals):
|
||||
@ -1062,8 +1061,8 @@ all_checks = automatic_checks | user_checks
|
||||
|
||||
|
||||
def checkify(f: Callable[..., Out],
|
||||
errors: FrozenSet[ErrorCategory] = user_checks
|
||||
) -> Callable[..., Tuple[Error, Out]]:
|
||||
errors: frozenset[ErrorCategory] = user_checks
|
||||
) -> Callable[..., tuple[Error, Out]]:
|
||||
"""Functionalize `check` calls in `fun`, and optionally add run-time error checks.
|
||||
|
||||
Run-time errors are either user-added :func:`~check` assertions, or
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Type, Sequence, Tuple
|
||||
from typing import Optional, Sequence
|
||||
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -28,7 +28,7 @@ class ClusterEnv:
|
||||
:class:`ClusterEnv` subclasses are automatically detected when imported.
|
||||
"""
|
||||
|
||||
_cluster_types: List[Type['ClusterEnv']] = []
|
||||
_cluster_types: list[type['ClusterEnv']] = []
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
@ -41,7 +41,7 @@ class ClusterEnv:
|
||||
num_processes: Optional[int],
|
||||
process_id: Optional[int],
|
||||
local_device_ids: Optional[Sequence[int]]
|
||||
) -> Tuple[Optional[str], Optional[int], Optional[int],
|
||||
) -> tuple[Optional[str], Optional[int], Optional[int],
|
||||
Optional[Sequence[int]]]:
|
||||
if all(p is not None for p in (coordinator_address, num_processes,
|
||||
process_id, local_device_ids)):
|
||||
|
@ -18,7 +18,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Optional
|
||||
import zlib
|
||||
|
||||
import numpy as np
|
||||
@ -310,7 +310,7 @@ _xla_flags_to_exclude_from_cache_key = [
|
||||
"--xla_tpu_sdc_checker_enable_sdc_event_callbacks",
|
||||
]
|
||||
|
||||
extra_flag_prefixes_to_include_in_cache_key: List[str] = []
|
||||
extra_flag_prefixes_to_include_in_cache_key: list[str] = []
|
||||
|
||||
|
||||
def _hash_xla_flags(hash_obj):
|
||||
|
@ -19,7 +19,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, List, Callable, Hashable, NamedTuple, Iterator, Optional
|
||||
from typing import Any, Callable, Hashable, NamedTuple, Iterator, Optional
|
||||
|
||||
from jax._src import lib
|
||||
from jax._src.lib import jax_jit
|
||||
@ -246,7 +246,7 @@ class Config:
|
||||
default_value=True)
|
||||
|
||||
def define_enum_state(
|
||||
self, name: str, enum_values: List[str], default: Optional[str],
|
||||
self, name: str, enum_values: list[str], default: Optional[str],
|
||||
help: str, update_global_hook: Optional[Callable[[str], None]] = None,
|
||||
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \
|
||||
= None):
|
||||
|
137
jax/_src/core.py
137
jax/_src/core.py
@ -27,10 +27,9 @@ import operator
|
||||
from operator import attrgetter
|
||||
import threading
|
||||
import types
|
||||
from typing import (Any, Callable, ClassVar, DefaultDict, Dict, FrozenSet,
|
||||
Generator, Generic, Hashable, Iterable, Iterator, List,
|
||||
NamedTuple, Optional, Sequence, Set, Tuple, Type, TypeVar,
|
||||
Union, cast, overload)
|
||||
from typing import (Any, Callable, ClassVar, DefaultDict, Generator, Generic,
|
||||
Hashable, Iterable, Iterator, NamedTuple, Optional,
|
||||
Sequence, TypeVar, Union, cast, overload)
|
||||
import warnings
|
||||
from weakref import ref
|
||||
|
||||
@ -71,17 +70,17 @@ no_effects: Effects = effects.no_effects
|
||||
class JaxprDebugInfo(NamedTuple):
|
||||
traced_for: str # e.g. 'jit', 'scan', etc
|
||||
func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}'
|
||||
arg_names: Tuple[Optional[str], ...] # e.g. ('args[0]', ... )
|
||||
result_paths: Tuple[Optional[str], ...] # e.g. ('[0]', '[1]', ...)
|
||||
arg_names: tuple[Optional[str], ...] # e.g. ('args[0]', ... )
|
||||
result_paths: tuple[Optional[str], ...] # e.g. ('[0]', '[1]', ...)
|
||||
|
||||
class Jaxpr:
|
||||
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
|
||||
'_effects', '_debug_info']
|
||||
|
||||
_constvars: List[Var]
|
||||
_invars: List[Var]
|
||||
_outvars: List[Atom]
|
||||
_eqns: List[JaxprEqn]
|
||||
_constvars: list[Var]
|
||||
_invars: list[Var]
|
||||
_outvars: list[Atom]
|
||||
_eqns: list[JaxprEqn]
|
||||
_effects: Effects
|
||||
_debug_info: Optional[JaxprDebugInfo]
|
||||
|
||||
@ -171,7 +170,7 @@ class ClosedJaxpr:
|
||||
__slots__ = ['__weakref__', '_jaxpr', '_consts']
|
||||
|
||||
_jaxpr: Jaxpr
|
||||
_consts: List[Any]
|
||||
_consts: list[Any]
|
||||
|
||||
jaxpr = property(lambda self: self._jaxpr)
|
||||
consts = property(lambda self: self._consts)
|
||||
@ -230,10 +229,10 @@ def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
|
||||
|
||||
|
||||
class JaxprEqn(NamedTuple):
|
||||
invars: List[Atom]
|
||||
outvars: List[Var]
|
||||
invars: list[Atom]
|
||||
outvars: list[Var]
|
||||
primitive: Primitive
|
||||
params: Dict[str, Any]
|
||||
params: dict[str, Any]
|
||||
effects: Effects
|
||||
source_info: source_info_util.SourceInfo
|
||||
|
||||
@ -242,10 +241,10 @@ class JaxprEqn(NamedTuple):
|
||||
|
||||
def replace(
|
||||
self,
|
||||
invars: Optional[List[Atom]] = None,
|
||||
outvars: Optional[List[Var]] = None,
|
||||
invars: Optional[list[Atom]] = None,
|
||||
outvars: Optional[list[Var]] = None,
|
||||
primitive: Optional[Primitive] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
effects: Optional[Effects] = None,
|
||||
source_info: Optional[source_info_util.SourceInfo] = None,
|
||||
):
|
||||
@ -355,7 +354,7 @@ class Literal:
|
||||
else:
|
||||
return f'Literal(val={self.val})'
|
||||
|
||||
literalable_types: Set[type] = set()
|
||||
literalable_types: set[type] = set()
|
||||
|
||||
Atom = Union[Var, Literal]
|
||||
|
||||
@ -436,7 +435,7 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True):
|
||||
assert typecheck(v.aval, val), (v.aval, val)
|
||||
env[v] = val
|
||||
|
||||
env: Dict[Var, Any] = {}
|
||||
env: dict[Var, Any] = {}
|
||||
map(write, jaxpr.constvars, consts)
|
||||
map(write, jaxpr.invars, args)
|
||||
lu = last_used(jaxpr)
|
||||
@ -836,8 +835,8 @@ class EvalTrace(Trace):
|
||||
|
||||
class MainTrace:
|
||||
level: int
|
||||
trace_type: Type[Trace]
|
||||
payload: Dict[str, Any]
|
||||
trace_type: type[Trace]
|
||||
payload: dict[str, Any]
|
||||
|
||||
def __init__(self, level, trace_type, **payload) -> None:
|
||||
self.level = level
|
||||
@ -861,7 +860,7 @@ class MainTrace:
|
||||
|
||||
class TraceStack:
|
||||
# See comments in https://github.com/google/jax/pull/3370
|
||||
stack: List[MainTrace]
|
||||
stack: list[MainTrace]
|
||||
dynamic: MainTrace
|
||||
|
||||
def __init__(self):
|
||||
@ -912,8 +911,8 @@ no_axis_name = object()
|
||||
|
||||
class TraceState:
|
||||
trace_stack: TraceStack
|
||||
substack: List[Sublevel]
|
||||
axis_env: List[AxisEnvFrame]
|
||||
substack: list[Sublevel]
|
||||
axis_env: list[AxisEnvFrame]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.trace_stack = TraceStack()
|
||||
@ -995,7 +994,7 @@ the following:
|
||||
"""
|
||||
|
||||
def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]
|
||||
) -> List[Tracer]:
|
||||
) -> list[Tracer]:
|
||||
"""Find the leaked tracers holding a reference to the MainTrace or SubLevel.
|
||||
|
||||
It's possible there's none! eg. there's some cases where JAX itself holds a
|
||||
@ -1012,14 +1011,14 @@ def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]
|
||||
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces)))
|
||||
return tracers
|
||||
|
||||
def leaked_tracer_error(name: str, t, tracers: List[Tracer]) -> Exception:
|
||||
def leaked_tracer_error(name: str, t, tracers: list[Tracer]) -> Exception:
|
||||
assert tracers
|
||||
why = partial(_why_alive, {id(tracers)})
|
||||
msgs = '\n\n'.join(f'{tracers[i]}{tracers[i]._origin_msg()}{why(tracers[i])}'
|
||||
for i in range(len(tracers)))
|
||||
return Exception(f'Leaked {name} {t}. Leaked tracer(s):\n\n{msgs}\n')
|
||||
|
||||
def _why_alive(ignore_ids: Set[int], x: Any) -> str:
|
||||
def _why_alive(ignore_ids: set[int], x: Any) -> str:
|
||||
parents = lambda x: [r for r in gc.get_referrers(x) if id(r) not in ignore_ids]
|
||||
child, lines, seen = x, [], set()
|
||||
while (id(child) not in seen and type(child) is not types.ModuleType
|
||||
@ -1078,7 +1077,7 @@ def _why_alive_container_info(container, obj_id) -> str:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def new_main(trace_type: Type[Trace],
|
||||
def new_main(trace_type: type[Trace],
|
||||
dynamic: bool = False,
|
||||
**payload) -> Generator[MainTrace, None, None]:
|
||||
# See comments in https://github.com/google/jax/pull/3370
|
||||
@ -1106,7 +1105,7 @@ def new_main(trace_type: Type[Trace],
|
||||
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
|
||||
|
||||
@contextmanager
|
||||
def new_base_main(trace_type: Type[Trace],
|
||||
def new_base_main(trace_type: type[Trace],
|
||||
**payload) -> Generator[MainTrace, None, None]:
|
||||
# See comments in https://github.com/google/jax/pull/3370
|
||||
stack = thread_local_state.trace_state.trace_stack
|
||||
@ -1234,7 +1233,7 @@ def get_referent(x: Any) -> Any:
|
||||
def same_referent(x: Any, y: Any) -> bool:
|
||||
return get_referent(x) is get_referent(y)
|
||||
|
||||
def dedup_referents(itr: Iterable[Any]) -> List[Any]:
|
||||
def dedup_referents(itr: Iterable[Any]) -> list[Any]:
|
||||
return list({HashableWrapper(get_referent(x)):x for x in itr}.values())
|
||||
|
||||
def definitely_equal(x, y):
|
||||
@ -1247,7 +1246,7 @@ def definitely_equal(x, y):
|
||||
# -------------------- abstract values --------------------
|
||||
|
||||
class AbstractValue:
|
||||
__slots__: List[str] = []
|
||||
__slots__: list[str] = []
|
||||
|
||||
def at_least_vspace(self):
|
||||
raise NotImplementedError("must override")
|
||||
@ -1292,13 +1291,13 @@ class OutDBIdx:
|
||||
# a sequence of pairs where the first element of each pair is an AbstractValue
|
||||
# (possibly containing DBIdx instances in its shape) and the second is a boolean
|
||||
# indicating whether that argument is explicit (i.e. passed to the callable).
|
||||
InputType = Tuple[Tuple[AbstractValue, bool], ...] # DBIdx in shapes
|
||||
InputType = tuple[tuple[AbstractValue, bool], ...] # DBIdx in shapes
|
||||
|
||||
# For annotating jaxpr output types, we use a sequence of pairs where the first
|
||||
# element of each pair is an AbstractValue (possibly containing InDBIdx and/or
|
||||
# OutDBIdx instances in its shape) and the second is a boolean indicating
|
||||
# whether that argument is explicit (i.e. returned by the callable).
|
||||
OutputType = Tuple[Tuple[AbstractValue, bool], ...] # InDBIdx / OutDBIdx shapes
|
||||
OutputType = tuple[tuple[AbstractValue, bool], ...] # InDBIdx / OutDBIdx shapes
|
||||
|
||||
|
||||
def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType:
|
||||
@ -1669,7 +1668,7 @@ def primal_dtype_to_tangent_dtype(primal_dtype):
|
||||
# it's kind of convenient!
|
||||
class DShapedArray(UnshapedArray):
|
||||
__slots__ = ['shape']
|
||||
shape: Tuple[AxisSize, ...] # noqa: F821
|
||||
shape: tuple[AxisSize, ...] # noqa: F821
|
||||
array_abstraction_level: int = 3
|
||||
|
||||
def __init__(self, shape, dtype, weak_type=False):
|
||||
@ -1731,7 +1730,7 @@ class DConcreteArray(DShapedArray):
|
||||
self.val = val
|
||||
|
||||
|
||||
pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {}
|
||||
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
|
||||
|
||||
|
||||
class DArray:
|
||||
@ -1813,7 +1812,7 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None):
|
||||
if handler: return handler(aval, weak_type)
|
||||
raise TypeError(type(aval))
|
||||
|
||||
raise_to_shaped_mappings : Dict[type, Callable] = {
|
||||
raise_to_shaped_mappings : dict[type, Callable] = {
|
||||
AbstractToken: lambda aval, _: aval,
|
||||
Bot: lambda aval, _: aval,
|
||||
UnshapedArray: lambda aval, _: aval,
|
||||
@ -1907,7 +1906,7 @@ class DimensionHandler:
|
||||
return d
|
||||
|
||||
_dimension_handler_int = DimensionHandler()
|
||||
_SPECIAL_DIMENSION_HANDLERS: Dict[type, DimensionHandler] = {}
|
||||
_SPECIAL_DIMENSION_HANDLERS: dict[type, DimensionHandler] = {}
|
||||
DArrayDimHandler = type('DArrayDimHandler', (DimensionHandler,), {})()
|
||||
|
||||
def _get_special_dim_handler(dim: DimSize) -> Optional[DimensionHandler]:
|
||||
@ -1917,7 +1916,7 @@ def _get_special_dim_handler(dim: DimSize) -> Optional[DimensionHandler]:
|
||||
return DArrayDimHandler
|
||||
return _SPECIAL_DIMENSION_HANDLERS.get(type(dim))
|
||||
|
||||
def _dim_handler_and_canonical(*dlist: DimSize) -> Tuple[DimensionHandler, Tuple[DimSize, ...]]:
|
||||
def _dim_handler_and_canonical(*dlist: DimSize) -> tuple[DimensionHandler, tuple[DimSize, ...]]:
|
||||
"""Finds the handler for the given dimensions; also returns the canonical dimensions.
|
||||
|
||||
A dimension is canonical if it is a Python integer scalar, or has a type
|
||||
@ -2081,7 +2080,7 @@ def _canonicalize_dimension(dim: DimSize) -> DimSize:
|
||||
else:
|
||||
raise type_error
|
||||
|
||||
def canonicalize_shape(shape: Shape, context: str="") -> Tuple[Any, ...]:
|
||||
def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
|
||||
"""Canonicalizes and checks for errors in a user-provided shape value.
|
||||
|
||||
Args:
|
||||
@ -2321,7 +2320,7 @@ closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')
|
||||
closed_call_p.def_impl(call_impl)
|
||||
|
||||
|
||||
outfeed_primitives: Set[Primitive] = set()
|
||||
outfeed_primitives: set[Primitive] = set()
|
||||
def jaxpr_uses_outfeed(jaxpr: Jaxpr) -> bool:
|
||||
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
|
||||
return any(primitive_uses_outfeed(eqn.primitive, eqn.params)
|
||||
@ -2336,7 +2335,7 @@ def _param_uses_outfeed(param):
|
||||
return True
|
||||
return False
|
||||
|
||||
def primitive_uses_outfeed(prim: Primitive, params: Dict) -> bool:
|
||||
def primitive_uses_outfeed(prim: Primitive, params: dict) -> bool:
|
||||
if prim in outfeed_primitives:
|
||||
return True
|
||||
for param in params.values():
|
||||
@ -2476,8 +2475,8 @@ def _unmap_dshaped_array(
|
||||
else:
|
||||
raise TypeError(axis)
|
||||
|
||||
AvalMapHandlerPair = Tuple[Callable, Callable]
|
||||
aval_mapping_handlers: Dict[Type, AvalMapHandlerPair] = {
|
||||
AvalMapHandlerPair = tuple[Callable, Callable]
|
||||
aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
|
||||
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
|
||||
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
|
||||
ConcreteArray: (_map_shaped_array, _unmap_shaped_array),
|
||||
@ -2501,7 +2500,7 @@ def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
|
||||
if f.name is not no_axis_name))
|
||||
|
||||
@contextmanager
|
||||
def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]], tag: Any = None):
|
||||
def extend_axis_env_nd(axes: Iterable[tuple[AxisName, int]], tag: Any = None):
|
||||
frames = [AxisEnvFrame(axis_name, size, tag) for axis_name, size in axes]
|
||||
ts = thread_local_state.trace_state
|
||||
ts.axis_env.extend(frames)
|
||||
@ -2573,8 +2572,8 @@ def axis_frame(axis_name: AxisName, main_trace: Optional[MainTrace] = None
|
||||
f'by pmap) are available to collective operations: {named_axes}')
|
||||
|
||||
|
||||
ParamDict = Dict[str, Any]
|
||||
AxisSubst = Callable[[AxisName], Tuple[AxisName, ...]]
|
||||
ParamDict = dict[str, Any]
|
||||
AxisSubst = Callable[[AxisName], tuple[AxisName, ...]]
|
||||
|
||||
class NameGatheringSubst:
|
||||
def __init__(self):
|
||||
@ -2583,7 +2582,7 @@ class NameGatheringSubst:
|
||||
self.axis_names.add(axis_name)
|
||||
return (axis_name,)
|
||||
|
||||
def used_axis_names(primitive: Primitive, params: ParamDict) -> Set[AxisName]:
|
||||
def used_axis_names(primitive: Primitive, params: ParamDict) -> set[AxisName]:
|
||||
subst = NameGatheringSubst()
|
||||
subst_axis_names(primitive, params, subst)
|
||||
return subst.axis_names
|
||||
@ -2612,7 +2611,7 @@ class DuplicateAxisNameError(Exception):
|
||||
self.var = var
|
||||
self.eqn = None
|
||||
|
||||
def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: Dict[Var, Var]) -> Var:
|
||||
def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: dict[Var, Var]) -> Var:
|
||||
# Var identity is load-bearing, so we can't have duplicates!
|
||||
if isinstance(v, DropVar): return v
|
||||
assert v not in var_map
|
||||
@ -2627,8 +2626,8 @@ def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: Dict[Var, Var]) -> V
|
||||
var_map[v] = new_v
|
||||
return new_v
|
||||
|
||||
def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: Dict[Var, Var]) -> JaxprEqn:
|
||||
invars: List[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in eqn.invars]
|
||||
def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: dict[Var, Var]) -> JaxprEqn:
|
||||
invars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in eqn.invars]
|
||||
try:
|
||||
outvars = [subst_axis_names_var(v, subst, var_map) for v in eqn.outvars]
|
||||
except DuplicateAxisNameError as e:
|
||||
@ -2642,11 +2641,11 @@ def do_subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst
|
||||
if isinstance(jaxpr, ClosedJaxpr):
|
||||
consts = jaxpr.consts
|
||||
jaxpr = jaxpr.jaxpr
|
||||
var_map: Dict[Var, Var] = {}
|
||||
var_map: dict[Var, Var] = {}
|
||||
invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars] # type: ignore[union-attr]
|
||||
constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars] # type: ignore[union-attr]
|
||||
eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns] # type: ignore[union-attr]
|
||||
outvars: List[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars] # type: ignore[union-attr]
|
||||
outvars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars] # type: ignore[union-attr]
|
||||
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, jaxpr.effects)
|
||||
if consts is not None:
|
||||
return ClosedJaxpr(new_jaxpr, consts)
|
||||
@ -2668,11 +2667,11 @@ def replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: Effects):
|
||||
return _replace_jaxpr_effects(jaxpr, frozenset(effects))
|
||||
|
||||
@weakref_lru_cache
|
||||
def _replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: FrozenSet[Effect]):
|
||||
def _replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: frozenset[Effect]):
|
||||
return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(effects=set(effects)))
|
||||
|
||||
|
||||
axis_substitution_rules: Dict[Primitive, Callable[[ParamDict, AxisSubst, bool], ParamDict]] = {}
|
||||
axis_substitution_rules: dict[Primitive, Callable[[ParamDict, AxisSubst, bool], ParamDict]] = {}
|
||||
|
||||
# ------------------- AxisPrimitive -------------------
|
||||
# Primitives that store axis names in params and want those axis names to
|
||||
@ -2722,7 +2721,7 @@ def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
|
||||
|
||||
class JaxprTypeError(TypeError): pass
|
||||
|
||||
custom_typechecks: Dict[Primitive, Callable] = {}
|
||||
custom_typechecks: dict[Primitive, Callable] = {}
|
||||
|
||||
def _check_closed_call(_, *in_atoms, call_jaxpr):
|
||||
in_avals = [x.aval for x in in_atoms]
|
||||
@ -2765,11 +2764,11 @@ def check_jaxpr(jaxpr: Jaxpr):
|
||||
raise JaxprTypeError(msg) from None
|
||||
|
||||
def _check_jaxpr(
|
||||
ctx_factory: Callable[[], Tuple[JaxprPpContext, JaxprPpSettings]],
|
||||
ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]],
|
||||
jaxpr: Jaxpr
|
||||
) -> None:
|
||||
# Use set of variables to types to check that variables are in scope.
|
||||
env: Set[Var] = set()
|
||||
env: set[Var] = set()
|
||||
|
||||
def read(x: Atom) -> Atom:
|
||||
# Check the type annotation is itself well-typed.
|
||||
@ -2874,8 +2873,8 @@ def _check_jaxpr(
|
||||
map(read, jaxpr.outvars)
|
||||
|
||||
def check_type(
|
||||
ctx_factory: Callable[[], Tuple[JaxprPpContext, JaxprPpSettings]],
|
||||
env: Set[Var],
|
||||
ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]],
|
||||
env: set[Var],
|
||||
ty: AbstractValue,
|
||||
) -> None:
|
||||
if isinstance(ty, DShapedArray):
|
||||
@ -2912,7 +2911,7 @@ def substitute_vars_in_output_ty(
|
||||
out_type: Sequence[AbstractValue], # shapes may contain InDBIdx / OutDBIdx
|
||||
in_atoms: Sequence[Atom],
|
||||
out_binders: Sequence[Var],
|
||||
) -> List[AbstractValue]: # shapes may contain Vars
|
||||
) -> list[AbstractValue]: # shapes may contain Vars
|
||||
in_atoms = [x.val if type(x) is Literal else x for x in in_atoms]
|
||||
result = []
|
||||
for aval in out_type:
|
||||
@ -2945,7 +2944,7 @@ def _check_call(ctx_factory, prim, in_atoms, params):
|
||||
f"{len(call_jaxpr.invars)} inputs")
|
||||
|
||||
# Check `call_jaxpr` can be applied to in_atoms.
|
||||
env: Dict[Var, Atom] = {}
|
||||
env: dict[Var, Atom] = {}
|
||||
def substitute(aval: AbstractValue):
|
||||
if isinstance(aval, DShapedArray):
|
||||
aval = aval.update(shape=tuple([env.get(d, d) for d in aval.shape])) # type: ignore
|
||||
@ -2961,8 +2960,8 @@ def _check_call(ctx_factory, prim, in_atoms, params):
|
||||
_check_jaxpr(ctx_factory, call_jaxpr)
|
||||
|
||||
invars, outvars = call_jaxpr.invars, call_jaxpr.outvars
|
||||
in_map : Dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)}
|
||||
out_map: Dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars)
|
||||
in_map : dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)}
|
||||
out_map: dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars)
|
||||
if type(x) is Var}
|
||||
out_avals = [x.aval for x in call_jaxpr.outvars]
|
||||
out_type = [a.update(shape=tuple(in_map.get(d, out_map.get(d))
|
||||
@ -3094,7 +3093,7 @@ def _pp_eqn(eqn, context, settings) -> pp.Doc:
|
||||
pp.text(" ") + pp_vars(eqn.invars, context)]
|
||||
return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs])
|
||||
CustomPpEqnRule = Callable[[JaxprEqn, JaxprPpContext, JaxprPpSettings], pp.Doc]
|
||||
pp_eqn_rules: Dict[Primitive, CustomPpEqnRule] = {}
|
||||
pp_eqn_rules: dict[Primitive, CustomPpEqnRule] = {}
|
||||
|
||||
def pp_eqns(eqns, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
|
||||
return pp.join(
|
||||
@ -3109,7 +3108,7 @@ def _compact_eqn_should_include(k: str, v: Any) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
def str_eqn_compact(primitive_name: str, params: Dict) -> str:
|
||||
def str_eqn_compact(primitive_name: str, params: dict) -> str:
|
||||
"Compact equation to string conversion used in HLO metadata."
|
||||
kvs = " ".join(f"{k}={v}" for k, v in params.items()
|
||||
if _compact_eqn_should_include(k, v))
|
||||
@ -3187,9 +3186,9 @@ def pp_effect(effect: Effect, context: JaxprPpContext) -> pp.Doc:
|
||||
|
||||
# ------------------- Jaxpr util -------------------
|
||||
|
||||
def last_used(jaxpr: Jaxpr) -> Dict[Var, Optional[JaxprEqn]]:
|
||||
def last_used(jaxpr: Jaxpr) -> dict[Var, Optional[JaxprEqn]]:
|
||||
"""Returns a mapping from every var in jaxpr to what equation uses it last."""
|
||||
last_used: Dict[Var, Optional[JaxprEqn]] = {
|
||||
last_used: dict[Var, Optional[JaxprEqn]] = {
|
||||
v: None for v in jaxpr.outvars if not isinstance(v, Literal)}
|
||||
for eqn in reversed(jaxpr.eqns):
|
||||
for v in eqn.invars:
|
||||
@ -3197,8 +3196,8 @@ def last_used(jaxpr: Jaxpr) -> Dict[Var, Optional[JaxprEqn]]:
|
||||
last_used[v] = eqn
|
||||
return last_used
|
||||
|
||||
def clean_up_dead_vars(eqn: JaxprEqn, env: Dict[Var, Any],
|
||||
last_used: Dict[Var, Optional[JaxprEqn]]):
|
||||
def clean_up_dead_vars(eqn: JaxprEqn, env: dict[Var, Any],
|
||||
last_used: dict[Var, Optional[JaxprEqn]]):
|
||||
"""Remove all eqn.invars from env if eqn is the last time they were used."""
|
||||
for v in set(v for v in eqn.invars if not isinstance(v, Literal)):
|
||||
if last_used[v] is eqn:
|
||||
|
@ -15,8 +15,7 @@
|
||||
import dataclasses
|
||||
from functools import update_wrapper, reduce, partial
|
||||
import inspect
|
||||
from typing import (
|
||||
Any, Callable, Generic, List, Optional, Sequence, Tuple, TypeVar)
|
||||
from typing import Any, Callable, Generic, Optional, Sequence, TypeVar
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import custom_api_util
|
||||
@ -137,13 +136,13 @@ class custom_jvp(Generic[ReturnValue]):
|
||||
.. _tutorial: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
|
||||
"""
|
||||
fun: Callable[..., ReturnValue]
|
||||
nondiff_argnums: Tuple[int, ...]
|
||||
jvp: Optional[Callable[..., Tuple[ReturnValue, ReturnValue]]] = None
|
||||
nondiff_argnums: tuple[int, ...]
|
||||
jvp: Optional[Callable[..., tuple[ReturnValue, ReturnValue]]] = None
|
||||
symbolic_zeros: bool = False
|
||||
|
||||
def __init__(self,
|
||||
fun: Callable[..., ReturnValue],
|
||||
nondiff_argnums: Tuple[int, ...] = (),
|
||||
nondiff_argnums: tuple[int, ...] = (),
|
||||
):
|
||||
update_wrapper(self, fun)
|
||||
self.fun = fun
|
||||
@ -152,9 +151,9 @@ class custom_jvp(Generic[ReturnValue]):
|
||||
__getattr__ = custom_api_util.forward_attr
|
||||
|
||||
def defjvp(self,
|
||||
jvp: Callable[..., Tuple[ReturnValue, ReturnValue]],
|
||||
jvp: Callable[..., tuple[ReturnValue, ReturnValue]],
|
||||
symbolic_zeros: bool = False,
|
||||
) -> Callable[..., Tuple[ReturnValue, ReturnValue]]:
|
||||
) -> Callable[..., tuple[ReturnValue, ReturnValue]]:
|
||||
"""Define a custom JVP rule for the function represented by this instance.
|
||||
|
||||
Args:
|
||||
@ -491,19 +490,19 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
|
||||
def __init__(self,
|
||||
fun: Callable[..., ReturnValue],
|
||||
nondiff_argnums: Tuple[int, ...] = ()):
|
||||
nondiff_argnums: tuple[int, ...] = ()):
|
||||
update_wrapper(self, fun)
|
||||
self.fun = fun
|
||||
self.nondiff_argnums = nondiff_argnums
|
||||
self.fwd: Optional[Callable[..., Tuple[ReturnValue, Any]]] = None
|
||||
self.bwd: Optional[Callable[..., Tuple[Any, ...]]] = None
|
||||
self.fwd: Optional[Callable[..., tuple[ReturnValue, Any]]] = None
|
||||
self.bwd: Optional[Callable[..., tuple[Any, ...]]] = None
|
||||
self.symbolic_zeros = False
|
||||
|
||||
__getattr__ = custom_api_util.forward_attr
|
||||
|
||||
def defvjp(self,
|
||||
fwd: Callable[..., Tuple[ReturnValue, Any]],
|
||||
bwd: Callable[..., Tuple[Any, ...]],
|
||||
fwd: Callable[..., tuple[ReturnValue, Any]],
|
||||
bwd: Callable[..., tuple[Any, ...]],
|
||||
symbolic_zeros: bool = False,
|
||||
) -> None:
|
||||
"""Define a custom VJP rule for the function represented by this instance.
|
||||
@ -831,7 +830,7 @@ mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun(
|
||||
|
||||
def _custom_vjp_call_jaxpr_jvp(
|
||||
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
|
||||
fwd_jaxpr_thunk: Callable[..., Tuple[core.Jaxpr, Sequence[Any]]],
|
||||
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
|
||||
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
|
||||
_, args = split_list(primals, [num_consts])
|
||||
consts_dot, args_dot = split_list(tangents, [num_consts])
|
||||
@ -857,7 +856,7 @@ ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
|
||||
def _custom_vjp_call_jaxpr_vmap(
|
||||
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *,
|
||||
fun_jaxpr: core.ClosedJaxpr,
|
||||
fwd_jaxpr_thunk: Callable[..., Tuple[core.Jaxpr, Sequence[Any]]],
|
||||
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
|
||||
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
|
||||
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
|
||||
else x for x, d in zip(args, in_dims)]
|
||||
@ -1006,7 +1005,7 @@ class Residuals:
|
||||
return cls(jaxpr, in_tree, out_tree, consts)
|
||||
|
||||
|
||||
def closure_convert(fun: Callable, *example_args) -> Tuple[Callable, List[Any]]:
|
||||
def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
|
||||
"""Closure conversion utility, for use with higher-order custom derivatives.
|
||||
|
||||
To define custom derivatives such as with ``jax.custom_vjp(f)``, the target
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
@ -51,7 +51,7 @@ class StoreEqual(lu.Store):
|
||||
|
||||
@util.curry
|
||||
def transformation_with_aux(
|
||||
gen, fun: lu.WrappedFun, *gen_static_args) -> Tuple[lu.WrappedFun, Any]:
|
||||
gen, fun: lu.WrappedFun, *gen_static_args) -> tuple[lu.WrappedFun, Any]:
|
||||
out_store = StoreEqual()
|
||||
out_thunk = lambda: out_store.val
|
||||
return fun.wrap(gen, gen_static_args, out_store), out_thunk
|
||||
|
@ -18,7 +18,7 @@ import pprint
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from typing import Any, IO, List, Optional
|
||||
from typing import Any, IO, Optional
|
||||
|
||||
from jax._src.debugger import core as debugger_core
|
||||
|
||||
@ -28,7 +28,7 @@ class CliDebugger(cmd.Cmd):
|
||||
"""A text-based debugger."""
|
||||
prompt = '(jdb) '
|
||||
|
||||
def __init__(self, frames: List[DebuggerFrame], thread_id,
|
||||
def __init__(self, frames: list[DebuggerFrame], thread_id,
|
||||
stdin: Optional[IO[str]] = None, stdout: Optional[IO[str]] = None,
|
||||
completekey: str = "tab"):
|
||||
super().__init__(stdin=stdin, stdout=stdout, completekey=completekey)
|
||||
@ -163,7 +163,7 @@ class CliDebugger(cmd.Cmd):
|
||||
except KeyboardInterrupt:
|
||||
print('--KeyboardInterrupt--', file=sys.stdout)
|
||||
|
||||
def run_debugger(frames: List[DebuggerFrame], thread_id: Optional[int],
|
||||
def run_debugger(frames: list[DebuggerFrame], thread_id: Optional[int],
|
||||
**kwargs: Any):
|
||||
CliDebugger(frames, thread_id, **kwargs).run()
|
||||
debugger_core.register_debugger("cli", run_debugger, -1)
|
||||
|
@ -18,8 +18,6 @@ import html
|
||||
import inspect
|
||||
import traceback
|
||||
|
||||
from typing import List
|
||||
|
||||
import uuid
|
||||
|
||||
from jax._src.debugger import colab_lib
|
||||
@ -42,7 +40,7 @@ except ImportError:
|
||||
class CodeViewer(colab_lib.DynamicDOMElement):
|
||||
"""A mutable DOM element that displays code as HTML."""
|
||||
|
||||
def __init__(self, code_: str, highlights: List[int], linenostart: int = 1):
|
||||
def __init__(self, code_: str, highlights: list[int], linenostart: int = 1):
|
||||
self._code = code_
|
||||
self._highlights = highlights
|
||||
self._view = colab_lib.dynamic(colab_lib.div())
|
||||
@ -226,7 +224,7 @@ class ColabDebugger(cli_debugger.CliDebugger):
|
||||
"""A JAX debugger for a Colab environment."""
|
||||
|
||||
def __init__(self,
|
||||
frames: List[debugger_core.DebuggerFrame],
|
||||
frames: list[debugger_core.DebuggerFrame],
|
||||
thread_id: int):
|
||||
super().__init__(frames, thread_id)
|
||||
self._debugger_view = DebuggerView(self.current_frame())
|
||||
|
@ -20,7 +20,7 @@ import functools
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Union
|
||||
|
||||
IS_COLAB_ENABLED = "google.colab" in sys.modules
|
||||
if IS_COLAB_ENABLED:
|
||||
@ -106,8 +106,8 @@ class StaticDOMElement(DOMElement):
|
||||
"""An immutable DOM element."""
|
||||
_uuid: str = dataclasses.field(init=False)
|
||||
name: str
|
||||
children: List[Union[str, DOMElement]]
|
||||
attrs: Dict[str, str]
|
||||
children: list[Union[str, DOMElement]]
|
||||
attrs: dict[str, str]
|
||||
|
||||
def html(self):
|
||||
attr_str = ""
|
||||
@ -137,7 +137,7 @@ class StaticDOMElement(DOMElement):
|
||||
return dataclasses.replace(self, **kwargs)
|
||||
|
||||
|
||||
def _style_dict_to_str(style_dict: Dict[str, Any]) -> str:
|
||||
def _style_dict_to_str(style_dict: dict[str, Any]) -> str:
|
||||
return " ".join([f"{k}: {v};" for k, v in style_dict.items()])
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import inspect
|
||||
import threading
|
||||
from typing import Any, Dict, Hashable, List, Optional, Protocol, Tuple
|
||||
from typing import Any, Hashable, Optional, Protocol
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -72,10 +72,10 @@ def _safe_flatten_dict(dct: dict[Any, Any]
|
||||
class DebuggerFrame:
|
||||
"""Encapsulates Python frame information."""
|
||||
filename: str
|
||||
locals: Dict[str, Any]
|
||||
globals: Dict[str, Any]
|
||||
locals: dict[str, Any]
|
||||
globals: dict[str, Any]
|
||||
code_context: str
|
||||
source: List[str]
|
||||
source: list[str]
|
||||
lineno: int
|
||||
offset: Optional[int]
|
||||
|
||||
@ -131,10 +131,10 @@ class DebuggerFrame:
|
||||
|
||||
class Debugger(Protocol):
|
||||
|
||||
def __call__(self, frames: List[DebuggerFrame], thread_id: Optional[int],
|
||||
def __call__(self, frames: list[DebuggerFrame], thread_id: Optional[int],
|
||||
**kwargs: Any) -> None:
|
||||
...
|
||||
_debugger_registry: Dict[str, Tuple[int, Debugger]] = {}
|
||||
_debugger_registry: dict[str, tuple[int, Debugger]] = {}
|
||||
|
||||
|
||||
def get_debugger(backend: Optional[str] = None) -> Debugger:
|
||||
|
@ -15,12 +15,12 @@ from __future__ import annotations
|
||||
import os
|
||||
import weakref
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
from jax._src.debugger import cli_debugger
|
||||
from jax._src.debugger import core as debugger_core
|
||||
|
||||
web_pdb_version: Optional[Tuple[int, ...]] = None
|
||||
web_pdb_version: Optional[tuple[int, ...]] = None
|
||||
try:
|
||||
import web_pdb # pytype: disable=import-error
|
||||
web_pdb_version = tuple(map(int, web_pdb.__version__.split(".")))
|
||||
@ -29,14 +29,14 @@ except:
|
||||
WEB_PDB_ENABLED = False
|
||||
|
||||
|
||||
_web_consoles: Dict[Tuple[str, int], web_pdb.WebConsole] = {}
|
||||
_web_consoles: dict[tuple[str, int], web_pdb.WebConsole] = {}
|
||||
|
||||
class WebDebugger(cli_debugger.CliDebugger):
|
||||
"""A web-based debugger."""
|
||||
prompt = '(jdb) '
|
||||
use_rawinput: bool = False
|
||||
|
||||
def __init__(self, frames: List[debugger_core.DebuggerFrame], thread_id,
|
||||
def __init__(self, frames: list[debugger_core.DebuggerFrame], thread_id,
|
||||
completekey: str = "tab", host: str = "", port: int = 5555):
|
||||
if (host, port) not in _web_consoles:
|
||||
_web_consoles[host, port] = web_pdb.WebConsole(host, port, self)
|
||||
@ -87,7 +87,7 @@ class WebDebugger(cli_debugger.CliDebugger):
|
||||
def run(self):
|
||||
return self.cmdloop()
|
||||
|
||||
def run_debugger(frames: List[debugger_core.DebuggerFrame],
|
||||
def run_debugger(frames: list[debugger_core.DebuggerFrame],
|
||||
thread_id: Optional[int], **kwargs: Any):
|
||||
WebDebugger(frames, thread_id, **kwargs).run()
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
import functools
|
||||
import string
|
||||
import sys
|
||||
from typing import Any, Dict, Callable, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
@ -413,8 +413,8 @@ def _raise_to_slice(slc: Union[slice, int]):
|
||||
return slice(slc, slc + 1)
|
||||
return slc
|
||||
|
||||
Color = Union[Tuple[float, float, float], str]
|
||||
ColorMap = Callable[[float], Tuple[float, float, float, float]]
|
||||
Color = Union[tuple[float, float, float], str]
|
||||
ColorMap = Callable[[float], tuple[float, float, float, float]]
|
||||
|
||||
def _canonicalize_color(color: Color) -> str:
|
||||
if isinstance(color, str):
|
||||
@ -464,9 +464,9 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *,
|
||||
device_kind = next(iter(sharding.device_set)).platform.upper()
|
||||
|
||||
device_indices_map = sharding.devices_indices_map(tuple(shape))
|
||||
slices: Dict[Tuple[int, ...], Set[int]] = {}
|
||||
heights: Dict[Tuple[int, ...], Optional[float]] = {}
|
||||
widths: Dict[Tuple[int, ...], float] = {}
|
||||
slices: dict[tuple[int, ...], set[int]] = {}
|
||||
heights: dict[tuple[int, ...], Optional[float]] = {}
|
||||
widths: dict[tuple[int, ...], float] = {}
|
||||
|
||||
for i, (dev, slcs) in enumerate(device_indices_map.items()):
|
||||
assert slcs is not None
|
||||
|
@ -21,8 +21,8 @@ import dataclasses
|
||||
from functools import partial
|
||||
import itertools
|
||||
import time
|
||||
from typing import (Any, Callable, Dict, Iterator, Optional,
|
||||
Set, Tuple, List, Union, NamedTuple, Sequence)
|
||||
from typing import (Any, Callable, Iterator, Optional, Union, NamedTuple,
|
||||
Sequence)
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@ -94,7 +94,7 @@ _on_exit = False
|
||||
|
||||
### op-by-op execution
|
||||
|
||||
ArgSpec = Tuple[core.AbstractValue, Optional[Device]]
|
||||
ArgSpec = tuple[core.AbstractValue, Optional[Device]]
|
||||
|
||||
def arg_spec(x: Any) -> ArgSpec:
|
||||
from jax._src import pjit
|
||||
@ -150,9 +150,9 @@ def simple_impl(prim):
|
||||
RuntimeToken = Any
|
||||
|
||||
class RuntimeTokenSet(threading.local):
|
||||
tokens: Dict[core.Effect, Tuple[RuntimeToken, Device]]
|
||||
output_tokens: Dict[Device, RuntimeToken]
|
||||
output_runtime_tokens: Dict[Device, RuntimeToken]
|
||||
tokens: dict[core.Effect, tuple[RuntimeToken, Device]]
|
||||
output_tokens: dict[Device, RuntimeToken]
|
||||
output_runtime_tokens: dict[Device, RuntimeToken]
|
||||
|
||||
def __init__(self):
|
||||
self.tokens = {}
|
||||
@ -324,7 +324,7 @@ class SourceInfo(NamedTuple):
|
||||
|
||||
|
||||
def jaxpr_shardings(
|
||||
jaxpr) -> Iterator[Tuple[XLACompatibleSharding, SourceInfo]]:
|
||||
jaxpr) -> Iterator[tuple[XLACompatibleSharding, SourceInfo]]:
|
||||
from jax._src import pjit
|
||||
from jax.experimental import shard_map
|
||||
|
||||
@ -368,7 +368,7 @@ def _is_bint_axis_size(d: core.AxisSize) -> bool:
|
||||
return False
|
||||
|
||||
def _prune_unused_inputs(
|
||||
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
|
||||
jaxpr: core.Jaxpr) -> tuple[core.Jaxpr, set[int], set[int]]:
|
||||
used_outputs = [True] * len(jaxpr.outvars)
|
||||
new_jaxpr, used_consts, used_inputs = pe.dce_jaxpr_consts(jaxpr, used_outputs)
|
||||
kept_const_idx = {i for i, b in enumerate(used_consts) if b}
|
||||
@ -534,7 +534,7 @@ def _cache_write(cache_key: str,
|
||||
compile_time_secs: float,
|
||||
module_name: str,
|
||||
backend: Backend, executable: xc.LoadedExecutable,
|
||||
host_callbacks: List[Any]):
|
||||
host_callbacks: list[Any]):
|
||||
"""Writes `serialized_computation` to the persistent compilation cache."""
|
||||
if host_callbacks:
|
||||
logger.info(
|
||||
|
@ -22,8 +22,7 @@
|
||||
|
||||
import builtins
|
||||
import functools
|
||||
from typing import (cast, overload, Any, Dict, List, Literal, Optional, Set,
|
||||
Tuple, Type, Union)
|
||||
from typing import cast, overload, Any, Literal, Optional, Union
|
||||
import warnings
|
||||
|
||||
import ml_dtypes
|
||||
@ -38,23 +37,23 @@ traceback_util.register_exclusion(__file__)
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# TODO(frostig,mattjj): achieve this w/ a protocol instead of registry?
|
||||
opaque_dtypes: Set[OpaqueDType] = set()
|
||||
opaque_dtypes: set[OpaqueDType] = set()
|
||||
|
||||
def is_opaque_dtype(dtype: Any) -> bool:
|
||||
return type(dtype) in opaque_dtypes
|
||||
|
||||
# fp8 support
|
||||
# TODO(jakevdp): remove this if statement when minimum ml_dtypes version > 0.1
|
||||
float8_e4m3b11fnuz: Optional[Type[np.generic]] = None
|
||||
float8_e4m3fn: Type[np.generic] = ml_dtypes.float8_e4m3fn
|
||||
float8_e5m2: Type[np.generic] = ml_dtypes.float8_e5m2
|
||||
float8_e4m3b11fnuz: Optional[type[np.generic]] = None
|
||||
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
|
||||
float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2
|
||||
|
||||
_float8_e4m3b11fnuz_dtype: Optional[np.dtype] = None
|
||||
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
|
||||
_float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2)
|
||||
|
||||
# bfloat16 support
|
||||
bfloat16: Type[np.generic] = ml_dtypes.bfloat16
|
||||
bfloat16: type[np.generic] = ml_dtypes.bfloat16
|
||||
_bfloat16_dtype: np.dtype = np.dtype(bfloat16)
|
||||
|
||||
_custom_float_scalar_types = [
|
||||
@ -74,9 +73,9 @@ if hasattr(ml_dtypes, "float8_e4m3b11fnuz"):
|
||||
_custom_float_scalar_types.insert(0, float8_e4m3b11fnuz) # type: ignore[arg-type]
|
||||
_custom_float_dtypes.insert(0, _float8_e4m3b11fnuz_dtype) # type: ignore[arg-type]
|
||||
|
||||
int4: Optional[Type[np.generic]] = None
|
||||
int4: Optional[type[np.generic]] = None
|
||||
_int4_dtype: Optional[np.dtype] = None
|
||||
uint4: Optional[Type[np.generic]] = None
|
||||
uint4: Optional[type[np.generic]] = None
|
||||
_uint4_dtype: Optional[np.dtype] = None
|
||||
|
||||
if hasattr(ml_dtypes, "int4"):
|
||||
@ -91,12 +90,12 @@ int_: type = np.int32 if config.jax_default_dtype_bits == '32' else np.int64
|
||||
uint: type = np.uint32 if config.jax_default_dtype_bits == '32' else np.uint64
|
||||
float_: type = np.float32 if config.jax_default_dtype_bits == '32' else np.float64
|
||||
complex_: type = np.complex64 if config.jax_default_dtype_bits == '32' else np.complex128
|
||||
_default_types: Dict[str, type] = {'b': bool_, 'i': int_, 'u': uint, 'f': float_, 'c': complex_}
|
||||
_default_types: dict[str, type] = {'b': bool_, 'i': int_, 'u': uint, 'f': float_, 'c': complex_}
|
||||
|
||||
# Trivial vectorspace datatype needed for tangent values of int/bool primals
|
||||
float0: np.dtype = np.dtype([('float0', np.void, 0)])
|
||||
|
||||
_dtype_to_32bit_dtype: Dict[DType, DType] = {
|
||||
_dtype_to_32bit_dtype: dict[DType, DType] = {
|
||||
np.dtype('int64'): np.dtype('int32'),
|
||||
np.dtype('uint64'): np.dtype('uint32'),
|
||||
np.dtype('float64'): np.dtype('float32'),
|
||||
@ -106,7 +105,7 @@ _dtype_to_32bit_dtype: Dict[DType, DType] = {
|
||||
# Note: we promote narrow types to float32 here for backward compatibility
|
||||
# with earlier approaches. We might consider revisiting this, or perhaps
|
||||
# tying the logic more closely to the type promotion lattice.
|
||||
_dtype_to_inexact: Dict[DType, DType] = {
|
||||
_dtype_to_inexact: dict[DType, DType] = {
|
||||
np.dtype(k): np.dtype(v) for k, v in [
|
||||
('bool', 'float32'),
|
||||
('uint8', 'float32'), ('int8', 'float32'),
|
||||
@ -163,7 +162,7 @@ def canonicalize_dtype(dtype: Any, allow_opaque_dtype: bool = False) -> Union[DT
|
||||
return _canonicalize_dtype(config.x64_enabled, allow_opaque_dtype, dtype)
|
||||
|
||||
# Default dtypes corresponding to Python scalars.
|
||||
python_scalar_dtypes : Dict[type, DType] = {
|
||||
python_scalar_dtypes : dict[type, DType] = {
|
||||
bool: np.dtype('bool'),
|
||||
int: np.dtype('int64'),
|
||||
float: np.dtype('float64'),
|
||||
@ -304,9 +303,9 @@ issubsctype = np.issubsctype
|
||||
JAXType = Union[type, DType]
|
||||
|
||||
# Enumeration of all valid JAX types in order.
|
||||
_weak_types: List[JAXType] = [int, float, complex]
|
||||
_bool_types: List[JAXType] = [np.dtype(bool)]
|
||||
_int_types: List[JAXType]
|
||||
_weak_types: list[JAXType] = [int, float, complex]
|
||||
_bool_types: list[JAXType] = [np.dtype(bool)]
|
||||
_int_types: list[JAXType]
|
||||
if int4 is not None:
|
||||
_int_types = [
|
||||
np.dtype(uint4),
|
||||
@ -332,13 +331,13 @@ else:
|
||||
np.dtype('int64'),
|
||||
]
|
||||
|
||||
_float_types: List[JAXType] = [
|
||||
_float_types: list[JAXType] = [
|
||||
*_custom_float_dtypes,
|
||||
np.dtype('float16'),
|
||||
np.dtype('float32'),
|
||||
np.dtype('float64'),
|
||||
]
|
||||
_complex_types: List[JAXType] = [
|
||||
_complex_types: list[JAXType] = [
|
||||
np.dtype('complex64'),
|
||||
np.dtype('complex128'),
|
||||
]
|
||||
@ -355,11 +354,11 @@ def _jax_type(dtype: DType, weak_type: bool) -> JAXType:
|
||||
return type(dtype.type(0).item())
|
||||
return dtype
|
||||
|
||||
def _dtype_and_weaktype(value: Any) -> Tuple[DType, bool]:
|
||||
def _dtype_and_weaktype(value: Any) -> tuple[DType, bool]:
|
||||
"""Return a (dtype, weak_type) tuple for the given input."""
|
||||
return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
|
||||
|
||||
def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> Dict[JAXType, List[JAXType]]:
|
||||
def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, list[JAXType]]:
|
||||
"""
|
||||
Return the type promotion lattice in the form of a DAG.
|
||||
This DAG maps each type to its immediately higher type on the lattice.
|
||||
@ -373,7 +372,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> Dict[JAXType, Lis
|
||||
c4, c8 = _complex_types
|
||||
i_, f_, c_ = _weak_types
|
||||
if jax_numpy_dtype_promotion == 'standard':
|
||||
out: Dict[JAXType, List[JAXType]]
|
||||
out: dict[JAXType, list[JAXType]]
|
||||
out = {
|
||||
b1: [i_],
|
||||
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
|
||||
@ -400,7 +399,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> Dict[JAXType, Lis
|
||||
raise ValueError(
|
||||
f"Unexpected value of jax_numpy_dtype_promotion={jax_numpy_dtype_promotion!r}")
|
||||
|
||||
def _make_lattice_upper_bounds(jax_numpy_dtype_promotion: str) -> Dict[JAXType, Set[JAXType]]:
|
||||
def _make_lattice_upper_bounds(jax_numpy_dtype_promotion: str) -> dict[JAXType, set[JAXType]]:
|
||||
lattice = _type_promotion_lattice(jax_numpy_dtype_promotion)
|
||||
upper_bounds = {node: {node} for node in lattice}
|
||||
for n in lattice:
|
||||
@ -413,7 +412,7 @@ def _make_lattice_upper_bounds(jax_numpy_dtype_promotion: str) -> Dict[JAXType,
|
||||
upper_bounds[n] |= new_upper_bounds
|
||||
return upper_bounds
|
||||
|
||||
_lattice_upper_bounds: Dict[str, Dict[JAXType, Set[JAXType]]] = {
|
||||
_lattice_upper_bounds: dict[str, dict[JAXType, set[JAXType]]] = {
|
||||
'standard': _make_lattice_upper_bounds('standard'),
|
||||
'strict': _make_lattice_upper_bounds('strict'),
|
||||
}
|
||||
@ -533,7 +532,7 @@ def dtype(x: Any, *, canonicalize: bool = False) -> DType:
|
||||
"type. Only arrays of numeric types are supported by JAX.")
|
||||
return canonicalize_dtype(dt, allow_opaque_dtype=True) if canonicalize else dt
|
||||
|
||||
def _lattice_result_type(*args: Any) -> Tuple[DType, bool]:
|
||||
def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
|
||||
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
|
||||
if len(dtypes) == 1:
|
||||
out_dtype = dtypes[0]
|
||||
@ -559,15 +558,15 @@ def _lattice_result_type(*args: Any) -> Tuple[DType, bool]:
|
||||
return out_dtype, (out_dtype != bool_) and out_weak_type
|
||||
|
||||
@overload
|
||||
def result_type(*args: Any, return_weak_type_flag: Literal[True]) -> Tuple[DType, bool]: ...
|
||||
def result_type(*args: Any, return_weak_type_flag: Literal[True]) -> tuple[DType, bool]: ...
|
||||
|
||||
@overload
|
||||
def result_type(*args: Any, return_weak_type_flag: Literal[False] = False) -> DType: ...
|
||||
|
||||
@overload
|
||||
def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, Tuple[DType, bool]]: ...
|
||||
def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, tuple[DType, bool]]: ...
|
||||
|
||||
def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, Tuple[DType, bool]]:
|
||||
def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, tuple[DType, bool]]:
|
||||
"""Convenience function to apply JAX argument dtype promotion.
|
||||
|
||||
Args:
|
||||
|
@ -13,12 +13,12 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterable, Optional, Type, Set
|
||||
from typing import Any, Iterable, Optional
|
||||
|
||||
class Effect:
|
||||
"""A generic side-effect."""
|
||||
|
||||
Effects = Set[Effect]
|
||||
Effects = set[Effect]
|
||||
|
||||
class JaxprInputEffect(Effect):
|
||||
"""A side-effect associated with the input of a jaxpr.
|
||||
@ -48,9 +48,9 @@ class JaxprInputEffect(Effect):
|
||||
class EffectTypeSet:
|
||||
|
||||
def __init__(self):
|
||||
self._effect_types: Set[Type[Effect]] = set()
|
||||
self._effect_types: set[type[Effect]] = set()
|
||||
|
||||
def add_type(self, effect_type: Type[Effect]):
|
||||
def add_type(self, effect_type: type[Effect]):
|
||||
self._effect_types.add(effect_type)
|
||||
|
||||
def contains(self, eff: Effect) -> bool:
|
||||
|
@ -16,7 +16,7 @@ import contextlib
|
||||
import functools
|
||||
import itertools as it
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union
|
||||
from typing import Any, Callable, Sequence, Optional, Union
|
||||
|
||||
import jax
|
||||
from jax._src import linear_util as lu
|
||||
@ -45,8 +45,8 @@ def identity(x): return x
|
||||
|
||||
def _update_annotation(
|
||||
f: lu.WrappedFun,
|
||||
orig_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]],
|
||||
explicit_nonzeros: List[bool]
|
||||
orig_type: Optional[tuple[tuple[core.AbstractValue, bool], ...]],
|
||||
explicit_nonzeros: list[bool]
|
||||
) -> lu.WrappedFun:
|
||||
if orig_type is None:
|
||||
return f
|
||||
@ -221,13 +221,13 @@ def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
|
||||
if not is_undefined_primal(val):
|
||||
primal_env[v] = val
|
||||
|
||||
primal_env: Dict[Any, Any] = {}
|
||||
primal_env: dict[Any, Any] = {}
|
||||
map(write_primal, jaxpr.constvars, consts)
|
||||
# FIXME: invars can contain both primal and tangent values, and this line
|
||||
# forces primal_in to contain UndefinedPrimals for tangent values!
|
||||
map(write_primal, jaxpr.invars, primals_in)
|
||||
|
||||
ct_env: Dict[Any, Any] = {}
|
||||
ct_env: dict[Any, Any] = {}
|
||||
ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
|
||||
else contextlib.nullcontext())
|
||||
with ctx:
|
||||
@ -481,17 +481,17 @@ def _primal_tangent_shapes_match(primal, tangent):
|
||||
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
|
||||
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
|
||||
|
||||
call_param_updaters: Dict[core.Primitive, Callable] = {}
|
||||
call_transpose_param_updaters: Dict[core.Primitive, Callable] = {}
|
||||
call_param_updaters: dict[core.Primitive, Callable] = {}
|
||||
call_transpose_param_updaters: dict[core.Primitive, Callable] = {}
|
||||
|
||||
|
||||
# -------------------- Primitives --------------------
|
||||
|
||||
primitive_jvps : Dict[core.Primitive, Callable] = {}
|
||||
primitive_jvps : dict[core.Primitive, Callable] = {}
|
||||
|
||||
primitive_transposes: Dict[core.Primitive, Callable] = {}
|
||||
primitive_transposes: dict[core.Primitive, Callable] = {}
|
||||
# transpose rules that internally perform reductions over the given named axes
|
||||
reducing_transposes: Dict[core.Primitive, Callable] = {}
|
||||
reducing_transposes: dict[core.Primitive, Callable] = {}
|
||||
|
||||
|
||||
def deflinear(primitive, transpose_rule):
|
||||
@ -693,7 +693,7 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
|
||||
|
||||
def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
|
||||
instantiate: Union[bool, Sequence[bool]]
|
||||
) -> Tuple[core.ClosedJaxpr, List[bool]]:
|
||||
) -> tuple[core.ClosedJaxpr, list[bool]]:
|
||||
if type(instantiate) is bool:
|
||||
instantiate = (instantiate,) * len(jaxpr.out_avals)
|
||||
return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate))
|
||||
|
@ -15,8 +15,7 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Dict, Iterable, List, Optional,
|
||||
Sequence, Set, Tuple, Type, Union)
|
||||
from typing import Any, Callable, Iterable, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -115,7 +114,7 @@ class RaggedAxis:
|
||||
# For each axis, we store its index and the corresponding segment lengths.
|
||||
# For example, the pile i:(Fin 3) => f32[lens1.i, 7, lens2.i]
|
||||
# would be represented with ragged_axes = [(1, lens1), (3, lens2)]
|
||||
ragged_axes: List[Tuple[int, Array]]
|
||||
ragged_axes: list[tuple[int, Array]]
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
@ -139,7 +138,7 @@ class RaggedAxis:
|
||||
return RaggedAxis(self.stacked_axis, new_ragged_axes)
|
||||
|
||||
def make_batch_axis(
|
||||
ndim: int, stacked_axis: int, ragged_axes: List[Tuple[int, Array]]
|
||||
ndim: int, stacked_axis: int, ragged_axes: list[tuple[int, Array]]
|
||||
) -> Union[int, RaggedAxis]:
|
||||
if ragged_axes:
|
||||
canonical = [(canonicalize_axis(ax, ndim), sz) for ax, sz in ragged_axes]
|
||||
@ -241,7 +240,7 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
|
||||
if spec is not None else x)
|
||||
else:
|
||||
assert False
|
||||
to_elt_handlers: Dict[Type, ToEltHandler] = {}
|
||||
to_elt_handlers: dict[type, ToEltHandler] = {}
|
||||
|
||||
def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
|
||||
) -> Vmappable:
|
||||
@ -257,7 +256,7 @@ def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
|
||||
return _pile_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val)
|
||||
else:
|
||||
return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
|
||||
from_elt_handlers: Dict[Type, FromEltHandler] = {}
|
||||
from_elt_handlers: dict[type, FromEltHandler] = {}
|
||||
|
||||
def make_iota(axis_size: AxisSize) -> Array:
|
||||
handler = make_iota_handlers.get(type(axis_size))
|
||||
@ -265,9 +264,9 @@ def make_iota(axis_size: AxisSize) -> Array:
|
||||
return handler(axis_size)
|
||||
else:
|
||||
return jax.lax.iota('int32', int(axis_size))
|
||||
make_iota_handlers: Dict[Type, MakeIotaHandler] = {}
|
||||
make_iota_handlers: dict[type, MakeIotaHandler] = {}
|
||||
|
||||
def register_vmappable(data_type: Type, spec_type: Type, axis_size_type: Type,
|
||||
def register_vmappable(data_type: type, spec_type: type, axis_size_type: type,
|
||||
to_elt: Callable, from_elt: Callable,
|
||||
make_iota: Optional[Callable]):
|
||||
vmappables[data_type] = (spec_type, axis_size_type)
|
||||
@ -275,10 +274,10 @@ def register_vmappable(data_type: Type, spec_type: Type, axis_size_type: Type,
|
||||
to_elt_handlers[data_type] = to_elt
|
||||
from_elt_handlers[data_type] = from_elt
|
||||
if make_iota: make_iota_handlers[axis_size_type] = make_iota
|
||||
vmappables: Dict[Type, Tuple[Type, Type]] = {}
|
||||
spec_types: Set[Type] = {PileAxis}
|
||||
vmappables: dict[type, tuple[type, type]] = {}
|
||||
spec_types: set[type] = {PileAxis}
|
||||
|
||||
def unregister_vmappable(data_type: Type) -> None:
|
||||
def unregister_vmappable(data_type: type) -> None:
|
||||
spec_type, axis_size_type = vmappables.pop(data_type)
|
||||
spec_types.remove(spec_type)
|
||||
del to_elt_handlers[data_type]
|
||||
@ -591,8 +590,8 @@ def _main_trace_for_axis_names(main_trace: core.MainTrace,
|
||||
### API for batching callables with vmappable inputs and outputs
|
||||
|
||||
def batch(fun: lu.WrappedFun, axis_name: AxisName, axis_size,
|
||||
in_dims, out_dim_dests, main_type: Type[BatchTrace] = BatchTrace,
|
||||
spmd_axis_name: Optional[Tuple[AxisName, ...]] = None
|
||||
in_dims, out_dim_dests, main_type: type[BatchTrace] = BatchTrace,
|
||||
spmd_axis_name: Optional[tuple[AxisName, ...]] = None
|
||||
) -> lu.WrappedFun:
|
||||
# we split up _batch_inner and _batch_outer for the leak checker
|
||||
f = _batch_inner(fun, axis_size, out_dim_dests)
|
||||
@ -624,11 +623,11 @@ def _batch_inner(axis_size, out_dim_dests, main, in_dims, *in_vals):
|
||||
|
||||
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
|
||||
def vtile(f_flat: lu.WrappedFun,
|
||||
in_axes_flat: Tuple[Optional[int], ...],
|
||||
out_axes_flat: Tuple[Optional[int], ...],
|
||||
in_axes_flat: tuple[Optional[int], ...],
|
||||
out_axes_flat: tuple[Optional[int], ...],
|
||||
tile_size: Optional[int],
|
||||
axis_name: AxisName,
|
||||
main_type: Type[BatchTrace] = BatchTrace):
|
||||
main_type: type[BatchTrace] = BatchTrace):
|
||||
@curry
|
||||
def tile_axis(arg, axis: Optional[int], tile_size):
|
||||
if axis is None:
|
||||
@ -673,22 +672,22 @@ def batch_subtrace(main, in_dims, *in_vals):
|
||||
|
||||
def batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
|
||||
axis_size: core.AxisSize,
|
||||
in_axes: Tuple[Union[int, NotMapped], ...],
|
||||
in_axes: tuple[Union[int, NotMapped], ...],
|
||||
axis_name: AxisName,
|
||||
spmd_axis_name: AxisName,
|
||||
main_type: Type[BatchTrace],
|
||||
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
|
||||
main_type: type[BatchTrace],
|
||||
) -> tuple[core.ClosedJaxpr, tuple[Union[int, NotMapped], ...]]:
|
||||
return _batch_jaxpr2(closed_jaxpr, axis_size, tuple(in_axes), axis_name,
|
||||
spmd_axis_name, main_type)
|
||||
|
||||
@weakref_lru_cache
|
||||
def _batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
|
||||
axis_size: core.AxisSize,
|
||||
in_axes: Tuple[Union[int, NotMapped], ...],
|
||||
in_axes: tuple[Union[int, NotMapped], ...],
|
||||
axis_name: AxisName,
|
||||
spmd_axis_name: AxisName,
|
||||
main_type: Type[BatchTrace],
|
||||
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
|
||||
main_type: type[BatchTrace],
|
||||
) -> tuple[core.ClosedJaxpr, tuple[Union[int, NotMapped], ...]]:
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
||||
f, out_axes = _batch_jaxpr_inner(f, axis_size)
|
||||
f = _batch_jaxpr_outer(f, axis_name, spmd_axis_name, axis_size, in_axes,
|
||||
@ -862,10 +861,10 @@ def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False)
|
||||
|
||||
### utilities for defining primitives' batching rules
|
||||
|
||||
BatchingRule = Callable[..., Tuple[Any, Union[None, int, Tuple[Union[None, int], ...]]]]
|
||||
primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
|
||||
axis_primitive_batchers: Dict[core.Primitive, Callable] = {}
|
||||
spmd_axis_primitive_batchers: Dict[core.Primitive, Callable] = {}
|
||||
BatchingRule = Callable[..., tuple[Any, Union[None, int, tuple[Union[None, int], ...]]]]
|
||||
primitive_batchers : dict[core.Primitive, BatchingRule] = {}
|
||||
axis_primitive_batchers: dict[core.Primitive, Callable] = {}
|
||||
spmd_axis_primitive_batchers: dict[core.Primitive, Callable] = {}
|
||||
|
||||
def defvectorized(prim):
|
||||
primitive_batchers[prim] = partial(vectorized_batcher, prim)
|
||||
|
@ -24,8 +24,8 @@ import itertools
|
||||
import operator
|
||||
import re
|
||||
import typing
|
||||
from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional,
|
||||
Protocol, Sequence, Set, Tuple, Type, Union)
|
||||
from typing import (Any, Callable, Iterator, NamedTuple, Optional,
|
||||
Protocol, Sequence, Union)
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -114,7 +114,7 @@ def delegate_lowering(ctx, lowering_fun, *args, **ctx_override_kwargs):
|
||||
# IR Types
|
||||
|
||||
# Non-canonicalized dtype to IR type mapping.
|
||||
_dtype_to_ir_type : Dict[np.dtype, Callable[[], ir.Type]] = {
|
||||
_dtype_to_ir_type : dict[np.dtype, Callable[[], ir.Type]] = {
|
||||
np.dtype(dtypes.float0): partial(ir.IntegerType.get_signless, 1),
|
||||
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
|
||||
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
|
||||
@ -165,7 +165,7 @@ def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
|
||||
shape = [d if type(d) is int else dyn_size for d in aval.shape]
|
||||
return (ir.RankedTensorType.get(shape, dtype_to_ir_type(aval.dtype)),)
|
||||
|
||||
ir_type_handlers: Dict[Type[core.AbstractValue],
|
||||
ir_type_handlers: dict[type[core.AbstractValue],
|
||||
Callable[[Any], Sequence[ir.Type]]] = {}
|
||||
|
||||
def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]:
|
||||
@ -203,7 +203,7 @@ class ConstantHandler(Protocol):
|
||||
|
||||
A JAX value is represented by zero or more IR values."""
|
||||
|
||||
_constant_handlers : Dict[type, ConstantHandler] = {}
|
||||
_constant_handlers : dict[type, ConstantHandler] = {}
|
||||
|
||||
def register_constant_handler(type_: type, handler_fun: ConstantHandler):
|
||||
_constant_handlers[type_] = handler_fun
|
||||
@ -338,7 +338,7 @@ def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
|
||||
return ir.Location.callsite(frame_locs[-1], frame_locs[-2::-1])
|
||||
|
||||
def _source_info_to_location(
|
||||
primitive: core.Primitive, params: Dict,
|
||||
primitive: core.Primitive, params: dict,
|
||||
source_info: source_info_util.SourceInfo,
|
||||
name_stack: source_info_util.NameStack) -> ir.Location:
|
||||
eqn_str = (f'{str(source_info.name_stack)}/'
|
||||
@ -410,15 +410,15 @@ class ModuleContext:
|
||||
platform: str
|
||||
axis_context: AxisContext
|
||||
name_stack: source_info_util.NameStack
|
||||
keepalives: List[Any]
|
||||
keepalives: list[Any]
|
||||
channel_iterator: Iterator[int]
|
||||
host_callbacks: List[Any]
|
||||
host_callbacks: list[Any]
|
||||
# Keep state for the lowering of shape polymorphism
|
||||
shape_poly_state: ShapePolyLoweringState
|
||||
|
||||
# Cached primitive lowerings.
|
||||
cached_primitive_lowerings: Dict[Any, func_dialect.FuncOp]
|
||||
cached_call_jaxpr_lowerings: Dict[Any, func_dialect.FuncOp]
|
||||
cached_primitive_lowerings: dict[Any, func_dialect.FuncOp]
|
||||
cached_call_jaxpr_lowerings: dict[Any, func_dialect.FuncOp]
|
||||
|
||||
|
||||
@property
|
||||
@ -431,16 +431,16 @@ class ModuleContext:
|
||||
platform: str,
|
||||
axis_context: AxisContext,
|
||||
name_stack: source_info_util.NameStack,
|
||||
keepalives: List[Any],
|
||||
keepalives: list[Any],
|
||||
channel_iterator: Iterator[int],
|
||||
host_callbacks: List[Any],
|
||||
host_callbacks: list[Any],
|
||||
context: Optional[ir.Context] = None,
|
||||
module: Optional[ir.Module] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
symbol_table: Optional[ir.SymbolTable] = None,
|
||||
cached_primitive_lowerings: Optional[Dict[Any,
|
||||
cached_primitive_lowerings: Optional[dict[Any,
|
||||
func_dialect.FuncOp]] = None,
|
||||
cached_call_jaxpr_lowerings: Optional[Dict[Any,
|
||||
cached_call_jaxpr_lowerings: Optional[dict[Any,
|
||||
func_dialect.FuncOp]] = None,
|
||||
shape_poly_state = None):
|
||||
assert platform is not None
|
||||
@ -489,7 +489,7 @@ class LoweringRuleContext:
|
||||
avals_out: Any # Usually Sequence[core.AbstractValue], but sometimes None.
|
||||
tokens_in: TokenSet
|
||||
tokens_out: Optional[TokenSet] # Mutable store for output containers
|
||||
axis_size_env: Optional[Dict[core.Var, ir.Value]] = None # Dynamic axis sizes
|
||||
axis_size_env: Optional[dict[core.Var, ir.Value]] = None # Dynamic axis sizes
|
||||
dim_var_values: Sequence[ir.Value] = () # The values for the dimension variables
|
||||
# in same order as module_context.shape_poly_state.dim_vars
|
||||
|
||||
@ -509,8 +509,8 @@ if not MYPY:
|
||||
else:
|
||||
LoweringRule = Any
|
||||
|
||||
_lowerings: Dict[core.Primitive, LoweringRule] = {}
|
||||
_platform_specific_lowerings: Dict[str, Dict[core.Primitive, LoweringRule]]
|
||||
_lowerings: dict[core.Primitive, LoweringRule] = {}
|
||||
_platform_specific_lowerings: dict[str, dict[core.Primitive, LoweringRule]]
|
||||
_platform_specific_lowerings = collections.defaultdict(dict)
|
||||
|
||||
def register_lowering(prim: core.Primitive, rule: LoweringRule,
|
||||
@ -553,7 +553,7 @@ def sharded_aval(aval: core.AbstractValue,
|
||||
|
||||
|
||||
def eval_dynamic_shape(ctx: LoweringRuleContext,
|
||||
shape: core.Shape) -> Tuple[Union[int, Value], ...]:
|
||||
shape: core.Shape) -> tuple[Union[int, Value], ...]:
|
||||
if config.jax_dynamic_shapes:
|
||||
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
|
||||
else:
|
||||
@ -569,7 +569,7 @@ def eval_dynamic_shape(ctx: LoweringRuleContext,
|
||||
|
||||
# TODO: replace usage of eval_dynamic_shape_as_vals with eval_dynamic_shape_as_ivals
|
||||
def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
|
||||
shape: core.Shape) -> Tuple[Value, ...]:
|
||||
shape: core.Shape) -> tuple[Value, ...]:
|
||||
"""Evaluates the dynamic shapes as int32 values."""
|
||||
def convert_dim(d: Union[int, Value]):
|
||||
if type(d) is int:
|
||||
@ -585,7 +585,7 @@ def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
|
||||
|
||||
def eval_dynamic_shape_as_ivals(
|
||||
ctx: LoweringRuleContext, shape: core.Shape
|
||||
) -> Tuple[Union[int, Value], ...]:
|
||||
) -> tuple[Union[int, Value], ...]:
|
||||
"""Evaluates the dynamic shapes as int or ir.int32 values."""
|
||||
def convert_dim(d: Union[int, Value]) -> Union[int, ir.Value]:
|
||||
if type(d) is int:
|
||||
@ -602,7 +602,7 @@ def eval_dynamic_shape_as_ivals(
|
||||
class LoweringResult(NamedTuple):
|
||||
module: ir.Module
|
||||
keepalive: Optional[Any]
|
||||
host_callbacks: List[Any]
|
||||
host_callbacks: list[Any]
|
||||
shape_poly_state: ShapePolyLoweringState
|
||||
|
||||
|
||||
@ -622,7 +622,7 @@ def _to_logical_op_sharding(
|
||||
def lower_jaxpr_to_module(
|
||||
module_name: str,
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
ordered_effects: List[core.Effect],
|
||||
ordered_effects: list[core.Effect],
|
||||
backend_or_name: Optional[Union[str, xb.XlaBackend]],
|
||||
platform: str,
|
||||
axis_context: AxisContext,
|
||||
@ -665,8 +665,8 @@ def lower_jaxpr_to_module(
|
||||
# HLO channels need to start at 1
|
||||
channel_iter = itertools.count(1)
|
||||
# Create a keepalives list that will be mutated during the lowering.
|
||||
keepalives: List[Any] = []
|
||||
host_callbacks: List[Any] = []
|
||||
keepalives: list[Any] = []
|
||||
host_callbacks: list[Any] = []
|
||||
|
||||
dim_vars: Sequence[str]
|
||||
if not config.jax_dynamic_shapes:
|
||||
@ -786,7 +786,7 @@ class TokenSet:
|
||||
tokens = [create_token() for _ in effects]
|
||||
return TokenSet(zip(effects, tokens))
|
||||
|
||||
def items(self) -> Sequence[Tuple[core.Effect, Token]]:
|
||||
def items(self) -> Sequence[tuple[core.Effect, Token]]:
|
||||
return tuple(self._tokens.items())
|
||||
|
||||
def effects(self) -> set[core.Effect]:
|
||||
@ -934,7 +934,7 @@ def lower_jaxpr_to_fun(
|
||||
or arg_names is not None
|
||||
or num_tokens > 0
|
||||
):
|
||||
arg_attrs: List[Dict[str, ir.Attribute]] = [
|
||||
arg_attrs: list[dict[str, ir.Attribute]] = [
|
||||
{} for _ in range(len(flat_input_types))]
|
||||
|
||||
if replicated_args is not None:
|
||||
@ -952,7 +952,7 @@ def lower_jaxpr_to_fun(
|
||||
if input_output_aliases is not None:
|
||||
output_ids = util.unflatten(list(range(len(flat_output_types))),
|
||||
map(len, output_types))
|
||||
aliases: List[Optional[int]] = []
|
||||
aliases: list[Optional[int]] = []
|
||||
for types, alias in zip(input_types, input_output_aliases):
|
||||
if alias is None:
|
||||
aliases.extend([None] * len(types))
|
||||
@ -977,7 +977,7 @@ def lower_jaxpr_to_fun(
|
||||
func_op.arg_attrs = ir.ArrayAttr.get(
|
||||
[ir.DictAttr.get(attrs) for attrs in arg_attrs])
|
||||
|
||||
result_attrs: List[Dict[str, ir.Attribute]] = [
|
||||
result_attrs: list[dict[str, ir.Attribute]] = [
|
||||
{} for _ in range(len(flat_output_types))]
|
||||
|
||||
if num_tokens > 0:
|
||||
@ -1020,7 +1020,7 @@ def lower_jaxpr_to_fun(
|
||||
tokens_in = TokenSet.create(effects)
|
||||
else:
|
||||
tokens_in = TokenSet(zip(effects, token_args))
|
||||
args: List[List[ir.Value]] = []
|
||||
args: list[list[ir.Value]] = []
|
||||
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
|
||||
if replace_tokens_with_dummy and aval is core.abstract_token:
|
||||
args.append(hlo.CreateTokenOp().results)
|
||||
@ -1104,7 +1104,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
consts: Sequence[Sequence[ir.Value]],
|
||||
*args: Sequence[ir.Value],
|
||||
dim_var_values: Sequence[ir.Value]
|
||||
) -> Tuple[Sequence[Sequence[ir.Value]], TokenSet]:
|
||||
) -> tuple[Sequence[Sequence[ir.Value]], TokenSet]:
|
||||
"""Lowers a jaxpr into MLIR, inlined into an existing function.
|
||||
|
||||
Assumes that an MLIR context, location, and insertion point are set.
|
||||
@ -1131,7 +1131,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
env[v] = tuple(node)
|
||||
|
||||
|
||||
env: Dict[core.Var, Tuple[ir.Value, ...]] = {}
|
||||
env: dict[core.Var, tuple[ir.Value, ...]] = {}
|
||||
|
||||
assert len(args) == len(jaxpr.invars), (jaxpr, args)
|
||||
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
|
||||
@ -1547,7 +1547,7 @@ def _wrap_with_spmd_op(name: str,
|
||||
x: ir.Value,
|
||||
aval_out: core.AbstractValue,
|
||||
sharding_proto: xc.OpSharding,
|
||||
unspecified_dims: Optional[Set[int]] = None):
|
||||
unspecified_dims: Optional[set[int]] = None):
|
||||
# unspecified_dims indicate dimensions whose shardings are not specified and
|
||||
# XLA sharding propagation can change them.
|
||||
if unspecified_dims:
|
||||
@ -1810,13 +1810,13 @@ def _emit_tpu_python_callback(
|
||||
callback,
|
||||
token: Optional[Any],
|
||||
operands: Sequence[ir.Value],
|
||||
operand_avals: List[core.ShapedArray],
|
||||
operand_shapes: List[xc.Shape],
|
||||
result_avals: List[core.ShapedArray],
|
||||
result_shapes: List[xc.Shape],
|
||||
operand_avals: list[core.ShapedArray],
|
||||
operand_shapes: list[xc.Shape],
|
||||
result_avals: list[core.ShapedArray],
|
||||
result_shapes: list[xc.Shape],
|
||||
*,
|
||||
sharding: Optional[xc.OpSharding] = None
|
||||
) -> Tuple[List[ir.Value], Any, Any]:
|
||||
) -> tuple[list[ir.Value], Any, Any]:
|
||||
token = token or hlo.CreateTokenOp().result
|
||||
_wrapped_callback = callback
|
||||
|
||||
@ -1886,12 +1886,12 @@ def _aval_to_default_layouts(aval):
|
||||
|
||||
def emit_python_callback(
|
||||
ctx: LoweringRuleContext, callback, token: Optional[Any],
|
||||
operands: Sequence[ir.Value], operand_avals: List[core.ShapedArray],
|
||||
result_avals: List[core.ShapedArray],
|
||||
operands: Sequence[ir.Value], operand_avals: list[core.ShapedArray],
|
||||
result_avals: list[core.ShapedArray],
|
||||
has_side_effect: bool, *, sharding: Optional[xc.OpSharding] = None,
|
||||
operand_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None,
|
||||
result_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None,
|
||||
) -> Tuple[List[ir.Value], Any, Any]:
|
||||
) -> tuple[list[ir.Value], Any, Any]:
|
||||
"""Emits MLIR that calls back to a provided Python function."""
|
||||
platform = ctx.module_context.platform
|
||||
if platform not in {"cpu", "cuda", "rocm", "tpu"}:
|
||||
@ -2012,7 +2012,7 @@ def custom_call(
|
||||
result_shapes: Optional[Sequence[ir.Value]] = None,
|
||||
called_computations: Sequence[str] = (),
|
||||
api_version: int = 2,
|
||||
extra_attributes: Dict[str, ir.Attribute] = {},
|
||||
extra_attributes: dict[str, ir.Attribute] = {},
|
||||
) -> ir.Operation:
|
||||
"""Wraps a hlo.CustomCall.
|
||||
|
||||
@ -2063,7 +2063,7 @@ def reduce_window(
|
||||
# d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each
|
||||
# spatial dimension.
|
||||
int2d = aval_to_ir_type(core.ShapedArray((1, 2), np.int32))
|
||||
def prep_one_pad(pad_lo_hi: Tuple[core.DimSize, core.DimSize]):
|
||||
def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):
|
||||
pads = shape_tensor(eval_dynamic_shape(ctx, pad_lo_hi)) # i32[2]
|
||||
return hlo.ReshapeOp(int2d, pads)
|
||||
d_padding = hlo.ConcatenateOp(list(map(prep_one_pad, padding)),
|
||||
|
@ -20,8 +20,8 @@ from functools import partial
|
||||
import inspect
|
||||
import itertools as it
|
||||
import operator as op
|
||||
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
|
||||
List, Union, Hashable, Set)
|
||||
from typing import (Any, Callable, NamedTuple, Optional, Sequence,
|
||||
Union, Hashable)
|
||||
from weakref import ref
|
||||
|
||||
import numpy as np
|
||||
@ -62,7 +62,7 @@ ConstId = int
|
||||
def _update_annotation_known(
|
||||
f: lu.WrappedFun,
|
||||
orig_type: Optional[InputType],
|
||||
in_knowns: List[bool]
|
||||
in_knowns: list[bool]
|
||||
) -> lu.WrappedFun:
|
||||
if orig_type is None: return f
|
||||
# orig_type might contain DBIdx, but we're tossing out some args so we have to
|
||||
@ -104,7 +104,7 @@ class PartialVal(tuple):
|
||||
* `(<AbstractValue>, None)` indicates an unknown value characterized by an
|
||||
abstract value.
|
||||
"""
|
||||
def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
|
||||
def __new__(cls, xs: tuple[Optional[AbstractValue], core.Value]):
|
||||
pv, const = xs
|
||||
if config.jax_enable_checks:
|
||||
# type checks
|
||||
@ -549,8 +549,8 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
raise NotImplementedError # TODO(mattjj)
|
||||
|
||||
def partition_pvals(
|
||||
pvals: List[PartialVal]
|
||||
) -> Tuple[List[bool], List[AbstractValue], List[Any]]:
|
||||
pvals: list[PartialVal]
|
||||
) -> tuple[list[bool], list[AbstractValue], list[Any]]:
|
||||
knowns = [pval.is_known() for pval in pvals ]
|
||||
avals = [pval.get_aval() for pval in pvals if not pval.is_known()]
|
||||
consts = [pval.get_known() for pval in pvals if pval.is_known()]
|
||||
@ -581,7 +581,7 @@ def trace_to_subjaxpr_nounits_dyn(
|
||||
# for all axis sizes, so that we can then use those tracers in the shapes of
|
||||
# avals for unknown inputs' tracers. We use ConstVar recipes for on-the-fly
|
||||
# type agreement checking via get_referent.
|
||||
in_consts_full: List[Optional[JaxprTracer]] = [None] * len(in_type)
|
||||
in_consts_full: list[Optional[JaxprTracer]] = [None] * len(in_type)
|
||||
in_consts_iter, in_knowns_iter = iter(in_consts), iter(in_knowns)
|
||||
for idx, (aval, explicit) in enumerate(in_type):
|
||||
if explicit and next(in_knowns_iter):
|
||||
@ -631,8 +631,8 @@ def trace_to_subjaxpr_nounits_dyn(
|
||||
for inst, t in zip(instantiate, out_tracers)]
|
||||
|
||||
# Collect known outputs.
|
||||
out_knowns: List[bool] = [t.is_known() for t in out_tracers]
|
||||
out_consts: List[Any] = [t.pval.get_known() for t in out_tracers
|
||||
out_knowns: list[bool] = [t.is_known() for t in out_tracers]
|
||||
out_consts: list[Any] = [t.pval.get_known() for t in out_tracers
|
||||
if t.is_known()]
|
||||
|
||||
# Build the jaxpr.
|
||||
@ -647,7 +647,7 @@ def trace_to_subjaxpr_nounits_dyn(
|
||||
# Which residuals are just forwarded inputs? Check obj id, then prune.
|
||||
id_map = {id(c.recipe.val): i for i, c in enumerate(in_consts_full) # type: ignore
|
||||
if c is not None}
|
||||
fwds: List[Optional[int]] = [id_map.get(id(c)) for c in res]
|
||||
fwds: list[Optional[int]] = [id_map.get(id(c)) for c in res]
|
||||
res = tuple([c for c, fwd in zip(res, fwds) if fwd is None])
|
||||
|
||||
del main, in_consts, trace, in_consts_iter, in_knowns_iter, in_consts_full, \
|
||||
@ -655,9 +655,9 @@ def trace_to_subjaxpr_nounits_dyn(
|
||||
yield (*out_consts, *res), (fwds, out_knowns, tuple(out_type), jaxpr, env)
|
||||
|
||||
|
||||
custom_partial_eval_rules: Dict[Primitive, Callable] = {}
|
||||
call_partial_eval_rules: Dict[Primitive, Callable] = {}
|
||||
call_param_updaters: Dict[Primitive, Callable] = {}
|
||||
custom_partial_eval_rules: dict[Primitive, Callable] = {}
|
||||
call_partial_eval_rules: dict[Primitive, Callable] = {}
|
||||
call_param_updaters: dict[Primitive, Callable] = {}
|
||||
|
||||
def _closed_call_param_updater(params, _, __):
|
||||
jaxpr = params.get('call_jaxpr')
|
||||
@ -733,7 +733,7 @@ class JaxprTracer(Tracer):
|
||||
def trace_to_jaxpr(
|
||||
fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
instantiate: Union[bool, Sequence[bool]] = False,
|
||||
) -> Tuple[Jaxpr, List[PartialVal], List[core.Value]]:
|
||||
) -> tuple[Jaxpr, list[PartialVal], list[core.Value]]:
|
||||
"""
|
||||
Partially evaluate a function, building a jaxpr for un-evaluated computation.
|
||||
|
||||
@ -770,7 +770,7 @@ def trace_to_jaxpr(
|
||||
def trace_to_jaxpr_nounits(
|
||||
fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
instantiate: Union[bool, Sequence[bool]] = False,
|
||||
) -> Tuple[Jaxpr, List[PartialVal], List[core.Value]]:
|
||||
) -> tuple[Jaxpr, list[PartialVal], list[core.Value]]:
|
||||
current_name_stack = source_info_util.current_name_stack()
|
||||
with core.new_main(JaxprTrace, name_stack=current_name_stack) as main:
|
||||
fun = trace_to_subjaxpr_nounits(fun, main, instantiate)
|
||||
@ -828,7 +828,7 @@ def trace_to_subjaxpr_nounits_fwd(
|
||||
# Which out_consts (aka residuals) are just forwarded inputs? Check obj id.
|
||||
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
|
||||
id_map = {id(c): i for i, c in enumerate(in_consts)}
|
||||
fwds: List[Optional[int]] = [id_map.get(id(c)) for c in out_consts]
|
||||
fwds: list[Optional[int]] = [id_map.get(id(c)) for c in out_consts]
|
||||
pruned_consts = [c for c, fwd in zip(out_consts, fwds) if fwd is None]
|
||||
|
||||
del out_tracers
|
||||
@ -844,14 +844,14 @@ class JaxprEqnRecipe(NamedTuple):
|
||||
out_tracer_refs: Sequence[ref[JaxprTracer]]
|
||||
out_avals: Sequence[core.AbstractValue]
|
||||
primitive: Primitive
|
||||
params: Dict[str, Any]
|
||||
params: dict[str, Any]
|
||||
effects: core.Effects
|
||||
source_info: source_info_util.SourceInfo
|
||||
|
||||
def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
|
||||
out_tracers: Sequence[JaxprTracer],
|
||||
primitive: Primitive,
|
||||
params: Dict[str, Any],
|
||||
params: dict[str, Any],
|
||||
effects: core.Effects,
|
||||
source_info: source_info_util.SourceInfo
|
||||
) -> JaxprEqnRecipe:
|
||||
@ -882,7 +882,7 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
|
||||
def tracers_to_jaxpr(
|
||||
in_tracers: Sequence[JaxprTracer],
|
||||
out_tracers: Sequence[JaxprTracer]
|
||||
) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]:
|
||||
) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]:
|
||||
"""Constructs Jaxpr given tracers for inputs and outputs.
|
||||
|
||||
Params:
|
||||
@ -896,10 +896,10 @@ def tracers_to_jaxpr(
|
||||
"""
|
||||
gensym = core.gensym()
|
||||
|
||||
t_to_var: Dict[TracerId, Var] = {}
|
||||
consts: Dict[Var, Any] = {}
|
||||
env: Dict[Var, JaxprTracer] = {}
|
||||
constid_to_var: Dict[ConstId, Var] = {} # for deduplication
|
||||
t_to_var: dict[TracerId, Var] = {}
|
||||
consts: dict[Var, Any] = {}
|
||||
env: dict[Var, JaxprTracer] = {}
|
||||
constid_to_var: dict[ConstId, Var] = {} # for deduplication
|
||||
|
||||
def get_atom(t: JaxprTracer) -> Atom:
|
||||
return t.recipe if type(t.recipe) is Literal else t_to_var[id(t)]
|
||||
@ -920,7 +920,7 @@ def tracers_to_jaxpr(
|
||||
return aval
|
||||
|
||||
processed_eqn_ids = set()
|
||||
eqns: List[core.JaxprEqn] = []
|
||||
eqns: list[core.JaxprEqn] = []
|
||||
for t in toposort([*in_tracers, *out_tracers]):
|
||||
r = t.recipe
|
||||
if isinstance(r, JaxprEqnRecipe):
|
||||
@ -1004,7 +1004,7 @@ def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
|
||||
def partial_eval_jaxpr_nounits(
|
||||
jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
|
||||
instantiate: Union[bool, Sequence[bool]],
|
||||
) -> Tuple[ClosedJaxpr, ClosedJaxpr, List[bool], List[AbstractValue]]:
|
||||
) -> tuple[ClosedJaxpr, ClosedJaxpr, list[bool], list[AbstractValue]]:
|
||||
"""Unzip a jaxpr in two by data dependence into 'known' and 'unknown' parts.
|
||||
|
||||
That is, given a jaxpr and a sequence of booleans indicating which jaxpr
|
||||
@ -1122,7 +1122,7 @@ def partial_eval_jaxpr_custom(
|
||||
ensure_out_unknowns: Union[bool, Sequence[bool]],
|
||||
ensure_out_inst: Union[bool, Sequence[bool]],
|
||||
saveable: Callable[..., bool],
|
||||
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]:
|
||||
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int]:
|
||||
if type(in_inst) is bool:
|
||||
in_inst = (in_inst,) * len(jaxpr.invars)
|
||||
if type(ensure_out_unknowns) is bool:
|
||||
@ -1146,7 +1146,7 @@ def partial_eval_jaxpr_stateful(
|
||||
ensure_out_unknowns: Union[bool, Sequence[bool]],
|
||||
ensure_out_inst: Union[bool, Sequence[bool]],
|
||||
saveable: Callable[..., bool],
|
||||
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int, int]:
|
||||
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int, int]:
|
||||
if type(in_inst) is bool:
|
||||
in_inst = (in_inst,) * len(jaxpr.invars)
|
||||
if type(ensure_out_unknowns) is bool:
|
||||
@ -1163,17 +1163,17 @@ def partial_eval_jaxpr_stateful(
|
||||
@weakref_lru_cache
|
||||
def _partial_eval_jaxpr_custom_cached(
|
||||
jaxpr: Jaxpr,
|
||||
in_unknowns: Tuple[bool, ...],
|
||||
in_inst: Tuple[bool, ...],
|
||||
ensure_out_unknowns: Tuple[bool, ...],
|
||||
ensure_out_inst: Tuple[bool, ...],
|
||||
in_unknowns: tuple[bool, ...],
|
||||
in_inst: tuple[bool, ...],
|
||||
ensure_out_unknowns: tuple[bool, ...],
|
||||
ensure_out_inst: tuple[bool, ...],
|
||||
saveable: Callable[..., bool],
|
||||
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int, int]:
|
||||
env: Dict[Var, Tuple[bool, bool]] = {}
|
||||
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int, int]:
|
||||
env: dict[Var, tuple[bool, bool]] = {}
|
||||
residuals: OrderedSet[Var] = OrderedSet()
|
||||
residual_refs: OrderedSet[Var] = OrderedSet()
|
||||
|
||||
def read(x: Atom) -> Tuple[bool, bool]:
|
||||
def read(x: Atom) -> tuple[bool, bool]:
|
||||
if type(x) is Var:
|
||||
return env[x]
|
||||
return (False, True)
|
||||
@ -1259,12 +1259,12 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
# * a list of Var instances representing residuals to be added (i.e. to be
|
||||
# plumbed as outputs of the 'known' side jaxpr and added as input binders to
|
||||
# the 'unknown' jaxpr).
|
||||
PartialEvalCustomResult = Tuple[Optional[JaxprEqn], Optional[JaxprEqn],
|
||||
Sequence[bool], Sequence[bool], List[Var]]
|
||||
PartialEvalCustomResult = tuple[Optional[JaxprEqn], Optional[JaxprEqn],
|
||||
Sequence[bool], Sequence[bool], list[Var]]
|
||||
PartialEvalCustomRule = Callable[
|
||||
[Callable[..., bool], Sequence[bool], Sequence[bool], JaxprEqn],
|
||||
PartialEvalCustomResult]
|
||||
partial_eval_jaxpr_custom_rules: Dict[Primitive, PartialEvalCustomRule] = {}
|
||||
partial_eval_jaxpr_custom_rules: dict[Primitive, PartialEvalCustomRule] = {}
|
||||
|
||||
def partial_eval_jaxpr_custom_rule_not_implemented(
|
||||
name: str, saveable: Callable[..., bool], unks_in: Sequence[bool],
|
||||
@ -1276,10 +1276,10 @@ def partial_eval_jaxpr_custom_rule_not_implemented(
|
||||
|
||||
ParamsUpdater = Callable[[Sequence[bool], Sequence[bool], Sequence[bool],
|
||||
Sequence[bool], int, dict, dict],
|
||||
Tuple[dict, dict]]
|
||||
ResAvalUpdater = Callable[[Dict[str, Any], AbstractValue], AbstractValue]
|
||||
tuple[dict, dict]]
|
||||
ResAvalUpdater = Callable[[dict[str, Any], AbstractValue], AbstractValue]
|
||||
def _default_res_aval_updater(
|
||||
params: Dict[str, Any], aval: AbstractValue) -> AbstractValue:
|
||||
params: dict[str, Any], aval: AbstractValue) -> AbstractValue:
|
||||
return aval
|
||||
|
||||
@contextmanager
|
||||
@ -1287,10 +1287,10 @@ def trivial_ctx(_): yield
|
||||
|
||||
def call_partial_eval_custom_rule(
|
||||
jaxpr_param_name: str, params_updater: ParamsUpdater,
|
||||
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
|
||||
saveable: Callable[..., bool], unks_in: list[bool], inst_in: list[bool],
|
||||
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
|
||||
ctx: Callable[[core.ParamDict], AbstractContextManager[None]] = trivial_ctx,
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
|
||||
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
|
||||
jaxpr = eqn.params[jaxpr_param_name]
|
||||
with ctx(eqn.params):
|
||||
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
|
||||
@ -1319,9 +1319,9 @@ def call_partial_eval_custom_rule(
|
||||
|
||||
def closed_call_partial_eval_custom_rule(
|
||||
jaxpr_param_name: str, params_updater: ParamsUpdater,
|
||||
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
|
||||
saveable: Callable[..., bool], unks_in: list[bool], inst_in: list[bool],
|
||||
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
|
||||
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
|
||||
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.
|
||||
closed_jaxpr = eqn.params[jaxpr_param_name]
|
||||
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_out, num_res_ref = \
|
||||
@ -1369,9 +1369,9 @@ partial_eval_jaxpr_custom_rules[core.closed_call_p] = \
|
||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
||||
|
||||
|
||||
def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
|
||||
def _jaxpr_forwarding(jaxpr: Jaxpr) -> list[Optional[int]]:
|
||||
# Compute which inputs are just forwarded to outputs.
|
||||
fwds: Dict[Var, Var] = dict(zip(jaxpr.invars, jaxpr.invars))
|
||||
fwds: dict[Var, Var] = dict(zip(jaxpr.invars, jaxpr.invars))
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive in forwarding_rules:
|
||||
eqn = eqn.replace(invars=[a if type(a) is Literal else fwds.get(a, a) # type: ignore
|
||||
@ -1380,14 +1380,14 @@ def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
|
||||
for v_orig, v_new in zip(eqn.outvars, fwd_vars):
|
||||
if v_new is not None:
|
||||
fwds[v_orig] = v_new
|
||||
idxs: Dict[Var, int] = {v: i for i, v in enumerate(jaxpr.invars)}
|
||||
idxs: dict[Var, int] = {v: i for i, v in enumerate(jaxpr.invars)}
|
||||
return [None if type(v) is Literal else idxs.get(fwds.get(v)) # type: ignore
|
||||
for v in jaxpr.outvars]
|
||||
|
||||
|
||||
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool],
|
||||
instantiate: Union[bool, Sequence[bool]] = False,
|
||||
) -> Tuple[Jaxpr, List[bool]]:
|
||||
) -> tuple[Jaxpr, list[bool]]:
|
||||
if type(instantiate) is bool:
|
||||
instantiate = (instantiate,) * len(jaxpr.invars)
|
||||
return _dce_jaxpr(jaxpr, tuple(used_outputs), tuple(instantiate))
|
||||
@ -1395,7 +1395,7 @@ def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool],
|
||||
|
||||
def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
|
||||
instantiate: Union[bool, Sequence[bool]] = False,
|
||||
) -> Tuple[Jaxpr, List[bool], List[bool]]:
|
||||
) -> tuple[Jaxpr, list[bool], list[bool]]:
|
||||
jaxpr_ = convert_constvars_jaxpr(jaxpr)
|
||||
new_jaxpr_, used_inputs_ = dce_jaxpr(jaxpr_, used_outputs)
|
||||
used_consts, used_inputs = split_list(used_inputs_, [len(jaxpr.constvars)])
|
||||
@ -1404,10 +1404,10 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...],
|
||||
instantiate: Tuple[bool, ...]
|
||||
) -> Tuple[Jaxpr, List[bool]]:
|
||||
env: Dict[Var, bool] = {}
|
||||
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
|
||||
instantiate: tuple[bool, ...]
|
||||
) -> tuple[Jaxpr, list[bool]]:
|
||||
env: dict[Var, bool] = {}
|
||||
|
||||
def read(v: Var) -> bool:
|
||||
return env.get(v, False)
|
||||
@ -1448,18 +1448,18 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...],
|
||||
|
||||
return new_jaxpr, used_inputs
|
||||
|
||||
DCERule = Callable[[List[bool], JaxprEqn], Tuple[List[bool], Optional[JaxprEqn]]]
|
||||
DCERule = Callable[[list[bool], JaxprEqn], tuple[list[bool], Optional[JaxprEqn]]]
|
||||
|
||||
def _default_dce_rule(
|
||||
used_outs: List[bool], eqn: JaxprEqn
|
||||
) -> Tuple[List[bool], JaxprEqn]:
|
||||
used_outs: list[bool], eqn: JaxprEqn
|
||||
) -> tuple[list[bool], JaxprEqn]:
|
||||
return [True] * len(eqn.invars), eqn
|
||||
|
||||
dce_rules: Dict[Primitive, DCERule] = {}
|
||||
dce_rules: dict[Primitive, DCERule] = {}
|
||||
|
||||
|
||||
def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
|
||||
) -> Tuple[List[bool], Optional[JaxprEqn]]:
|
||||
def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn
|
||||
) -> tuple[list[bool], Optional[JaxprEqn]]:
|
||||
new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
|
||||
new_params = dict(eqn.params, call_jaxpr=new_jaxpr)
|
||||
update_params = call_param_updaters.get(eqn.primitive)
|
||||
@ -1476,8 +1476,8 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
|
||||
dce_rules[core.call_p] = dce_jaxpr_call_rule
|
||||
|
||||
|
||||
def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn
|
||||
) -> Tuple[List[bool], JaxprEqn]:
|
||||
def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn
|
||||
) -> tuple[list[bool], JaxprEqn]:
|
||||
# TODO(mattjj): de-duplicate with above rule?
|
||||
jaxpr_ = eqn.params['call_jaxpr']
|
||||
jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts
|
||||
@ -1500,7 +1500,7 @@ def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]
|
||||
return _move_binders_to_front(closed_jaxpr, tuple(to_move))
|
||||
|
||||
@weakref_lru_cache
|
||||
def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Tuple[bool, ...]
|
||||
def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: tuple[bool, ...]
|
||||
) -> ClosedJaxpr:
|
||||
assert len(closed_jaxpr.in_avals) == len(to_move)
|
||||
new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move)
|
||||
@ -1611,12 +1611,12 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
|
||||
|
||||
class JaxprStackFrame:
|
||||
gensym: Callable[[AbstractValue], Var]
|
||||
tracer_to_var: Dict[TracerId, Var]
|
||||
constid_to_tracer: Dict[ConstId, Tracer]
|
||||
constvar_to_val: Dict[Var, Any]
|
||||
tracers: List[DynamicJaxprTracer] # hold onto strong refs for all tracers
|
||||
eqns: List[JaxprEqn]
|
||||
invars: List[Var]
|
||||
tracer_to_var: dict[TracerId, Var]
|
||||
constid_to_tracer: dict[ConstId, Tracer]
|
||||
constvar_to_val: dict[Var, Any]
|
||||
tracers: list[DynamicJaxprTracer] # hold onto strong refs for all tracers
|
||||
eqns: list[JaxprEqn]
|
||||
invars: list[Var]
|
||||
effects: core.Effects
|
||||
debug_info: Optional[DebugInfo]
|
||||
|
||||
@ -1634,7 +1634,7 @@ class JaxprStackFrame:
|
||||
def add_eqn(self, eqn: core.JaxprEqn):
|
||||
self.eqns.append(eqn)
|
||||
|
||||
def to_jaxpr(self, out_tracers: Sequence[Tracer]) -> Tuple[Jaxpr, List[Any]]:
|
||||
def to_jaxpr(self, out_tracers: Sequence[Tracer]) -> tuple[Jaxpr, list[Any]]:
|
||||
# It's not necessary, but we keep the tracer-to-var mapping injective:
|
||||
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
|
||||
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
|
||||
@ -1687,9 +1687,9 @@ class JaxprStackFrame:
|
||||
return invar_positions, const_eqns
|
||||
|
||||
def _const_folding_and_forwarding(
|
||||
jaxpr: Jaxpr, constvals: Sequence[Any]) -> Tuple[Jaxpr, Tuple[Any, ...]]:
|
||||
consts: Dict[Var, Any] = dict(zip(jaxpr.constvars, constvals))
|
||||
var_subs: Dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined
|
||||
jaxpr: Jaxpr, constvals: Sequence[Any]) -> tuple[Jaxpr, tuple[Any, ...]]:
|
||||
consts: dict[Var, Any] = dict(zip(jaxpr.constvars, constvals))
|
||||
var_subs: dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined
|
||||
new_eqns = []
|
||||
for eqn in jaxpr.eqns:
|
||||
# always apply invar substitutions
|
||||
@ -1723,18 +1723,18 @@ def _const_folding_and_forwarding(
|
||||
jaxpr_effects, jaxpr.debug_info)
|
||||
return new_jaxpr, new_constvals
|
||||
|
||||
ConstFoldRule = Callable[[List[Optional[Any]], JaxprEqn],
|
||||
Tuple[List[Optional[Any]], Optional[JaxprEqn]]]
|
||||
const_fold_rules: Dict[Primitive, ConstFoldRule] = {}
|
||||
ConstFoldRule = Callable[[list[Optional[Any]], JaxprEqn],
|
||||
tuple[list[Optional[Any]], Optional[JaxprEqn]]]
|
||||
const_fold_rules: dict[Primitive, ConstFoldRule] = {}
|
||||
|
||||
ForwardingRule = Callable[[JaxprEqn],
|
||||
Tuple[List[Optional[Var]], Optional[JaxprEqn]]]
|
||||
forwarding_rules: Dict[Primitive, ForwardingRule] = {}
|
||||
tuple[list[Optional[Var]], Optional[JaxprEqn]]]
|
||||
forwarding_rules: dict[Primitive, ForwardingRule] = {}
|
||||
|
||||
|
||||
def _inline_literals(
|
||||
jaxpr: Jaxpr, constvals: Sequence[Any]
|
||||
) -> Tuple[Jaxpr, List[Any]]:
|
||||
) -> tuple[Jaxpr, list[Any]]:
|
||||
# This function also prunes unused constants and inserts `dropvar` symbols.
|
||||
input_effects = {eff for eff in jaxpr.effects
|
||||
if isinstance(eff, effects.JaxprInputEffect)}
|
||||
@ -1746,7 +1746,7 @@ def _inline_literals(
|
||||
if type(c) in core.literalable_types and not np.shape(c) and not e}
|
||||
lit: Callable[[Var], Optional[Literal]] = lits.get
|
||||
newname: Callable[[AbstractValue], Var] = core.gensym()
|
||||
newvars: Dict[Var, Var] = {}
|
||||
newvars: dict[Var, Var] = {}
|
||||
newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval))
|
||||
var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval))
|
||||
dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval))
|
||||
@ -2052,7 +2052,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
return out_tracers
|
||||
|
||||
|
||||
custom_staging_rules: Dict[Primitive, Callable] = {}
|
||||
custom_staging_rules: dict[Primitive, Callable] = {}
|
||||
|
||||
@lu.transformation
|
||||
def _interleave_fun(every_others, *args, **kwargs):
|
||||
@ -2113,7 +2113,7 @@ def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
|
||||
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
|
||||
return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
|
||||
|
||||
def arg_info_all(dbg: DebugInfo) -> Optional[List[Tuple[str, KeyPath]]]:
|
||||
def arg_info_all(dbg: DebugInfo) -> Optional[list[tuple[str, KeyPath]]]:
|
||||
ba = None if dbg.in_tree is None else sig_info(dbg)
|
||||
if ba is None: return None
|
||||
return [(name, key_path) for name, dummy_arg in ba.arguments.items()
|
||||
@ -2131,7 +2131,7 @@ def sig_info(dbg: DebugInfo) -> Optional[inspect.BoundArguments]:
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
def result_info(dbg: DebugInfo) -> Optional[List[KeyPath]]:
|
||||
def result_info(dbg: DebugInfo) -> Optional[list[KeyPath]]:
|
||||
if dbg.out_tree is None: return None
|
||||
try:
|
||||
num_leaves = dbg.out_tree().num_leaves
|
||||
@ -2148,8 +2148,8 @@ def trace_to_jaxpr_dynamic(
|
||||
in_avals: Sequence[AbstractValue],
|
||||
debug_info: Optional[DebugInfo] = None,
|
||||
*,
|
||||
keep_inputs: Optional[List[bool]] = None,
|
||||
) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]:
|
||||
keep_inputs: Optional[list[bool]] = None,
|
||||
) -> tuple[Jaxpr, list[AbstractValue], list[Any]]:
|
||||
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
@ -2165,7 +2165,7 @@ def trace_to_subjaxpr_dynamic(
|
||||
*,
|
||||
keep_inputs: Optional[Sequence[bool]] = None,
|
||||
debug_info: Optional[DebugInfo] = None,
|
||||
) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]:
|
||||
) -> tuple[Jaxpr, list[AbstractValue], list[Any]]:
|
||||
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
|
||||
|
||||
frame = JaxprStackFrame()
|
||||
@ -2185,7 +2185,7 @@ def trace_to_subjaxpr_dynamic(
|
||||
@profiler.annotate_function
|
||||
def trace_to_jaxpr_dynamic2(
|
||||
fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None
|
||||
) -> Tuple[Jaxpr, OutputType, List[Any]]:
|
||||
) -> tuple[Jaxpr, OutputType, list[Any]]:
|
||||
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
|
||||
@ -2195,7 +2195,7 @@ def trace_to_jaxpr_dynamic2(
|
||||
def trace_to_subjaxpr_dynamic2(
|
||||
fun: lu.WrappedFun, main: core.MainTrace,
|
||||
debug_info: Optional[DebugInfo] = None
|
||||
) -> Tuple[Jaxpr, OutputType, List[Any]]:
|
||||
) -> tuple[Jaxpr, OutputType, list[Any]]:
|
||||
in_avals, keep_inputs = unzip2(fun.in_type)
|
||||
frame = JaxprStackFrame()
|
||||
frame.debug_info = debug_info
|
||||
@ -2226,7 +2226,7 @@ def trace_to_jaxpr_final(
|
||||
in_avals: Sequence[AbstractValue],
|
||||
debug_info: Optional[DebugInfo] = None,
|
||||
keep_inputs: Optional[Sequence[bool]] = None,
|
||||
) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]:
|
||||
) -> tuple[Jaxpr, list[AbstractValue], list[Any]]:
|
||||
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
with core.new_sublevel():
|
||||
@ -2239,7 +2239,7 @@ def trace_to_jaxpr_final(
|
||||
@profiler.annotate_function
|
||||
def trace_to_jaxpr_final2(
|
||||
fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None
|
||||
) -> Tuple[Jaxpr, OutputType, List[Any]]:
|
||||
) -> tuple[Jaxpr, OutputType, list[Any]]:
|
||||
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
with core.new_sublevel():
|
||||
@ -2249,8 +2249,8 @@ def trace_to_jaxpr_final2(
|
||||
|
||||
|
||||
AbstractedAxisName = Hashable
|
||||
AbstractedAxesSpec = Union[Dict[int, AbstractedAxisName],
|
||||
Tuple[AbstractedAxisName, ...]]
|
||||
AbstractedAxesSpec = Union[dict[int, AbstractedAxisName],
|
||||
tuple[AbstractedAxisName, ...]]
|
||||
def infer_lambda_input_type(
|
||||
axes_specs: Optional[Sequence[AbstractedAxesSpec]],
|
||||
args: Sequence[Any]
|
||||
@ -2265,7 +2265,7 @@ def infer_lambda_input_type(
|
||||
lu._check_input_type(input_type)
|
||||
return input_type
|
||||
|
||||
def _spec_to_dict(spec: AbstractedAxesSpec) -> Dict[int, AbstractedAxisName]:
|
||||
def _spec_to_dict(spec: AbstractedAxesSpec) -> dict[int, AbstractedAxisName]:
|
||||
if isinstance(spec, tuple):
|
||||
return {i: d for i, d in enumerate(spec) if d is not None}
|
||||
else:
|
||||
@ -2273,15 +2273,15 @@ def _spec_to_dict(spec: AbstractedAxesSpec) -> Dict[int, AbstractedAxisName]:
|
||||
|
||||
def _canonicalize_specs(
|
||||
ndims: Sequence[int], specs: Optional[Sequence[AbstractedAxesSpec]]
|
||||
) -> List[Dict[int, AbstractedAxisName]]:
|
||||
) -> list[dict[int, AbstractedAxisName]]:
|
||||
if specs is None:
|
||||
return [{}] * len(ndims)
|
||||
else:
|
||||
return [_spec_to_dict(s) for n, s in zip(ndims, specs)]
|
||||
|
||||
def _complete_specs(
|
||||
args: Sequence[Any], partial_specs: List[Dict[int, AbstractedAxisName]]
|
||||
) -> List[Dict[int, AbstractedAxisName]]:
|
||||
args: Sequence[Any], partial_specs: list[dict[int, AbstractedAxisName]]
|
||||
) -> list[dict[int, AbstractedAxisName]]:
|
||||
# The abstracted axes specification in `partial_specs` is partial in the sense
|
||||
# that there could be additional axis abstraction represented in `args` due to
|
||||
# Tracers existing in the shapes of elements of `args`. The purpose of this
|
||||
@ -2291,16 +2291,16 @@ def _complete_specs(
|
||||
# names (with one new name per unique Tracer object id).
|
||||
|
||||
# Identify each user-supplied name in partial_specs with a size.
|
||||
sizes: Dict[AbstractedAxisName, Union[int, DynamicJaxprTracer]] = {}
|
||||
sizes: dict[AbstractedAxisName, Union[int, DynamicJaxprTracer]] = {}
|
||||
for x, spec in zip(args, partial_specs):
|
||||
for i, name in spec.items():
|
||||
d = sizes.setdefault(name, x.shape[i])
|
||||
if d is not x.shape[i] and d != x.shape[i]: raise TypeError
|
||||
|
||||
# Introduce new names as needed for Tracers in shapes.
|
||||
named_tracers: Dict[TracerId, AbstractedAxisName] = {
|
||||
named_tracers: dict[TracerId, AbstractedAxisName] = {
|
||||
id(d): name for name, d in sizes.items() if isinstance(d, Tracer)}
|
||||
specs: List[Dict[int, AbstractedAxisName]] = []
|
||||
specs: list[dict[int, AbstractedAxisName]] = []
|
||||
for x, spec in zip(args, partial_specs):
|
||||
if isinstance(get_aval(x), DShapedArray):
|
||||
spec = dict(spec)
|
||||
@ -2318,15 +2318,15 @@ def _complete_specs(
|
||||
|
||||
|
||||
def _collect_implicit(
|
||||
args: Sequence[Any], specs: List[Dict[int, AbstractedAxisName]]
|
||||
) -> Tuple[Dict[AbstractedAxisName, DBIdx], List[AbstractValue]]:
|
||||
args: Sequence[Any], specs: list[dict[int, AbstractedAxisName]]
|
||||
) -> tuple[dict[AbstractedAxisName, DBIdx], list[AbstractValue]]:
|
||||
# Given an explicit argument list and a specification of abstracted axes, we
|
||||
# want to produce an InputType by identifying AbstractedAxisNames with DBIdxs
|
||||
# and figuring out which AbstractedAxisNames correspond to implicit arguments.
|
||||
|
||||
idxs: Dict[AbstractedAxisName, DBIdx] = {}
|
||||
implicit_types: List[AbstractValue] = []
|
||||
explicit_tracers: Dict[TracerId, int] = {}
|
||||
idxs: dict[AbstractedAxisName, DBIdx] = {}
|
||||
implicit_types: list[AbstractValue] = []
|
||||
explicit_tracers: dict[TracerId, int] = {}
|
||||
counter = it.count()
|
||||
|
||||
# Add implicit arguments to idxs.
|
||||
@ -2348,32 +2348,32 @@ def _collect_implicit(
|
||||
return idxs, implicit_types
|
||||
|
||||
def _arg_type(
|
||||
idxs: Dict[AbstractedAxisName, DBIdx], x: Any,
|
||||
spec: Dict[int, AbstractedAxisName]
|
||||
idxs: dict[AbstractedAxisName, DBIdx], x: Any,
|
||||
spec: dict[int, AbstractedAxisName]
|
||||
) -> AbstractValue:
|
||||
# Produce an AbstractValue by substituting DBIdxs for AbstractedAxisNames.
|
||||
aval = get_aval(x) # aval.shape could contain Tracers
|
||||
if not spec: return core.raise_to_shaped(aval)
|
||||
shape: List[Union[int, DBIdx]] = [idxs[spec[i]] if i in spec else d
|
||||
shape: list[Union[int, DBIdx]] = [idxs[spec[i]] if i in spec else d
|
||||
for i, d in enumerate(aval.shape)]
|
||||
assert not any(isinstance(d, Tracer) for d in shape)
|
||||
return DShapedArray(tuple(shape), aval.dtype, False)
|
||||
|
||||
def _add_implicit_outputs(jaxpr: Jaxpr) -> Tuple[Jaxpr, OutputType]:
|
||||
def _add_implicit_outputs(jaxpr: Jaxpr) -> tuple[Jaxpr, OutputType]:
|
||||
invars = [*jaxpr.constvars, *jaxpr.invars]
|
||||
expl_outvars = jaxpr.outvars
|
||||
|
||||
# First do a pass to collect implicit outputs, meaning variables which occurr
|
||||
# in explicit_outvars types but not in invars or to the left in outvars.
|
||||
seen: Set[Var] = set(invars)
|
||||
seen: set[Var] = set(invars)
|
||||
impl_outvars = [seen.add(d) or d for x in expl_outvars if type(x) is Var and # type: ignore
|
||||
(seen.add(x) or type(x.aval) is DShapedArray) # type: ignore
|
||||
for d in x.aval.shape if type(d) is Var and d not in seen]
|
||||
outvars = [*impl_outvars, *expl_outvars]
|
||||
|
||||
# Now assemble an OutputType by mapping vars in shapes to InDBIdx/OutDBIdx.
|
||||
in_map : Dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)}
|
||||
out_map: Dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars)
|
||||
in_map : dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)}
|
||||
out_map: dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars)
|
||||
if type(x) is Var}
|
||||
out_avals_ = (x.aval for x in outvars)
|
||||
out_avals = [a.update(shape=tuple(in_map.get(d, out_map.get(d))
|
||||
@ -2398,7 +2398,7 @@ class TracerAsName:
|
||||
return id(self.ref)
|
||||
|
||||
def _extract_implicit_args(
|
||||
trace: DynamicJaxprTrace, in_type: Sequence[Tuple[AbstractValue, bool]],
|
||||
trace: DynamicJaxprTrace, in_type: Sequence[tuple[AbstractValue, bool]],
|
||||
explicit_tracers: Sequence[DynamicJaxprTracer]
|
||||
) -> Sequence[DynamicJaxprTracer]:
|
||||
# First, construct a list to represent the full argument list, leaving the
|
||||
@ -2430,7 +2430,7 @@ def _input_type_to_tracers(
|
||||
# DeBruijn indices which refer to positions in the input argument list. That
|
||||
# is, each element `a` of `in_avals` can have DBIdx instances in its shape,
|
||||
# which must refer to positions left of `a`'s.
|
||||
in_tracers: List[Tracer] = []
|
||||
in_tracers: list[Tracer] = []
|
||||
|
||||
def _substitute_tracers_in_aval(a: AbstractValue) -> AbstractValue:
|
||||
if isinstance(a, DShapedArray) and any(type(d) is DBIdx for d in a.shape):
|
||||
@ -2443,7 +2443,7 @@ def _input_type_to_tracers(
|
||||
return in_tracers
|
||||
|
||||
def _substitute_vars_in_type(
|
||||
consts: Dict[Var, Literal], env: Dict[Var, Var], a: AbstractValue
|
||||
consts: dict[Var, Literal], env: dict[Var, Var], a: AbstractValue
|
||||
) -> AbstractValue:
|
||||
if isinstance(a, DShapedArray) and any(isinstance(d, Var) for d in a.shape):
|
||||
shape = [consts[d].val if d in consts else env[d] # type: ignore
|
||||
@ -2494,7 +2494,7 @@ Const = Any
|
||||
Val = Any
|
||||
|
||||
def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const]
|
||||
) -> Tuple[Jaxpr, List[Const]]:
|
||||
) -> tuple[Jaxpr, list[Const]]:
|
||||
bounds = {v: v.aval.dtype.bound for v in jaxpr.invars
|
||||
if isinstance(v.aval, core.UnshapedArray) and
|
||||
type(v.aval.dtype) is core.bint and not v.aval.shape}
|
||||
@ -2521,9 +2521,9 @@ class BoundedAxisSize(NamedTuple):
|
||||
bound: int
|
||||
|
||||
def _eval_jaxpr_padded(
|
||||
jaxpr: Jaxpr, consts: List[Const], *args: DynamicJaxprTracer
|
||||
) -> List[Union[Const, DynamicJaxprTracer]]:
|
||||
env: Dict[Var, Val] = {}
|
||||
jaxpr: Jaxpr, consts: list[Const], *args: DynamicJaxprTracer
|
||||
) -> list[Union[Const, DynamicJaxprTracer]]:
|
||||
env: dict[Var, Val] = {}
|
||||
|
||||
def read(x):
|
||||
return x.val if type(x) is Literal else env[x]
|
||||
@ -2543,7 +2543,7 @@ def _eval_jaxpr_padded(
|
||||
core.clean_up_dead_vars(eqn, env, last_used)
|
||||
return map(read, jaxpr.outvars)
|
||||
|
||||
def _substitute_axis_sizes(env: Dict, aval: AbstractValue) -> AbstractValue:
|
||||
def _substitute_axis_sizes(env: dict, aval: AbstractValue) -> AbstractValue:
|
||||
if isinstance(aval, DShapedArray):
|
||||
shp = []
|
||||
for d in aval.shape:
|
||||
@ -2570,7 +2570,7 @@ def _is_bint_axis_size(d: Union[int, core.DArray, core.Var]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
padding_rules: Dict[Primitive, Callable] = {}
|
||||
padding_rules: dict[Primitive, Callable] = {}
|
||||
|
||||
def def_trivial_padding(prim: Primitive) -> None:
|
||||
if prim.multiple_results:
|
||||
|
@ -23,9 +23,8 @@ from functools import partial, lru_cache, cached_property
|
||||
import itertools as it
|
||||
import logging
|
||||
import math
|
||||
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
||||
Sequence, Set, Tuple, Type, Union, Iterable,
|
||||
TYPE_CHECKING, cast, TypeVar)
|
||||
from typing import (Any, Callable, NamedTuple, Optional, Sequence, Union,
|
||||
Iterable, TYPE_CHECKING, cast, TypeVar)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -81,7 +80,7 @@ unsafe_map, map = map, safe_map # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
|
||||
Index = Union[int, slice, tuple[Union[int, slice], ...]]
|
||||
|
||||
NoSharding = sharding_specs.NoSharding
|
||||
Chunked = sharding_specs.Chunked
|
||||
@ -140,7 +139,7 @@ def shard_args(
|
||||
return [shard_arg(arg, devices, indices[i], shardings[i])
|
||||
for i, arg in enumerate(args)]
|
||||
|
||||
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any, Any], Any]] = {}
|
||||
shard_arg_handlers: dict[Any, Callable[[Any, Any, Any, Any], Any]] = {}
|
||||
|
||||
def _shard_token(x, devices, indices, sharding):
|
||||
zeros = np.zeros((), dtype=np.dtype(np.bool_))
|
||||
@ -190,8 +189,8 @@ def batched_device_put(aval: core.ShapedArray,
|
||||
# from the input ShardingSpec, rather than the indices. However, this would
|
||||
# require duplicating the ordering logic of spec_to_indices, which is more
|
||||
# subtle and more likely to change than the index logic we have to support here.
|
||||
def as_slice_indices(arr: Any, idx: Index) -> Tuple[
|
||||
Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]:
|
||||
def as_slice_indices(arr: Any, idx: Index) -> tuple[
|
||||
tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
"""Returns start_indices, limit_indices, removed_dims"""
|
||||
start_indices = [0] * arr.ndim
|
||||
limit_indices = list(arr.shape)
|
||||
@ -220,7 +219,7 @@ def shard_aval(size, axis: int, aval):
|
||||
return shard_aval_handlers[type(aval)](size, axis, aval)
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No shard_aval handler for type: {type(aval)}") from err
|
||||
shard_aval_handlers: Dict[Type[core.AbstractValue], Callable[[int, int, Any], Any]] = {}
|
||||
shard_aval_handlers: dict[type[core.AbstractValue], Callable[[int, int, Any], Any]] = {}
|
||||
def _shard_abstract_array(size, axis: int, x):
|
||||
try:
|
||||
if x.shape[axis] != size:
|
||||
@ -235,8 +234,8 @@ shard_aval_handlers[ShapedArray] = _shard_abstract_array
|
||||
def local_aval_to_result_handler(
|
||||
aval: core.AbstractValue,
|
||||
sharding: sharding_impls.XLACompatibleSharding,
|
||||
indices: Optional[Tuple[Index, ...]],
|
||||
) -> Callable[[List[xc.ArrayImpl]], Any]:
|
||||
indices: Optional[tuple[Index, ...]],
|
||||
) -> Callable[[list[xc.ArrayImpl]], Any]:
|
||||
"""Returns a function for handling the raw buffers of a single output aval.
|
||||
|
||||
Args:
|
||||
@ -258,7 +257,7 @@ def local_aval_to_result_handler(
|
||||
f"No pxla_result_handler for type: {type(aval)}") from err
|
||||
|
||||
PxlaResultHandler = Callable[..., Callable[[Any], Any]]
|
||||
local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
local_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
|
||||
|
||||
def global_aval_to_result_handler(
|
||||
@ -288,7 +287,7 @@ def global_aval_to_result_handler(
|
||||
raise TypeError(
|
||||
f"No pxla_result_handler for type: {type(aval)}") from err
|
||||
|
||||
global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
global_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
|
||||
### lazy device-memory persistence and result handling
|
||||
|
||||
@ -297,8 +296,8 @@ def make_sharded_device_array(
|
||||
aval: ShapedArray,
|
||||
sharding_spec: Optional[ShardingSpec],
|
||||
# Any is for JAX extensions implementing their own buffer.
|
||||
device_buffers: List[Any],
|
||||
indices: Optional[Tuple[Index, ...]] = None,
|
||||
device_buffers: list[Any],
|
||||
indices: Optional[tuple[Index, ...]] = None,
|
||||
):
|
||||
"""Returns a ShardedDeviceArray implementation based on arguments.
|
||||
|
||||
@ -482,7 +481,7 @@ def _emap_impl(fun: lu.WrappedFun, *args,
|
||||
new_outvals.append(out)
|
||||
return new_outvals
|
||||
|
||||
def _map_schedule(idx: Tuple[Optional[int], ...]) -> Tuple[Optional[int], ...]:
|
||||
def _map_schedule(idx: tuple[Optional[int], ...]) -> tuple[Optional[int], ...]:
|
||||
# In order to do a multi-map (a simultaneous map over several axes), we will
|
||||
# nest several maps. Each time we do a map, we "remove" an input axis so we
|
||||
# need to update the remaining map axes. For example, if we are to map over
|
||||
@ -498,9 +497,9 @@ def _map_schedule(idx: Tuple[Optional[int], ...]) -> Tuple[Optional[int], ...]:
|
||||
# _function object_. Adding this annotation here lets us reuse the same pmap
|
||||
# callable for all equivalent primitive pmaps.
|
||||
@lru_cache()
|
||||
def _multi_pmap(f: Callable, info: EmapInfo, names: List[core.AxisName],
|
||||
all_axes: List[Tuple[Optional[int], ...]]
|
||||
) -> Tuple[Callable, Dict[core.AxisName, int]]:
|
||||
def _multi_pmap(f: Callable, info: EmapInfo, names: list[core.AxisName],
|
||||
all_axes: list[tuple[Optional[int], ...]]
|
||||
) -> tuple[Callable, dict[core.AxisName, int]]:
|
||||
used_names = []
|
||||
for i, name in reversed(list(enumerate(names))):
|
||||
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
|
||||
@ -609,9 +608,9 @@ def _annot_to_flat(ndim: int, mapped_axes: Iterable[int],
|
||||
return [i for i in range(ndim) if i not in mapped_axes_][annotation]
|
||||
|
||||
def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any,
|
||||
shard_axis_src: Dict[core.AxisName, int],
|
||||
shard_axis_src: dict[core.AxisName, int],
|
||||
dst_annotation: Optional[int]
|
||||
) -> Tuple[Any, Dict[core.AxisName, int]]:
|
||||
) -> tuple[Any, dict[core.AxisName, int]]:
|
||||
shard_axis_out = dict(shard_axis_src)
|
||||
src = shard_axis_out.pop(axis_name, None)
|
||||
dst = _annot_to_flat(np.ndim(val) + (src is None), shard_axis_out.values(),
|
||||
@ -629,9 +628,9 @@ def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any,
|
||||
raise NotImplementedError
|
||||
return outval, shard_axis_out
|
||||
|
||||
def _moveaxis(ndim: int, shard_axes: Dict[core.AxisName, int],
|
||||
src: int, dst: int) -> Dict[core.AxisName, int]:
|
||||
lst: List[Optional[core.AxisName]] = [None] * ndim
|
||||
def _moveaxis(ndim: int, shard_axes: dict[core.AxisName, int],
|
||||
src: int, dst: int) -> dict[core.AxisName, int]:
|
||||
lst: list[Optional[core.AxisName]] = [None] * ndim
|
||||
for k, v in shard_axes.items():
|
||||
lst[v] = k
|
||||
name = lst.pop(src)
|
||||
@ -641,7 +640,7 @@ def _moveaxis(ndim: int, shard_axes: Dict[core.AxisName, int],
|
||||
class MapTracer(core.Tracer):
|
||||
__slots__ = ["val", "shard_axes"]
|
||||
|
||||
def __init__(self, trace: MapTrace, val, shard_axes: Dict[core.AxisName, int]):
|
||||
def __init__(self, trace: MapTrace, val, shard_axes: dict[core.AxisName, int]):
|
||||
self._trace = trace
|
||||
self.val = val
|
||||
self.shard_axes = shard_axes
|
||||
@ -736,7 +735,7 @@ def find_replicas(
|
||||
|
||||
def stage_parallel_callable(
|
||||
pci: ParallelCallableInfo, fun: lu.WrappedFun
|
||||
) -> Tuple[core.Jaxpr, List[Any], ReplicaInfo, ShardInfo]:
|
||||
) -> tuple[core.Jaxpr, list[Any], ReplicaInfo, ShardInfo]:
|
||||
sharded_avals = tuple(
|
||||
shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
|
||||
for axis, aval in safe_zip(pci.in_axes, pci.avals))
|
||||
@ -940,8 +939,8 @@ class UnloadedPmapExecutable:
|
||||
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||
local_output_avals: Sequence[ShapedArray]
|
||||
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||
unordered_effects: List[core.Effect]
|
||||
ordered_effects: List[core.Effect]
|
||||
unordered_effects: list[core.Effect]
|
||||
ordered_effects: list[core.Effect]
|
||||
keepalive: Sequence[Any]
|
||||
host_callbacks: Sequence[Any]
|
||||
jaxpr_debug_info: core.JaxprDebugInfo
|
||||
@ -979,9 +978,9 @@ class UnloadedPmapExecutable:
|
||||
replicas: ReplicaInfo,
|
||||
shards: ShardInfo,
|
||||
tuple_args: bool,
|
||||
unordered_effects: List[core.Effect],
|
||||
ordered_effects: List[core.Effect],
|
||||
host_callbacks: List[Any],
|
||||
unordered_effects: list[core.Effect],
|
||||
ordered_effects: list[core.Effect],
|
||||
host_callbacks: list[Any],
|
||||
keepalive: Any,
|
||||
jaxpr_debug_info: core.JaxprDebugInfo,
|
||||
compiler_options=None):
|
||||
@ -1156,7 +1155,7 @@ def _get_pmap_sharding(devices, specs):
|
||||
return [sharding_impls.PmapSharding(devices, spec) for spec in specs]
|
||||
|
||||
|
||||
multi_host_supported_collectives: Set[core.Primitive] = set()
|
||||
multi_host_supported_collectives: set[core.Primitive] = set()
|
||||
|
||||
|
||||
def check_multihost_collective_allowlist(jaxpr):
|
||||
@ -1287,9 +1286,9 @@ class ExecuteReplicated:
|
||||
|
||||
def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
|
||||
out_handler: ResultsHandler,
|
||||
unordered_effects: List[core.Effect],
|
||||
ordered_effects: List[core.Effect], keepalive: Any,
|
||||
has_host_callbacks: bool, kept_var_idx: Set[int]):
|
||||
unordered_effects: list[core.Effect],
|
||||
ordered_effects: list[core.Effect], keepalive: Any,
|
||||
has_host_callbacks: bool, kept_var_idx: set[int]):
|
||||
self.xla_executable = xla_executable
|
||||
self.name = name
|
||||
self.backend = backend
|
||||
@ -1582,7 +1581,7 @@ class SPMDBatchTrace(batching.BatchTrace):
|
||||
return super().get_axis_primitive_batcher(primitive, frame)
|
||||
|
||||
|
||||
spmd_primitive_batchers: Dict[core.Primitive, Callable] = {}
|
||||
spmd_primitive_batchers: dict[core.Primitive, Callable] = {}
|
||||
|
||||
|
||||
def vtile_by_mesh(fun: lu.WrappedFun,
|
||||
@ -1611,7 +1610,7 @@ def _full_to_shard_abstract_eval(x, axes, mesh, **_):
|
||||
|
||||
def manual_proto(
|
||||
aval: core.ShapedArray,
|
||||
manual_axes_set: FrozenSet[sharding_impls.MeshAxisName], mesh: Mesh):
|
||||
manual_axes_set: frozenset[sharding_impls.MeshAxisName], mesh: Mesh):
|
||||
"""Create an OpSharding proto that declares all mesh axes from `axes` as manual
|
||||
and all others as replicated.
|
||||
"""
|
||||
@ -1638,7 +1637,7 @@ def manual_proto(
|
||||
|
||||
@partial(mlir.register_lowering, full_to_shard_p)
|
||||
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
|
||||
manual_axes: FrozenSet[sharding_impls.MeshAxisName]):
|
||||
manual_axes: frozenset[sharding_impls.MeshAxisName]):
|
||||
# TODO: Can we short-circuit for replicated values? Probably not.
|
||||
aval_in, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
@ -1658,7 +1657,7 @@ def _shard_to_full_abstract_eval(x, axes, mesh, **_):
|
||||
|
||||
@partial(mlir.register_lowering, shard_to_full_p)
|
||||
def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapping, mesh: Mesh,
|
||||
manual_axes: FrozenSet[sharding_impls.MeshAxisName]):
|
||||
manual_axes: frozenset[sharding_impls.MeshAxisName]):
|
||||
aval_in, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
proto = manual_proto(aval_in, manual_axes, mesh) # type: ignore
|
||||
@ -1669,7 +1668,7 @@ def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapp
|
||||
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto, unspecified_dims),
|
||||
|
||||
@lu.transformation
|
||||
def vtile_manual(manual_axes: FrozenSet[sharding_impls.MeshAxisName],
|
||||
def vtile_manual(manual_axes: frozenset[sharding_impls.MeshAxisName],
|
||||
mesh: Mesh,
|
||||
in_axes: Sequence[ArrayMapping],
|
||||
out_axes: Sequence[ArrayMapping],
|
||||
@ -1688,7 +1687,7 @@ class TileVectorize:
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TileManual:
|
||||
manual_axes: FrozenSet[sharding_impls.MeshAxisName]
|
||||
manual_axes: frozenset[sharding_impls.MeshAxisName]
|
||||
|
||||
TilingMethod = Union[TileVectorize, TileManual]
|
||||
|
||||
@ -1756,7 +1755,7 @@ class DeviceAssignmentMismatchError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
ShardingInfo = Tuple[
|
||||
ShardingInfo = tuple[
|
||||
Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO],
|
||||
MismatchType, Optional[Any]] # Any is dispatch.SourceInfo to avoid circular imports
|
||||
|
||||
@ -1768,7 +1767,7 @@ def _get_default_device() -> xc.Device:
|
||||
def _get_and_check_device_assignment(
|
||||
shardings: Iterable[ShardingInfo],
|
||||
devices: Optional[Sequence[xc.Device]],
|
||||
) -> Tuple[xc.Client, Tuple[xc.Device, ...]]:
|
||||
) -> tuple[xc.Client, tuple[xc.Device, ...]]:
|
||||
first_sharding_info = None
|
||||
if devices is None:
|
||||
devices = ()
|
||||
@ -1851,7 +1850,7 @@ def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SemanticallyEqualShardings:
|
||||
shardings: Tuple[Union[sharding_impls.GSPMDSharding, UnspecifiedValue], ...]
|
||||
shardings: tuple[Union[sharding_impls.GSPMDSharding, UnspecifiedValue], ...]
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(
|
||||
@ -1895,8 +1894,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
dispatch.raise_warnings_or_errors_for_jit_of_pmap(
|
||||
nreps, backend, fun_name, jaxpr)
|
||||
|
||||
in_mlir_shardings: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
|
||||
out_mlir_shardings: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
|
||||
in_mlir_shardings: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]]
|
||||
out_mlir_shardings: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]]
|
||||
axis_ctx: mlir.AxisContext
|
||||
|
||||
if nreps == 1:
|
||||
@ -1952,7 +1951,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _DeviceAssignment:
|
||||
device_assignment: Tuple[xc.Device, ...]
|
||||
device_assignment: tuple[xc.Device, ...]
|
||||
|
||||
@cached_property
|
||||
def _hash(self):
|
||||
@ -1980,7 +1979,7 @@ class _DeviceAssignment:
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _create_da_object(
|
||||
device_assignment: Tuple[xc.Device, ...]) -> _DeviceAssignment:
|
||||
device_assignment: tuple[xc.Device, ...]) -> _DeviceAssignment:
|
||||
return _DeviceAssignment(device_assignment)
|
||||
|
||||
|
||||
@ -2161,7 +2160,7 @@ def lower_mesh_computation(
|
||||
|
||||
# 1. Trace to jaxpr and preprocess/verify it
|
||||
if spmd_lowering:
|
||||
manual_axes: FrozenSet[MeshAxisName] = frozenset()
|
||||
manual_axes: frozenset[MeshAxisName] = frozenset()
|
||||
# TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
|
||||
if tiling_method is not None:
|
||||
if isinstance(tiling_method, TileVectorize):
|
||||
@ -2217,8 +2216,8 @@ def lower_mesh_computation(
|
||||
# 2. Build up the HLO
|
||||
tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform)
|
||||
|
||||
in_partitions: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
|
||||
out_partitions: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
|
||||
in_partitions: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]]
|
||||
out_partitions: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]]
|
||||
axis_ctx: mlir.AxisContext
|
||||
if spmd_lowering:
|
||||
in_partitions = map(_to_logical_sharding, global_in_avals, in_shardings)
|
||||
@ -2330,7 +2329,7 @@ class MeshComputation(stages.XlaLowering):
|
||||
return executable
|
||||
return self._executable
|
||||
|
||||
def cost_analysis(self) -> Dict[str, float]:
|
||||
def cost_analysis(self) -> dict[str, float]:
|
||||
backend = self.compile_args["backend"]
|
||||
if xb.using_pjrt_c_api(backend):
|
||||
raise NotImplementedError(
|
||||
@ -2352,7 +2351,7 @@ def _get_input_indices(
|
||||
avals: Sequence[ShapedArray],
|
||||
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
da_object: Union[_DeviceAssignment, Sequence[xc.Device]],
|
||||
) -> Sequence[Tuple[Optional[Index], ...]]:
|
||||
) -> Sequence[tuple[Optional[Index], ...]]:
|
||||
|
||||
input_indices = []
|
||||
if isinstance(da_object, _DeviceAssignment):
|
||||
@ -2378,7 +2377,7 @@ def _get_input_indices(
|
||||
def get_gspmd_shardings_from_executable(
|
||||
xla_executable, device_assignment: Sequence[xc.Device],
|
||||
num_in_avals: int, num_out_avals: int
|
||||
) -> Tuple[Sequence[sharding_impls.XLACompatibleSharding],
|
||||
) -> tuple[Sequence[sharding_impls.XLACompatibleSharding],
|
||||
Sequence[sharding_impls.XLACompatibleSharding]]:
|
||||
from jax.experimental import pjit
|
||||
|
||||
@ -2411,7 +2410,7 @@ def get_gspmd_shardings_from_executable(
|
||||
# without mesh.
|
||||
def _get_mesh_pspec_shardings_from_executable(
|
||||
xla_executable, mesh: Mesh
|
||||
) -> Tuple[Sequence[sharding_impls.NamedSharding],
|
||||
) -> tuple[Sequence[sharding_impls.NamedSharding],
|
||||
Sequence[sharding_impls.NamedSharding]]:
|
||||
from jax.experimental import pjit
|
||||
|
||||
@ -2421,7 +2420,7 @@ def _get_mesh_pspec_shardings_from_executable(
|
||||
|
||||
|
||||
SubClassT = TypeVar("SubClassT", bound=sharding_impls.XLACompatibleSharding)
|
||||
OrigHandlerType = Dict[Type[SubClassT],
|
||||
OrigHandlerType = dict[type[SubClassT],
|
||||
Callable[[xc.OpSharding, SubClassT], SubClassT]]
|
||||
|
||||
orig_out_sharding_handlers: OrigHandlerType = {}
|
||||
@ -2569,11 +2568,11 @@ class UnloadedMeshExecutable:
|
||||
committed: bool
|
||||
are_out_shardings_from_xla: Sequence[bool]
|
||||
name: str
|
||||
unordered_effects: List[core.Effect]
|
||||
ordered_effects: List[core.Effect]
|
||||
unordered_effects: list[core.Effect]
|
||||
ordered_effects: list[core.Effect]
|
||||
keepalive: Sequence[Any]
|
||||
host_callbacks: Sequence[Any]
|
||||
kept_var_idx: Set[int]
|
||||
kept_var_idx: set[int]
|
||||
auto_spmd_lowering: bool
|
||||
jaxpr_debug_info: Optional[core.JaxprDebugInfo]
|
||||
|
||||
@ -2611,11 +2610,11 @@ class UnloadedMeshExecutable:
|
||||
spmd_lowering: bool,
|
||||
tuple_args: bool,
|
||||
auto_spmd_lowering: bool,
|
||||
unordered_effects: List[core.Effect],
|
||||
ordered_effects: List[core.Effect],
|
||||
host_callbacks: List[Any],
|
||||
unordered_effects: list[core.Effect],
|
||||
ordered_effects: list[core.Effect],
|
||||
host_callbacks: list[Any],
|
||||
keepalive: Any,
|
||||
kept_var_idx: Set[int],
|
||||
kept_var_idx: set[int],
|
||||
backend: xb.XlaBackend,
|
||||
device_assignment: Union[_DeviceAssignment, Sequence[xc.Device]],
|
||||
committed: bool,
|
||||
@ -2875,7 +2874,7 @@ def _out_shardings_for_trivial(
|
||||
jaxpr: core.Jaxpr, consts: Sequence[Any],
|
||||
in_shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
device_assignment: Sequence[xc.Device],
|
||||
) -> List[sharding_impls.XLACompatibleSharding]:
|
||||
) -> list[sharding_impls.XLACompatibleSharding]:
|
||||
# For each jaxpr output, compute a Sharding by:
|
||||
# * if the output is a forwarded input, get the corresponding in_sharding;
|
||||
# * if the output is a constant Array, get its .sharding attribute;
|
||||
@ -2893,7 +2892,7 @@ def _out_shardings_for_trivial(
|
||||
rep = sharding_impls.SingleDeviceSharding(dev)
|
||||
in_shardings = (sharding_impls.SingleDeviceSharding(dev),) * len(in_shardings)
|
||||
|
||||
shardings: Dict[core.Var, sharding_impls.XLACompatibleSharding] = {}
|
||||
shardings: dict[core.Var, sharding_impls.XLACompatibleSharding] = {}
|
||||
for constvar, constval in zip(jaxpr.constvars, consts):
|
||||
if isinstance(constval, array.ArrayImpl):
|
||||
shardings[constvar] = constval.sharding
|
||||
@ -2903,7 +2902,7 @@ def _out_shardings_for_trivial(
|
||||
|
||||
|
||||
def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args):
|
||||
env: Dict[core.Var, Any] = {}
|
||||
env: dict[core.Var, Any] = {}
|
||||
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
|
||||
map(env.setdefault, jaxpr.invars, pruned_args)
|
||||
map(env.setdefault, jaxpr.constvars, consts)
|
||||
@ -3037,7 +3036,7 @@ def _sanitize_mesh_jaxpr(jaxpr):
|
||||
core.traverse_jaxpr_params(_sanitize_mesh_jaxpr, eqn.params)
|
||||
|
||||
|
||||
custom_resource_typing_rules: Dict[core.Primitive, Callable] = {}
|
||||
custom_resource_typing_rules: dict[core.Primitive, Callable] = {}
|
||||
|
||||
def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
|
||||
if isinstance(jaxpr, core.ClosedJaxpr):
|
||||
@ -3113,7 +3112,7 @@ def maybe_extend_axis_env(*args, **kwargs):
|
||||
|
||||
|
||||
def device_put(x, devices: Sequence[xc.ArrayImpl],
|
||||
replicate: bool=False) -> List[xc.ArrayImpl]:
|
||||
replicate: bool=False) -> list[xc.ArrayImpl]:
|
||||
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
|
||||
if replicate:
|
||||
return [jax.device_put(x, device) for device in devices]
|
||||
|
@ -22,8 +22,7 @@ import itertools as it
|
||||
import math
|
||||
import operator
|
||||
import re
|
||||
from typing import (Any, Callable, Dict, Optional, Protocol,
|
||||
Sequence, Set, Type, Tuple, Union)
|
||||
from typing import Any, Callable, Optional, Protocol, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -91,7 +90,7 @@ def parameter(builder, num, shape, name=None, replicated=None):
|
||||
# nesting in this type definition.
|
||||
SpatialSharding = Union[Shape,
|
||||
None,
|
||||
Tuple[Optional[Shape], ...]]
|
||||
tuple[Optional[Shape], ...]]
|
||||
|
||||
def sharding_to_proto(sharding: SpatialSharding):
|
||||
"""Converts a SpatialSharding to an OpSharding.
|
||||
@ -145,7 +144,7 @@ def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]:
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err
|
||||
|
||||
xla_shape_handlers: Dict[Type[core.AbstractValue],
|
||||
xla_shape_handlers: dict[type[core.AbstractValue],
|
||||
Callable[[Any], Sequence[xc.Shape]]] = {
|
||||
ShapedArray: _make_array_shape,
|
||||
ConcreteArray: _make_array_shape,
|
||||
@ -179,7 +178,7 @@ def _canonicalize_python_scalar_dtype(typ, x):
|
||||
return np.asarray(
|
||||
x, dtypes.canonicalize_dtype(dtypes._scalar_type_to_dtype(typ, x)))
|
||||
|
||||
canonicalize_dtype_handlers: Dict[Any, Callable] = {}
|
||||
canonicalize_dtype_handlers: dict[Any, Callable] = {}
|
||||
canonicalize_dtype_handlers.update(
|
||||
(t, _canonicalize_ndarray_dtype) for t in numpy_scalar_types)
|
||||
canonicalize_dtype_handlers[np.ndarray] = _canonicalize_ndarray_dtype
|
||||
@ -217,7 +216,7 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
|
||||
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
|
||||
pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = {}
|
||||
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = {}
|
||||
pytype_aval_mappings[core.DArray] = operator.attrgetter('_aval')
|
||||
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
|
||||
pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar)
|
||||
@ -291,7 +290,7 @@ def axis_read(axis_env, axis_name):
|
||||
except ValueError:
|
||||
raise NameError(f"unbound axis name: {axis_name}") from None
|
||||
|
||||
def axis_groups(axis_env: AxisEnv, name) -> Tuple[Tuple[int, ...]]:
|
||||
def axis_groups(axis_env: AxisEnv, name) -> tuple[tuple[int, ...]]:
|
||||
if not isinstance(name, (list, tuple)):
|
||||
name = (name,)
|
||||
mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
|
||||
@ -374,12 +373,12 @@ if not MYPY:
|
||||
else:
|
||||
TranslationRule = Any
|
||||
|
||||
_translations: Dict[core.Primitive, TranslationRule] = {}
|
||||
_backend_specific_translations: Dict[str, Dict[core.Primitive, TranslationRule]]
|
||||
_translations: dict[core.Primitive, TranslationRule] = {}
|
||||
_backend_specific_translations: dict[str, dict[core.Primitive, TranslationRule]]
|
||||
_backend_specific_translations = defaultdict(dict)
|
||||
|
||||
_collective_primitives: Set[core.Primitive] = set()
|
||||
initial_style_primitives: Set[core.Primitive] = set()
|
||||
_collective_primitives: set[core.Primitive] = set()
|
||||
initial_style_primitives: set[core.Primitive] = set()
|
||||
|
||||
def register_initial_style_primitive(prim: core.Primitive):
|
||||
initial_style_primitives.add(prim)
|
||||
@ -441,5 +440,5 @@ class _BackendSpecificTranslationsAdapter(defaultdict):
|
||||
translation_tables, _wrap_old_translation)
|
||||
return ret
|
||||
|
||||
backend_specific_translations: Dict[str, _TranslationRuleAdapter]
|
||||
backend_specific_translations: dict[str, _TranslationRuleAdapter]
|
||||
backend_specific_translations = _BackendSpecificTranslationsAdapter()
|
||||
|
@ -24,7 +24,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import platform
|
||||
from typing import Any, List, Sequence, Optional
|
||||
from typing import Any, Sequence, Optional
|
||||
|
||||
import iree.compiler
|
||||
import iree.runtime
|
||||
@ -73,7 +73,7 @@ class IreeDevice:
|
||||
def transfer_from_outfeed(self, shape: xla_client.Shape):
|
||||
raise NotImplementedError("transfer_to_outfeed")
|
||||
|
||||
def live_buffers(self) -> List[IreeBuffer]:
|
||||
def live_buffers(self) -> list[IreeBuffer]:
|
||||
raise NotImplementedError("live_buffers")
|
||||
|
||||
|
||||
@ -123,10 +123,10 @@ class IreeExecutable:
|
||||
self.module_object = module_object
|
||||
self.function_name = function_name
|
||||
|
||||
def local_devices(self) -> List[IreeDevice]:
|
||||
def local_devices(self) -> list[IreeDevice]:
|
||||
return self._devices
|
||||
|
||||
def execute(self, arguments: Sequence[IreeBuffer]) -> List[IreeBuffer]:
|
||||
def execute(self, arguments: Sequence[IreeBuffer]) -> list[IreeBuffer]:
|
||||
inputs = [arg.to_iree() for arg in arguments]
|
||||
outputs = self.module_object[self.function_name](*inputs)
|
||||
# TODO(phawkins): Have a way to just have it always return the list,
|
||||
@ -159,10 +159,10 @@ class IreeClient:
|
||||
def device_count(self) -> int:
|
||||
return len(self._devices)
|
||||
|
||||
def devices(self) -> List[IreeDevice]:
|
||||
def devices(self) -> list[IreeDevice]:
|
||||
return self._devices
|
||||
|
||||
def local_devices(self) -> List[IreeDevice]:
|
||||
def local_devices(self) -> list[IreeDevice]:
|
||||
return self._devices
|
||||
|
||||
def local_device_count(self) -> int:
|
||||
@ -170,7 +170,7 @@ class IreeClient:
|
||||
|
||||
def get_default_device_assignment(
|
||||
self,
|
||||
num_replicas: int) -> List[IreeDevice]:
|
||||
num_replicas: int) -> list[IreeDevice]:
|
||||
if num_replicas != 1:
|
||||
raise NotImplementedError("Only single-device computations implemented")
|
||||
return [self._devices[0]]
|
||||
|
@ -70,7 +70,7 @@ Todos::
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import (Any, Tuple)
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -98,7 +98,7 @@ def approx_max_k(operand: Array,
|
||||
reduction_dimension: int = -1,
|
||||
recall_target: float = 0.95,
|
||||
reduction_input_size_override: int = -1,
|
||||
aggregate_to_topk: bool = True) -> Tuple[Array, Array]:
|
||||
aggregate_to_topk: bool = True) -> tuple[Array, Array]:
|
||||
"""Returns max ``k`` values and their indices of the ``operand`` in an approximate manner.
|
||||
|
||||
See https://arxiv.org/abs/2206.14286 for the algorithm details.
|
||||
@ -157,7 +157,7 @@ def approx_min_k(operand: Array,
|
||||
reduction_dimension: int = -1,
|
||||
recall_target: float = 0.95,
|
||||
reduction_input_size_override: int = -1,
|
||||
aggregate_to_topk: bool = True) -> Tuple[Array, Array]:
|
||||
aggregate_to_topk: bool = True) -> tuple[Array, Array]:
|
||||
"""Returns min ``k`` values and their indices of the ``operand`` in an approximate manner.
|
||||
|
||||
See https://arxiv.org/abs/2206.14286 for the algorithm details.
|
||||
|
@ -15,7 +15,7 @@
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
from typing import Any, Callable, Optional, Sequence
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
@ -134,7 +134,7 @@ def _initial_style_jaxprs_with_common_consts(
|
||||
# b[] <- 2.0
|
||||
# in () }
|
||||
canonical_ref_indices = []
|
||||
canonical_refs: List[Any] = []
|
||||
canonical_refs: list[Any] = []
|
||||
tracer_id_to_canonical_id = {}
|
||||
all_nonref_consts = []
|
||||
canonical_ref_avals = []
|
||||
|
@ -19,7 +19,7 @@ import inspect
|
||||
import itertools
|
||||
import operator
|
||||
|
||||
from typing import Callable, Sequence, List, Tuple
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from jax import config
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
@ -511,7 +511,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
|
||||
# First, compute output unknowns (unks_out), where an output of the cond is
|
||||
# unknown if it would be unknown on any of the branches.
|
||||
unks_out: List[bool] = [False] * len(eqn.outvars)
|
||||
unks_out: list[bool] = [False] * len(eqn.outvars)
|
||||
for jaxpr in branches:
|
||||
_, _, unks_out_, _, _ = pe.partial_eval_jaxpr_custom(
|
||||
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=True,
|
||||
@ -520,9 +520,9 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
|
||||
# Next, use the computed output unknowns to build a known jaxpr and a staged
|
||||
# jaxpr for each branch.
|
||||
branches_known_ : List[core.ClosedJaxpr] = []
|
||||
branches_staged_: List[core.ClosedJaxpr] = []
|
||||
branch_res_avals: List[core.AbstractValue] = []
|
||||
branches_known_ : list[core.ClosedJaxpr] = []
|
||||
branches_staged_: list[core.ClosedJaxpr] = []
|
||||
branch_res_avals: list[core.AbstractValue] = []
|
||||
for jaxpr in branches:
|
||||
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
|
||||
pe.partial_eval_jaxpr_custom(
|
||||
@ -653,13 +653,13 @@ def _ordered_unique(xs):
|
||||
d = collections.OrderedDict((x, None) for x in xs)
|
||||
return list(d.keys())
|
||||
|
||||
def _cond_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn,
|
||||
) -> Tuple[List[bool], core.JaxprEqn]:
|
||||
def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
|
||||
) -> tuple[list[bool], core.JaxprEqn]:
|
||||
closed_branches = eqn.params['branches']
|
||||
branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches]
|
||||
|
||||
# First, compute which inputs are used in any branch (not including `pred`).
|
||||
used_inputs: List[bool] = [False] * (len(eqn.invars) - 1) # -1 for pred
|
||||
used_inputs: list[bool] = [False] * (len(eqn.invars) - 1) # -1 for pred
|
||||
for jaxpr in branches:
|
||||
_, used_inputs_ = pe.dce_jaxpr(jaxpr, used_outputs, instantiate=False)
|
||||
used_inputs = map(operator.or_, used_inputs, used_inputs_)
|
||||
|
@ -15,7 +15,7 @@
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple, TypeVar, Union
|
||||
from typing import Any, Callable, Generic, Optional, Sequence, TypeVar, Union
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
@ -92,7 +92,7 @@ def _hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
|
||||
def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef,
|
||||
state_avals: Sequence[core.AbstractValue]
|
||||
) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]:
|
||||
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
|
||||
f, out_tree_thunk = flatten_fun_nokwargs(
|
||||
lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree)))
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
@ -172,12 +172,12 @@ Carry = TypeVar('Carry')
|
||||
X = TypeVar('X')
|
||||
Y = TypeVar('Y')
|
||||
|
||||
def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
init: Carry,
|
||||
xs: X,
|
||||
length: Optional[int] = None,
|
||||
reverse: bool = False,
|
||||
unroll: int = 1) -> Tuple[Carry, Y]:
|
||||
unroll: int = 1) -> tuple[Carry, Y]:
|
||||
if not callable(f):
|
||||
raise TypeError("scan: f argument should be a callable.")
|
||||
if unroll < 1:
|
||||
@ -253,7 +253,7 @@ def _for_abstract_eval(*avals, jaxpr, **__):
|
||||
def _for_discharge_rule(in_avals, _, *args: Any, jaxpr: core.Jaxpr,
|
||||
reverse: bool, which_linear: Sequence[bool],
|
||||
nsteps: int, unroll: int
|
||||
) -> Tuple[Sequence[Optional[Any]], Sequence[Any]]:
|
||||
) -> tuple[Sequence[Optional[Any]], Sequence[Any]]:
|
||||
out_vals = for_p.bind(*args, jaxpr=jaxpr, reverse=reverse,
|
||||
which_linear=which_linear, nsteps=nsteps,
|
||||
unroll=unroll)
|
||||
@ -371,7 +371,7 @@ def _partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy):
|
||||
|
||||
_save_everything = lambda *_, **__: True
|
||||
|
||||
def _is_read_only(ref_effects: Set[StateEffect]) -> bool:
|
||||
def _is_read_only(ref_effects: set[StateEffect]) -> bool:
|
||||
assert len(ref_effects) > 0
|
||||
if len(ref_effects) > 1:
|
||||
# Means we must have a write or accum effect so not read-only
|
||||
@ -379,7 +379,7 @@ def _is_read_only(ref_effects: Set[StateEffect]) -> bool:
|
||||
eff, = ref_effects
|
||||
return isinstance(eff, ReadEffect)
|
||||
|
||||
def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> List[bool]:
|
||||
def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> list[bool]:
|
||||
# Get effects for each of the jaxpr inputs and remove the loop index.
|
||||
ref_effects = state_types.get_ref_state_effects(
|
||||
[v.aval for v in jaxpr.invars], jaxpr.effects)[1:]
|
||||
@ -406,8 +406,8 @@ def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> List[bool]:
|
||||
|
||||
def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
|
||||
jaxpr: core.Jaxpr, nsteps: int, reverse: bool,
|
||||
which_linear: Tuple[bool, ...],
|
||||
unroll: int) -> List[pe.JaxprTracer]:
|
||||
which_linear: tuple[bool, ...],
|
||||
unroll: int) -> list[pe.JaxprTracer]:
|
||||
num_inputs = len(tracers)
|
||||
assert num_inputs == len(jaxpr.invars) - 1
|
||||
in_unknowns = [not t.pval.is_known() for t in tracers]
|
||||
@ -636,7 +636,7 @@ pe.partial_eval_jaxpr_custom_rules[for_p] = _for_partial_eval_custom
|
||||
|
||||
def _convert_outputs_to_writes(
|
||||
nsteps: int, jaxpr: core.Jaxpr, loop_invar_res: Sequence[bool]
|
||||
) -> Tuple[core.Jaxpr, List[core.ShapedArray]]:
|
||||
) -> tuple[core.Jaxpr, list[core.ShapedArray]]:
|
||||
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."
|
||||
|
||||
in_avals = [v.aval for v in jaxpr.invars] # [i, *orig_ref_avals]
|
||||
@ -654,7 +654,7 @@ def _convert_outputs_to_writes(
|
||||
res_ref[i] = res_val
|
||||
return []
|
||||
# TODO(mattjj, sharadmv): better handling of tokens, which don't have shape/dtype
|
||||
res_ref_avals: List[core.AbstractValue] = [
|
||||
res_ref_avals: list[core.AbstractValue] = [
|
||||
AbstractRef(v.aval) if loop_invar else # pytype: disable=attribute-error
|
||||
AbstractRef(core.ShapedArray((nsteps, *v.aval.shape), # pytype: disable=attribute-error
|
||||
v.aval.dtype)) # pytype: disable=attribute-error
|
||||
@ -679,7 +679,7 @@ def _convert_inputs_to_reads(
|
||||
|
||||
res_val_avals, (i_aval,), orig_ref_avals = \
|
||||
split_list([v.aval for v in jaxpr.invars], [num_res, 1])
|
||||
res_ref_avals: List[core.AbstractValue] = [
|
||||
res_ref_avals: list[core.AbstractValue] = [
|
||||
AbstractRef(aval) if loop_invar else # pytype: disable=attribute-error
|
||||
AbstractRef(core.ShapedArray((nsteps, *aval.shape), # pytype: disable=attribute-error
|
||||
aval.dtype)) # pytype: disable=attribute-error
|
||||
@ -689,7 +689,7 @@ def _convert_inputs_to_reads(
|
||||
eval_jaxpr, [i_aval, *res_ref_avals, *orig_ref_avals])
|
||||
return jaxpr
|
||||
|
||||
def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: List[bool]) -> core.Jaxpr:
|
||||
def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: list[bool]) -> core.Jaxpr:
|
||||
def trans(i, *args):
|
||||
# First we want to run the computation to read all the residual refs. We can
|
||||
# do that by using partial evaluation with all linear inputs unknown.
|
||||
|
@ -16,7 +16,7 @@ from functools import partial
|
||||
import inspect
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar
|
||||
from typing import Any, Callable, Optional, Sequence, TypeVar
|
||||
|
||||
import jax
|
||||
import weakref
|
||||
@ -96,12 +96,12 @@ X = TypeVar('X')
|
||||
Y = TypeVar('Y')
|
||||
|
||||
@api_boundary
|
||||
def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
init: Carry,
|
||||
xs: X,
|
||||
length: Optional[int] = None,
|
||||
reverse: bool = False,
|
||||
unroll: int = 1) -> Tuple[Carry, Y]:
|
||||
unroll: int = 1) -> tuple[Carry, Y]:
|
||||
"""Scan a function over leading array axes while carrying along state.
|
||||
|
||||
The `Haskell-like type signature`_ in brief is
|
||||
@ -820,8 +820,8 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
|
||||
padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts))
|
||||
return scan_p.bind(*args, jaxpr=padded_jaxpr, **params)
|
||||
|
||||
def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
|
||||
) -> Tuple[List[bool], core.JaxprEqn]:
|
||||
def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
) -> tuple[list[bool], core.JaxprEqn]:
|
||||
jaxpr = eqn.params['jaxpr']
|
||||
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
|
||||
num_xs = len(jaxpr.in_avals) - num_consts - num_carry
|
||||
|
@ -15,7 +15,7 @@
|
||||
import builtins
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, NamedTuple, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -51,11 +51,11 @@ class ConvDimensionNumbers(NamedTuple):
|
||||
out_spec: Sequence[int]
|
||||
|
||||
ConvGeneralDilatedDimensionNumbers = Union[
|
||||
None, ConvDimensionNumbers, Tuple[str, str, str]]
|
||||
None, ConvDimensionNumbers, tuple[str, str, str]]
|
||||
|
||||
def conv_general_dilated(
|
||||
lhs: Array, rhs: Array, window_strides: Sequence[int],
|
||||
padding: Union[str, Sequence[Tuple[int, int]]],
|
||||
padding: Union[str, Sequence[tuple[int, int]]],
|
||||
lhs_dilation: Optional[Sequence[int]] = None,
|
||||
rhs_dilation: Optional[Sequence[int]] = None,
|
||||
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
|
||||
@ -199,7 +199,7 @@ def conv(lhs: Array, rhs: Array, window_strides: Sequence[int],
|
||||
|
||||
def conv_with_general_padding(lhs: Array, rhs: Array,
|
||||
window_strides: Sequence[int],
|
||||
padding: Union[str, Sequence[Tuple[int, int]]],
|
||||
padding: Union[str, Sequence[tuple[int, int]]],
|
||||
lhs_dilation: Optional[Sequence[int]],
|
||||
rhs_dilation: Optional[Sequence[int]],
|
||||
precision: lax.PrecisionLike = None,
|
||||
@ -271,7 +271,7 @@ def _flip_axes(x, axes):
|
||||
|
||||
|
||||
def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
|
||||
padding: Union[str, Sequence[Tuple[int, int]]],
|
||||
padding: Union[str, Sequence[tuple[int, int]]],
|
||||
rhs_dilation: Optional[Sequence[int]] = None,
|
||||
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
|
||||
transpose_kernel: bool = False,
|
||||
@ -330,7 +330,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
|
||||
k_shape = np.take(rhs.shape, dn.rhs_spec)
|
||||
k_sdims = k_shape[2:] # type: ignore[index]
|
||||
# Calculate correct output shape given padding and strides.
|
||||
pads: Union[str, Sequence[Tuple[int, int]]]
|
||||
pads: Union[str, Sequence[tuple[int, int]]]
|
||||
if isinstance(padding, str) and padding in {'SAME', 'VALID'}:
|
||||
if rhs_dilation is None:
|
||||
rhs_dilation = (1,) * (rhs.ndim - 2)
|
||||
@ -351,7 +351,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
|
||||
def _conv_general_dilated_shape_rule(
|
||||
lhs: core.ShapedArray, rhs: core.ShapedArray, *, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
|
||||
batch_group_count, **unused_kwargs) -> Tuple[int, ...]:
|
||||
batch_group_count, **unused_kwargs) -> tuple[int, ...]:
|
||||
assert type(dimension_numbers) is ConvDimensionNumbers
|
||||
if len(lhs.shape) != len(rhs.shape):
|
||||
msg = ("conv_general_dilated lhs and rhs must have the same number of "
|
||||
@ -730,7 +730,7 @@ def _conv_general_dilated_lower(
|
||||
# d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each
|
||||
# spatial dimension.
|
||||
int2d = mlir.aval_to_ir_type(core.ShapedArray((1, 2), np.int32))
|
||||
def prep_one_pad(pad_lo_hi: Tuple[core.DimSize, core.DimSize]):
|
||||
def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):
|
||||
pad1 = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, pad_lo_hi)) # i32[2]
|
||||
return hlo.ReshapeOp(int2d, pad1)
|
||||
d_padding = hlo.ConcatenateOp(list(map(prep_one_pad, padding)),
|
||||
@ -910,7 +910,7 @@ def conv_general_permutations(dimension_numbers):
|
||||
|
||||
def _conv_general_vjp_lhs_padding(
|
||||
in_shape, window_dimensions, window_strides, out_shape, padding,
|
||||
lhs_dilation, rhs_dilation) -> List[Tuple[int, int]]:
|
||||
lhs_dilation, rhs_dilation) -> list[tuple[int, int]]:
|
||||
lhs_dilated_shape = lax._dilate_shape(in_shape, lhs_dilation)
|
||||
rhs_dilated_shape = lax._dilate_shape(window_dimensions, rhs_dilation)
|
||||
out_dilated_shape = lax._dilate_shape(out_shape, window_strides)
|
||||
|
@ -19,8 +19,8 @@ from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
from typing import (Any, Callable, Optional, Sequence, Tuple, List,
|
||||
TypeVar, Union, cast as type_cast, overload)
|
||||
from typing import (Any, Callable, Optional, Sequence, TypeVar, Union,
|
||||
cast as type_cast, overload)
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -102,7 +102,7 @@ def _validate_shapes(shapes: Sequence[Shape]):
|
||||
map(_check_static_shape, shapes)
|
||||
|
||||
def _try_broadcast_shapes(
|
||||
shapes: Sequence[Tuple[int, ...]]) -> Optional[Tuple[int, ...]]:
|
||||
shapes: Sequence[tuple[int, ...]]) -> Optional[tuple[int, ...]]:
|
||||
if len(shapes) == 1: return shapes[0]
|
||||
rank, *others = {len(shape) for shape in shapes}
|
||||
if others: return None # must have consistent rank
|
||||
@ -134,11 +134,11 @@ def asarray(x: ArrayLike) -> Array:
|
||||
raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.")
|
||||
|
||||
@overload
|
||||
def broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]: ...
|
||||
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: ...
|
||||
|
||||
@overload
|
||||
def broadcast_shapes(*shapes: Tuple[Union[int, core.Tracer], ...]
|
||||
) -> Tuple[Union[int, core.Tracer], ...]: ...
|
||||
def broadcast_shapes(*shapes: tuple[Union[int, core.Tracer], ...]
|
||||
) -> tuple[Union[int, core.Tracer], ...]: ...
|
||||
|
||||
def broadcast_shapes(*shapes):
|
||||
"""Returns the shape that results from NumPy broadcasting of `shapes`."""
|
||||
@ -149,7 +149,7 @@ def broadcast_shapes(*shapes):
|
||||
return _broadcast_shapes_uncached(*shapes)
|
||||
|
||||
@cache()
|
||||
def _broadcast_shapes_cached(*shapes: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
|
||||
return _broadcast_shapes_uncached(*shapes)
|
||||
|
||||
def _broadcast_shapes_uncached(*shapes):
|
||||
@ -181,7 +181,7 @@ def _identity(x): return x
|
||||
|
||||
def _extract_tracers_dyn_shape(
|
||||
shape: Sequence[Union[int, core.Tracer]]
|
||||
) -> Tuple[List[core.Tracer], List[Optional[int]]]:
|
||||
) -> tuple[list[core.Tracer], list[Optional[int]]]:
|
||||
# Given a sequence representing a shape, pull out Tracers, replacing with None
|
||||
if config.jax_dynamic_shapes:
|
||||
# We must gate this behavior under a flag because otherwise the errors
|
||||
@ -195,7 +195,7 @@ def _extract_tracers_dyn_shape(
|
||||
def _merge_dyn_shape(
|
||||
static_shape: Sequence[Optional[int]],
|
||||
dyn_shape: Sequence[Any],
|
||||
) -> Tuple[Union[int, mlir.Value, core.Tracer], ...]:
|
||||
) -> tuple[Union[int, mlir.Value, core.Tracer], ...]:
|
||||
# Replace Nones in static_shape with elements of dyn_shape, in order
|
||||
dyn_shape_it = iter(dyn_shape)
|
||||
shape = tuple(next(dyn_shape_it) if d is None else d for d in static_shape)
|
||||
@ -663,8 +663,8 @@ class Precision(xla_client.PrecisionConfig.Precision): # type: ignore
|
||||
|
||||
|
||||
PrecisionType = Precision
|
||||
PrecisionLike = Union[None, str, PrecisionType, Tuple[str, str],
|
||||
Tuple[PrecisionType, PrecisionType]]
|
||||
PrecisionLike = Union[None, str, PrecisionType, tuple[str, str],
|
||||
tuple[PrecisionType, PrecisionType]]
|
||||
|
||||
def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
|
||||
preferred_element_type: Optional[DTypeLike] = None) -> Array:
|
||||
@ -699,8 +699,8 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
|
||||
lhs.shape, rhs.shape))
|
||||
|
||||
|
||||
DotDimensionNumbers = Tuple[Tuple[Sequence[int], Sequence[int]],
|
||||
Tuple[Sequence[int], Sequence[int]]]
|
||||
DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
|
||||
tuple[Sequence[int], Sequence[int]]]
|
||||
|
||||
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
|
||||
precision: PrecisionLike = None,
|
||||
@ -860,7 +860,7 @@ def reshape(operand: ArrayLike, new_sizes: Shape,
|
||||
dimensions=None if dims is None or same_dims else dims)
|
||||
|
||||
def pad(operand: ArrayLike, padding_value: ArrayLike,
|
||||
padding_config: Sequence[Tuple[int, int, int]]) -> Array:
|
||||
padding_config: Sequence[tuple[int, int, int]]) -> Array:
|
||||
"""Applies low, high, and/or interior padding to an array.
|
||||
|
||||
Wraps XLA's `Pad
|
||||
@ -1111,14 +1111,14 @@ def _reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
|
||||
@overload
|
||||
def sort(operand: Sequence[Array], dimension: int = -1,
|
||||
is_stable: bool = True, num_keys: int = 1) -> Tuple[Array, ...]: ...
|
||||
is_stable: bool = True, num_keys: int = 1) -> tuple[Array, ...]: ...
|
||||
|
||||
@overload
|
||||
def sort(operand: Array, dimension: int = -1,
|
||||
is_stable: bool = True, num_keys: int = 1) -> Array: ...
|
||||
|
||||
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
|
||||
is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]:
|
||||
is_stable: bool = True, num_keys: int = 1) -> Union[Array, tuple[Array, ...]]:
|
||||
"""Wraps XLA's `Sort
|
||||
<https://www.tensorflow.org/xla/operation_semantics#sort>`_ operator.
|
||||
|
||||
@ -1154,13 +1154,13 @@ def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
|
||||
return sort_p.bind(operand, dimension=dimension, is_stable=is_stable, num_keys=1)[0]
|
||||
|
||||
def sort_key_val(keys: Array, values: ArrayLike, dimension: int = -1,
|
||||
is_stable: bool = True) -> Tuple[Array, Array]:
|
||||
is_stable: bool = True) -> tuple[Array, Array]:
|
||||
"""Sorts ``keys`` along ``dimension`` and applies the same permutation to ``values``."""
|
||||
dimension = canonicalize_axis(dimension, len(keys.shape))
|
||||
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, num_keys=1)
|
||||
return k, v
|
||||
|
||||
def top_k(operand: ArrayLike, k: int) -> Tuple[Array, Array]:
|
||||
def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]:
|
||||
"""Returns top ``k`` values and their indices along the last axis of ``operand``.
|
||||
|
||||
Args:
|
||||
@ -4833,7 +4833,7 @@ def remaining(original, *removed_lists):
|
||||
return [i for i in original if i not in removed]
|
||||
|
||||
|
||||
def canonicalize_precision(precision: PrecisionLike) -> Optional[Tuple[PrecisionType, PrecisionType]]:
|
||||
def canonicalize_precision(precision: PrecisionLike) -> Optional[tuple[PrecisionType, PrecisionType]]:
|
||||
"""Turns an API precision specification, into a pair of enumeration values.
|
||||
|
||||
The API can take the precision as a string, or int, and either as a single
|
||||
@ -4844,7 +4844,7 @@ def canonicalize_precision(precision: PrecisionLike) -> Optional[Tuple[Precision
|
||||
return None
|
||||
try:
|
||||
return type_cast(
|
||||
Tuple[PrecisionType, PrecisionType],
|
||||
tuple[PrecisionType, PrecisionType],
|
||||
(Precision(config.jax_default_matmul_precision),
|
||||
Precision(config.jax_default_matmul_precision)))
|
||||
except TypeError:
|
||||
@ -4853,18 +4853,18 @@ def canonicalize_precision(precision: PrecisionLike) -> Optional[Tuple[Precision
|
||||
f"{list(Precision._strings)}, but got {config.jax_default_matmul_precision}"
|
||||
) from None
|
||||
elif isinstance(precision, str) and precision in Precision._strings:
|
||||
return type_cast(Tuple[PrecisionType, PrecisionType],
|
||||
return type_cast(tuple[PrecisionType, PrecisionType],
|
||||
(Precision(precision), Precision(precision)))
|
||||
elif isinstance(precision, xla_client.PrecisionConfig.Precision):
|
||||
return type_cast(Tuple[PrecisionType, PrecisionType], (precision, precision))
|
||||
return type_cast(tuple[PrecisionType, PrecisionType], (precision, precision))
|
||||
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
|
||||
all(isinstance(p, xla_client.PrecisionConfig.Precision) for p in precision)):
|
||||
return type_cast(Tuple[PrecisionType, PrecisionType], precision)
|
||||
return type_cast(tuple[PrecisionType, PrecisionType], precision)
|
||||
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
|
||||
all(isinstance(s, str) for s in precision)):
|
||||
s1, s2 = precision
|
||||
p1 = type_cast(Tuple[PrecisionType, PrecisionType], canonicalize_precision(s1))[0]
|
||||
p2 = type_cast(Tuple[PrecisionType, PrecisionType], canonicalize_precision(s2))[0]
|
||||
p1 = type_cast(tuple[PrecisionType, PrecisionType], canonicalize_precision(s1))[0]
|
||||
p2 = type_cast(tuple[PrecisionType, PrecisionType], canonicalize_precision(s2))[0]
|
||||
return (p1, p2)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -16,7 +16,7 @@ import inspect
|
||||
import functools
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import cast, Any, Callable, List, Literal, Optional, Tuple, TypeVar, Union, overload
|
||||
from typing import cast, Any, Callable, Literal, Optional, TypeVar, Union, overload
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -137,7 +137,7 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
|
||||
|
||||
@_warn_on_positional_kwargs
|
||||
def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
|
||||
compute_right_eigenvectors: bool = True) -> List[Array]:
|
||||
compute_right_eigenvectors: bool = True) -> list[Array]:
|
||||
"""Eigendecomposition of a general matrix.
|
||||
|
||||
Nonsymmetric eigendecomposition is at present only implemented on CPU.
|
||||
@ -147,7 +147,7 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
|
||||
|
||||
@_warn_on_positional_kwargs
|
||||
def eigh(x: Array, *, lower: bool = True, symmetrize_input: bool = True,
|
||||
sort_eigenvalues: bool = True) -> Tuple[Array, Array]:
|
||||
sort_eigenvalues: bool = True) -> tuple[Array, Array]:
|
||||
r"""Eigendecomposition of a Hermitian matrix.
|
||||
|
||||
Computes the eigenvectors and eigenvalues of a complex Hermitian or real
|
||||
@ -200,7 +200,7 @@ def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array:
|
||||
return permutation
|
||||
|
||||
|
||||
def lu(x: ArrayLike) -> Tuple[Array, Array, Array]:
|
||||
def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
|
||||
"""LU decomposition with partial pivoting.
|
||||
|
||||
Computes the matrix decomposition:
|
||||
@ -234,7 +234,7 @@ def lu(x: ArrayLike) -> Tuple[Array, Array, Array]:
|
||||
return lu, pivots, permutation
|
||||
|
||||
@_warn_on_positional_kwargs
|
||||
def qr(x: ArrayLike, *, full_matrices: bool = True) -> Tuple[Array, Array]:
|
||||
def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
|
||||
"""QR decomposition.
|
||||
|
||||
Computes the QR decomposition
|
||||
@ -265,17 +265,17 @@ def qr(x: ArrayLike, *, full_matrices: bool = True) -> Tuple[Array, Array]:
|
||||
return q, r
|
||||
|
||||
@overload
|
||||
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[True]) -> Tuple[Array, Array, Array]: ...
|
||||
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[True]) -> tuple[Array, Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[False]) -> Array: ...
|
||||
|
||||
@overload
|
||||
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Union[Array, Tuple[Array, Array, Array]]: ...
|
||||
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Union[Array, tuple[Array, Array, Array]]: ...
|
||||
|
||||
# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
|
||||
@_warn_on_positional_kwargs
|
||||
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Union[Array, Tuple[Array, Array, Array]]:
|
||||
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Union[Array, tuple[Array, Array, Array]]:
|
||||
"""Singular value decomposition.
|
||||
|
||||
Returns the singular values if compute_uv is False, otherwise returns a triple
|
||||
@ -586,7 +586,7 @@ ad.primitive_jvps[eig_p] = eig_jvp_rule
|
||||
|
||||
|
||||
def eigh_jacobi(x: ArrayLike, *, lower: bool = True,
|
||||
sort_eigenvalues: bool = True) -> Tuple[Array, Array]:
|
||||
sort_eigenvalues: bool = True) -> tuple[Array, Array]:
|
||||
"""Helper Jacobi eigendecomposition implemented by XLA.
|
||||
|
||||
Used as a subroutine of QDWH-eig on TPU."""
|
||||
@ -1389,7 +1389,7 @@ def lu_solve(lu: ArrayLike, permutation: ArrayLike, b: ArrayLike,
|
||||
# geqrf and orgqr. The names, while cryptic Fortran alphabet soup, are LAPACK's
|
||||
# names for the primitives, and we stick with them for consistency.
|
||||
|
||||
def geqrf(a: ArrayLike) -> Tuple[Array, Array]:
|
||||
def geqrf(a: ArrayLike) -> tuple[Array, Array]:
|
||||
"""Computes the QR decomposition of a matrix.
|
||||
|
||||
Args:
|
||||
@ -2008,7 +2008,7 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
|
||||
def schur(x: ArrayLike, *,
|
||||
compute_schur_vectors: bool = True,
|
||||
sort_eig_vals: bool = False,
|
||||
select_callable: Optional[Callable[..., Any]] = None) -> Tuple[Array, Array]:
|
||||
select_callable: Optional[Callable[..., Any]] = None) -> tuple[Array, Array]:
|
||||
return schur_p.bind(
|
||||
x,
|
||||
compute_schur_vectors=compute_schur_vectors,
|
||||
@ -2115,7 +2115,7 @@ ad.primitive_jvps[schur_p] = _schur_jvp_rule
|
||||
|
||||
# hessenberg: Upper Hessenberg reduction
|
||||
|
||||
def hessenberg(a: ArrayLike) -> Tuple[Array, Array]:
|
||||
def hessenberg(a: ArrayLike) -> tuple[Array, Array]:
|
||||
"""Reduces a square matrix to upper Hessenberg form.
|
||||
|
||||
Currently implemented on CPU only.
|
||||
@ -2187,7 +2187,7 @@ mlir.register_lowering(hessenberg_p, _hessenberg_cpu_hlo, platform='cpu')
|
||||
# tridiagonal: Upper Hessenberg reduction
|
||||
|
||||
def tridiagonal(a: ArrayLike, *, lower=True
|
||||
) -> Tuple[Array, Array, Array, Array]:
|
||||
) -> tuple[Array, Array, Array, Array]:
|
||||
"""Reduces a symmetric/Hermitian matrix to tridiagonal form.
|
||||
|
||||
Currently implemented on CPU and GPU only.
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Optional, Sequence, Tuple, Union, cast as type_cast
|
||||
from typing import Any, Optional, Sequence, Union, cast as type_cast
|
||||
|
||||
import jax
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
@ -26,7 +26,7 @@ def conv_general_dilated_patches(
|
||||
lhs: jax.typing.ArrayLike,
|
||||
filter_shape: Sequence[int],
|
||||
window_strides: Sequence[int],
|
||||
padding: Union[str, Sequence[Tuple[int, int]]],
|
||||
padding: Union[str, Sequence[tuple[int, int]]],
|
||||
lhs_dilation: Optional[Sequence[int]] = None,
|
||||
rhs_dilation: Optional[Sequence[int]] = None,
|
||||
dimension_numbers: Optional[convolution.ConvGeneralDilatedDimensionNumbers] = None,
|
||||
@ -122,7 +122,7 @@ def conv_general_dilated_local(
|
||||
lhs: jax.typing.ArrayLike,
|
||||
rhs: jax.typing.ArrayLike,
|
||||
window_strides: Sequence[int],
|
||||
padding: Union[str, Sequence[Tuple[int, int]]],
|
||||
padding: Union[str, Sequence[tuple[int, int]]],
|
||||
filter_shape: Sequence[int],
|
||||
lhs_dilation: Optional[Sequence[int]] = None,
|
||||
rhs_dilation: Optional[Sequence[int]] = None,
|
||||
|
@ -25,7 +25,7 @@ https://epubs.siam.org/doi/abs/10.1137/090774999
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -196,7 +196,7 @@ def _qdwh(x, m, n, is_hermitian, max_iterations, eps):
|
||||
# TODO: Add pivoting.
|
||||
@functools.partial(jax.jit, static_argnames=('is_hermitian',))
|
||||
def qdwh(x, *, is_hermitian=False, max_iterations=None, eps=None,
|
||||
dynamic_shape: Optional[Tuple[int, int]] = None):
|
||||
dynamic_shape: Optional[tuple[int, int]] = None):
|
||||
"""QR-based dynamically weighted Halley iteration for polar decomposition.
|
||||
|
||||
Args:
|
||||
|
@ -15,7 +15,7 @@
|
||||
import enum
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Callable, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
from typing import Callable, NamedTuple, Optional, Sequence, Union
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
@ -178,9 +178,9 @@ class GatherDimensionNumbers(NamedTuple):
|
||||
implicit; there is always an index vector dimension and it must always be the
|
||||
last dimension. To gather scalar indices, add a trailing dimension of size 1.
|
||||
"""
|
||||
offset_dims: Tuple[int, ...]
|
||||
collapsed_slice_dims: Tuple[int, ...]
|
||||
start_index_map: Tuple[int, ...]
|
||||
offset_dims: tuple[int, ...]
|
||||
collapsed_slice_dims: tuple[int, ...]
|
||||
start_index_map: tuple[int, ...]
|
||||
|
||||
|
||||
class GatherScatterMode(enum.Enum):
|
||||
@ -692,7 +692,7 @@ def dynamic_slice_in_dim(operand: Union[Array, np.ndarray],
|
||||
start_index: ArrayLike,
|
||||
slice_size: int, axis: int = 0) -> Array:
|
||||
"""Convenience wrapper around dynamic_slice applying to one dimension."""
|
||||
start_indices: List[ArrayLike] = [lax._const(start_index, 0)] * operand.ndim
|
||||
start_indices: list[ArrayLike] = [lax._const(start_index, 0)] * operand.ndim
|
||||
slice_sizes = list(operand.shape)
|
||||
|
||||
axis = int(axis)
|
||||
@ -719,7 +719,7 @@ def dynamic_update_slice_in_dim(operand: Union[Array, np.ndarray],
|
||||
in a single ``axis``.
|
||||
"""
|
||||
axis = int(axis)
|
||||
start_indices: List[ArrayLike] = [lax._const(start_index, 0)] * lax._ndim(operand)
|
||||
start_indices: list[ArrayLike] = [lax._const(start_index, 0)] * lax._ndim(operand)
|
||||
start_indices[axis] = start_index
|
||||
return dynamic_update_slice(operand, update, start_indices)
|
||||
|
||||
@ -2149,7 +2149,7 @@ mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu")
|
||||
def _dynamic_slice_indices(
|
||||
operand: Union[Array, np.ndarray],
|
||||
start_indices: Union[Union[Array, np.ndarray], Sequence[ArrayLike]]
|
||||
) -> List[ArrayLike]:
|
||||
) -> list[ArrayLike]:
|
||||
# Normalize the start_indices w.r.t. operand.shape
|
||||
if len(start_indices) != operand.ndim:
|
||||
msg = ("Length of slice indices must match number of operand dimensions ({} "
|
||||
@ -2160,7 +2160,7 @@ def _dynamic_slice_indices(
|
||||
raise ValueError("Slice indices must be a 1D sequence, got {}"
|
||||
.format(start_indices.shape)) # type: ignore[union-attr]
|
||||
start_indices = list(start_indices)
|
||||
result: List[ArrayLike] = []
|
||||
result: list[ArrayLike] = []
|
||||
for i, d in zip(start_indices, operand.shape):
|
||||
# We test whether i and d are static to avoid unnecessary staging.
|
||||
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d):
|
||||
|
@ -20,7 +20,7 @@ Eigendecomposition on TPU.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Tuple
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
@ -60,7 +60,7 @@ class Stack:
|
||||
lambda x, y: lax.dynamic_update_index_in_dim(x, y, self._size, 0),
|
||||
self._data, elem))
|
||||
|
||||
def pop(self) -> Tuple[Any, Stack]:
|
||||
def pop(self) -> tuple[Any, Stack]:
|
||||
"""Pops from the stack, returning an (elem, updated stack) pair."""
|
||||
elem = jax.tree_util.tree_map(
|
||||
lambda x: lax.dynamic_index_in_dim(x, self._size - 1, 0, keepdims=False),
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Optional, Sequence, Union, Tuple)
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -44,7 +44,7 @@ Array = Any
|
||||
|
||||
def reduce_window(operand, init_value, computation: Callable,
|
||||
window_dimensions: core.Shape, window_strides: Sequence[int],
|
||||
padding: Union[str, Sequence[Tuple[int, int]]],
|
||||
padding: Union[str, Sequence[tuple[int, int]]],
|
||||
base_dilation: Optional[Sequence[int]] = None,
|
||||
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
||||
"""Wraps XLA's `ReduceWindowWithGeneralPadding
|
||||
@ -112,7 +112,7 @@ def _get_monoid_window_reducer(monoid_op: Callable,
|
||||
|
||||
def _reduce_window_sum(operand: Array, window_dimensions: core.Shape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[Tuple[int, int]],
|
||||
padding: Sequence[tuple[int, int]],
|
||||
base_dilation: Optional[Sequence[int]] = None,
|
||||
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
||||
if base_dilation is None:
|
||||
@ -127,7 +127,7 @@ def _reduce_window_sum(operand: Array, window_dimensions: core.Shape,
|
||||
|
||||
def _reduce_window_prod(operand: Array, window_dimensions: core.Shape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[Tuple[int, int]],
|
||||
padding: Sequence[tuple[int, int]],
|
||||
base_dilation: Optional[Sequence[int]] = None,
|
||||
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
||||
init_value = lax._const(operand, 1)
|
||||
@ -146,7 +146,7 @@ def _reduce_window_prod(operand: Array, window_dimensions: core.Shape,
|
||||
|
||||
def _reduce_window_max(operand: Array, window_dimensions: core.Shape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[Tuple[int, int]],
|
||||
padding: Sequence[tuple[int, int]],
|
||||
base_dilation: Optional[Sequence[int]] = None,
|
||||
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
||||
if base_dilation is None:
|
||||
@ -161,7 +161,7 @@ def _reduce_window_max(operand: Array, window_dimensions: core.Shape,
|
||||
|
||||
def _reduce_window_min(operand: Array, window_dimensions: core.Shape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[Tuple[int, int]],
|
||||
padding: Sequence[tuple[int, int]],
|
||||
base_dilation: Optional[Sequence[int]] = None,
|
||||
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
||||
if base_dilation is None:
|
||||
@ -177,7 +177,7 @@ def _reduce_window_min(operand: Array, window_dimensions: core.Shape,
|
||||
def _reduce_window_logaddexp(
|
||||
operand: Array, window_dimensions: core.Shape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[Tuple[int, int]],
|
||||
padding: Sequence[tuple[int, int]],
|
||||
base_dilation: Optional[Sequence[int]] = None,
|
||||
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
||||
init_value = lax._const(operand, -np.inf)
|
||||
@ -197,7 +197,7 @@ def _reduce_window_logaddexp(
|
||||
def _select_and_scatter(operand: Array, select: Callable,
|
||||
window_dimensions: core.Shape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[Tuple[int, int]], source: Array,
|
||||
padding: Sequence[tuple[int, int]], source: Array,
|
||||
init_value: Array, scatter: Callable) -> Array:
|
||||
select_jaxpr, select_consts = lax._reduction_jaxpr(
|
||||
select, lax._abstractify(init_value))
|
||||
@ -213,7 +213,7 @@ def _select_and_scatter_add(source: Array, operand: Array,
|
||||
select_prim: core.Primitive,
|
||||
window_dimensions: core.Shape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[Tuple[int, int]]) -> Array:
|
||||
padding: Sequence[tuple[int, int]]) -> Array:
|
||||
return select_and_scatter_add_p.bind(
|
||||
source, operand, select_prim=select_prim,
|
||||
window_dimensions=tuple(window_dimensions),
|
||||
@ -223,7 +223,7 @@ def _select_and_gather_add(tangents: Array, operand: Array,
|
||||
select_prim: core.Primitive,
|
||||
window_dimensions: core.Shape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[Tuple[int, int]],
|
||||
padding: Sequence[tuple[int, int]],
|
||||
base_dilation: Sequence[int],
|
||||
window_dilation: Sequence[int]) -> Array:
|
||||
"""Extracts the tangent corresponding to the minimum or maximum element in
|
||||
|
@ -15,13 +15,13 @@
|
||||
"""A LazyLoader class."""
|
||||
|
||||
import importlib
|
||||
from typing import Any, Callable, List, Sequence, Tuple
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
|
||||
def attach(package_name: str, submodules: Sequence[str]) -> Tuple[
|
||||
def attach(package_name: str, submodules: Sequence[str]) -> tuple[
|
||||
Callable[[str], Any],
|
||||
Callable[[], List[str]],
|
||||
List[str],
|
||||
Callable[[], list[str]],
|
||||
list[str],
|
||||
]:
|
||||
"""Lazily loads submodules of a package.
|
||||
|
||||
@ -31,14 +31,14 @@ def attach(package_name: str, submodules: Sequence[str]) -> Tuple[
|
||||
```
|
||||
"""
|
||||
|
||||
__all__: List[str] = list(submodules)
|
||||
__all__: list[str] = list(submodules)
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in submodules:
|
||||
return importlib.import_module(f"{package_name}.{name}")
|
||||
raise AttributeError(f"module '{package_name}' has no attribute '{name}")
|
||||
|
||||
def __dir__() -> List[str]:
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
|
||||
return __getattr__, __dir__, __all__
|
||||
|
@ -18,7 +18,7 @@
|
||||
import gc
|
||||
import pathlib
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
import jaxlib as jaxlib
|
||||
@ -41,12 +41,12 @@ except Exception as err:
|
||||
# Checks the jaxlib version before importing anything else from jaxlib.
|
||||
# Returns the jaxlib version string.
|
||||
def check_jaxlib_version(jax_version: str, jaxlib_version: str,
|
||||
minimum_jaxlib_version: str) -> Tuple[int, ...]:
|
||||
minimum_jaxlib_version: str) -> tuple[int, ...]:
|
||||
# Regex to match a dotted version prefix 0.1.23.456.789 of a PEP440 version.
|
||||
# PEP440 allows a number of non-numeric suffixes, which we allow also.
|
||||
# We currently do not allow an epoch.
|
||||
version_regex = re.compile(r"[0-9]+(?:\.[0-9]+)*")
|
||||
def _parse_version(v: str) -> Tuple[int, ...]:
|
||||
def _parse_version(v: str) -> tuple[int, ...]:
|
||||
m = version_regex.match(v)
|
||||
if m is None:
|
||||
raise ValueError(f"Unable to parse jaxlib version '{v}'")
|
||||
|
@ -65,7 +65,7 @@ from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Any, Tuple, Callable, Optional, NamedTuple
|
||||
from typing import Any, Callable, Optional, NamedTuple
|
||||
import weakref
|
||||
|
||||
from jax._src.tree_util import tree_map
|
||||
@ -242,7 +242,7 @@ def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
|
||||
|
||||
@curry
|
||||
def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args,
|
||||
use_eq_store=False) -> Tuple[WrappedFun, Any]:
|
||||
use_eq_store=False) -> tuple[WrappedFun, Any]:
|
||||
"""Adds one more transformation with auxiliary output to a WrappedFun."""
|
||||
out_store = Store() if not use_eq_store else EqualStore()
|
||||
out_thunk = lambda: out_store.val
|
||||
@ -303,8 +303,8 @@ class TracingDebugInfo(NamedTuple):
|
||||
# TODO(mattjj): delete partial_eval.DebugInfo, replace all uses with this cls
|
||||
traced_for: str # e.g. 'jit', 'scan', etc
|
||||
func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}'
|
||||
arg_names: Tuple[str, ...] # e.g. ('args[0]', ... )
|
||||
result_paths: Optional[Callable[[], Tuple[str, ...]]]
|
||||
arg_names: tuple[str, ...] # e.g. ('args[0]', ... )
|
||||
result_paths: Optional[Callable[[], tuple[str, ...]]]
|
||||
|
||||
def add_debug_info(f: WrappedFun, debug_info: Optional[TracingDebugInfo]
|
||||
) -> WrappedFun:
|
||||
|
@ -16,7 +16,7 @@ import contextlib
|
||||
import numpy as np
|
||||
import itertools as it
|
||||
from collections import OrderedDict, abc
|
||||
from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set,
|
||||
from typing import (Callable, Iterable, Optional, Any,
|
||||
NamedTuple, Union, Sequence, Mapping)
|
||||
from functools import wraps, partial, partialmethod, lru_cache
|
||||
import math
|
||||
@ -268,7 +268,7 @@ def _prepare_axes(axes, arg_name):
|
||||
return tree_unflatten(treedef, entries), entries, treedef
|
||||
|
||||
Resource = Union[ResourceAxisName, SerialLoop]
|
||||
ResourceSet = Union[Resource, Tuple[Resource, ...]]
|
||||
ResourceSet = Union[Resource, tuple[Resource, ...]]
|
||||
|
||||
# TODO: Some syntactic sugar to make the API more usable in a single-axis case?
|
||||
# TODO: Are the resource axes scoped lexically or dynamically? Dynamically for now!
|
||||
@ -491,7 +491,7 @@ def xmap(fun: Callable,
|
||||
f"in_axes or axis_sizes, but the following are missing: "
|
||||
f"{out_axes_names - defined_names}")
|
||||
|
||||
normalized_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]] = {}
|
||||
normalized_axis_resources: dict[AxisName, tuple[ResourceAxisName, ...]] = {}
|
||||
for axis in defined_names:
|
||||
resources = axis_resources.get(axis, ())
|
||||
if not isinstance(resources, tuple):
|
||||
@ -708,10 +708,10 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
class EvaluationPlan(NamedTuple):
|
||||
"""Encapsulates preprocessing common to top-level xmap invocations and its translation rule."""
|
||||
resource_env: ResourceEnv
|
||||
physical_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]]
|
||||
loop_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]]
|
||||
axis_subst_dict: Dict[AxisName, Tuple[ResourceAxisName, ...]]
|
||||
axis_vmap_size: Dict[AxisName, Optional[int]]
|
||||
physical_axis_resources: dict[AxisName, tuple[ResourceAxisName, ...]]
|
||||
loop_axis_resources: dict[AxisName, tuple[ResourceAxisName, ...]]
|
||||
axis_subst_dict: dict[AxisName, tuple[ResourceAxisName, ...]]
|
||||
axis_vmap_size: dict[AxisName, Optional[int]]
|
||||
|
||||
@property
|
||||
def axis_subst(self) -> core.AxisSubst:
|
||||
@ -729,15 +729,15 @@ class EvaluationPlan(NamedTuple):
|
||||
|
||||
@classmethod
|
||||
def from_axis_resources(cls,
|
||||
axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]],
|
||||
axis_resources: dict[AxisName, tuple[ResourceAxisName, ...]],
|
||||
resource_env: ResourceEnv,
|
||||
global_axis_sizes: Dict[AxisName, int]):
|
||||
global_axis_sizes: dict[AxisName, int]):
|
||||
physical_axis_resources, loop_axis_resources = _unzip_axis_resources(
|
||||
axis_resources, resource_env)
|
||||
axis_resource_count = _get_axis_resource_count(
|
||||
axis_resources, resource_env)
|
||||
axis_subst_dict = dict(axis_resources)
|
||||
axis_vmap_size: Dict[AxisName, Optional[int]] = {}
|
||||
axis_vmap_size: dict[AxisName, Optional[int]] = {}
|
||||
for naxis, raxes in sorted(axis_resources.items(), key=lambda x: str(x[0])):
|
||||
num_resources = axis_resource_count[naxis]
|
||||
assert global_axis_sizes[naxis] % num_resources.nglobal == 0
|
||||
@ -1038,7 +1038,7 @@ def _xmap_partial_eval_custom_params_updater(
|
||||
unks_in: Sequence[bool], inst_in: Sequence[bool],
|
||||
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
|
||||
num_res: int, params_known: dict, params_staged: dict
|
||||
) -> Tuple[dict, dict]:
|
||||
) -> tuple[dict, dict]:
|
||||
assert params_known['spmd_in_axes'] is None is params_known['spmd_out_axes']
|
||||
assert params_staged['spmd_in_axes'] is None is params_staged['spmd_out_axes']
|
||||
|
||||
@ -1573,7 +1573,7 @@ class ResourceCount(NamedTuple):
|
||||
|
||||
|
||||
def _get_axis_resource_count(
|
||||
axis_resources, resource_env) -> Dict[ResourceAxisName, ResourceCount]:
|
||||
axis_resources, resource_env) -> dict[ResourceAxisName, ResourceCount]:
|
||||
global_res_shape = resource_env.shape
|
||||
local_res_shape = None
|
||||
|
||||
@ -1593,8 +1593,8 @@ def _get_axis_resource_count(
|
||||
|
||||
def _get_axis_sizes(args_flat: Iterable[Any],
|
||||
in_axes_flat: Iterable[AxisNamePos],
|
||||
global_axis_sizes: Dict[AxisName, int],
|
||||
axis_resource_count: Dict[AxisName, ResourceCount]):
|
||||
global_axis_sizes: dict[AxisName, int],
|
||||
axis_resource_count: dict[AxisName, ResourceCount]):
|
||||
global_axis_sizes = dict(global_axis_sizes)
|
||||
for arg, in_axes in zip(args_flat, in_axes_flat):
|
||||
for name, dim in in_axes.items():
|
||||
@ -1643,7 +1643,7 @@ def hide_mapped_axes(flat_in_axes, flat_out_axes, *flat_args):
|
||||
yield map(_unsqueeze_mapped_axes, flat_outputs, flat_out_axes)
|
||||
|
||||
|
||||
def _jaxpr_resources(jaxpr, resource_env) -> Set[ResourceAxisName]:
|
||||
def _jaxpr_resources(jaxpr, resource_env) -> set[ResourceAxisName]:
|
||||
if isinstance(jaxpr, core.ClosedJaxpr):
|
||||
jaxpr = jaxpr.jaxpr
|
||||
assert isinstance(jaxpr, core.Jaxpr)
|
||||
@ -1661,7 +1661,7 @@ def _jaxpr_resources(jaxpr, resource_env) -> Set[ResourceAxisName]:
|
||||
|
||||
|
||||
def _to_resource_axes(axes_specs: Sequence[AxisNamePos],
|
||||
axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]]):
|
||||
axis_resources: dict[AxisName, tuple[ResourceAxisName, ...]]):
|
||||
"""
|
||||
Convert in/out_axes parameters ranging over logical dimensions to
|
||||
ones that range over resource dimensions.
|
||||
@ -1695,7 +1695,7 @@ def _slice_tile(x, dim: Optional[int], i, n: int):
|
||||
return lax.dynamic_slice_in_dim(x, i * tile_size, slice_size=tile_size, axis=dim)
|
||||
|
||||
|
||||
def _unzip_axis_resources(axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]],
|
||||
def _unzip_axis_resources(axis_resources: dict[AxisName, tuple[ResourceAxisName, ...]],
|
||||
resource_env: ResourceEnv):
|
||||
"""Splits axis_resources into separate dicts for physical and loop resources."""
|
||||
physical_axis_resources = {}
|
||||
@ -1718,7 +1718,7 @@ def _unzip_axis_resources(axis_resources: Dict[AxisName, Tuple[ResourceAxisName,
|
||||
|
||||
def _check_out_avals_vs_out_axes(out_avals: Sequence[core.AbstractValue],
|
||||
out_axes: Sequence[AxisNamePos],
|
||||
global_axis_sizes: Dict[AxisName, int]):
|
||||
global_axis_sizes: dict[AxisName, int]):
|
||||
defined_axes = set(global_axis_sizes)
|
||||
for aval, axes in zip(out_avals, out_axes):
|
||||
if not isinstance(aval, core.ShapedArray):
|
||||
|
@ -20,7 +20,7 @@ import contextlib
|
||||
import functools
|
||||
import math
|
||||
import threading
|
||||
from typing import Any, Hashable, NamedTuple, Set, Sequence, Tuple, Union
|
||||
from typing import Any, Hashable, NamedTuple, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -43,7 +43,7 @@ def show_axes(axes):
|
||||
|
||||
class ResourceEnv(NamedTuple):
|
||||
physical_mesh: Mesh
|
||||
loops: Tuple[Loop, ...]
|
||||
loops: tuple[Loop, ...]
|
||||
|
||||
def with_mesh(self, mesh: Mesh):
|
||||
overlap = set(mesh.axis_names) & (self.resource_axes - set(self.physical_mesh.axis_names))
|
||||
@ -60,15 +60,15 @@ class ResourceEnv(NamedTuple):
|
||||
return self._replace(loops=self.loops + (loop,))
|
||||
|
||||
@property
|
||||
def physical_resource_axes(self) -> Set[ResourceAxisName]:
|
||||
def physical_resource_axes(self) -> set[ResourceAxisName]:
|
||||
return set(self.physical_mesh.axis_names)
|
||||
|
||||
@property
|
||||
def loop_resource_axes(self) -> Set[ResourceAxisName]:
|
||||
def loop_resource_axes(self) -> set[ResourceAxisName]:
|
||||
return {loop.name for loop in self.loops}
|
||||
|
||||
@property
|
||||
def resource_axes(self) -> Set[ResourceAxisName]:
|
||||
def resource_axes(self) -> set[ResourceAxisName]:
|
||||
return self.physical_resource_axes | self.loop_resource_axes
|
||||
|
||||
@property
|
||||
@ -138,7 +138,7 @@ class Mesh(contextlib.ContextDecorator):
|
||||
"""
|
||||
|
||||
devices: np.ndarray
|
||||
axis_names: Tuple[MeshAxisName, ...]
|
||||
axis_names: tuple[MeshAxisName, ...]
|
||||
|
||||
def __init__(self, devices: Union[np.ndarray, Sequence[xc.Device]],
|
||||
axis_names: Union[str, Sequence[MeshAxisName]]):
|
||||
|
@ -20,10 +20,10 @@ during program execution, the registered listeners will be invoked.
|
||||
A typical listener callback is to send an event to a metrics collector for
|
||||
aggregation/exporting.
|
||||
"""
|
||||
from typing import Callable, List
|
||||
from typing import Callable
|
||||
|
||||
_event_listeners: List[Callable[[str], None]] = []
|
||||
_event_duration_secs_listeners: List[Callable[[str, float], None]] = []
|
||||
_event_listeners: list[Callable[[str], None]] = []
|
||||
_event_duration_secs_listeners: list[Callable[[str, float], None]] = []
|
||||
|
||||
def record_event(event: str) -> None:
|
||||
"""Record an event."""
|
||||
|
@ -18,7 +18,7 @@ from functools import partial
|
||||
import operator
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -293,7 +293,7 @@ logsumexp = _logsumexp
|
||||
|
||||
@partial(jax.jit, static_argnames=("axis",))
|
||||
def log_softmax(x: Array,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = -1,
|
||||
where: Optional[Array] = None,
|
||||
initial: Optional[Array] = None) -> Array:
|
||||
r"""Log-Softmax function.
|
||||
@ -326,7 +326,7 @@ def log_softmax(x: Array,
|
||||
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
|
||||
#@partial(jax.jit, static_argnames=("axis",))
|
||||
def softmax(x: Array,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = -1,
|
||||
where: Optional[Array] = None,
|
||||
initial: Optional[Array] = None) -> Array:
|
||||
r"""Softmax function.
|
||||
@ -355,7 +355,7 @@ def softmax(x: Array,
|
||||
@partial(jax.custom_jvp, nondiff_argnums=(1,))
|
||||
def _softmax(
|
||||
x,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = -1,
|
||||
where: Optional[Array] = None,
|
||||
initial: Optional[Array] = None) -> Array:
|
||||
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
|
||||
@ -382,7 +382,7 @@ def _softmax_deprecated(x, axis, where, initial):
|
||||
|
||||
@partial(jax.jit, static_argnames=("axis",))
|
||||
def standardize(x: Array,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = -1,
|
||||
mean: Optional[Array] = None,
|
||||
variance: Optional[Array] = None,
|
||||
epsilon: Array = 1e-5,
|
||||
@ -400,7 +400,7 @@ def standardize(x: Array,
|
||||
return (x - mean) * lax.rsqrt(variance + epsilon)
|
||||
|
||||
def normalize(x: Array,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = -1,
|
||||
mean: Optional[Array] = None,
|
||||
variance: Optional[Array] = None,
|
||||
epsilon: Array = 1e-5,
|
||||
|
@ -18,7 +18,7 @@ used in Keras and Sonnet.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Any, Literal, Protocol, Sequence, Tuple, Union
|
||||
from typing import Any, Literal, Protocol, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -159,7 +159,7 @@ def _compute_fans(shape: core.NamedShape,
|
||||
in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Union[int, Sequence[int]] = ()
|
||||
) -> Tuple[Array, Array]:
|
||||
) -> tuple[Array, Array]:
|
||||
"""
|
||||
Compute effective input and output sizes for a linear or convolutional layer.
|
||||
|
||||
|
@ -24,7 +24,7 @@ __all__ = ['register_jax_array_methods']
|
||||
import abc
|
||||
from functools import partial, wraps
|
||||
import inspect
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -335,15 +335,15 @@ def _deprecated_split(*args, **kwargs):
|
||||
@core.stash_axis_env()
|
||||
@partial(jax.jit, static_argnums=(1,2,3))
|
||||
def _multi_slice(arr: ArrayLike,
|
||||
start_indices: Tuple[Tuple[int, ...]],
|
||||
limit_indices: Tuple[Tuple[int, ...]],
|
||||
removed_dims: Tuple[Tuple[int, ...]]) -> List[Array]:
|
||||
start_indices: tuple[tuple[int, ...]],
|
||||
limit_indices: tuple[tuple[int, ...]],
|
||||
removed_dims: tuple[tuple[int, ...]]) -> list[Array]:
|
||||
"""Extracts multiple slices from `arr`.
|
||||
|
||||
This is used to shard DeviceArray arguments to pmap. It's implemented as a
|
||||
DeviceArray method here to avoid circular imports.
|
||||
"""
|
||||
results: List[Array] = []
|
||||
results: list[Array] = []
|
||||
for starts, limits, removed in zip(start_indices, limit_indices, removed_dims):
|
||||
sliced = lax.slice(arr, starts, limits)
|
||||
if removed:
|
||||
@ -354,7 +354,7 @@ def _multi_slice(arr: ArrayLike,
|
||||
# The next two functions are related to iter(device_array), implemented here to
|
||||
# avoid circular imports.
|
||||
@jax.jit
|
||||
def _unstack(x: Array) -> List[Array]:
|
||||
def _unstack(x: Array) -> list[Array]:
|
||||
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
|
||||
|
||||
def _chunk_iter(x, size):
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import Any, Iterable, List, Tuple, Union
|
||||
from typing import Any, Iterable, Union
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
@ -73,7 +73,7 @@ class _Mgrid:
|
||||
[0, 1, 2]]], dtype=int32)
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: Union[slice, Tuple[slice, ...]]) -> Array:
|
||||
def __getitem__(self, key: Union[slice, tuple[slice, ...]]) -> Array:
|
||||
if isinstance(key, slice):
|
||||
return _make_1d_grid_from_slice(key, op_name="mgrid")
|
||||
output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="mgrid") for k in key)
|
||||
@ -117,8 +117,8 @@ class _Ogrid:
|
||||
"""
|
||||
|
||||
def __getitem__(
|
||||
self, key: Union[slice, Tuple[slice, ...]]
|
||||
) -> Union[Array, List[Array]]:
|
||||
self, key: Union[slice, tuple[slice, ...]]
|
||||
) -> Union[Array, list[Array]]:
|
||||
if isinstance(key, slice):
|
||||
return _make_1d_grid_from_slice(key, op_name="ogrid")
|
||||
output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="ogrid") for k in key)
|
||||
@ -140,8 +140,8 @@ class _AxisConcat(abc.ABC):
|
||||
trans1d: int
|
||||
op_name: str
|
||||
|
||||
def __getitem__(self, key: Union[_IndexType, Tuple[_IndexType, ...]]) -> Array:
|
||||
key_tup: Tuple[_IndexType, ...] = key if isinstance(key, tuple) else (key,)
|
||||
def __getitem__(self, key: Union[_IndexType, tuple[_IndexType, ...]]) -> Array:
|
||||
key_tup: tuple[_IndexType, ...] = key if isinstance(key, tuple) else (key,)
|
||||
|
||||
params = [self.axis, self.ndmin, self.trans1d, -1]
|
||||
|
||||
@ -154,7 +154,7 @@ class _AxisConcat(abc.ABC):
|
||||
elif directive == "c":
|
||||
params[-1] = 1
|
||||
else:
|
||||
vec: List[Any] = directive.split(",")
|
||||
vec: list[Any] = directive.split(",")
|
||||
k = len(vec)
|
||||
if k < 4:
|
||||
vec += params[k:]
|
||||
|
@ -31,8 +31,8 @@ import math
|
||||
import operator
|
||||
import types
|
||||
from typing import (
|
||||
overload, Any, Callable, Dict, FrozenSet, List, Literal,
|
||||
NamedTuple, Optional, Protocol, Sequence, Tuple, TypeVar, Union)
|
||||
overload, Any, Callable, Literal,
|
||||
NamedTuple, Optional, Protocol, Sequence, TypeVar, Union)
|
||||
from textwrap import dedent as _dedent
|
||||
import warnings
|
||||
|
||||
@ -230,7 +230,7 @@ def _jnp_dtype(obj: Optional[DTypeLike], *, align: bool = False,
|
||||
|
||||
### utility functions
|
||||
|
||||
_DEFAULT_TYPEMAP: Dict[type, _ScalarMeta] = {
|
||||
_DEFAULT_TYPEMAP: dict[type, _ScalarMeta] = {
|
||||
np.bool_: bool_,
|
||||
np.int_: int_,
|
||||
np.float_: float_,
|
||||
@ -448,7 +448,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
|
||||
def histogram(a: ArrayLike, bins: ArrayLike = 10,
|
||||
range: Optional[Sequence[ArrayLike]] = None,
|
||||
weights: Optional[ArrayLike] = None,
|
||||
density: Optional[bool] = None) -> Tuple[Array, Array]:
|
||||
density: Optional[bool] = None) -> tuple[Array, Array]:
|
||||
if weights is None:
|
||||
util.check_arraylike("histogram", a, bins)
|
||||
a = ravel(*util.promote_dtypes_inexact(a))
|
||||
@ -469,10 +469,10 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10,
|
||||
return counts, bin_edges
|
||||
|
||||
@util._wraps(np.histogram2d)
|
||||
def histogram2d(x: ArrayLike, y: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10,
|
||||
def histogram2d(x: ArrayLike, y: ArrayLike, bins: Union[ArrayLike, list[ArrayLike]] = 10,
|
||||
range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]]=None,
|
||||
weights: Optional[ArrayLike] = None,
|
||||
density: Optional[bool] = None) -> Tuple[Array, Array, Array]:
|
||||
density: Optional[bool] = None) -> tuple[Array, Array, Array]:
|
||||
util.check_arraylike("histogram2d", x, y)
|
||||
try:
|
||||
N = len(bins) # type: ignore[arg-type]
|
||||
@ -488,10 +488,10 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: Union[ArrayLike, List[ArrayLik
|
||||
return hist, edges[0], edges[1]
|
||||
|
||||
@util._wraps(np.histogramdd)
|
||||
def histogramdd(sample: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10,
|
||||
def histogramdd(sample: ArrayLike, bins: Union[ArrayLike, list[ArrayLike]] = 10,
|
||||
range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]] = None,
|
||||
weights: Optional[ArrayLike] = None,
|
||||
density: Optional[bool] = None) -> Tuple[Array, List[Array]]:
|
||||
density: Optional[bool] = None) -> tuple[Array, list[Array]]:
|
||||
if weights is None:
|
||||
util.check_arraylike("histogramdd", sample)
|
||||
sample, = util.promote_dtypes_inexact(sample)
|
||||
@ -511,14 +511,14 @@ def histogramdd(sample: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10,
|
||||
num_bins = len(bins) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
# when bin_size is integer, the same bin is used for each dimension
|
||||
bins_per_dimension: List[ArrayLike] = D * [bins] # type: ignore[assignment]
|
||||
bins_per_dimension: list[ArrayLike] = D * [bins] # type: ignore[assignment]
|
||||
else:
|
||||
if num_bins != D:
|
||||
raise ValueError("should be a bin for each dimension.")
|
||||
bins_per_dimension = list(bins) # type: ignore[arg-type]
|
||||
|
||||
bin_idx_by_dim: List[Array] = []
|
||||
bin_edges_by_dim: List[Array] = []
|
||||
bin_idx_by_dim: list[Array] = []
|
||||
bin_edges_by_dim: list[Array] = []
|
||||
|
||||
for i in builtins.range(D):
|
||||
range_i = None if range is None else range[i]
|
||||
@ -583,7 +583,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:
|
||||
|
||||
@util._wraps(np.rot90, lax_description=_ARRAY_VIEW_DOC)
|
||||
@partial(jit, static_argnames=('k', 'axes'))
|
||||
def rot90(m: ArrayLike, k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Array:
|
||||
def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array:
|
||||
util.check_arraylike("rot90", m)
|
||||
ax1, ax2 = axes
|
||||
ax1 = _canonicalize_axis(ax1, ndim(m))
|
||||
@ -605,12 +605,12 @@ def rot90(m: ArrayLike, k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Array:
|
||||
|
||||
|
||||
@util._wraps(np.flip, lax_description=_ARRAY_VIEW_DOC)
|
||||
def flip(m: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
|
||||
def flip(m: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array:
|
||||
util.check_arraylike("flip", m)
|
||||
return _flip(asarray(m), reductions._ensure_optional_axes(axis))
|
||||
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def _flip(m: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
|
||||
def _flip(m: Array, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array:
|
||||
if axis is None:
|
||||
return lax.rev(m, list(range(len(shape(m)))))
|
||||
axis = _ensure_index_tuple(axis)
|
||||
@ -674,7 +674,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
|
||||
nd = arr.ndim
|
||||
axis = _canonicalize_axis(axis, nd)
|
||||
|
||||
combined: List[Array] = []
|
||||
combined: list[Array] = []
|
||||
if prepend is not None:
|
||||
util.check_arraylike("diff", prepend)
|
||||
if isscalar(prepend):
|
||||
@ -734,8 +734,8 @@ def ediff1d(ary: ArrayLike, to_end: Optional[ArrayLike] = None,
|
||||
@util._wraps(np.gradient, skip_params=['edge_order'])
|
||||
@partial(jit, static_argnames=('axis', 'edge_order'))
|
||||
def gradient(f: ArrayLike, *varargs: ArrayLike,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
edge_order: Optional[int] = None) -> Union[Array, List[Array]]:
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = None,
|
||||
edge_order: Optional[int] = None) -> Union[Array, list[Array]]:
|
||||
if edge_order is not None:
|
||||
raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.")
|
||||
a, *spacing = util.promote_args_inexact("gradient", f, *varargs)
|
||||
@ -805,7 +805,7 @@ def ravel(a: ArrayLike, order: str = "C") -> Array:
|
||||
|
||||
|
||||
@util._wraps(np.ravel_multi_index)
|
||||
def ravel_multi_index(multi_index: Tuple[ArrayLike, ...], dims: Tuple[int, ...],
|
||||
def ravel_multi_index(multi_index: tuple[ArrayLike, ...], dims: tuple[int, ...],
|
||||
mode: str = 'raise', order: str = 'C') -> Array:
|
||||
assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}"
|
||||
dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims)
|
||||
@ -848,7 +848,7 @@ and out-of-bounds indices are clipped into the valid range.
|
||||
"""
|
||||
|
||||
@util._wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC)
|
||||
def unravel_index(indices: ArrayLike, shape: Shape) -> Tuple[Array, ...]:
|
||||
def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
|
||||
util.check_arraylike("unravel_index", indices)
|
||||
indices_arr = asarray(indices)
|
||||
# Note: we do not convert shape to an array, because it may be passed as a
|
||||
@ -888,12 +888,12 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array:
|
||||
return reshape(arr, new_shape)
|
||||
|
||||
@util._wraps(np.squeeze, lax_description=_ARRAY_VIEW_DOC)
|
||||
def squeeze(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
|
||||
def squeeze(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array:
|
||||
util.check_arraylike("squeeze", a)
|
||||
return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None)
|
||||
|
||||
@partial(jit, static_argnames=('axis',), inline=True)
|
||||
def _squeeze(a: Array, axis: Tuple[int]) -> Array:
|
||||
def _squeeze(a: Array, axis: tuple[int]) -> Array:
|
||||
if axis is None:
|
||||
a_shape = shape(a)
|
||||
if not core.is_constant_shape(a_shape):
|
||||
@ -927,7 +927,7 @@ def moveaxis(a: ArrayLike, source: Union[int, Sequence[int]],
|
||||
_ensure_index_tuple(destination))
|
||||
|
||||
@partial(jit, static_argnames=('source', 'destination'), inline=True)
|
||||
def _moveaxis(a: Array, source: Tuple[int, ...], destination: Tuple[int, ...]) -> Array:
|
||||
def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) -> Array:
|
||||
source = tuple(_canonicalize_axis(i, ndim(a)) for i in source)
|
||||
destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination)
|
||||
if len(source) != len(destination):
|
||||
@ -1064,20 +1064,20 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
|
||||
@overload
|
||||
def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, *,
|
||||
size: Optional[int] = None,
|
||||
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
|
||||
) -> Tuple[Array, ...]: ...
|
||||
fill_value: Union[None, Array, tuple[ArrayLike]] = None
|
||||
) -> tuple[Array, ...]: ...
|
||||
|
||||
@overload
|
||||
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *,
|
||||
size: Optional[int] = None,
|
||||
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
|
||||
fill_value: Union[None, Array, tuple[ArrayLike]] = None
|
||||
) -> Array: ...
|
||||
|
||||
@overload
|
||||
def where(condition: ArrayLike, x: Optional[ArrayLike] = None,
|
||||
y: Optional[ArrayLike] = None, *, size: Optional[int] = None,
|
||||
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
|
||||
) -> Union[Array, Tuple[Array, ...]]: ...
|
||||
fill_value: Union[None, Array, tuple[ArrayLike]] = None
|
||||
) -> Union[Array, tuple[Array, ...]]: ...
|
||||
|
||||
@util._wraps(np.where,
|
||||
lax_description=_dedent("""
|
||||
@ -1105,8 +1105,8 @@ def where(condition: ArrayLike, x: Optional[ArrayLike] = None,
|
||||
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
|
||||
def where(condition: ArrayLike, x: Optional[ArrayLike] = None,
|
||||
y: Optional[ArrayLike] = None, *, size: Optional[int] = None,
|
||||
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
|
||||
) -> Union[Array, Tuple[Array, ...]]:
|
||||
fill_value: Union[None, Array, tuple[ArrayLike]] = None
|
||||
) -> Union[Array, tuple[Array, ...]]:
|
||||
if x is None and y is None:
|
||||
util.check_arraylike("where", condition)
|
||||
return nonzero(condition, size=size, fill_value=fill_value)
|
||||
@ -1165,11 +1165,11 @@ def bincount(x: ArrayLike, weights: Optional[ArrayLike] = None,
|
||||
return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights)
|
||||
|
||||
@overload
|
||||
def broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]: ...
|
||||
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: ...
|
||||
|
||||
@overload
|
||||
def broadcast_shapes(*shapes: Tuple[Union[int, core.Tracer], ...]
|
||||
) -> Tuple[Union[int, core.Tracer], ...]: ...
|
||||
def broadcast_shapes(*shapes: tuple[Union[int, core.Tracer], ...]
|
||||
) -> tuple[Union[int, core.Tracer], ...]: ...
|
||||
|
||||
@util._wraps(getattr(np, "broadcast_shapes", None))
|
||||
def broadcast_shapes(*shapes):
|
||||
@ -1182,7 +1182,7 @@ def broadcast_shapes(*shapes):
|
||||
@util._wraps(np.broadcast_arrays, lax_description="""\
|
||||
The JAX version does not necessarily return a view of the input.
|
||||
""")
|
||||
def broadcast_arrays(*args: ArrayLike) -> List[Array]:
|
||||
def broadcast_arrays(*args: ArrayLike) -> list[Array]:
|
||||
return util._broadcast_arrays(*args)
|
||||
|
||||
|
||||
@ -1195,7 +1195,7 @@ def broadcast_to(array: ArrayLike, shape: Shape) -> Array:
|
||||
|
||||
def _split(op: str, ary: ArrayLike,
|
||||
indices_or_sections: Union[int, Sequence[int], ArrayLike],
|
||||
axis: int = 0) -> List[Array]:
|
||||
axis: int = 0) -> list[Array]:
|
||||
util.check_arraylike(op, ary)
|
||||
ary = asarray(ary)
|
||||
axis = core.concrete_or_error(operator.index, axis, f"in jax.numpy.{op} argument `axis`")
|
||||
@ -1234,12 +1234,12 @@ def _split(op: str, ary: ArrayLike,
|
||||
|
||||
@util._wraps(np.split, lax_description=_ARRAY_VIEW_DOC)
|
||||
def split(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike],
|
||||
axis: int = 0) -> List[Array]:
|
||||
axis: int = 0) -> list[Array]:
|
||||
return _split("split", ary, indices_or_sections, axis=axis)
|
||||
|
||||
def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, Union[int, ArrayLike]], List[Array]]:
|
||||
def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, Union[int, ArrayLike]], list[Array]]:
|
||||
@util._wraps(getattr(np, op), update_doc=False)
|
||||
def f(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike]) -> List[Array]:
|
||||
def f(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike]) -> list[Array]:
|
||||
# for 1-D array, hsplit becomes vsplit
|
||||
nonlocal axis
|
||||
util.check_arraylike(op, ary)
|
||||
@ -1255,7 +1255,7 @@ dsplit = _split_on_axis("dsplit", axis=2)
|
||||
|
||||
@util._wraps(np.array_split)
|
||||
def array_split(ary: ArrayLike, indices_or_sections: Union[int, Sequence[int], ArrayLike],
|
||||
axis: int = 0) -> List[Array]:
|
||||
axis: int = 0) -> list[Array]:
|
||||
return _split("array_split", ary, indices_or_sections, axis=axis)
|
||||
|
||||
@util._wraps(np.clip, skip_params=['out'])
|
||||
@ -1367,8 +1367,8 @@ fill_value : array_like, optional
|
||||
|
||||
@util._wraps(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)
|
||||
def nonzero(a: ArrayLike, *, size: Optional[int] = None,
|
||||
fill_value: Union[None, ArrayLike, Tuple[ArrayLike]] = None
|
||||
) -> Tuple[Array, ...]:
|
||||
fill_value: Union[None, ArrayLike, tuple[ArrayLike]] = None
|
||||
) -> tuple[Array, ...]:
|
||||
util.check_arraylike("nonzero", a)
|
||||
arr = atleast_1d(a)
|
||||
del a
|
||||
@ -1394,7 +1394,7 @@ def nonzero(a: ArrayLike, *, size: Optional[int] = None,
|
||||
|
||||
@util._wraps(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)
|
||||
def flatnonzero(a: ArrayLike, *, size: Optional[int] = None,
|
||||
fill_value: Union[None, ArrayLike, Tuple[ArrayLike]] = None) -> Array:
|
||||
fill_value: Union[None, ArrayLike, tuple[ArrayLike]] = None) -> Array:
|
||||
return nonzero(ravel(a), size=size, fill_value=fill_value)[0]
|
||||
|
||||
|
||||
@ -1428,7 +1428,7 @@ def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = None,
|
||||
### Padding
|
||||
|
||||
PadValueLike = Union[T, Sequence[T], Sequence[Sequence[T]]]
|
||||
PadValue = Tuple[Tuple[T, T], ...]
|
||||
PadValue = tuple[tuple[T, T], ...]
|
||||
|
||||
class PadStatFunc(Protocol):
|
||||
def __call__(self, array: ArrayLike, /, *,
|
||||
@ -1477,7 +1477,7 @@ def _broadcast_to_pairs(nvals: PadValueLike, nd: int, name: str) -> PadValue:
|
||||
f"Valid shapes are ({nd}, 2), (1, 2), (2,), (1,), or ().")
|
||||
|
||||
|
||||
def _check_no_padding(axis_padding: Tuple[Any, Any], mode: str):
|
||||
def _check_no_padding(axis_padding: tuple[Any, Any], mode: str):
|
||||
if (axis_padding[0] > 0 or axis_padding[1] > 0):
|
||||
msg = "Cannot apply '{}' padding to empty axis"
|
||||
raise ValueError(msg.format(mode))
|
||||
@ -1689,7 +1689,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str,
|
||||
if nd == 0:
|
||||
return array
|
||||
|
||||
stat_funcs: Dict[str, PadStatFunc] = {
|
||||
stat_funcs: dict[str, PadStatFunc] = {
|
||||
"maximum": reductions.amax,
|
||||
"minimum": reductions.amin,
|
||||
"mean": reductions.mean,
|
||||
@ -1806,7 +1806,7 @@ def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array:
|
||||
try:
|
||||
iter(reps) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
reps_tup: Tuple[DimSize, ...] = (reps,)
|
||||
reps_tup: tuple[DimSize, ...] = (reps,)
|
||||
else:
|
||||
reps_tup = tuple(reps) # type: ignore[assignment,arg-type]
|
||||
reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep
|
||||
@ -1932,7 +1932,7 @@ def _atleast_nd(x: ArrayLike, n: int) -> Array:
|
||||
m = ndim(x)
|
||||
return lax.broadcast(x, (1,) * (n - m)) if m < n else asarray(x)
|
||||
|
||||
def _block(xs: Union[ArrayLike, List[ArrayLike]]) -> Tuple[Array, int]:
|
||||
def _block(xs: Union[ArrayLike, list[ArrayLike]]) -> tuple[Array, int]:
|
||||
if isinstance(xs, tuple):
|
||||
raise ValueError("jax.numpy.block does not allow tuples, got {}"
|
||||
.format(xs))
|
||||
@ -1950,13 +1950,13 @@ def _block(xs: Union[ArrayLike, List[ArrayLike]]) -> Tuple[Array, int]:
|
||||
|
||||
@util._wraps(np.block)
|
||||
@jit
|
||||
def block(arrays: Union[ArrayLike, List[ArrayLike]]) -> Array:
|
||||
def block(arrays: Union[ArrayLike, list[ArrayLike]]) -> Array:
|
||||
out, _ = _block(arrays)
|
||||
return out
|
||||
|
||||
@util._wraps(np.atleast_1d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
|
||||
@jit
|
||||
def atleast_1d(*arys: ArrayLike) -> Union[Array, List[Array]]:
|
||||
def atleast_1d(*arys: ArrayLike) -> Union[Array, list[Array]]:
|
||||
if len(arys) == 1:
|
||||
arr = asarray(arys[0])
|
||||
return arr if ndim(arr) >= 1 else reshape(arr, -1)
|
||||
@ -1966,7 +1966,7 @@ def atleast_1d(*arys: ArrayLike) -> Union[Array, List[Array]]:
|
||||
|
||||
@util._wraps(np.atleast_2d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
|
||||
@jit
|
||||
def atleast_2d(*arys: ArrayLike) -> Union[Array, List[Array]]:
|
||||
def atleast_2d(*arys: ArrayLike) -> Union[Array, list[Array]]:
|
||||
if len(arys) == 1:
|
||||
arr = asarray(arys[0])
|
||||
if ndim(arr) >= 2:
|
||||
@ -1981,7 +1981,7 @@ def atleast_2d(*arys: ArrayLike) -> Union[Array, List[Array]]:
|
||||
|
||||
@util._wraps(np.atleast_3d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
|
||||
@jit
|
||||
def atleast_3d(*arys: ArrayLike) -> Union[Array, List[Array]]:
|
||||
def atleast_3d(*arys: ArrayLike) -> Union[Array, list[Array]]:
|
||||
if len(arys) == 1:
|
||||
arr = asarray(arys[0])
|
||||
if ndim(arr) == 0:
|
||||
@ -2389,22 +2389,22 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
def linspace(start: ArrayLike, stop: ArrayLike, num: int,
|
||||
endpoint: bool, retstep: Literal[True],
|
||||
dtype: Optional[DTypeLike] = None,
|
||||
axis: int = 0) -> Tuple[Array, Array]: ...
|
||||
axis: int = 0) -> tuple[Array, Array]: ...
|
||||
@overload
|
||||
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
endpoint: bool = True, *, retstep: Literal[True],
|
||||
dtype: Optional[DTypeLike] = None,
|
||||
axis: int = 0) -> Tuple[Array, Array]: ...
|
||||
axis: int = 0) -> tuple[Array, Array]: ...
|
||||
@overload
|
||||
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
endpoint: bool = True, retstep: bool = False,
|
||||
dtype: Optional[DTypeLike] = None,
|
||||
axis: int = 0) -> Union[Array, Tuple[Array, Array]]: ...
|
||||
axis: int = 0) -> Union[Array, tuple[Array, Array]]: ...
|
||||
@util._wraps(np.linspace)
|
||||
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
endpoint: bool = True, retstep: bool = False,
|
||||
dtype: Optional[DTypeLike] = None,
|
||||
axis: int = 0) -> Union[Array, Tuple[Array, Array]]:
|
||||
axis: int = 0) -> Union[Array, tuple[Array, Array]]:
|
||||
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace")
|
||||
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
|
||||
return _linspace(start, stop, num, endpoint, retstep, dtype, axis)
|
||||
@ -2413,7 +2413,7 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
endpoint: bool = True, retstep: bool = False,
|
||||
dtype: Optional[DTypeLike] = None,
|
||||
axis: int = 0) -> Union[Array, Tuple[Array, Array]]:
|
||||
axis: int = 0) -> Union[Array, tuple[Array, Array]]:
|
||||
"""Implementation of linspace differentiable in start and stop args."""
|
||||
dtypes.check_user_dtype_supported(dtype, "linspace")
|
||||
if num < 0:
|
||||
@ -2526,7 +2526,7 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool
|
||||
|
||||
@util._wraps(np.meshgrid, lax_description=_ARRAY_VIEW_DOC)
|
||||
def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
|
||||
indexing: str = 'xy') -> List[Array]:
|
||||
indexing: str = 'xy') -> list[Array]:
|
||||
util.check_arraylike("meshgrid", *xi)
|
||||
args = [asarray(x) for x in xi]
|
||||
if not copy:
|
||||
@ -2557,7 +2557,7 @@ def i0(x: ArrayLike) -> Array:
|
||||
|
||||
|
||||
@util._wraps(np.ix_)
|
||||
def ix_(*args: ArrayLike) -> Tuple[Array, ...]:
|
||||
def ix_(*args: ArrayLike) -> tuple[Array, ...]:
|
||||
util.check_arraylike("ix", *args)
|
||||
n = len(args)
|
||||
output = []
|
||||
@ -2584,13 +2584,13 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
sparse: Literal[False] = False) -> Array: ...
|
||||
@overload
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
*, sparse: Literal[True]) -> Tuple[Array, ...]: ...
|
||||
*, sparse: Literal[True]) -> tuple[Array, ...]: ...
|
||||
@overload
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
sparse: bool = False) -> Union[Array, Tuple[Array, ...]]: ...
|
||||
sparse: bool = False) -> Union[Array, tuple[Array, ...]]: ...
|
||||
@util._wraps(np.indices)
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
sparse: bool = False) -> Union[Array, Tuple[Array, ...]]:
|
||||
sparse: bool = False) -> Union[Array, tuple[Array, ...]]:
|
||||
dimensions = tuple(
|
||||
core.concrete_or_error(operator.index, d, "dimensions argument of jnp.indices")
|
||||
for d in dimensions)
|
||||
@ -2649,10 +2649,10 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = None, *,
|
||||
input_shape = shape(a)
|
||||
aux_axis = axis if axis < 0 else axis + 1
|
||||
a = expand_dims(a, aux_axis)
|
||||
reps: List[DimSize] = [1] * len(shape(a))
|
||||
reps: list[DimSize] = [1] * len(shape(a))
|
||||
reps[aux_axis] = repeats
|
||||
a = tile(a, reps)
|
||||
result_shape: List[DimSize] = list(input_shape)
|
||||
result_shape: list[DimSize] = list(input_shape)
|
||||
result_shape[axis] *= repeats
|
||||
return reshape(a, result_shape)
|
||||
|
||||
@ -2781,7 +2781,7 @@ def _triu_size(n, m, k):
|
||||
|
||||
|
||||
@util._wraps(np.triu_indices)
|
||||
def triu_indices(n: int, k: int = 0, m: Optional[int] = None) -> Tuple[Array, Array]:
|
||||
def triu_indices(n: int, k: int = 0, m: Optional[int] = None) -> tuple[Array, Array]:
|
||||
n = core.concrete_or_error(operator.index, n, "n argument of jnp.triu_indices")
|
||||
k = core.concrete_or_error(operator.index, k, "k argument of jnp.triu_indices")
|
||||
m = n if m is None else core.concrete_or_error(operator.index, m, "m argument of jnp.triu_indices")
|
||||
@ -2790,7 +2790,7 @@ def triu_indices(n: int, k: int = 0, m: Optional[int] = None) -> Tuple[Array, Ar
|
||||
|
||||
|
||||
@util._wraps(np.tril_indices)
|
||||
def tril_indices(n: int, k: int = 0, m: Optional[int] = None) -> Tuple[Array, Array]:
|
||||
def tril_indices(n: int, k: int = 0, m: Optional[int] = None) -> tuple[Array, Array]:
|
||||
n = core.concrete_or_error(operator.index, n, "n argument of jnp.triu_indices")
|
||||
k = core.concrete_or_error(operator.index, k, "k argument of jnp.triu_indices")
|
||||
m = n if m is None else core.concrete_or_error(operator.index, m, "m argument of jnp.triu_indices")
|
||||
@ -2799,13 +2799,13 @@ def tril_indices(n: int, k: int = 0, m: Optional[int] = None) -> Tuple[Array, Ar
|
||||
|
||||
|
||||
@util._wraps(np.triu_indices_from)
|
||||
def triu_indices_from(arr: ArrayLike, k: int = 0) -> Tuple[Array, Array]:
|
||||
def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
|
||||
arr_shape = shape(arr)
|
||||
return triu_indices(arr_shape[-2], k=k, m=arr_shape[-1])
|
||||
|
||||
|
||||
@util._wraps(np.tril_indices_from)
|
||||
def tril_indices_from(arr: ArrayLike, k: int = 0) -> Tuple[Array, Array]:
|
||||
def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
|
||||
arr_shape = shape(arr)
|
||||
return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1])
|
||||
|
||||
@ -3294,7 +3294,7 @@ def _removechars(s, chars):
|
||||
@partial(jit, static_argnums=(1, 2, 3, 4), inline=True)
|
||||
def _einsum(
|
||||
operands: Sequence,
|
||||
contractions: Sequence[Tuple[Tuple[int, ...], FrozenSet[str], str]],
|
||||
contractions: Sequence[tuple[tuple[int, ...], frozenset[str], str]],
|
||||
precision,
|
||||
preferred_element_type,
|
||||
_dot_general=lax.dot_general,
|
||||
@ -4315,8 +4315,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
# Pairs of (array, start_dim) values. These will be broadcast into
|
||||
# gather_indices_shape, with the array dimensions aligned to start_dim, and
|
||||
# then concatenated.
|
||||
gather_indices: List[Tuple[Array, int]] = []
|
||||
gather_indices_shape: List[int] = []
|
||||
gather_indices: list[tuple[Array, int]] = []
|
||||
gather_indices_shape: list[int] = []
|
||||
|
||||
# We perform three transformations to y before the scatter op, in order:
|
||||
# First, y is broadcast to slice_shape. In general `y` only need broadcast to
|
||||
@ -4687,11 +4687,11 @@ def kaiser(M: int, beta: ArrayLike) -> Array:
|
||||
return i0(beta * ufuncs.sqrt(1 - ((n - alpha) / alpha) ** 2)) / i0(beta)
|
||||
|
||||
|
||||
def _gcd_cond_fn(xs: Tuple[Array, Array]) -> Array:
|
||||
def _gcd_cond_fn(xs: tuple[Array, Array]) -> Array:
|
||||
x1, x2 = xs
|
||||
return reductions.any(x2 != 0)
|
||||
|
||||
def _gcd_body_fn(xs: Tuple[Array, Array]) -> Tuple[Array, Array]:
|
||||
def _gcd_body_fn(xs: tuple[Array, Array]) -> tuple[Array, Array]:
|
||||
x1, x2 = xs
|
||||
x1, x2 = (where(x2 != 0, x2, x1),
|
||||
where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0)))
|
||||
@ -4927,7 +4927,7 @@ See the :func:`jax.lax.switch` documentation for more information.
|
||||
|
||||
@util._wraps(np.piecewise, lax_description=_PIECEWISE_DOC)
|
||||
def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]],
|
||||
funclist: List[Union[ArrayLike, Callable[..., Array]]],
|
||||
funclist: list[Union[ArrayLike, Callable[..., Array]]],
|
||||
*args, **kw) -> Array:
|
||||
util.check_arraylike("piecewise", x)
|
||||
nc, nf = len(condlist), len(funclist)
|
||||
@ -4944,8 +4944,8 @@ def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]],
|
||||
*args, **kw)
|
||||
|
||||
@partial(jit, static_argnames=['funcs'])
|
||||
def _piecewise(x: Array, condlist: Array, consts: Dict[int, ArrayLike],
|
||||
funcs: FrozenSet[Tuple[int, Callable[..., Array]]],
|
||||
def _piecewise(x: Array, condlist: Array, consts: dict[int, ArrayLike],
|
||||
funcs: frozenset[tuple[int, Callable[..., Array]]],
|
||||
*args, **kw) -> Array:
|
||||
funcdict = dict(funcs)
|
||||
funclist = [consts.get(i, funcdict.get(i)) for i in range(len(condlist) + 1)]
|
||||
|
@ -18,7 +18,7 @@ from functools import partial
|
||||
import numpy as np
|
||||
import textwrap
|
||||
import operator
|
||||
from typing import Literal, Optional, Tuple, Union, cast, overload
|
||||
from typing import Literal, Optional, Union, cast, overload
|
||||
|
||||
import jax
|
||||
from jax import jit, custom_jvp
|
||||
@ -49,10 +49,10 @@ def cholesky(a: ArrayLike) -> Array:
|
||||
|
||||
@overload
|
||||
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[True],
|
||||
hermitian: bool = False) -> Tuple[Array, Array, Array]: ...
|
||||
hermitian: bool = False) -> tuple[Array, Array, Array]: ...
|
||||
@overload
|
||||
def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[True],
|
||||
hermitian: bool = False) -> Tuple[Array, Array, Array]: ...
|
||||
hermitian: bool = False) -> tuple[Array, Array, Array]: ...
|
||||
@overload
|
||||
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False],
|
||||
hermitian: bool = False) -> Array: ...
|
||||
@ -61,12 +61,12 @@ def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False],
|
||||
hermitian: bool = False) -> Array: ...
|
||||
@overload
|
||||
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||
hermitian: bool = False) -> Union[Array, Tuple[Array, Array, Array]]: ...
|
||||
hermitian: bool = False) -> Union[Array, tuple[Array, Array, Array]]: ...
|
||||
|
||||
@_wraps(np.linalg.svd)
|
||||
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
|
||||
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||
hermitian: bool = False) -> Union[Array, Tuple[Array, Array, Array]]:
|
||||
hermitian: bool = False) -> Union[Array, tuple[Array, Array, Array]]:
|
||||
check_arraylike("jnp.linalg.svd", a)
|
||||
a, = promote_dtypes_inexact(jnp.asarray(a))
|
||||
if hermitian:
|
||||
@ -142,7 +142,7 @@ def matrix_rank(M: ArrayLike, tol: Optional[ArrayLike] = None) -> Array:
|
||||
|
||||
|
||||
@custom_jvp
|
||||
def _slogdet_lu(a: Array) -> Tuple[Array, Array]:
|
||||
def _slogdet_lu(a: Array) -> tuple[Array, Array]:
|
||||
dtype = lax.dtype(a)
|
||||
lu, pivot, _ = lax_linalg.lu(a)
|
||||
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
|
||||
@ -164,7 +164,7 @@ def _slogdet_lu(a: Array) -> Tuple[Array, Array]:
|
||||
return sign, ufuncs.real(logdet)
|
||||
|
||||
@custom_jvp
|
||||
def _slogdet_qr(a: Array) -> Tuple[Array, Array]:
|
||||
def _slogdet_qr(a: Array) -> tuple[Array, Array]:
|
||||
# Implementation of slogdet using QR decomposition. One reason we might prefer
|
||||
# QR decomposition is that it is more amenable to a fast batched
|
||||
# implementation on TPU because of the lack of row pivoting.
|
||||
@ -192,7 +192,7 @@ def _slogdet_qr(a: Array) -> Tuple[Array, Array]:
|
||||
LU decomposition if ``None``.
|
||||
"""))
|
||||
@partial(jit, static_argnames=('method',))
|
||||
def slogdet(a: ArrayLike, *, method: Optional[str] = None) -> Tuple[Array, Array]:
|
||||
def slogdet(a: ArrayLike, *, method: Optional[str] = None) -> tuple[Array, Array]:
|
||||
check_arraylike("jnp.linalg.slogdet", a)
|
||||
a, = promote_dtypes_inexact(jnp.asarray(a))
|
||||
a_shape = jnp.shape(a)
|
||||
@ -223,7 +223,7 @@ def _slogdet_jvp(primals, tangents):
|
||||
_slogdet_lu.defjvp(_slogdet_jvp)
|
||||
_slogdet_qr.defjvp(_slogdet_jvp)
|
||||
|
||||
def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> Tuple[Array, Array]:
|
||||
def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> tuple[Array, Array]:
|
||||
"""Equivalent to det(a)*solve(a, b) for nonsingular mat.
|
||||
|
||||
Intermediate function used for jvp and vjp of det.
|
||||
@ -364,7 +364,7 @@ At present, non-symmetric eigendecomposition is only implemented on the CPU
|
||||
backend. However eigendecomposition for symmetric/Hermitian matrices is
|
||||
implemented more widely (see :func:`jax.numpy.linalg.eigh`).
|
||||
""")
|
||||
def eig(a: ArrayLike) -> Tuple[Array, Array]:
|
||||
def eig(a: ArrayLike) -> tuple[Array, Array]:
|
||||
check_arraylike("jnp.linalg.eig", a)
|
||||
a, = promote_dtypes_inexact(jnp.asarray(a))
|
||||
w, v = lax_linalg.eig(a, compute_left_eigenvectors=False)
|
||||
@ -382,7 +382,7 @@ def eigvals(a: ArrayLike) -> Array:
|
||||
@_wraps(np.linalg.eigh)
|
||||
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
|
||||
def eigh(a: ArrayLike, UPLO: Optional[str] = None,
|
||||
symmetrize_input: bool = True) -> Tuple[Array, Array]:
|
||||
symmetrize_input: bool = True) -> tuple[Array, Array]:
|
||||
check_arraylike("jnp.linalg.eigh", a)
|
||||
if UPLO is None or UPLO == "L":
|
||||
lower = True
|
||||
@ -481,7 +481,7 @@ def inv(a: ArrayLike) -> Array:
|
||||
@_wraps(np.linalg.norm)
|
||||
@partial(jit, static_argnames=('ord', 'axis', 'keepdims'))
|
||||
def norm(x: ArrayLike, ord: Union[int, str, None] = None,
|
||||
axis: Union[None, Tuple[int, ...], int] = None,
|
||||
axis: Union[None, tuple[int, ...], int] = None,
|
||||
keepdims: bool = False) -> Array:
|
||||
check_arraylike("jnp.linalg.norm", x)
|
||||
x, = promote_dtypes_inexact(jnp.asarray(x))
|
||||
@ -532,7 +532,7 @@ def norm(x: ArrayLike, ord: Union[int, str, None] = None,
|
||||
return ufuncs.power(out, ord_inv)
|
||||
|
||||
elif num_axes == 2:
|
||||
row_axis, col_axis = cast(Tuple[int, ...], axis)
|
||||
row_axis, col_axis = cast(tuple[int, ...], axis)
|
||||
if ord is None or ord in ('f', 'fro'):
|
||||
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
|
||||
keepdims=keepdims))
|
||||
@ -578,11 +578,11 @@ def norm(x: ArrayLike, ord: Union[int, str, None] = None,
|
||||
@overload
|
||||
def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ...
|
||||
@overload
|
||||
def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]: ...
|
||||
def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, tuple[Array, Array]]: ...
|
||||
|
||||
@_wraps(np.linalg.qr)
|
||||
@partial(jit, static_argnames=('mode',))
|
||||
def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]:
|
||||
def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, tuple[Array, Array]]:
|
||||
check_arraylike("jnp.linalg.qr", a)
|
||||
a, = promote_dtypes_inexact(jnp.asarray(a))
|
||||
if mode == "raw":
|
||||
@ -609,7 +609,7 @@ def solve(a: ArrayLike, b: ArrayLike) -> Array:
|
||||
|
||||
|
||||
def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *,
|
||||
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
|
||||
numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]:
|
||||
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
|
||||
# TODO: add custom jvp rule for more robust lstsq differentiation
|
||||
a, b = promote_dtypes_inexact(a, b)
|
||||
@ -669,7 +669,7 @@ _jit_lstsq = jit(partial(_lstsq, numpy_resid=False))
|
||||
poorly behaved for some inputs, particularly for low-rank `a`.
|
||||
"""))
|
||||
def lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float] = None, *,
|
||||
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
|
||||
numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]:
|
||||
check_arraylike("jnp.linalg.lstsq", a, b)
|
||||
if numpy_resid:
|
||||
return _lstsq(a, b, rcond, numpy_resid=True)
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -110,7 +110,7 @@ Also, it works best on rcond <= 10e-3 values.
|
||||
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
|
||||
def polyfit(x: Array, y: Array, deg: int, rcond: Optional[float] = None,
|
||||
full: bool = False, w: Optional[Array] = None, cov: bool = False
|
||||
) -> Union[Array, Tuple[Array, ...]]:
|
||||
) -> Union[Array, tuple[Array, ...]]:
|
||||
check_arraylike("polyfit", x, y)
|
||||
deg = core.concrete_or_error(int, deg, "deg must be int")
|
||||
order = deg + 1
|
||||
@ -301,7 +301,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
|
||||
return convolve(a1_arr, a2_arr, mode='full')
|
||||
|
||||
@_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
|
||||
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> Tuple[Array, Array]:
|
||||
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]:
|
||||
check_arraylike("polydiv", u, v)
|
||||
u_arr, v_arr = promote_dtypes_inexact(u, v)
|
||||
m = len(u_arr) - 1
|
||||
|
@ -17,7 +17,7 @@ from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from typing import (
|
||||
overload, Any, Callable, Literal, Optional, Protocol, Sequence, Tuple, Union)
|
||||
overload, Any, Callable, Literal, Optional, Protocol, Sequence, Union)
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -349,15 +349,15 @@ def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None
|
||||
returned: Literal[True], keepdims: bool = False) -> Array: ...
|
||||
@overload
|
||||
def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
|
||||
returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]: ...
|
||||
returned: bool = False, keepdims: bool = False) -> Union[Array, tuple[Array, Array]]: ...
|
||||
@_wraps(np.average)
|
||||
def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
|
||||
returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]:
|
||||
returned: bool = False, keepdims: bool = False) -> Union[Array, tuple[Array, Array]]:
|
||||
return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True)
|
||||
def _average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
|
||||
returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]:
|
||||
returned: bool = False, keepdims: bool = False) -> Union[Array, tuple[Array, Array]]:
|
||||
if weights is None: # Treat all weights as 1
|
||||
check_arraylike("average", a)
|
||||
a, = promote_dtypes_inexact(a)
|
||||
@ -450,7 +450,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
|
||||
return lax.div(result, normalizer).astype(dtype)
|
||||
|
||||
|
||||
def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike) -> Tuple[DType, DType]:
|
||||
def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike) -> tuple[DType, DType]:
|
||||
if dtype:
|
||||
if (not dtypes.issubdtype(dtype, np.complexfloating) and
|
||||
dtypes.issubdtype(a_dtype, np.complexfloating)):
|
||||
@ -689,7 +689,7 @@ nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
|
||||
@_wraps(np.quantile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
check_arraylike("quantile", a, q)
|
||||
@ -705,7 +705,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ..
|
||||
@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
check_arraylike("nanquantile", a, q)
|
||||
@ -718,7 +718,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int,
|
||||
"Use 'method=' instead.", DeprecationWarning)
|
||||
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, True)
|
||||
|
||||
def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
|
||||
def _quantile(a: Array, q: Array, axis: Optional[Union[int, tuple[int, ...]]],
|
||||
interpolation: str, keepdims: bool, squash_nans: bool) -> Array:
|
||||
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
|
||||
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
|
||||
@ -843,7 +843,7 @@ def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def percentile(a: ArrayLike, q: ArrayLike,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
check_arraylike("percentile", a, q)
|
||||
@ -855,7 +855,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def nanpercentile(a: ArrayLike, q: ArrayLike,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
check_arraylike("nanpercentile", a, q)
|
||||
@ -866,7 +866,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
|
||||
|
||||
@_wraps(np.median, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
|
||||
def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
def median(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False,
|
||||
keepdims: bool = False) -> Array:
|
||||
check_arraylike("median", a)
|
||||
@ -875,7 +875,7 @@ def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
|
||||
@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
|
||||
def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
def nanmedian(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False,
|
||||
keepdims: bool = False) -> Array:
|
||||
check_arraylike("nanmedian", a)
|
||||
|
@ -16,7 +16,7 @@ from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from textwrap import dedent as _dedent
|
||||
from typing import Optional, Tuple, Union, cast
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -155,7 +155,7 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['return_indices'])
|
||||
def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: bool = False) -> Tuple[Array, ...]:
|
||||
def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: bool = False) -> tuple[Array, ...]:
|
||||
"""
|
||||
Helper function for intersect1d which is jit-able
|
||||
"""
|
||||
@ -175,7 +175,7 @@ def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: boo
|
||||
|
||||
@_wraps(np.intersect1d)
|
||||
def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
|
||||
return_indices: bool = False) -> Union[Array, Tuple[Array, Array, Array]]:
|
||||
return_indices: bool = False) -> Union[Array, tuple[Array, Array, Array]]:
|
||||
check_arraylike("intersect1d", ar1, ar2)
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()")
|
||||
ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()")
|
||||
@ -226,7 +226,7 @@ UNIQUE_SIZE_HINT = (
|
||||
"a concrete value for the size argument, which will determine the output size.")
|
||||
|
||||
@partial(jit, static_argnums=1)
|
||||
def _unique_sorted_mask(ar: Array, axis: int) -> Tuple[Array, Array, Array]:
|
||||
def _unique_sorted_mask(ar: Array, axis: int) -> tuple[Array, Array, Array]:
|
||||
aux = moveaxis(ar, axis, 0)
|
||||
if np.issubdtype(aux.dtype, np.complexfloating):
|
||||
# Work around issue in sorting of complex numbers with Nan only in the
|
||||
@ -255,7 +255,7 @@ def _unique_sorted_mask(ar: Array, axis: int) -> Tuple[Array, Array, Array]:
|
||||
def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bool = False,
|
||||
return_counts: bool = False, size: Optional[int] = None,
|
||||
fill_value: Optional[ArrayLike] = None, return_true_size: bool = False
|
||||
) -> Union[Array, Tuple[Array, ...]]:
|
||||
) -> Union[Array, tuple[Array, ...]]:
|
||||
"""
|
||||
Find the unique elements of an array along a particular axis.
|
||||
"""
|
||||
@ -280,7 +280,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
|
||||
result = full_like(result, fill_value, shape=(size, *result.shape[1:]))
|
||||
result = moveaxis(result, 0, axis)
|
||||
|
||||
ret: Tuple[Array, ...] = (result,)
|
||||
ret: tuple[Array, ...] = (result,)
|
||||
if return_index:
|
||||
if aux.size:
|
||||
ret += (perm[ind],)
|
||||
|
@ -19,7 +19,7 @@ Implements ufuncs for jax.numpy.
|
||||
from functools import partial
|
||||
import operator
|
||||
from textwrap import dedent
|
||||
from typing import Any, Callable, Tuple, Union, overload
|
||||
from typing import Any, Callable, Union, overload
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -285,7 +285,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
|
||||
@_wraps(np.divmod, module='numpy')
|
||||
@jit
|
||||
def divmod(x1: ArrayLike, x2: ArrayLike, /) -> Tuple[Array, Array]:
|
||||
def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]:
|
||||
x1, x2 = promote_args_numeric("divmod", x1, x2)
|
||||
if dtypes.issubdtype(dtypes.dtype(x1), np.integer):
|
||||
return floor_divide(x1, x2), remainder(x1, x2)
|
||||
@ -293,7 +293,7 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> Tuple[Array, Array]:
|
||||
return _float_divmod(x1, x2)
|
||||
|
||||
|
||||
def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> Tuple[Array, Array]:
|
||||
def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]:
|
||||
# see float_divmod in floatobject.c of CPython
|
||||
mod = lax.rem(x1, x2)
|
||||
div = lax.div(lax.sub(x1, mod), x2)
|
||||
@ -521,7 +521,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
|
||||
@_wraps(np.frexp, module='numpy')
|
||||
@jit
|
||||
def frexp(x: ArrayLike, /) -> Tuple[Array, Array]:
|
||||
def frexp(x: ArrayLike, /) -> tuple[Array, Array]:
|
||||
check_arraylike("frexp", x)
|
||||
x, = promote_dtypes_inexact(x)
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
@ -616,7 +616,7 @@ def real(val: ArrayLike, /) -> Array:
|
||||
|
||||
@_wraps(np.modf, module='numpy', skip_params=['out'])
|
||||
@jit
|
||||
def modf(x: ArrayLike, /, out=None) -> Tuple[Array, Array]:
|
||||
def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]:
|
||||
check_arraylike("modf", x)
|
||||
x, = promote_dtypes_inexact(x)
|
||||
if out is not None:
|
||||
|
@ -15,9 +15,8 @@
|
||||
from functools import partial
|
||||
import re
|
||||
import textwrap
|
||||
from typing import (
|
||||
Any, Callable, Dict, List, NamedTuple, Optional, Sequence, TypeVar
|
||||
)
|
||||
from typing import Any, Callable, NamedTuple, Optional, Sequence, TypeVar
|
||||
|
||||
import warnings
|
||||
|
||||
from jax._src import dtypes
|
||||
@ -53,7 +52,7 @@ class ParsedDoc(NamedTuple):
|
||||
signature: str = ""
|
||||
summary: str = ""
|
||||
front_matter: str = ""
|
||||
sections: Dict[str, str] = {}
|
||||
sections: dict[str, str] = {}
|
||||
|
||||
|
||||
def _parse_numpydoc(docstr: Optional[str]) -> ParsedDoc:
|
||||
@ -101,7 +100,7 @@ def _parse_numpydoc(docstr: Optional[str]) -> ParsedDoc:
|
||||
front_matter=front_matter, sections=sections)
|
||||
|
||||
|
||||
def _parse_parameters(body: str) -> Dict[str, str]:
|
||||
def _parse_parameters(body: str) -> dict[str, str]:
|
||||
"""Parse the Parameters section of a docstring."""
|
||||
title, underline, content = body.split('\n', 2)
|
||||
assert title == 'Parameters'
|
||||
@ -110,7 +109,7 @@ def _parse_parameters(body: str) -> Dict[str, str]:
|
||||
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}
|
||||
|
||||
|
||||
def _parse_extra_params(extra_params: str) -> Dict[str, str]:
|
||||
def _parse_extra_params(extra_params: str) -> dict[str, str]:
|
||||
"""Parse the extra parameters passed to _wraps()"""
|
||||
parameters = _parameter_break.split(extra_params.strip('\n'))
|
||||
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}
|
||||
@ -223,7 +222,7 @@ def _wraps(
|
||||
|
||||
_dtype = partial(dtypes.dtype, canonicalize=True)
|
||||
|
||||
def promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
|
||||
def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
|
||||
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
|
||||
if len(args) < 2:
|
||||
return [lax.asarray(arg) for arg in args]
|
||||
@ -264,7 +263,7 @@ def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
|
||||
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
|
||||
|
||||
|
||||
def promote_dtypes(*args: ArrayLike) -> List[Array]:
|
||||
def promote_dtypes(*args: ArrayLike) -> list[Array]:
|
||||
"""Convenience function to apply Numpy argument dtype promotion."""
|
||||
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
|
||||
if len(args) < 2:
|
||||
@ -275,7 +274,7 @@ def promote_dtypes(*args: ArrayLike) -> List[Array]:
|
||||
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
|
||||
|
||||
|
||||
def promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
|
||||
def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]:
|
||||
"""Convenience function to apply Numpy argument dtype promotion.
|
||||
|
||||
Promotes arguments to an inexact type."""
|
||||
@ -286,7 +285,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
|
||||
for x in args]
|
||||
|
||||
|
||||
def promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
|
||||
def promote_dtypes_numeric(*args: ArrayLike) -> list[Array]:
|
||||
"""Convenience function to apply Numpy argument dtype promotion.
|
||||
|
||||
Promotes arguments to a numeric (non-bool) type."""
|
||||
@ -297,7 +296,7 @@ def promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
|
||||
for x in args]
|
||||
|
||||
|
||||
def promote_dtypes_complex(*args: ArrayLike) -> List[Array]:
|
||||
def promote_dtypes_complex(*args: ArrayLike) -> list[Array]:
|
||||
"""Convenience function to apply Numpy argument dtype promotion.
|
||||
|
||||
Promotes arguments to a complex type."""
|
||||
@ -350,20 +349,20 @@ def _check_no_float0s(fun_name: str, *args: Any):
|
||||
"taken a gradient with respect to an integer argument.")
|
||||
|
||||
|
||||
def promote_args(fun_name: str, *args: ArrayLike) -> List[Array]:
|
||||
def promote_args(fun_name: str, *args: ArrayLike) -> list[Array]:
|
||||
"""Convenience function to apply Numpy argument shape and dtype promotion."""
|
||||
check_arraylike(fun_name, *args)
|
||||
_check_no_float0s(fun_name, *args)
|
||||
return promote_shapes(fun_name, *promote_dtypes(*args))
|
||||
|
||||
|
||||
def promote_args_numeric(fun_name: str, *args: ArrayLike) -> List[Array]:
|
||||
def promote_args_numeric(fun_name: str, *args: ArrayLike) -> list[Array]:
|
||||
check_arraylike(fun_name, *args)
|
||||
_check_no_float0s(fun_name, *args)
|
||||
return promote_shapes(fun_name, *promote_dtypes_numeric(*args))
|
||||
|
||||
|
||||
def promote_args_inexact(fun_name: str, *args: ArrayLike) -> List[Array]:
|
||||
def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]:
|
||||
"""Convenience function to apply Numpy argument shape and dtype promotion.
|
||||
|
||||
Promotes non-inexact types to an inexact type."""
|
||||
@ -373,7 +372,7 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> List[Array]:
|
||||
|
||||
|
||||
@partial(api.jit, inline=True)
|
||||
def _broadcast_arrays(*args: ArrayLike) -> List[Array]:
|
||||
def _broadcast_arrays(*args: ArrayLike) -> list[Array]:
|
||||
"""Like Numpy's broadcast_arrays but doesn't return views."""
|
||||
shapes = [np.shape(arg) for arg in args]
|
||||
if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import functools
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
from typing import Any, Callable
|
||||
|
||||
from jax._src import api
|
||||
from jax import lax
|
||||
@ -30,13 +30,13 @@ _ARGUMENT_LIST = '{0:}(?:,{0:})*'.format(_ARGUMENT)
|
||||
_SIGNATURE = '^{0:}->{0:}$'.format(_ARGUMENT_LIST)
|
||||
|
||||
|
||||
CoreDims = Tuple[str, ...]
|
||||
CoreDims = tuple[str, ...]
|
||||
NDArray = Any
|
||||
|
||||
|
||||
def _parse_gufunc_signature(
|
||||
signature: str,
|
||||
) -> Tuple[List[CoreDims], List[CoreDims]]:
|
||||
) -> tuple[list[CoreDims], list[CoreDims]]:
|
||||
"""Parse string signatures for a generalized universal function.
|
||||
|
||||
Args:
|
||||
@ -56,8 +56,8 @@ def _parse_gufunc_signature(
|
||||
|
||||
|
||||
def _update_dim_sizes(
|
||||
dim_sizes: Dict[str, int],
|
||||
shape: Tuple[int, ...],
|
||||
dim_sizes: dict[str, int],
|
||||
shape: tuple[int, ...],
|
||||
core_dims: CoreDims,
|
||||
error_context: str = "",
|
||||
*,
|
||||
@ -94,10 +94,10 @@ def _update_dim_sizes(
|
||||
|
||||
|
||||
def _parse_input_dimensions(
|
||||
args: Tuple[NDArray, ...],
|
||||
input_core_dims: List[CoreDims],
|
||||
args: tuple[NDArray, ...],
|
||||
input_core_dims: list[CoreDims],
|
||||
error_context: str = "",
|
||||
) -> Tuple[Tuple[int, ...], Dict[str, int]]:
|
||||
) -> tuple[tuple[int, ...], dict[str, int]]:
|
||||
"""Parse broadcast and core dimensions for vectorize with a signature.
|
||||
|
||||
Args:
|
||||
@ -114,7 +114,7 @@ def _parse_input_dimensions(
|
||||
'wrong number of positional arguments: expected %r, got %r %s'
|
||||
% (len(input_core_dims), len(args), error_context))
|
||||
shapes = []
|
||||
dim_sizes: Dict[str, int] = {}
|
||||
dim_sizes: dict[str, int] = {}
|
||||
for arg, core_dims in zip(args, input_core_dims):
|
||||
_update_dim_sizes(dim_sizes, arg.shape, core_dims, error_context,
|
||||
is_input=True)
|
||||
@ -127,8 +127,8 @@ def _parse_input_dimensions(
|
||||
|
||||
def _check_output_dims(
|
||||
func: Callable,
|
||||
dim_sizes: Dict[str, int],
|
||||
expected_output_core_dims: List[CoreDims],
|
||||
dim_sizes: dict[str, int],
|
||||
expected_output_core_dims: list[CoreDims],
|
||||
error_context: str = "",
|
||||
) -> Callable:
|
||||
"""Check that output core dimensions match the signature."""
|
||||
|
@ -14,7 +14,7 @@
|
||||
"""Sharding utilities"""
|
||||
|
||||
import itertools
|
||||
from typing import List, Sequence, Tuple, Union
|
||||
from typing import Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -22,7 +22,7 @@ from jax._src.lib import xla_client as xc
|
||||
|
||||
|
||||
def get_num_ways_dim_sharded(
|
||||
hlo_sharding: xc.HloSharding) -> Tuple[Sequence[int], int]:
|
||||
hlo_sharding: xc.HloSharding) -> tuple[Sequence[int], int]:
|
||||
if hlo_sharding.is_replicated(): # type: ignore
|
||||
return [], 1
|
||||
partitions = hlo_sharding.tile_assignment_dimensions()
|
||||
@ -61,7 +61,7 @@ def are_op_shardings_equal(op1: Union[xc.OpSharding, xc.HloSharding],
|
||||
return hc1 == hc2
|
||||
|
||||
|
||||
_Index = Union[int, slice, Tuple[Union[int, slice], ...]]
|
||||
_Index = Union[int, slice, tuple[Union[int, slice], ...]]
|
||||
|
||||
|
||||
def op_sharding_to_numpy_indices(
|
||||
@ -81,7 +81,7 @@ def op_sharding_to_numpy_indices(
|
||||
partitions, num_replicas = get_num_ways_dim_sharded(hlo_sharding)
|
||||
assert len(partitions) == len(shape), (len(partitions), len(shape))
|
||||
|
||||
axis_indices: List[Sequence[_Index]] = []
|
||||
axis_indices: list[Sequence[_Index]] = []
|
||||
for dim, n_shards in zip(shape, partitions):
|
||||
if n_shards == 1:
|
||||
axis_indices.append([slice(None)])
|
||||
@ -103,6 +103,6 @@ def op_sharding_to_numpy_indices(
|
||||
|
||||
def op_sharding_to_indices(
|
||||
op_sharding: xc.HloSharding, shape: Sequence[int],
|
||||
num_devices: int) -> Tuple[Tuple[slice, ...], ...]:
|
||||
num_devices: int) -> tuple[tuple[slice, ...], ...]:
|
||||
indices = op_sharding_to_numpy_indices(op_sharding, shape, num_devices)
|
||||
return tuple(indices.flat)
|
||||
|
@ -15,7 +15,7 @@
|
||||
# Helpers for indexed updates.
|
||||
|
||||
import sys
|
||||
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -37,7 +37,7 @@ if sys.version_info >= (3, 10):
|
||||
SingleIndex = Union[None, int, slice, Sequence[int], Array, EllipsisType]
|
||||
else:
|
||||
SingleIndex = Union[None, int, slice, Sequence[int], Array]
|
||||
Index = Union[SingleIndex, Tuple[SingleIndex, ...]]
|
||||
Index = Union[SingleIndex, tuple[SingleIndex, ...]]
|
||||
Scalar = Union[complex, float, int, np.number]
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import overload, Literal, Optional, Tuple, Union
|
||||
from typing import overload, Literal, Optional, Union
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
@ -32,14 +32,14 @@ def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] =
|
||||
|
||||
@overload
|
||||
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
|
||||
keepdims: bool = False, *, return_sign: Literal[True]) -> Tuple[Array, Array]: ...
|
||||
keepdims: bool = False, *, return_sign: Literal[True]) -> tuple[Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
|
||||
keepdims: bool = False, return_sign: bool = False) -> Union[Array, Tuple[Array, Array]]: ...
|
||||
keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]: ...
|
||||
|
||||
def logsumexp(a: ArrayLike, axis: Optional[int] = None, b: Optional[ArrayLike] = None,
|
||||
keepdims: bool = False, return_sign: bool = False) -> Union[Array, Tuple[Array, Array]]:
|
||||
keepdims: bool = False, return_sign: bool = False) -> Union[Array, tuple[Array, Array]]:
|
||||
r"""Log-sum-exp reduction.
|
||||
|
||||
Computes
|
||||
|
@ -17,8 +17,8 @@ import inspect
|
||||
import logging
|
||||
import weakref
|
||||
import numpy as np
|
||||
from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional,
|
||||
Iterable, NamedTuple, Any)
|
||||
from typing import (Callable, Sequence, Union, cast, Optional, Iterable,
|
||||
NamedTuple, Any)
|
||||
import itertools as it
|
||||
from functools import partial, lru_cache
|
||||
import threading
|
||||
@ -399,9 +399,9 @@ class PjitInfo(NamedTuple):
|
||||
fun: Callable
|
||||
in_shardings: Any
|
||||
out_shardings: Any
|
||||
static_argnums: Tuple[int, ...]
|
||||
static_argnames: Tuple[str, ...]
|
||||
donate_argnums: Tuple[int, ...]
|
||||
static_argnums: tuple[int, ...]
|
||||
static_argnames: tuple[str, ...]
|
||||
donate_argnums: tuple[int, ...]
|
||||
device: Optional[xc.Device]
|
||||
backend: Optional[str]
|
||||
keep_unused: bool
|
||||
@ -537,7 +537,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
donate_argnums)
|
||||
|
||||
def _extract_implicit_args(
|
||||
in_type: Sequence[Tuple[core.AbstractValue, bool]],
|
||||
in_type: Sequence[tuple[core.AbstractValue, bool]],
|
||||
explicit_args: Sequence[Any]
|
||||
) -> Sequence[core.Tracer]:
|
||||
"""
|
||||
@ -567,7 +567,7 @@ def _extract_implicit_args(
|
||||
return [x for x, (_, e) in zip(args, in_type) if not e] # type: ignore
|
||||
|
||||
def _flat_axes_specs(abstracted_axes, *args, **kwargs
|
||||
) -> Optional[List[pe.AbstractedAxesSpec]]:
|
||||
) -> Optional[list[pe.AbstractedAxesSpec]]:
|
||||
if abstracted_axes is None: return None
|
||||
if kwargs: raise NotImplementedError
|
||||
def ax_leaf(l):
|
||||
@ -978,7 +978,7 @@ def _pjit_jaxpr(fun, out_shardings_thunk, in_type, debug_info,
|
||||
|
||||
|
||||
def pjit_check_aval_sharding(
|
||||
shardings, flat_avals, names: Optional[Tuple[str, ...]],
|
||||
shardings, flat_avals, names: Optional[tuple[str, ...]],
|
||||
what_aval: str, allow_uneven_sharding: bool):
|
||||
new_names = [''] * len(shardings) if names is None else names
|
||||
for aval, s, name in zip(flat_avals, shardings, new_names):
|
||||
@ -1208,7 +1208,7 @@ pjit_p.def_impl(_pjit_call_impl)
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SameDeviceAssignmentTuple:
|
||||
shardings: Tuple[PjitSharding, ...]
|
||||
shardings: tuple[PjitSharding, ...]
|
||||
# device_assignment is Optional because shardings can contain `AUTO` and in
|
||||
# that case `mesh` is compulsory to be used. So in that case
|
||||
# `_pjit_lower_cached` cache, resource_env will check against the devices.
|
||||
@ -1262,9 +1262,9 @@ def _pjit_lower_cached(
|
||||
always_lower: bool,
|
||||
*,
|
||||
lowering_platform: Optional[str]):
|
||||
in_shardings: Tuple[PjitShardingMinusUnspecified, ...] = cast(
|
||||
Tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
|
||||
out_shardings: Tuple[PjitSharding, ...] = sdat_out_shardings.shardings
|
||||
in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast(
|
||||
tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
|
||||
out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings
|
||||
|
||||
if resource_env is not None:
|
||||
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
|
||||
@ -1323,7 +1323,7 @@ pe.custom_staging_rules[pjit_p] = pjit_staging_rule
|
||||
|
||||
# TODO(mattjj): remove/trivialize this when jaxprs have type annotation on them,
|
||||
# since it's actually not possible in general to infer the type from the term
|
||||
def _out_type(jaxpr: core.ClosedJaxpr) -> List[core.AbstractValue]:
|
||||
def _out_type(jaxpr: core.ClosedJaxpr) -> list[core.AbstractValue]:
|
||||
out = []
|
||||
in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)}
|
||||
out_idx = {x: i for i, x in enumerate(jaxpr.jaxpr.invars)
|
||||
@ -1430,7 +1430,7 @@ pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None)
|
||||
|
||||
def _pjit_batcher_for_sharding(
|
||||
s: Union[GSPMDSharding, UnspecifiedValue],
|
||||
dim: int, val: Tuple[str, ...], mesh, ndim: int):
|
||||
dim: int, val: tuple[str, ...], mesh, ndim: int):
|
||||
if is_unspecified(s):
|
||||
return s
|
||||
if not val:
|
||||
@ -1490,7 +1490,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
|
||||
|
||||
@weakref_lru_cache
|
||||
def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr,
|
||||
fwds_known: Tuple[Optional[int]]) -> core.ClosedJaxpr:
|
||||
fwds_known: tuple[Optional[int]]) -> core.ClosedJaxpr:
|
||||
updated_jaxpr = known_jaxpr.jaxpr.replace(
|
||||
outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, fwds_known)
|
||||
if i is None])
|
||||
@ -1603,7 +1603,7 @@ def _pjit_partial_eval_custom_params_updater(
|
||||
unks_in: Sequence[bool], inst_in: Sequence[bool],
|
||||
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
|
||||
num_res: int, params_known: dict, params_staged: dict
|
||||
) -> Tuple[dict, dict]:
|
||||
) -> tuple[dict, dict]:
|
||||
# prune inputs to jaxpr_known according to unks_in
|
||||
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
|
||||
in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings'])
|
||||
@ -1688,14 +1688,14 @@ ad.reducing_transposes[pjit_p] = _pjit_transpose
|
||||
|
||||
@weakref_lru_cache
|
||||
def _dce_jaxpr_pjit(
|
||||
jaxpr: core.ClosedJaxpr, used_outputs: Tuple[bool]
|
||||
) -> Tuple[core.ClosedJaxpr, List[bool]]:
|
||||
jaxpr: core.ClosedJaxpr, used_outputs: tuple[bool]
|
||||
) -> tuple[core.ClosedJaxpr, list[bool]]:
|
||||
new_jaxpr, used_inputs = pe.dce_jaxpr(jaxpr.jaxpr, used_outputs)
|
||||
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts), used_inputs
|
||||
|
||||
|
||||
def dce_jaxpr_pjit_rule(used_outputs: List[bool], eqn: core.JaxprEqn
|
||||
) -> Tuple[List[bool], Optional[core.JaxprEqn]]:
|
||||
def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
) -> tuple[list[bool], Optional[core.JaxprEqn]]:
|
||||
dced_jaxpr, used_inputs = _dce_jaxpr_pjit(
|
||||
eqn.params['jaxpr'], tuple(used_outputs))
|
||||
|
||||
@ -1965,13 +1965,13 @@ def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[Parti
|
||||
|
||||
|
||||
def _get_op_sharding_from_executable(
|
||||
executable) -> Tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]:
|
||||
in_op_shardings: List[xc.OpSharding] = []
|
||||
executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]:
|
||||
in_op_shardings: list[xc.OpSharding] = []
|
||||
parameter_shardings_from_xla = executable.get_parameter_shardings()
|
||||
if parameter_shardings_from_xla is not None:
|
||||
in_op_shardings = parameter_shardings_from_xla
|
||||
|
||||
out_op_shardings: List[xc.OpSharding] = []
|
||||
out_op_shardings: list[xc.OpSharding] = []
|
||||
output_shardings_from_xla = executable.get_output_shardings()
|
||||
if output_shardings_from_xla is not None:
|
||||
out_op_shardings = output_shardings_from_xla
|
||||
@ -1979,10 +1979,10 @@ def _get_op_sharding_from_executable(
|
||||
return in_op_shardings, out_op_shardings
|
||||
|
||||
|
||||
def _get_ppspec_from_executable(executable, mesh) -> Tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
|
||||
def _get_ppspec_from_executable(executable, mesh) -> tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
|
||||
input_op_shardings: Sequence[xc.OpSharding] = executable.hlo_modules()[0].spmd_parameters_shardings
|
||||
output_op_sharding: xc.OpSharding = executable.hlo_modules()[0].spmd_output_sharding
|
||||
in_ppspec: List[ParsedPartitionSpec] = []
|
||||
in_ppspec: list[ParsedPartitionSpec] = []
|
||||
for s in input_op_shardings:
|
||||
in_ppspec.extend(parse_flatten_op_sharding(s, mesh))
|
||||
out_ppspec = parse_flatten_op_sharding(output_op_sharding, mesh)
|
||||
@ -1991,7 +1991,7 @@ def _get_ppspec_from_executable(executable, mesh) -> Tuple[Sequence[ParsedPartit
|
||||
|
||||
def _get_pspec_from_executable(
|
||||
executable, mesh: pxla.Mesh
|
||||
) -> Tuple[Tuple[PartitionSpec, ...], Tuple[PartitionSpec, ...]]:
|
||||
) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]:
|
||||
in_ppspec, out_ppspec = _get_ppspec_from_executable(executable, mesh)
|
||||
out_partition_spec = _get_partition_spec(out_ppspec)
|
||||
in_partition_spec = _get_partition_spec(in_ppspec)
|
||||
|
@ -29,7 +29,7 @@ import abc
|
||||
import enum
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
from typing import NamedTuple, Optional, Sequence, Union
|
||||
from jax._src.config import config
|
||||
|
||||
try:
|
||||
@ -95,7 +95,7 @@ class _TextDoc(Doc):
|
||||
|
||||
class _ConcatDoc(Doc):
|
||||
__slots__ = ("children",)
|
||||
children: List[Doc]
|
||||
children: list[Doc]
|
||||
|
||||
def __init__(self, children: Sequence[Doc]):
|
||||
self.children = list(children)
|
||||
@ -164,7 +164,7 @@ _BreakMode = enum.Enum("_BreakMode", ["FLAT", "BREAK"])
|
||||
# non-recursive formulation using an explicit stack, necessary because Python
|
||||
# doesn't have a tail recursion optimization.
|
||||
|
||||
def _fits(doc: Doc, width: int, agenda: List[Tuple[int, _BreakMode, Doc]]
|
||||
def _fits(doc: Doc, width: int, agenda: list[tuple[int, _BreakMode, Doc]]
|
||||
) -> bool:
|
||||
while width >= 0 and len(agenda) > 0:
|
||||
i, m, doc = agenda.pop()
|
||||
@ -234,11 +234,11 @@ class _State(NamedTuple):
|
||||
class _Line(NamedTuple):
|
||||
text: str
|
||||
width: int
|
||||
annotations: Union[Optional[str], List[str]]
|
||||
annotations: Union[Optional[str], list[str]]
|
||||
|
||||
|
||||
def _update_color(use_color: bool, state: _ColorState, update: _ColorState
|
||||
) -> Tuple[_ColorState, str]:
|
||||
) -> tuple[_ColorState, str]:
|
||||
if not use_color or colorama is None:
|
||||
return update, ""
|
||||
color_str = ""
|
||||
|
@ -17,8 +17,8 @@ import abc
|
||||
from functools import partial, reduce
|
||||
import math
|
||||
import operator as op
|
||||
from typing import (Any, Callable, Hashable, Iterator, List, NamedTuple,
|
||||
Set, Sequence, Tuple, Union)
|
||||
from typing import (Any, Callable, Hashable, Iterator, NamedTuple,
|
||||
Sequence, Union)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -154,7 +154,7 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def shape(self) -> Tuple[int, ...]: ...
|
||||
def shape(self) -> tuple[int, ...]: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@ -211,7 +211,7 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
|
||||
@abc.abstractmethod
|
||||
def device(self) -> Device: ...
|
||||
@abc.abstractmethod
|
||||
def devices(self) -> Set[Device]: ...
|
||||
def devices(self) -> set[Device]: ...
|
||||
@abc.abstractmethod
|
||||
def delete(self) -> None: ...
|
||||
@abc.abstractmethod
|
||||
@ -220,10 +220,10 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
|
||||
def on_device_size_in_bytes(self) -> int: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def addressable_shards(self) -> List[Shard]: ...
|
||||
def addressable_shards(self) -> list[Shard]: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def global_shards(self) -> List[Shard]: ...
|
||||
def global_shards(self) -> list[Shard]: ...
|
||||
@abc.abstractmethod
|
||||
def addressable_data(self, index: int) -> PRNGKeyArray: ...
|
||||
|
||||
@ -305,7 +305,7 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
return PRNGKeyArrayImpl(self.impl, self._base_array.addressable_data(index))
|
||||
|
||||
@property
|
||||
def addressable_shards(self) -> List[Shard]:
|
||||
def addressable_shards(self) -> list[Shard]:
|
||||
return [
|
||||
type(s)(
|
||||
device=s._device,
|
||||
@ -317,7 +317,7 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
]
|
||||
|
||||
@property
|
||||
def global_shards(self) -> List[Shard]:
|
||||
def global_shards(self) -> list[Shard]:
|
||||
return [
|
||||
type(s)(
|
||||
device=s._device,
|
||||
|
@ -16,7 +16,7 @@
|
||||
from functools import partial
|
||||
import math
|
||||
from operator import index
|
||||
from typing import Optional, Sequence, Tuple, Union
|
||||
from typing import Optional, Sequence, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -67,7 +67,7 @@ def _isnan(x: ArrayLike) -> Array:
|
||||
return lax.ne(x, x)
|
||||
|
||||
|
||||
def _check_prng_key(key) -> Tuple[prng.PRNGKeyArray, bool]:
|
||||
def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]:
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
if isinstance(key, prng.PRNGKeyArray):
|
||||
return key, False
|
||||
|
@ -19,7 +19,7 @@ import numpy as np
|
||||
import scipy.linalg
|
||||
import textwrap
|
||||
import warnings
|
||||
from typing import cast, overload, Any, Literal, Optional, Tuple, Union
|
||||
from typing import cast, overload, Any, Literal, Optional, Union
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -56,7 +56,7 @@ def cholesky(a: ArrayLike, lower: bool = False, overwrite_a: bool = False,
|
||||
@_wraps(scipy.linalg.cho_factor,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
|
||||
def cho_factor(a: ArrayLike, lower: bool = False, overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> Tuple[Array, bool]:
|
||||
check_finite: bool = True) -> tuple[Array, bool]:
|
||||
del overwrite_a, check_finite # Unused
|
||||
return (cholesky(a, lower=lower), lower)
|
||||
|
||||
@ -72,30 +72,30 @@ def _cho_solve(c: ArrayLike, b: ArrayLike, lower: bool) -> Array:
|
||||
|
||||
@_wraps(scipy.linalg.cho_solve, update_doc=False,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'check_finite'))
|
||||
def cho_solve(c_and_lower: Tuple[ArrayLike, bool], b: ArrayLike,
|
||||
def cho_solve(c_and_lower: tuple[ArrayLike, bool], b: ArrayLike,
|
||||
overwrite_b: bool = False, check_finite: bool = True) -> Array:
|
||||
del overwrite_b, check_finite # Unused
|
||||
c, lower = c_and_lower
|
||||
return _cho_solve(c, b, lower)
|
||||
|
||||
@overload
|
||||
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[True]) -> Tuple[Array, Array, Array]: ...
|
||||
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[True]) -> tuple[Array, Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[False]) -> Array: ...
|
||||
|
||||
@overload
|
||||
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Union[Array, Tuple[Array, Array, Array]]: ...
|
||||
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Union[Array, tuple[Array, Array, Array]]: ...
|
||||
|
||||
@partial(jit, static_argnames=('full_matrices', 'compute_uv'))
|
||||
def _svd(a: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Union[Array, Tuple[Array, Array, Array]]:
|
||||
def _svd(a: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Union[Array, tuple[Array, Array, Array]]:
|
||||
a, = promote_dtypes_inexact(jnp.asarray(a))
|
||||
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
|
||||
@overload
|
||||
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: Literal[True] = True,
|
||||
overwrite_a: bool = False, check_finite: bool = True,
|
||||
lapack_driver: str = 'gesdd') -> Tuple[Array, Array, Array]: ...
|
||||
lapack_driver: str = 'gesdd') -> tuple[Array, Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False],
|
||||
@ -110,13 +110,13 @@ def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False],
|
||||
@overload
|
||||
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||
overwrite_a: bool = False, check_finite: bool = True,
|
||||
lapack_driver: str = 'gesdd') -> Union[Array, Tuple[Array, Array, Array]]: ...
|
||||
lapack_driver: str = 'gesdd') -> Union[Array, tuple[Array, Array, Array]]: ...
|
||||
|
||||
@_wraps(scipy.linalg.svd,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lapack_driver'))
|
||||
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||
overwrite_a: bool = False, check_finite: bool = True,
|
||||
lapack_driver: str = 'gesdd') -> Union[Array, Tuple[Array, Array, Array]]:
|
||||
lapack_driver: str = 'gesdd') -> Union[Array, tuple[Array, Array, Array]]:
|
||||
del overwrite_a, check_finite, lapack_driver # unused
|
||||
return _svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
|
||||
@ -133,15 +133,15 @@ def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: Liter
|
||||
|
||||
@overload
|
||||
def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: Literal[False],
|
||||
eigvals: None, type: int) -> Tuple[Array, Array]: ...
|
||||
eigvals: None, type: int) -> tuple[Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: bool,
|
||||
eigvals: None, type: int) -> Union[Array, Tuple[Array, Array]]: ...
|
||||
eigvals: None, type: int) -> Union[Array, tuple[Array, Array]]: ...
|
||||
|
||||
@partial(jit, static_argnames=('lower', 'eigvals_only', 'eigvals', 'type'))
|
||||
def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: bool,
|
||||
eigvals: None, type: int) -> Union[Array, Tuple[Array, Array]]:
|
||||
eigvals: None, type: int) -> Union[Array, tuple[Array, Array]]:
|
||||
if b is not None:
|
||||
raise NotImplementedError("Only the b=None case of eigh is implemented")
|
||||
if type != 1:
|
||||
@ -162,7 +162,7 @@ def _eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool, eigvals_only: bool,
|
||||
def eigh(a: ArrayLike, b: Optional[ArrayLike] = None, lower: bool = True,
|
||||
eigvals_only: Literal[False] = False, overwrite_a: bool = False,
|
||||
overwrite_b: bool = False, turbo: bool = True, eigvals: None = None,
|
||||
type: int = 1, check_finite: bool = True) -> Tuple[Array, Array]: ...
|
||||
type: int = 1, check_finite: bool = True) -> tuple[Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def eigh(a: ArrayLike, b: Optional[ArrayLike] = None, lower: bool = True, *,
|
||||
@ -180,7 +180,7 @@ def eigh(a: ArrayLike, b: Optional[ArrayLike], lower: bool,
|
||||
def eigh(a: ArrayLike, b: Optional[ArrayLike] = None, lower: bool = True,
|
||||
eigvals_only: bool = False, overwrite_a: bool = False,
|
||||
overwrite_b: bool = False, turbo: bool = True, eigvals: None = None,
|
||||
type: int = 1, check_finite: bool = True) -> Union[Array, Tuple[Array, Array]]: ...
|
||||
type: int = 1, check_finite: bool = True) -> Union[Array, tuple[Array, Array]]: ...
|
||||
|
||||
@_wraps(scipy.linalg.eigh,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc,
|
||||
@ -188,18 +188,18 @@ def eigh(a: ArrayLike, b: Optional[ArrayLike] = None, lower: bool = True,
|
||||
def eigh(a: ArrayLike, b: Optional[ArrayLike] = None, lower: bool = True,
|
||||
eigvals_only: bool = False, overwrite_a: bool = False,
|
||||
overwrite_b: bool = False, turbo: bool = True, eigvals: None = None,
|
||||
type: int = 1, check_finite: bool = True) -> Union[Array, Tuple[Array, Array]]:
|
||||
type: int = 1, check_finite: bool = True) -> Union[Array, tuple[Array, Array]]:
|
||||
del overwrite_a, overwrite_b, turbo, check_finite # unused
|
||||
return _eigh(a, b, lower, eigvals_only, eigvals, type)
|
||||
|
||||
@partial(jit, static_argnames=('output',))
|
||||
def _schur(a: Array, output: str) -> Tuple[Array, Array]:
|
||||
def _schur(a: Array, output: str) -> tuple[Array, Array]:
|
||||
if output == "complex":
|
||||
a = a.astype(dtypes.to_complex_dtype(a.dtype))
|
||||
return lax_linalg.schur(a)
|
||||
|
||||
@_wraps(scipy.linalg.schur)
|
||||
def schur(a: ArrayLike, output: str = 'real') -> Tuple[Array, Array]:
|
||||
def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]:
|
||||
if output not in ('real', 'complex'):
|
||||
raise ValueError(
|
||||
f"Expected 'output' to be either 'real' or 'complex', got {output=}.")
|
||||
@ -215,7 +215,7 @@ def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> A
|
||||
@_wraps(scipy.linalg.lu_factor,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
|
||||
@partial(jit, static_argnames=('overwrite_a', 'check_finite'))
|
||||
def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Tuple[Array, Array]:
|
||||
def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]:
|
||||
del overwrite_a, check_finite # unused
|
||||
a, = promote_dtypes_inexact(jnp.asarray(a))
|
||||
lu, pivots, _ = lax_linalg.lu(a)
|
||||
@ -225,7 +225,7 @@ def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True
|
||||
@_wraps(scipy.linalg.lu_solve,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'check_finite'))
|
||||
@partial(jit, static_argnames=('trans', 'overwrite_b', 'check_finite'))
|
||||
def lu_solve(lu_and_piv: Tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0,
|
||||
def lu_solve(lu_and_piv: tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0,
|
||||
overwrite_b: bool = False, check_finite: bool = True) -> Array:
|
||||
del overwrite_b, check_finite # unused
|
||||
lu, pivots = lu_and_piv
|
||||
@ -234,16 +234,16 @@ def lu_solve(lu_and_piv: Tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0,
|
||||
return lax_linalg.lu_solve(lu, perm, b, trans)
|
||||
|
||||
@overload
|
||||
def _lu(a: ArrayLike, permute_l: Literal[True]) -> Tuple[Array, Array]: ...
|
||||
def _lu(a: ArrayLike, permute_l: Literal[True]) -> tuple[Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def _lu(a: ArrayLike, permute_l: Literal[False]) -> Tuple[Array, Array, Array]: ...
|
||||
def _lu(a: ArrayLike, permute_l: Literal[False]) -> tuple[Array, Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def _lu(a: ArrayLike, permute_l: bool) -> Union[Tuple[Array, Array], Tuple[Array, Array, Array]]: ...
|
||||
def _lu(a: ArrayLike, permute_l: bool) -> Union[tuple[Array, Array], tuple[Array, Array, Array]]: ...
|
||||
|
||||
@partial(jit, static_argnums=(1,))
|
||||
def _lu(a: ArrayLike, permute_l: bool) -> Union[Tuple[Array, Array], Tuple[Array, Array, Array]]:
|
||||
def _lu(a: ArrayLike, permute_l: bool) -> Union[tuple[Array, Array], tuple[Array, Array, Array]]:
|
||||
a, = promote_dtypes_inexact(jnp.asarray(a))
|
||||
lu, _, permutation = lax_linalg.lu(a)
|
||||
dtype = lax.dtype(a)
|
||||
@ -259,35 +259,35 @@ def _lu(a: ArrayLike, permute_l: bool) -> Union[Tuple[Array, Array], Tuple[Array
|
||||
|
||||
@overload
|
||||
def lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> Tuple[Array, Array, Array]: ...
|
||||
check_finite: bool = True) -> tuple[Array, Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> Tuple[Array, Array]: ...
|
||||
check_finite: bool = True) -> tuple[Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> Union[Tuple[Array, Array], Tuple[Array, Array, Array]]: ...
|
||||
check_finite: bool = True) -> Union[tuple[Array, Array], tuple[Array, Array, Array]]: ...
|
||||
|
||||
@_wraps(scipy.linalg.lu, update_doc=False,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
|
||||
@partial(jit, static_argnames=('permute_l', 'overwrite_a', 'check_finite'))
|
||||
def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> Union[Tuple[Array, Array], Tuple[Array, Array, Array]]:
|
||||
check_finite: bool = True) -> Union[tuple[Array, Array], tuple[Array, Array, Array]]:
|
||||
del overwrite_a, check_finite # unused
|
||||
return _lu(a, permute_l)
|
||||
|
||||
@overload
|
||||
def _qr(a: ArrayLike, mode: Literal["r"], pivoting: bool) -> Tuple[Array]: ...
|
||||
def _qr(a: ArrayLike, mode: Literal["r"], pivoting: bool) -> tuple[Array]: ...
|
||||
|
||||
@overload
|
||||
def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: bool) -> Tuple[Array, Array]: ...
|
||||
def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: bool) -> tuple[Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def _qr(a: ArrayLike, mode: str, pivoting: bool) -> Union[Tuple[Array], Tuple[Array, Array]]: ...
|
||||
def _qr(a: ArrayLike, mode: str, pivoting: bool) -> Union[tuple[Array], tuple[Array, Array]]: ...
|
||||
|
||||
@partial(jit, static_argnames=('mode', 'pivoting'))
|
||||
def _qr(a: ArrayLike, mode: str, pivoting: bool) -> Union[Tuple[Array], Tuple[Array, Array]]:
|
||||
def _qr(a: ArrayLike, mode: str, pivoting: bool) -> Union[tuple[Array], tuple[Array, Array]]:
|
||||
if pivoting:
|
||||
raise NotImplementedError(
|
||||
"The pivoting=True case of qr is not implemented.")
|
||||
@ -306,24 +306,24 @@ def _qr(a: ArrayLike, mode: str, pivoting: bool) -> Union[Tuple[Array], Tuple[Ar
|
||||
|
||||
@overload
|
||||
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: Literal["full", "economic"] = "full",
|
||||
pivoting: bool = False, check_finite: bool = True) -> Tuple[Array, Array]: ...
|
||||
pivoting: bool = False, check_finite: bool = True) -> tuple[Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def qr(a: ArrayLike, overwrite_a: bool, lwork: Any, mode: Literal["r"],
|
||||
pivoting: bool = False, check_finite: bool = True) -> Tuple[Array]: ...
|
||||
pivoting: bool = False, check_finite: bool = True) -> tuple[Array]: ...
|
||||
|
||||
@overload
|
||||
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal["r"],
|
||||
pivoting: bool = False, check_finite: bool = True) -> Tuple[Array]: ...
|
||||
pivoting: bool = False, check_finite: bool = True) -> tuple[Array]: ...
|
||||
|
||||
@overload
|
||||
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full",
|
||||
pivoting: bool = False, check_finite: bool = True) -> Union[Tuple[Array], Tuple[Array, Array]]: ...
|
||||
pivoting: bool = False, check_finite: bool = True) -> Union[tuple[Array], tuple[Array, Array]]: ...
|
||||
|
||||
@_wraps(scipy.linalg.qr,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lwork'))
|
||||
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full",
|
||||
pivoting: bool = False, check_finite: bool = True) -> Union[Tuple[Array], Tuple[Array, Array]]:
|
||||
pivoting: bool = False, check_finite: bool = True) -> Union[tuple[Array], tuple[Array, Array]]:
|
||||
del overwrite_a, lwork, check_finite # unused
|
||||
return _qr(a, mode, pivoting)
|
||||
|
||||
@ -458,7 +458,7 @@ def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 1
|
||||
return R
|
||||
|
||||
@jit
|
||||
def _calc_P_Q(A: ArrayLike) -> Tuple[Array, Array, Array]:
|
||||
def _calc_P_Q(A: ArrayLike) -> tuple[Array, Array, Array]:
|
||||
A = jnp.asarray(A)
|
||||
if A.ndim != 2 or A.shape[0] != A.shape[1]:
|
||||
raise ValueError('expected A to be a square matrix')
|
||||
@ -513,7 +513,7 @@ def _squaring(R: Array, n_squarings: Array, max_squarings: int) -> Array:
|
||||
|
||||
return res
|
||||
|
||||
def _pade3(A: Array) -> Tuple[Array, Array]:
|
||||
def _pade3(A: Array) -> tuple[Array, Array]:
|
||||
b = (120., 60., 12., 1.)
|
||||
M, N = A.shape
|
||||
ident = jnp.eye(M, N, dtype=A.dtype)
|
||||
@ -522,7 +522,7 @@ def _pade3(A: Array) -> Tuple[Array, Array]:
|
||||
V: Array = b[2]*A2 + b[0]*ident
|
||||
return U, V
|
||||
|
||||
def _pade5(A: Array) -> Tuple[Array, Array]:
|
||||
def _pade5(A: Array) -> tuple[Array, Array]:
|
||||
b = (30240., 15120., 3360., 420., 30., 1.)
|
||||
M, N = A.shape
|
||||
ident = jnp.eye(M, N, dtype=A.dtype)
|
||||
@ -532,7 +532,7 @@ def _pade5(A: Array) -> Tuple[Array, Array]:
|
||||
V: Array = b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
return U, V
|
||||
|
||||
def _pade7(A: Array) -> Tuple[Array, Array]:
|
||||
def _pade7(A: Array) -> tuple[Array, Array]:
|
||||
b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
|
||||
M, N = A.shape
|
||||
ident = jnp.eye(M, N, dtype=A.dtype)
|
||||
@ -543,7 +543,7 @@ def _pade7(A: Array) -> Tuple[Array, Array]:
|
||||
V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
return U,V
|
||||
|
||||
def _pade9(A: Array) -> Tuple[Array, Array]:
|
||||
def _pade9(A: Array) -> tuple[Array, Array]:
|
||||
b = (17643225600., 8821612800., 2075673600., 302702400., 30270240.,
|
||||
2162160., 110880., 3960., 90., 1.)
|
||||
M, N = A.shape
|
||||
@ -556,7 +556,7 @@ def _pade9(A: Array) -> Tuple[Array, Array]:
|
||||
V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
return U,V
|
||||
|
||||
def _pade13(A: Array) -> Tuple[Array, Array]:
|
||||
def _pade13(A: Array) -> tuple[Array, Array]:
|
||||
b = (64764752532480000., 32382376266240000., 7771770303897600.,
|
||||
1187353796428800., 129060195264000., 10559470521600., 670442572800.,
|
||||
33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.)
|
||||
@ -578,7 +578,7 @@ support the ``method='blockEnlarge'`` argument.
|
||||
|
||||
@overload
|
||||
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None,
|
||||
compute_expm: Literal[True] = True) -> Tuple[Array, Array]: ...
|
||||
compute_expm: Literal[True] = True) -> tuple[Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None,
|
||||
@ -586,12 +586,12 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None,
|
||||
|
||||
@overload
|
||||
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None,
|
||||
compute_expm: bool = True) -> Union[Array, Tuple[Array, Array]]: ...
|
||||
compute_expm: bool = True) -> Union[Array, tuple[Array, Array]]: ...
|
||||
|
||||
@_wraps(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description)
|
||||
@partial(jit, static_argnames=('method', 'compute_expm'))
|
||||
def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None,
|
||||
compute_expm: bool = True) -> Union[Array, Tuple[Array, Array]]:
|
||||
compute_expm: bool = True) -> Union[Array, tuple[Array, Array]]:
|
||||
A = jnp.asarray(A)
|
||||
E = jnp.asarray(E)
|
||||
if A.ndim != 2 or A.shape[0] != A.shape[1]:
|
||||
@ -617,8 +617,8 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: Optional[str] = None,
|
||||
@jit
|
||||
def block_diag(*arrs: ArrayLike) -> Array:
|
||||
if len(arrs) == 0:
|
||||
arrs = cast(Tuple[ArrayLike], (jnp.zeros((1, 0)),))
|
||||
arrs = cast(Tuple[ArrayLike], promote_dtypes(*arrs))
|
||||
arrs = cast(tuple[ArrayLike], (jnp.zeros((1, 0)),))
|
||||
arrs = cast(tuple[ArrayLike], promote_dtypes(*arrs))
|
||||
bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2]
|
||||
if bad_shapes:
|
||||
raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at "
|
||||
@ -638,7 +638,7 @@ def block_diag(*arrs: ArrayLike) -> Array:
|
||||
@_wraps(scipy.linalg.eigh_tridiagonal)
|
||||
@partial(jit, static_argnames=("eigvals_only", "select", "select_range"))
|
||||
def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False,
|
||||
select: str = 'a', select_range: Optional[Tuple[float, float]] = None,
|
||||
select: str = 'a', select_range: Optional[tuple[float, float]] = None,
|
||||
tol: Optional[float] = None) -> Array:
|
||||
if not eigvals_only:
|
||||
raise NotImplementedError("Calculation of eigenvectors is not implemented")
|
||||
@ -794,7 +794,7 @@ def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False,
|
||||
@partial(jit, static_argnames=('side', 'method'))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def polar(a: ArrayLike, side: str = 'right', *, method: str = 'qdwh', eps: Optional[float] = None,
|
||||
max_iterations: Optional[int] = None) -> Tuple[Array, Array]:
|
||||
max_iterations: Optional[int] = None) -> tuple[Array, Array]:
|
||||
r"""Computes the polar decomposition.
|
||||
|
||||
Given the :math:`m \times n` matrix :math:`a`, returns the factors of the polar
|
||||
@ -936,7 +936,7 @@ def sqrtm(A: ArrayLike, blocksize: int = 1) -> Array:
|
||||
|
||||
@_wraps(scipy.linalg.rsf2csf, lax_description=_no_chkfinite_doc)
|
||||
@partial(jit, static_argnames=('check_finite',))
|
||||
def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> Tuple[Array, Array]:
|
||||
def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]:
|
||||
del check_finite # unused
|
||||
|
||||
T = jnp.asarray(T)
|
||||
@ -1001,12 +1001,12 @@ def hessenberg(a: ArrayLike, *, calc_q: Literal[False], overwrite_a: bool = Fals
|
||||
|
||||
@overload
|
||||
def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> Tuple[Array, Array]: ...
|
||||
check_finite: bool = True) -> tuple[Array, Array]: ...
|
||||
|
||||
@_wraps(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc)
|
||||
@partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a'))
|
||||
def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> Union[Array, Tuple[Array, Array]]:
|
||||
check_finite: bool = True) -> Union[Array, tuple[Array, Array]]:
|
||||
del overwrite_a, check_finite
|
||||
n = jnp.shape(a)[-1]
|
||||
if n == 0:
|
||||
|
@ -17,7 +17,7 @@ import functools
|
||||
import itertools
|
||||
import operator
|
||||
import textwrap
|
||||
from typing import Callable, Dict, List, Sequence, Tuple
|
||||
from typing import Callable, Sequence
|
||||
|
||||
import scipy.ndimage
|
||||
|
||||
@ -44,7 +44,7 @@ def _mirror_index_fixer(index: Array, size: int) -> Array:
|
||||
def _reflect_index_fixer(index: Array, size: int) -> Array:
|
||||
return jnp.floor_divide(_mirror_index_fixer(2*index+1, 2*size+1) - 1, 2)
|
||||
|
||||
_INDEX_FIXERS: Dict[str, Callable[[Array, int], Array]] = {
|
||||
_INDEX_FIXERS: dict[str, Callable[[Array, int], Array]] = {
|
||||
'constant': lambda index, size: index,
|
||||
'nearest': lambda index, size: jnp.clip(index, 0, size - 1),
|
||||
'wrap': lambda index, size: index % size,
|
||||
@ -57,13 +57,13 @@ def _round_half_away_from_zero(a: Array) -> Array:
|
||||
return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a)
|
||||
|
||||
|
||||
def _nearest_indices_and_weights(coordinate: Array) -> List[Tuple[Array, ArrayLike]]:
|
||||
def _nearest_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
|
||||
index = _round_half_away_from_zero(coordinate).astype(jnp.int32)
|
||||
weight = coordinate.dtype.type(1)
|
||||
return [(index, weight)]
|
||||
|
||||
|
||||
def _linear_indices_and_weights(coordinate: Array) -> List[Tuple[Array, ArrayLike]]:
|
||||
def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
|
||||
lower = jnp.floor(coordinate)
|
||||
upper_weight = coordinate - lower
|
||||
lower_weight = 1 - upper_weight
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, Mapping, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Mapping, Optional, Union
|
||||
|
||||
import jax
|
||||
from jax._src.scipy.optimize.bfgs import minimize_bfgs
|
||||
@ -50,7 +50,7 @@ class OptimizeResults(NamedTuple):
|
||||
def minimize(
|
||||
fun: Callable,
|
||||
x0: jax.Array,
|
||||
args: Tuple = (),
|
||||
args: tuple = (),
|
||||
*,
|
||||
method: str,
|
||||
tol: Optional[float] = None,
|
||||
|
@ -15,7 +15,7 @@
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from typing import Callable, Optional, Tuple, Union, Sequence
|
||||
from typing import Callable, Optional, Union, Sequence
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -289,7 +289,7 @@ def _spectral_helper(x: Array, y: Optional[ArrayLike], fs: ArrayLike = 1.0,
|
||||
detrend_type: Union[bool, str, Callable[[Array], Array]] = 'constant',
|
||||
return_onesided: bool = True, scaling: str = 'density',
|
||||
axis: int = -1, mode: str = 'psd', boundary: Optional[str] = None,
|
||||
padded: bool = False) -> Tuple[Array, Array, Array]:
|
||||
padded: bool = False) -> tuple[Array, Array, Array]:
|
||||
"""LAX-backend implementation of `scipy.signal._spectral_helper`.
|
||||
|
||||
Unlike the original helper function, `y` can be None for explicitly
|
||||
@ -500,7 +500,7 @@ def _spectral_helper(x: Array, y: Optional[ArrayLike], fs: ArrayLike = 1.0,
|
||||
def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256,
|
||||
noverlap: Optional[int] = None, nfft: Optional[int] = None,
|
||||
detrend: bool = False, return_onesided: bool = True, boundary: Optional[str] = 'zeros',
|
||||
padded: bool = True, axis: int = -1) -> Tuple[Array, Array, Array]:
|
||||
padded: bool = True, axis: int = -1) -> tuple[Array, Array, Array]:
|
||||
return _spectral_helper(x, None, fs, window, nperseg, noverlap,
|
||||
nfft, detrend, return_onesided,
|
||||
scaling='spectrum', axis=axis,
|
||||
@ -520,7 +520,7 @@ def csd(x: Array, y: Optional[ArrayLike], fs: ArrayLike = 1.0, window: str = 'ha
|
||||
nperseg: Optional[int] = None, noverlap: Optional[int] = None,
|
||||
nfft: Optional[int] = None, detrend: str = 'constant',
|
||||
return_onesided: bool = True, scaling: str = 'density',
|
||||
axis: int = -1, average: str = 'mean') -> Tuple[Array, Array]:
|
||||
axis: int = -1, average: str = 'mean') -> tuple[Array, Array]:
|
||||
freqs, _, Pxy = _spectral_helper(x, y, fs, window, nperseg, noverlap, nfft,
|
||||
detrend, return_onesided, scaling, axis,
|
||||
mode='psd')
|
||||
@ -553,7 +553,7 @@ def welch(x: Array, fs: ArrayLike = 1.0, window: str = 'hann',
|
||||
nperseg: Optional[int] = None, noverlap: Optional[int] = None,
|
||||
nfft: Optional[int] = None, detrend: str = 'constant',
|
||||
return_onesided: bool = True, scaling: str = 'density',
|
||||
axis: int = -1, average: str = 'mean') -> Tuple[Array, Array]:
|
||||
axis: int = -1, average: str = 'mean') -> tuple[Array, Array]:
|
||||
freqs, Pxx = csd(x, None, fs=fs, window=window, nperseg=nperseg,
|
||||
noverlap=noverlap, nfft=nfft, detrend=detrend,
|
||||
return_onesided=return_onesided, scaling=scaling,
|
||||
@ -615,7 +615,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
|
||||
nperseg: Optional[int] = None, noverlap: Optional[int] = None,
|
||||
nfft: Optional[int] = None, input_onesided: bool = True,
|
||||
boundary: bool = True, time_axis: int = -1,
|
||||
freq_axis: int = -2) -> Tuple[Array, Array]:
|
||||
freq_axis: int = -2) -> tuple[Array, Array]:
|
||||
# Input validation
|
||||
check_arraylike("istft", Zxx)
|
||||
if Zxx.ndim < 2:
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import cast, Any, List, Optional, Tuple
|
||||
from typing import cast, Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import scipy.special as osp_special
|
||||
@ -765,7 +765,7 @@ def bessel_jn(z: ArrayLike, *, v: int, n_iter: int=50) -> Array:
|
||||
|
||||
def _gen_recurrence_mask(
|
||||
l_max: int, is_normalized: bool, dtype: Any
|
||||
) -> Tuple[Array, Array]:
|
||||
) -> tuple[Array, Array]:
|
||||
"""Generates mask for recurrence relation on the remaining entries.
|
||||
|
||||
The remaining entries are with respect to the diagonal and offdiagonal
|
||||
@ -1012,7 +1012,7 @@ def _gen_associated_legendre(l_max: int,
|
||||
return p
|
||||
|
||||
|
||||
def lpmn(m: int, n: int, z: Array) -> Tuple[Array, Array]:
|
||||
def lpmn(m: int, n: int, z: Array) -> tuple[Array, Array]:
|
||||
"""The associated Legendre functions (ALFs) of the first kind.
|
||||
|
||||
Args:
|
||||
@ -1215,7 +1215,7 @@ def _expint1(x: Array) -> Array:
|
||||
return x * f + jnp.euler_gamma + jnp.log(x)
|
||||
|
||||
|
||||
def _eval_expint_k(A: List[float], B: List[float], x: Array) -> Array:
|
||||
def _eval_expint_k(A: list[float], B: list[float], x: Array) -> Array:
|
||||
# helper function for all subsequent intervals
|
||||
A_arr = jnp.array(A, dtype=x.dtype)
|
||||
B_arr = jnp.array(B, dtype=x.dtype)
|
||||
|
@ -15,7 +15,7 @@
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -70,7 +70,7 @@ def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", k
|
||||
axis = 0
|
||||
x = x.ravel()
|
||||
|
||||
def _mode_helper(x: jax.Array) -> Tuple[jax.Array, jax.Array]:
|
||||
def _mode_helper(x: jax.Array) -> tuple[jax.Array, jax.Array]:
|
||||
"""Helper function to return mode and count of a given array."""
|
||||
if x.size == 0:
|
||||
return jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)), jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_))
|
||||
|
@ -15,15 +15,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import (Mapping, Optional, Sequence, Set, Tuple)
|
||||
from typing import Mapping, Optional, Sequence
|
||||
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
Shape = Tuple[int, ...]
|
||||
Shape = tuple[int, ...]
|
||||
Device = xc.Device
|
||||
Index = Tuple[slice, ...]
|
||||
Index = tuple[slice, ...]
|
||||
XLADeviceAssignment = Sequence[Device]
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ class Sharding:
|
||||
|
||||
# Abstract methods below that subclasses should implement.
|
||||
@property
|
||||
def device_set(self) -> Set[Device]:
|
||||
def device_set(self) -> set[Device]:
|
||||
"""A ``set`` of global devices that this ``Sharding`` spans.
|
||||
|
||||
In multi-controller JAX, the set of devices is global, i.e., includes
|
||||
@ -89,7 +89,7 @@ class Sharding:
|
||||
# Default implementations below that all subclasses will inherit.
|
||||
|
||||
@functools.cached_property
|
||||
def addressable_devices(self) -> Set[Device]:
|
||||
def addressable_devices(self) -> set[Device]:
|
||||
"""A set of devices that are addressable by the current process."""
|
||||
# Add a fast path for single controller runtimes.
|
||||
if xb.process_count() == 1:
|
||||
|
@ -22,8 +22,8 @@ import itertools
|
||||
import math
|
||||
import operator as op
|
||||
import sys
|
||||
from typing import (Any, Dict, FrozenSet, List, Mapping, Optional, OrderedDict,
|
||||
NamedTuple, Sequence, Set, Tuple, Union, cast)
|
||||
from typing import (Any, Mapping, Optional, OrderedDict, NamedTuple, Sequence,
|
||||
Union, cast)
|
||||
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.op_shardings import (
|
||||
@ -44,13 +44,13 @@ import numpy as np
|
||||
if sys.version_info >= (3, 9):
|
||||
OrderedDictType = OrderedDict
|
||||
else:
|
||||
OrderedDictType = Dict
|
||||
OrderedDictType = dict
|
||||
|
||||
|
||||
Shape = Tuple[int, ...]
|
||||
Shape = tuple[int, ...]
|
||||
Device = xc.Device
|
||||
Index = Tuple[slice, ...]
|
||||
XLADeviceAssignment = Tuple[Device, ...]
|
||||
Index = tuple[slice, ...]
|
||||
XLADeviceAssignment = tuple[Device, ...]
|
||||
|
||||
|
||||
# Shardings that inherit from XLACompatibleSharding should implement the
|
||||
@ -146,7 +146,7 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
|
||||
'create a device to index mapping for your sharding from which replica '
|
||||
'ids will be calculated.') from None
|
||||
|
||||
index_to_replica: Dict[int, int] = collections.Counter()
|
||||
index_to_replica: dict[int, int] = collections.Counter()
|
||||
out = {}
|
||||
for device, index in device_indices_map_fn(global_shape).items():
|
||||
h_index = hashed_index(index)
|
||||
@ -255,7 +255,7 @@ class NamedSharding(XLACompatibleSharding):
|
||||
return cls(mesh, parsed_pspec.get_partition_spec(), parsed_pspec)
|
||||
|
||||
@property
|
||||
def device_set(self) -> Set[Device]:
|
||||
def device_set(self) -> set[Device]:
|
||||
return self.mesh._flat_devices_set
|
||||
|
||||
@property
|
||||
@ -269,7 +269,7 @@ class NamedSharding(XLACompatibleSharding):
|
||||
return not self.mesh.is_multi_process
|
||||
|
||||
@property
|
||||
def addressable_devices(self) -> Set[Device]:
|
||||
def addressable_devices(self) -> set[Device]:
|
||||
# Override addressable devices because there is a high chance that the mesh
|
||||
# across multiple NamedSharding objects will be the same.
|
||||
return self.mesh._local_devices_set
|
||||
@ -355,7 +355,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
||||
return self._device == other._device
|
||||
|
||||
@property
|
||||
def device_set(self) -> Set[Device]:
|
||||
def device_set(self) -> set[Device]:
|
||||
return {self._device}
|
||||
|
||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: # type: ignore
|
||||
@ -453,7 +453,7 @@ class PmapSharding(XLACompatibleSharding):
|
||||
return cls(pmap_devices, sharding_spec)
|
||||
|
||||
@functools.cached_property
|
||||
def device_set(self) -> Set[Device]:
|
||||
def device_set(self) -> set[Device]:
|
||||
return set(self.devices.flat)
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
@ -524,7 +524,7 @@ def _op_sharding_to_pos_sharding(
|
||||
|
||||
|
||||
class PositionalSharding(XLACompatibleSharding):
|
||||
_devices: Tuple[xc.Device, ...]
|
||||
_devices: tuple[xc.Device, ...]
|
||||
_ids: np.ndarray # dtype DeviceIdSet
|
||||
|
||||
def __init__(self, devices: Union[Sequence[xc.Device], np.ndarray]):
|
||||
@ -564,7 +564,7 @@ class PositionalSharding(XLACompatibleSharding):
|
||||
|
||||
@classmethod
|
||||
def remake(
|
||||
cls, devices: Tuple[xc.Device, ...], ids: np.ndarray) -> PositionalSharding:
|
||||
cls, devices: tuple[xc.Device, ...], ids: np.ndarray) -> PositionalSharding:
|
||||
self = cls.__new__(cls)
|
||||
self._devices = devices
|
||||
self._ids = ids
|
||||
@ -626,7 +626,7 @@ class PositionalSharding(XLACompatibleSharding):
|
||||
|
||||
class DeviceIdSet:
|
||||
_name: str
|
||||
_ids: FrozenSet[int]
|
||||
_ids: frozenset[int]
|
||||
def __init__(self, name, *ids):
|
||||
self._name = name
|
||||
self._ids = frozenset(ids)
|
||||
@ -655,7 +655,7 @@ class DeviceIdSet:
|
||||
|
||||
@use_cpp_class(xc.GSPMDSharding)
|
||||
class GSPMDSharding(XLACompatibleSharding):
|
||||
_devices: Tuple[Device, ...]
|
||||
_devices: tuple[Device, ...]
|
||||
_hlo_sharding: xc.HloSharding
|
||||
|
||||
@use_cpp_method()
|
||||
@ -706,7 +706,7 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
f"{len(aval_shape)}")
|
||||
|
||||
@functools.cached_property
|
||||
def device_set(self) -> Set[Device]:
|
||||
def device_set(self) -> set[Device]:
|
||||
return set(self._devices)
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
@ -997,8 +997,8 @@ def _check_unique_resources(axis_resources, arg_name):
|
||||
class AxisEnv(NamedTuple):
|
||||
"""Represents a pmap mesh (only along the replica axes)."""
|
||||
nreps: int
|
||||
names: Tuple[Any, ...]
|
||||
sizes: Tuple[int, ...]
|
||||
names: tuple[Any, ...]
|
||||
sizes: tuple[int, ...]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -1010,7 +1010,7 @@ class SPMDAxisContext:
|
||||
is invoked inside an xmap) lowered in the MANUAL sharding mode.
|
||||
"""
|
||||
mesh: mesh_lib.Mesh
|
||||
manual_axes: FrozenSet[MeshAxisName] = frozenset()
|
||||
manual_axes: frozenset[MeshAxisName] = frozenset()
|
||||
|
||||
@property
|
||||
def axis_env(self):
|
||||
@ -1030,7 +1030,7 @@ class SPMDAxisContext:
|
||||
names=self.mesh.axis_names,
|
||||
sizes=tuple(self.mesh.shape.values()))
|
||||
|
||||
def extend_manual(self, axes: FrozenSet[MeshAxisName]) -> SPMDAxisContext:
|
||||
def extend_manual(self, axes: frozenset[MeshAxisName]) -> SPMDAxisContext:
|
||||
return SPMDAxisContext(self.mesh, self.manual_axes | axes)
|
||||
|
||||
|
||||
@ -1162,7 +1162,7 @@ def parse_flatten_op_sharding(op_sharding: Union[xc.OpSharding, xc.HloSharding],
|
||||
if isinstance(op_sharding, xc.HloSharding):
|
||||
op_sharding = op_sharding.to_proto() # type: ignore
|
||||
if op_sharding.type == xc.OpSharding.Type.TUPLE:
|
||||
out: List[ParsedPartitionSpec] = []
|
||||
out: list[ParsedPartitionSpec] = []
|
||||
for s in op_sharding.tuple_shardings:
|
||||
out.extend(parse_flatten_op_sharding(s, mesh))
|
||||
return out
|
||||
|
@ -31,7 +31,7 @@ import collections
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast
|
||||
from typing import Any, Mapping, Optional, Sequence, Union, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -91,7 +91,7 @@ def sharding_spec_sharding_proto(
|
||||
the code here might help:
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/compiler/xla/experimental/xla_sharding/xla_sharding.py
|
||||
"""
|
||||
mesh_shape = cast(Tuple[int, ...], self.mesh_shape)
|
||||
mesh_shape = cast(tuple[int, ...], self.mesh_shape)
|
||||
|
||||
sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped
|
||||
replicated_maxes = [] # lists mesh axis identifiers to replicate over
|
||||
@ -128,8 +128,8 @@ def sharding_spec_sharding_proto(
|
||||
# specially over some mesh axes.
|
||||
last_tile_dims = []
|
||||
if replicated_maxes:
|
||||
axes_by_type: Dict[OpShardingType, List[_MeshAxisName]] = {}
|
||||
size_by_type: Dict[OpShardingType, int] = collections.defaultdict(lambda: 1)
|
||||
axes_by_type: dict[OpShardingType, list[_MeshAxisName]] = {}
|
||||
size_by_type: dict[OpShardingType, int] = collections.defaultdict(lambda: 1)
|
||||
assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes.keys()))
|
||||
for axis, size in replicated_maxes:
|
||||
ty = special_axes.get(axis, xc.OpSharding.Type.REPLICATED)
|
||||
@ -145,7 +145,7 @@ def sharding_spec_sharding_proto(
|
||||
transpose_perm=mesh_permutation, subgroup_types=last_tile_dims)
|
||||
|
||||
|
||||
def _sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
|
||||
def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray:
|
||||
"""Returns NumPy-style indices corresponding to a sharding spec.
|
||||
|
||||
Args:
|
||||
@ -167,7 +167,7 @@ def _sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
|
||||
hlo_sharding, shape, math.prod(self.mesh_shape)
|
||||
).reshape(self.mesh_shape)
|
||||
|
||||
axis_indices: List[Sequence[Index]] = []
|
||||
axis_indices: list[Sequence[Index]] = []
|
||||
shard_indices_shape = []
|
||||
for dim, sharding in enumerate(self.sharding):
|
||||
axis_size = shape[dim]
|
||||
@ -223,10 +223,10 @@ ShardingSpec.indices = _sharding_spec_indices
|
||||
ShardingSpec.__repr__ = _sharding_spec_repr # type: ignore
|
||||
|
||||
|
||||
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
|
||||
Index = Union[int, slice, tuple[Union[int, slice], ...]]
|
||||
|
||||
def spec_to_indices(shape: Sequence[int],
|
||||
spec: ShardingSpec) -> Tuple[Index, ...]:
|
||||
spec: ShardingSpec) -> tuple[Index, ...]:
|
||||
"""Returns numpy-style indices corresponding to a sharding spec.
|
||||
|
||||
Each index describes a shard of the array. The order of the indices is the
|
||||
@ -306,7 +306,7 @@ def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int],
|
||||
mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping)
|
||||
|
||||
|
||||
def create_pmap_sharding_spec(shape: Tuple[int, ...], sharded_dim: int = 0,
|
||||
def create_pmap_sharding_spec(shape: tuple[int, ...], sharded_dim: int = 0,
|
||||
sharded_dim_size: Optional[int] = None):
|
||||
if sharded_dim is not None:
|
||||
sharded_shape = shape[:sharded_dim] + shape[sharded_dim+1:]
|
||||
|
@ -20,7 +20,7 @@ import os.path
|
||||
import sysconfig
|
||||
import threading
|
||||
import types
|
||||
from typing import List, Optional, Iterator, NamedTuple, Union, Tuple
|
||||
from typing import Optional, Iterator, NamedTuple, Union
|
||||
|
||||
import jax.version
|
||||
from jax._src.lib import xla_client
|
||||
@ -40,7 +40,7 @@ class Frame(NamedTuple):
|
||||
end_column: int
|
||||
|
||||
|
||||
_exclude_paths: List[str] = [
|
||||
_exclude_paths: list[str] = [
|
||||
os.path.dirname(jax.version.__file__),
|
||||
# Also exclude stdlib as user frames. In a non-standard Python runtime,
|
||||
# the following two may be different.
|
||||
@ -54,13 +54,13 @@ def register_exclusion(path: str):
|
||||
class Scope(NamedTuple):
|
||||
name: str
|
||||
|
||||
def wrap(self, stack: Tuple[str, ...]) -> Tuple[str, ...]:
|
||||
def wrap(self, stack: tuple[str, ...]) -> tuple[str, ...]:
|
||||
return (self.name, *stack)
|
||||
|
||||
class Transform(NamedTuple):
|
||||
name: str
|
||||
|
||||
def wrap(self, stack: Tuple[str, ...]) -> Tuple[str, ...]:
|
||||
def wrap(self, stack: tuple[str, ...]) -> tuple[str, ...]:
|
||||
if stack:
|
||||
return (f'{self.name}({stack[0]})', *stack[1:])
|
||||
else:
|
||||
@ -68,9 +68,9 @@ class Transform(NamedTuple):
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NameStack:
|
||||
stack: Tuple[Union[Scope, Transform], ...] = ()
|
||||
stack: tuple[Union[Scope, Transform], ...] = ()
|
||||
|
||||
def extend(self, name: Union[Tuple[str, ...], str]) -> 'NameStack':
|
||||
def extend(self, name: Union[tuple[str, ...], str]) -> 'NameStack':
|
||||
if not isinstance(name, tuple):
|
||||
name = (name,)
|
||||
scopes = tuple(map(Scope, name))
|
||||
@ -97,7 +97,7 @@ class NameStack:
|
||||
return NameStack(other.stack + self.stack)
|
||||
|
||||
def __str__(self) -> str:
|
||||
scope: Tuple[str, ...] = ()
|
||||
scope: tuple[str, ...] = ()
|
||||
for elem in self.stack[::-1]:
|
||||
scope = elem.wrap(scope)
|
||||
return '/'.join(scope)
|
||||
|
@ -33,8 +33,7 @@ from __future__ import annotations
|
||||
import warnings
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tuple, Union)
|
||||
from typing import Any, NamedTuple, Optional, Protocol, Sequence, Union
|
||||
|
||||
import jax
|
||||
|
||||
@ -56,7 +55,7 @@ xla_extension = xc._xla
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
CompilerOptions = Dict[str, Union[str, bool]]
|
||||
CompilerOptions = dict[str, Union[str, bool]]
|
||||
|
||||
|
||||
# -- Internal protocols
|
||||
@ -232,9 +231,9 @@ class XlaExecutable(Executable):
|
||||
else:
|
||||
raise
|
||||
|
||||
# TODO(skyewm): this should return a single Dict (I think returning a list
|
||||
# TODO(skyewm): this should return a single dict (I think returning a list
|
||||
# was to support MPMD executables, which never fully landed)
|
||||
def cost_analysis(self) -> List[Dict[str, float]]:
|
||||
def cost_analysis(self) -> list[dict[str, float]]:
|
||||
xla_ext_exe = self.xla_extension_executable()
|
||||
|
||||
# TODO(b/259255524): Unify/merge the two cost_analysis calls below.
|
||||
@ -284,7 +283,7 @@ class XlaExecutable(Executable):
|
||||
class XlaLowering(Lowering):
|
||||
"""Adapts our various internal XLA-backed computations into a ``Lowering``."""
|
||||
|
||||
compile_args: Dict[str, Any]
|
||||
compile_args: dict[str, Any]
|
||||
|
||||
def hlo(self) -> xc.XlaComputation:
|
||||
"""Return an HLO representation of this computation."""
|
||||
@ -331,7 +330,7 @@ class XlaLowering(Lowering):
|
||||
else:
|
||||
raise ValueError(f"unknown dialect: {dialect}")
|
||||
|
||||
def cost_analysis(self) -> Dict[str, float]:
|
||||
def cost_analysis(self) -> dict[str, float]:
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
|
||||
@ -578,7 +577,7 @@ class Lowered(Stage):
|
||||
lowering: XlaLowering,
|
||||
in_tree: tree_util.PyTreeDef,
|
||||
in_avals,
|
||||
donate_argnums: Tuple[int, ...],
|
||||
donate_argnums: tuple[int, ...],
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
no_kwargs: bool = False):
|
||||
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.
|
||||
@ -599,7 +598,7 @@ class Lowered(Stage):
|
||||
def compile(
|
||||
self, compiler_options: Optional[CompilerOptions] = None) -> Compiled:
|
||||
"""Compile, returning a corresponding ``Compiled`` instance."""
|
||||
kw: Dict[str, Any] = {"compiler_options": compiler_options}
|
||||
kw: dict[str, Any] = {"compiler_options": compiler_options}
|
||||
return Compiled(
|
||||
self._lowering.compile(**kw), # pytype: disable=wrong-keyword-args
|
||||
self.args_info,
|
||||
|
@ -17,7 +17,7 @@ import dataclasses
|
||||
from functools import partial
|
||||
import operator
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Protocol, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -55,7 +55,7 @@ PyTreeDef = tree_util.PyTreeDef
|
||||
|
||||
def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * ,
|
||||
should_discharge: Union[bool, Sequence[bool]] = True
|
||||
) -> Tuple[core.Jaxpr, List[Any]]:
|
||||
) -> tuple[core.Jaxpr, list[Any]]:
|
||||
"""Converts a jaxpr that takes in `Ref`s into one that doesn't."""
|
||||
if isinstance(should_discharge, bool):
|
||||
should_discharge = [should_discharge] * len(jaxpr.invars)
|
||||
@ -69,7 +69,7 @@ def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * ,
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Environment:
|
||||
env: Dict[core.Var, Any]
|
||||
env: dict[core.Var, Any]
|
||||
|
||||
def read(self, v: core.Atom) -> Any:
|
||||
if type(v) is core.Literal:
|
||||
@ -84,7 +84,7 @@ class DischargeRule(Protocol):
|
||||
|
||||
def __call__(self, in_avals: Sequence[core.AbstractValue],
|
||||
out_avals: Sequence[core.AbstractValue], *args: Any,
|
||||
**params: Any) -> Tuple[Sequence[Optional[Any]], Sequence[Any]]:
|
||||
**params: Any) -> tuple[Sequence[Optional[Any]], Sequence[Any]]:
|
||||
...
|
||||
|
||||
_discharge_rules: dict[core.Primitive, DischargeRule] = {}
|
||||
@ -327,7 +327,7 @@ ad.primitive_jvps[run_state_p] = _run_state_jvp
|
||||
_save_everything = lambda *_, **__: True
|
||||
|
||||
def _convert_outputs_to_writes(
|
||||
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, List[core.ShapedArray]]:
|
||||
jaxpr: core.Jaxpr) -> tuple[core.Jaxpr, list[core.ShapedArray]]:
|
||||
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."
|
||||
|
||||
in_avals = [v.aval for v in jaxpr.invars]
|
||||
@ -677,7 +677,7 @@ def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue],
|
||||
|
||||
def initial_style_jaxpr(
|
||||
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue]
|
||||
) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]:
|
||||
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
|
||||
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals))
|
||||
|
||||
@weakref_lru_cache
|
||||
|
@ -14,7 +14,7 @@
|
||||
"""Module for state primitives."""
|
||||
from functools import partial
|
||||
|
||||
from typing import Any, List, Tuple, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -55,7 +55,7 @@ def _get_impl(ref: AbstractRef, *idx: int, **_):
|
||||
raise ValueError("Cannot run stateful primitive.")
|
||||
get_p.def_impl(_get_impl)
|
||||
|
||||
Indexer = Tuple[Union[int, slice, Array], ...]
|
||||
Indexer = tuple[Union[int, slice, Array], ...]
|
||||
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
|
||||
|
||||
def _is_trivial_indexer(idx: Indexer) -> bool:
|
||||
@ -68,7 +68,7 @@ def _is_trivial_indexer(idx: Indexer) -> bool:
|
||||
return False
|
||||
|
||||
def _unpack_idx(idx: Indexer, ndim: int
|
||||
) -> Tuple[Tuple[Array, ...], Tuple[bool, ...]]:
|
||||
) -> tuple[tuple[Array, ...], tuple[bool, ...]]:
|
||||
if _is_trivial_indexer(idx):
|
||||
idx = tuple(slice(None) for _ in range(ndim))
|
||||
indexed_dims_ = []
|
||||
@ -85,9 +85,9 @@ def _unpack_idx(idx: Indexer, ndim: int
|
||||
import jax.numpy as jnp
|
||||
return (tuple(map(jnp.int32, non_slice_idx)), tuple(indexed_dims))
|
||||
|
||||
def _get_slice_output_shape(in_shape: Tuple[int, ...],
|
||||
idx_shapes: Tuple[Tuple[int, ...], ...],
|
||||
indexed_dims: Tuple[bool, ...]) -> Tuple[int, ...]:
|
||||
def _get_slice_output_shape(in_shape: tuple[int, ...],
|
||||
idx_shapes: tuple[tuple[int, ...], ...],
|
||||
indexed_dims: tuple[bool, ...]) -> tuple[int, ...]:
|
||||
shape_suffix = [d for i, d in zip(indexed_dims, in_shape) if not i]
|
||||
shape_prefix, = set(idx_shapes) or [()] # tie fighter
|
||||
# Move shape prefix dimensions to the front
|
||||
@ -95,7 +95,7 @@ def _get_slice_output_shape(in_shape: Tuple[int, ...],
|
||||
return shape
|
||||
|
||||
def _get_indexer(ref: AbstractRef, idx: Indexer
|
||||
) -> Tuple[Indexer, Tuple[bool, ...]]:
|
||||
) -> tuple[Indexer, tuple[bool, ...]]:
|
||||
if isinstance(ref.inner_aval, core.ShapedArray):
|
||||
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
|
||||
else:
|
||||
@ -201,7 +201,7 @@ get_p.def_effectful_abstract_eval(_get_abstract_eval)
|
||||
|
||||
def _swap_abstract_eval(ref_aval: AbstractRef,
|
||||
val_aval: core.AbstractValue,
|
||||
*idx: core.ShapedArray, indexed_dims: Tuple[bool]):
|
||||
*idx: core.ShapedArray, indexed_dims: tuple[bool]):
|
||||
out_aval: core.AbstractValue
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
|
||||
@ -236,7 +236,7 @@ swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
|
||||
|
||||
def _addupdate_abstract_eval(ref_aval: AbstractRef,
|
||||
val_aval: core.AbstractValue,
|
||||
*idx: core.ShapedArray, indexed_dims: Tuple[bool]):
|
||||
*idx: core.ShapedArray, indexed_dims: tuple[bool]):
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
|
||||
if idx and not isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
@ -327,7 +327,7 @@ core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
|
||||
|
||||
## get/swap/addupdate JVP rules
|
||||
|
||||
def _get_jvp(primals: List[Any], tangents: List[Any], **params: Any):
|
||||
def _get_jvp(primals: list[Any], tangents: list[Any], **params: Any):
|
||||
ref_primal, *idx = primals
|
||||
assert isinstance(ref_primal.aval, AbstractRef)
|
||||
ref_tangent, *_ = tangents
|
||||
@ -336,7 +336,7 @@ def _get_jvp(primals: List[Any], tangents: List[Any], **params: Any):
|
||||
get_p.bind(ref_tangent, *idx, **params)) # type: ignore[arg-type]
|
||||
ad.primitive_jvps[get_p] = _get_jvp
|
||||
|
||||
def _swap_jvp(primals: List[Any], tangents: List[Any], **params: Any):
|
||||
def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any):
|
||||
ref_primal, x_primal, *idx = primals
|
||||
assert isinstance(ref_primal.aval, AbstractRef)
|
||||
ref_tangent, x_tangent, *_ = tangents
|
||||
@ -346,7 +346,7 @@ def _swap_jvp(primals: List[Any], tangents: List[Any], **params: Any):
|
||||
swap_p.bind(ref_tangent, x_tangent, *idx, **params)) # type: ignore[arg-type]
|
||||
ad.primitive_jvps[swap_p] = _swap_jvp
|
||||
|
||||
def addupdate_jvp_rule(primals: List[Any], tangents: List[Any], **params: Any):
|
||||
def addupdate_jvp_rule(primals: list[Any], tangents: list[Any], **params: Any):
|
||||
ref_primal, x_primal, *idx = primals
|
||||
ref_tangent, x_tangent, *_ = tangents
|
||||
x_tangent = ad_util.instantiate(x_tangent)
|
||||
@ -397,8 +397,8 @@ pe.partial_eval_jaxpr_custom_rules[addupdate_p] = partial(
|
||||
|
||||
## get/swap/addupdate batching rules
|
||||
|
||||
def _output_bdim(indexed_dims: Tuple[bool, ...], ref_dim: int,
|
||||
idxs_shape: Tuple[int, ...]):
|
||||
def _output_bdim(indexed_dims: tuple[bool, ...], ref_dim: int,
|
||||
idxs_shape: tuple[int, ...]):
|
||||
num_idxs_to_left = sum(indexed_dims[:ref_dim])
|
||||
return ref_dim - num_idxs_to_left + len(idxs_shape)
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Any, Generic, List, Sequence, Set, Tuple, TypeVar, Union
|
||||
from typing import Any, Generic, Sequence, TypeVar, Union
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import effects
|
||||
@ -150,12 +150,12 @@ core.aval_mapping_handlers[AbstractRef] = (_map_ref, _unmap_ref)
|
||||
|
||||
def get_ref_state_effects(
|
||||
avals: Sequence[core.AbstractValue],
|
||||
effects: core.Effects) -> List[Set[StateEffect]]:
|
||||
effects: core.Effects) -> list[set[StateEffect]]:
|
||||
return [{eff for eff in effects
|
||||
if isinstance(eff, (ReadEffect, WriteEffect, AccumEffect))
|
||||
and eff.input_index == i} for i, _ in enumerate(avals)]
|
||||
|
||||
def shaped_array_ref(shape: Tuple[int, ...], dtype,
|
||||
def shaped_array_ref(shape: tuple[int, ...], dtype,
|
||||
weak_type: bool = False,
|
||||
named_shape = None) -> AbstractRef[core.AbstractValue]:
|
||||
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type,
|
||||
|
@ -22,7 +22,7 @@ import re
|
||||
import os
|
||||
import tempfile
|
||||
import textwrap
|
||||
from typing import Any, Callable, Dict, List, Generator, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Generator, Optional, Sequence, Union
|
||||
import unittest
|
||||
import warnings
|
||||
import zlib
|
||||
@ -1057,7 +1057,7 @@ def ignore_warning(**kw):
|
||||
|
||||
# -------------------- Mesh parametrization helpers --------------------
|
||||
|
||||
MeshSpec = List[Tuple[str, int]]
|
||||
MeshSpec = list[tuple[str, int]]
|
||||
|
||||
@contextmanager
|
||||
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
||||
@ -1199,7 +1199,7 @@ def strict_promotion_if_dtypes_match(dtypes):
|
||||
return jax.numpy_dtype_promotion('standard')
|
||||
|
||||
_version_regex = re.compile(r"([0-9]+(?:\.[0-9]+)*)(?:(rc|dev).*)?")
|
||||
def _parse_version(v: str) -> Tuple[int, ...]:
|
||||
def _parse_version(v: str) -> tuple[int, ...]:
|
||||
m = _version_regex.match(v)
|
||||
if m is None:
|
||||
raise ValueError(f"Unable to parse version '{v}'")
|
||||
@ -1209,8 +1209,8 @@ def numpy_version():
|
||||
return _parse_version(np.__version__)
|
||||
|
||||
def parameterized_filterable(*,
|
||||
kwargs: Sequence[Dict[str, Any]],
|
||||
testcase_name: Optional[Callable[[Dict[str, Any]], str]] = None,
|
||||
kwargs: Sequence[dict[str, Any]],
|
||||
testcase_name: Optional[Callable[[dict[str, Any]], str]] = None,
|
||||
one_containing: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -1229,7 +1229,7 @@ def parameterized_filterable(*,
|
||||
only one `kwargs` whose `testcase_name` includes `one_containing`.
|
||||
"""
|
||||
# Ensure that all kwargs contain a testcase_name
|
||||
kwargs_with_testcase_name: Sequence[Dict[str, Any]]
|
||||
kwargs_with_testcase_name: Sequence[dict[str, Any]]
|
||||
if testcase_name is not None:
|
||||
kwargs_with_testcase_name = [dict(testcase_name=testcase_name(kw), **kw)
|
||||
for kw in kwargs]
|
||||
|
6
jax/_src/third_party/scipy/linalg.py
vendored
6
jax/_src/third_party/scipy/linalg.py
vendored
@ -1,4 +1,4 @@
|
||||
from typing import Callable, Tuple
|
||||
from typing import Callable
|
||||
|
||||
import scipy.linalg
|
||||
|
||||
@ -11,7 +11,7 @@ from jax._src.typing import ArrayLike, Array
|
||||
|
||||
|
||||
@jit
|
||||
def _algorithm_11_1_1(F: Array, T: Array) -> Tuple[Array, Array]:
|
||||
def _algorithm_11_1_1(F: Array, T: Array) -> tuple[Array, Array]:
|
||||
# Algorithm 11.1.1 from Golub and Van Loan "Matrix Computations"
|
||||
N = T.shape[0]
|
||||
minden = jnp.abs(T[0, 0])
|
||||
@ -50,7 +50,7 @@ will be printed if the error in the array output is estimated to be large.
|
||||
"""
|
||||
|
||||
@_wraps(scipy.linalg.funm, lax_description=_FUNM_LAX_DESCRIPTION)
|
||||
def funm(A: ArrayLike, func: Callable[[Array], Array], disp: bool = True) -> Tuple[Array, Array]:
|
||||
def funm(A: ArrayLike, func: Callable[[Array], Array], disp: bool = True) -> tuple[Array, Array]:
|
||||
A = jnp.asarray(A)
|
||||
if A.ndim != 2 or A.shape[0] != A.shape[1]:
|
||||
raise ValueError('expected square array_like input')
|
||||
|
6
jax/_src/third_party/scipy/signal_helper.py
vendored
6
jax/_src/third_party/scipy/signal_helper.py
vendored
@ -1,15 +1,15 @@
|
||||
"""Utility functions adopted from scipy.signal."""
|
||||
|
||||
import scipy.signal as osp_signal
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
import warnings
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
|
||||
|
||||
def _triage_segments(window: Union[ArrayLike, str, Tuple[Any, ...]], nperseg: Optional[int],
|
||||
input_length: int, dtype: DTypeLike) -> Tuple[Array, int]:
|
||||
def _triage_segments(window: Union[ArrayLike, str, tuple[Any, ...]], nperseg: Optional[int],
|
||||
input_length: int, dtype: DTypeLike) -> tuple[Array, int]:
|
||||
"""
|
||||
Parses window and nperseg arguments for spectrogram and _spectral_helper.
|
||||
This is a helper function, not meant to be called externally.
|
||||
|
@ -16,7 +16,7 @@ import functools
|
||||
import os
|
||||
import traceback
|
||||
import types
|
||||
from typing import Any, Callable, List, Optional, TypeVar, cast
|
||||
from typing import Any, Callable, Optional, TypeVar, cast
|
||||
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_extension
|
||||
@ -25,7 +25,7 @@ from jax._src import util
|
||||
|
||||
C = TypeVar("C", bound=Callable[..., Any])
|
||||
|
||||
_exclude_paths: List[str] = [__file__, util.__file__]
|
||||
_exclude_paths: list[str] = [__file__, util.__file__]
|
||||
|
||||
def register_exclusion(path: str):
|
||||
_exclude_paths.append(path)
|
||||
|
@ -20,8 +20,8 @@ import functools
|
||||
from functools import partial
|
||||
import operator as op
|
||||
import textwrap
|
||||
from typing import (Any, Callable, Hashable, Iterable, List, NamedTuple,
|
||||
Optional, Tuple, Type, TypeVar, Union, overload)
|
||||
from typing import (Any, Callable, Hashable, Iterable, NamedTuple,
|
||||
Optional, TypeVar, Union, overload)
|
||||
import warnings
|
||||
|
||||
from jax._src import traceback_util
|
||||
@ -33,7 +33,7 @@ from jax._src.util import unzip2
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U", bound=Type[Any])
|
||||
U = TypeVar("U", bound=type[Any])
|
||||
|
||||
Leaf = Any
|
||||
PyTreeDef = pytree.PyTreeDef
|
||||
@ -41,7 +41,7 @@ PyTreeDef = pytree.PyTreeDef
|
||||
|
||||
def tree_flatten(tree: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None
|
||||
) -> Tuple[List[Leaf], PyTreeDef]:
|
||||
) -> tuple[list[Leaf], PyTreeDef]:
|
||||
"""Flattens a pytree.
|
||||
|
||||
The flattening order (i.e. the order of elements in the output list)
|
||||
@ -79,7 +79,7 @@ def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
|
||||
|
||||
def tree_leaves(tree: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None
|
||||
) -> List[Leaf]:
|
||||
) -> list[Leaf]:
|
||||
"""Gets the leaves of a pytree."""
|
||||
return pytree.flatten(tree, is_leaf)[0]
|
||||
|
||||
@ -92,7 +92,7 @@ def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef:
|
||||
"""Makes a tuple treedef from an iterable of child treedefs."""
|
||||
return pytree.tuple(list(treedefs))
|
||||
|
||||
def treedef_children(treedef: PyTreeDef) -> List[PyTreeDef]:
|
||||
def treedef_children(treedef: PyTreeDef) -> list[PyTreeDef]:
|
||||
return treedef.children()
|
||||
|
||||
def treedef_is_leaf(treedef: PyTreeDef) -> bool:
|
||||
@ -129,8 +129,8 @@ def all_leaves(iterable: Iterable[Any],
|
||||
_Children = TypeVar("_Children", bound=Iterable[Any])
|
||||
_AuxData = TypeVar("_AuxData", bound=Hashable)
|
||||
|
||||
def register_pytree_node(nodetype: Type[T],
|
||||
flatten_func: Callable[[T], Tuple[_Children, _AuxData]],
|
||||
def register_pytree_node(nodetype: type[T],
|
||||
flatten_func: Callable[[T], tuple[_Children, _AuxData]],
|
||||
unflatten_func: Callable[[_AuxData, _Children], T]):
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
@ -393,7 +393,7 @@ register_pytree_node(
|
||||
|
||||
def broadcast_prefix(prefix_tree: Any, full_tree: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
# If prefix_tree is not a tree prefix of full_tree, this code can raise a
|
||||
# ValueError; use prefix_errors to find disagreements and raise more precise
|
||||
# error messages.
|
||||
@ -403,7 +403,7 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,
|
||||
tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
|
||||
return result
|
||||
|
||||
def flatten_one_level(pytree: Any) -> Tuple[List[Any], Hashable]:
|
||||
def flatten_one_level(pytree: Any) -> tuple[list[Any], Hashable]:
|
||||
"""Flatten the given pytree node by one level.
|
||||
|
||||
Args:
|
||||
@ -429,12 +429,12 @@ def flatten_one_level(pytree: Any) -> Tuple[List[Any], Hashable]:
|
||||
|
||||
def prefix_errors(prefix_tree: Any, full_tree: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None,
|
||||
) -> List[Callable[[str], ValueError]]:
|
||||
) -> list[Callable[[str], ValueError]]:
|
||||
return list(_prefix_error((), prefix_tree, full_tree, is_leaf))
|
||||
|
||||
def equality_errors(
|
||||
tree1: Any, tree2: Any, is_leaf: Optional[Callable[[Any], bool]] = None,
|
||||
) -> Iterable[Tuple[KeyPath, str, str, str]]:
|
||||
) -> Iterable[tuple[KeyPath, str, str, str]]:
|
||||
"""Helper to describe structural differences between two pytrees.
|
||||
|
||||
Args:
|
||||
@ -562,7 +562,7 @@ class FlattenedIndexKey():
|
||||
BuiltInKeyEntry = Union[SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey]
|
||||
|
||||
KeyEntry = TypeVar("KeyEntry", bound=Hashable)
|
||||
KeyPath = Tuple[KeyEntry, ...]
|
||||
KeyPath = tuple[KeyEntry, ...]
|
||||
|
||||
def keystr(keys: KeyPath):
|
||||
"""Helper to pretty-print a tuple of keys.
|
||||
@ -582,7 +582,7 @@ class _RegistryWithKeypathsEntry(NamedTuple):
|
||||
|
||||
|
||||
def register_keypaths(
|
||||
ty: Type[T], handler: Callable[[T], Tuple[KeyEntry, ...]]
|
||||
ty: type[T], handler: Callable[[T], tuple[KeyEntry, ...]]
|
||||
) -> None:
|
||||
"""[Deprecated] Register the method to get keypaths for type.
|
||||
|
||||
@ -603,7 +603,7 @@ def register_keypaths(
|
||||
|
||||
|
||||
def _register_keypaths(
|
||||
ty: Type[T], handler: Callable[[T], Tuple[KeyEntry, ...]]
|
||||
ty: type[T], handler: Callable[[T], tuple[KeyEntry, ...]]
|
||||
) -> None:
|
||||
def flatten_with_keys(xs):
|
||||
children, treedef = _registry[ty].to_iter(xs)
|
||||
@ -634,13 +634,13 @@ _register_keypaths(
|
||||
|
||||
|
||||
def register_pytree_with_keys(
|
||||
nodetype: Type[T],
|
||||
nodetype: type[T],
|
||||
flatten_with_keys: Callable[
|
||||
[T], Tuple[Iterable[Tuple[KeyEntry, Any]], _AuxData]
|
||||
[T], tuple[Iterable[tuple[KeyEntry, Any]], _AuxData]
|
||||
],
|
||||
unflatten_func: Callable[[_AuxData, Iterable[Any]], T],
|
||||
flatten_func: Optional[
|
||||
Callable[[T], Tuple[Iterable[Any], _AuxData]]
|
||||
Callable[[T], tuple[Iterable[Any], _AuxData]]
|
||||
] = None,
|
||||
):
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
@ -708,7 +708,7 @@ def register_pytree_with_keys_class(cls: U) -> U:
|
||||
|
||||
def tree_flatten_with_path(
|
||||
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
|
||||
) -> Tuple[List[Tuple[KeyPath, Any]], PyTreeDef]:
|
||||
) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]:
|
||||
"""Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path.
|
||||
|
||||
Args:
|
||||
@ -725,7 +725,7 @@ def tree_flatten_with_path(
|
||||
|
||||
def tree_leaves_with_path(
|
||||
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
|
||||
) -> List[Tuple[KeyPath, Any]]:
|
||||
) -> list[tuple[KeyPath, Any]]:
|
||||
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
|
||||
|
||||
Args:
|
||||
@ -739,7 +739,7 @@ def tree_leaves_with_path(
|
||||
|
||||
def generate_key_paths(
|
||||
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
|
||||
) -> List[Tuple[KeyPath, Any]]:
|
||||
) -> list[tuple[KeyPath, Any]]:
|
||||
return list(_generate_key_paths_((), tree, is_leaf))
|
||||
_generate_key_paths = generate_key_paths # alias for backward compat
|
||||
|
||||
@ -749,7 +749,7 @@ def _generate_key_paths_(
|
||||
key_path: KeyPath,
|
||||
tree: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None,
|
||||
) -> Iterable[Tuple[KeyPath, Any]]:
|
||||
) -> Iterable[tuple[KeyPath, Any]]:
|
||||
if is_leaf and is_leaf(tree):
|
||||
yield key_path, tree
|
||||
return
|
||||
|
@ -17,9 +17,8 @@ from functools import partial
|
||||
import itertools as it
|
||||
import logging
|
||||
import operator
|
||||
from typing import (Any, Callable, Generic, Iterable, Iterator, List,
|
||||
Optional, Sequence, Set, Tuple, TypeVar, overload,
|
||||
TYPE_CHECKING, cast)
|
||||
from typing import (Any, Callable, Generic, Iterable, Iterator, Optional,
|
||||
Sequence, TypeVar, overload, TYPE_CHECKING, cast)
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
@ -43,13 +42,13 @@ if TYPE_CHECKING:
|
||||
# to that used for builtins.zip in python/typeshed. This supports
|
||||
# return types matching input types for up to three arguments.
|
||||
@overload
|
||||
def safe_zip(__arg1: Iterable[T1]) -> List[Tuple[T1]]: ...
|
||||
def safe_zip(__arg1: Iterable[T1]) -> list[tuple[T1]]: ...
|
||||
@overload
|
||||
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2]) -> List[Tuple[T1, T2]]: ...
|
||||
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2]) -> list[tuple[T1, T2]]: ...
|
||||
@overload
|
||||
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> List[Tuple[T1, T2, T3]]: ...
|
||||
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> list[tuple[T1, T2, T3]]: ...
|
||||
@overload
|
||||
def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> List[Tuple[Any, ...]]: ...
|
||||
def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> list[tuple[Any, ...]]: ...
|
||||
|
||||
def safe_zip(*args):
|
||||
args = list(map(list, args))
|
||||
@ -77,16 +76,16 @@ if TYPE_CHECKING:
|
||||
# to that used for builtins.map in python/typeshed. This supports
|
||||
# checking input types for the callable with up to three arguments.
|
||||
@overload
|
||||
def safe_map(f: Callable[[T1], T], __arg1: Iterable[T1]) -> List[T]: ...
|
||||
def safe_map(f: Callable[[T1], T], __arg1: Iterable[T1]) -> list[T]: ...
|
||||
|
||||
@overload
|
||||
def safe_map(f: Callable[[T1, T2], T], __arg1: Iterable[T1], __arg2: Iterable[T2]) -> List[T]: ...
|
||||
def safe_map(f: Callable[[T1, T2], T], __arg1: Iterable[T1], __arg2: Iterable[T2]) -> list[T]: ...
|
||||
|
||||
@overload
|
||||
def safe_map(f: Callable[[T1, T2, T3], T], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> List[T]: ...
|
||||
def safe_map(f: Callable[[T1, T2, T3], T], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> list[T]: ...
|
||||
|
||||
@overload
|
||||
def safe_map(f: Callable[..., T], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> List[T]: ...
|
||||
def safe_map(f: Callable[..., T], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> list[T]: ...
|
||||
|
||||
def safe_map(f, *args):
|
||||
args = list(map(list, args))
|
||||
@ -108,26 +107,26 @@ else:
|
||||
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
||||
return list(map(f, *args))
|
||||
|
||||
def unzip2(xys: Iterable[Tuple[T1, T2]]
|
||||
) -> Tuple[Tuple[T1, ...], Tuple[T2, ...]]:
|
||||
def unzip2(xys: Iterable[tuple[T1, T2]]
|
||||
) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
|
||||
"""Unzip sequence of length-2 tuples into two tuples."""
|
||||
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
|
||||
# is too permissive about inputs, and does not guarantee a length-2 output.
|
||||
xs: List[T1] = []
|
||||
ys: List[T2] = []
|
||||
xs: list[T1] = []
|
||||
ys: list[T2] = []
|
||||
for x, y in xys:
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
return tuple(xs), tuple(ys)
|
||||
|
||||
def unzip3(xyzs: Iterable[Tuple[T1, T2, T3]]
|
||||
) -> Tuple[Tuple[T1, ...], Tuple[T2, ...], Tuple[T3, ...]]:
|
||||
def unzip3(xyzs: Iterable[tuple[T1, T2, T3]]
|
||||
) -> tuple[tuple[T1, ...], tuple[T2, ...], tuple[T3, ...]]:
|
||||
"""Unzip sequence of length-3 tuples into three tuples."""
|
||||
# Note: we deliberately don't use zip(*xyzs) because it is lazily evaluated,
|
||||
# is too permissive about inputs, and does not guarantee a length-3 output.
|
||||
xs: List[T1] = []
|
||||
ys: List[T2] = []
|
||||
zs: List[T3] = []
|
||||
xs: list[T1] = []
|
||||
ys: list[T2] = []
|
||||
zs: list[T3] = []
|
||||
for x, y, z in xyzs:
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
@ -140,7 +139,7 @@ def subvals(lst, replace):
|
||||
lst[i] = v
|
||||
return tuple(lst)
|
||||
|
||||
def split_list(args: Sequence[T], ns: Sequence[int]) -> List[List[T]]:
|
||||
def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]:
|
||||
args = list(args)
|
||||
lists = []
|
||||
for n in ns:
|
||||
@ -149,14 +148,14 @@ def split_list(args: Sequence[T], ns: Sequence[int]) -> List[List[T]]:
|
||||
lists.append(args)
|
||||
return lists
|
||||
|
||||
def partition_list(bs: Sequence[bool], l: Sequence[T]) -> Tuple[List[T], List[T]]:
|
||||
def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T]]:
|
||||
assert len(bs) == len(l)
|
||||
lists = [], [] # type: ignore
|
||||
for b, x in zip(bs, l):
|
||||
lists[b].append(x)
|
||||
return lists
|
||||
|
||||
def merge_lists(bs: Sequence[bool], l0: Sequence[T], l1: Sequence[T]) -> List[T]:
|
||||
def merge_lists(bs: Sequence[bool], l0: Sequence[T], l1: Sequence[T]) -> list[T]:
|
||||
assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0)
|
||||
i0, i1 = iter(l0), iter(l1)
|
||||
out = [next(i1) if b else next(i0) for b in bs]
|
||||
@ -170,7 +169,7 @@ def split_dict(dct, names):
|
||||
assert not dct
|
||||
return lst
|
||||
|
||||
def concatenate(xs: Iterable[Sequence[T]]) -> List[T]:
|
||||
def concatenate(xs: Iterable[Sequence[T]]) -> list[T]:
|
||||
"""Concatenates/flattens a list of lists."""
|
||||
return list(it.chain.from_iterable(xs))
|
||||
|
||||
@ -178,7 +177,7 @@ flatten = concatenate
|
||||
|
||||
_unflatten_done = object()
|
||||
|
||||
def unflatten(xs: Iterable[T], ns: Sequence[int]) -> List[List[T]]:
|
||||
def unflatten(xs: Iterable[T], ns: Sequence[int]) -> list[list[T]]:
|
||||
"""Splits `xs` into subsequences of lengths `ns`.
|
||||
|
||||
Unlike `split_list`, the `sum(ns)` must be equal to `len(xs)`."""
|
||||
@ -499,8 +498,8 @@ def distributed_debug_log(*pairs):
|
||||
|
||||
|
||||
class OrderedSet(Generic[T]):
|
||||
elts_set: Set[T]
|
||||
elts_list: List[T]
|
||||
elts_set: set[T]
|
||||
elts_list: list[T]
|
||||
|
||||
def __init__(self):
|
||||
self.elts_set = set()
|
||||
|
@ -30,7 +30,7 @@ import platform as py_platform
|
||||
import pkgutil
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Mapping, Optional, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -109,7 +109,7 @@ def get_compile_options(
|
||||
use_auto_spmd_partitioning: bool = False,
|
||||
auto_spmd_partitioning_mesh_shape=[],
|
||||
auto_spmd_partitioning_mesh_ids=[],
|
||||
env_options_overrides: Optional[Dict[str, str]] = None,
|
||||
env_options_overrides: Optional[dict[str, str]] = None,
|
||||
) -> xla_client.CompileOptions:
|
||||
"""Returns the compile options to use, as derived from flag values.
|
||||
|
||||
@ -231,10 +231,10 @@ class BackendRegistration:
|
||||
# a buggy plugin.
|
||||
experimental: bool = False
|
||||
|
||||
_backend_factories: Dict[str, BackendRegistration] = {}
|
||||
_backend_factories: dict[str, BackendRegistration] = {}
|
||||
_default_backend: Optional[xla_client.Client] = None
|
||||
_backends : Dict[str, xla_client.Client] = {}
|
||||
_backends_errors : Dict[str, str] = {}
|
||||
_backends : dict[str, xla_client.Client] = {}
|
||||
_backends_errors : dict[str, str] = {}
|
||||
_backend_lock = threading.Lock()
|
||||
|
||||
# The set of known non-experimental plugins.
|
||||
@ -245,7 +245,7 @@ _backend_lock = threading.Lock()
|
||||
# It is fine for a plugin not to implement every feature that JAX uses, provided
|
||||
# that a reasonable feature set is implemented and the plugin fails gracefully
|
||||
# for unimplemented features. Wrong outputs are not acceptable.
|
||||
_nonexperimental_plugins: Set[str] = set()
|
||||
_nonexperimental_plugins: set[str] = set()
|
||||
|
||||
def register_backend_factory(name: str, factory: BackendFactory, *,
|
||||
priority: int = 0,
|
||||
@ -314,7 +314,7 @@ if hasattr(xla_client, "make_tpu_client"):
|
||||
|
||||
def _get_pjrt_plugin_names_and_library_paths(
|
||||
plugins_from_env: str,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""Gets the names and library paths of PJRT plugins to load from env var.
|
||||
|
||||
Args:
|
||||
@ -343,7 +343,7 @@ def _get_pjrt_plugin_names_and_library_paths(
|
||||
|
||||
def _get_pjrt_plugin_config(
|
||||
json_path: str,
|
||||
) -> Tuple[str, Optional[Mapping[str, Union[str, int, List[int], float]]]]:
|
||||
) -> tuple[str, Optional[Mapping[str, Union[str, int, list[int], float]]]]:
|
||||
"""Gets PJRT plugin configuration from a json file.
|
||||
|
||||
The json file needs to have a "library_path" field for the plugin library
|
||||
@ -440,7 +440,7 @@ def register_plugin(
|
||||
*,
|
||||
priority: int = 400,
|
||||
library_path: Optional[str] = None,
|
||||
options: Optional[Mapping[str, Union[str, int, List[int], float]]] = None,
|
||||
options: Optional[Mapping[str, Union[str, int, list[int], float]]] = None,
|
||||
) -> None:
|
||||
"""Registers a backend factory for the PJRT plugin.
|
||||
|
||||
@ -516,7 +516,7 @@ _platform_aliases = {
|
||||
"rocm": "gpu",
|
||||
}
|
||||
|
||||
_alias_to_platforms: Dict[str, List[str]] = {}
|
||||
_alias_to_platforms: dict[str, list[str]] = {}
|
||||
for _platform, _alias in _platform_aliases.items():
|
||||
_alias_to_platforms.setdefault(_alias, []).append(_platform)
|
||||
|
||||
@ -550,7 +550,7 @@ def canonicalize_platform(platform: str) -> str:
|
||||
"Platforms present are: " + ",".join(b.keys()))
|
||||
|
||||
|
||||
def expand_platform_alias(platform: str) -> List[str]:
|
||||
def expand_platform_alias(platform: str) -> list[str]:
|
||||
"""Expands, e.g., "gpu" to ["cuda", "rocm"].
|
||||
|
||||
This is used for convenience reasons: we expect cuda and rocm to act similarly
|
||||
@ -561,7 +561,7 @@ def expand_platform_alias(platform: str) -> List[str]:
|
||||
def is_gpu(platform):
|
||||
return platform in ("cuda", "rocm")
|
||||
|
||||
def backends() -> Dict[str, xla_client.Client]:
|
||||
def backends() -> dict[str, xla_client.Client]:
|
||||
global _backends
|
||||
global _backends_errors
|
||||
global _default_backend
|
||||
@ -732,7 +732,7 @@ def local_device_count(
|
||||
|
||||
def devices(
|
||||
backend: Optional[Union[str, xla_client.Client]] = None
|
||||
) -> List[xla_client.Device]:
|
||||
) -> list[xla_client.Device]:
|
||||
"""Returns a list of all devices for a given backend.
|
||||
|
||||
.. currentmodule:: jaxlib.xla_extension
|
||||
@ -765,7 +765,7 @@ def default_backend() -> str:
|
||||
@lru_cache
|
||||
def local_devices(process_index: Optional[int] = None,
|
||||
backend: Optional[Union[str, xla_client.Client]] = None,
|
||||
host_id: Optional[int] = None) -> List[xla_client.Device]:
|
||||
host_id: Optional[int] = None) -> list[xla_client.Device]:
|
||||
"""Like :py:func:`jax.devices`, but only returns devices local to a given process.
|
||||
|
||||
If ``process_index`` is ``None``, returns devices local to this process.
|
||||
@ -839,7 +839,7 @@ def host_count(backend: Optional[Union[str, xla_client.Client]] = None) -> int:
|
||||
# TODO: remove this sometime after jax 0.2.13 is released
|
||||
def host_ids(
|
||||
backend: Optional[Union[str, xla_client.Client]] = None
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
warnings.warn(
|
||||
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
|
||||
"instead. jax.host_ids will eventually be removed; please update your "
|
||||
|
@ -89,7 +89,7 @@ Example Usage:
|
||||
.. _Optax: https://github.com/deepmind/optax
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, NamedTuple, Tuple, Union
|
||||
from typing import Any, Callable, NamedTuple, Union
|
||||
|
||||
from collections import namedtuple
|
||||
import functools
|
||||
@ -138,7 +138,7 @@ class Optimizer(NamedTuple):
|
||||
Schedule = Callable[[Step], float]
|
||||
|
||||
def optimizer(opt_maker: Callable[...,
|
||||
Tuple[Callable[[Params], State],
|
||||
tuple[Callable[[Params], State],
|
||||
Callable[[Step, Updates, Params], Params],
|
||||
Callable[[State], Params]]]) -> Callable[..., Optimizer]:
|
||||
"""Decorator to make an optimizer defined for arrays generalize to containers.
|
||||
|
@ -22,7 +22,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
import threading
|
||||
from typing import Awaitable, Any, Callable, Dict, Optional, Sequence, Union
|
||||
from typing import Awaitable, Any, Callable, Optional, Sequence, Union
|
||||
|
||||
import jax
|
||||
from jax._src import distributed
|
||||
@ -236,7 +236,7 @@ def estimate_read_memory_footprint(t: ts.TensorStore,
|
||||
|
||||
async def async_deserialize(
|
||||
in_sharding: sharding_impls.XLACompatibleSharding,
|
||||
tensorstore_spec: Union[ts.Spec, Dict[str, Any]],
|
||||
tensorstore_spec: Union[ts.Spec, dict[str, Any]],
|
||||
global_shape: Optional[Sequence[int]] = None,
|
||||
dtype=None,
|
||||
byte_limiter: Optional[_LimitInFlightBytes] = None,
|
||||
@ -288,7 +288,7 @@ async def async_deserialize(
|
||||
|
||||
|
||||
def run_deserialization(shardings: Sequence[sharding.Sharding],
|
||||
tensorstore_specs: Sequence[Dict[str, Any]],
|
||||
tensorstore_specs: Sequence[dict[str, Any]],
|
||||
global_shapes: Optional[Sequence[array.Shape]] = None,
|
||||
dtypes: Optional[Sequence[typing.DTypeLike]] = None,
|
||||
concurrent_gb: int = 32):
|
||||
@ -370,7 +370,7 @@ class GlobalAsyncCheckpointManagerBase(metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
def deserialize(self, shardings: Sequence[sharding.Sharding],
|
||||
tensorstore_specs: Sequence[Dict[str, Any]],
|
||||
tensorstore_specs: Sequence[dict[str, Any]],
|
||||
global_shapes: Optional[Sequence[array.Shape]] = None,
|
||||
dtypes: Optional[Sequence[typing.DTypeLike]] = None):
|
||||
"""Deserializes GDAs from TensorStore."""
|
||||
@ -519,7 +519,7 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
|
||||
self._start_async_commit(on_commit_callback)
|
||||
|
||||
def deserialize(self, shardings: Sequence[sharding.Sharding],
|
||||
tensorstore_specs: Sequence[Dict[str, Any]],
|
||||
tensorstore_specs: Sequence[dict[str, Any]],
|
||||
global_shapes: Optional[Sequence[array.Shape]] = None,
|
||||
dtypes: Optional[Sequence[typing.DTypeLike]] = None,
|
||||
concurrent_gb: int = 32):
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user