mirror of
https://github.com/ROCm/jax.git
synced 2025-04-13 02:16:06 +00:00
Migrated to mypy 1.14.1 with --allow_redefinition
I initially wanted to upgrade to 1.15, but it seems to have a bug in how ternary expressions are type checked. For example, def f(x: int) -> str: ... def g(x: int) -> str: ... callback = f if ... else g # has type object!
This commit is contained in:
parent
3ec7a67e51
commit
194884d311
@ -31,7 +31,7 @@ repos:
|
||||
- id: ruff
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: '102bbee94061ff02fd361ec29c27b7cb26582f5f' # frozen: v1.12.2
|
||||
rev: 'bbc3dc1f890007061f18f17e2334f216ea9e5df7' # frozen: v1.14.1
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: (jax/|tests/typing_test\.py)
|
||||
|
@ -260,7 +260,7 @@ class Error:
|
||||
cur_effect = None
|
||||
for error_effect, code in self._code.items():
|
||||
if self._pred[error_effect][idx]: # type: ignore
|
||||
if min_code is None or code[idx] < min_code:
|
||||
if min_code is None or code[idx] < min_code: # type: ignore
|
||||
min_code = code[idx] # type: ignore
|
||||
cur_effect = error_effect
|
||||
|
||||
|
@ -597,8 +597,12 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *,
|
||||
else:
|
||||
color = None
|
||||
text_color = None
|
||||
padding = (top_padding, right_padding, bottom_padding, left_padding)
|
||||
padding = tuple(max(x, 0) for x in padding) # type: ignore
|
||||
padding = (
|
||||
max(top_padding, 0),
|
||||
max(right_padding, 0),
|
||||
max(bottom_padding, 0),
|
||||
max(left_padding, 0),
|
||||
)
|
||||
col.append(
|
||||
rich.padding.Padding(
|
||||
rich.align.Align(entry, "center", vertical="middle"), padding,
|
||||
|
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
import enum
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -263,7 +264,7 @@ def _resize_nearest(x, output_shape: core.Shape):
|
||||
# TODO(b/206898375): this computation produces the wrong result on
|
||||
# CPU and GPU when using float64. Use float32 until the bug is fixed.
|
||||
offsets = jnp.floor(offsets.astype(np.float32)).astype(np.int32)
|
||||
indices = [slice(None)] * len(input_shape)
|
||||
indices: list[Any] = [slice(None)] * len(input_shape)
|
||||
indices[d] = offsets
|
||||
x = x[tuple(indices)]
|
||||
return x
|
||||
|
@ -2003,7 +2003,9 @@ def jaxpr_transfer_mem_kinds(
|
||||
return out
|
||||
|
||||
|
||||
def are_all_shardings_default_mem_kind(da_object: xc.DeviceList, shardings):
|
||||
def are_all_shardings_default_mem_kind(
|
||||
da_object: xc.DeviceList | None, shardings
|
||||
):
|
||||
if da_object is None:
|
||||
return True
|
||||
try:
|
||||
@ -2339,8 +2341,8 @@ def lower_sharding_computation(
|
||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
||||
semantic_out_shardings, in_layouts, out_layouts, num_devices,
|
||||
tuple(da_object) if prim_requires_devices else None, donated_invars,
|
||||
name_stack, all_default_mem_kind, inout_aliases,
|
||||
tuple(da_object) if prim_requires_devices else None, # type: ignore[arg-type]
|
||||
donated_invars, name_stack, all_default_mem_kind, inout_aliases,
|
||||
propagated_out_mem_kinds, platforms,
|
||||
lowering_parameters=lowering_parameters,
|
||||
abstract_mesh=abstract_mesh)
|
||||
|
@ -437,7 +437,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
|
||||
del p, x
|
||||
shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape)
|
||||
y = lax.full_like(x_arr, 0, shape=shape, dtype=x_arr.dtype)
|
||||
y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll)
|
||||
y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) # type: ignore[misc]
|
||||
return y
|
||||
|
||||
|
||||
|
@ -461,7 +461,7 @@ class ufunc:
|
||||
idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices)
|
||||
a = a.at[idx].set(self(a.at[idx].get(), *(arg[i] for arg in args)))
|
||||
return (i + 1, a), x
|
||||
carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0]))
|
||||
carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0])) # type: ignore[arg-type]
|
||||
return carry[1]
|
||||
|
||||
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
|
||||
|
@ -94,7 +94,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]:
|
||||
Promotes arguments to an inexact type."""
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment]
|
||||
to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype)
|
||||
to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) # type: ignore[arg-type]
|
||||
return [lax._convert_element_type(x, to_dtype_inexact, weak_type)
|
||||
for x in args]
|
||||
|
||||
|
@ -494,7 +494,7 @@ def _wgmma_lowering(
|
||||
b_transforms_tree,
|
||||
):
|
||||
_, a_aval, *_ = ctx.avals_in
|
||||
lhs_swizzle = None
|
||||
lhs_swizzle: int | None = None
|
||||
if a_transforms_tree is not None:
|
||||
a_transforms_leaves, b_transforms_leaves = util.split_list(
|
||||
transforms_leaves, [a_transforms_tree.num_leaves]
|
||||
|
@ -729,7 +729,7 @@ def _pallas_call_batching_rule(
|
||||
for pos, invar in enumerate(jaxpr.invars):
|
||||
ragged_axis_values[pos] = var_to_raggedness[invar]
|
||||
|
||||
per_input_ragged_axis_dim = []
|
||||
per_input_ragged_axis_dim: list[int | None] = []
|
||||
for rav in ragged_axis_values:
|
||||
if rav is not None:
|
||||
per_input_ragged_axis_dim.append(rav[1])
|
||||
|
@ -17,7 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Union
|
||||
from typing import Any, Sequence, Union
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import tree_util
|
||||
@ -198,20 +198,20 @@ class NDIndexer:
|
||||
# TODO(slebedev): Consider requiring `indices` to be a Sequence.
|
||||
indices = (indices,)
|
||||
|
||||
indices = list(indices)
|
||||
if num_ellipsis := sum(idx is ... for idx in indices):
|
||||
if num_ellipsis > 1:
|
||||
raise ValueError("Only one ellipsis is supported.")
|
||||
# Expand ... so that `indices` has the same length as `shape`.
|
||||
ip = indices.index(...)
|
||||
indices = list(indices)
|
||||
indices[ip:ip+1] = [slice(None)] * (len(shape) - len(indices) + 1)
|
||||
if len(indices) > len(shape):
|
||||
indices = tuple(indices)
|
||||
if len(indices) > len(shape):
|
||||
raise ValueError("`indices` must not be longer than `shape`: "
|
||||
f"{indices=}, {shape=}")
|
||||
elif len(indices) < len(shape):
|
||||
# Pad `indices` to have the same length as `shape`.
|
||||
indices.extend([slice(None)] * (len(shape) - len(indices)))
|
||||
indices = (*indices, *[slice(None)] * (len(shape) - len(indices)))
|
||||
|
||||
# Promote all builtin `slice`s to `Slice`.
|
||||
indices = tuple(
|
||||
@ -220,6 +220,7 @@ class NDIndexer:
|
||||
|
||||
is_int_indexing = [not isinstance(i, Slice) for i in indices]
|
||||
if any(is_int_indexing):
|
||||
int_indexers: Sequence[Any]
|
||||
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
|
||||
indexer_shapes = tuple(core.get_aval(i).shape for i in int_indexers)
|
||||
try:
|
||||
|
@ -84,7 +84,7 @@ TODO:
|
||||
"""
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import cast, Any
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
@ -256,15 +256,15 @@ def _lstm_cudnn_allow_tf32(precision: lax.PrecisionLike) -> bool:
|
||||
if precision is None or not (isinstance(precision, tuple) and len(precision) == 2):
|
||||
return True
|
||||
# cuDNN allows only one precision specifier per RNN op
|
||||
precision, _ = cast(tuple[lax.Precision, lax.Precision], precision)
|
||||
if precision == lax.Precision.HIGHEST:
|
||||
return False
|
||||
elif precision == lax.Precision.HIGH:
|
||||
return True
|
||||
elif precision == lax.Precision.DEFAULT: # bfloat16
|
||||
raise NotImplementedError("bfloat16 support not implemented for LSTM")
|
||||
else:
|
||||
raise ValueError(f"Unexpected precision specifier value {precision}")
|
||||
match precision:
|
||||
case (lax.Precision.HIGHEST, _):
|
||||
return False
|
||||
case (lax.Precision.HIGH, _):
|
||||
return True
|
||||
case (lax.Precision.DEFAULT, _): # bfloat16
|
||||
raise NotImplementedError("bfloat16 support not implemented for LSTM")
|
||||
case _:
|
||||
raise ValueError(f"Unexpected precision specifier value {precision}")
|
||||
|
||||
|
||||
@partial(custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10))
|
||||
|
@ -13,9 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
from typing import ClassVar
|
||||
|
||||
|
||||
class ComputationMode(enum.Enum):
|
||||
kComputeEigenvectors: ClassVar[ComputationMode] = ...
|
||||
kNoEigenvectors: ClassVar[ComputationMode] = ...
|
||||
kComputeEigenvectors = ...
|
||||
kNoEigenvectors = ...
|
||||
|
@ -13,14 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
from typing import ClassVar
|
||||
|
||||
|
||||
class ComputationMode(enum.Enum):
|
||||
kComputeSchurVectors: ClassVar[ComputationMode]
|
||||
kNoComputeSchurVectors: ClassVar[ComputationMode]
|
||||
kComputeSchurVectors = ...
|
||||
kNoComputeSchurVectors = ...
|
||||
|
||||
|
||||
class Sort(enum.Enum):
|
||||
kNoSortEigenvalues: ClassVar[Sort]
|
||||
kSortEigenvalues: ClassVar[Sort]
|
||||
kNoSortEigenvalues = ...
|
||||
kSortEigenvalues = ...
|
||||
|
@ -7,6 +7,7 @@ show_error_codes = true
|
||||
disable_error_code = "attr-defined, name-defined, annotation-unchecked"
|
||||
no_implicit_optional = true
|
||||
warn_redundant_casts = true
|
||||
allow_redefinition = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user