jnp.ufunc: add fast paths for add/prod reductions

This commit is contained in:
Jake VanderPlas 2023-08-28 08:30:23 -07:00
parent f407298f90
commit cb7c7ad942
2 changed files with 110 additions and 2 deletions

View File

@ -21,9 +21,11 @@ np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g.
"""
from functools import partial
import operator
from typing import Any, Callable, Optional
import jax
from jax._src.lax import lax as lax_internal
from jax._src.numpy import reductions
from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take
from jax._src.numpy.reductions import _moveaxis
from jax._src.numpy.util import _wraps, check_arraylike, _broadcast_to, _where
@ -32,6 +34,40 @@ from jax._src.util import canonicalize_axis
import numpy as np
def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> Optional[jax.core.Primitive]:
"""
If fun(*args) lowers to a single primitive with inputs and outputs matching
function inputs and outputs, return that primitive. Otherwise return None.
"""
try:
jaxpr = jax.make_jaxpr(fun)(*args)
except:
return None
while len(jaxpr.eqns) == 1:
eqn = jaxpr.eqns[0]
if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars):
return None
elif (eqn.primitive == jax._src.pjit.pjit_p and
all(jax._src.pjit.is_unspecified(sharding) for sharding in
(*eqn.params['in_shardings'], *eqn.params['out_shardings']))):
jaxpr = jaxpr.eqns[0].params['jaxpr']
else:
return jaxpr.eqns[0].primitive
return None
_primitive_reducers = {
lax_internal.add_p: reductions.sum,
lax_internal.mul_p: reductions.prod,
}
_primitive_accumulators = {
lax_internal.add_p: reductions.cumsum,
lax_internal.mul_p: reductions.cumprod,
}
class ufunc:
"""Functions that operate element-by-element on whole arrays.
@ -99,7 +135,9 @@ class ufunc:
"so to use a where mask one has to specify 'initial'.")
if lax_internal._dtype(where) != bool:
raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}")
return self._reduce_via_scan(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)]))
reducer = _primitive_reducers.get(primitive, self._reduce_via_scan)
return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
def _reduce_via_scan(self, arr, axis=0, dtype=None, keepdims=False, initial=None, where=None):
assert self.nin == 2 and self.nout == 1
@ -167,7 +205,9 @@ class ufunc:
raise ValueError("accumulate only supported for functions returning a single value")
if out is not None:
raise NotImplementedError(f"out argument of {self.__name__}.accumulate()")
return self._accumulate_via_scan(a, axis=axis, dtype=dtype)
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)]))
accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan)
return accumulator(a, axis=axis, dtype=dtype)
def _accumulate_via_scan(self, arr, axis=0, dtype=None):
assert self.nin == 2 and self.nout == 1

View File

@ -22,6 +22,7 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.numpy.ufunc_api import get_if_single_primitive
from jax import config
config.parse_flags_with_absl()
@ -55,6 +56,19 @@ SCALAR_FUNCS = [
{'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None},
]
FASTPATH_FUNCS = [
{'func': jnp.add, 'nin': 2, 'nout': 1, 'identity': 0,
'reducer': jax.lax.reduce_sum_p, 'accumulator': jax.lax.cumsum_p},
{'func': jnp.multiply, 'nin': 2, 'nout': 1, 'identity': 1,
'reducer': jax.lax.reduce_prod_p, 'accumulator': jax.lax.cumprod_p},
]
NON_FASTPATH_FUNCS = [
{'func': lambda a, b: jnp.add(a, a), 'nin': 2, 'nout': 1, 'identity': 0},
{'func': lambda a, b: jnp.multiply(b, a), 'nin': 2, 'nout': 1, 'identity': 1},
{'func': jax.jit(lambda a, b: jax.jit(jnp.multiply)(b, a)), 'nin': 2, 'nout': 1, 'identity': 1},
]
broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
nonscalar_shapes = [(3,), (4,), (4, 3)]
@ -180,6 +194,44 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
FASTPATH_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating,
)
def test_reduce_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator):
del accumulator # unused
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
rng = jtu.rand_default(self.rng())
args = (rng(shape, dtype),)
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
self.assertEqual(get_if_single_primitive(jnp_fun, *args), reducer)
@jtu.sample_product(
NON_FASTPATH_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating,
)
def test_non_fastpath(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
rng = jtu.rand_default(self.rng())
args = (rng(shape, dtype),)
_ = func(0, 0) # function should not error.
reduce_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
self.assertIsNone(get_if_single_primitive(reduce_fun, *args))
accum_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
self.assertIsNone(get_if_single_primitive(accum_fun, *args))
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
@ -199,6 +251,22 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
FASTPATH_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating,
)
def test_accumulate_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator):
del reducer # unused
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
rng = jtu.rand_default(self.rng())
args = (rng(shape, dtype),)
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
self.assertEqual(get_if_single_primitive(jnp_fun, *args), accumulator)
@jtu.sample_product(
SCALAR_FUNCS,
shape=nonscalar_shapes,