mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Upgrade most .py sources to 3.9
This commit was generated by running pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
This commit is contained in:
parent
7af1c149f5
commit
36f6b52e42
@ -29,7 +29,7 @@ import inspect
|
||||
import math
|
||||
import typing
|
||||
from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload,
|
||||
cast, Optional)
|
||||
cast)
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
@ -2461,7 +2461,7 @@ def make_jaxpr(fun: Callable,
|
||||
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
||||
return make_jaxpr_f
|
||||
|
||||
def _infer_src_sharding(src, x) -> Optional[Sharding]:
|
||||
def _infer_src_sharding(src, x) -> Sharding | None:
|
||||
if src is not None:
|
||||
return src
|
||||
if isinstance(x, array.ArrayImpl):
|
||||
|
@ -16,7 +16,8 @@
|
||||
|
||||
import abc
|
||||
import numpy as np
|
||||
from typing import Any, Sequence, Union
|
||||
from typing import Any, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
# TODO(jakevdp): fix import cycles and define these.
|
||||
Shard = Any
|
||||
|
@ -22,7 +22,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, Callable, Generic, NamedTuple, NoReturn, Optional, TypeVar
|
||||
from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar
|
||||
import warnings
|
||||
|
||||
from jax._src import lib
|
||||
@ -134,7 +134,7 @@ class Config:
|
||||
raise AttributeError(f"Unrecognized config option: {name}")
|
||||
|
||||
def add_option(self, name, default, opt_type, meta_args, meta_kwargs,
|
||||
update_hook: Optional[Callable[[Any], None]] = None):
|
||||
update_hook: Callable[[Any], None] | None = None):
|
||||
if name in self.values:
|
||||
raise Exception(f"Config option {name} already defined")
|
||||
self.values[name] = default
|
||||
@ -238,7 +238,7 @@ _thread_local_state = threading.local()
|
||||
|
||||
class _StateContextManager(Generic[_T]):
|
||||
def __init__(self, name, help, update_thread_local_hook,
|
||||
validate_new_val_hook: Optional[Callable[[Any], None]] = None,
|
||||
validate_new_val_hook: Callable[[Any], None] | None = None,
|
||||
extra_description: str = "", default_value: Any = no_default):
|
||||
self._name = name
|
||||
self.__name__ = name[4:] if name.startswith('jax_') else name
|
||||
@ -302,8 +302,8 @@ def define_bool_state(
|
||||
default: bool,
|
||||
help: str,
|
||||
*,
|
||||
update_global_hook: Optional[Callable[[bool], None]] = None,
|
||||
update_thread_local_hook: Optional[Callable[[Optional[bool]], None]] = None,
|
||||
update_global_hook: Callable[[bool], None] | None = None,
|
||||
update_thread_local_hook: Callable[[bool | None], None] | None = None,
|
||||
upgrade: bool = False,
|
||||
extra_description: str = '',
|
||||
) -> _StateContextManager[bool]:
|
||||
@ -375,11 +375,11 @@ def define_bool_state(
|
||||
def define_enum_state(
|
||||
name: str,
|
||||
enum_values: list[str],
|
||||
default: Optional[str],
|
||||
default: str | None,
|
||||
help: str,
|
||||
*,
|
||||
update_global_hook: Optional[Callable[[str], None]] = None,
|
||||
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
|
||||
update_global_hook: Callable[[str], None] | None = None,
|
||||
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
||||
) -> _StateContextManager[str]:
|
||||
"""Set up thread-local state and return a contextmanager for managing it.
|
||||
|
||||
@ -420,11 +420,11 @@ def define_enum_state(
|
||||
|
||||
def define_int_state(
|
||||
name: str,
|
||||
default: Optional[int],
|
||||
default: int | None,
|
||||
help: str,
|
||||
*,
|
||||
update_global_hook: Optional[Callable[[str], None]] = None,
|
||||
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
|
||||
update_global_hook: Callable[[str], None] | None = None,
|
||||
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
||||
) -> _StateContextManager[int]:
|
||||
"""Set up thread-local state and return a contextmanager for managing it.
|
||||
|
||||
@ -463,11 +463,11 @@ def define_int_state(
|
||||
|
||||
def define_float_state(
|
||||
name: str,
|
||||
default: Optional[float],
|
||||
default: float | None,
|
||||
help: str,
|
||||
*,
|
||||
update_global_hook: Optional[Callable[[str], None]] = None,
|
||||
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
|
||||
update_global_hook: Callable[[str], None] | None = None,
|
||||
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
||||
) -> _StateContextManager[float]:
|
||||
"""Set up thread-local state and return a contextmanager for managing it.
|
||||
|
||||
@ -508,11 +508,11 @@ def define_float_state(
|
||||
|
||||
def define_string_state(
|
||||
name: str,
|
||||
default: Optional[str],
|
||||
default: str | None,
|
||||
help: str,
|
||||
*,
|
||||
update_global_hook: Optional[Callable[[str], None]] = None,
|
||||
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
|
||||
update_global_hook: Callable[[str], None] | None = None,
|
||||
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
||||
) -> _StateContextManager[str]:
|
||||
"""Set up thread-local state and return a contextmanager for managing it.
|
||||
|
||||
@ -552,9 +552,9 @@ def define_string_or_object_state(
|
||||
default: Any,
|
||||
help: str,
|
||||
*,
|
||||
update_global_hook: Optional[Callable[[Any], None]] = None,
|
||||
update_thread_local_hook: Optional[Callable[[Any], None]] = None,
|
||||
validate_new_val_hook: Optional[Callable[[Any], None]] = None,
|
||||
update_global_hook: Callable[[Any], None] | None = None,
|
||||
update_thread_local_hook: Callable[[Any], None] | None = None,
|
||||
validate_new_val_hook: Callable[[Any], None] | None = None,
|
||||
) -> _StateContextManager[Any]:
|
||||
"""Set up thread-local state and return a contextmanager for managing it.
|
||||
|
||||
@ -651,9 +651,9 @@ already_configured_with_absl = False
|
||||
# a global/thread-local state. These methods allow updates to part of the
|
||||
# state when a configuration value changes.
|
||||
class _GlobalExtraJitContext(NamedTuple):
|
||||
numpy_rank_promotion: Optional[str] = None
|
||||
numpy_dtype_promotion: Optional[str] = None
|
||||
default_matmul_precision: Optional[Any] = None
|
||||
numpy_rank_promotion: str | None = None
|
||||
numpy_dtype_promotion: str | None = None
|
||||
default_matmul_precision: Any | None = None
|
||||
dynamic_shapes: bool = False
|
||||
threefry_partitionable: bool = False
|
||||
softmax_custom_jvp: bool = False
|
||||
@ -675,12 +675,12 @@ class _ThreadLocalExtraJitContext(NamedTuple):
|
||||
The initialization, which uses both config.py and core.py is done using
|
||||
`_update_thread_local_jit_state` in core.py to prevent circular imports.
|
||||
"""
|
||||
dynamic_trace_state: Optional[Any] = None
|
||||
dynamic_trace_state: Any | None = None
|
||||
axis_env_state: Hashable = ()
|
||||
mesh_context_manager: Hashable = ()
|
||||
numpy_rank_promotion: Optional[str] = None
|
||||
numpy_dtype_promotion: Optional[str] = None
|
||||
default_matmul_precision: Optional[Any] = None
|
||||
numpy_rank_promotion: str | None = None
|
||||
numpy_dtype_promotion: str | None = None
|
||||
default_matmul_precision: Any | None = None
|
||||
dynamic_shapes: bool = False
|
||||
threefry_partitionable: bool = False
|
||||
softmax_custom_jvp: bool = False
|
||||
@ -1320,7 +1320,7 @@ def transfer_guard(new_val: str) -> Iterator[None]:
|
||||
yield
|
||||
|
||||
|
||||
def _update_debug_log_modules(module_names_str: Optional[str]):
|
||||
def _update_debug_log_modules(module_names_str: str | None):
|
||||
logging_config.disable_all_debug_logging()
|
||||
if not module_names_str:
|
||||
return
|
||||
|
@ -24,7 +24,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import builtins
|
||||
import functools
|
||||
from typing import cast, overload, Any, Literal, Optional, Union
|
||||
from typing import cast, overload, Any, Literal, Union
|
||||
import warnings
|
||||
|
||||
import ml_dtypes
|
||||
@ -207,7 +207,7 @@ def to_complex_dtype(dtype: DTypeLike) -> DType:
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: Any) -> Union[DType, ExtendedDType]:
|
||||
def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: Any) -> DType | ExtendedDType:
|
||||
if issubdtype(dtype, extended):
|
||||
if not allow_extended_dtype:
|
||||
raise ValueError(f"Internal: canonicalize_dtype called on extended dtype {dtype} "
|
||||
@ -227,10 +227,10 @@ def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: An
|
||||
def canonicalize_dtype(dtype: Any, allow_extended_dtype: Literal[False] = False) -> DType: ...
|
||||
|
||||
@overload
|
||||
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> Union[DType, ExtendedDType]: ...
|
||||
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> DType | ExtendedDType: ...
|
||||
|
||||
@export
|
||||
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> Union[DType, ExtendedDType]:
|
||||
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> DType | ExtendedDType:
|
||||
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
|
||||
return _canonicalize_dtype(config.enable_x64.value, allow_extended_dtype, dtype) # pytype: disable=bad-return-type
|
||||
|
||||
@ -292,7 +292,7 @@ def _scalar_type_to_dtype(typ: type, value: Any = None) -> DType:
|
||||
return dtype
|
||||
|
||||
|
||||
def coerce_to_array(x: Any, dtype: Optional[DTypeLike] = None) -> np.ndarray:
|
||||
def coerce_to_array(x: Any, dtype: DTypeLike | None = None) -> np.ndarray:
|
||||
"""Coerces a scalar or NumPy array to an np.array.
|
||||
|
||||
Handles Python scalar type promotion according to JAX's rules, not NumPy's
|
||||
@ -643,10 +643,10 @@ def result_type(*args: Any, return_weak_type_flag: Literal[True]) -> tuple[DType
|
||||
def result_type(*args: Any, return_weak_type_flag: Literal[False] = False) -> DType: ...
|
||||
|
||||
@overload
|
||||
def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, tuple[DType, bool]]: ...
|
||||
def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tuple[DType, bool]: ...
|
||||
|
||||
@export
|
||||
def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, tuple[DType, bool]]:
|
||||
def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tuple[DType, bool]:
|
||||
"""Convenience function to apply JAX argument dtype promotion.
|
||||
|
||||
Args:
|
||||
|
@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Hashable
|
||||
from typing import Callable
|
||||
from collections.abc import Hashable
|
||||
|
||||
from jax import Array
|
||||
|
||||
|
@ -25,7 +25,7 @@ import itertools
|
||||
import operator
|
||||
import re
|
||||
import typing
|
||||
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union
|
||||
from typing import Any, Callable, NamedTuple, Protocol, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -707,7 +707,7 @@ def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str |
|
||||
return layout._to_xla_layout()
|
||||
|
||||
|
||||
def _get_mem_kind(s: Optional[XLACompatibleSharding]) -> Optional[str]:
|
||||
def _get_mem_kind(s: XLACompatibleSharding | None) -> str | None:
|
||||
if s is None:
|
||||
return None
|
||||
assert isinstance(s, sharding_impls.XLACompatibleSharding)
|
||||
@ -1454,7 +1454,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
||||
override_rule = get_override_lowering_rule(eqn.primitive)
|
||||
platform_rules: dict[str, LoweringRule] = {}
|
||||
default_rule: Optional[LoweringRule] = None
|
||||
default_rule: LoweringRule | None = None
|
||||
# See mlir.lower_per_platform for meaning of `platform_rules` and `default_rule`
|
||||
if override_rule is not None:
|
||||
default_rule = override_rule
|
||||
@ -1525,7 +1525,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
def lower_per_platform(ctx: LoweringRuleContext,
|
||||
description: str,
|
||||
platform_rules: dict[str, LoweringRule],
|
||||
default_rule: Optional[LoweringRule],
|
||||
default_rule: LoweringRule | None,
|
||||
effects: effects_lib.Effects,
|
||||
*rule_args: ir.Value,
|
||||
**rule_kwargs) -> ir.Value:
|
||||
@ -1710,7 +1710,7 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects,
|
||||
return func_op
|
||||
|
||||
|
||||
def check_backend_matches(inner_backend: Optional[str],
|
||||
def check_backend_matches(inner_backend: str | None,
|
||||
lowering_platforms: Sequence[str]):
|
||||
# For nested calls, the outermost call sets the backend for all inner calls;
|
||||
# it's an error if the inner call has a conflicting explicit backend spec.
|
||||
|
@ -25,7 +25,8 @@ import itertools as it
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
from typing import (Any, Callable, NamedTuple, Iterator, Optional, Union, cast, TypeVar)
|
||||
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar)
|
||||
from collections.abc import Iterator
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -2914,8 +2915,8 @@ def _compile_replicated_mesh_executable_from_hlo(
|
||||
|
||||
@lru_cache
|
||||
def create_mesh_pspec_sharding(
|
||||
mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None,
|
||||
memory_kind: Optional[str] = None) -> sharding_impls.NamedSharding:
|
||||
mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None,
|
||||
memory_kind: str | None = None) -> sharding_impls.NamedSharding:
|
||||
if pspec is None:
|
||||
pspec, parsed_pspec = PartitionSpec(), None
|
||||
return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec,
|
||||
|
@ -202,7 +202,7 @@ def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes:
|
||||
pprof tool for visualization.
|
||||
"""
|
||||
d: DefaultDict[tuple[Optional[xla_client.Traceback], core.Primitive], int]
|
||||
d = collections.defaultdict(lambda: 0)
|
||||
d = collections.defaultdict(int)
|
||||
for _, eqn in all_eqns(jaxpr):
|
||||
d[(eqn.source_info.traceback, eqn.primitive)] += 1
|
||||
return _pprof_profile(d)
|
||||
|
@ -19,7 +19,7 @@ from functools import partial
|
||||
import inspect
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
import jax
|
||||
import weakref
|
||||
@ -104,7 +104,7 @@ Y = TypeVar('Y')
|
||||
def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
init: Carry,
|
||||
xs: X,
|
||||
length: Optional[int] = None,
|
||||
length: int | None = None,
|
||||
reverse: bool = False,
|
||||
unroll: int | bool = 1) -> tuple[Carry, Y]:
|
||||
"""Scan a function over leading array axes while carrying along state.
|
||||
|
@ -22,7 +22,7 @@ from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
from typing import (Any, Callable, Optional, TypeVar, Union,
|
||||
from typing import (Any, Callable, TypeVar, Union,
|
||||
cast as type_cast, overload)
|
||||
import warnings
|
||||
|
||||
@ -104,7 +104,7 @@ def _validate_shapes(shapes: Sequence[Shape]):
|
||||
map(_check_static_shape, shapes)
|
||||
|
||||
def _try_broadcast_shapes(
|
||||
shapes: Sequence[tuple[int, ...]]) -> Optional[tuple[int, ...]]:
|
||||
shapes: Sequence[tuple[int, ...]]) -> tuple[int, ...] | None:
|
||||
if len(shapes) == 1: return shapes[0]
|
||||
ranks = {len(shape) for shape in shapes}
|
||||
if len(ranks) > 1: return None # must have consistent rank
|
||||
@ -140,8 +140,8 @@ def asarray(x: ArrayLike) -> Array:
|
||||
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: ...
|
||||
|
||||
@overload
|
||||
def broadcast_shapes(*shapes: tuple[Union[int, core.Tracer], ...]
|
||||
) -> tuple[Union[int, core.Tracer], ...]: ...
|
||||
def broadcast_shapes(*shapes: tuple[int | core.Tracer, ...]
|
||||
) -> tuple[int | core.Tracer, ...]: ...
|
||||
|
||||
def broadcast_shapes(*shapes):
|
||||
"""Returns the shape that results from NumPy broadcasting of `shapes`."""
|
||||
@ -183,8 +183,8 @@ def _broadcast_ranks(s1, s2):
|
||||
def _identity(x): return x
|
||||
|
||||
def _extract_tracers_dyn_shape(
|
||||
shape: Sequence[Union[int, core.Tracer]]
|
||||
) -> tuple[list[core.Tracer], list[Optional[int]]]:
|
||||
shape: Sequence[int | core.Tracer]
|
||||
) -> tuple[list[core.Tracer], list[int | None]]:
|
||||
# Given a sequence representing a shape, pull out Tracers, replacing with None
|
||||
if config.dynamic_shapes.value:
|
||||
# We must gate this behavior under a flag because otherwise the errors
|
||||
@ -196,9 +196,9 @@ def _extract_tracers_dyn_shape(
|
||||
return [], list(shape) # type: ignore
|
||||
|
||||
def _merge_dyn_shape(
|
||||
static_shape: Sequence[Optional[int]],
|
||||
static_shape: Sequence[int | None],
|
||||
dyn_shape: Sequence[Any],
|
||||
) -> tuple[Union[int, mlir.Value, core.Tracer], ...]:
|
||||
) -> tuple[int | mlir.Value | core.Tracer, ...]:
|
||||
# Replace Nones in static_shape with elements of dyn_shape, in order
|
||||
dyn_shape_it = iter(dyn_shape)
|
||||
shape = tuple(next(dyn_shape_it) if d is None else d for d in static_shape)
|
||||
@ -516,7 +516,7 @@ def convert_element_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
|
||||
"""
|
||||
return _convert_element_type(operand, new_dtype, weak_type=False)
|
||||
|
||||
def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DTypeLike] = None,
|
||||
def _convert_element_type(operand: ArrayLike, new_dtype: DTypeLike | None = None,
|
||||
weak_type: bool = False):
|
||||
if hasattr(operand, '__jax_array__'):
|
||||
operand = operand.__jax_array__() # type: ignore
|
||||
@ -599,7 +599,7 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array:
|
||||
"""
|
||||
return clamp_p.bind(min, x, max)
|
||||
|
||||
def concatenate(operands: Union[Array, Sequence[ArrayLike]], dimension: int) -> Array:
|
||||
def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
|
||||
"""Concatenates a sequence of arrays along `dimension`.
|
||||
|
||||
Wraps XLA's `Concatenate
|
||||
@ -677,7 +677,7 @@ PrecisionLike = Union[None, str, PrecisionType, tuple[str, str],
|
||||
tuple[PrecisionType, PrecisionType]]
|
||||
|
||||
def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
|
||||
preferred_element_type: Optional[DTypeLike] = None) -> Array:
|
||||
preferred_element_type: DTypeLike | None = None) -> Array:
|
||||
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
|
||||
|
||||
Wraps XLA's `Dot
|
||||
@ -714,7 +714,7 @@ DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
|
||||
|
||||
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: Optional[DTypeLike] = None) -> Array:
|
||||
preferred_element_type: DTypeLike | None = None) -> Array:
|
||||
"""General dot product/contraction operator.
|
||||
|
||||
Wraps XLA's `DotGeneral
|
||||
@ -813,7 +813,7 @@ def broadcast_to_rank(x: Array, rank: int) -> Array:
|
||||
return broadcast(x, (1,) * (rank - x.ndim))
|
||||
|
||||
def reshape(operand: ArrayLike, new_sizes: Shape,
|
||||
dimensions: Optional[Sequence[int]] = None) -> Array:
|
||||
dimensions: Sequence[int] | None = None) -> Array:
|
||||
"""Wraps XLA's `Reshape
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
|
||||
operator.
|
||||
@ -1042,7 +1042,7 @@ def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree):
|
||||
return jaxpr, tuple(consts), out_tree()
|
||||
|
||||
def _get_monoid_reducer(monoid_op: Callable,
|
||||
xs: Sequence[Array]) -> Optional[Callable]:
|
||||
xs: Sequence[Array]) -> Callable | None:
|
||||
if len(xs) != 1:
|
||||
return None
|
||||
x, = xs
|
||||
@ -1128,8 +1128,8 @@ def sort(operand: Array, dimension: int = -1,
|
||||
def sort(operand: Sequence[Array], dimension: int = -1,
|
||||
is_stable: bool = True, num_keys: int = 1) -> tuple[Array, ...]: ...
|
||||
|
||||
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
|
||||
is_stable: bool = True, num_keys: int = 1) -> Union[Array, tuple[Array, ...]]:
|
||||
def sort(operand: Array | Sequence[Array], dimension: int = -1,
|
||||
is_stable: bool = True, num_keys: int = 1) -> Array | tuple[Array, ...]:
|
||||
"""Wraps XLA's `Sort
|
||||
<https://www.tensorflow.org/xla/operation_semantics#sort>`_ operator.
|
||||
|
||||
@ -1196,7 +1196,7 @@ def tie_in(x: Any, y: T) -> T:
|
||||
"""Deprecated. Ignores ``x`` and returns ``y``."""
|
||||
return y
|
||||
|
||||
def full(shape: Shape, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None) -> Array:
|
||||
def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None) -> Array:
|
||||
"""Returns an array of `shape` filled with `fill_value`.
|
||||
|
||||
Args:
|
||||
@ -1303,7 +1303,7 @@ def stop_gradient(x: T) -> T:
|
||||
return x
|
||||
return tree_map(stop, x)
|
||||
|
||||
def reduce_precision(operand: Union[float, ArrayLike],
|
||||
def reduce_precision(operand: float | ArrayLike,
|
||||
exponent_bits: int,
|
||||
mantissa_bits: int) -> Array:
|
||||
"""Wraps XLA's `ReducePrecision
|
||||
@ -1342,9 +1342,9 @@ def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array:
|
||||
|
||||
### convenience wrappers around traceables
|
||||
|
||||
def full_like(x: Union[ArrayLike, DuckTypedArray],
|
||||
fill_value: ArrayLike, dtype: Optional[DTypeLike] = None,
|
||||
shape: Optional[Shape] = None) -> Array:
|
||||
def full_like(x: ArrayLike | DuckTypedArray,
|
||||
fill_value: ArrayLike, dtype: DTypeLike | None = None,
|
||||
shape: Shape | None = None) -> Array:
|
||||
"""Create a full array like np.full based on the example array `x`.
|
||||
|
||||
Args:
|
||||
@ -1380,7 +1380,7 @@ def full_like(x: Union[ArrayLike, DuckTypedArray],
|
||||
|
||||
|
||||
def collapse(operand: Array, start_dimension: int,
|
||||
stop_dimension: Optional[int] = None) -> Array:
|
||||
stop_dimension: int | None = None) -> Array:
|
||||
"""Collapses dimensions of an array into a single dimension.
|
||||
|
||||
For example, if ``operand`` is an array with shape ``[2, 3, 4]``,
|
||||
@ -1691,7 +1691,7 @@ def broadcast_hlo(
|
||||
return out
|
||||
|
||||
def _nary_lower_hlo(op: Callable, ctx,
|
||||
*args: Union[ir.Value, Sequence[ir.Value]],
|
||||
*args: ir.Value | Sequence[ir.Value],
|
||||
explicit_type=False, **params) -> Sequence[ir.Value]:
|
||||
"""Lowers an elementwise operator to its MLIR equivalent.
|
||||
|
||||
@ -2516,7 +2516,7 @@ def _precision_config(precision):
|
||||
|
||||
|
||||
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: Optional[DTypeLike]):
|
||||
preferred_element_type: DTypeLike | None):
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
|
||||
for d in (lhs_contracting, lhs_batch)):
|
||||
@ -2592,7 +2592,7 @@ def tuple_delete(tup, idx):
|
||||
|
||||
|
||||
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: Optional[DTypeLike]):
|
||||
preferred_element_type: DTypeLike | None):
|
||||
# We're mostly matching XLA's logic here, namely in shape_inference.cc and
|
||||
# primitive_util.h's HigherPrecisionType, e.g.
|
||||
# https://github.com/openxla/xla/blob/ea3a841768d0dcf192e5820c9b25c34c73f2226a/xla/primitive_util.h#L329
|
||||
@ -2636,7 +2636,7 @@ def _maybe_upcast(result_dtype, preferred_element_type):
|
||||
return preferred_element_type
|
||||
|
||||
def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
|
||||
preferred_element_type: Optional[DTypeLike],
|
||||
preferred_element_type: DTypeLike | None,
|
||||
swap_ans=False):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
x_ndim = x.aval.ndim
|
||||
@ -2657,7 +2657,7 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
|
||||
return x_bar
|
||||
|
||||
def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
|
||||
preferred_element_type: Optional[DTypeLike]):
|
||||
preferred_element_type: DTypeLike | None):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
|
||||
y_bar = _dot_general_transpose_lhs(
|
||||
@ -2670,7 +2670,7 @@ def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
|
||||
|
||||
def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type: Optional[DTypeLike]):
|
||||
preferred_element_type: DTypeLike | None):
|
||||
lhs, rhs = batched_args
|
||||
lbd, rbd = batch_dims
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
@ -2794,7 +2794,7 @@ def precision_attr(precision: PrecisionType) -> ir.ArrayAttr:
|
||||
|
||||
|
||||
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
precision, preferred_element_type: Optional[np.dtype],
|
||||
precision, preferred_element_type: np.dtype | None,
|
||||
platform: str = "default"):
|
||||
del preferred_element_type # Implied by the output aval
|
||||
lhs_aval, rhs_aval = ctx.avals_in
|
||||
@ -4540,7 +4540,7 @@ def _array_copy(arr: ArrayLike) -> Array:
|
||||
return copy_p.bind(arr)
|
||||
|
||||
|
||||
def _which_dim_sharded(s: PmapSharding) -> Optional[int]:
|
||||
def _which_dim_sharded(s: PmapSharding) -> int | None:
|
||||
sharded_dim = None
|
||||
for i, s in enumerate(s.sharding_spec.sharding):
|
||||
if isinstance(s, pxla.Unstacked):
|
||||
@ -4898,7 +4898,7 @@ def remaining(original, *removed_lists):
|
||||
return [i for i in original if i not in removed]
|
||||
|
||||
|
||||
def canonicalize_precision(precision: PrecisionLike) -> Optional[tuple[PrecisionType, PrecisionType]]:
|
||||
def canonicalize_precision(precision: PrecisionLike) -> tuple[PrecisionType, PrecisionType] | None:
|
||||
"""Turns an API precision specification, into a pair of enumeration values.
|
||||
|
||||
The API can take the precision as a string, or int, and either as a single
|
||||
|
@ -197,7 +197,7 @@ class Mesh(contextlib.ContextDecorator):
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
self = super(Mesh, cls).__new__(cls)
|
||||
self = super().__new__(cls)
|
||||
self.devices = devices.copy()
|
||||
self.devices.flags.writeable = False
|
||||
self.axis_names = axis_names
|
||||
|
@ -32,8 +32,7 @@ from functools import partial
|
||||
import math
|
||||
import operator
|
||||
import types
|
||||
from typing import (overload, Any, Callable, Literal, NamedTuple, Optional,
|
||||
Protocol, TypeVar, Union)
|
||||
from typing import (overload, Any, Callable, Literal, NamedTuple, Protocol, TypeVar, Union)
|
||||
from textwrap import dedent as _dedent
|
||||
import warnings
|
||||
|
||||
@ -2443,7 +2442,7 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array:
|
||||
``(jnp.arange(-600, 600) * .01).astype(jnp.bfloat16)`` to generate a sequence in a higher precision
|
||||
and then convert it to the desired lower precision.
|
||||
""")
|
||||
def arange(start: DimSize, stop: Optional[DimSize] = None,
|
||||
def arange(start: DimSize, stop: DimSize | None = None,
|
||||
step: DimSize | None = None, dtype: DTypeLike | None = None) -> Array:
|
||||
dtypes.check_user_dtype_supported(dtype, "arange")
|
||||
if not config.dynamic_shapes.value:
|
||||
@ -2480,8 +2479,7 @@ def arange(start: DimSize, stop: Optional[DimSize] = None,
|
||||
return lax.iota(dtype, start)
|
||||
else:
|
||||
if step is None and start == 0 and stop is not None:
|
||||
stop = np.ceil(stop).astype(int)
|
||||
return lax.iota(dtype, stop)
|
||||
return lax.iota(dtype, np.ceil(stop).astype(int))
|
||||
return array(np.arange(start, stop=stop, step=step, dtype=dtype))
|
||||
|
||||
|
||||
@ -2833,7 +2831,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
|
||||
|
||||
|
||||
@util._wraps(np.tri)
|
||||
def tri(N: int, M: int | None = None, k: int = 0, dtype: Optional[DTypeLike] = None) -> Array:
|
||||
def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array:
|
||||
dtypes.check_user_dtype_supported(dtype, "tri")
|
||||
M = M if M is not None else N
|
||||
dtype = dtype or float32
|
||||
@ -3933,7 +3931,7 @@ def sort_complex(a: ArrayLike) -> Array:
|
||||
|
||||
@util._wraps(np.lexsort)
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def lexsort(keys: Union[Array, np.ndarray, Sequence[ArrayLike]], axis: int = -1) -> Array:
|
||||
def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array:
|
||||
key_tuple = tuple(keys)
|
||||
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
|
||||
util.check_arraylike("lexsort", *key_tuple, emit_warning=True)
|
||||
|
@ -17,7 +17,7 @@ from collections.abc import Sequence
|
||||
from functools import partial
|
||||
import re
|
||||
import textwrap
|
||||
from typing import Any, Callable, NamedTuple, Optional, TypeVar
|
||||
from typing import Any, Callable, NamedTuple, TypeVar
|
||||
|
||||
import warnings
|
||||
|
||||
@ -50,14 +50,14 @@ class ParsedDoc(NamedTuple):
|
||||
front_matter: front matter before sections.
|
||||
sections: dictionary of section titles to section content.
|
||||
"""
|
||||
docstr: Optional[str]
|
||||
docstr: str | None
|
||||
signature: str = ""
|
||||
summary: str = ""
|
||||
front_matter: str = ""
|
||||
sections: dict[str, str] = {}
|
||||
|
||||
|
||||
def _parse_numpydoc(docstr: Optional[str]) -> ParsedDoc:
|
||||
def _parse_numpydoc(docstr: str | None) -> ParsedDoc:
|
||||
"""Parse a standard numpy-style docstring.
|
||||
|
||||
Args:
|
||||
@ -118,13 +118,13 @@ def _parse_extra_params(extra_params: str) -> dict[str, str]:
|
||||
|
||||
|
||||
def _wraps(
|
||||
fun: Optional[Callable[..., Any]],
|
||||
fun: Callable[..., Any] | None,
|
||||
update_doc: bool = True,
|
||||
lax_description: str = "",
|
||||
sections: Sequence[str] = ('Parameters', 'Returns', 'References'),
|
||||
skip_params: Sequence[str] = (),
|
||||
extra_params: Optional[str] = None,
|
||||
module: Optional[str] = None,
|
||||
extra_params: str | None = None,
|
||||
module: str | None = None,
|
||||
) -> Callable[[_T], _T]:
|
||||
"""Specialized version of functools.wraps for wrapping numpy functions.
|
||||
|
||||
|
@ -19,7 +19,8 @@ from collections.abc import Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Any, Callable, Iterator
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Iterator
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import core as jax_core
|
||||
|
@ -17,7 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Tuple
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
@ -37,7 +37,7 @@ import numpy as np
|
||||
# to `tl.broadcast_to`.
|
||||
broadcast_to_p = jax_core.Primitive('broadcast_to')
|
||||
|
||||
def broadcast_to(a: jax.Array, shape: Tuple[int, ...]) -> jax.Array:
|
||||
def broadcast_to(a: jax.Array, shape: tuple[int, ...]) -> jax.Array:
|
||||
if a.shape == shape:
|
||||
return a
|
||||
return broadcast_to_p.bind(a, shape=shape)
|
||||
@ -99,9 +99,9 @@ ds = dslice # Handy alias
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass
|
||||
class NDIndexer:
|
||||
indices: Tuple[int | Slice | jax.Array, ...]
|
||||
shape: Tuple[int, ...]
|
||||
int_indexer_shape: Tuple[int, ...]
|
||||
indices: tuple[int | Slice | jax.Array, ...]
|
||||
shape: tuple[int, ...]
|
||||
int_indexer_shape: tuple[int, ...]
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.indices) != len(self.shape):
|
||||
@ -148,7 +148,7 @@ class NDIndexer:
|
||||
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
|
||||
return NDIndexer(tuple(indices), shape, bcast_shape)
|
||||
|
||||
def get_indexer_shape(self) -> Tuple[int, ...]:
|
||||
def get_indexer_shape(self) -> tuple[int, ...]:
|
||||
is_int_indexing = [not isinstance(i, Slice) for i in self.indices]
|
||||
other_indexers, _ = partition_list(is_int_indexing, self.indices)
|
||||
other_shape = [s.size for s in other_indexers] # type: ignore
|
||||
|
@ -17,7 +17,8 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Any, Callable, Sequence
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jax import core as jax_core
|
||||
from jax import lax
|
||||
|
@ -18,7 +18,8 @@ from __future__ import annotations
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
|
||||
from typing import Any, Callable, Dict, Sequence, Tuple
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Sequence
|
||||
|
||||
import jax
|
||||
from jax import api_util
|
||||
@ -89,7 +90,7 @@ def _uninitialized_value(shape, dtype):
|
||||
def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
|
||||
interpret, debug: bool,
|
||||
in_shapes,
|
||||
input_output_aliases: Tuple[Tuple[int, int], ...],
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
**compiler_params: Any):
|
||||
if interpret:
|
||||
@ -189,7 +190,7 @@ def _pallas_call_abstract_eval(*avals, out_shapes, **_):
|
||||
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
|
||||
|
||||
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
|
||||
input_output_aliases: Tuple[Tuple[int, int], ...],
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
in_shapes, out_shapes, grid_mapping, debug, interpret, **compiler_params: Any):
|
||||
if grid_mapping.num_index_operands:
|
||||
raise NotImplementedError
|
||||
@ -236,7 +237,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
|
||||
return out_primals, out_tangents
|
||||
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
|
||||
|
||||
def _batch_block_mapping(grid: Tuple[int, ...], aval: jax_core.ShapedArray,
|
||||
def _batch_block_mapping(grid: tuple[int, ...], aval: jax_core.ShapedArray,
|
||||
dim: int | batching.NotMapped,
|
||||
block_mapping: BlockMapping | None) -> BlockMapping:
|
||||
def _block_map_function(new_idx, *args):
|
||||
@ -271,13 +272,13 @@ def _batch_block_mapping(grid: Tuple[int, ...], aval: jax_core.ShapedArray,
|
||||
def _pallas_call_batching_rule(args, dims, *,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
grid_mapping: GridMapping,
|
||||
input_output_aliases: Tuple[Tuple[int, int], ...],
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
which_linear: Tuple[bool, ...],
|
||||
which_linear: tuple[bool, ...],
|
||||
**compiler_params: Any):
|
||||
if grid_mapping.num_index_operands:
|
||||
scalar_batch_dims = dims[:grid_mapping.num_index_operands]
|
||||
@ -428,7 +429,7 @@ def pallas_call(
|
||||
in_specs: Sequence[BlockSpec | NoBlockSpec] | NoBlockSpec = no_block_spec,
|
||||
out_specs: BlockSpec | NoBlockSpec
|
||||
| Sequence[BlockSpec | NoBlockSpec] = no_block_spec,
|
||||
input_output_aliases: Dict[int, int] = {},
|
||||
input_output_aliases: dict[int, int] = {},
|
||||
interpret: bool = False,
|
||||
name: str | None = None,
|
||||
**compiler_params: Any,
|
||||
|
@ -18,7 +18,8 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import functools
|
||||
import operator
|
||||
from typing import Any, Callable, Dict, Sequence, Tuple
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Sequence
|
||||
import zlib
|
||||
|
||||
import jax
|
||||
@ -67,7 +68,7 @@ import triton.language as tl
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
partial = functools.partial
|
||||
Grid = Tuple[int, ...]
|
||||
Grid = tuple[int, ...]
|
||||
NDIndexer = indexing.NDIndexer
|
||||
GridMapping = pallas_core.GridMapping
|
||||
BlockMapping = pallas_core.BlockMapping
|
||||
@ -88,7 +89,7 @@ class TritonModuleContext:
|
||||
class BlockInfo:
|
||||
full_shape_dtype: jax.ShapeDtypeStruct
|
||||
start_indices: Sequence[Any]
|
||||
block_shape: Tuple[int, ...]
|
||||
block_shape: tuple[int, ...]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -111,7 +112,7 @@ class TritonLoweringResult:
|
||||
ir_context: tl_ir.context
|
||||
module: tl_ir.module
|
||||
builder: tl_ir.builder
|
||||
grid: Tuple[int, ...]
|
||||
grid: tuple[int, ...]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -610,7 +611,7 @@ def _compute_pointers_from_indices(
|
||||
root_ptr: tl.core.tensor,
|
||||
block_info: BlockInfo | None,
|
||||
nd_indexer: NDIndexer,
|
||||
array_shape: Tuple[int, ...],
|
||||
array_shape: tuple[int, ...],
|
||||
builder: tl_ir.builder,
|
||||
) -> tl.core.tensor:
|
||||
if block_info is None:
|
||||
@ -1512,14 +1513,14 @@ def pallas_call_lowering(
|
||||
*in_nodes,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
|
||||
which_linear: Tuple[bool, ...],
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
which_linear: tuple[bool, ...],
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
input_output_aliases: Tuple[Tuple[int, int], ...],
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
triton_params: Dict[str, Any] | None = None,
|
||||
triton_params: dict[str, Any] | None = None,
|
||||
**compiler_params: Any,
|
||||
):
|
||||
if interpret:
|
||||
|
@ -15,7 +15,6 @@
|
||||
"""Pallas utility functions."""
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
from jax import lax
|
||||
from jax._src import core as jax_core
|
||||
@ -36,7 +35,7 @@ def cdiv(a: int, b: int) -> int:
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def strides_from_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]:
|
||||
size = np.prod(shape)
|
||||
strides = []
|
||||
for s in shape:
|
||||
|
@ -19,7 +19,8 @@ from functools import partial
|
||||
import math
|
||||
from operator import index
|
||||
import typing
|
||||
from typing import Hashable, Optional, Union
|
||||
from typing import Union
|
||||
from collections.abc import Hashable
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -150,7 +151,7 @@ class PRNGSpec:
|
||||
PRNGSpecDesc = Union[str, PRNGSpec, PRNGImpl]
|
||||
|
||||
|
||||
def resolve_prng_impl(impl_spec: Optional[PRNGSpecDesc]) -> PRNGImpl:
|
||||
def resolve_prng_impl(impl_spec: PRNGSpecDesc | None) -> PRNGImpl:
|
||||
if impl_spec is None:
|
||||
return default_prng_impl()
|
||||
if type(impl_spec) is PRNGImpl:
|
||||
@ -174,7 +175,7 @@ def resolve_prng_impl(impl_spec: Optional[PRNGSpecDesc]) -> PRNGImpl:
|
||||
|
||||
|
||||
def _key(ctor_name: str, seed: int | ArrayLike,
|
||||
impl_spec: Optional[PRNGSpecDesc]) -> KeyArray:
|
||||
impl_spec: PRNGSpecDesc | None) -> KeyArray:
|
||||
impl = resolve_prng_impl(impl_spec)
|
||||
if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key):
|
||||
raise TypeError(
|
||||
@ -186,7 +187,7 @@ def _key(ctor_name: str, seed: int | ArrayLike,
|
||||
return prng.random_seed(seed, impl=impl)
|
||||
|
||||
def key(seed: int | ArrayLike, *,
|
||||
impl: Optional[PRNGSpecDesc] = None) -> KeyArray:
|
||||
impl: PRNGSpecDesc | None = None) -> KeyArray:
|
||||
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
|
||||
|
||||
The result is a scalar array with a key that indicates the default PRNG
|
||||
@ -205,7 +206,7 @@ def key(seed: int | ArrayLike, *,
|
||||
return _key('key', seed, impl)
|
||||
|
||||
def PRNGKey(seed: int | ArrayLike, *,
|
||||
impl: Optional[PRNGSpecDesc] = None) -> KeyArray:
|
||||
impl: PRNGSpecDesc | None = None) -> KeyArray:
|
||||
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
|
||||
|
||||
The resulting key carries the default PRNG implementation, as
|
||||
@ -277,7 +278,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
|
||||
return _return_prng_keys(wrapped, key_out)
|
||||
|
||||
|
||||
def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
|
||||
def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray:
|
||||
# Alternative to split() to use within random samplers.
|
||||
# TODO(frostig): remove and use split(); we no longer need to wait
|
||||
# to always enable_custom_prng
|
||||
@ -288,7 +289,7 @@ def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
|
||||
shape = tuple(num) if isinstance(num, Sequence) else (num,)
|
||||
return prng.random_split(key, shape=shape)
|
||||
|
||||
def split(key: KeyArrayLike, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
|
||||
def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
|
||||
"""Splits a PRNG key into `num` new keys by adding a leading axis.
|
||||
|
||||
Args:
|
||||
@ -324,7 +325,7 @@ def key_data(keys: KeyArrayLike) -> Array:
|
||||
|
||||
|
||||
def wrap_key_data(key_bits_array: Array, *,
|
||||
impl: Optional[PRNGSpecDesc] = None):
|
||||
impl: PRNGSpecDesc | None = None):
|
||||
"""Wrap an array of key data bits into a PRNG key array.
|
||||
|
||||
Args:
|
||||
@ -344,7 +345,7 @@ def wrap_key_data(key_bits_array: Array, *,
|
||||
### random samplers
|
||||
|
||||
|
||||
def _check_shape(name: str, shape: Union[Shape, NamedShape], *param_shapes) -> None:
|
||||
def _check_shape(name: str, shape: Shape | NamedShape, *param_shapes) -> None:
|
||||
shape = core.as_named_shape(shape)
|
||||
|
||||
if param_shapes:
|
||||
@ -358,7 +359,7 @@ def _check_shape(name: str, shape: Union[Shape, NamedShape], *param_shapes) -> N
|
||||
|
||||
def bits(key: KeyArrayLike,
|
||||
shape: Shape = (),
|
||||
dtype: Optional[DTypeLikeUInt] = None) -> Array:
|
||||
dtype: DTypeLikeUInt | None = None) -> Array:
|
||||
"""Sample uniform bits in the form of unsigned integers.
|
||||
|
||||
Args:
|
||||
@ -386,7 +387,7 @@ def bits(key: KeyArrayLike,
|
||||
|
||||
|
||||
def uniform(key: KeyArrayLike,
|
||||
shape: Union[Shape, NamedShape] = (),
|
||||
shape: Shape | NamedShape = (),
|
||||
dtype: DTypeLikeFloat = float,
|
||||
minval: RealArray = 0.,
|
||||
maxval: RealArray = 1.) -> Array:
|
||||
@ -561,7 +562,7 @@ def shuffle(key: KeyArrayLike, x: ArrayLike, axis: int = 0) -> Array:
|
||||
|
||||
|
||||
def permutation(key: KeyArrayLike,
|
||||
x: Union[int, ArrayLike],
|
||||
x: int | ArrayLike,
|
||||
axis: int = 0,
|
||||
independent: bool = False) -> Array:
|
||||
"""Returns a randomly permuted array or range.
|
||||
@ -620,10 +621,10 @@ def _shuffle(key, x, axis) -> Array:
|
||||
|
||||
|
||||
def choice(key: KeyArrayLike,
|
||||
a: Union[int, ArrayLike],
|
||||
a: int | ArrayLike,
|
||||
shape: Shape = (),
|
||||
replace: bool = True,
|
||||
p: Optional[RealArray] = None,
|
||||
p: RealArray | None = None,
|
||||
axis: int = 0) -> Array:
|
||||
"""Generates a random sample from a given array.
|
||||
|
||||
@ -697,7 +698,7 @@ def choice(key: KeyArrayLike,
|
||||
|
||||
|
||||
def normal(key: KeyArrayLike,
|
||||
shape: Union[Shape, NamedShape] = (),
|
||||
shape: Shape | NamedShape = (),
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample standard normal random values with given shape and float dtype.
|
||||
|
||||
@ -752,8 +753,8 @@ def _normal_real(key, shape, dtype) -> Array:
|
||||
def multivariate_normal(key: KeyArrayLike,
|
||||
mean: RealArray,
|
||||
cov: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
dtype: Optional[DTypeLikeFloat] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat | None = None,
|
||||
method: str = 'cholesky') -> Array:
|
||||
r"""Sample multivariate normal random values with given mean and covariance.
|
||||
|
||||
@ -835,7 +836,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array:
|
||||
def truncated_normal(key: KeyArrayLike,
|
||||
lower: RealArray,
|
||||
upper: RealArray,
|
||||
shape: Optional[Union[Shape, NamedShape]] = None,
|
||||
shape: Shape | NamedShape | None = None,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample truncated standard normal random values with given shape and dtype.
|
||||
|
||||
@ -900,7 +901,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array:
|
||||
|
||||
def bernoulli(key: KeyArrayLike,
|
||||
p: RealArray = np.float32(0.5),
|
||||
shape: Optional[Union[Shape, NamedShape]] = None) -> Array:
|
||||
shape: Shape | NamedShape | None = None) -> Array:
|
||||
r"""Sample Bernoulli random values with given shape and mean.
|
||||
|
||||
The values are distributed according to the probability mass function:
|
||||
@ -946,7 +947,7 @@ def _bernoulli(key, p, shape) -> Array:
|
||||
def beta(key: KeyArrayLike,
|
||||
a: RealArray,
|
||||
b: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Beta random values with given shape and float dtype.
|
||||
|
||||
@ -1045,7 +1046,7 @@ def _cauchy(key, shape, dtype) -> Array:
|
||||
|
||||
def dirichlet(key: KeyArrayLike,
|
||||
alpha: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Dirichlet random values with given shape and float dtype.
|
||||
|
||||
@ -1284,7 +1285,7 @@ batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule
|
||||
|
||||
def gamma(key: KeyArrayLike,
|
||||
a: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Gamma random values with given shape and float dtype.
|
||||
|
||||
@ -1331,7 +1332,7 @@ def gamma(key: KeyArrayLike,
|
||||
|
||||
def loggamma(key: KeyArrayLike,
|
||||
a: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
"""Sample log-gamma random values with given shape and float dtype.
|
||||
|
||||
@ -1473,7 +1474,7 @@ def _poisson(key, lam, shape, dtype) -> Array:
|
||||
|
||||
def poisson(key: KeyArrayLike,
|
||||
lam: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeInt = int) -> Array:
|
||||
r"""Sample Poisson random values with given shape and integer dtype.
|
||||
|
||||
@ -1555,7 +1556,7 @@ def _gumbel(key, shape, dtype) -> Array:
|
||||
def categorical(key: KeyArrayLike,
|
||||
logits: RealArray,
|
||||
axis: int = -1,
|
||||
shape: Optional[Shape] = None) -> Array:
|
||||
shape: Shape | None = None) -> Array:
|
||||
"""Sample random values from categorical distributions.
|
||||
|
||||
Args:
|
||||
@ -1669,7 +1670,7 @@ def _logistic(key, shape, dtype):
|
||||
|
||||
def pareto(key: KeyArrayLike,
|
||||
b: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Pareto random values with given shape and float dtype.
|
||||
|
||||
@ -1770,7 +1771,7 @@ def _t(key, df, shape, dtype) -> Array:
|
||||
|
||||
def chisquare(key: KeyArrayLike,
|
||||
df: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Chisquare random values with given shape and float dtype.
|
||||
|
||||
@ -1823,7 +1824,7 @@ def _chisquare(key, df, shape, dtype) -> Array:
|
||||
def f(key: KeyArrayLike,
|
||||
dfnum: RealArray,
|
||||
dfden: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample F-distribution random values with given shape and float dtype.
|
||||
|
||||
@ -2153,7 +2154,7 @@ def ball(
|
||||
|
||||
def rayleigh(key: KeyArrayLike,
|
||||
scale: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Rayleigh random values with given shape and float dtype.
|
||||
|
||||
@ -2206,7 +2207,7 @@ def _rayleigh(key, scale, shape, dtype) -> Array:
|
||||
|
||||
def wald(key: KeyArrayLike,
|
||||
mean: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Wald random values with given shape and float dtype.
|
||||
|
||||
@ -2264,7 +2265,7 @@ def _wald(key, mean, shape, dtype) -> Array:
|
||||
|
||||
def geometric(key: KeyArrayLike,
|
||||
p: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeInt = int) -> Array:
|
||||
r"""Sample Geometric random values with given shape and float dtype.
|
||||
|
||||
@ -2319,7 +2320,7 @@ def triangular(key: KeyArrayLike,
|
||||
left: RealArray,
|
||||
mode: RealArray,
|
||||
right: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Triangular random values with given shape and float dtype.
|
||||
|
||||
@ -2382,7 +2383,7 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array:
|
||||
|
||||
def lognormal(key: KeyArrayLike,
|
||||
sigma: RealArray = np.float32(1),
|
||||
shape: Optional[Shape] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r""" Sample lognormal random values with given shape and float dtype.
|
||||
|
||||
@ -2588,7 +2589,7 @@ def binomial(
|
||||
key: KeyArray,
|
||||
n: RealArray,
|
||||
p: RealArray,
|
||||
shape: Optional[Shape] = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float,
|
||||
) -> Array:
|
||||
r"""Sample Binomial random values with given shape and float dtype.
|
||||
|
@ -22,7 +22,7 @@ import enum
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
from typing import Any, NamedTuple, Union, cast, Optional
|
||||
from typing import Any, NamedTuple, Union, cast
|
||||
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.op_shardings import (
|
||||
@ -566,9 +566,9 @@ class PmapSharding(XLACompatibleSharding):
|
||||
|
||||
|
||||
def _op_sharding_to_pos_sharding(
|
||||
op_sharding: Union[xc.OpSharding, xc.HloSharding],
|
||||
op_sharding: xc.OpSharding | xc.HloSharding,
|
||||
device_assignment: Sequence[xc.Device],
|
||||
memory_kind: Optional[str] = None) -> PositionalSharding:
|
||||
memory_kind: str | None = None) -> PositionalSharding:
|
||||
if isinstance(op_sharding, xc.OpSharding):
|
||||
op_sharding = xc.HloSharding.from_proto(op_sharding) # type: ignore
|
||||
|
||||
|
@ -24,7 +24,7 @@ import re
|
||||
import os
|
||||
import tempfile
|
||||
import textwrap
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable
|
||||
import unittest
|
||||
import warnings
|
||||
import zlib
|
||||
@ -919,7 +919,7 @@ class JaxTestCase(parameterized.TestCase):
|
||||
'jax_legacy_prng_key': 'error',
|
||||
}
|
||||
|
||||
_compilation_cache_exit_stack: Optional[ExitStack] = None
|
||||
_compilation_cache_exit_stack: ExitStack | None = None
|
||||
|
||||
# TODO(mattjj): this obscures the error messages from failures, figure out how
|
||||
# to re-enable it
|
||||
@ -1275,8 +1275,8 @@ def numpy_version():
|
||||
|
||||
def parameterized_filterable(*,
|
||||
kwargs: Sequence[dict[str, Any]],
|
||||
testcase_name: Optional[Callable[[dict[str, Any]], str]] = None,
|
||||
one_containing: Optional[str] = None,
|
||||
testcase_name: Callable[[dict[str, Any]], str] | None = None,
|
||||
one_containing: str | None = None,
|
||||
):
|
||||
"""
|
||||
Decorator for named parameterized tests, with filtering.
|
||||
|
@ -21,7 +21,7 @@ import functools
|
||||
from functools import partial
|
||||
import operator as op
|
||||
import textwrap
|
||||
from typing import Any, Callable, NamedTuple, Type, TypeVar, Union, overload
|
||||
from typing import Any, Callable, NamedTuple, TypeVar, Union, overload
|
||||
|
||||
from jax._src import traceback_util
|
||||
from jax._src.lib import pytree
|
||||
@ -737,7 +737,7 @@ def register_pytree_with_keys_class(cls: U) -> U:
|
||||
return cls
|
||||
|
||||
|
||||
def register_static(cls: Type[H]) -> Type[H]:
|
||||
def register_static(cls: type[H]) -> type[H]:
|
||||
"""Registers `cls` as a pytree with no leaves.
|
||||
|
||||
Instances are treated as static by `jax.jit`, `jax.pmap`, etc. This can be an
|
||||
|
@ -32,7 +32,7 @@ import pkgutil
|
||||
import platform as py_platform
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
import warnings
|
||||
|
||||
from jax._src import config
|
||||
@ -855,7 +855,7 @@ def default_backend() -> str:
|
||||
return get_backend(None).platform
|
||||
|
||||
|
||||
def backend_pjrt_c_api_version(platform=None) -> Optional[Tuple[int, int]]:
|
||||
def backend_pjrt_c_api_version(platform=None) -> Optional[tuple[int, int]]:
|
||||
"""Returns the PJRT C API version of the backend.
|
||||
|
||||
Returns None if the backend does not use PJRT C API.
|
||||
|
@ -31,7 +31,7 @@ def _array_namespace(self, /, *, api_version: None | str = None):
|
||||
|
||||
|
||||
def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *,
|
||||
stream: Optional[Union[int, Any]] = None):
|
||||
stream: int | Any | None = None):
|
||||
if stream is not None:
|
||||
raise NotImplementedError("stream argument of array.to_device()")
|
||||
# The type of device is defined by Array.device. In JAX, this is a callable that
|
||||
|
@ -20,17 +20,17 @@ from jax import Array
|
||||
from jax.experimental.array_api._data_type_functions import result_type as _result_type
|
||||
|
||||
|
||||
def broadcast_arrays(*arrays: Array) -> List[Array]:
|
||||
def broadcast_arrays(*arrays: Array) -> list[Array]:
|
||||
"""Broadcasts one or more arrays against one another."""
|
||||
return jax.numpy.broadcast_arrays(*arrays)
|
||||
|
||||
|
||||
def broadcast_to(x: Array, /, shape: Tuple[int]) -> Array:
|
||||
def broadcast_to(x: Array, /, shape: tuple[int]) -> Array:
|
||||
"""Broadcasts an array to a specified shape."""
|
||||
return jax.numpy.broadcast_to(x, shape=shape)
|
||||
|
||||
|
||||
def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array:
|
||||
def concat(arrays: Union[tuple[Array, ...], list[Array]], /, *, axis: Optional[int] = 0) -> Array:
|
||||
"""Joins a sequence of arrays along an existing axis."""
|
||||
dtype = _result_type(*arrays)
|
||||
if axis is None:
|
||||
@ -46,34 +46,34 @@ def expand_dims(x: Array, /, *, axis: int = 0) -> Array:
|
||||
return jax.numpy.expand_dims(x, axis=axis)
|
||||
|
||||
|
||||
def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
|
||||
def flip(x: Array, /, *, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array:
|
||||
"""Reverses the order of elements in an array along the given axis."""
|
||||
return jax.numpy.flip(x, axis=axis)
|
||||
|
||||
|
||||
def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
|
||||
def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array:
|
||||
"""Permutes the axes (dimensions) of an array x."""
|
||||
return jax.lax.transpose(x, axes)
|
||||
|
||||
|
||||
def reshape(x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None) -> Array:
|
||||
def reshape(x: Array, /, shape: tuple[int, ...], *, copy: Optional[bool] = None) -> Array:
|
||||
"""Reshapes an array without changing its data."""
|
||||
del copy # unused
|
||||
return jax.numpy.reshape(x, shape)
|
||||
|
||||
|
||||
def roll(x: Array, /, shift: Union[int, Tuple[int]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
|
||||
def roll(x: Array, /, shift: Union[int, tuple[int]], *, axis: Optional[Union[int, tuple[int, ...]]] = None) -> Array:
|
||||
"""Rolls array elements along a specified axis."""
|
||||
return jax.numpy.roll(x, shift=shift, axis=axis)
|
||||
|
||||
|
||||
def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
|
||||
def squeeze(x: Array, /, axis: Union[int, tuple[int, ...]]) -> Array:
|
||||
"""Removes singleton dimensions (axes) from x."""
|
||||
dimensions = axis if isinstance(axis, tuple) else (axis,)
|
||||
return jax.lax.squeeze(x, dimensions=dimensions)
|
||||
|
||||
|
||||
def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
|
||||
def stack(arrays: Union[tuple[Array, ...], list[Array]], /, *, axis: int = 0) -> Array:
|
||||
"""Joins a sequence of arrays along a new axis."""
|
||||
dtype = _result_type(*arrays)
|
||||
return jax.numpy.stack(arrays, axis=axis, dtype=dtype)
|
||||
|
@ -2163,7 +2163,7 @@ tf_impl_with_avals[lax.dot_general_p] = _dot_general
|
||||
def _dot_general_convert_to_common_dtype(
|
||||
lhs: TfVal, lhs_aval: core.ShapedArray,
|
||||
rhs: TfVal, rhs_aval: core.ShapedArray,
|
||||
out_aval: core.ShapedArray) -> Tuple[TfVal, TfVal, Callable[[TfVal], TfVal]]:
|
||||
out_aval: core.ShapedArray) -> tuple[TfVal, TfVal, Callable[[TfVal], TfVal]]:
|
||||
# Returns the converted lhs, rhs, and the converter for the result.
|
||||
# tfxla.dot_general does not handle arguments of different types.
|
||||
# We convert the arguments and the result.
|
||||
|
@ -76,7 +76,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
|
||||
cls.tf_devices.append(tf_device)
|
||||
|
||||
super(CallTfTest, cls).setUpClass()
|
||||
super().setUpClass()
|
||||
|
||||
def setUp(self):
|
||||
if tf is None:
|
||||
|
@ -69,7 +69,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
|
||||
cls.tf_devices.append(tf_device)
|
||||
|
||||
super(Jax2TfTest, cls).setUpClass()
|
||||
super().setUpClass()
|
||||
|
||||
def test_empty(self):
|
||||
f_jax = lambda x, y: x
|
||||
|
@ -173,7 +173,7 @@ def mha(
|
||||
block_q: int = 128,
|
||||
block_k: int = 128,
|
||||
backward_pass_impl: str = "triton",
|
||||
num_warps: Optional[int] = None,
|
||||
num_warps: int | None = None,
|
||||
num_stages: int = 2,
|
||||
grid: tuple[int, ...] | None = None,
|
||||
interpret: bool = False,
|
||||
@ -239,7 +239,7 @@ def _mha_forward(
|
||||
block_q: int,
|
||||
block_k: int,
|
||||
backward_pass_impl: str,
|
||||
num_warps: Optional[int],
|
||||
num_warps: int | None,
|
||||
num_stages: int,
|
||||
grid: Any,
|
||||
interpret: bool,
|
||||
@ -451,7 +451,7 @@ def mha_backward_kernel(
|
||||
|
||||
|
||||
def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
|
||||
backward_pass_impl: str, num_warps: Optional[int],
|
||||
backward_pass_impl: str, num_warps: int | None,
|
||||
num_stages: int, grid: Any, interpret: bool,
|
||||
debug: bool, res, do):
|
||||
del num_warps, num_stages, grid
|
||||
|
@ -25,7 +25,7 @@ chunk, splits it in two, and sends each of the half-chunks in each direction
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
|
||||
from typing import Sequence
|
||||
from collections.abc import Sequence
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
|
@ -23,7 +23,7 @@ import shutil
|
||||
import sys
|
||||
import subprocess
|
||||
import glob
|
||||
from typing import Sequence
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
def is_windows() -> bool:
|
||||
@ -88,7 +88,7 @@ def build_editable(
|
||||
|
||||
def update_setup_with_cuda_version(file_dir: pathlib.Path, cuda_version: str):
|
||||
src_file = file_dir / "setup.py"
|
||||
with open(src_file, "r") as f:
|
||||
with open(src_file) as f:
|
||||
content = f.read()
|
||||
content = content.replace(
|
||||
"cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}"
|
||||
|
@ -72,7 +72,7 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
if d.platform not in cls.platforms:
|
||||
cls.platforms.append(d.platform)
|
||||
cls.devices.append(d)
|
||||
super(PrimitiveTest, cls).setUpClass()
|
||||
super().setUpClass()
|
||||
|
||||
# For each primitive we export for all platforms that are available and
|
||||
# compare the results of running the exported code and running the native
|
||||
|
@ -18,7 +18,6 @@ import functools
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from typing import Optional
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -98,7 +97,7 @@ mlir.register_lowering(testing_primitive_with_effect_p,
|
||||
_testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.)
|
||||
|
||||
def _testing_multi_platform_func(x, *,
|
||||
effect_class_name: Optional[str] = None):
|
||||
effect_class_name: str | None = None):
|
||||
# Behaves like x + 2 * _testing_multi_platform_to_add[platform]
|
||||
def for_platform(platform: str):
|
||||
if effect_class_name is None:
|
||||
@ -142,7 +141,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
except RuntimeError:
|
||||
continue
|
||||
cls.platforms.append(backend)
|
||||
super(JaxExportTest, cls).setUpClass()
|
||||
super().setUpClass()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
@ -266,7 +266,7 @@ class MemoriesComputationTest(jtu.BufferDonationTestCase):
|
||||
|
||||
expected_out = np_inp * 2
|
||||
self.assertArraysEqual(out, expected_out)
|
||||
self.assertEqual(out.sharding, s.with_memory_kind(("unpinned_host")))
|
||||
self.assertEqual(out.sharding, s.with_memory_kind("unpinned_host"))
|
||||
self._check_mem_kind(executable_mk[0], out.sharding, "unpinned_host")
|
||||
for s in out.addressable_shards:
|
||||
self.assertArraysEqual(s.data, expected_out[s.index])
|
||||
|
@ -65,7 +65,7 @@ def mock_tpu_devices(x, y, z, dev_kind, one_device_per_chip, num_slices=1,
|
||||
|
||||
# If this function raises, it's a bug in the test code!
|
||||
def _validate_mocked_process_indices(devices, one_device_per_chip):
|
||||
process_to_devices = collections.defaultdict(lambda: [])
|
||||
process_to_devices = collections.defaultdict(list)
|
||||
for d in devices:
|
||||
process_to_devices[d.process_index].append(d)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user