mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Simplify promote_shapes.
We can use lax.broadcast_to_rank instead of the considerably more complicated _broadcast_to. Add a fast path to broadcast_to_rank and broadcast to avoid emitting an equation if the rank is already correct.
This commit is contained in:
parent
f1cfd99fe8
commit
52fa165d75
@ -838,6 +838,8 @@ def broadcast(operand: ArrayLike, sizes: Sequence[int]) -> Array:
|
||||
See Also:
|
||||
jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape.
|
||||
"""
|
||||
if len(sizes) == 0:
|
||||
return asarray(operand)
|
||||
dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
|
||||
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
|
||||
|
||||
@ -872,9 +874,12 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
|
||||
operand, *dyn_shape, shape=tuple(static_shape),
|
||||
broadcast_dimensions=tuple(broadcast_dimensions))
|
||||
|
||||
def broadcast_to_rank(x: Array, rank: int) -> Array:
|
||||
def broadcast_to_rank(x: ArrayLike, rank: int) -> Array:
|
||||
"""Adds leading dimensions of ``1`` to give ``x`` rank ``rank``."""
|
||||
return broadcast(x, (1,) * (rank - x.ndim))
|
||||
ndim = np.ndim(x)
|
||||
if ndim == rank:
|
||||
return asarray(x)
|
||||
return broadcast(x, (1,) * (rank - ndim))
|
||||
|
||||
def reshape(operand: ArrayLike, new_sizes: Shape,
|
||||
dimensions: Sequence[int] | None = None) -> Array:
|
||||
|
@ -248,8 +248,7 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
|
||||
if config.numpy_rank_promotion.value != "allow":
|
||||
_rank_promotion_warning_or_error(fun_name, shapes)
|
||||
result_rank = len(lax.broadcast_shapes(*shapes))
|
||||
return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
|
||||
for arg, shp in zip(args, shapes)]
|
||||
return [lax.broadcast_to_rank(arg, result_rank) for arg in args]
|
||||
|
||||
|
||||
def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
|
||||
|
Loading…
x
Reference in New Issue
Block a user