diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 763959254..27ccc6d83 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: - id: ruff - repo: https://github.com/pre-commit/mirrors-mypy - rev: '102bbee94061ff02fd361ec29c27b7cb26582f5f' # frozen: v1.12.2 + rev: 'bbc3dc1f890007061f18f17e2334f216ea9e5df7' # frozen: v1.14.1 hooks: - id: mypy files: (jax/|tests/typing_test\.py) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index a77a6456c..831c4488f 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -260,7 +260,7 @@ class Error: cur_effect = None for error_effect, code in self._code.items(): if self._pred[error_effect][idx]: # type: ignore - if min_code is None or code[idx] < min_code: + if min_code is None or code[idx] < min_code: # type: ignore min_code = code[idx] # type: ignore cur_effect = error_effect diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 7685ac2bf..13fe70c54 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -597,8 +597,12 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *, else: color = None text_color = None - padding = (top_padding, right_padding, bottom_padding, left_padding) - padding = tuple(max(x, 0) for x in padding) # type: ignore + padding = ( + max(top_padding, 0), + max(right_padding, 0), + max(bottom_padding, 0), + max(left_padding, 0), + ) col.append( rich.padding.Padding( rich.align.Align(entry, "center", vertical="middle"), padding, diff --git a/jax/_src/image/scale.py b/jax/_src/image/scale.py index aa9910555..faaee6a54 100644 --- a/jax/_src/image/scale.py +++ b/jax/_src/image/scale.py @@ -17,6 +17,7 @@ from __future__ import annotations from collections.abc import Callable, Sequence from functools import partial import enum +from typing import Any import numpy as np @@ -263,7 +264,7 @@ def _resize_nearest(x, output_shape: core.Shape): # TODO(b/206898375): this computation produces the wrong result on # CPU and GPU when using float64. Use float32 until the bug is fixed. offsets = jnp.floor(offsets.astype(np.float32)).astype(np.int32) - indices = [slice(None)] * len(input_shape) + indices: list[Any] = [slice(None)] * len(input_shape) indices[d] = offsets x = x[tuple(indices)] return x diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index e4b528136..e408e4c0e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2003,7 +2003,9 @@ def jaxpr_transfer_mem_kinds( return out -def are_all_shardings_default_mem_kind(da_object: xc.DeviceList, shardings): +def are_all_shardings_default_mem_kind( + da_object: xc.DeviceList | None, shardings +): if da_object is None: return True try: @@ -2339,8 +2341,8 @@ def lower_sharding_computation( nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, in_layouts, out_layouts, num_devices, - tuple(da_object) if prim_requires_devices else None, donated_invars, - name_stack, all_default_mem_kind, inout_aliases, + tuple(da_object) if prim_requires_devices else None, # type: ignore[arg-type] + donated_invars, name_stack, all_default_mem_kind, inout_aliases, propagated_out_mem_kinds, platforms, lowering_parameters=lowering_parameters, abstract_mesh=abstract_mesh) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 49c119954..81d320cb7 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -437,7 +437,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: del p, x shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape) y = lax.full_like(x_arr, 0, shape=shape, dtype=x_arr.dtype) - y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) + y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) # type: ignore[misc] return y diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 5dbd67e62..9d1f0840c 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -461,7 +461,7 @@ class ufunc: idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices) a = a.at[idx].set(self(a.at[idx].get(), *(arg[i] for arg in args))) return (i + 1, a), x - carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0])) + carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0])) # type: ignore[arg-type] return carry[1] @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index b5a1a3bb6..e902950e7 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -94,7 +94,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]: Promotes arguments to an inexact type.""" to_dtype, weak_type = dtypes._lattice_result_type(*args) to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment] - to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) + to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) # type: ignore[arg-type] return [lax._convert_element_type(x, to_dtype_inexact, weak_type) for x in args] diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 4fd63fb9f..01adc42ea 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -494,7 +494,7 @@ def _wgmma_lowering( b_transforms_tree, ): _, a_aval, *_ = ctx.avals_in - lhs_swizzle = None + lhs_swizzle: int | None = None if a_transforms_tree is not None: a_transforms_leaves, b_transforms_leaves = util.split_list( transforms_leaves, [a_transforms_tree.num_leaves] diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 583754cde..7e0bf5830 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -729,7 +729,7 @@ def _pallas_call_batching_rule( for pos, invar in enumerate(jaxpr.invars): ragged_axis_values[pos] = var_to_raggedness[invar] - per_input_ragged_axis_dim = [] + per_input_ragged_axis_dim: list[int | None] = [] for rav in ragged_axis_values: if rav is not None: per_input_ragged_axis_dim.append(rav[1]) diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 2da93e3d8..7abaa3185 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -17,7 +17,7 @@ from __future__ import annotations import dataclasses -from typing import Any, Union +from typing import Any, Sequence, Union from jax._src import core from jax._src import tree_util @@ -198,20 +198,20 @@ class NDIndexer: # TODO(slebedev): Consider requiring `indices` to be a Sequence. indices = (indices,) - indices = list(indices) if num_ellipsis := sum(idx is ... for idx in indices): if num_ellipsis > 1: raise ValueError("Only one ellipsis is supported.") # Expand ... so that `indices` has the same length as `shape`. ip = indices.index(...) + indices = list(indices) indices[ip:ip+1] = [slice(None)] * (len(shape) - len(indices) + 1) - if len(indices) > len(shape): indices = tuple(indices) + if len(indices) > len(shape): raise ValueError("`indices` must not be longer than `shape`: " f"{indices=}, {shape=}") elif len(indices) < len(shape): # Pad `indices` to have the same length as `shape`. - indices.extend([slice(None)] * (len(shape) - len(indices))) + indices = (*indices, *[slice(None)] * (len(shape) - len(indices))) # Promote all builtin `slice`s to `Slice`. indices = tuple( @@ -220,6 +220,7 @@ class NDIndexer: is_int_indexing = [not isinstance(i, Slice) for i in indices] if any(is_int_indexing): + int_indexers: Sequence[Any] other_indexers, int_indexers = partition_list(is_int_indexing, indices) indexer_shapes = tuple(core.get_aval(i).shape for i in int_indexers) try: diff --git a/jax/experimental/rnn.py b/jax/experimental/rnn.py index 1629e7204..55cf2b3ba 100644 --- a/jax/experimental/rnn.py +++ b/jax/experimental/rnn.py @@ -84,7 +84,7 @@ TODO: """ from functools import partial import math -from typing import cast, Any +from typing import Any import jax import numpy as np @@ -256,15 +256,15 @@ def _lstm_cudnn_allow_tf32(precision: lax.PrecisionLike) -> bool: if precision is None or not (isinstance(precision, tuple) and len(precision) == 2): return True # cuDNN allows only one precision specifier per RNN op - precision, _ = cast(tuple[lax.Precision, lax.Precision], precision) - if precision == lax.Precision.HIGHEST: - return False - elif precision == lax.Precision.HIGH: - return True - elif precision == lax.Precision.DEFAULT: # bfloat16 - raise NotImplementedError("bfloat16 support not implemented for LSTM") - else: - raise ValueError(f"Unexpected precision specifier value {precision}") + match precision: + case (lax.Precision.HIGHEST, _): + return False + case (lax.Precision.HIGH, _): + return True + case (lax.Precision.DEFAULT, _): # bfloat16 + raise NotImplementedError("bfloat16 support not implemented for LSTM") + case _: + raise ValueError(f"Unexpected precision specifier value {precision}") @partial(custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10)) diff --git a/jaxlib/cpu/_lapack/eig.pyi b/jaxlib/cpu/_lapack/eig.pyi index 338c15402..75fea332b 100644 --- a/jaxlib/cpu/_lapack/eig.pyi +++ b/jaxlib/cpu/_lapack/eig.pyi @@ -13,9 +13,7 @@ # limitations under the License. import enum -from typing import ClassVar - class ComputationMode(enum.Enum): - kComputeEigenvectors: ClassVar[ComputationMode] = ... - kNoEigenvectors: ClassVar[ComputationMode] = ... + kComputeEigenvectors = ... + kNoEigenvectors = ... diff --git a/jaxlib/cpu/_lapack/schur.pyi b/jaxlib/cpu/_lapack/schur.pyi index add01b049..5153ad864 100644 --- a/jaxlib/cpu/_lapack/schur.pyi +++ b/jaxlib/cpu/_lapack/schur.pyi @@ -13,14 +13,13 @@ # limitations under the License. import enum -from typing import ClassVar class ComputationMode(enum.Enum): - kComputeSchurVectors: ClassVar[ComputationMode] - kNoComputeSchurVectors: ClassVar[ComputationMode] + kComputeSchurVectors = ... + kNoComputeSchurVectors = ... class Sort(enum.Enum): - kNoSortEigenvalues: ClassVar[Sort] - kSortEigenvalues: ClassVar[Sort] + kNoSortEigenvalues = ... + kSortEigenvalues = ... diff --git a/pyproject.toml b/pyproject.toml index c804fed2a..e32b14a89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ show_error_codes = true disable_error_code = "attr-defined, name-defined, annotation-unchecked" no_implicit_optional = true warn_redundant_casts = true +allow_redefinition = true [[tool.mypy.overrides]] module = [