Fixed mypy type errors for numpy 1.20

Revert also previous changes that pinned numpy to 1.19.

One of the changes in numpy 1.20 is to add more type annotations.
However, this sometimes make mypy give errors. A common example is
numpy.take, which with the new type annotation does not appear to
mypy as indexable.

Another change is that np.int and np.bool are deprecated. One
should use np.bool_ or np.int_, or the built-ins bool and int.
This commit is contained in:
George Necula 2021-01-31 15:34:20 +02:00
parent 92a0e695ed
commit f105517ea2
13 changed files with 39 additions and 25 deletions

View File

@ -58,6 +58,15 @@ jobs:
enable-omnistaging: 1
package-overrides: "none"
num_generated_cases: 25
# TODO: re-enable this for numpy 1.20
# - name-prefix: "with numpy-dispatch"
# python-version: 3.7
# os: ubuntu-latest
# enable-x64: 1
# enable-omnistaging: 1
# # Test experimental NumPy dispatch
# package-overrides: "git+https://github.com/seberg/numpy-dispatch.git"
# num_generated_cases: 10
- name-prefix: "with internal numpy"
python-version: 3.6
os: ubuntu-latest

View File

@ -10,6 +10,10 @@ Change Log
jaxlib 0.1.61 (Unreleased)
--------------------------
.. AFTER RELEASING THE NEXT JAXLIB, please remove the numpy<1.20 requirements in
.. test-requirements.txt. It is there only to match the previously released
.. jaxlib 0.1.60
* Bug fixes:
@ -31,7 +35,7 @@ jax 0.2.10 (Unreleased)
made to match the semantics of NumPy 1.20.0.
* Several `jax.numpy` functions no longer accept tuples or lists in place
of array arguments: :func:`jax.numpy.pad`, :func`jax.numpy.ravel`,
:func:`jax.numpy.repeat`.
:func:`jax.numpy.repeat`, :func:`jax.numpy.reshape`.
In general, `jax.numpy` functions should be used with scalars or array arguments.
jaxlib 0.1.60 (Febuary 3 2021)

View File

@ -591,10 +591,10 @@ def conv_general_dilated(
rhs_dilation = (1,) * (rhs.ndim - 2)
if isinstance(padding, str):
lhs_perm, rhs_perm, _ = dnums
rhs_shape = np.take(rhs.shape, rhs_perm)[2:]
rhs_shape = np.take(rhs.shape, rhs_perm)[2:] # type: ignore[index]
effective_rhs_shape = [(k-1) * r + 1 for k, r in zip(rhs_shape, rhs_dilation)]
padding = padtype_to_pads(
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape,
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, # type: ignore[index]
window_strides, padding)
return conv_general_dilated_p.bind(
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
@ -1515,7 +1515,7 @@ def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array:
shape = tuple(map(int, shape))
axes = tuple(map(int, axes))
dtype = dtypes.canonicalize_dtype(dtype)
base_shape = tuple(np.take(shape, axes))
base_shape = tuple(np.take(shape, axes)) # type: ignore[arg-type]
if config.omnistaging_enabled:
iotas = [broadcasted_iota(np.uint32, base_shape, i)
for i in range(len(base_shape))]
@ -1716,7 +1716,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
raise ValueError('No 4+ dimensional dimension_number defaults.')
dn = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
k_shape = np.take(rhs.shape, dn.rhs_spec)
k_sdims = k_shape[2:]
k_sdims = k_shape[2:] # type: ignore[index]
# Calculate correct output shape given padding and strides.
pads: Union[str, Sequence[Tuple[int, int]]]
if padding in {'SAME', 'VALID'}:
@ -2734,7 +2734,7 @@ def _conv_general_dilated_shape_rule(
rhs_trans = _dilate_shape(np.take(rhs.shape, rhs_perm), rhs_dilation)
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
batch_group_count)
return tuple(np.take(out_trans, np.argsort(out_perm)))
return tuple(np.take(out_trans, np.argsort(out_perm))) # type: ignore[arg-type]
def _conv_general_dilated_dtype_rule(
lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
@ -3132,7 +3132,7 @@ def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
else:
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
dims = ((ans_y, y_kept), (ans_batch, y_batch))
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) # type: ignore[arg-type]
out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
return transpose(dot_general(g, y, dims, precision=precision, preferred_element_type=preferred_element_type),
tuple(out_axes))
@ -3733,7 +3733,7 @@ transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
'transpose')
transpose_p.def_impl(_transpose_impl)
ad.deflinear2(transpose_p,
lambda t, _, permutation: [transpose(t, np.argsort(permutation))])
lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) # type: ignore[arg-type]
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
masking.masking_rules[transpose_p] = _transpose_masking_rule

View File

