From 52fa165d75746b50708b973b19ada7e54f4f361f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 24 Jul 2024 11:00:55 -0400 Subject: [PATCH] Simplify promote_shapes. We can use lax.broadcast_to_rank instead of the considerably more complicated _broadcast_to. Add a fast path to broadcast_to_rank and broadcast to avoid emitting an equation if the rank is already correct. --- jax/_src/lax/lax.py | 9 +++++++-- jax/_src/numpy/util.py | 3 +-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f9fc12ca0..2dbc58c24 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -838,6 +838,8 @@ def broadcast(operand: ArrayLike, sizes: Sequence[int]) -> Array: See Also: jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape. """ + if len(sizes) == 0: + return asarray(operand) dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand))) return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims) @@ -872,9 +874,12 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, operand, *dyn_shape, shape=tuple(static_shape), broadcast_dimensions=tuple(broadcast_dimensions)) -def broadcast_to_rank(x: Array, rank: int) -> Array: +def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: """Adds leading dimensions of ``1`` to give ``x`` rank ``rank``.""" - return broadcast(x, (1,) * (rank - x.ndim)) + ndim = np.ndim(x) + if ndim == rank: + return asarray(x) + return broadcast(x, (1,) * (rank - ndim)) def reshape(operand: ArrayLike, new_sizes: Shape, dimensions: Sequence[int] | None = None) -> Array: diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 71f8be9f4..09ff99cb4 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -248,8 +248,7 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]: if config.numpy_rank_promotion.value != "allow": _rank_promotion_warning_or_error(fun_name, shapes) result_rank = len(lax.broadcast_shapes(*shapes)) - return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp) - for arg, shp in zip(args, shapes)] + return [lax.broadcast_to_rank(arg, result_rank) for arg in args] def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):