Upgrade most .py sources to 3.9

This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
This commit is contained in:
Sergei Lebedev 2023-12-08 12:09:04 +00:00
parent 7af1c149f5
commit 36f6b52e42
36 changed files with 198 additions and 194 deletions

View File

@ -29,7 +29,7 @@ import inspect
import math
import typing
from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload,
cast, Optional)
cast)
import weakref
import numpy as np
@ -2461,7 +2461,7 @@ def make_jaxpr(fun: Callable,
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
return make_jaxpr_f
def _infer_src_sharding(src, x) -> Optional[Sharding]:
def _infer_src_sharding(src, x) -> Sharding | None:
if src is not None:
return src
if isinstance(x, array.ArrayImpl):

View File

@ -16,7 +16,8 @@
import abc
import numpy as np
from typing import Any, Sequence, Union
from typing import Any, Union
from collections.abc import Sequence
# TODO(jakevdp): fix import cycles and define these.
Shard = Any

View File

@ -22,7 +22,7 @@ import logging
import os
import sys
import threading
from typing import Any, Callable, Generic, NamedTuple, NoReturn, Optional, TypeVar
from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar
import warnings
from jax._src import lib
@ -134,7 +134,7 @@ class Config:
raise AttributeError(f"Unrecognized config option: {name}")
def add_option(self, name, default, opt_type, meta_args, meta_kwargs,
update_hook: Optional[Callable[[Any], None]] = None):
update_hook: Callable[[Any], None] | None = None):
if name in self.values:
raise Exception(f"Config option {name} already defined")
self.values[name] = default
@ -238,7 +238,7 @@ _thread_local_state = threading.local()
class _StateContextManager(Generic[_T]):
def __init__(self, name, help, update_thread_local_hook,
validate_new_val_hook: Optional[Callable[[Any], None]] = None,
validate_new_val_hook: Callable[[Any], None] | None = None,
extra_description: str = "", default_value: Any = no_default):
self._name = name
self.__name__ = name[4:] if name.startswith('jax_') else name
@ -302,8 +302,8 @@ def define_bool_state(
default: bool,
help: str,
*,
update_global_hook: Optional[Callable[[bool], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[bool]], None]] = None,
update_global_hook: Callable[[bool], None] | None = None,
update_thread_local_hook: Callable[[bool | None], None] | None = None,
upgrade: bool = False,
extra_description: str = '',
) -> _StateContextManager[bool]:
@ -375,11 +375,11 @@ def define_bool_state(
def define_enum_state(
name: str,
enum_values: list[str],
default: Optional[str],
default: str | None,
help: str,
*,
update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> _StateContextManager[str]:
"""Set up thread-local state and return a contextmanager for managing it.
@ -420,11 +420,11 @@ def define_enum_state(
def define_int_state(
name: str,
default: Optional[int],
default: int | None,
help: str,
*,
update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> _StateContextManager[int]:
"""Set up thread-local state and return a contextmanager for managing it.
@ -463,11 +463,11 @@ def define_int_state(
def define_float_state(
name: str,
default: Optional[float],
default: float | None,
help: str,
*,
update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> _StateContextManager[float]:
"""Set up thread-local state and return a contextmanager for managing it.
@ -508,11 +508,11 @@ def define_float_state(
def define_string_state(
name: str,
default: Optional[str],
default: str | None,
help: str,
*,
update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> _StateContextManager[str]:
"""Set up thread-local state and return a contextmanager for managing it.
@ -552,9 +552,9 @@ def define_string_or_object_state(
default: Any,
help: str,
*,
update_global_hook: Optional[Callable[[Any], None]] = None,
update_thread_local_hook: Optional[Callable[[Any], None]] = None,
validate_new_val_hook: Optional[Callable[[Any], None]] = None,
update_global_hook: Callable[[Any], None] | None = None,
update_thread_local_hook: Callable[[Any], None] | None = None,
validate_new_val_hook: Callable[[Any], None] | None = None,
) -> _StateContextManager[Any]:
"""Set up thread-local state and return a contextmanager for managing it.
@ -651,9 +651,9 @@ already_configured_with_absl = False
# a global/thread-local state. These methods allow updates to part of the
# state when a configuration value changes.
class _GlobalExtraJitContext(NamedTuple):
numpy_rank_promotion: Optional[str] = None
numpy_dtype_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
numpy_rank_promotion: str | None = None
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool = False
threefry_partitionable: bool = False
softmax_custom_jvp: bool = False
@ -675,12 +675,12 @@ class _ThreadLocalExtraJitContext(NamedTuple):
The initialization, which uses both config.py and core.py is done using
`_update_thread_local_jit_state` in core.py to prevent circular imports.
"""
dynamic_trace_state: Optional[Any] = None
dynamic_trace_state: Any | None = None
axis_env_state: Hashable = ()
mesh_context_manager: Hashable = ()
numpy_rank_promotion: Optional[str] = None
numpy_dtype_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
numpy_rank_promotion: str | None = None
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool = False
threefry_partitionable: bool = False
softmax_custom_jvp: bool = False
@ -1320,7 +1320,7 @@ def transfer_guard(new_val: str) -> Iterator[None]:
yield
def _update_debug_log_modules(module_names_str: Optional[str]):
def _update_debug_log_modules(module_names_str: str | None):
logging_config.disable_all_debug_logging()
if not module_names_str:
return

View File

@ -24,7 +24,7 @@ from __future__ import annotations
import abc
import builtins
import functools
from typing import cast, overload, Any, Literal, Optional, Union
from typing import cast, overload, Any, Literal, Union
import warnings
import ml_dtypes
@ -207,7 +207,7 @@ def to_complex_dtype(dtype: DTypeLike) -> DType:
@functools.cache
def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: Any) -> Union[DType, ExtendedDType]:
def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: Any) -> DType | ExtendedDType:
if issubdtype(dtype, extended):
if not allow_extended_dtype:
raise ValueError(f"Internal: canonicalize_dtype called on extended dtype {dtype} "
@ -227,10 +227,10 @@ def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: An
def canonicalize_dtype(dtype: Any, allow_extended_dtype: Literal[False] = False) -> DType: ...
@overload
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> Union[DType, ExtendedDType]: ...
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> DType | ExtendedDType: ...
@export
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> Union[DType, ExtendedDType]:
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> DType | ExtendedDType:
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
return _canonicalize_dtype(config.enable_x64.value, allow_extended_dtype, dtype) # pytype: disable=bad-return-type
@ -292,7 +292,7 @@ def _scalar_type_to_dtype(typ: type, value: Any = None) -> DType:
return dtype
def coerce_to_array(x: Any, dtype: Optional[DTypeLike] = None) -> np.ndarray:
def coerce_to_array(x: Any, dtype: DTypeLike | None = None) -> np.ndarray:
"""Coerces a scalar or NumPy array to an np.array.
Handles Python scalar type promotion according to JAX's rules, not NumPy's
@ -643,10 +643,10 @@ def result_type(*args: Any, return_weak_type_flag: Literal[True]) -> tuple[DType
def result_type(*args: Any, return_weak_type_flag: Literal[False] = False) -> DType: ...
@overload
def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, tuple[DType, bool]]: ...
def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tuple[DType, bool]: ...
@export
def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, tuple[DType, bool]]:
def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tuple[DType, bool]:
"""Convenience function to apply JAX argument dtype promotion.
Args:

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Hashable
from typing import Callable
from collections.abc import Hashable
from jax import Array

View File

@ -25,7 +25,7 @@ import itertools
import operator
import re
import typing
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union
from typing import Any, Callable, NamedTuple, Protocol, Union
import warnings
import numpy as np
@ -707,7 +707,7 @@ def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str |
return layout._to_xla_layout()
def _get_mem_kind(s: Optional[XLACompatibleSharding]) -> Optional[str]:
def _get_mem_kind(s: XLACompatibleSharding | None) -> str | None:
if s is None:
return None
assert isinstance(s, sharding_impls.XLACompatibleSharding)
@ -1454,7 +1454,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
with source_info_util.user_context(eqn.source_info.traceback), loc:
override_rule = get_override_lowering_rule(eqn.primitive)
platform_rules: dict[str, LoweringRule] = {}
default_rule: Optional[LoweringRule] = None
default_rule: LoweringRule | None = None
# See mlir.lower_per_platform for meaning of `platform_rules` and `default_rule`
if override_rule is not None:
default_rule = override_rule
@ -1525,7 +1525,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
def lower_per_platform(ctx: LoweringRuleContext,
description: str,
platform_rules: dict[str, LoweringRule],
default_rule: Optional[LoweringRule],
default_rule: LoweringRule | None,
effects: effects_lib.Effects,
*rule_args: ir.Value,
**rule_kwargs) -> ir.Value:
@ -1710,7 +1710,7 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects,
return func_op
def check_backend_matches(inner_backend: Optional[str],
def check_backend_matches(inner_backend: str | None,
lowering_platforms: Sequence[str]):
# For nested calls, the outermost call sets the backend for all inner calls;
# it's an error if the inner call has a conflicting explicit backend spec.

View File

@ -25,7 +25,8 @@ import itertools as it
import logging
import math
import threading
from typing import (Any, Callable, NamedTuple, Iterator, Optional, Union, cast, TypeVar)
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar)
from collections.abc import Iterator
import warnings
import numpy as np
@ -2914,8 +2915,8 @@ def _compile_replicated_mesh_executable_from_hlo(
@lru_cache
def create_mesh_pspec_sharding(
mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None,
memory_kind: Optional[str] = None) -> sharding_impls.NamedSharding:
mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None,
memory_kind: str | None = None) -> sharding_impls.NamedSharding:
if pspec is None:
pspec, parsed_pspec = PartitionSpec(), None
return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec,

View File

@ -202,7 +202,7 @@ def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes:
pprof tool for visualization.
"""
d: DefaultDict[tuple[Optional[xla_client.Traceback], core.Primitive], int]
d = collections.defaultdict(lambda: 0)
d = collections.defaultdict(int)
for _, eqn in all_eqns(jaxpr):
d[(eqn.source_info.traceback, eqn.primitive)] += 1
return _pprof_profile(d)

View File

@ -19,7 +19,7 @@ from functools import partial
import inspect
import itertools
import operator
from typing import Any, Callable, Optional, TypeVar
from typing import Any, Callable, TypeVar
import jax
import weakref
@ -104,7 +104,7 @@ Y = TypeVar('Y')
def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
length: Optional[int] = None,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1) -> tuple[Carry, Y]:
"""Scan a function over leading array axes while carrying along state.

View File

@ -22,7 +22,7 @@ from functools import partial
import itertools
import math
import operator
from typing import (Any, Callable, Optional, TypeVar, Union,
from typing import (Any, Callable, TypeVar, Union,
cast as type_cast, overload)
import warnings
@ -104,7 +104,7 @@ def _validate_shapes(shapes: Sequence[Shape]):
map(_check_static_shape, shapes)
def _try_broadcast_shapes(
shapes: Sequence[tuple[int, ...]]) -> Optional[tuple[int, ...]]:
shapes: Sequence[tuple[int, ...]]) -> tuple[int, ...] | None:
if len(shapes) == 1: return shapes[0]
ranks = {len(shape) for shape in shapes}
if len(ranks) > 1: return None # must have consistent rank
@ -140,8 +140,8 @@ def asarray(x: ArrayLike) -> Array:
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: ...
@overload
def broadcast_shapes(*shapes: tuple[Union[int, core.Tracer], ...]
) -> tuple[Union[int, core.Tracer], ...]: ...
def broadcast_shapes(*shapes: tuple[int | core.Tracer, ...]
) -> tuple[int | core.Tracer, ...]: ...
def broadcast_shapes(*shapes):
"""Returns the shape that results from NumPy broadcasting of `shapes`."""
@ -183,8 +183,8 @@ def _broadcast_ranks(s1, s2):
def _identity(x): return x
def _extract_tracers_dyn_shape(
shape: Sequence[Union[int, core.Tracer]]
) -> tuple[list[core.Tracer], list[Optional[int]]]:
shape: Sequence[int | core.Tracer]
) -> tuple[list[core.Tracer], list[int | None]]:
# Given a sequence representing a shape, pull out Tracers, replacing with None
if config.dynamic_shapes.value:
# We must gate this behavior under a flag because otherwise the errors
@ -196,9 +196,9 @@ def _extract_tracers_dyn_shape(
return [], list(shape) # type: ignore
def _merge_dyn_shape(
static_shape: Sequence[Optional[int]],
static_shape: Sequence[int | None],
dyn_shape: Sequence[Any],
) -> tuple[Union[int, mlir.Value, core.Tracer], ...]:
) -> tuple[int | mlir.Value | core.Tracer, ...]:
# Replace Nones in static_shape with elements of dyn_shape, in order
dyn_shape_it = iter(dyn_shape)
shape = tuple(next(dyn_shape_it) if d is None else d for d in static_shape)
@ -516,7 +516,7 @@ def convert_element_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
"""
return _convert_element_type(operand, new_dtype, weak_type=False)
def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DTypeLike] = None,
def _convert_element_type(operand: ArrayLike, new_dtype: DTypeLike | None = None,
weak_type: bool = False):
if hasattr(operand, '__jax_array__'):
operand = operand.__jax_array__() # type: ignore
@ -599,7 +599,7 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array:
"""
return clamp_p.bind(min, x, max)
def concatenate(operands: Union[Array, Sequence[ArrayLike]], dimension: int) -> Array:
def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
"""Concatenates a sequence of arrays along `dimension`.
Wraps XLA's `Concatenate
@ -677,7 +677,7 @@ PrecisionLike = Union[None, str, PrecisionType, tuple[str, str],
tuple[PrecisionType, PrecisionType]]
def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
preferred_element_type: Optional[DTypeLike] = None) -> Array:
preferred_element_type: DTypeLike | None = None) -> Array:
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
Wraps XLA's `Dot
@ -714,7 +714,7 @@ DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
precision: PrecisionLike = None,
preferred_element_type: Optional[DTypeLike] = None) -> Array:
preferred_element_type: DTypeLike | None = None) -> Array:
"""General dot product/contraction operator.
Wraps XLA's `DotGeneral
@ -813,7 +813,7 @@ def broadcast_to_rank(x: Array, rank: int) -> Array:
return broadcast(x, (1,) * (rank - x.ndim))
def reshape(operand: ArrayLike, new_sizes: Shape,
dimensions: Optional[Sequence[int]] = None) -> Array:
dimensions: Sequence[int] | None = None) -> Array:
"""Wraps XLA's `Reshape
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
operator.
@ -1042,7 +1042,7 @@ def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree):
return jaxpr, tuple(consts), out_tree()
def _get_monoid_reducer(monoid_op: Callable,
xs: Sequence[Array]) -> Optional[Callable]:
xs: Sequence[Array]) -> Callable | None:
if len(xs) != 1:
return None
x, = xs
@ -1128,8 +1128,8 @@ def sort(operand: Array, dimension: int = -1,
def sort(operand: Sequence[Array], dimension: int = -1,
is_stable: bool = True, num_keys: int = 1) -> tuple[Array, ...]: ...
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
is_stable: bool = True, num_keys: int = 1) -> Union[Array, tuple[Array, ...]]:
def sort(operand: Array | Sequence[Array], dimension: int = -1,
is_stable: bool = True, num_keys: int = 1) -> Array | tuple[Array, ...]:
"""Wraps XLA's `Sort
<https://www.tensorflow.org/xla/operation_semantics#sort>`_ operator.
@ -1196,7 +1196,7 @@ def tie_in(x: Any, y: T) -> T:
"""Deprecated. Ignores ``x`` and returns ``y``."""
return y
def full(shape: Shape, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None) -> Array:
def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None) -> Array:
"""Returns an array of `shape` filled with `fill_value`.
Args:
@ -1303,7 +1303,7 @@ def stop_gradient(x: T) -> T:
return x
return tree_map(stop, x)
def reduce_precision(operand: Union[float, ArrayLike],
def reduce_precision(operand: float | ArrayLike,
exponent_bits: int,
mantissa_bits: int) -> Array:
"""Wraps XLA's `ReducePrecision
@ -1342,9 +1342,9 @@ def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array:
### convenience wrappers around traceables
def full_like(x: Union[ArrayLike, DuckTypedArray],
fill_value: ArrayLike, dtype: Optional[DTypeLike] = None,
shape: Optional[Shape] = None) -> Array:
def full_like(x: ArrayLike | DuckTypedArray,
fill_value: ArrayLike, dtype: DTypeLike | None = None,
shape: Shape | None = None) -> Array:
"""Create a full array like np.full based on the example array `x`.
Args:
@ -1380,7 +1380,7 @@ def full_like(x: Union[ArrayLike, DuckTypedArray],
def collapse(operand: Array, start_dimension: int,
stop_dimension: Optional[int] = None) -> Array:
stop_dimension: int | None = None) -> Array:
"""Collapses dimensions of an array into a single dimension.
For example, if ``operand`` is an array with shape ``[2, 3, 4]``,
@ -1691,7 +1691,7 @@ def broadcast_hlo(
return out
def _nary_lower_hlo(op: Callable, ctx,
*args: Union[ir.Value, Sequence[ir.Value]],
*args: ir.Value | Sequence[ir.Value],
explicit_type=False, **params) -> Sequence[ir.Value]:
"""Lowers an elementwise operator to its MLIR equivalent.
@ -2516,7 +2516,7 @@ def _precision_config(precision):
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: Optional[DTypeLike]):
preferred_element_type: DTypeLike | None):
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
for d in (lhs_contracting, lhs_batch)):
@ -2592,7 +2592,7 @@ def tuple_delete(tup, idx):
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: Optional[DTypeLike]):
preferred_element_type: DTypeLike | None):
# We're mostly matching XLA's logic here, namely in shape_inference.cc and
# primitive_util.h's HigherPrecisionType, e.g.
# https://github.com/openxla/xla/blob/ea3a841768d0dcf192e5820c9b25c34c73f2226a/xla/primitive_util.h#L329
@ -2636,7 +2636,7 @@ def _maybe_upcast(result_dtype, preferred_element_type):
return preferred_element_type
def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
preferred_element_type: Optional[DTypeLike],
preferred_element_type: DTypeLike | None,
swap_ans=False):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
x_ndim = x.aval.ndim
@ -2657,7 +2657,7 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
return x_bar
def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
preferred_element_type: Optional[DTypeLike]):
preferred_element_type: DTypeLike | None):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
y_bar = _dot_general_transpose_lhs(
@ -2670,7 +2670,7 @@ def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
precision,
preferred_element_type: Optional[DTypeLike]):
preferred_element_type: DTypeLike | None):
lhs, rhs = batched_args
lbd, rbd = batch_dims
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
@ -2794,7 +2794,7 @@ def precision_attr(precision: PrecisionType) -> ir.ArrayAttr:
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
precision, preferred_element_type: Optional[np.dtype],
precision, preferred_element_type: np.dtype | None,
platform: str = "default"):
del preferred_element_type # Implied by the output aval
lhs_aval, rhs_aval = ctx.avals_in
@ -4540,7 +4540,7 @@ def _array_copy(arr: ArrayLike) -> Array:
return copy_p.bind(arr)
def _which_dim_sharded(s: PmapSharding) -> Optional[int]:
def _which_dim_sharded(s: PmapSharding) -> int | None:
sharded_dim = None
for i, s in enumerate(s.sharding_spec.sharding):
if isinstance(s, pxla.Unstacked):
@ -4898,7 +4898,7 @@ def remaining(original, *removed_lists):
return [i for i in original if i not in removed]
def canonicalize_precision(precision: PrecisionLike) -> Optional[tuple[PrecisionType, PrecisionType]]:
def canonicalize_precision(precision: PrecisionLike) -> tuple[PrecisionType, PrecisionType] | None:
"""Turns an API precision specification, into a pair of enumeration values.
The API can take the precision as a string, or int, and either as a single

View File

@ -197,7 +197,7 @@ class Mesh(contextlib.ContextDecorator):
if val is not None:
return val
self = super(Mesh, cls).__new__(cls)
self = super().__new__(cls)
self.devices = devices.copy()
self.devices.flags.writeable = False
self.axis_names = axis_names

View File

@ -32,8 +32,7 @@ from functools import partial
import math
import operator
import types
from typing import (overload, Any, Callable, Literal, NamedTuple, Optional,
Protocol, TypeVar, Union)
from typing import (overload, Any, Callable, Literal, NamedTuple, Protocol, TypeVar, Union)
from textwrap import dedent as _dedent
import warnings
@ -2443,7 +2442,7 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array:
``(jnp.arange(-600, 600) * .01).astype(jnp.bfloat16)`` to generate a sequence in a higher precision
and then convert it to the desired lower precision.
""")
def arange(start: DimSize, stop: Optional[DimSize] = None,
def arange(start: DimSize, stop: DimSize | None = None,
step: DimSize | None = None, dtype: DTypeLike | None = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "arange")
if not config.dynamic_shapes.value:
@ -2480,8 +2479,7 @@ def arange(start: DimSize, stop: Optional[DimSize] = None,
return lax.iota(dtype, start)
else:
if step is None and start == 0 and stop is not None:
stop = np.ceil(stop).astype(int)
return lax.iota(dtype, stop)
return lax.iota(dtype, np.ceil(stop).astype(int))
return array(np.arange(start, stop=stop, step=step, dtype=dtype))
@ -2833,7 +2831,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
@util._wraps(np.tri)
def tri(N: int, M: int | None = None, k: int = 0, dtype: Optional[DTypeLike] = None) -> Array:
def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "tri")
M = M if M is not None else N
dtype = dtype or float32
@ -3933,7 +3931,7 @@ def sort_complex(a: ArrayLike) -> Array:
@util._wraps(np.lexsort)
@partial(jit, static_argnames=('axis',))
def lexsort(keys: Union[Array, np.ndarray, Sequence[ArrayLike]], axis: int = -1) -> Array:
def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array:
key_tuple = tuple(keys)
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
util.check_arraylike("lexsort", *key_tuple, emit_warning=True)

View File

@ -17,7 +17,7 @@ from collections.abc import Sequence
from functools import partial
import re
import textwrap
from typing import Any, Callable, NamedTuple, Optional, TypeVar
from typing import Any, Callable, NamedTuple, TypeVar
import warnings
@ -50,14 +50,14 @@ class ParsedDoc(NamedTuple):
front_matter: front matter before sections.
sections: dictionary of section titles to section content.
"""
docstr: Optional[str]
docstr: str | None
signature: str = ""
summary: str = ""
front_matter: str = ""
sections: dict[str, str] = {}
def _parse_numpydoc(docstr: Optional[str]) -> ParsedDoc:
def _parse_numpydoc(docstr: str | None) -> ParsedDoc:
"""Parse a standard numpy-style docstring.
Args:
@ -118,13 +118,13 @@ def _parse_extra_params(extra_params: str) -> dict[str, str]:
def _wraps(
fun: Optional[Callable[..., Any]],
fun: Callable[..., Any] | None,
update_doc: bool = True,
lax_description: str = "",
sections: Sequence[str] = ('Parameters', 'Returns', 'References'),
skip_params: Sequence[str] = (),
extra_params: Optional[str] = None,
module: Optional[str] = None,
extra_params: str | None = None,
module: str | None = None,
) -> Callable[[_T], _T]:
"""Specialized version of functools.wraps for wrapping numpy functions.

View File

@ -19,7 +19,8 @@ from collections.abc import Sequence
import contextlib
import dataclasses
import functools
from typing import Any, Callable, Iterator
from typing import Any, Callable
from collections.abc import Iterator
from jax._src import api_util
from jax._src import core as jax_core

View File

@ -17,7 +17,7 @@
from __future__ import annotations
import dataclasses
from typing import Any, Tuple
from typing import Any
import jax
from jax import core as jax_core
@ -37,7 +37,7 @@ import numpy as np
# to `tl.broadcast_to`.
broadcast_to_p = jax_core.Primitive('broadcast_to')
def broadcast_to(a: jax.Array, shape: Tuple[int, ...]) -> jax.Array:
def broadcast_to(a: jax.Array, shape: tuple[int, ...]) -> jax.Array:
if a.shape == shape:
return a
return broadcast_to_p.bind(a, shape=shape)
@ -99,9 +99,9 @@ ds = dslice # Handy alias
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class NDIndexer:
indices: Tuple[int | Slice | jax.Array, ...]
shape: Tuple[int, ...]
int_indexer_shape: Tuple[int, ...]
indices: tuple[int | Slice | jax.Array, ...]
shape: tuple[int, ...]
int_indexer_shape: tuple[int, ...]
def __post_init__(self):
if len(self.indices) != len(self.shape):
@ -148,7 +148,7 @@ class NDIndexer:
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
return NDIndexer(tuple(indices), shape, bcast_shape)
def get_indexer_shape(self) -> Tuple[int, ...]:
def get_indexer_shape(self) -> tuple[int, ...]:
is_int_indexing = [not isinstance(i, Slice) for i in self.indices]
other_indexers, _ = partition_list(is_int_indexing, self.indices)
other_shape = [s.size for s in other_indexers] # type: ignore

View File

@ -17,7 +17,8 @@ from __future__ import annotations
import dataclasses
import functools
from typing import Any, Callable, Sequence
from typing import Any, Callable
from collections.abc import Sequence
from jax import core as jax_core
from jax import lax

View File

@ -18,7 +18,8 @@ from __future__ import annotations
from functools import partial
import itertools as it
from typing import Any, Callable, Dict, Sequence, Tuple
from typing import Any, Callable
from collections.abc import Sequence
import jax
from jax import api_util
@ -89,7 +90,7 @@ def _uninitialized_value(shape, dtype):
def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
interpret, debug: bool,
in_shapes,
input_output_aliases: Tuple[Tuple[int, int], ...],
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
**compiler_params: Any):
if interpret:
@ -189,7 +190,7 @@ def _pallas_call_abstract_eval(*avals, out_shapes, **_):
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
input_output_aliases: Tuple[Tuple[int, int], ...],
input_output_aliases: tuple[tuple[int, int], ...],
in_shapes, out_shapes, grid_mapping, debug, interpret, **compiler_params: Any):
if grid_mapping.num_index_operands:
raise NotImplementedError
@ -236,7 +237,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
return out_primals, out_tangents
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
def _batch_block_mapping(grid: Tuple[int, ...], aval: jax_core.ShapedArray,
def _batch_block_mapping(grid: tuple[int, ...], aval: jax_core.ShapedArray,
dim: int | batching.NotMapped,
block_mapping: BlockMapping | None) -> BlockMapping:
def _block_map_function(new_idx, *args):
@ -271,13 +272,13 @@ def _batch_block_mapping(grid: Tuple[int, ...], aval: jax_core.ShapedArray,
def _pallas_call_batching_rule(args, dims, *,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
grid_mapping: GridMapping,
input_output_aliases: Tuple[Tuple[int, int], ...],
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
which_linear: Tuple[bool, ...],
which_linear: tuple[bool, ...],
**compiler_params: Any):
if grid_mapping.num_index_operands:
scalar_batch_dims = dims[:grid_mapping.num_index_operands]
@ -428,7 +429,7 @@ def pallas_call(
in_specs: Sequence[BlockSpec | NoBlockSpec] | NoBlockSpec = no_block_spec,
out_specs: BlockSpec | NoBlockSpec
| Sequence[BlockSpec | NoBlockSpec] = no_block_spec,
input_output_aliases: Dict[int, int] = {},
input_output_aliases: dict[int, int] = {},
interpret: bool = False,
name: str | None = None,
**compiler_params: Any,

View File

@ -18,7 +18,8 @@ from __future__ import annotations
import dataclasses
import functools
import operator
from typing import Any, Callable, Dict, Sequence, Tuple
from typing import Any, Callable
from collections.abc import Sequence
import zlib
import jax
@ -67,7 +68,7 @@ import triton.language as tl
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
partial = functools.partial
Grid = Tuple[int, ...]
Grid = tuple[int, ...]
NDIndexer = indexing.NDIndexer
GridMapping = pallas_core.GridMapping
BlockMapping = pallas_core.BlockMapping
@ -88,7 +89,7 @@ class TritonModuleContext:
class BlockInfo:
full_shape_dtype: jax.ShapeDtypeStruct
start_indices: Sequence[Any]
block_shape: Tuple[int, ...]
block_shape: tuple[int, ...]
@dataclasses.dataclass
@ -111,7 +112,7 @@ class TritonLoweringResult:
ir_context: tl_ir.context
module: tl_ir.module
builder: tl_ir.builder
grid: Tuple[int, ...]
grid: tuple[int, ...]
@dataclasses.dataclass
@ -610,7 +611,7 @@ def _compute_pointers_from_indices(
root_ptr: tl.core.tensor,
block_info: BlockInfo | None,
nd_indexer: NDIndexer,
array_shape: Tuple[int, ...],
array_shape: tuple[int, ...],
builder: tl_ir.builder,
) -> tl.core.tensor:
if block_info is None:
@ -1512,14 +1513,14 @@ def pallas_call_lowering(
*in_nodes,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
which_linear: Tuple[bool, ...],
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
which_linear: tuple[bool, ...],
interpret: bool,
debug: bool,
input_output_aliases: Tuple[Tuple[int, int], ...],
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
triton_params: Dict[str, Any] | None = None,
triton_params: dict[str, Any] | None = None,
**compiler_params: Any,
):
if interpret:

View File

@ -15,7 +15,6 @@
"""Pallas utility functions."""
import math
import numpy as np
from typing import Tuple
from jax import lax
from jax._src import core as jax_core
@ -36,7 +35,7 @@ def cdiv(a: int, b: int) -> int:
return (a + b - 1) // b
def strides_from_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]:
def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]:
size = np.prod(shape)
strides = []
for s in shape:

View File

@ -19,7 +19,8 @@ from functools import partial
import math
from operator import index
import typing
from typing import Hashable, Optional, Union
from typing import Union
from collections.abc import Hashable
import warnings
import numpy as np
@ -150,7 +151,7 @@ class PRNGSpec:
PRNGSpecDesc = Union[str, PRNGSpec, PRNGImpl]
def resolve_prng_impl(impl_spec: Optional[PRNGSpecDesc]) -> PRNGImpl:
def resolve_prng_impl(impl_spec: PRNGSpecDesc | None) -> PRNGImpl:
if impl_spec is None:
return default_prng_impl()
if type(impl_spec) is PRNGImpl:
@ -174,7 +175,7 @@ def resolve_prng_impl(impl_spec: Optional[PRNGSpecDesc]) -> PRNGImpl:
def _key(ctor_name: str, seed: int | ArrayLike,
impl_spec: Optional[PRNGSpecDesc]) -> KeyArray:
impl_spec: PRNGSpecDesc | None) -> KeyArray:
impl = resolve_prng_impl(impl_spec)
if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key):
raise TypeError(
@ -186,7 +187,7 @@ def _key(ctor_name: str, seed: int | ArrayLike,
return prng.random_seed(seed, impl=impl)
def key(seed: int | ArrayLike, *,
impl: Optional[PRNGSpecDesc] = None) -> KeyArray:
impl: PRNGSpecDesc | None = None) -> KeyArray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
The result is a scalar array with a key that indicates the default PRNG
@ -205,7 +206,7 @@ def key(seed: int | ArrayLike, *,
return _key('key', seed, impl)
def PRNGKey(seed: int | ArrayLike, *,
impl: Optional[PRNGSpecDesc] = None) -> KeyArray:
impl: PRNGSpecDesc | None = None) -> KeyArray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
The resulting key carries the default PRNG implementation, as
@ -277,7 +278,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
return _return_prng_keys(wrapped, key_out)
def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray:
# Alternative to split() to use within random samplers.
# TODO(frostig): remove and use split(); we no longer need to wait
# to always enable_custom_prng
@ -288,7 +289,7 @@ def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
shape = tuple(num) if isinstance(num, Sequence) else (num,)
return prng.random_split(key, shape=shape)
def split(key: KeyArrayLike, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
"""Splits a PRNG key into `num` new keys by adding a leading axis.
Args:
@ -324,7 +325,7 @@ def key_data(keys: KeyArrayLike) -> Array:
def wrap_key_data(key_bits_array: Array, *,
impl: Optional[PRNGSpecDesc] = None):
impl: PRNGSpecDesc | None = None):
"""Wrap an array of key data bits into a PRNG key array.
Args:
@ -344,7 +345,7 @@ def wrap_key_data(key_bits_array: Array, *,
### random samplers
def _check_shape(name: str, shape: Union[Shape, NamedShape], *param_shapes) -> None:
def _check_shape(name: str, shape: Shape | NamedShape, *param_shapes) -> None:
shape = core.as_named_shape(shape)
if param_shapes:
@ -358,7 +359,7 @@ def _check_shape(name: str, shape: Union[Shape, NamedShape], *param_shapes) -> N
def bits(key: KeyArrayLike,
shape: Shape = (),
dtype: Optional[DTypeLikeUInt] = None) -> Array:
dtype: DTypeLikeUInt | None = None) -> Array:
"""Sample uniform bits in the form of unsigned integers.
Args:
@ -386,7 +387,7 @@ def bits(key: KeyArrayLike,
def uniform(key: KeyArrayLike,
shape: Union[Shape, NamedShape] = (),
shape: Shape | NamedShape = (),
dtype: DTypeLikeFloat = float,
minval: RealArray = 0.,
maxval: RealArray = 1.) -> Array:
@ -561,7 +562,7 @@ def shuffle(key: KeyArrayLike, x: ArrayLike, axis: int = 0) -> Array:
def permutation(key: KeyArrayLike,
x: Union[int, ArrayLike],
x: int | ArrayLike,
axis: int = 0,
independent: bool = False) -> Array:
"""Returns a randomly permuted array or range.
@ -620,10 +621,10 @@ def _shuffle(key, x, axis) -> Array:
def choice(key: KeyArrayLike,
a: Union[int, ArrayLike],
a: int | ArrayLike,
shape: Shape = (),
replace: bool = True,
p: Optional[RealArray] = None,
p: RealArray | None = None,
axis: int = 0) -> Array:
"""Generates a random sample from a given array.
@ -697,7 +698,7 @@ def choice(key: KeyArrayLike,
def normal(key: KeyArrayLike,
shape: Union[Shape, NamedShape] = (),
shape: Shape | NamedShape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample standard normal random values with given shape and float dtype.
@ -752,8 +753,8 @@ def _normal_real(key, shape, dtype) -> Array:
def multivariate_normal(key: KeyArrayLike,
mean: RealArray,
cov: RealArray,
shape: Optional[Shape] = None,
dtype: Optional[DTypeLikeFloat] = None,
shape: Shape | None = None,
dtype: DTypeLikeFloat | None = None,
method: str = 'cholesky') -> Array:
r"""Sample multivariate normal random values with given mean and covariance.
@ -835,7 +836,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array:
def truncated_normal(key: KeyArrayLike,
lower: RealArray,
upper: RealArray,
shape: Optional[Union[Shape, NamedShape]] = None,
shape: Shape | NamedShape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample truncated standard normal random values with given shape and dtype.
@ -900,7 +901,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array:
def bernoulli(key: KeyArrayLike,
p: RealArray = np.float32(0.5),
shape: Optional[Union[Shape, NamedShape]] = None) -> Array:
shape: Shape | NamedShape | None = None) -> Array:
r"""Sample Bernoulli random values with given shape and mean.
The values are distributed according to the probability mass function:
@ -946,7 +947,7 @@ def _bernoulli(key, p, shape) -> Array:
def beta(key: KeyArrayLike,
a: RealArray,
b: RealArray,
shape: Optional[Shape] = None,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Beta random values with given shape and float dtype.
@ -1045,7 +1046,7 @@ def _cauchy(key, shape, dtype) -> Array:
def dirichlet(key: KeyArrayLike,
alpha: RealArray,
shape: Optional[Shape] = None,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Dirichlet random values with given shape and float dtype.
@ -1284,7 +1285,7 @@ batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule
def gamma(key: KeyArrayLike,
a: RealArray,
shape: Optional[Shape] = None,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Gamma random values with given shape and float dtype.
@ -1331,7 +1332,7 @@ def gamma(key: KeyArrayLike,
def loggamma(key: KeyArrayLike,
a: RealArray,
shape: Optional[Shape] = None,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
"""Sample log-gamma random values with given shape and float dtype.
@ -1473,7 +1474,7 @@ def _poisson(key, lam, shape, dtype) -> Array:
def poisson(key: KeyArrayLike,
lam: RealArray,
shape: Optional[Shape] = None,
shape: Shape | None = None,
dtype: DTypeLikeInt = int) -> Array:
r"""Sample Poisson random values with given shape and integer dtype.
@ -1555,7 +1556,7 @@ def _gumbel(key, shape, dtype) -> Array:
def categorical(key: KeyArrayLike,
logits: RealArray,
axis: int = -1,
shape: Optional[Shape] = None) -> Array:
shape: Shape | None = None) -> Array:
"""Sample random values from categorical distributions.
Args:
@ -1669,7 +1670,7 @@ def _logistic(key, shape, dtype):
def pareto(key: KeyArrayLike,
b: RealArray,
shape: Optional[Shape] = None,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Pareto random values with given shape and float dtype.
@ -1770,7 +1771,7 @@ def _t(key, df, shape, dtype) -> Array:
def chisquare(key: KeyArrayLike,
df: RealArray,
shape: Optional[Shape] = None,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Chisquare random values with given shape and float dtype.
@ -1823,7 +1824,7 @@ def _chisquare(key, df, shape, dtype) -> Array:
def f(key: KeyArrayLike,
dfnum: RealArray,
dfden: RealArray,
shape: Optional[Shape] = None,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample F-distribution random values with given shape and float dtype.
@ -2153,7 +2154,7 @@ def ball(
def rayleigh(key: KeyArrayLike,
scale: RealArray,
shape: Optional[Shape] = None,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Rayleigh random values with given shape and float dtype.
@ -2206,7 +2207,7 @@ def _rayleigh(key, scale, shape, dtype) -> Array:
def wald(key: KeyArrayLike,
mean: RealArray,
shape: Optional[Shape] = None,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Wald random values with given shape and float dtype.
@ -2264,7 +2265,7 @@ def _wald(key, mean, shape, dtype) -> Array:
def geometric(key: KeyArrayLike,
p: RealArray,
shape: Optional[Shape] = None,
shape: Shape | None = None,
dtype: DTypeLikeInt = int) -> Array:
r"""Sample Geometric random values with given shape and float dtype.
@ -2319,7 +2320,7 @@ def triangular(key: KeyArrayLike,
left: RealArray,
mode: RealArray,
right: RealArray,
shape: Optional[Shape] = None,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Triangular random values with given shape and float dtype.
@ -2382,7 +2383,7 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array:
def lognormal(key: KeyArrayLike,
sigma: RealArray = np.float32(1),
shape: Optional[Shape] = None,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r""" Sample lognormal random values with given shape and float dtype.
@ -2588,7 +2589,7 @@ def binomial(
key: KeyArray,
n: RealArray,
p: RealArray,
shape: Optional[Shape] = None,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float,
) -> Array:
r"""Sample Binomial random values with given shape and float dtype.

View File

@ -22,7 +22,7 @@ import enum
import functools
import itertools
import math
from typing import Any, NamedTuple, Union, cast, Optional
from typing import Any, NamedTuple, Union, cast
from jax._src import mesh as mesh_lib
from jax._src.op_shardings import (
@ -566,9 +566,9 @@ class PmapSharding(XLACompatibleSharding):
def _op_sharding_to_pos_sharding(
op_sharding: Union[xc.OpSharding, xc.HloSharding],
op_sharding: xc.OpSharding | xc.HloSharding,
device_assignment: Sequence[xc.Device],
memory_kind: Optional[str] = None) -> PositionalSharding:
memory_kind: str | None = None) -> PositionalSharding:
if isinstance(op_sharding, xc.OpSharding):
op_sharding = xc.HloSharding.from_proto(op_sharding) # type: ignore

View File

@ -24,7 +24,7 @@ import re
import os
import tempfile
import textwrap
from typing import Any, Callable, Optional
from typing import Any, Callable
import unittest
import warnings
import zlib
@ -919,7 +919,7 @@ class JaxTestCase(parameterized.TestCase):
'jax_legacy_prng_key': 'error',
}
_compilation_cache_exit_stack: Optional[ExitStack] = None
_compilation_cache_exit_stack: ExitStack | None = None
# TODO(mattjj): this obscures the error messages from failures, figure out how
# to re-enable it
@ -1275,8 +1275,8 @@ def numpy_version():
def parameterized_filterable(*,
kwargs: Sequence[dict[str, Any]],
testcase_name: Optional[Callable[[dict[str, Any]], str]] = None,
one_containing: Optional[str] = None,
testcase_name: Callable[[dict[str, Any]], str] | None = None,
one_containing: str | None = None,
):
"""
Decorator for named parameterized tests, with filtering.

View File

@ -21,7 +21,7 @@ import functools
from functools import partial
import operator as op
import textwrap
from typing import Any, Callable, NamedTuple, Type, TypeVar, Union, overload
from typing import Any, Callable, NamedTuple, TypeVar, Union, overload
from jax._src import traceback_util
from jax._src.lib import pytree
@ -737,7 +737,7 @@ def register_pytree_with_keys_class(cls: U) -> U:
return cls
def register_static(cls: Type[H]) -> Type[H]:
def register_static(cls: type[H]) -> type[H]:
"""Registers `cls` as a pytree with no leaves.
Instances are treated as static by `jax.jit`, `jax.pmap`, etc. This can be an

View File

@ -32,7 +32,7 @@ import pkgutil
import platform as py_platform
import sys
import threading
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import warnings
from jax._src import config
@ -855,7 +855,7 @@ def default_backend() -> str:
return get_backend(None).platform
def backend_pjrt_c_api_version(platform=None) -> Optional[Tuple[int, int]]:
def backend_pjrt_c_api_version(platform=None) -> Optional[tuple[int, int]]:
"""Returns the PJRT C API version of the backend.
Returns None if the backend does not use PJRT C API.

View File

@ -31,7 +31,7 @@ def _array_namespace(self, /, *, api_version: None | str = None):
def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *,
stream: Optional[Union[int, Any]] = None):
stream: int | Any | None = None):
if stream is not None:
raise NotImplementedError("stream argument of array.to_device()")
# The type of device is defined by Array.device. In JAX, this is a callable that

View File

@ -20,17 +20,17 @@ from jax import Array
from jax.experimental.array_api._data_type_functions import result_type as _result_type
def broadcast_arrays(*arrays: Array) -> List[Array]:
def broadcast_arrays(*arrays: Array) -> list[Array]:
"""Broadcasts one or more arrays against one another."""
return jax.numpy.broadcast_arrays(*arrays)
def broadcast_to(x: Array, /, shape: Tuple[int]) -> Array:
def broadcast_to(x: Array, /, shape: tuple[int]) -> Array:
"""Broadcasts an array to a specified shape."""
return jax.numpy.broadcast_to(x, shape=shape)
def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array:
def concat(arrays: Union[tuple[Array, ...], list[Array]], /, *, axis: Optional[int] = 0) -> Array:
"""Joins a sequence of arrays along an existing axis."""
dtype = _result_type(*arrays)
if axis is None:
@ -46,34 +46,34 @@ def expand_dims(x: Array, /, *, axis: int = 0) -> Array:
return jax.numpy.expand_dims(x, axis=axis)
def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
def flip(x: Array, /, *, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array:
"""Reverses the order of elements in an array along the given axis."""
return jax.numpy.flip(x, axis=axis)
def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array:
"""Permutes the axes (dimensions) of an array x."""
return jax.lax.transpose(x, axes)
def reshape(x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None) -> Array:
def reshape(x: Array, /, shape: tuple[int, ...], *, copy: Optional[bool] = None) -> Array:
"""Reshapes an array without changing its data."""
del copy # unused
return jax.numpy.reshape(x, shape)
def roll(x: Array, /, shift: Union[int, Tuple[int]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
def roll(x: Array, /, shift: Union[int, tuple[int]], *, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array:
"""Rolls array elements along a specified axis."""
return jax.numpy.roll(x, shift=shift, axis=axis)
def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
def squeeze(x: Array, /, axis: Union[int, tuple[int, ...]]) -> Array:
"""Removes singleton dimensions (axes) from x."""
dimensions = axis if isinstance(axis, tuple) else (axis,)
return jax.lax.squeeze(x, dimensions=dimensions)
def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
def stack(arrays: Union[tuple[Array, ...], list[Array]], /, *, axis: int = 0) -> Array:
"""Joins a sequence of arrays along a new axis."""
dtype = _result_type(*arrays)
return jax.numpy.stack(arrays, axis=axis, dtype=dtype)

View File

@ -2163,7 +2163,7 @@ tf_impl_with_avals[lax.dot_general_p] = _dot_general
def _dot_general_convert_to_common_dtype(
lhs: TfVal, lhs_aval: core.ShapedArray,
rhs: TfVal, rhs_aval: core.ShapedArray,
out_aval: core.ShapedArray) -> Tuple[TfVal, TfVal, Callable[[TfVal], TfVal]]:
out_aval: core.ShapedArray) -> tuple[TfVal, TfVal, Callable[[TfVal], TfVal]]:
# Returns the converted lhs, rhs, and the converter for the result.
# tfxla.dot_general does not handle arguments of different types.
# We convert the arguments and the result.

View File

@ -76,7 +76,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
cls.tf_devices.append(tf_device)
super(CallTfTest, cls).setUpClass()
super().setUpClass()
def setUp(self):
if tf is None:

View File

@ -69,7 +69,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
cls.tf_devices.append(tf_device)
super(Jax2TfTest, cls).setUpClass()
super().setUpClass()
def test_empty(self):
f_jax = lambda x, y: x

View File

@ -173,7 +173,7 @@ def mha(
block_q: int = 128,
block_k: int = 128,
backward_pass_impl: str = "triton",
num_warps: Optional[int] = None,
num_warps: int | None = None,
num_stages: int = 2,
grid: tuple[int, ...] | None = None,
interpret: bool = False,
@ -239,7 +239,7 @@ def _mha_forward(
block_q: int,
block_k: int,
backward_pass_impl: str,
num_warps: Optional[int],
num_warps: int | None,
num_stages: int,
grid: Any,
interpret: bool,
@ -451,7 +451,7 @@ def mha_backward_kernel(
def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
backward_pass_impl: str, num_warps: Optional[int],
backward_pass_impl: str, num_warps: int | None,
num_stages: int, grid: Any, interpret: bool,
debug: bool, res, do):
del num_warps, num_stages, grid

View File

@ -25,7 +25,7 @@ chunk, splits it in two, and sends each of the half-chunks in each direction
from __future__ import annotations
import functools
from typing import Sequence
from collections.abc import Sequence
import jax
from jax import lax

View File

@ -23,7 +23,7 @@ import shutil
import sys
import subprocess
import glob
from typing import Sequence
from collections.abc import Sequence
def is_windows() -> bool:
@ -88,7 +88,7 @@ def build_editable(
def update_setup_with_cuda_version(file_dir: pathlib.Path, cuda_version: str):
src_file = file_dir / "setup.py"
with open(src_file, "r") as f:
with open(src_file) as f:
content = f.read()
content = content.replace(
"cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}"

View File

@ -72,7 +72,7 @@ class PrimitiveTest(jtu.JaxTestCase):
if d.platform not in cls.platforms:
cls.platforms.append(d.platform)
cls.devices.append(d)
super(PrimitiveTest, cls).setUpClass()
super().setUpClass()
# For each primitive we export for all platforms that are available and
# compare the results of running the exported code and running the native

View File

@ -18,7 +18,6 @@ import functools
import logging
import math
import re
from typing import Optional
import unittest
from absl.testing import absltest
@ -98,7 +97,7 @@ mlir.register_lowering(testing_primitive_with_effect_p,
_testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.)
def _testing_multi_platform_func(x, *,
effect_class_name: Optional[str] = None):
effect_class_name: str | None = None):
# Behaves like x + 2 * _testing_multi_platform_to_add[platform]
def for_platform(platform: str):
if effect_class_name is None:
@ -142,7 +141,7 @@ class JaxExportTest(jtu.JaxTestCase):
except RuntimeError:
continue
cls.platforms.append(backend)
super(JaxExportTest, cls).setUpClass()
super().setUpClass()
def setUp(self):
super().setUp()

View File

@ -266,7 +266,7 @@ class MemoriesComputationTest(jtu.BufferDonationTestCase):
expected_out = np_inp * 2
self.assertArraysEqual(out, expected_out)
self.assertEqual(out.sharding, s.with_memory_kind(("unpinned_host")))
self.assertEqual(out.sharding, s.with_memory_kind("unpinned_host"))
self._check_mem_kind(executable_mk[0], out.sharding, "unpinned_host")
for s in out.addressable_shards:
self.assertArraysEqual(s.data, expected_out[s.index])

View File

@ -65,7 +65,7 @@ def mock_tpu_devices(x, y, z, dev_kind, one_device_per_chip, num_slices=1,
# If this function raises, it's a bug in the test code!
def _validate_mocked_process_indices(devices, one_device_per_chip):
process_to_devices = collections.defaultdict(lambda: [])
process_to_devices = collections.defaultdict(list)
for d in devices:
process_to_devices[d.process_index].append(d)