Merge pull request #5573 from gnecula:numpy_1_20

PiperOrigin-RevId: 355831312
This commit is contained in:
jax authors 2021-02-05 05:29:36 -08:00
commit 55907e2d11
13 changed files with 39 additions and 25 deletions

View File

@ -58,6 +58,15 @@ jobs:
enable-omnistaging: 1
package-overrides: "none"
num_generated_cases: 25
# TODO: re-enable this for numpy 1.20
# - name-prefix: "with numpy-dispatch"
# python-version: 3.7
# os: ubuntu-latest
# enable-x64: 1
# enable-omnistaging: 1
# # Test experimental NumPy dispatch
# package-overrides: "git+https://github.com/seberg/numpy-dispatch.git"
# num_generated_cases: 10
- name-prefix: "with internal numpy"
python-version: 3.6
os: ubuntu-latest

View File

@ -10,6 +10,10 @@ Change Log
jaxlib 0.1.61 (Unreleased)
--------------------------
.. AFTER RELEASING THE NEXT JAXLIB, please remove the numpy<1.20 requirements in
.. test-requirements.txt. It is there only to match the previously released
.. jaxlib 0.1.60
* Bug fixes:
@ -31,7 +35,7 @@ jax 0.2.10 (Unreleased)
made to match the semantics of NumPy 1.20.0.
* Several `jax.numpy` functions no longer accept tuples or lists in place
of array arguments: :func:`jax.numpy.pad`, :func`jax.numpy.ravel`,
:func:`jax.numpy.repeat`.
:func:`jax.numpy.repeat`, :func:`jax.numpy.reshape`.
In general, `jax.numpy` functions should be used with scalars or array arguments.
jaxlib 0.1.60 (Febuary 3 2021)

View File

@ -591,10 +591,10 @@ def conv_general_dilated(
rhs_dilation = (1,) * (rhs.ndim - 2)
if isinstance(padding, str):
lhs_perm, rhs_perm, _ = dnums
rhs_shape = np.take(rhs.shape, rhs_perm)[2:]
rhs_shape = np.take(rhs.shape, rhs_perm)[2:] # type: ignore[index]
effective_rhs_shape = [(k-1) * r + 1 for k, r in zip(rhs_shape, rhs_dilation)]
padding = padtype_to_pads(
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape,
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, # type: ignore[index]
window_strides, padding)
return conv_general_dilated_p.bind(
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
@ -1515,7 +1515,7 @@ def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array:
shape = tuple(map(int, shape))
axes = tuple(map(int, axes))
dtype = dtypes.canonicalize_dtype(dtype)
base_shape = tuple(np.take(shape, axes))
base_shape = tuple(np.take(shape, axes)) # type: ignore[arg-type]
if config.omnistaging_enabled:
iotas = [broadcasted_iota(np.uint32, base_shape, i)
for i in range(len(base_shape))]
@ -1716,7 +1716,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
raise ValueError('No 4+ dimensional dimension_number defaults.')
dn = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
k_shape = np.take(rhs.shape, dn.rhs_spec)
k_sdims = k_shape[2:]
k_sdims = k_shape[2:] # type: ignore[index]
# Calculate correct output shape given padding and strides.
pads: Union[str, Sequence[Tuple[int, int]]]
if padding in {'SAME', 'VALID'}:
@ -2734,7 +2734,7 @@ def _conv_general_dilated_shape_rule(
rhs_trans = _dilate_shape(np.take(rhs.shape, rhs_perm), rhs_dilation)
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
batch_group_count)
return tuple(np.take(out_trans, np.argsort(out_perm)))
return tuple(np.take(out_trans, np.argsort(out_perm))) # type: ignore[arg-type]
def _conv_general_dilated_dtype_rule(
lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
@ -3132,7 +3132,7 @@ def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
else:
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
dims = ((ans_y, y_kept), (ans_batch, y_batch))
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) # type: ignore[arg-type]
out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
return transpose(dot_general(g, y, dims, precision=precision, preferred_element_type=preferred_element_type),
tuple(out_axes))
@ -3733,7 +3733,7 @@ transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
'transpose')
transpose_p.def_impl(_transpose_impl)
ad.deflinear2(transpose_p,
lambda t, _, permutation: [transpose(t, np.argsort(permutation))])
lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) # type: ignore[arg-type]
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
masking.masking_rules[transpose_p] = _transpose_masking_rule

View File

