avoid unnecessary broadcasting in jax.numpy.vectorize

This commit is contained in:
Giacomo Petrillo 2023-03-04 18:00:30 +01:00
parent ad8c39ad7c
commit 95a5b4e48a
2 changed files with 36 additions and 22 deletions

View File

@ -278,36 +278,41 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None):
excluded_func, dim_sizes, output_core_dims, error_context)
# Rather than broadcasting all arguments to full broadcast shapes, prefer
# expanding dimensions using vmap when possible. By pushing broadcasting
# expanding dimensions using vmap. By pushing broadcasting
# into vmap, we can make use of more efficient batching rules for
# primitives where only some arguments are batched (e.g., for
# lax_linalg.triangular_solve).
# lax_linalg.triangular_solve), and avoid instantiating large broadcasted
# arrays.
vec_args = []
vmap_counts = []
squeezed_args = []
rev_filled_shapes = []
for arg, core_dims in zip(args, input_core_dims):
# Explicitly broadcast the dimensions already found on each argument,
# because these dimensiosns might be of size 1, which vmap doesn't
# handle.
# TODO(shoyer): Consider squeezing out size 1 dimensions instead, and
# doing all vectorization with vmap? This *might* be a little more
# efficient but would require more careful book-keeping.
core_shape = tuple(dim_sizes[dim] for dim in core_dims)
full_shape = broadcast_shape + core_shape
vec_shape = full_shape[-arg.ndim:] if arg.ndim else ()
noncore_shape = arg.shape[:arg.ndim - len(core_dims)]
vec_arg = jnp.broadcast_to(arg, vec_shape)
vec_args.append(vec_arg)
pad_ndim = len(broadcast_shape) - len(noncore_shape)
filled_shape = pad_ndim * (1,) + noncore_shape
rev_filled_shapes.append(filled_shape[::-1])
vmap_count = len(vec_shape) - len(core_shape)
vmap_counts.append(vmap_count)
squeeze_indices = tuple(i for i, size in enumerate(noncore_shape) if size == 1)
squeezed_arg = jnp.squeeze(arg, axis=squeeze_indices)
squeezed_args.append(squeezed_arg)
vectorized_func = checked_func
while any(vmap_counts):
in_axes = tuple(0 if c > 0 else None for c in vmap_counts)
vmap_counts = [max(c - 1, 0) for c in vmap_counts]
vectorized_func = api.vmap(vectorized_func, in_axes)
return vectorized_func(*vec_args)
dims_to_expand = []
for negdim, axis_sizes in enumerate(zip(*rev_filled_shapes)):
in_axes = tuple(None if size == 1 else 0 for size in axis_sizes)
if all(axis is None for axis in in_axes):
dims_to_expand.append(len(broadcast_shape) - 1 - negdim)
else:
vectorized_func = api.vmap(vectorized_func, in_axes)
result = vectorized_func(*squeezed_args)
if not dims_to_expand:
return result
elif isinstance(result, tuple):
return tuple(jnp.expand_dims(r, axis=dims_to_expand) for r in result)
else:
return jnp.expand_dims(result, axis=dims_to_expand)
return wrapped

View File

@ -32,6 +32,7 @@ class VectorizeTest(jtu.JaxTestCase):
for left_shape, right_shape, result_shape in [
((2, 3), (3, 4), (2, 4)),
((2, 3), (1, 3, 4), (1, 2, 4)),
((1, 2, 3), (1, 3, 4), (1, 2, 4)),
((5, 2, 3), (1, 3, 4), (5, 2, 4)),
((6, 5, 2, 3), (3, 4), (6, 5, 2, 4)),
]
@ -216,6 +217,14 @@ class VectorizeTest(jtu.JaxTestCase):
ValueError, r"inconsistent size for core dimension 'n'"):
f(jnp.zeros((2, 3)), jnp.zeros((3, 4)))
def test_expand_dims_multiple_outputs_no_signature(self):
f = jnp.vectorize(lambda x: (x, x))
x = jnp.arange(1)
xx = f(x)
self.assertAllClose(xx[0], x)
self.assertAllClose(xx[1], x)
self.assertIsInstance(xx, tuple)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())