From 9d1f3b4dd2481f543a498c727803f611e65b63ca Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 12 Jul 2023 10:41:51 -0700 Subject: [PATCH] pre-commit: update mypy to most recent version --- .pre-commit-config.yaml | 4 ++-- .../jax2tf/examples/tf_js/quickdraw/quickdraw.py | 2 +- jax/experimental/sparse/test_util.py | 14 ++++++++------ pyproject.toml | 2 +- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0c76e2d41..a6c85b18a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,12 +27,12 @@ repos: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v0.982' + rev: 'v1.4.1' hooks: - id: mypy files: (jax/|tests/typing_test\.py) exclude: jax/_src/basearray.py # Use pyi instead - additional_dependencies: [types-requests==2.29.0, jaxlib==0.4.7, ml_dtypes==0.2.0, numpy==1.24.3, scipy==1.10.1] + additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.13, ml_dtypes==0.2.0, numpy==1.24.3, scipy==1.10.1] args: [--config=pyproject.toml] - repo: https://github.com/mwouts/jupytext diff --git a/jax/experimental/jax2tf/examples/tf_js/quickdraw/quickdraw.py b/jax/experimental/jax2tf/examples/tf_js/quickdraw/quickdraw.py index 151610728..94f21724d 100644 --- a/jax/experimental/jax2tf/examples/tf_js/quickdraw/quickdraw.py +++ b/jax/experimental/jax2tf/examples/tf_js/quickdraw/quickdraw.py @@ -30,7 +30,7 @@ import numpy as np import tensorflow as tf import tensorflowjs as tfjs -import input_pipeline +import input_pipeline # type: ignore[import] flags.DEFINE_integer("num_epochs", 5, diff --git a/jax/experimental/sparse/test_util.py b/jax/experimental/sparse/test_util.py index 5767f79d6..a49bc00be 100644 --- a/jax/experimental/sparse/test_util.py +++ b/jax/experimental/sparse/test_util.py @@ -132,22 +132,24 @@ def _rand_sparse(shape: Sequence[int], dtype: DTypeLike, *, [n_batch, n_sparse]) if 0 <= nse < 1: nse = int(np.ceil(nse * np.prod(sparse_shape))) + nse_int = int(nse) data_rng = rand_method(rng) - data_shape = (*batch_shape, nse, *dense_shape) + data_shape = (*batch_shape, nse_int, *dense_shape) data = jnp.array(data_rng(data_shape, dtype)) + int32 = np.dtype('int32') if sparse_format == 'bcoo': - index_shape = (*batch_shape, nse, n_sparse) + index_shape = (*batch_shape, nse_int, n_sparse) indices = jnp.array( - rng.randint(0, sparse_shape, size=index_shape, dtype=np.int32)) # type: ignore[arg-type] + rng.randint(0, sparse_shape, size=index_shape, dtype=int32)) return sparse.BCOO((data, indices), shape=shape) else: - index_shape = (*batch_shape, nse) + index_shape = (*batch_shape, nse_int) indptr_shape = (*batch_shape, sparse_shape[0] + 1) indices = jnp.array( - rng.randint(0, sparse_shape[1], size=index_shape, dtype=np.int32)) # type: ignore[arg-type] + rng.randint(0, sparse_shape[1], size=index_shape, dtype=int32)) indptr = jnp.sort( - rng.randint(0, nse + 1, size=indptr_shape, dtype=np.int32), axis=-1) # type: ignore[call-overload] + rng.randint(0, nse_int + 1, size=indptr_shape, dtype=int32), axis=-1) indptr = indptr.at[..., 0].set(0) return sparse.BCSR((data, indices, indptr), shape=shape) diff --git a/pyproject.toml b/pyproject.toml index eb3eb4c1f..4bcff99d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [tool.mypy] show_error_codes = true -disable_error_code = "attr-defined" +disable_error_code = "attr-defined, name-defined" no_implicit_optional = true [[tool.mypy.overrides]]