Removed unnecessary mypy exclusions from pyproject.toml

* 2/3 files type check just fine now
* the remaining one could be handled via a file-level directive
This commit is contained in:
Sergei Lebedev 2024-06-07 20:07:42 +01:00
parent af90464b53
commit 0786da8fd8
2 changed files with 28 additions and 34 deletions

View File

@ -62,6 +62,9 @@ from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src.lib import xla_client
from jax._src import random as jax_random
# mypy generates a lot of false positive due to re-assigned variables.
# mypy: disable-error-code="assignment, no-redef"
# The code in this file relies on the values of some flags that are defined by
# jtu. Note that the following can not always be moved to a test file since
# then the test file has to import jtu first (to define the flags) which is not
@ -172,9 +175,9 @@ class Harness:
self.group_name = jtu.sanitize_test_name(group_name)
self.name = jtu.sanitize_test_name(name)
self.fullname = self.name if self.group_name is None else f"{self.group_name}_{self.name}"
self.fun = fun # type: ignore[assignment]
self.fun = fun
self.arg_descriptors = arg_descriptors
self.rng_factory = rng_factory # type: ignore[assignment]
self.rng_factory = rng_factory
self.jax_unimplemented = jax_unimplemented
self.dtype = dtype
self.params = params
@ -2060,18 +2063,17 @@ def _make_slice_harness(name,
define(
lax.slice_p,
f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_{strides=}",
# type: ignore
lax.slice,
[
RandArg(shape, dtype), # type: ignore
StaticArg(start_indices), # type: ignore
StaticArg(limit_indices), # type: ignore
RandArg(shape, dtype),
StaticArg(start_indices),
StaticArg(limit_indices),
StaticArg(strides)
], # type: ignore
],
dtype=dtype,
shape=shape, # type: ignore
start_indices=start_indices, # type: ignore
limit_indices=limit_indices) # type: ignore
shape=shape,
start_indices=start_indices,
limit_indices=limit_indices)
# Test first all dtypes
@ -2161,17 +2163,16 @@ def _make_dynamic_slice_harness(name,
define(
lax.dynamic_slice_p,
f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_enablexla={enable_xla}",
# type: ignore
lax.dynamic_slice,
[
RandArg(shape, dtype), # type: ignore
RandArg(shape, dtype),
np.array(list(start_indices)),
StaticArg(tuple(map(operator.sub, limit_indices, start_indices)))
], # type: ignore
],
dtype=dtype,
shape=shape, # type: ignore
start_indices=start_indices, # type: ignore
limit_indices=limit_indices, # type: ignore
shape=shape,
start_indices=start_indices,
limit_indices=limit_indices,
enable_xla=enable_xla)
@ -2218,19 +2219,19 @@ def _make_dynamic_update_slice_harness(name,
define(
lax.dynamic_update_slice_p,
(
f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore
f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}"
f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}"
f"_{start_indices=}_{enable_xla=}"),
lax.dynamic_update_slice,
[
RandArg(shape, dtype), # type: ignore
RandArg(update_shape, dtype), # type: ignore
RandArg(shape, dtype),
RandArg(update_shape, dtype),
np.array(start_indices)
], # type: ignore
],
dtype=dtype,
shape=shape, # type: ignore
start_indices=start_indices, # type: ignore
update_shape=update_shape, # type: ignore
shape=shape,
start_indices=start_indices,
update_shape=update_shape,
enable_xla=enable_xla)
@ -2261,12 +2262,12 @@ def _make_squeeze_harness(name,
dtype=np.float32):
define(
lax.squeeze_p,
f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}", # type: ignore
f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}",
lax.squeeze,
[RandArg(shape, dtype), StaticArg(dimensions)], # type: ignore[has-type]
[RandArg(shape, dtype), StaticArg(dimensions)],
dtype=dtype,
arg_shape=shape,
dimensions=dimensions) # type: ignore[has-type]
dimensions=dimensions)
# Test first all dtypes
@ -3312,6 +3313,7 @@ for padding, lhs_dilation, rhs_dilation in [
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation)
key_types: list[tuple[tuple[int, ...], jax.typing.DTypeLike]]
key_types = [((4,), np.uint32)]
if config.enable_x64.value:
key_types.append(((2,), np.uint64))

View File

@ -44,14 +44,6 @@ module = [
]
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = [
"jax.interpreters.autospmd",
"jax.lax.lax_parallel",
"jax._src.internal_test_util.test_harnesses",
]
ignore_errors = true
[tool.pytest.ini_options]
markers = [
"multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators",