mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix test failures and update changelog.
Use dtypes.issubdtype to test for subtyping otherwise we mishandle bfloat16 dtypes. Don't pass an empty list to concatenate() when converting a shape to a value. Forbid empty lists as arguments to lax.concatenate().
This commit is contained in:
parent
b7e3129d19
commit
1bcedd58cb
@ -13,6 +13,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
commits](https://github.com/google/jax/compare/jax-v0.2.25...main).
|
commits](https://github.com/google/jax/compare/jax-v0.2.25...main).
|
||||||
|
|
||||||
## jaxlib 0.1.74 (Unreleased)
|
## jaxlib 0.1.74 (Unreleased)
|
||||||
|
* Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via
|
||||||
|
the host, which is usually slower.
|
||||||
|
* Added experimental MLIR Python bindings for use by JAX.
|
||||||
|
|
||||||
## jax 0.2.25 (Nov 10, 2021)
|
## jax 0.2.25 (Nov 10, 2021)
|
||||||
* [GitHub
|
* [GitHub
|
||||||
|
@ -525,6 +525,8 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
|
|||||||
Returns:
|
Returns:
|
||||||
An array containing the concatenation.
|
An array containing the concatenation.
|
||||||
"""
|
"""
|
||||||
|
if len(operands) == 0:
|
||||||
|
raise ValueError("concatenate requires a non-empty sequences of arrays")
|
||||||
return concatenate_p.bind(*operands, dimension=dimension)
|
return concatenate_p.bind(*operands, dimension=dimension)
|
||||||
|
|
||||||
|
|
||||||
@ -4775,6 +4777,8 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers,
|
|||||||
|
|
||||||
def _shape_as_value(shape):
|
def _shape_as_value(shape):
|
||||||
"""Converts a shape that may contain Poly values into a JAX value."""
|
"""Converts a shape that may contain Poly values into a JAX value."""
|
||||||
|
if len(shape) == 0:
|
||||||
|
return full((0,), np.array(0, np.int64))
|
||||||
dims = [
|
dims = [
|
||||||
expand_dims(convert_element_type(core.dimension_as_value(d), np.int64),
|
expand_dims(convert_element_type(core.dimension_as_value(d), np.int64),
|
||||||
(0,))
|
(0,))
|
||||||
@ -6605,10 +6609,10 @@ def _sort_lt_comparator(*operands, num_keys=1):
|
|||||||
x_keys, y_keys = [], []
|
x_keys, y_keys = [], []
|
||||||
for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]):
|
for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]):
|
||||||
assert x.dtype == y.dtype, (x.dtype, y.dtype)
|
assert x.dtype == y.dtype, (x.dtype, y.dtype)
|
||||||
if np.issubdtype(x.dtype, np.complexfloating):
|
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||||
x_keys.extend([_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))])
|
x_keys.extend([_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))])
|
||||||
y_keys.extend([_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))])
|
y_keys.extend([_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))])
|
||||||
elif np.issubdtype(x.dtype, np.floating):
|
elif dtypes.issubdtype(x.dtype, np.floating):
|
||||||
x_keys.append(_float_to_int_for_sort(x))
|
x_keys.append(_float_to_int_for_sort(x))
|
||||||
y_keys.append(_float_to_int_for_sort(y))
|
y_keys.append(_float_to_int_for_sort(y))
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user