mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
avoid unnecessary broadcasting in jax.numpy.vectorize
This commit is contained in:
parent
ad8c39ad7c
commit
95a5b4e48a
@ -278,36 +278,41 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None):
|
||||
excluded_func, dim_sizes, output_core_dims, error_context)
|
||||
|
||||
# Rather than broadcasting all arguments to full broadcast shapes, prefer
|
||||
# expanding dimensions using vmap when possible. By pushing broadcasting
|
||||
# expanding dimensions using vmap. By pushing broadcasting
|
||||
# into vmap, we can make use of more efficient batching rules for
|
||||
# primitives where only some arguments are batched (e.g., for
|
||||
# lax_linalg.triangular_solve).
|
||||
# lax_linalg.triangular_solve), and avoid instantiating large broadcasted
|
||||
# arrays.
|
||||
|
||||
vec_args = []
|
||||
vmap_counts = []
|
||||
squeezed_args = []
|
||||
rev_filled_shapes = []
|
||||
|
||||
for arg, core_dims in zip(args, input_core_dims):
|
||||
# Explicitly broadcast the dimensions already found on each argument,
|
||||
# because these dimensiosns might be of size 1, which vmap doesn't
|
||||
# handle.
|
||||
# TODO(shoyer): Consider squeezing out size 1 dimensions instead, and
|
||||
# doing all vectorization with vmap? This *might* be a little more
|
||||
# efficient but would require more careful book-keeping.
|
||||
core_shape = tuple(dim_sizes[dim] for dim in core_dims)
|
||||
full_shape = broadcast_shape + core_shape
|
||||
vec_shape = full_shape[-arg.ndim:] if arg.ndim else ()
|
||||
noncore_shape = arg.shape[:arg.ndim - len(core_dims)]
|
||||
|
||||
vec_arg = jnp.broadcast_to(arg, vec_shape)
|
||||
vec_args.append(vec_arg)
|
||||
pad_ndim = len(broadcast_shape) - len(noncore_shape)
|
||||
filled_shape = pad_ndim * (1,) + noncore_shape
|
||||
rev_filled_shapes.append(filled_shape[::-1])
|
||||
|
||||
vmap_count = len(vec_shape) - len(core_shape)
|
||||
vmap_counts.append(vmap_count)
|
||||
squeeze_indices = tuple(i for i, size in enumerate(noncore_shape) if size == 1)
|
||||
squeezed_arg = jnp.squeeze(arg, axis=squeeze_indices)
|
||||
squeezed_args.append(squeezed_arg)
|
||||
|
||||
vectorized_func = checked_func
|
||||
while any(vmap_counts):
|
||||
in_axes = tuple(0 if c > 0 else None for c in vmap_counts)
|
||||
vmap_counts = [max(c - 1, 0) for c in vmap_counts]
|
||||
vectorized_func = api.vmap(vectorized_func, in_axes)
|
||||
return vectorized_func(*vec_args)
|
||||
dims_to_expand = []
|
||||
for negdim, axis_sizes in enumerate(zip(*rev_filled_shapes)):
|
||||
in_axes = tuple(None if size == 1 else 0 for size in axis_sizes)
|
||||
if all(axis is None for axis in in_axes):
|
||||
dims_to_expand.append(len(broadcast_shape) - 1 - negdim)
|
||||
else:
|
||||
vectorized_func = api.vmap(vectorized_func, in_axes)
|
||||
result = vectorized_func(*squeezed_args)
|
||||
|
||||
if not dims_to_expand:
|
||||
return result
|
||||
elif isinstance(result, tuple):
|
||||
return tuple(jnp.expand_dims(r, axis=dims_to_expand) for r in result)
|
||||
else:
|
||||
return jnp.expand_dims(result, axis=dims_to_expand)
|
||||
|
||||
return wrapped
|
||||
|
@ -32,6 +32,7 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
for left_shape, right_shape, result_shape in [
|
||||
((2, 3), (3, 4), (2, 4)),
|
||||
((2, 3), (1, 3, 4), (1, 2, 4)),
|
||||
((1, 2, 3), (1, 3, 4), (1, 2, 4)),
|
||||
((5, 2, 3), (1, 3, 4), (5, 2, 4)),
|
||||
((6, 5, 2, 3), (3, 4), (6, 5, 2, 4)),
|
||||
]
|
||||
@ -216,6 +217,14 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
ValueError, r"inconsistent size for core dimension 'n'"):
|
||||
f(jnp.zeros((2, 3)), jnp.zeros((3, 4)))
|
||||
|
||||
def test_expand_dims_multiple_outputs_no_signature(self):
|
||||
f = jnp.vectorize(lambda x: (x, x))
|
||||
x = jnp.arange(1)
|
||||
xx = f(x)
|
||||
self.assertAllClose(xx[0], x)
|
||||
self.assertAllClose(xx[1], x)
|
||||
self.assertIsInstance(xx, tuple)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user