mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
jnp.ufunc: add fast paths for add/prod reductions
This commit is contained in:
parent
f407298f90
commit
cb7c7ad942
@ -21,9 +21,11 @@ np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g.
|
||||
"""
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import jax
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take
|
||||
from jax._src.numpy.reductions import _moveaxis
|
||||
from jax._src.numpy.util import _wraps, check_arraylike, _broadcast_to, _where
|
||||
@ -32,6 +34,40 @@ from jax._src.util import canonicalize_axis
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> Optional[jax.core.Primitive]:
|
||||
"""
|
||||
If fun(*args) lowers to a single primitive with inputs and outputs matching
|
||||
function inputs and outputs, return that primitive. Otherwise return None.
|
||||
"""
|
||||
try:
|
||||
jaxpr = jax.make_jaxpr(fun)(*args)
|
||||
except:
|
||||
return None
|
||||
while len(jaxpr.eqns) == 1:
|
||||
eqn = jaxpr.eqns[0]
|
||||
if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars):
|
||||
return None
|
||||
elif (eqn.primitive == jax._src.pjit.pjit_p and
|
||||
all(jax._src.pjit.is_unspecified(sharding) for sharding in
|
||||
(*eqn.params['in_shardings'], *eqn.params['out_shardings']))):
|
||||
jaxpr = jaxpr.eqns[0].params['jaxpr']
|
||||
else:
|
||||
return jaxpr.eqns[0].primitive
|
||||
return None
|
||||
|
||||
|
||||
_primitive_reducers = {
|
||||
lax_internal.add_p: reductions.sum,
|
||||
lax_internal.mul_p: reductions.prod,
|
||||
}
|
||||
|
||||
|
||||
_primitive_accumulators = {
|
||||
lax_internal.add_p: reductions.cumsum,
|
||||
lax_internal.mul_p: reductions.cumprod,
|
||||
}
|
||||
|
||||
|
||||
class ufunc:
|
||||
"""Functions that operate element-by-element on whole arrays.
|
||||
|
||||
@ -99,7 +135,9 @@ class ufunc:
|
||||
"so to use a where mask one has to specify 'initial'.")
|
||||
if lax_internal._dtype(where) != bool:
|
||||
raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}")
|
||||
return self._reduce_via_scan(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
|
||||
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)]))
|
||||
reducer = _primitive_reducers.get(primitive, self._reduce_via_scan)
|
||||
return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
|
||||
|
||||
def _reduce_via_scan(self, arr, axis=0, dtype=None, keepdims=False, initial=None, where=None):
|
||||
assert self.nin == 2 and self.nout == 1
|
||||
@ -167,7 +205,9 @@ class ufunc:
|
||||
raise ValueError("accumulate only supported for functions returning a single value")
|
||||
if out is not None:
|
||||
raise NotImplementedError(f"out argument of {self.__name__}.accumulate()")
|
||||
return self._accumulate_via_scan(a, axis=axis, dtype=dtype)
|
||||
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)]))
|
||||
accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan)
|
||||
return accumulator(a, axis=axis, dtype=dtype)
|
||||
|
||||
def _accumulate_via_scan(self, arr, axis=0, dtype=None):
|
||||
assert self.nin == 2 and self.nout == 1
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.numpy.ufunc_api import get_if_single_primitive
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -55,6 +56,19 @@ SCALAR_FUNCS = [
|
||||
{'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None},
|
||||
]
|
||||
|
||||
FASTPATH_FUNCS = [
|
||||
{'func': jnp.add, 'nin': 2, 'nout': 1, 'identity': 0,
|
||||
'reducer': jax.lax.reduce_sum_p, 'accumulator': jax.lax.cumsum_p},
|
||||
{'func': jnp.multiply, 'nin': 2, 'nout': 1, 'identity': 1,
|
||||
'reducer': jax.lax.reduce_prod_p, 'accumulator': jax.lax.cumprod_p},
|
||||
]
|
||||
|
||||
NON_FASTPATH_FUNCS = [
|
||||
{'func': lambda a, b: jnp.add(a, a), 'nin': 2, 'nout': 1, 'identity': 0},
|
||||
{'func': lambda a, b: jnp.multiply(b, a), 'nin': 2, 'nout': 1, 'identity': 1},
|
||||
{'func': jax.jit(lambda a, b: jax.jit(jnp.multiply)(b, a)), 'nin': 2, 'nout': 1, 'identity': 1},
|
||||
]
|
||||
|
||||
broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
|
||||
nonscalar_shapes = [(3,), (4,), (4, 3)]
|
||||
|
||||
@ -180,6 +194,44 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
FASTPATH_FUNCS,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
for shape in nonscalar_shapes
|
||||
for axis in range(-len(shape), len(shape))],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def test_reduce_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator):
|
||||
del accumulator # unused
|
||||
if (nin, nout) != (2, 1):
|
||||
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args = (rng(shape, dtype),)
|
||||
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
|
||||
self.assertEqual(get_if_single_primitive(jnp_fun, *args), reducer)
|
||||
|
||||
@jtu.sample_product(
|
||||
NON_FASTPATH_FUNCS,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
for shape in nonscalar_shapes
|
||||
for axis in range(-len(shape), len(shape))],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def test_non_fastpath(self, func, nin, nout, identity, shape, axis, dtype):
|
||||
if (nin, nout) != (2, 1):
|
||||
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args = (rng(shape, dtype),)
|
||||
|
||||
_ = func(0, 0) # function should not error.
|
||||
|
||||
reduce_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
|
||||
self.assertIsNone(get_if_single_primitive(reduce_fun, *args))
|
||||
|
||||
accum_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
|
||||
self.assertIsNone(get_if_single_primitive(accum_fun, *args))
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
SCALAR_FUNCS,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
@ -199,6 +251,22 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
FASTPATH_FUNCS,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
for shape in nonscalar_shapes
|
||||
for axis in range(-len(shape), len(shape))],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
def test_accumulate_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator):
|
||||
del reducer # unused
|
||||
if (nin, nout) != (2, 1):
|
||||
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args = (rng(shape, dtype),)
|
||||
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
|
||||
self.assertEqual(get_if_single_primitive(jnp_fun, *args), accumulator)
|
||||
|
||||
@jtu.sample_product(
|
||||
SCALAR_FUNCS,
|
||||
shape=nonscalar_shapes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user