jnp.ufunc: add __hash__ method and jit methods by default

This allows the JIT cache to work properly with ufunc methods, because bound
methods are created with a new ID each time.
This commit is contained in:
Jake VanderPlas 2023-08-14 09:51:50 -07:00
parent 619377ebc1
commit b3a02e1b62
2 changed files with 78 additions and 9 deletions

View File

@ -38,13 +38,37 @@ class ufunc:
This is a class for LAX-backed implementations of numpy ufuncs.
"""
def __init__(self, func, /, nin, nout, *, name=None, nargs=None, identity=None):
# TODO(jakevdp): validate the signature of func via eval_shape.
# We want ufunc instances to work properly when marked as static,
# and for this reason it's important that their properties not be
# mutated. We prevent this by storing them in a dunder attribute,
# and accessing them via read-only properties.
self.__name__ = name or func.__name__
self._call = vectorize(func)
self.nin = operator.index(nin)
self.nout = operator.index(nout)
self.nargs = nargs or self.nin
self.identity = identity
self.__static_props = {
'func': func,
'call': vectorize(func),
'nin': operator.index(nin),
'nout': operator.index(nout),
'nargs': operator.index(nargs or nin),
'identity': identity
}
_func = property(lambda self: self.__static_props['func'])
_call = property(lambda self: self.__static_props['call'])
nin = property(lambda self: self.__static_props['nin'])
nout = property(lambda self: self.__static_props['nout'])
nargs = property(lambda self: self.__static_props['nargs'])
identity = property(lambda self: self.__static_props['identity'])
def __hash__(self):
# Do not include _call, because it is computed from _func.
return hash((self._func, self.__name__, self.identity,
self.nin, self.nout, self.nargs))
def __eq__(self, other):
# Do not include _call, because it is computed from _func.
return isinstance(other, ufunc) and (
(self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs) ==
(other._func, other.__name__, other.identity, other.nin, other.nout, other.nargs))
def __repr__(self):
return f"<jnp.ufunc '{self.__name__}'>"
@ -57,6 +81,7 @@ class ufunc:
return self._call(*args, **kwargs)
@_wraps(np.ufunc.reduce, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims'])
def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False, initial=None, where=None):
if self.nin != 2:
raise ValueError("reduce only supported for binary ufuncs")
@ -120,6 +145,7 @@ class ufunc:
return result
@_wraps(np.ufunc.accumulate, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
def accumulate(self, a, axis=0, dtype=None, out=None):
if self.nin != 2:
raise ValueError("accumulate only supported for binary ufuncs")
@ -150,6 +176,7 @@ class ufunc:
return _moveaxis(result, 0, axis)
@_wraps(np.ufunc.accumulate, module="numpy.ufunc")
@partial(jax.jit, static_argnums=[0], static_argnames=['inplace'])
def at(self, a, indices, b=None, /, *, inplace=True):
if inplace:
raise NotImplementedError(_AT_INPLACE_WARNING)
@ -184,6 +211,7 @@ class ufunc:
return carry[1]
@_wraps(np.ufunc.reduceat, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
def reduceat(self, a, indices, axis=0, dtype=None, out=None):
if self.nin != 2:
raise ValueError("reduceat only supported for binary ufuncs")
@ -220,6 +248,7 @@ class ufunc:
return jax.lax.fori_loop(0, a.shape[axis], loop_body, out)
@_wraps(np.ufunc.outer, module="numpy.ufunc")
@partial(jax.jit, static_argnums=[0])
def outer(self, A, B, /, **kwargs):
if self.nin != 2:
raise ValueError("outer only supported for binary ufuncs")

View File

@ -33,14 +33,26 @@ def scalar_add(x, y):
return x + y
def scalar_div(x, y):
assert np.shape(x) == np.shape(y) == ()
return x / y
def scalar_mul(x, y):
assert np.shape(x) == np.shape(y) == ()
return x * y
def scalar_sub(x, y):
assert np.shape(x) == np.shape(y) == ()
return x - y
SCALAR_FUNCS = [
{'func': scalar_add, 'nin': 2, 'nout': 1, 'identity': 0},
{'func': scalar_div, 'nin': 2, 'nout': 1, 'identity': None},
{'func': scalar_mul, 'nin': 2, 'nout': 1, 'identity': 1},
{'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None},
]
broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
@ -54,6 +66,34 @@ def cast_outputs(fun):
class LaxNumpyUfuncTests(jtu.JaxTestCase):
@jtu.sample_product(SCALAR_FUNCS)
def test_ufunc_properties(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
self.assertEqual(jnp_fun.identity, identity)
self.assertEqual(jnp_fun.nin, nin)
self.assertEqual(jnp_fun.nout, nout)
self.assertEqual(jnp_fun.nargs, nin)
@jtu.sample_product(SCALAR_FUNCS)
def test_ufunc_properties_readonly(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
for attr in ['nargs', 'nin', 'nout', 'identity', '_func', '_call']:
getattr(jnp_fun, attr) # no error on attribute access.
with self.assertRaises(AttributeError):
setattr(jnp_fun, attr, None) # error when trying to mutate.
@jtu.sample_product(SCALAR_FUNCS)
def test_ufunc_hash(self, func, nin, nout, identity):
jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
jnp_fun_2 = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity)
self.assertEqual(jnp_fun, jnp_fun_2)
self.assertEqual(hash(jnp_fun), hash(jnp_fun_2))
other_fun = jnp.frompyfunc(jnp.add, nin=2, nout=1, identity=0)
self.assertNotEqual(jnp_fun, other_fun)
# Note: don't test hash for non-equality because it may collide.
@jtu.sample_product(
SCALAR_FUNCS,
lhs_shape=broadcast_compatible_shapes,
@ -124,7 +164,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker, check_cache_misses=False) # TODO(jakevdp): why the cache misses?
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
@ -146,7 +186,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32'), rng(idx_shape[1:], dtype)]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker, check_cache_misses=False) # TODO(jakevdp): why the cache misses?
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
SCALAR_FUNCS,
@ -167,7 +207,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')]
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker, check_cache_misses=False) # TODO(jakevdp): why the cache misses?
self._CompileAndCheck(jnp_fun, args_maker)
if __name__ == "__main__":