From 36f6b52e4203766f79ec22326e2dce2fdbcb310c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 8 Dec 2023 12:09:04 +0000 Subject: [PATCH] Upgrade most .py sources to 3.9 This commit was generated by running pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py --- jax/_src/api.py | 4 +- jax/_src/basearray.py | 3 +- jax/_src/config.py | 56 +++++++-------- jax/_src/dtypes.py | 14 ++-- jax/_src/extend/random.py | 3 +- jax/_src/interpreters/mlir.py | 10 +-- jax/_src/interpreters/pxla.py | 7 +- jax/_src/jaxpr_util.py | 2 +- jax/_src/lax/control_flow/loops.py | 4 +- jax/_src/lax/lax.py | 62 ++++++++--------- jax/_src/mesh.py | 2 +- jax/_src/numpy/lax_numpy.py | 12 ++-- jax/_src/numpy/util.py | 12 ++-- jax/_src/pallas/core.py | 3 +- jax/_src/pallas/indexing.py | 12 ++-- jax/_src/pallas/mosaic/lowering.py | 3 +- jax/_src/pallas/pallas_call.py | 19 ++--- jax/_src/pallas/triton/lowering.py | 21 +++--- jax/_src/pallas/utils.py | 3 +- jax/_src/random.py | 69 ++++++++++--------- jax/_src/sharding_impls.py | 6 +- jax/_src/test_util.py | 8 +-- jax/_src/tree_util.py | 4 +- jax/_src/xla_bridge.py | 4 +- jax/experimental/array_api/_array_methods.py | 2 +- .../array_api/_manipulation_functions.py | 18 ++--- jax/experimental/jax2tf/jax2tf.py | 2 +- jax/experimental/jax2tf/tests/call_tf_test.py | 2 +- jax/experimental/jax2tf/tests/jax2tf_test.py | 2 +- jax/experimental/pallas/ops/attention.py | 6 +- jax/experimental/pallas/ops/tpu/all_gather.py | 2 +- jax/tools/build_utils.py | 4 +- tests/export_harnesses_multi_platform_test.py | 2 +- tests/export_test.py | 5 +- tests/memories_test.py | 2 +- tests/mesh_utils_test.py | 2 +- 36 files changed, 198 insertions(+), 194 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 616036d21..59970b413 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -29,7 +29,7 @@ import inspect import math import typing from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload, - cast, Optional) + cast) import weakref import numpy as np @@ -2461,7 +2461,7 @@ def make_jaxpr(fun: Callable, make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})" return make_jaxpr_f -def _infer_src_sharding(src, x) -> Optional[Sharding]: +def _infer_src_sharding(src, x) -> Sharding | None: if src is not None: return src if isinstance(x, array.ArrayImpl): diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index c96374636..ea881a723 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -16,7 +16,8 @@ import abc import numpy as np -from typing import Any, Sequence, Union +from typing import Any, Union +from collections.abc import Sequence # TODO(jakevdp): fix import cycles and define these. Shard = Any diff --git a/jax/_src/config.py b/jax/_src/config.py index 6fb88b173..1c858f27f 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -22,7 +22,7 @@ import logging import os import sys import threading -from typing import Any, Callable, Generic, NamedTuple, NoReturn, Optional, TypeVar +from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar import warnings from jax._src import lib @@ -134,7 +134,7 @@ class Config: raise AttributeError(f"Unrecognized config option: {name}") def add_option(self, name, default, opt_type, meta_args, meta_kwargs, - update_hook: Optional[Callable[[Any], None]] = None): + update_hook: Callable[[Any], None] | None = None): if name in self.values: raise Exception(f"Config option {name} already defined") self.values[name] = default @@ -238,7 +238,7 @@ _thread_local_state = threading.local() class _StateContextManager(Generic[_T]): def __init__(self, name, help, update_thread_local_hook, - validate_new_val_hook: Optional[Callable[[Any], None]] = None, + validate_new_val_hook: Callable[[Any], None] | None = None, extra_description: str = "", default_value: Any = no_default): self._name = name self.__name__ = name[4:] if name.startswith('jax_') else name @@ -302,8 +302,8 @@ def define_bool_state( default: bool, help: str, *, - update_global_hook: Optional[Callable[[bool], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[bool]], None]] = None, + update_global_hook: Callable[[bool], None] | None = None, + update_thread_local_hook: Callable[[bool | None], None] | None = None, upgrade: bool = False, extra_description: str = '', ) -> _StateContextManager[bool]: @@ -375,11 +375,11 @@ def define_bool_state( def define_enum_state( name: str, enum_values: list[str], - default: Optional[str], + default: str | None, help: str, *, - update_global_hook: Optional[Callable[[str], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None, + update_global_hook: Callable[[str], None] | None = None, + update_thread_local_hook: Callable[[str | None], None] | None = None, ) -> _StateContextManager[str]: """Set up thread-local state and return a contextmanager for managing it. @@ -420,11 +420,11 @@ def define_enum_state( def define_int_state( name: str, - default: Optional[int], + default: int | None, help: str, *, - update_global_hook: Optional[Callable[[str], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None, + update_global_hook: Callable[[str], None] | None = None, + update_thread_local_hook: Callable[[str | None], None] | None = None, ) -> _StateContextManager[int]: """Set up thread-local state and return a contextmanager for managing it. @@ -463,11 +463,11 @@ def define_int_state( def define_float_state( name: str, - default: Optional[float], + default: float | None, help: str, *, - update_global_hook: Optional[Callable[[str], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None, + update_global_hook: Callable[[str], None] | None = None, + update_thread_local_hook: Callable[[str | None], None] | None = None, ) -> _StateContextManager[float]: """Set up thread-local state and return a contextmanager for managing it. @@ -508,11 +508,11 @@ def define_float_state( def define_string_state( name: str, - default: Optional[str], + default: str | None, help: str, *, - update_global_hook: Optional[Callable[[str], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None, + update_global_hook: Callable[[str], None] | None = None, + update_thread_local_hook: Callable[[str | None], None] | None = None, ) -> _StateContextManager[str]: """Set up thread-local state and return a contextmanager for managing it. @@ -552,9 +552,9 @@ def define_string_or_object_state( default: Any, help: str, *, - update_global_hook: Optional[Callable[[Any], None]] = None, - update_thread_local_hook: Optional[Callable[[Any], None]] = None, - validate_new_val_hook: Optional[Callable[[Any], None]] = None, + update_global_hook: Callable[[Any], None] | None = None, + update_thread_local_hook: Callable[[Any], None] | None = None, + validate_new_val_hook: Callable[[Any], None] | None = None, ) -> _StateContextManager[Any]: """Set up thread-local state and return a contextmanager for managing it. @@ -651,9 +651,9 @@ already_configured_with_absl = False # a global/thread-local state. These methods allow updates to part of the # state when a configuration value changes. class _GlobalExtraJitContext(NamedTuple): - numpy_rank_promotion: Optional[str] = None - numpy_dtype_promotion: Optional[str] = None - default_matmul_precision: Optional[Any] = None + numpy_rank_promotion: str | None = None + numpy_dtype_promotion: str | None = None + default_matmul_precision: Any | None = None dynamic_shapes: bool = False threefry_partitionable: bool = False softmax_custom_jvp: bool = False @@ -675,12 +675,12 @@ class _ThreadLocalExtraJitContext(NamedTuple): The initialization, which uses both config.py and core.py is done using `_update_thread_local_jit_state` in core.py to prevent circular imports. """ - dynamic_trace_state: Optional[Any] = None + dynamic_trace_state: Any | None = None axis_env_state: Hashable = () mesh_context_manager: Hashable = () - numpy_rank_promotion: Optional[str] = None - numpy_dtype_promotion: Optional[str] = None - default_matmul_precision: Optional[Any] = None + numpy_rank_promotion: str | None = None + numpy_dtype_promotion: str | None = None + default_matmul_precision: Any | None = None dynamic_shapes: bool = False threefry_partitionable: bool = False softmax_custom_jvp: bool = False @@ -1320,7 +1320,7 @@ def transfer_guard(new_val: str) -> Iterator[None]: yield -def _update_debug_log_modules(module_names_str: Optional[str]): +def _update_debug_log_modules(module_names_str: str | None): logging_config.disable_all_debug_logging() if not module_names_str: return diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 10c53942c..2356a61f3 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -24,7 +24,7 @@ from __future__ import annotations import abc import builtins import functools -from typing import cast, overload, Any, Literal, Optional, Union +from typing import cast, overload, Any, Literal, Union import warnings import ml_dtypes @@ -207,7 +207,7 @@ def to_complex_dtype(dtype: DTypeLike) -> DType: @functools.cache -def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: Any) -> Union[DType, ExtendedDType]: +def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: Any) -> DType | ExtendedDType: if issubdtype(dtype, extended): if not allow_extended_dtype: raise ValueError(f"Internal: canonicalize_dtype called on extended dtype {dtype} " @@ -227,10 +227,10 @@ def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: An def canonicalize_dtype(dtype: Any, allow_extended_dtype: Literal[False] = False) -> DType: ... @overload -def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> Union[DType, ExtendedDType]: ... +def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> DType | ExtendedDType: ... @export -def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> Union[DType, ExtendedDType]: +def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> DType | ExtendedDType: """Convert from a dtype to a canonical dtype based on config.x64_enabled.""" return _canonicalize_dtype(config.enable_x64.value, allow_extended_dtype, dtype) # pytype: disable=bad-return-type @@ -292,7 +292,7 @@ def _scalar_type_to_dtype(typ: type, value: Any = None) -> DType: return dtype -def coerce_to_array(x: Any, dtype: Optional[DTypeLike] = None) -> np.ndarray: +def coerce_to_array(x: Any, dtype: DTypeLike | None = None) -> np.ndarray: """Coerces a scalar or NumPy array to an np.array. Handles Python scalar type promotion according to JAX's rules, not NumPy's @@ -643,10 +643,10 @@ def result_type(*args: Any, return_weak_type_flag: Literal[True]) -> tuple[DType def result_type(*args: Any, return_weak_type_flag: Literal[False] = False) -> DType: ... @overload -def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, tuple[DType, bool]]: ... +def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tuple[DType, bool]: ... @export -def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, tuple[DType, bool]]: +def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tuple[DType, bool]: """Convenience function to apply JAX argument dtype promotion. Args: diff --git a/jax/_src/extend/random.py b/jax/_src/extend/random.py index 250567a9c..ffc7f6307 100644 --- a/jax/_src/extend/random.py +++ b/jax/_src/extend/random.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Hashable +from typing import Callable +from collections.abc import Hashable from jax import Array diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 1078b51bd..1604e6539 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -25,7 +25,7 @@ import itertools import operator import re import typing -from typing import Any, Callable, NamedTuple, Optional, Protocol, Union +from typing import Any, Callable, NamedTuple, Protocol, Union import warnings import numpy as np @@ -707,7 +707,7 @@ def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | return layout._to_xla_layout() -def _get_mem_kind(s: Optional[XLACompatibleSharding]) -> Optional[str]: +def _get_mem_kind(s: XLACompatibleSharding | None) -> str | None: if s is None: return None assert isinstance(s, sharding_impls.XLACompatibleSharding) @@ -1454,7 +1454,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, with source_info_util.user_context(eqn.source_info.traceback), loc: override_rule = get_override_lowering_rule(eqn.primitive) platform_rules: dict[str, LoweringRule] = {} - default_rule: Optional[LoweringRule] = None + default_rule: LoweringRule | None = None # See mlir.lower_per_platform for meaning of `platform_rules` and `default_rule` if override_rule is not None: default_rule = override_rule @@ -1525,7 +1525,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, def lower_per_platform(ctx: LoweringRuleContext, description: str, platform_rules: dict[str, LoweringRule], - default_rule: Optional[LoweringRule], + default_rule: LoweringRule | None, effects: effects_lib.Effects, *rule_args: ir.Value, **rule_kwargs) -> ir.Value: @@ -1710,7 +1710,7 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, return func_op -def check_backend_matches(inner_backend: Optional[str], +def check_backend_matches(inner_backend: str | None, lowering_platforms: Sequence[str]): # For nested calls, the outermost call sets the backend for all inner calls; # it's an error if the inner call has a conflicting explicit backend spec. diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a8bcdac2b..ea59f923d 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -25,7 +25,8 @@ import itertools as it import logging import math import threading -from typing import (Any, Callable, NamedTuple, Iterator, Optional, Union, cast, TypeVar) +from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar) +from collections.abc import Iterator import warnings import numpy as np @@ -2914,8 +2915,8 @@ def _compile_replicated_mesh_executable_from_hlo( @lru_cache def create_mesh_pspec_sharding( - mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None, - memory_kind: Optional[str] = None) -> sharding_impls.NamedSharding: + mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None, + memory_kind: str | None = None) -> sharding_impls.NamedSharding: if pspec is None: pspec, parsed_pspec = PartitionSpec(), None return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec, diff --git a/jax/_src/jaxpr_util.py b/jax/_src/jaxpr_util.py index 7fa08e6d3..d7b30e878 100644 --- a/jax/_src/jaxpr_util.py +++ b/jax/_src/jaxpr_util.py @@ -202,7 +202,7 @@ def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes: pprof tool for visualization. """ d: DefaultDict[tuple[Optional[xla_client.Traceback], core.Primitive], int] - d = collections.defaultdict(lambda: 0) + d = collections.defaultdict(int) for _, eqn in all_eqns(jaxpr): d[(eqn.source_info.traceback, eqn.primitive)] += 1 return _pprof_profile(d) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 86eaa8004..9eaa37577 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -19,7 +19,7 @@ from functools import partial import inspect import itertools import operator -from typing import Any, Callable, Optional, TypeVar +from typing import Any, Callable, TypeVar import jax import weakref @@ -104,7 +104,7 @@ Y = TypeVar('Y') def scan(f: Callable[[Carry, X], tuple[Carry, Y]], init: Carry, xs: X, - length: Optional[int] = None, + length: int | None = None, reverse: bool = False, unroll: int | bool = 1) -> tuple[Carry, Y]: """Scan a function over leading array axes while carrying along state. diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9ca1d4318..b734bf8a5 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -22,7 +22,7 @@ from functools import partial import itertools import math import operator -from typing import (Any, Callable, Optional, TypeVar, Union, +from typing import (Any, Callable, TypeVar, Union, cast as type_cast, overload) import warnings @@ -104,7 +104,7 @@ def _validate_shapes(shapes: Sequence[Shape]): map(_check_static_shape, shapes) def _try_broadcast_shapes( - shapes: Sequence[tuple[int, ...]]) -> Optional[tuple[int, ...]]: + shapes: Sequence[tuple[int, ...]]) -> tuple[int, ...] | None: if len(shapes) == 1: return shapes[0] ranks = {len(shape) for shape in shapes} if len(ranks) > 1: return None # must have consistent rank @@ -140,8 +140,8 @@ def asarray(x: ArrayLike) -> Array: def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: ... @overload -def broadcast_shapes(*shapes: tuple[Union[int, core.Tracer], ...] - ) -> tuple[Union[int, core.Tracer], ...]: ... +def broadcast_shapes(*shapes: tuple[int | core.Tracer, ...] + ) -> tuple[int | core.Tracer, ...]: ... def broadcast_shapes(*shapes): """Returns the shape that results from NumPy broadcasting of `shapes`.""" @@ -183,8 +183,8 @@ def _broadcast_ranks(s1, s2): def _identity(x): return x def _extract_tracers_dyn_shape( - shape: Sequence[Union[int, core.Tracer]] - ) -> tuple[list[core.Tracer], list[Optional[int]]]: + shape: Sequence[int | core.Tracer] + ) -> tuple[list[core.Tracer], list[int | None]]: # Given a sequence representing a shape, pull out Tracers, replacing with None if config.dynamic_shapes.value: # We must gate this behavior under a flag because otherwise the errors @@ -196,9 +196,9 @@ def _extract_tracers_dyn_shape( return [], list(shape) # type: ignore def _merge_dyn_shape( - static_shape: Sequence[Optional[int]], + static_shape: Sequence[int | None], dyn_shape: Sequence[Any], - ) -> tuple[Union[int, mlir.Value, core.Tracer], ...]: + ) -> tuple[int | mlir.Value | core.Tracer, ...]: # Replace Nones in static_shape with elements of dyn_shape, in order dyn_shape_it = iter(dyn_shape) shape = tuple(next(dyn_shape_it) if d is None else d for d in static_shape) @@ -516,7 +516,7 @@ def convert_element_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: """ return _convert_element_type(operand, new_dtype, weak_type=False) -def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DTypeLike] = None, +def _convert_element_type(operand: ArrayLike, new_dtype: DTypeLike | None = None, weak_type: bool = False): if hasattr(operand, '__jax_array__'): operand = operand.__jax_array__() # type: ignore @@ -599,7 +599,7 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array: """ return clamp_p.bind(min, x, max) -def concatenate(operands: Union[Array, Sequence[ArrayLike]], dimension: int) -> Array: +def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: """Concatenates a sequence of arrays along `dimension`. Wraps XLA's `Concatenate @@ -677,7 +677,7 @@ PrecisionLike = Union[None, str, PrecisionType, tuple[str, str], tuple[PrecisionType, PrecisionType]] def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, - preferred_element_type: Optional[DTypeLike] = None) -> Array: + preferred_element_type: DTypeLike | None = None) -> Array: """Vector/vector, matrix/vector, and matrix/matrix multiplication. Wraps XLA's `Dot @@ -714,7 +714,7 @@ DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]], def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, - preferred_element_type: Optional[DTypeLike] = None) -> Array: + preferred_element_type: DTypeLike | None = None) -> Array: """General dot product/contraction operator. Wraps XLA's `DotGeneral @@ -813,7 +813,7 @@ def broadcast_to_rank(x: Array, rank: int) -> Array: return broadcast(x, (1,) * (rank - x.ndim)) def reshape(operand: ArrayLike, new_sizes: Shape, - dimensions: Optional[Sequence[int]] = None) -> Array: + dimensions: Sequence[int] | None = None) -> Array: """Wraps XLA's `Reshape `_ operator. @@ -1042,7 +1042,7 @@ def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree): return jaxpr, tuple(consts), out_tree() def _get_monoid_reducer(monoid_op: Callable, - xs: Sequence[Array]) -> Optional[Callable]: + xs: Sequence[Array]) -> Callable | None: if len(xs) != 1: return None x, = xs @@ -1128,8 +1128,8 @@ def sort(operand: Array, dimension: int = -1, def sort(operand: Sequence[Array], dimension: int = -1, is_stable: bool = True, num_keys: int = 1) -> tuple[Array, ...]: ... -def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1, - is_stable: bool = True, num_keys: int = 1) -> Union[Array, tuple[Array, ...]]: +def sort(operand: Array | Sequence[Array], dimension: int = -1, + is_stable: bool = True, num_keys: int = 1) -> Array | tuple[Array, ...]: """Wraps XLA's `Sort `_ operator. @@ -1196,7 +1196,7 @@ def tie_in(x: Any, y: T) -> T: """Deprecated. Ignores ``x`` and returns ``y``.""" return y -def full(shape: Shape, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None) -> Array: +def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None) -> Array: """Returns an array of `shape` filled with `fill_value`. Args: @@ -1303,7 +1303,7 @@ def stop_gradient(x: T) -> T: return x return tree_map(stop, x) -def reduce_precision(operand: Union[float, ArrayLike], +def reduce_precision(operand: float | ArrayLike, exponent_bits: int, mantissa_bits: int) -> Array: """Wraps XLA's `ReducePrecision @@ -1342,9 +1342,9 @@ def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array: ### convenience wrappers around traceables -def full_like(x: Union[ArrayLike, DuckTypedArray], - fill_value: ArrayLike, dtype: Optional[DTypeLike] = None, - shape: Optional[Shape] = None) -> Array: +def full_like(x: ArrayLike | DuckTypedArray, + fill_value: ArrayLike, dtype: DTypeLike | None = None, + shape: Shape | None = None) -> Array: """Create a full array like np.full based on the example array `x`. Args: @@ -1380,7 +1380,7 @@ def full_like(x: Union[ArrayLike, DuckTypedArray], def collapse(operand: Array, start_dimension: int, - stop_dimension: Optional[int] = None) -> Array: + stop_dimension: int | None = None) -> Array: """Collapses dimensions of an array into a single dimension. For example, if ``operand`` is an array with shape ``[2, 3, 4]``, @@ -1691,7 +1691,7 @@ def broadcast_hlo( return out def _nary_lower_hlo(op: Callable, ctx, - *args: Union[ir.Value, Sequence[ir.Value]], + *args: ir.Value | Sequence[ir.Value], explicit_type=False, **params) -> Sequence[ir.Value]: """Lowers an elementwise operator to its MLIR equivalent. @@ -2516,7 +2516,7 @@ def _precision_config(precision): def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: Optional[DTypeLike]): + preferred_element_type: DTypeLike | None): (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim)) for d in (lhs_contracting, lhs_batch)): @@ -2592,7 +2592,7 @@ def tuple_delete(tup, idx): def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: Optional[DTypeLike]): + preferred_element_type: DTypeLike | None): # We're mostly matching XLA's logic here, namely in shape_inference.cc and # primitive_util.h's HigherPrecisionType, e.g. # https://github.com/openxla/xla/blob/ea3a841768d0dcf192e5820c9b25c34c73f2226a/xla/primitive_util.h#L329 @@ -2636,7 +2636,7 @@ def _maybe_upcast(result_dtype, preferred_element_type): return preferred_element_type def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, - preferred_element_type: Optional[DTypeLike], + preferred_element_type: DTypeLike | None, swap_ans=False): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers x_ndim = x.aval.ndim @@ -2657,7 +2657,7 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, return x_bar def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, - preferred_element_type: Optional[DTypeLike]): + preferred_element_type: DTypeLike | None): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) y_bar = _dot_general_transpose_lhs( @@ -2670,7 +2670,7 @@ def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, precision, - preferred_element_type: Optional[DTypeLike]): + preferred_element_type: DTypeLike | None): lhs, rhs = batched_args lbd, rbd = batch_dims (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers @@ -2794,7 +2794,7 @@ def precision_attr(precision: PrecisionType) -> ir.ArrayAttr: def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, - precision, preferred_element_type: Optional[np.dtype], + precision, preferred_element_type: np.dtype | None, platform: str = "default"): del preferred_element_type # Implied by the output aval lhs_aval, rhs_aval = ctx.avals_in @@ -4540,7 +4540,7 @@ def _array_copy(arr: ArrayLike) -> Array: return copy_p.bind(arr) -def _which_dim_sharded(s: PmapSharding) -> Optional[int]: +def _which_dim_sharded(s: PmapSharding) -> int | None: sharded_dim = None for i, s in enumerate(s.sharding_spec.sharding): if isinstance(s, pxla.Unstacked): @@ -4898,7 +4898,7 @@ def remaining(original, *removed_lists): return [i for i in original if i not in removed] -def canonicalize_precision(precision: PrecisionLike) -> Optional[tuple[PrecisionType, PrecisionType]]: +def canonicalize_precision(precision: PrecisionLike) -> tuple[PrecisionType, PrecisionType] | None: """Turns an API precision specification, into a pair of enumeration values. The API can take the precision as a string, or int, and either as a single diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 55a1c68f7..a5c9d1f2c 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -197,7 +197,7 @@ class Mesh(contextlib.ContextDecorator): if val is not None: return val - self = super(Mesh, cls).__new__(cls) + self = super().__new__(cls) self.devices = devices.copy() self.devices.flags.writeable = False self.axis_names = axis_names diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index fe2626701..280263eea 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -32,8 +32,7 @@ from functools import partial import math import operator import types -from typing import (overload, Any, Callable, Literal, NamedTuple, Optional, - Protocol, TypeVar, Union) +from typing import (overload, Any, Callable, Literal, NamedTuple, Protocol, TypeVar, Union) from textwrap import dedent as _dedent import warnings @@ -2443,7 +2442,7 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: ``(jnp.arange(-600, 600) * .01).astype(jnp.bfloat16)`` to generate a sequence in a higher precision and then convert it to the desired lower precision. """) -def arange(start: DimSize, stop: Optional[DimSize] = None, +def arange(start: DimSize, stop: DimSize | None = None, step: DimSize | None = None, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "arange") if not config.dynamic_shapes.value: @@ -2480,8 +2479,7 @@ def arange(start: DimSize, stop: Optional[DimSize] = None, return lax.iota(dtype, start) else: if step is None and start == 0 and stop is not None: - stop = np.ceil(stop).astype(int) - return lax.iota(dtype, stop) + return lax.iota(dtype, np.ceil(stop).astype(int)) return array(np.arange(start, stop=stop, step=step, dtype=dtype)) @@ -2833,7 +2831,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, @util._wraps(np.tri) -def tri(N: int, M: int | None = None, k: int = 0, dtype: Optional[DTypeLike] = None) -> Array: +def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "tri") M = M if M is not None else N dtype = dtype or float32 @@ -3933,7 +3931,7 @@ def sort_complex(a: ArrayLike) -> Array: @util._wraps(np.lexsort) @partial(jit, static_argnames=('axis',)) -def lexsort(keys: Union[Array, np.ndarray, Sequence[ArrayLike]], axis: int = -1) -> Array: +def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array: key_tuple = tuple(keys) # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("lexsort", *key_tuple, emit_warning=True) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 616648120..3061de85a 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -17,7 +17,7 @@ from collections.abc import Sequence from functools import partial import re import textwrap -from typing import Any, Callable, NamedTuple, Optional, TypeVar +from typing import Any, Callable, NamedTuple, TypeVar import warnings @@ -50,14 +50,14 @@ class ParsedDoc(NamedTuple): front_matter: front matter before sections. sections: dictionary of section titles to section content. """ - docstr: Optional[str] + docstr: str | None signature: str = "" summary: str = "" front_matter: str = "" sections: dict[str, str] = {} -def _parse_numpydoc(docstr: Optional[str]) -> ParsedDoc: +def _parse_numpydoc(docstr: str | None) -> ParsedDoc: """Parse a standard numpy-style docstring. Args: @@ -118,13 +118,13 @@ def _parse_extra_params(extra_params: str) -> dict[str, str]: def _wraps( - fun: Optional[Callable[..., Any]], + fun: Callable[..., Any] | None, update_doc: bool = True, lax_description: str = "", sections: Sequence[str] = ('Parameters', 'Returns', 'References'), skip_params: Sequence[str] = (), - extra_params: Optional[str] = None, - module: Optional[str] = None, + extra_params: str | None = None, + module: str | None = None, ) -> Callable[[_T], _T]: """Specialized version of functools.wraps for wrapping numpy functions. diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 0413df0e5..563779aef 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -19,7 +19,8 @@ from collections.abc import Sequence import contextlib import dataclasses import functools -from typing import Any, Callable, Iterator +from typing import Any, Callable +from collections.abc import Iterator from jax._src import api_util from jax._src import core as jax_core diff --git a/jax/_src/pallas/indexing.py b/jax/_src/pallas/indexing.py index f444b0e3a..390db1cbd 100644 --- a/jax/_src/pallas/indexing.py +++ b/jax/_src/pallas/indexing.py @@ -17,7 +17,7 @@ from __future__ import annotations import dataclasses -from typing import Any, Tuple +from typing import Any import jax from jax import core as jax_core @@ -37,7 +37,7 @@ import numpy as np # to `tl.broadcast_to`. broadcast_to_p = jax_core.Primitive('broadcast_to') -def broadcast_to(a: jax.Array, shape: Tuple[int, ...]) -> jax.Array: +def broadcast_to(a: jax.Array, shape: tuple[int, ...]) -> jax.Array: if a.shape == shape: return a return broadcast_to_p.bind(a, shape=shape) @@ -99,9 +99,9 @@ ds = dslice # Handy alias @tree_util.register_pytree_node_class @dataclasses.dataclass class NDIndexer: - indices: Tuple[int | Slice | jax.Array, ...] - shape: Tuple[int, ...] - int_indexer_shape: Tuple[int, ...] + indices: tuple[int | Slice | jax.Array, ...] + shape: tuple[int, ...] + int_indexer_shape: tuple[int, ...] def __post_init__(self): if len(self.indices) != len(self.shape): @@ -148,7 +148,7 @@ class NDIndexer: indices = merge_lists(is_int_indexing, other_indexers, int_indexers) return NDIndexer(tuple(indices), shape, bcast_shape) - def get_indexer_shape(self) -> Tuple[int, ...]: + def get_indexer_shape(self) -> tuple[int, ...]: is_int_indexing = [not isinstance(i, Slice) for i in self.indices] other_indexers, _ = partition_list(is_int_indexing, self.indices) other_shape = [s.size for s in other_indexers] # type: ignore diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 286fdef8b..35020d1be 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -17,7 +17,8 @@ from __future__ import annotations import dataclasses import functools -from typing import Any, Callable, Sequence +from typing import Any, Callable +from collections.abc import Sequence from jax import core as jax_core from jax import lax diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index f248063ee..9ed084e72 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -18,7 +18,8 @@ from __future__ import annotations from functools import partial import itertools as it -from typing import Any, Callable, Dict, Sequence, Tuple +from typing import Any, Callable +from collections.abc import Sequence import jax from jax import api_util @@ -89,7 +90,7 @@ def _uninitialized_value(shape, dtype): def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, interpret, debug: bool, in_shapes, - input_output_aliases: Tuple[Tuple[int, int], ...], + input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, **compiler_params: Any): if interpret: @@ -189,7 +190,7 @@ def _pallas_call_abstract_eval(*avals, out_shapes, **_): pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval) def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, - input_output_aliases: Tuple[Tuple[int, int], ...], + input_output_aliases: tuple[tuple[int, int], ...], in_shapes, out_shapes, grid_mapping, debug, interpret, **compiler_params: Any): if grid_mapping.num_index_operands: raise NotImplementedError @@ -236,7 +237,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, return out_primals, out_tangents ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule -def _batch_block_mapping(grid: Tuple[int, ...], aval: jax_core.ShapedArray, +def _batch_block_mapping(grid: tuple[int, ...], aval: jax_core.ShapedArray, dim: int | batching.NotMapped, block_mapping: BlockMapping | None) -> BlockMapping: def _block_map_function(new_idx, *args): @@ -271,13 +272,13 @@ def _batch_block_mapping(grid: Tuple[int, ...], aval: jax_core.ShapedArray, def _pallas_call_batching_rule(args, dims, *, jaxpr: jax_core.Jaxpr, name: str, - in_shapes: Tuple[jax.ShapeDtypeStruct, ...], - out_shapes: Tuple[jax.ShapeDtypeStruct, ...], + in_shapes: tuple[jax.ShapeDtypeStruct, ...], + out_shapes: tuple[jax.ShapeDtypeStruct, ...], grid_mapping: GridMapping, - input_output_aliases: Tuple[Tuple[int, int], ...], + input_output_aliases: tuple[tuple[int, int], ...], debug: bool, interpret: bool, - which_linear: Tuple[bool, ...], + which_linear: tuple[bool, ...], **compiler_params: Any): if grid_mapping.num_index_operands: scalar_batch_dims = dims[:grid_mapping.num_index_operands] @@ -428,7 +429,7 @@ def pallas_call( in_specs: Sequence[BlockSpec | NoBlockSpec] | NoBlockSpec = no_block_spec, out_specs: BlockSpec | NoBlockSpec | Sequence[BlockSpec | NoBlockSpec] = no_block_spec, - input_output_aliases: Dict[int, int] = {}, + input_output_aliases: dict[int, int] = {}, interpret: bool = False, name: str | None = None, **compiler_params: Any, diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 6b1a95d17..1931bc25d 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -18,7 +18,8 @@ from __future__ import annotations import dataclasses import functools import operator -from typing import Any, Callable, Dict, Sequence, Tuple +from typing import Any, Callable +from collections.abc import Sequence import zlib import jax @@ -67,7 +68,7 @@ import triton.language as tl map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip partial = functools.partial -Grid = Tuple[int, ...] +Grid = tuple[int, ...] NDIndexer = indexing.NDIndexer GridMapping = pallas_core.GridMapping BlockMapping = pallas_core.BlockMapping @@ -88,7 +89,7 @@ class TritonModuleContext: class BlockInfo: full_shape_dtype: jax.ShapeDtypeStruct start_indices: Sequence[Any] - block_shape: Tuple[int, ...] + block_shape: tuple[int, ...] @dataclasses.dataclass @@ -111,7 +112,7 @@ class TritonLoweringResult: ir_context: tl_ir.context module: tl_ir.module builder: tl_ir.builder - grid: Tuple[int, ...] + grid: tuple[int, ...] @dataclasses.dataclass @@ -610,7 +611,7 @@ def _compute_pointers_from_indices( root_ptr: tl.core.tensor, block_info: BlockInfo | None, nd_indexer: NDIndexer, - array_shape: Tuple[int, ...], + array_shape: tuple[int, ...], builder: tl_ir.builder, ) -> tl.core.tensor: if block_info is None: @@ -1512,14 +1513,14 @@ def pallas_call_lowering( *in_nodes, jaxpr: jax_core.Jaxpr, name: str, - in_shapes: Tuple[jax.ShapeDtypeStruct, ...], - out_shapes: Tuple[jax.ShapeDtypeStruct, ...], - which_linear: Tuple[bool, ...], + in_shapes: tuple[jax.ShapeDtypeStruct, ...], + out_shapes: tuple[jax.ShapeDtypeStruct, ...], + which_linear: tuple[bool, ...], interpret: bool, debug: bool, - input_output_aliases: Tuple[Tuple[int, int], ...], + input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, - triton_params: Dict[str, Any] | None = None, + triton_params: dict[str, Any] | None = None, **compiler_params: Any, ): if interpret: diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index 54e2fb395..a14f3871b 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -15,7 +15,6 @@ """Pallas utility functions.""" import math import numpy as np -from typing import Tuple from jax import lax from jax._src import core as jax_core @@ -36,7 +35,7 @@ def cdiv(a: int, b: int) -> int: return (a + b - 1) // b -def strides_from_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]: +def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]: size = np.prod(shape) strides = [] for s in shape: diff --git a/jax/_src/random.py b/jax/_src/random.py index 4853390c3..fa23fecb2 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -19,7 +19,8 @@ from functools import partial import math from operator import index import typing -from typing import Hashable, Optional, Union +from typing import Union +from collections.abc import Hashable import warnings import numpy as np @@ -150,7 +151,7 @@ class PRNGSpec: PRNGSpecDesc = Union[str, PRNGSpec, PRNGImpl] -def resolve_prng_impl(impl_spec: Optional[PRNGSpecDesc]) -> PRNGImpl: +def resolve_prng_impl(impl_spec: PRNGSpecDesc | None) -> PRNGImpl: if impl_spec is None: return default_prng_impl() if type(impl_spec) is PRNGImpl: @@ -174,7 +175,7 @@ def resolve_prng_impl(impl_spec: Optional[PRNGSpecDesc]) -> PRNGImpl: def _key(ctor_name: str, seed: int | ArrayLike, - impl_spec: Optional[PRNGSpecDesc]) -> KeyArray: + impl_spec: PRNGSpecDesc | None) -> KeyArray: impl = resolve_prng_impl(impl_spec) if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key): raise TypeError( @@ -186,7 +187,7 @@ def _key(ctor_name: str, seed: int | ArrayLike, return prng.random_seed(seed, impl=impl) def key(seed: int | ArrayLike, *, - impl: Optional[PRNGSpecDesc] = None) -> KeyArray: + impl: PRNGSpecDesc | None = None) -> KeyArray: """Create a pseudo-random number generator (PRNG) key given an integer seed. The result is a scalar array with a key that indicates the default PRNG @@ -205,7 +206,7 @@ def key(seed: int | ArrayLike, *, return _key('key', seed, impl) def PRNGKey(seed: int | ArrayLike, *, - impl: Optional[PRNGSpecDesc] = None) -> KeyArray: + impl: PRNGSpecDesc | None = None) -> KeyArray: """Create a pseudo-random number generator (PRNG) key given an integer seed. The resulting key carries the default PRNG implementation, as @@ -277,7 +278,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: return _return_prng_keys(wrapped, key_out) -def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray: +def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray: # Alternative to split() to use within random samplers. # TODO(frostig): remove and use split(); we no longer need to wait # to always enable_custom_prng @@ -288,7 +289,7 @@ def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray: shape = tuple(num) if isinstance(num, Sequence) else (num,) return prng.random_split(key, shape=shape) -def split(key: KeyArrayLike, num: Union[int, tuple[int, ...]] = 2) -> KeyArray: +def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: """Splits a PRNG key into `num` new keys by adding a leading axis. Args: @@ -324,7 +325,7 @@ def key_data(keys: KeyArrayLike) -> Array: def wrap_key_data(key_bits_array: Array, *, - impl: Optional[PRNGSpecDesc] = None): + impl: PRNGSpecDesc | None = None): """Wrap an array of key data bits into a PRNG key array. Args: @@ -344,7 +345,7 @@ def wrap_key_data(key_bits_array: Array, *, ### random samplers -def _check_shape(name: str, shape: Union[Shape, NamedShape], *param_shapes) -> None: +def _check_shape(name: str, shape: Shape | NamedShape, *param_shapes) -> None: shape = core.as_named_shape(shape) if param_shapes: @@ -358,7 +359,7 @@ def _check_shape(name: str, shape: Union[Shape, NamedShape], *param_shapes) -> N def bits(key: KeyArrayLike, shape: Shape = (), - dtype: Optional[DTypeLikeUInt] = None) -> Array: + dtype: DTypeLikeUInt | None = None) -> Array: """Sample uniform bits in the form of unsigned integers. Args: @@ -386,7 +387,7 @@ def bits(key: KeyArrayLike, def uniform(key: KeyArrayLike, - shape: Union[Shape, NamedShape] = (), + shape: Shape | NamedShape = (), dtype: DTypeLikeFloat = float, minval: RealArray = 0., maxval: RealArray = 1.) -> Array: @@ -561,7 +562,7 @@ def shuffle(key: KeyArrayLike, x: ArrayLike, axis: int = 0) -> Array: def permutation(key: KeyArrayLike, - x: Union[int, ArrayLike], + x: int | ArrayLike, axis: int = 0, independent: bool = False) -> Array: """Returns a randomly permuted array or range. @@ -620,10 +621,10 @@ def _shuffle(key, x, axis) -> Array: def choice(key: KeyArrayLike, - a: Union[int, ArrayLike], + a: int | ArrayLike, shape: Shape = (), replace: bool = True, - p: Optional[RealArray] = None, + p: RealArray | None = None, axis: int = 0) -> Array: """Generates a random sample from a given array. @@ -697,7 +698,7 @@ def choice(key: KeyArrayLike, def normal(key: KeyArrayLike, - shape: Union[Shape, NamedShape] = (), + shape: Shape | NamedShape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample standard normal random values with given shape and float dtype. @@ -752,8 +753,8 @@ def _normal_real(key, shape, dtype) -> Array: def multivariate_normal(key: KeyArrayLike, mean: RealArray, cov: RealArray, - shape: Optional[Shape] = None, - dtype: Optional[DTypeLikeFloat] = None, + shape: Shape | None = None, + dtype: DTypeLikeFloat | None = None, method: str = 'cholesky') -> Array: r"""Sample multivariate normal random values with given mean and covariance. @@ -835,7 +836,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array: def truncated_normal(key: KeyArrayLike, lower: RealArray, upper: RealArray, - shape: Optional[Union[Shape, NamedShape]] = None, + shape: Shape | NamedShape | None = None, dtype: DTypeLikeFloat = float) -> Array: r"""Sample truncated standard normal random values with given shape and dtype. @@ -900,7 +901,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array: def bernoulli(key: KeyArrayLike, p: RealArray = np.float32(0.5), - shape: Optional[Union[Shape, NamedShape]] = None) -> Array: + shape: Shape | NamedShape | None = None) -> Array: r"""Sample Bernoulli random values with given shape and mean. The values are distributed according to the probability mass function: @@ -946,7 +947,7 @@ def _bernoulli(key, p, shape) -> Array: def beta(key: KeyArrayLike, a: RealArray, b: RealArray, - shape: Optional[Shape] = None, + shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: r"""Sample Beta random values with given shape and float dtype. @@ -1045,7 +1046,7 @@ def _cauchy(key, shape, dtype) -> Array: def dirichlet(key: KeyArrayLike, alpha: RealArray, - shape: Optional[Shape] = None, + shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: r"""Sample Dirichlet random values with given shape and float dtype. @@ -1284,7 +1285,7 @@ batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule def gamma(key: KeyArrayLike, a: RealArray, - shape: Optional[Shape] = None, + shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: r"""Sample Gamma random values with given shape and float dtype. @@ -1331,7 +1332,7 @@ def gamma(key: KeyArrayLike, def loggamma(key: KeyArrayLike, a: RealArray, - shape: Optional[Shape] = None, + shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: """Sample log-gamma random values with given shape and float dtype. @@ -1473,7 +1474,7 @@ def _poisson(key, lam, shape, dtype) -> Array: def poisson(key: KeyArrayLike, lam: RealArray, - shape: Optional[Shape] = None, + shape: Shape | None = None, dtype: DTypeLikeInt = int) -> Array: r"""Sample Poisson random values with given shape and integer dtype. @@ -1555,7 +1556,7 @@ def _gumbel(key, shape, dtype) -> Array: def categorical(key: KeyArrayLike, logits: RealArray, axis: int = -1, - shape: Optional[Shape] = None) -> Array: + shape: Shape | None = None) -> Array: """Sample random values from categorical distributions. Args: @@ -1669,7 +1670,7 @@ def _logistic(key, shape, dtype): def pareto(key: KeyArrayLike, b: RealArray, - shape: Optional[Shape] = None, + shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: r"""Sample Pareto random values with given shape and float dtype. @@ -1770,7 +1771,7 @@ def _t(key, df, shape, dtype) -> Array: def chisquare(key: KeyArrayLike, df: RealArray, - shape: Optional[Shape] = None, + shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: r"""Sample Chisquare random values with given shape and float dtype. @@ -1823,7 +1824,7 @@ def _chisquare(key, df, shape, dtype) -> Array: def f(key: KeyArrayLike, dfnum: RealArray, dfden: RealArray, - shape: Optional[Shape] = None, + shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: r"""Sample F-distribution random values with given shape and float dtype. @@ -2153,7 +2154,7 @@ def ball( def rayleigh(key: KeyArrayLike, scale: RealArray, - shape: Optional[Shape] = None, + shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: r"""Sample Rayleigh random values with given shape and float dtype. @@ -2206,7 +2207,7 @@ def _rayleigh(key, scale, shape, dtype) -> Array: def wald(key: KeyArrayLike, mean: RealArray, - shape: Optional[Shape] = None, + shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: r"""Sample Wald random values with given shape and float dtype. @@ -2264,7 +2265,7 @@ def _wald(key, mean, shape, dtype) -> Array: def geometric(key: KeyArrayLike, p: RealArray, - shape: Optional[Shape] = None, + shape: Shape | None = None, dtype: DTypeLikeInt = int) -> Array: r"""Sample Geometric random values with given shape and float dtype. @@ -2319,7 +2320,7 @@ def triangular(key: KeyArrayLike, left: RealArray, mode: RealArray, right: RealArray, - shape: Optional[Shape] = None, + shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: r"""Sample Triangular random values with given shape and float dtype. @@ -2382,7 +2383,7 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array: def lognormal(key: KeyArrayLike, sigma: RealArray = np.float32(1), - shape: Optional[Shape] = None, + shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: r""" Sample lognormal random values with given shape and float dtype. @@ -2588,7 +2589,7 @@ def binomial( key: KeyArray, n: RealArray, p: RealArray, - shape: Optional[Shape] = None, + shape: Shape | None = None, dtype: DTypeLikeFloat = float, ) -> Array: r"""Sample Binomial random values with given shape and float dtype. diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 6e7eb49ce..16eab2d4b 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -22,7 +22,7 @@ import enum import functools import itertools import math -from typing import Any, NamedTuple, Union, cast, Optional +from typing import Any, NamedTuple, Union, cast from jax._src import mesh as mesh_lib from jax._src.op_shardings import ( @@ -566,9 +566,9 @@ class PmapSharding(XLACompatibleSharding): def _op_sharding_to_pos_sharding( - op_sharding: Union[xc.OpSharding, xc.HloSharding], + op_sharding: xc.OpSharding | xc.HloSharding, device_assignment: Sequence[xc.Device], - memory_kind: Optional[str] = None) -> PositionalSharding: + memory_kind: str | None = None) -> PositionalSharding: if isinstance(op_sharding, xc.OpSharding): op_sharding = xc.HloSharding.from_proto(op_sharding) # type: ignore diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index d1338fec0..9befc8a36 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -24,7 +24,7 @@ import re import os import tempfile import textwrap -from typing import Any, Callable, Optional +from typing import Any, Callable import unittest import warnings import zlib @@ -919,7 +919,7 @@ class JaxTestCase(parameterized.TestCase): 'jax_legacy_prng_key': 'error', } - _compilation_cache_exit_stack: Optional[ExitStack] = None + _compilation_cache_exit_stack: ExitStack | None = None # TODO(mattjj): this obscures the error messages from failures, figure out how # to re-enable it @@ -1275,8 +1275,8 @@ def numpy_version(): def parameterized_filterable(*, kwargs: Sequence[dict[str, Any]], - testcase_name: Optional[Callable[[dict[str, Any]], str]] = None, - one_containing: Optional[str] = None, + testcase_name: Callable[[dict[str, Any]], str] | None = None, + one_containing: str | None = None, ): """ Decorator for named parameterized tests, with filtering. diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index d149ee02d..210c819a4 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -21,7 +21,7 @@ import functools from functools import partial import operator as op import textwrap -from typing import Any, Callable, NamedTuple, Type, TypeVar, Union, overload +from typing import Any, Callable, NamedTuple, TypeVar, Union, overload from jax._src import traceback_util from jax._src.lib import pytree @@ -737,7 +737,7 @@ def register_pytree_with_keys_class(cls: U) -> U: return cls -def register_static(cls: Type[H]) -> Type[H]: +def register_static(cls: type[H]) -> type[H]: """Registers `cls` as a pytree with no leaves. Instances are treated as static by `jax.jit`, `jax.pmap`, etc. This can be an diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index dc9db3459..64da2b020 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -32,7 +32,7 @@ import pkgutil import platform as py_platform import sys import threading -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import warnings from jax._src import config @@ -855,7 +855,7 @@ def default_backend() -> str: return get_backend(None).platform -def backend_pjrt_c_api_version(platform=None) -> Optional[Tuple[int, int]]: +def backend_pjrt_c_api_version(platform=None) -> Optional[tuple[int, int]]: """Returns the PJRT C API version of the backend. Returns None if the backend does not use PJRT C API. diff --git a/jax/experimental/array_api/_array_methods.py b/jax/experimental/array_api/_array_methods.py index 5eedd6a15..567094290 100644 --- a/jax/experimental/array_api/_array_methods.py +++ b/jax/experimental/array_api/_array_methods.py @@ -31,7 +31,7 @@ def _array_namespace(self, /, *, api_version: None | str = None): def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *, - stream: Optional[Union[int, Any]] = None): + stream: int | Any | None = None): if stream is not None: raise NotImplementedError("stream argument of array.to_device()") # The type of device is defined by Array.device. In JAX, this is a callable that diff --git a/jax/experimental/array_api/_manipulation_functions.py b/jax/experimental/array_api/_manipulation_functions.py index 411476f22..d405b846c 100644 --- a/jax/experimental/array_api/_manipulation_functions.py +++ b/jax/experimental/array_api/_manipulation_functions.py @@ -20,17 +20,17 @@ from jax import Array from jax.experimental.array_api._data_type_functions import result_type as _result_type -def broadcast_arrays(*arrays: Array) -> List[Array]: +def broadcast_arrays(*arrays: Array) -> list[Array]: """Broadcasts one or more arrays against one another.""" return jax.numpy.broadcast_arrays(*arrays) -def broadcast_to(x: Array, /, shape: Tuple[int]) -> Array: +def broadcast_to(x: Array, /, shape: tuple[int]) -> Array: """Broadcasts an array to a specified shape.""" return jax.numpy.broadcast_to(x, shape=shape) -def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array: +def concat(arrays: Union[tuple[Array, ...], list[Array]], /, *, axis: Optional[int] = 0) -> Array: """Joins a sequence of arrays along an existing axis.""" dtype = _result_type(*arrays) if axis is None: @@ -46,34 +46,34 @@ def expand_dims(x: Array, /, *, axis: int = 0) -> Array: return jax.numpy.expand_dims(x, axis=axis) -def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: +def flip(x: Array, /, *, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array: """Reverses the order of elements in an array along the given axis.""" return jax.numpy.flip(x, axis=axis) -def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: +def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array: """Permutes the axes (dimensions) of an array x.""" return jax.lax.transpose(x, axes) -def reshape(x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None) -> Array: +def reshape(x: Array, /, shape: tuple[int, ...], *, copy: Optional[bool] = None) -> Array: """Reshapes an array without changing its data.""" del copy # unused return jax.numpy.reshape(x, shape) -def roll(x: Array, /, shift: Union[int, Tuple[int]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: +def roll(x: Array, /, shift: Union[int, tuple[int]], *, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array: """Rolls array elements along a specified axis.""" return jax.numpy.roll(x, shift=shift, axis=axis) -def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: +def squeeze(x: Array, /, axis: Union[int, tuple[int, ...]]) -> Array: """Removes singleton dimensions (axes) from x.""" dimensions = axis if isinstance(axis, tuple) else (axis,) return jax.lax.squeeze(x, dimensions=dimensions) -def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array: +def stack(arrays: Union[tuple[Array, ...], list[Array]], /, *, axis: int = 0) -> Array: """Joins a sequence of arrays along a new axis.""" dtype = _result_type(*arrays) return jax.numpy.stack(arrays, axis=axis, dtype=dtype) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 5dfad2df0..616a5e954 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2163,7 +2163,7 @@ tf_impl_with_avals[lax.dot_general_p] = _dot_general def _dot_general_convert_to_common_dtype( lhs: TfVal, lhs_aval: core.ShapedArray, rhs: TfVal, rhs_aval: core.ShapedArray, - out_aval: core.ShapedArray) -> Tuple[TfVal, TfVal, Callable[[TfVal], TfVal]]: + out_aval: core.ShapedArray) -> tuple[TfVal, TfVal, Callable[[TfVal], TfVal]]: # Returns the converted lhs, rhs, and the converter for the result. # tfxla.dot_general does not handle arguments of different types. # We convert the arguments and the result. diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 8e07262a2..b729c1193 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -76,7 +76,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): if all(tf_device.device_type != d.device_type for d in cls.tf_devices): cls.tf_devices.append(tf_device) - super(CallTfTest, cls).setUpClass() + super().setUpClass() def setUp(self): if tf is None: diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index a4adadf0f..def5af73b 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -69,7 +69,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): if all(tf_device.device_type != d.device_type for d in cls.tf_devices): cls.tf_devices.append(tf_device) - super(Jax2TfTest, cls).setUpClass() + super().setUpClass() def test_empty(self): f_jax = lambda x, y: x diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index 02e6e622c..f6b5a773a 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -173,7 +173,7 @@ def mha( block_q: int = 128, block_k: int = 128, backward_pass_impl: str = "triton", - num_warps: Optional[int] = None, + num_warps: int | None = None, num_stages: int = 2, grid: tuple[int, ...] | None = None, interpret: bool = False, @@ -239,7 +239,7 @@ def _mha_forward( block_q: int, block_k: int, backward_pass_impl: str, - num_warps: Optional[int], + num_warps: int | None, num_stages: int, grid: Any, interpret: bool, @@ -451,7 +451,7 @@ def mha_backward_kernel( def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, - backward_pass_impl: str, num_warps: Optional[int], + backward_pass_impl: str, num_warps: int | None, num_stages: int, grid: Any, interpret: bool, debug: bool, res, do): del num_warps, num_stages, grid diff --git a/jax/experimental/pallas/ops/tpu/all_gather.py b/jax/experimental/pallas/ops/tpu/all_gather.py index 828f8d7c1..979513d0f 100644 --- a/jax/experimental/pallas/ops/tpu/all_gather.py +++ b/jax/experimental/pallas/ops/tpu/all_gather.py @@ -25,7 +25,7 @@ chunk, splits it in two, and sends each of the half-chunks in each direction from __future__ import annotations import functools -from typing import Sequence +from collections.abc import Sequence import jax from jax import lax diff --git a/jax/tools/build_utils.py b/jax/tools/build_utils.py index ecf6669c4..22aa87771 100644 --- a/jax/tools/build_utils.py +++ b/jax/tools/build_utils.py @@ -23,7 +23,7 @@ import shutil import sys import subprocess import glob -from typing import Sequence +from collections.abc import Sequence def is_windows() -> bool: @@ -88,7 +88,7 @@ def build_editable( def update_setup_with_cuda_version(file_dir: pathlib.Path, cuda_version: str): src_file = file_dir / "setup.py" - with open(src_file, "r") as f: + with open(src_file) as f: content = f.read() content = content.replace( "cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}" diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 6c81a7c3e..85083eb8d 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -72,7 +72,7 @@ class PrimitiveTest(jtu.JaxTestCase): if d.platform not in cls.platforms: cls.platforms.append(d.platform) cls.devices.append(d) - super(PrimitiveTest, cls).setUpClass() + super().setUpClass() # For each primitive we export for all platforms that are available and # compare the results of running the exported code and running the native diff --git a/tests/export_test.py b/tests/export_test.py index 65904dcdf..c7e5b618d 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -18,7 +18,6 @@ import functools import logging import math import re -from typing import Optional import unittest from absl.testing import absltest @@ -98,7 +97,7 @@ mlir.register_lowering(testing_primitive_with_effect_p, _testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.) def _testing_multi_platform_func(x, *, - effect_class_name: Optional[str] = None): + effect_class_name: str | None = None): # Behaves like x + 2 * _testing_multi_platform_to_add[platform] def for_platform(platform: str): if effect_class_name is None: @@ -142,7 +141,7 @@ class JaxExportTest(jtu.JaxTestCase): except RuntimeError: continue cls.platforms.append(backend) - super(JaxExportTest, cls).setUpClass() + super().setUpClass() def setUp(self): super().setUp() diff --git a/tests/memories_test.py b/tests/memories_test.py index 2fb14d0d8..cfe7c95e6 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -266,7 +266,7 @@ class MemoriesComputationTest(jtu.BufferDonationTestCase): expected_out = np_inp * 2 self.assertArraysEqual(out, expected_out) - self.assertEqual(out.sharding, s.with_memory_kind(("unpinned_host"))) + self.assertEqual(out.sharding, s.with_memory_kind("unpinned_host")) self._check_mem_kind(executable_mk[0], out.sharding, "unpinned_host") for s in out.addressable_shards: self.assertArraysEqual(s.data, expected_out[s.index]) diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 48504ed71..a4b1fda5c 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -65,7 +65,7 @@ def mock_tpu_devices(x, y, z, dev_kind, one_device_per_chip, num_slices=1, # If this function raises, it's a bug in the test code! def _validate_mocked_process_indices(devices, one_device_per_chip): - process_to_devices = collections.defaultdict(lambda: []) + process_to_devices = collections.defaultdict(list) for d in devices: process_to_devices[d.process_index].append(d)