mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fixed mypy type errors for numpy 1.20
Revert also previous changes that pinned numpy to 1.19. One of the changes in numpy 1.20 is to add more type annotations. However, this sometimes make mypy give errors. A common example is numpy.take, which with the new type annotation does not appear to mypy as indexable. Another change is that np.int and np.bool are deprecated. One should use np.bool_ or np.int_, or the built-ins bool and int.
This commit is contained in:
parent
92a0e695ed
commit
f105517ea2
9
.github/workflows/ci-build.yaml
vendored
9
.github/workflows/ci-build.yaml
vendored
@ -58,6 +58,15 @@ jobs:
|
||||
enable-omnistaging: 1
|
||||
package-overrides: "none"
|
||||
num_generated_cases: 25
|
||||
# TODO: re-enable this for numpy 1.20
|
||||
# - name-prefix: "with numpy-dispatch"
|
||||
# python-version: 3.7
|
||||
# os: ubuntu-latest
|
||||
# enable-x64: 1
|
||||
# enable-omnistaging: 1
|
||||
# # Test experimental NumPy dispatch
|
||||
# package-overrides: "git+https://github.com/seberg/numpy-dispatch.git"
|
||||
# num_generated_cases: 10
|
||||
- name-prefix: "with internal numpy"
|
||||
python-version: 3.6
|
||||
os: ubuntu-latest
|
||||
|
@ -10,6 +10,10 @@ Change Log
|
||||
jaxlib 0.1.61 (Unreleased)
|
||||
--------------------------
|
||||
|
||||
.. AFTER RELEASING THE NEXT JAXLIB, please remove the numpy<1.20 requirements in
|
||||
.. test-requirements.txt. It is there only to match the previously released
|
||||
.. jaxlib 0.1.60
|
||||
|
||||
* Bug fixes:
|
||||
|
||||
|
||||
@ -31,7 +35,7 @@ jax 0.2.10 (Unreleased)
|
||||
made to match the semantics of NumPy 1.20.0.
|
||||
* Several `jax.numpy` functions no longer accept tuples or lists in place
|
||||
of array arguments: :func:`jax.numpy.pad`, :func`jax.numpy.ravel`,
|
||||
:func:`jax.numpy.repeat`.
|
||||
:func:`jax.numpy.repeat`, :func:`jax.numpy.reshape`.
|
||||
In general, `jax.numpy` functions should be used with scalars or array arguments.
|
||||
|
||||
jaxlib 0.1.60 (Febuary 3 2021)
|
||||
|
@ -591,10 +591,10 @@ def conv_general_dilated(
|
||||
rhs_dilation = (1,) * (rhs.ndim - 2)
|
||||
if isinstance(padding, str):
|
||||
lhs_perm, rhs_perm, _ = dnums
|
||||
rhs_shape = np.take(rhs.shape, rhs_perm)[2:]
|
||||
rhs_shape = np.take(rhs.shape, rhs_perm)[2:] # type: ignore[index]
|
||||
effective_rhs_shape = [(k-1) * r + 1 for k, r in zip(rhs_shape, rhs_dilation)]
|
||||
padding = padtype_to_pads(
|
||||
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape,
|
||||
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, # type: ignore[index]
|
||||
window_strides, padding)
|
||||
return conv_general_dilated_p.bind(
|
||||
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
|
||||
@ -1515,7 +1515,7 @@ def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array:
|
||||
shape = tuple(map(int, shape))
|
||||
axes = tuple(map(int, axes))
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
base_shape = tuple(np.take(shape, axes))
|
||||
base_shape = tuple(np.take(shape, axes)) # type: ignore[arg-type]
|
||||
if config.omnistaging_enabled:
|
||||
iotas = [broadcasted_iota(np.uint32, base_shape, i)
|
||||
for i in range(len(base_shape))]
|
||||
@ -1716,7 +1716,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
|
||||
raise ValueError('No 4+ dimensional dimension_number defaults.')
|
||||
dn = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
|
||||
k_shape = np.take(rhs.shape, dn.rhs_spec)
|
||||
k_sdims = k_shape[2:]
|
||||
k_sdims = k_shape[2:] # type: ignore[index]
|
||||
# Calculate correct output shape given padding and strides.
|
||||
pads: Union[str, Sequence[Tuple[int, int]]]
|
||||
if padding in {'SAME', 'VALID'}:
|
||||
@ -2734,7 +2734,7 @@ def _conv_general_dilated_shape_rule(
|
||||
rhs_trans = _dilate_shape(np.take(rhs.shape, rhs_perm), rhs_dilation)
|
||||
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
|
||||
batch_group_count)
|
||||
return tuple(np.take(out_trans, np.argsort(out_perm)))
|
||||
return tuple(np.take(out_trans, np.argsort(out_perm))) # type: ignore[arg-type]
|
||||
|
||||
def _conv_general_dilated_dtype_rule(
|
||||
lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
@ -3132,7 +3132,7 @@ def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
|
||||
else:
|
||||
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
|
||||
dims = ((ans_y, y_kept), (ans_batch, y_batch))
|
||||
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
|
||||
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) # type: ignore[arg-type]
|
||||
out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
|
||||
return transpose(dot_general(g, y, dims, precision=precision, preferred_element_type=preferred_element_type),
|
||||
tuple(out_axes))
|
||||
@ -3733,7 +3733,7 @@ transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
|
||||
'transpose')
|
||||
transpose_p.def_impl(_transpose_impl)
|
||||
ad.deflinear2(transpose_p,
|
||||
lambda t, _, permutation: [transpose(t, np.argsort(permutation))])
|
||||
lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) # type: ignore[arg-type]
|
||||
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
|
||||
masking.masking_rules[transpose_p] = _transpose_masking_rule
|
||||
|
||||
|
@ -1993,7 +1993,7 @@ def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
if axis is None:
|
||||
normalizer = size(a)
|
||||
else:
|
||||
normalizer = np.prod(np.take(shape(a), axis))
|
||||
normalizer = np.prod(np.take(shape(a), axis)) # type: ignore
|
||||
if dtype is None:
|
||||
if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
|
||||
dtype = float_
|
||||
@ -2074,7 +2074,7 @@ def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
if axis is None:
|
||||
normalizer = size(a)
|
||||
else:
|
||||
normalizer = np.prod(np.take(shape(a), axis))
|
||||
normalizer = np.prod(np.take(shape(a), axis)) # type: ignore
|
||||
normalizer = normalizer - ddof
|
||||
|
||||
result = sum(centered, axis, keepdims=keepdims)
|
||||
|
@ -56,7 +56,7 @@ def PRNGKey(seed: int) -> jnp.ndarray:
|
||||
|
||||
# Explicitly cast to int64 for JIT invariance of behavior on large ints.
|
||||
if isinstance(seed, int):
|
||||
seed = np.int64(seed)
|
||||
seed = np.int64(seed) # type: ignore[assignment]
|
||||
# Converting to jnp.array may truncate bits when jax_enable_x64=False, but this
|
||||
# is necessary for the sake of JIT invariance of the result for such values.
|
||||
seed = jnp.asarray(seed)
|
||||
@ -634,7 +634,7 @@ def _normal_real(key, shape, dtype) -> jnp.ndarray:
|
||||
_check_shape("normal", shape)
|
||||
lo = np.nextafter(np.array(-1., dtype), 0., dtype=dtype)
|
||||
hi = np.array(1., dtype)
|
||||
u = uniform(key, shape, dtype, lo, hi)
|
||||
u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type]
|
||||
return np.array(np.sqrt(2), dtype) * lax.erf_inv(u)
|
||||
|
||||
|
||||
@ -764,7 +764,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> jnp.ndarray:
|
||||
|
||||
|
||||
def bernoulli(key: jnp.ndarray,
|
||||
p: jnp.ndarray = np.float32(0.5),
|
||||
p: jnp.ndarray = np.float32(0.5), # type: ignore[assignment]
|
||||
shape: Optional[Sequence[int]] = None) -> jnp.ndarray:
|
||||
"""Sample Bernoulli random values with given shape and mean.
|
||||
|
||||
|
@ -76,4 +76,4 @@ for t in dtypes.python_scalar_dtypes:
|
||||
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
|
||||
ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)
|
||||
|
||||
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
|
||||
core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) # type: ignore
|
||||
|
@ -46,8 +46,8 @@ _bfloat16_dtype = np.dtype(bfloat16)
|
||||
# Default types.
|
||||
|
||||
bool_ = np.bool_
|
||||
int_ = np.int64
|
||||
float_ = np.float64
|
||||
int_: np.dtype = np.int64 # type: ignore
|
||||
float_: np.dtype = np.float64 # type: ignore
|
||||
complex_ = np.complex128
|
||||
|
||||
# TODO(phawkins): change the above defaults to:
|
||||
@ -218,7 +218,7 @@ _jax_types = [
|
||||
np.dtype('float64'),
|
||||
np.dtype('complex64'),
|
||||
np.dtype('complex128'),
|
||||
] + _weak_types
|
||||
] + _weak_types # type: ignore[operator]
|
||||
|
||||
def _jax_type(value):
|
||||
"""Return the jax type for a value or type."""
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
import itertools
|
||||
import numpy as np
|
||||
from typing import Any, Callable, Optional, Sequence
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
@ -35,8 +35,8 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
self,
|
||||
description: str,
|
||||
*,
|
||||
devices: Sequence[str] = ("cpu", "gpu", "tpu"),
|
||||
dtypes: Sequence[DType] = (),
|
||||
devices: Union[str, Sequence[str]] = ("cpu", "gpu", "tpu"),
|
||||
dtypes: Union[DType, Sequence[DType]] = (),
|
||||
enabled: bool = True,
|
||||
# jax2tf specific
|
||||
modes=("eager", "graph", "compiled"),
|
||||
|
@ -293,8 +293,8 @@ class Limitation:
|
||||
description: str,
|
||||
*,
|
||||
enabled: bool = True,
|
||||
devices: Sequence[str] = ("cpu", "gpu", "tpu"),
|
||||
dtypes: Sequence[DType] = (),
|
||||
devices: Union[str, Sequence[str]] = ("cpu", "gpu", "tpu"),
|
||||
dtypes: Union[DType, Sequence[DType]] = (),
|
||||
skip_run: bool = False,
|
||||
):
|
||||
"""Args:
|
||||
|
@ -326,7 +326,7 @@ def spec_to_indices(shape: Tuple[int, ...],
|
||||
int, a slice object with step=1, or a tuple thereof, to be treated as an
|
||||
index into the full logical array.
|
||||
"""
|
||||
return tuple(spec.indices(shape).flat)
|
||||
return tuple(spec.indices(shape).flat) # type: ignore
|
||||
|
||||
|
||||
### util
|
||||
|
1
mypy.ini
1
mypy.ini
@ -1,4 +1,5 @@
|
||||
[mypy]
|
||||
show_error_codes=True
|
||||
|
||||
[mypy-absl.*]
|
||||
ignore_missing_imports = True
|
||||
|
2
setup.py
2
setup.py
@ -29,7 +29,7 @@ setup(
|
||||
package_data={'jax': ['py.typed']},
|
||||
python_requires='>=3.6',
|
||||
install_requires=[
|
||||
'numpy >=1.12,<1.20',
|
||||
'numpy >=1.12',
|
||||
'absl-py',
|
||||
'opt_einsum',
|
||||
],
|
||||
|
@ -870,7 +870,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
def testNestedDifferentResources(self):
|
||||
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
|
||||
def f(x):
|
||||
with mesh(np.empty((), dtype=np.object), ()):
|
||||
with mesh(np.empty((), dtype=np.object_), ()):
|
||||
@partial(xmap, in_axes={0: 'b'}, out_axes={0: 'b'})
|
||||
def h(x):
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user