@ -1993,7 +1993,7 @@ def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
if axis is None:
normalizer = size(a)
else:
normalizer = np.prod(np.take(shape(a), axis))
normalizer = np.prod(np.take(shape(a), axis)) # type: ignore
if dtype is None:
if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
dtype = float_
@ -2074,7 +2074,7 @@ def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
if axis is None:
normalizer = size(a)
else:
normalizer = np.prod(np.take(shape(a), axis))
normalizer = np.prod(np.take(shape(a), axis)) # type: ignore
normalizer = normalizer - ddof
result = sum(centered, axis, keepdims=keepdims)

View File

@ -57,7 +57,7 @@ def PRNGKey(seed: int) -> jnp.ndarray:
# Explicitly cast to int64 for JIT invariance of behavior on large ints.
if isinstance(seed, int):
seed = np.int64(seed)
seed = np.int64(seed) # type: ignore[assignment]
# Converting to jnp.array may truncate bits when jax_enable_x64=False, but this
# is necessary for the sake of JIT invariance of the result for such values.
seed = jnp.asarray(seed)
@ -643,7 +643,7 @@ def _normal_real(key, shape, dtype) -> jnp.ndarray:
_check_shape("normal", shape)
lo = np.nextafter(np.array(-1., dtype), 0., dtype=dtype)
hi = np.array(1., dtype)
u = uniform(key, shape, dtype, lo, hi)
u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type]
return np.array(np.sqrt(2), dtype) * lax.erf_inv(u)
@ -773,7 +773,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> jnp.ndarray:
def bernoulli(key: jnp.ndarray,
p: jnp.ndarray = np.float32(0.5),
p: jnp.ndarray = np.float32(0.5), # type: ignore[assignment]
shape: Optional[Union[Sequence[int], NamedShape]] = None) -> jnp.ndarray:
"""Sample Bernoulli random values with given shape and mean.

View File

@ -76,4 +76,4 @@ for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) # type: ignore

View File

@ -46,8 +46,8 @@ _bfloat16_dtype = np.dtype(bfloat16)
# Default types.
bool_ = np.bool_
int_ = np.int64
float_ = np.float64
int_: np.dtype = np.int64 # type: ignore
float_: np.dtype = np.float64 # type: ignore
complex_ = np.complex128
# TODO(phawkins): change the above defaults to:
@ -218,7 +218,7 @@ _jax_types = [
np.dtype('float64'),
np.dtype('complex64'),
np.dtype('complex128'),
] + _weak_types
] + _weak_types # type: ignore[operator]
def _jax_type(value):
"""Return the jax type for a value or type."""

View File

@ -15,7 +15,7 @@
import itertools
import numpy as np
from typing import Any, Callable, Optional, Sequence
from typing import Any, Callable, Optional, Sequence, Union
from jax import dtypes
from jax import lax
@ -35,8 +35,8 @@ class Jax2TfLimitation(primitive_harness.Limitation):
self,
description: str,
*,
devices: Sequence[str] = ("cpu", "gpu", "tpu"),
dtypes: Sequence[DType] = (),
devices: Union[str, Sequence[str]] = ("cpu", "gpu", "tpu"),
dtypes: Union[DType, Sequence[DType]] = (),
enabled: bool = True,
# jax2tf specific
modes=("eager", "graph", "compiled"),

View File

@ -293,8 +293,8 @@ class Limitation:
description: str,
*,
enabled: bool = True,
devices: Sequence[str] = ("cpu", "gpu", "tpu"),
dtypes: Sequence[DType] = (),
devices: Union[str, Sequence[str]] = ("cpu", "gpu", "tpu"),
dtypes: Union[DType, Sequence[DType]] = (),
skip_run: bool = False,
):
"""Args:

View File

@ -326,7 +326,7 @@ def spec_to_indices(shape: Tuple[int, ...],
int, a slice object with step=1, or a tuple thereof, to be treated as an
index into the full logical array.
"""
return tuple(spec.indices(shape).flat)
return tuple(spec.indices(shape).flat) # type: ignore
### util

View File

@ -1,4 +1,5 @@
[mypy]
show_error_codes=True
[mypy-absl.*]
ignore_missing_imports = True

View File

@ -29,7 +29,7 @@ setup(
package_data={'jax': ['py.typed']},
python_requires='>=3.6',
install_requires=[
'numpy >=1.12,<1.20',
'numpy >=1.12',
'absl-py',
'opt_einsum',
],

View File

@ -973,7 +973,7 @@ class XMapErrorTest(jtu.JaxTestCase):
def testNestedDifferentResources(self):
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
def f(x):
with mesh(np.empty((), dtype=np.object), ()):
with mesh(np.empty((), dtype=np.object_), ()):
@partial(xmap, in_axes={0: 'b'}, out_axes={0: 'b'})
def h(x):
return x