@ -1993,7 +1993,7 @@ def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
if axis is None:
normalizer = size(a)
else:
normalizer = np.prod(np.take(shape(a), axis))
normalizer = np.prod(np.take(shape(a), axis)) # type: ignore
if dtype is None:
if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
dtype = float_
@ -2074,7 +2074,7 @@ def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
if axis is None:
normalizer = size(a)
else:
normalizer = np.prod(np.take(shape(a), axis))
normalizer = np.prod(np.take(shape(a), axis)) # type: ignore
normalizer = normalizer - ddof
result = sum(centered, axis, keepdims=keepdims)

View File

@ -56,7 +56,7 @@ def PRNGKey(seed: int) -> jnp.ndarray:
# Explicitly cast to int64 for JIT invariance of behavior on large ints.
if isinstance(seed, int):
seed = np.int64(seed)
seed = np.int64(seed) # type: ignore[assignment]
# Converting to jnp.array may truncate bits when jax_enable_x64=False, but this
# is necessary for the sake of JIT invariance of the result for such values.
seed = jnp.asarray(seed)
@ -634,7 +634,7 @@ def _normal_real(key, shape, dtype) -> jnp.ndarray:
_check_shape("normal", shape)
lo = np.nextafter(np.array(-1., dtype), 0., dtype=dtype)
hi = np.array(1., dtype)
u = uniform(key, shape, dtype, lo, hi)
u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type]
return np.array(np.sqrt(2), dtype) * lax.erf_inv(u)
@ -764,7 +764,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> jnp.ndarray:
def bernoulli(key: jnp.ndarray,
p: jnp.ndarray = np.float32(0.5),
p: jnp.ndarray = np.float32(0.5), # type: ignore[assignment]
shape: Optional[Sequence[int]] = None) -> jnp.ndarray:
"""Sample Bernoulli random values with given shape and mean.

View File

@ -76,4 +76,4 @@ for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) # type: ignore

View File

@ -46,8 +46,8 @@ _bfloat16_dtype = np.dtype(bfloat16)
# Default types.
bool_ = np.bool_
int_ = np.int64
float_ = np.float64
int_: np.dtype = np.int64 # type: ignore
float_: np.dtype = np.float64 # type: ignore
complex_ = np.complex128
# TODO(phawkins): change the above defaults to:
@ -218,7 +218,7 @@ _jax_types = [
np.dtype('float64'),
np.dtype('complex64'),
np.dtype('complex128'),
] + _weak_types
] + _weak_types # type: ignore[operator]
def _jax_type(value):
"""Return the jax type for a value or type."""

View File

@ -15,7 +15,7 @@
import itertools
import numpy as np
from typing import Any, Callable, Optional, Sequence
from typing import Any, Callable, Optional, Sequence, Union
from jax import dtypes
from jax import lax
@ -35,8 +35,8 @@ class Jax2TfLimitation(primitive_harness.Limitation):
self,
description: str,
*,
devices: Sequence[str] = ("cpu", "gpu", "tpu"),
dtypes: Sequence[DType] = (),
devices: Union[str, Sequence[str]] = ("cpu", "gpu", "tpu"),
dtypes: Union[DType, Sequence[DType]] = (),
enabled: bool = True,
# jax2tf specific
modes=("eager", "graph", "compiled"),

View File

@ -293,8 +293,8 @@ class Limitation:
description: str,
*,
enabled: bool = True,
devices: Sequence[str] = ("cpu", "gpu", "tpu"),
dtypes: Sequence[DType] = (),
devices: Union[str, Sequence[str]] = ("cpu", "gpu", "tpu"),
dtypes: Union[DType, Sequence[DType]] = (),
skip_run: bool = False,
):
"""Args:

View File

@ -326,7 +326,7 @@ def spec_to_indices(shape: Tuple[int, ...],
int, a slice object with step=1, or a tuple thereof, to be treated as an
index into the full logical array.
"""
return tuple(spec.indices(shape).flat)
return tuple(spec.indices(shape).flat) # type: ignore
### util

View File

@ -1,4 +1,5 @@
[mypy]
show_error_codes=True
[mypy-absl.*]
ignore_missing_imports = True

View File

@ -29,7 +29,7 @@ setup(
package_data={'jax': ['py.typed']},
python_requires='>=3.6',
install_requires=[
'numpy >=1.12,<1.20',
'numpy >=1.12',
'absl-py',
'opt_einsum',
],

View File

@ -870,7 +870,7 @@ class XMapErrorTest(jtu.JaxTestCase):
def testNestedDifferentResources(self):
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
def f(x):
with mesh(np.empty((), dtype=np.object), ()):
with mesh(np.empty((), dtype=np.object_), ()):
@partial(xmap, in_axes={0: 'b'}, out_axes={0: 'b'})
def h(x):
return x