diff --git a/CHANGELOG.md b/CHANGELOG.md index b7a8f9d62..4fbb13c97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. commits](https://github.com/google/jax/compare/jax-v0.2.25...main). ## jaxlib 0.1.74 (Unreleased) +* Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via + the host, which is usually slower. +* Added experimental MLIR Python bindings for use by JAX. ## jax 0.2.25 (Nov 10, 2021) * [GitHub diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8881eb7ef..aa8814117 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -525,6 +525,8 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array: Returns: An array containing the concatenation. """ + if len(operands) == 0: + raise ValueError("concatenate requires a non-empty sequences of arrays") return concatenate_p.bind(*operands, dimension=dimension) @@ -4775,6 +4777,8 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, def _shape_as_value(shape): """Converts a shape that may contain Poly values into a JAX value.""" + if len(shape) == 0: + return full((0,), np.array(0, np.int64)) dims = [ expand_dims(convert_element_type(core.dimension_as_value(d), np.int64), (0,)) @@ -6605,10 +6609,10 @@ def _sort_lt_comparator(*operands, num_keys=1): x_keys, y_keys = [], [] for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]): assert x.dtype == y.dtype, (x.dtype, y.dtype) - if np.issubdtype(x.dtype, np.complexfloating): + if dtypes.issubdtype(x.dtype, np.complexfloating): x_keys.extend([_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))]) y_keys.extend([_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))]) - elif np.issubdtype(x.dtype, np.floating): + elif dtypes.issubdtype(x.dtype, np.floating): x_keys.append(_float_to_int_for_sort(x)) y_keys.append(_float_to_int_for_sort(y)) else: