1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 11:16:10 +00:00

Migrated to mypy 1.14.1 with --allow_redefinition

I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,

   def f(x: int) -> str: ...
   def g(x: int) -> str: ...

   callback = f if ... else g  # has type object!
This commit is contained in:
Sergei Lebedev 2025-02-12 13:27:29 +00:00
parent 3ec7a67e51
commit 194884d311
15 changed files with 42 additions and 36 deletions

@ -31,7 +31,7 @@ repos:
- id: ruff - id: ruff
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: '102bbee94061ff02fd361ec29c27b7cb26582f5f' # frozen: v1.12.2 rev: 'bbc3dc1f890007061f18f17e2334f216ea9e5df7' # frozen: v1.14.1
hooks: hooks:
- id: mypy - id: mypy
files: (jax/|tests/typing_test\.py) files: (jax/|tests/typing_test\.py)

@ -260,7 +260,7 @@ class Error:
cur_effect = None cur_effect = None
for error_effect, code in self._code.items(): for error_effect, code in self._code.items():
if self._pred[error_effect][idx]: # type: ignore if self._pred[error_effect][idx]: # type: ignore
if min_code is None or code[idx] < min_code: if min_code is None or code[idx] < min_code: # type: ignore
min_code = code[idx] # type: ignore min_code = code[idx] # type: ignore
cur_effect = error_effect cur_effect = error_effect

@ -597,8 +597,12 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *,
else: else:
color = None color = None
text_color = None text_color = None
padding = (top_padding, right_padding, bottom_padding, left_padding) padding = (
padding = tuple(max(x, 0) for x in padding) # type: ignore max(top_padding, 0),
max(right_padding, 0),
max(bottom_padding, 0),
max(left_padding, 0),
)
col.append( col.append(
rich.padding.Padding( rich.padding.Padding(
rich.align.Align(entry, "center", vertical="middle"), padding, rich.align.Align(entry, "center", vertical="middle"), padding,

@ -17,6 +17,7 @@ from __future__ import annotations
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from functools import partial from functools import partial
import enum import enum
from typing import Any
import numpy as np import numpy as np
@ -263,7 +264,7 @@ def _resize_nearest(x, output_shape: core.Shape):
# TODO(b/206898375): this computation produces the wrong result on # TODO(b/206898375): this computation produces the wrong result on
# CPU and GPU when using float64. Use float32 until the bug is fixed. # CPU and GPU when using float64. Use float32 until the bug is fixed.
offsets = jnp.floor(offsets.astype(np.float32)).astype(np.int32) offsets = jnp.floor(offsets.astype(np.float32)).astype(np.int32)
indices = [slice(None)] * len(input_shape) indices: list[Any] = [slice(None)] * len(input_shape)
indices[d] = offsets indices[d] = offsets
x = x[tuple(indices)] x = x[tuple(indices)]
return x return x

@ -2003,7 +2003,9 @@ def jaxpr_transfer_mem_kinds(
return out return out
def are_all_shardings_default_mem_kind(da_object: xc.DeviceList, shardings): def are_all_shardings_default_mem_kind(
da_object: xc.DeviceList | None, shardings
):
if da_object is None: if da_object is None:
return True return True
try: try:
@ -2339,8 +2341,8 @@ def lower_sharding_computation(
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
semantic_out_shardings, in_layouts, out_layouts, num_devices, semantic_out_shardings, in_layouts, out_layouts, num_devices,
tuple(da_object) if prim_requires_devices else None, donated_invars, tuple(da_object) if prim_requires_devices else None, # type: ignore[arg-type]
name_stack, all_default_mem_kind, inout_aliases, donated_invars, name_stack, all_default_mem_kind, inout_aliases,
propagated_out_mem_kinds, platforms, propagated_out_mem_kinds, platforms,
lowering_parameters=lowering_parameters, lowering_parameters=lowering_parameters,
abstract_mesh=abstract_mesh) abstract_mesh=abstract_mesh)

@ -437,7 +437,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
del p, x del p, x
shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape) shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape)
y = lax.full_like(x_arr, 0, shape=shape, dtype=x_arr.dtype) y = lax.full_like(x_arr, 0, shape=shape, dtype=x_arr.dtype)
y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) # type: ignore[misc]
return y return y

@ -461,7 +461,7 @@ class ufunc:
idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices) idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices)
a = a.at[idx].set(self(a.at[idx].get(), *(arg[i] for arg in args))) a = a.at[idx].set(self(a.at[idx].get(), *(arg[i] for arg in args)))
return (i + 1, a), x return (i + 1, a), x
carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0])) carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0])) # type: ignore[arg-type]
return carry[1] return carry[1]
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) @partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])

@ -94,7 +94,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]:
Promotes arguments to an inexact type.""" Promotes arguments to an inexact type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args) to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment] to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment]
to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) # type: ignore[arg-type]
return [lax._convert_element_type(x, to_dtype_inexact, weak_type) return [lax._convert_element_type(x, to_dtype_inexact, weak_type)
for x in args] for x in args]

@ -494,7 +494,7 @@ def _wgmma_lowering(
b_transforms_tree, b_transforms_tree,
): ):
_, a_aval, *_ = ctx.avals_in _, a_aval, *_ = ctx.avals_in
lhs_swizzle = None lhs_swizzle: int | None = None
if a_transforms_tree is not None: if a_transforms_tree is not None:
a_transforms_leaves, b_transforms_leaves = util.split_list( a_transforms_leaves, b_transforms_leaves = util.split_list(
transforms_leaves, [a_transforms_tree.num_leaves] transforms_leaves, [a_transforms_tree.num_leaves]

@ -729,7 +729,7 @@ def _pallas_call_batching_rule(
for pos, invar in enumerate(jaxpr.invars): for pos, invar in enumerate(jaxpr.invars):
ragged_axis_values[pos] = var_to_raggedness[invar] ragged_axis_values[pos] = var_to_raggedness[invar]
per_input_ragged_axis_dim = [] per_input_ragged_axis_dim: list[int | None] = []
for rav in ragged_axis_values: for rav in ragged_axis_values:
if rav is not None: if rav is not None:
per_input_ragged_axis_dim.append(rav[1]) per_input_ragged_axis_dim.append(rav[1])

@ -17,7 +17,7 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
from typing import Any, Union from typing import Any, Sequence, Union
from jax._src import core from jax._src import core
from jax._src import tree_util from jax._src import tree_util
@ -198,20 +198,20 @@ class NDIndexer:
# TODO(slebedev): Consider requiring `indices` to be a Sequence. # TODO(slebedev): Consider requiring `indices` to be a Sequence.
indices = (indices,) indices = (indices,)
indices = list(indices)
if num_ellipsis := sum(idx is ... for idx in indices): if num_ellipsis := sum(idx is ... for idx in indices):
if num_ellipsis > 1: if num_ellipsis > 1:
raise ValueError("Only one ellipsis is supported.") raise ValueError("Only one ellipsis is supported.")
# Expand ... so that `indices` has the same length as `shape`. # Expand ... so that `indices` has the same length as `shape`.
ip = indices.index(...) ip = indices.index(...)
indices = list(indices)
indices[ip:ip+1] = [slice(None)] * (len(shape) - len(indices) + 1) indices[ip:ip+1] = [slice(None)] * (len(shape) - len(indices) + 1)
if len(indices) > len(shape):
indices = tuple(indices) indices = tuple(indices)
if len(indices) > len(shape):
raise ValueError("`indices` must not be longer than `shape`: " raise ValueError("`indices` must not be longer than `shape`: "
f"{indices=}, {shape=}") f"{indices=}, {shape=}")
elif len(indices) < len(shape): elif len(indices) < len(shape):
# Pad `indices` to have the same length as `shape`. # Pad `indices` to have the same length as `shape`.
indices.extend([slice(None)] * (len(shape) - len(indices))) indices = (*indices, *[slice(None)] * (len(shape) - len(indices)))
# Promote all builtin `slice`s to `Slice`. # Promote all builtin `slice`s to `Slice`.
indices = tuple( indices = tuple(
@ -220,6 +220,7 @@ class NDIndexer:
is_int_indexing = [not isinstance(i, Slice) for i in indices] is_int_indexing = [not isinstance(i, Slice) for i in indices]
if any(is_int_indexing): if any(is_int_indexing):
int_indexers: Sequence[Any]
other_indexers, int_indexers = partition_list(is_int_indexing, indices) other_indexers, int_indexers = partition_list(is_int_indexing, indices)
indexer_shapes = tuple(core.get_aval(i).shape for i in int_indexers) indexer_shapes = tuple(core.get_aval(i).shape for i in int_indexers)
try: try:

@ -84,7 +84,7 @@ TODO:
""" """
from functools import partial from functools import partial
import math import math
from typing import cast, Any from typing import Any
import jax import jax
import numpy as np import numpy as np
@ -256,14 +256,14 @@ def _lstm_cudnn_allow_tf32(precision: lax.PrecisionLike) -> bool:
if precision is None or not (isinstance(precision, tuple) and len(precision) == 2): if precision is None or not (isinstance(precision, tuple) and len(precision) == 2):
return True return True
# cuDNN allows only one precision specifier per RNN op # cuDNN allows only one precision specifier per RNN op
precision, _ = cast(tuple[lax.Precision, lax.Precision], precision) match precision:
if precision == lax.Precision.HIGHEST: case (lax.Precision.HIGHEST, _):
return False return False
elif precision == lax.Precision.HIGH: case (lax.Precision.HIGH, _):
return True return True
elif precision == lax.Precision.DEFAULT: # bfloat16 case (lax.Precision.DEFAULT, _): # bfloat16
raise NotImplementedError("bfloat16 support not implemented for LSTM") raise NotImplementedError("bfloat16 support not implemented for LSTM")
else: case _:
raise ValueError(f"Unexpected precision specifier value {precision}") raise ValueError(f"Unexpected precision specifier value {precision}")

@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
import enum import enum
from typing import ClassVar
class ComputationMode(enum.Enum): class ComputationMode(enum.Enum):
kComputeEigenvectors: ClassVar[ComputationMode] = ... kComputeEigenvectors = ...
kNoEigenvectors: ClassVar[ComputationMode] = ... kNoEigenvectors = ...

@ -13,14 +13,13 @@
# limitations under the License. # limitations under the License.
import enum import enum
from typing import ClassVar
class ComputationMode(enum.Enum): class ComputationMode(enum.Enum):
kComputeSchurVectors: ClassVar[ComputationMode] kComputeSchurVectors = ...
kNoComputeSchurVectors: ClassVar[ComputationMode] kNoComputeSchurVectors = ...
class Sort(enum.Enum): class Sort(enum.Enum):
kNoSortEigenvalues: ClassVar[Sort] kNoSortEigenvalues = ...
kSortEigenvalues: ClassVar[Sort] kSortEigenvalues = ...

@ -7,6 +7,7 @@ show_error_codes = true
disable_error_code = "attr-defined, name-defined, annotation-unchecked" disable_error_code = "attr-defined, name-defined, annotation-unchecked"
no_implicit_optional = true no_implicit_optional = true
warn_redundant_casts = true warn_redundant_casts = true
allow_redefinition = true
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = [ module = [