mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix test failures and update changelog.
Use dtypes.issubdtype to test for subtyping otherwise we mishandle bfloat16 dtypes. Don't pass an empty list to concatenate() when converting a shape to a value. Forbid empty lists as arguments to lax.concatenate().
This commit is contained in:
parent
b7e3129d19
commit
1bcedd58cb
@ -13,6 +13,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
commits](https://github.com/google/jax/compare/jax-v0.2.25...main).
|
||||
|
||||
## jaxlib 0.1.74 (Unreleased)
|
||||
* Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via
|
||||
the host, which is usually slower.
|
||||
* Added experimental MLIR Python bindings for use by JAX.
|
||||
|
||||
## jax 0.2.25 (Nov 10, 2021)
|
||||
* [GitHub
|
||||
|
@ -525,6 +525,8 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
|
||||
Returns:
|
||||
An array containing the concatenation.
|
||||
"""
|
||||
if len(operands) == 0:
|
||||
raise ValueError("concatenate requires a non-empty sequences of arrays")
|
||||
return concatenate_p.bind(*operands, dimension=dimension)
|
||||
|
||||
|
||||
@ -4775,6 +4777,8 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers,
|
||||
|
||||
def _shape_as_value(shape):
|
||||
"""Converts a shape that may contain Poly values into a JAX value."""
|
||||
if len(shape) == 0:
|
||||
return full((0,), np.array(0, np.int64))
|
||||
dims = [
|
||||
expand_dims(convert_element_type(core.dimension_as_value(d), np.int64),
|
||||
(0,))
|
||||
@ -6605,10 +6609,10 @@ def _sort_lt_comparator(*operands, num_keys=1):
|
||||
x_keys, y_keys = [], []
|
||||
for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]):
|
||||
assert x.dtype == y.dtype, (x.dtype, y.dtype)
|
||||
if np.issubdtype(x.dtype, np.complexfloating):
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
x_keys.extend([_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))])
|
||||
y_keys.extend([_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))])
|
||||
elif np.issubdtype(x.dtype, np.floating):
|
||||
elif dtypes.issubdtype(x.dtype, np.floating):
|
||||
x_keys.append(_float_to_int_for_sort(x))
|
||||
y_keys.append(_float_to_int_for_sort(y))
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user