diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index cd65b702f..c8b8d203a 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -62,6 +62,9 @@ from jax._src.lax import windowed_reductions as lax_windowed_reductions from jax._src.lib import xla_client from jax._src import random as jax_random +# mypy generates a lot of false positive due to re-assigned variables. +# mypy: disable-error-code="assignment, no-redef" + # The code in this file relies on the values of some flags that are defined by # jtu. Note that the following can not always be moved to a test file since # then the test file has to import jtu first (to define the flags) which is not @@ -172,9 +175,9 @@ class Harness: self.group_name = jtu.sanitize_test_name(group_name) self.name = jtu.sanitize_test_name(name) self.fullname = self.name if self.group_name is None else f"{self.group_name}_{self.name}" - self.fun = fun # type: ignore[assignment] + self.fun = fun self.arg_descriptors = arg_descriptors - self.rng_factory = rng_factory # type: ignore[assignment] + self.rng_factory = rng_factory self.jax_unimplemented = jax_unimplemented self.dtype = dtype self.params = params @@ -2060,18 +2063,17 @@ def _make_slice_harness(name, define( lax.slice_p, f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_{strides=}", - # type: ignore lax.slice, [ - RandArg(shape, dtype), # type: ignore - StaticArg(start_indices), # type: ignore - StaticArg(limit_indices), # type: ignore + RandArg(shape, dtype), + StaticArg(start_indices), + StaticArg(limit_indices), StaticArg(strides) - ], # type: ignore + ], dtype=dtype, - shape=shape, # type: ignore - start_indices=start_indices, # type: ignore - limit_indices=limit_indices) # type: ignore + shape=shape, + start_indices=start_indices, + limit_indices=limit_indices) # Test first all dtypes @@ -2161,17 +2163,16 @@ def _make_dynamic_slice_harness(name, define( lax.dynamic_slice_p, f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_enablexla={enable_xla}", - # type: ignore lax.dynamic_slice, [ - RandArg(shape, dtype), # type: ignore + RandArg(shape, dtype), np.array(list(start_indices)), StaticArg(tuple(map(operator.sub, limit_indices, start_indices))) - ], # type: ignore + ], dtype=dtype, - shape=shape, # type: ignore - start_indices=start_indices, # type: ignore - limit_indices=limit_indices, # type: ignore + shape=shape, + start_indices=start_indices, + limit_indices=limit_indices, enable_xla=enable_xla) @@ -2218,19 +2219,19 @@ def _make_dynamic_update_slice_harness(name, define( lax.dynamic_update_slice_p, ( - f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore + f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}" f"_{start_indices=}_{enable_xla=}"), lax.dynamic_update_slice, [ - RandArg(shape, dtype), # type: ignore - RandArg(update_shape, dtype), # type: ignore + RandArg(shape, dtype), + RandArg(update_shape, dtype), np.array(start_indices) - ], # type: ignore + ], dtype=dtype, - shape=shape, # type: ignore - start_indices=start_indices, # type: ignore - update_shape=update_shape, # type: ignore + shape=shape, + start_indices=start_indices, + update_shape=update_shape, enable_xla=enable_xla) @@ -2261,12 +2262,12 @@ def _make_squeeze_harness(name, dtype=np.float32): define( lax.squeeze_p, - f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}", # type: ignore + f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}", lax.squeeze, - [RandArg(shape, dtype), StaticArg(dimensions)], # type: ignore[has-type] + [RandArg(shape, dtype), StaticArg(dimensions)], dtype=dtype, arg_shape=shape, - dimensions=dimensions) # type: ignore[has-type] + dimensions=dimensions) # Test first all dtypes @@ -3312,6 +3313,7 @@ for padding, lhs_dilation, rhs_dilation in [ lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation) +key_types: list[tuple[tuple[int, ...], jax.typing.DTypeLike]] key_types = [((4,), np.uint32)] if config.enable_x64.value: key_types.append(((2,), np.uint64)) diff --git a/pyproject.toml b/pyproject.toml index 8132c8773..630f25835 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,14 +44,6 @@ module = [ ] ignore_missing_imports = true -[[tool.mypy.overrides]] -module = [ - "jax.interpreters.autospmd", - "jax.lax.lax_parallel", - "jax._src.internal_test_util.test_harnesses", -] -ignore_errors = true - [tool.pytest.ini_options] markers = [ "multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators",