Migrated to mypy 1.14.1 with --allow_redefinition

I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,

   def f(x: int) -> str: ...
   def g(x: int) -> str: ...

   callback = f if ... else g  # has type object!
This commit is contained in:
Sergei Lebedev 2025-02-12 13:27:29 +00:00
parent 3ec7a67e51
commit 194884d311
15 changed files with 42 additions and 36 deletions

View File

@ -31,7 +31,7 @@ repos:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-mypy
rev: '102bbee94061ff02fd361ec29c27b7cb26582f5f' # frozen: v1.12.2
rev: 'bbc3dc1f890007061f18f17e2334f216ea9e5df7' # frozen: v1.14.1
hooks:
- id: mypy
files: (jax/|tests/typing_test\.py)

View File

@ -260,7 +260,7 @@ class Error:
cur_effect = None
for error_effect, code in self._code.items():
if self._pred[error_effect][idx]: # type: ignore
if min_code is None or code[idx] < min_code:
if min_code is None or code[idx] < min_code: # type: ignore
min_code = code[idx] # type: ignore
cur_effect = error_effect

View File

@ -597,8 +597,12 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *,
else:
color = None
text_color = None
padding = (top_padding, right_padding, bottom_padding, left_padding)
padding = tuple(max(x, 0) for x in padding) # type: ignore
padding = (
max(top_padding, 0),
max(right_padding, 0),
max(bottom_padding, 0),
max(left_padding, 0),
)
col.append(
rich.padding.Padding(
rich.align.Align(entry, "center", vertical="middle"), padding,

View File

@ -17,6 +17,7 @@ from __future__ import annotations
from collections.abc import Callable, Sequence
from functools import partial
import enum
from typing import Any
import numpy as np
@ -263,7 +264,7 @@ def _resize_nearest(x, output_shape: core.Shape):
# TODO(b/206898375): this computation produces the wrong result on
# CPU and GPU when using float64. Use float32 until the bug is fixed.
offsets = jnp.floor(offsets.astype(np.float32)).astype(np.int32)
indices = [slice(None)] * len(input_shape)
indices: list[Any] = [slice(None)] * len(input_shape)
indices[d] = offsets
x = x[tuple(indices)]
return x

View File

@ -2003,7 +2003,9 @@ def jaxpr_transfer_mem_kinds(
return out
def are_all_shardings_default_mem_kind(da_object: xc.DeviceList, shardings):
def are_all_shardings_default_mem_kind(
da_object: xc.DeviceList | None, shardings
):
if da_object is None:
return True
try:
@ -2339,8 +2341,8 @@ def lower_sharding_computation(
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
semantic_out_shardings, in_layouts, out_layouts, num_devices,
tuple(da_object) if prim_requires_devices else None, donated_invars,
name_stack, all_default_mem_kind, inout_aliases,
tuple(da_object) if prim_requires_devices else None, # type: ignore[arg-type]
donated_invars, name_stack, all_default_mem_kind, inout_aliases,
propagated_out_mem_kinds, platforms,
lowering_parameters=lowering_parameters,
abstract_mesh=abstract_mesh)

View File

@ -437,7 +437,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
del p, x
shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape)
y = lax.full_like(x_arr, 0, shape=shape, dtype=x_arr.dtype)
y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll)
y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) # type: ignore[misc]
return y

View File

@ -461,7 +461,7 @@ class ufunc:
idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices)
a = a.at[idx].set(self(a.at[idx].get(), *(arg[i] for arg in args)))
return (i + 1, a), x
carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0]))
carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0])) # type: ignore[arg-type]
return carry[1]
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])

View File

@ -94,7 +94,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]:
Promotes arguments to an inexact type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment]
to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype)
to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) # type: ignore[arg-type]
return [lax._convert_element_type(x, to_dtype_inexact, weak_type)
for x in args]

View File

