pre-commit: update mypy to most recent version

This commit is contained in:
Jake VanderPlas 2023-07-12 10:41:51 -07:00
parent 0860c24767
commit 9d1f3b4dd2
4 changed files with 12 additions and 10 deletions

View File

@ -27,12 +27,12 @@ repos:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.982'
rev: 'v1.4.1'
hooks:
- id: mypy
files: (jax/|tests/typing_test\.py)
exclude: jax/_src/basearray.py # Use pyi instead
additional_dependencies: [types-requests==2.29.0, jaxlib==0.4.7, ml_dtypes==0.2.0, numpy==1.24.3, scipy==1.10.1]
additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.13, ml_dtypes==0.2.0, numpy==1.24.3, scipy==1.10.1]
args: [--config=pyproject.toml]
- repo: https://github.com/mwouts/jupytext

View File

@ -30,7 +30,7 @@ import numpy as np
import tensorflow as tf
import tensorflowjs as tfjs
import input_pipeline
import input_pipeline # type: ignore[import]
flags.DEFINE_integer("num_epochs", 5,

View File

@ -132,22 +132,24 @@ def _rand_sparse(shape: Sequence[int], dtype: DTypeLike, *,
[n_batch, n_sparse])
if 0 <= nse < 1:
nse = int(np.ceil(nse * np.prod(sparse_shape)))
nse_int = int(nse)
data_rng = rand_method(rng)
data_shape = (*batch_shape, nse, *dense_shape)
data_shape = (*batch_shape, nse_int, *dense_shape)
data = jnp.array(data_rng(data_shape, dtype))
int32 = np.dtype('int32')
if sparse_format == 'bcoo':
index_shape = (*batch_shape, nse, n_sparse)
index_shape = (*batch_shape, nse_int, n_sparse)
indices = jnp.array(
rng.randint(0, sparse_shape, size=index_shape, dtype=np.int32)) # type: ignore[arg-type]
rng.randint(0, sparse_shape, size=index_shape, dtype=int32))
return sparse.BCOO((data, indices), shape=shape)
else:
index_shape = (*batch_shape, nse)
index_shape = (*batch_shape, nse_int)
indptr_shape = (*batch_shape, sparse_shape[0] + 1)
indices = jnp.array(
rng.randint(0, sparse_shape[1], size=index_shape, dtype=np.int32)) # type: ignore[arg-type]
rng.randint(0, sparse_shape[1], size=index_shape, dtype=int32))
indptr = jnp.sort(
rng.randint(0, nse + 1, size=indptr_shape, dtype=np.int32), axis=-1) # type: ignore[call-overload]
rng.randint(0, nse_int + 1, size=indptr_shape, dtype=int32), axis=-1)
indptr = indptr.at[..., 0].set(0)
return sparse.BCSR((data, indices, indptr), shape=shape)

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[tool.mypy]
show_error_codes = true
disable_error_code = "attr-defined"
disable_error_code = "attr-defined, name-defined"
no_implicit_optional = true
[[tool.mypy.overrides]]