mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #5573 from gnecula:numpy_1_20
PiperOrigin-RevId: 355831312
This commit is contained in:
commit
55907e2d11
9
.github/workflows/ci-build.yaml
vendored
9
.github/workflows/ci-build.yaml
vendored
@ -58,6 +58,15 @@ jobs:
|
||||
enable-omnistaging: 1
|
||||
package-overrides: "none"
|
||||
num_generated_cases: 25
|
||||
# TODO: re-enable this for numpy 1.20
|
||||
# - name-prefix: "with numpy-dispatch"
|
||||
# python-version: 3.7
|
||||
# os: ubuntu-latest
|
||||
# enable-x64: 1
|
||||
# enable-omnistaging: 1
|
||||
# # Test experimental NumPy dispatch
|
||||
# package-overrides: "git+https://github.com/seberg/numpy-dispatch.git"
|
||||
# num_generated_cases: 10
|
||||
- name-prefix: "with internal numpy"
|
||||
python-version: 3.6
|
||||
os: ubuntu-latest
|
||||
|
@ -10,6 +10,10 @@ Change Log
|
||||
jaxlib 0.1.61 (Unreleased)
|
||||
--------------------------
|
||||
|
||||
.. AFTER RELEASING THE NEXT JAXLIB, please remove the numpy<1.20 requirements in
|
||||
.. test-requirements.txt. It is there only to match the previously released
|
||||
.. jaxlib 0.1.60
|
||||
|
||||
* Bug fixes:
|
||||
|
||||
|
||||
@ -31,7 +35,7 @@ jax 0.2.10 (Unreleased)
|
||||
made to match the semantics of NumPy 1.20.0.
|
||||
* Several `jax.numpy` functions no longer accept tuples or lists in place
|
||||
of array arguments: :func:`jax.numpy.pad`, :func`jax.numpy.ravel`,
|
||||
:func:`jax.numpy.repeat`.
|
||||
:func:`jax.numpy.repeat`, :func:`jax.numpy.reshape`.
|
||||
In general, `jax.numpy` functions should be used with scalars or array arguments.
|
||||
|
||||
jaxlib 0.1.60 (Febuary 3 2021)
|
||||
|
@ -591,10 +591,10 @@ def conv_general_dilated(
|
||||
rhs_dilation = (1,) * (rhs.ndim - 2)
|
||||
if isinstance(padding, str):
|
||||
lhs_perm, rhs_perm, _ = dnums
|
||||
rhs_shape = np.take(rhs.shape, rhs_perm)[2:]
|
||||
rhs_shape = np.take(rhs.shape, rhs_perm)[2:] # type: ignore[index]
|
||||
effective_rhs_shape = [(k-1) * r + 1 for k, r in zip(rhs_shape, rhs_dilation)]
|
||||
padding = padtype_to_pads(
|
||||
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape,
|
||||
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, # type: ignore[index]
|
||||
window_strides, padding)
|
||||
return conv_general_dilated_p.bind(
|
||||
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
|
||||
@ -1515,7 +1515,7 @@ def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array:
|
||||
shape = tuple(map(int, shape))
|
||||
axes = tuple(map(int, axes))
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
base_shape = tuple(np.take(shape, axes))
|
||||
base_shape = tuple(np.take(shape, axes)) # type: ignore[arg-type]
|
||||
if config.omnistaging_enabled:
|
||||
iotas = [broadcasted_iota(np.uint32, base_shape, i)
|
||||
for i in range(len(base_shape))]
|
||||
@ -1716,7 +1716,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
|
||||
raise ValueError('No 4+ dimensional dimension_number defaults.')
|
||||
dn = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
|
||||
k_shape = np.take(rhs.shape, dn.rhs_spec)
|
||||
k_sdims = k_shape[2:]
|
||||
k_sdims = k_shape[2:] # type: ignore[index]
|
||||
# Calculate correct output shape given padding and strides.
|
||||
pads: Union[str, Sequence[Tuple[int, int]]]
|
||||
if padding in {'SAME', 'VALID'}:
|
||||
@ -2734,7 +2734,7 @@ def _conv_general_dilated_shape_rule(
|
||||
rhs_trans = _dilate_shape(np.take(rhs.shape, rhs_perm), rhs_dilation)
|
||||
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
|
||||
batch_group_count)
|
||||
return tuple(np.take(out_trans, np.argsort(out_perm)))
|
||||
return tuple(np.take(out_trans, np.argsort(out_perm))) # type: ignore[arg-type]
|
||||
|
||||
def _conv_general_dilated_dtype_rule(
|
||||
lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
@ -3132,7 +3132,7 @@ def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
|
||||
else:
|
||||
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
|
||||
dims = ((ans_y, y_kept), (ans_batch, y_batch))
|
||||
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
|
||||
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) # type: ignore[arg-type]
|
||||
out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
|
||||
return transpose(dot_general(g, y, dims, precision=precision, preferred_element_type=preferred_element_type),
|
||||
tuple(out_axes))
|
||||
@ -3733,7 +3733,7 @@ transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
|
||||
'transpose')
|
||||
transpose_p.def_impl(_transpose_impl)
|
||||
ad.deflinear2(transpose_p,
|
||||
lambda t, _, permutation: [transpose(t, np.argsort(permutation))])
|
||||
lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) # type: ignore[arg-type]
|
||||
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
|
||||
masking.masking_rules[transpose_p] = _transpose_masking_rule
|
||||
|
||||
|
@ -1993,7 +1993,7 @@ def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
if axis is None:
|
||||
normalizer = size(a)
|
||||
else:
|
||||
normalizer = np.prod(np.take(shape(a), axis))
|
||||
normalizer = np.prod(np.take(shape(a), axis)) # type: ignore
|
||||
if dtype is None:
|
||||
if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
|
||||
dtype = float_
|
||||
@ -2074,7 +2074,7 @@ def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
if axis is None:
|
||||
normalizer = size(a)
|
||||
else:
|
||||
normalizer = np.prod(np.take(shape(a), axis))
|
||||
normalizer = np.prod(np.take(shape(a), axis)) # type: ignore
|
||||
normalizer = normalizer - ddof
|
||||
|
||||
result = sum(centered, axis, keepdims=keepdims)
|
||||
|
@ -57,7 +57,7 @@ def PRNGKey(seed: int) -> jnp.ndarray:
|
||||
|
||||
# Explicitly cast to int64 for JIT invariance of behavior on large ints.
|
||||
if isinstance(seed, int):
|
||||
seed = np.int64(seed)
|
||||
seed = np.int64(seed) # type: ignore[assignment]
|
||||
# Converting to jnp.array may truncate bits when jax_enable_x64=False, but this
|
||||
# is necessary for the sake of JIT invariance of the result for such values.
|
||||
seed = jnp.asarray(seed)
|
||||
@ -643,7 +643,7 @@ def _normal_real(key, shape, dtype) -> jnp.ndarray:
|
||||
_check_shape("normal", shape)
|
||||
lo = np.nextafter(np.array(-1., dtype), 0., dtype=dtype)
|
||||
hi = np.array(1., dtype)
|
||||
u = uniform(key, shape, dtype, lo, hi)
|
||||
u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type]
|
||||
return np.array(np.sqrt(2), dtype) * lax.erf_inv(u)
|
||||
|
||||
|
||||
@ -773,7 +773,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> jnp.ndarray:
|
||||
|
||||
|
||||
def bernoulli(key: jnp.ndarray,
|
||||
p: jnp.ndarray = np.float32(0.5),
|
||||
p: jnp.ndarray = np.float32(0.5), # type: ignore[assignment]
|
||||
shape: Optional[Union[Sequence[int], NamedShape]] = None) -> jnp.ndarray:
|
||||
"""Sample Bernoulli random values with given shape and mean.
|
||||
|
||||
|
@ -76,4 +76,4 @@ for t in dtypes.python_scalar_dtypes:
|
||||
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
|
||||
ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)
|
||||
|
||||
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
|
||||
core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) # type: ignore
|
||||
|
@ -46,8 +46,8 @@ _bfloat16_dtype = np.dtype(bfloat16)
|
||||
# Default types.
|
||||
|
||||
bool_ = np.bool_
|
||||
int_ = np.int64
|
||||
float_ = np.float64
|
||||
int_: np.dtype = np.int64 # type: ignore
|
||||
float_: np.dtype = np.float64 # type: ignore
|
||||
complex_ = np.complex128
|
||||
|
||||
# TODO(phawkins): change the above defaults to:
|
||||
@ -218,7 +218,7 @@ _jax_types = [
|
||||
np.dtype('float64'),
|
||||
np.dtype('complex64'),
|
||||
np.dtype('complex128'),
|
||||
] + _weak_types
|
||||
] + _weak_types # type: ignore[operator]
|
||||
|
||||
def _jax_type(value):
|
||||
"""Return the jax type for a value or type."""
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
import itertools
|
||||
import numpy as np
|
||||
from typing import Any, Callable, Optional, Sequence
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
@ -35,8 +35,8 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
self,
|
||||
description: str,
|
||||
*,
|
||||
devices: Sequence[str] = ("cpu", "gpu", "tpu"),
|
||||
dtypes: Sequence[DType] = (),
|
||||
devices: Union[str, Sequence[str]] = ("cpu", "gpu", "tpu"),
|
||||
dtypes: Union[DType, Sequence[DType]] = (),
|
||||
enabled: bool = True,
|
||||
# jax2tf specific
|
||||
modes=("eager", "graph", "compiled"),
|
||||
|
@ -293,8 +293,8 @@ class Limitation:
|
||||
description: str,
|
||||
*,
|
||||
enabled: bool = True,
|
||||
devices: Sequence[str] = ("cpu", "gpu", "tpu"),
|
||||
dtypes: Sequence[DType] = (),
|
||||
devices: Union[str, Sequence[str]] = ("cpu", "gpu", "tpu"),
|
||||
dtypes: Union[DType, Sequence[DType]] = (),
|
||||
skip_run: bool = False,
|
||||
):
|
||||
"""Args:
|
||||
|
@ -326,7 +326,7 @@ def spec_to_indices(shape: Tuple[int, ...],
|
||||
int, a slice object with step=1, or a tuple thereof, to be treated as an
|
||||
index into the full logical array.
|
||||
"""
|
||||
return tuple(spec.indices(shape).flat)
|
||||
return tuple(spec.indices(shape).flat) # type: ignore
|
||||
|
||||
|
||||
### util
|
||||
|
1
mypy.ini
1
mypy.ini
@ -1,4 +1,5 @@
|
||||
[mypy]
|
||||
show_error_codes=True
|
||||
|
||||
[mypy-absl.*]
|
||||
ignore_missing_imports = True
|
||||
|
2
setup.py
2
setup.py
@ -29,7 +29,7 @@ setup(
|
||||
package_data={'jax': ['py.typed']},
|
||||
python_requires='>=3.6',
|
||||
install_requires=[
|
||||
'numpy >=1.12,<1.20',
|
||||
'numpy >=1.12',
|
||||
'absl-py',
|
||||
'opt_einsum',
|
||||
],
|
||||
|
@ -973,7 +973,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
def testNestedDifferentResources(self):
|
||||
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
|
||||
def f(x):
|
||||
with mesh(np.empty((), dtype=np.object), ()):
|
||||
with mesh(np.empty((), dtype=np.object_), ()):
|
||||
@partial(xmap, in_axes={0: 'b'}, out_axes={0: 'b'})
|
||||
def h(x):
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user