Fix test failures and update changelog.

Use dtypes.issubdtype to test for subtyping otherwise we mishandle bfloat16 dtypes.
Don't pass an empty list to concatenate() when converting a shape to a value.
Forbid empty lists as arguments to lax.concatenate().
This commit is contained in:
Peter Hawkins 2021-11-16 17:36:28 -05:00
parent b7e3129d19
commit 1bcedd58cb
2 changed files with 9 additions and 2 deletions

View File

@ -13,6 +13,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
commits](https://github.com/google/jax/compare/jax-v0.2.25...main).
## jaxlib 0.1.74 (Unreleased)
* Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via
the host, which is usually slower.
* Added experimental MLIR Python bindings for use by JAX.
## jax 0.2.25 (Nov 10, 2021)
* [GitHub

View File

@ -525,6 +525,8 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
Returns:
An array containing the concatenation.
"""
if len(operands) == 0:
raise ValueError("concatenate requires a non-empty sequences of arrays")
return concatenate_p.bind(*operands, dimension=dimension)
@ -4775,6 +4777,8 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers,
def _shape_as_value(shape):
"""Converts a shape that may contain Poly values into a JAX value."""
if len(shape) == 0:
return full((0,), np.array(0, np.int64))
dims = [
expand_dims(convert_element_type(core.dimension_as_value(d), np.int64),
(0,))
@ -6605,10 +6609,10 @@ def _sort_lt_comparator(*operands, num_keys=1):
x_keys, y_keys = [], []
for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]):
assert x.dtype == y.dtype, (x.dtype, y.dtype)
if np.issubdtype(x.dtype, np.complexfloating):
if dtypes.issubdtype(x.dtype, np.complexfloating):
x_keys.extend([_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))])
y_keys.extend([_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))])
elif np.issubdtype(x.dtype, np.floating):
elif dtypes.issubdtype(x.dtype, np.floating):
x_keys.append(_float_to_int_for_sort(x))
y_keys.append(_float_to_int_for_sort(y))
else: