mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Removed unnecessary mypy exclusions from pyproject.toml
* 2/3 files type check just fine now * the remaining one could be handled via a file-level directive
This commit is contained in:
parent
af90464b53
commit
0786da8fd8
@ -62,6 +62,9 @@ from jax._src.lax import windowed_reductions as lax_windowed_reductions
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import random as jax_random
|
||||
|
||||
# mypy generates a lot of false positive due to re-assigned variables.
|
||||
# mypy: disable-error-code="assignment, no-redef"
|
||||
|
||||
# The code in this file relies on the values of some flags that are defined by
|
||||
# jtu. Note that the following can not always be moved to a test file since
|
||||
# then the test file has to import jtu first (to define the flags) which is not
|
||||
@ -172,9 +175,9 @@ class Harness:
|
||||
self.group_name = jtu.sanitize_test_name(group_name)
|
||||
self.name = jtu.sanitize_test_name(name)
|
||||
self.fullname = self.name if self.group_name is None else f"{self.group_name}_{self.name}"
|
||||
self.fun = fun # type: ignore[assignment]
|
||||
self.fun = fun
|
||||
self.arg_descriptors = arg_descriptors
|
||||
self.rng_factory = rng_factory # type: ignore[assignment]
|
||||
self.rng_factory = rng_factory
|
||||
self.jax_unimplemented = jax_unimplemented
|
||||
self.dtype = dtype
|
||||
self.params = params
|
||||
@ -2060,18 +2063,17 @@ def _make_slice_harness(name,
|
||||
define(
|
||||
lax.slice_p,
|
||||
f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_{strides=}",
|
||||
# type: ignore
|
||||
lax.slice,
|
||||
[
|
||||
RandArg(shape, dtype), # type: ignore
|
||||
StaticArg(start_indices), # type: ignore
|
||||
StaticArg(limit_indices), # type: ignore
|
||||
RandArg(shape, dtype),
|
||||
StaticArg(start_indices),
|
||||
StaticArg(limit_indices),
|
||||
StaticArg(strides)
|
||||
], # type: ignore
|
||||
],
|
||||
dtype=dtype,
|
||||
shape=shape, # type: ignore
|
||||
start_indices=start_indices, # type: ignore
|
||||
limit_indices=limit_indices) # type: ignore
|
||||
shape=shape,
|
||||
start_indices=start_indices,
|
||||
limit_indices=limit_indices)
|
||||
|
||||
|
||||
# Test first all dtypes
|
||||
@ -2161,17 +2163,16 @@ def _make_dynamic_slice_harness(name,
|
||||
define(
|
||||
lax.dynamic_slice_p,
|
||||
f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_enablexla={enable_xla}",
|
||||
# type: ignore
|
||||
lax.dynamic_slice,
|
||||
[
|
||||
RandArg(shape, dtype), # type: ignore
|
||||
RandArg(shape, dtype),
|
||||
np.array(list(start_indices)),
|
||||
StaticArg(tuple(map(operator.sub, limit_indices, start_indices)))
|
||||
], # type: ignore
|
||||
],
|
||||
dtype=dtype,
|
||||
shape=shape, # type: ignore
|
||||
start_indices=start_indices, # type: ignore
|
||||
limit_indices=limit_indices, # type: ignore
|
||||
shape=shape,
|
||||
start_indices=start_indices,
|
||||
limit_indices=limit_indices,
|
||||
enable_xla=enable_xla)
|
||||
|
||||
|
||||
@ -2218,19 +2219,19 @@ def _make_dynamic_update_slice_harness(name,
|
||||
define(
|
||||
lax.dynamic_update_slice_p,
|
||||
(
|
||||
f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore
|
||||
f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}"
|
||||
f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}"
|
||||
f"_{start_indices=}_{enable_xla=}"),
|
||||
lax.dynamic_update_slice,
|
||||
[
|
||||
RandArg(shape, dtype), # type: ignore
|
||||
RandArg(update_shape, dtype), # type: ignore
|
||||
RandArg(shape, dtype),
|
||||
RandArg(update_shape, dtype),
|
||||
np.array(start_indices)
|
||||
], # type: ignore
|
||||
],
|
||||
dtype=dtype,
|
||||
shape=shape, # type: ignore
|
||||
start_indices=start_indices, # type: ignore
|
||||
update_shape=update_shape, # type: ignore
|
||||
shape=shape,
|
||||
start_indices=start_indices,
|
||||
update_shape=update_shape,
|
||||
enable_xla=enable_xla)
|
||||
|
||||
|
||||
@ -2261,12 +2262,12 @@ def _make_squeeze_harness(name,
|
||||
dtype=np.float32):
|
||||
define(
|
||||
lax.squeeze_p,
|
||||
f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}", # type: ignore
|
||||
f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}",
|
||||
lax.squeeze,
|
||||
[RandArg(shape, dtype), StaticArg(dimensions)], # type: ignore[has-type]
|
||||
[RandArg(shape, dtype), StaticArg(dimensions)],
|
||||
dtype=dtype,
|
||||
arg_shape=shape,
|
||||
dimensions=dimensions) # type: ignore[has-type]
|
||||
dimensions=dimensions)
|
||||
|
||||
|
||||
# Test first all dtypes
|
||||
@ -3312,6 +3313,7 @@ for padding, lhs_dilation, rhs_dilation in [
|
||||
lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation)
|
||||
|
||||
key_types: list[tuple[tuple[int, ...], jax.typing.DTypeLike]]
|
||||
key_types = [((4,), np.uint32)]
|
||||
if config.enable_x64.value:
|
||||
key_types.append(((2,), np.uint64))
|
||||
|
@ -44,14 +44,6 @@ module = [
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"jax.interpreters.autospmd",
|
||||
"jax.lax.lax_parallel",
|
||||
"jax._src.internal_test_util.test_harnesses",
|
||||
]
|
||||
ignore_errors = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators",
|
||||
|
Loading…
x
Reference in New Issue
Block a user