Simplify promote_shapes.

We can use lax.broadcast_to_rank instead of the considerably more complicated _broadcast_to.

Add a fast path to broadcast_to_rank and broadcast to avoid emitting an equation if the rank is already correct.
This commit is contained in:
Peter Hawkins 2024-07-24 11:00:55 -04:00
parent f1cfd99fe8
commit 52fa165d75
2 changed files with 8 additions and 4 deletions

View File

@ -838,6 +838,8 @@ def broadcast(operand: ArrayLike, sizes: Sequence[int]) -> Array:
See Also:
jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape.
"""
if len(sizes) == 0:
return asarray(operand)
dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
@ -872,9 +874,12 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
operand, *dyn_shape, shape=tuple(static_shape),
broadcast_dimensions=tuple(broadcast_dimensions))
def broadcast_to_rank(x: Array, rank: int) -> Array:
def broadcast_to_rank(x: ArrayLike, rank: int) -> Array:
"""Adds leading dimensions of ``1`` to give ``x`` rank ``rank``."""
return broadcast(x, (1,) * (rank - x.ndim))
ndim = np.ndim(x)
if ndim == rank:
return asarray(x)
return broadcast(x, (1,) * (rank - ndim))
def reshape(operand: ArrayLike, new_sizes: Shape,
dimensions: Sequence[int] | None = None) -> Array:

View File

@ -248,8 +248,7 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
if config.numpy_rank_promotion.value != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
result_rank = len(lax.broadcast_shapes(*shapes))
return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
for arg, shp in zip(args, shapes)]
return [lax.broadcast_to_rank(arg, result_rank) for arg in args]
def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):