mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
pre-commit: update mypy to most recent version
This commit is contained in:
parent
0860c24767
commit
9d1f3b4dd2
@ -27,12 +27,12 @@ repos:
|
||||
- id: flake8
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: 'v0.982'
|
||||
rev: 'v1.4.1'
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: (jax/|tests/typing_test\.py)
|
||||
exclude: jax/_src/basearray.py # Use pyi instead
|
||||
additional_dependencies: [types-requests==2.29.0, jaxlib==0.4.7, ml_dtypes==0.2.0, numpy==1.24.3, scipy==1.10.1]
|
||||
additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.13, ml_dtypes==0.2.0, numpy==1.24.3, scipy==1.10.1]
|
||||
args: [--config=pyproject.toml]
|
||||
|
||||
- repo: https://github.com/mwouts/jupytext
|
||||
|
@ -30,7 +30,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflowjs as tfjs
|
||||
|
||||
import input_pipeline
|
||||
import input_pipeline # type: ignore[import]
|
||||
|
||||
|
||||
flags.DEFINE_integer("num_epochs", 5,
|
||||
|
@ -132,22 +132,24 @@ def _rand_sparse(shape: Sequence[int], dtype: DTypeLike, *,
|
||||
[n_batch, n_sparse])
|
||||
if 0 <= nse < 1:
|
||||
nse = int(np.ceil(nse * np.prod(sparse_shape)))
|
||||
nse_int = int(nse)
|
||||
data_rng = rand_method(rng)
|
||||
data_shape = (*batch_shape, nse, *dense_shape)
|
||||
data_shape = (*batch_shape, nse_int, *dense_shape)
|
||||
data = jnp.array(data_rng(data_shape, dtype))
|
||||
|
||||
int32 = np.dtype('int32')
|
||||
if sparse_format == 'bcoo':
|
||||
index_shape = (*batch_shape, nse, n_sparse)
|
||||
index_shape = (*batch_shape, nse_int, n_sparse)
|
||||
indices = jnp.array(
|
||||
rng.randint(0, sparse_shape, size=index_shape, dtype=np.int32)) # type: ignore[arg-type]
|
||||
rng.randint(0, sparse_shape, size=index_shape, dtype=int32))
|
||||
return sparse.BCOO((data, indices), shape=shape)
|
||||
else:
|
||||
index_shape = (*batch_shape, nse)
|
||||
index_shape = (*batch_shape, nse_int)
|
||||
indptr_shape = (*batch_shape, sparse_shape[0] + 1)
|
||||
indices = jnp.array(
|
||||
rng.randint(0, sparse_shape[1], size=index_shape, dtype=np.int32)) # type: ignore[arg-type]
|
||||
rng.randint(0, sparse_shape[1], size=index_shape, dtype=int32))
|
||||
indptr = jnp.sort(
|
||||
rng.randint(0, nse + 1, size=indptr_shape, dtype=np.int32), axis=-1) # type: ignore[call-overload]
|
||||
rng.randint(0, nse_int + 1, size=indptr_shape, dtype=int32), axis=-1)
|
||||
indptr = indptr.at[..., 0].set(0)
|
||||
return sparse.BCSR((data, indices, indptr), shape=shape)
|
||||
|
||||
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.mypy]
|
||||
show_error_codes = true
|
||||
disable_error_code = "attr-defined"
|
||||
disable_error_code = "attr-defined, name-defined"
|
||||
no_implicit_optional = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user