@ -494,7 +494,7 @@ def _wgmma_lowering(
b_transforms_tree,
):
_, a_aval, *_ = ctx.avals_in
lhs_swizzle = None
lhs_swizzle: int | None = None
if a_transforms_tree is not None:
a_transforms_leaves, b_transforms_leaves = util.split_list(
transforms_leaves, [a_transforms_tree.num_leaves]

View File

@ -729,7 +729,7 @@ def _pallas_call_batching_rule(
for pos, invar in enumerate(jaxpr.invars):
ragged_axis_values[pos] = var_to_raggedness[invar]
per_input_ragged_axis_dim = []
per_input_ragged_axis_dim: list[int | None] = []
for rav in ragged_axis_values:
if rav is not None:
per_input_ragged_axis_dim.append(rav[1])

View File

@ -17,7 +17,7 @@
from __future__ import annotations
import dataclasses
from typing import Any, Union
from typing import Any, Sequence, Union
from jax._src import core
from jax._src import tree_util
@ -198,20 +198,20 @@ class NDIndexer:
# TODO(slebedev): Consider requiring `indices` to be a Sequence.
indices = (indices,)
indices = list(indices)
if num_ellipsis := sum(idx is ... for idx in indices):
if num_ellipsis > 1:
raise ValueError("Only one ellipsis is supported.")
# Expand ... so that `indices` has the same length as `shape`.
ip = indices.index(...)
indices = list(indices)
indices[ip:ip+1] = [slice(None)] * (len(shape) - len(indices) + 1)
if len(indices) > len(shape):
indices = tuple(indices)
if len(indices) > len(shape):
raise ValueError("`indices` must not be longer than `shape`: "
f"{indices=}, {shape=}")
elif len(indices) < len(shape):
# Pad `indices` to have the same length as `shape`.
indices.extend([slice(None)] * (len(shape) - len(indices)))
indices = (*indices, *[slice(None)] * (len(shape) - len(indices)))
# Promote all builtin `slice`s to `Slice`.
indices = tuple(
@ -220,6 +220,7 @@ class NDIndexer:
is_int_indexing = [not isinstance(i, Slice) for i in indices]
if any(is_int_indexing):
int_indexers: Sequence[Any]
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
indexer_shapes = tuple(core.get_aval(i).shape for i in int_indexers)
try:

View File

@ -84,7 +84,7 @@ TODO:
"""
from functools import partial
import math
from typing import cast, Any
from typing import Any
import jax
import numpy as np
@ -256,15 +256,15 @@ def _lstm_cudnn_allow_tf32(precision: lax.PrecisionLike) -> bool:
if precision is None or not (isinstance(precision, tuple) and len(precision) == 2):
return True
# cuDNN allows only one precision specifier per RNN op
precision, _ = cast(tuple[lax.Precision, lax.Precision], precision)
if precision == lax.Precision.HIGHEST:
return False
elif precision == lax.Precision.HIGH:
return True
elif precision == lax.Precision.DEFAULT: # bfloat16
raise NotImplementedError("bfloat16 support not implemented for LSTM")
else:
raise ValueError(f"Unexpected precision specifier value {precision}")
match precision:
case (lax.Precision.HIGHEST, _):
return False
case (lax.Precision.HIGH, _):
return True
case (lax.Precision.DEFAULT, _): # bfloat16
raise NotImplementedError("bfloat16 support not implemented for LSTM")
case _:
raise ValueError(f"Unexpected precision specifier value {precision}")
@partial(custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10))

View File

@ -13,9 +13,7 @@
# limitations under the License.
import enum
from typing import ClassVar
class ComputationMode(enum.Enum):
kComputeEigenvectors: ClassVar[ComputationMode] = ...
kNoEigenvectors: ClassVar[ComputationMode] = ...
kComputeEigenvectors = ...
kNoEigenvectors = ...

View File

@ -13,14 +13,13 @@
# limitations under the License.
import enum
from typing import ClassVar
class ComputationMode(enum.Enum):
kComputeSchurVectors: ClassVar[ComputationMode]
kNoComputeSchurVectors: ClassVar[ComputationMode]
kComputeSchurVectors = ...
kNoComputeSchurVectors = ...
class Sort(enum.Enum):
kNoSortEigenvalues: ClassVar[Sort]
kSortEigenvalues: ClassVar[Sort]
kNoSortEigenvalues = ...
kSortEigenvalues = ...

View File

@ -7,6 +7,7 @@ show_error_codes = true
disable_error_code = "attr-defined, name-defined, annotation-unchecked"
no_implicit_optional = true
warn_redundant_casts = true
allow_redefinition = true
[[tool.mypy.overrides]]
module = [