From 7b6a88c9aa2221a47ac03de15d7b00a359fc4775 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Sep 2023 15:50:15 -0700 Subject: [PATCH] [typing] better annotations for jnp.ufunc --- jax/_src/numpy/ufunc_api.py | 72 +++++++++++++++++++++++++------------ jax/_src/util.py | 4 +-- 2 files changed, 51 insertions(+), 25 deletions(-) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index a0296322c..0ef7aeba8 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -24,13 +24,14 @@ import operator from typing import Any, Callable, Optional import jax +from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.lax import lax as lax_internal from jax._src.numpy import reductions from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take from jax._src.numpy.reductions import _moveaxis from jax._src.numpy.util import _wraps, check_arraylike, _broadcast_to, _where from jax._src.numpy.vectorize import vectorize -from jax._src.util import canonicalize_axis +from jax._src.util import canonicalize_axis, set_module import numpy as np @@ -56,28 +57,35 @@ def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> Optional[jax return None -_primitive_reducers = { +_primitive_reducers: dict[jax.core.Primitive, Callable[..., Any]] = { lax_internal.add_p: reductions.sum, lax_internal.mul_p: reductions.prod, } -_primitive_accumulators = { +_primitive_accumulators: dict[jax.core.Primitive, Callable[..., Any]] = { lax_internal.add_p: reductions.cumsum, lax_internal.mul_p: reductions.cumprod, } +@set_module('jax.numpy') class ufunc: """Functions that operate element-by-element on whole arrays. This is a class for LAX-backed implementations of numpy ufuncs. """ - def __init__(self, func, /, nin, nout, *, name=None, nargs=None, identity=None): + def __init__(self, func: Callable[..., Any], /, + nin: int, nout: int, *, + name: Optional[str] = None, + nargs: Optional[int] = None, + identity: Any = None, update_doc=False): # We want ufunc instances to work properly when marked as static, # and for this reason it's important that their properties not be # mutated. We prevent this by storing them in a dunder attribute, # and accessing them via read-only properties. + if update_doc: + self.__doc__ = func.__doc__ self.__name__ = name or func.__name__ self.__static_props = { 'func': func, @@ -95,21 +103,23 @@ class ufunc: nargs = property(lambda self: self.__static_props['nargs']) identity = property(lambda self: self.__static_props['identity']) - def __hash__(self): + def __hash__(self) -> int: # Do not include _call, because it is computed from _func. return hash((self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: # Do not include _call, because it is computed from _func. return isinstance(other, ufunc) and ( (self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs) == (other._func, other.__name__, other.identity, other.nin, other.nout, other.nargs)) - def __repr__(self): + def __repr__(self) -> str: return f"" - def __call__(self, *args, out=None, where=None, **kwargs): + def __call__(self, *args: ArrayLike, + out: None = None, where: None = None, + **kwargs: Any) -> Any: if out is not None: raise NotImplementedError(f"out argument of {self}") if where is not None: @@ -118,7 +128,9 @@ class ufunc: @_wraps(np.ufunc.reduce, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) - def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False, initial=None, where=None): + def reduce(self, a: ArrayLike, axis: int = 0, dtype: Optional[DTypeLike] = None, + out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, + where: Optional[ArrayLike] = None) -> Array: check_arraylike(f"{self.__name__}.reduce", a) if self.nin != 2: raise ValueError("reduce only supported for binary ufuncs") @@ -136,10 +148,15 @@ class ufunc: if lax_internal._dtype(where) != bool: raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}") primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)])) - reducer = _primitive_reducers.get(primitive, self._reduce_via_scan) + if primitive is None: + reducer = self._reduce_via_scan + else: + reducer = _primitive_reducers.get(primitive, self._reduce_via_scan) return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) - def _reduce_via_scan(self, arr, axis=0, dtype=None, keepdims=False, initial=None, where=None): + def _reduce_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: Optional[DTypeLike] = None, + keepdims: bool = False, initial: Optional[ArrayLike] = None, + where: Optional[ArrayLike] = None) -> Array: assert self.nin == 2 and self.nout == 1 arr = lax_internal.asarray(arr) if initial is None: @@ -182,6 +199,7 @@ class ufunc: else: return _where(where[i], self._call(val, arr[i].astype(dtype)), val) + start_value: ArrayLike if initial is None: start_index = 1 start_value = arr[0] @@ -198,7 +216,8 @@ class ufunc: @_wraps(np.ufunc.accumulate, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) - def accumulate(self, a, axis=0, dtype=None, out=None): + def accumulate(self, a: ArrayLike, axis: int = 0, dtype: Optional[DTypeLike] = None, + out: None = None) -> Array: if self.nin != 2: raise ValueError("accumulate only supported for binary ufuncs") if self.nout != 1: @@ -206,10 +225,14 @@ class ufunc: if out is not None: raise NotImplementedError(f"out argument of {self.__name__}.accumulate()") primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)])) - accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan) + if primitive is None: + accumulator = self._accumulate_via_scan + else: + accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan) return accumulator(a, axis=axis, dtype=dtype) - def _accumulate_via_scan(self, arr, axis=0, dtype=None): + def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0, + dtype: Optional[DTypeLike] = None) -> Array: assert self.nin == 2 and self.nout == 1 check_arraylike(f"{self.__name__}.accumulate", arr) arr = lax_internal.asarray(arr) @@ -231,7 +254,8 @@ class ufunc: @_wraps(np.ufunc.accumulate, module="numpy.ufunc") @partial(jax.jit, static_argnums=[0], static_argnames=['inplace']) - def at(self, a, indices, b=None, /, *, inplace=True): + def at(self, a: ArrayLike, indices: Any, b: Optional[ArrayLike] = None, /, *, + inplace: bool = True) -> Array: if inplace: raise NotImplementedError(_AT_INPLACE_WARNING) if b is None: @@ -239,7 +263,7 @@ class ufunc: else: return self._at_via_scan(a, indices, b) - def _at_via_scan(self, a, indices, *args): + def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array: check_arraylike(f"{self.__name__}.at", a, *args) dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype a = lax_internal.asarray(a).astype(dtype) @@ -266,7 +290,8 @@ class ufunc: @_wraps(np.ufunc.reduceat, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) - def reduceat(self, a, indices, axis=0, dtype=None, out=None): + def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0, + dtype: Optional[DTypeLike] = None, out: None = None) -> Array: if self.nin != 2: raise ValueError("reduceat only supported for binary ufuncs") if self.nout != 1: @@ -275,7 +300,8 @@ class ufunc: raise NotImplementedError(f"out argument of {self.__name__}.reduceat()") return self._reduceat_via_scan(a, indices, axis=axis, dtype=dtype) - def _reduceat_via_scan(self, a, indices, axis=0, dtype=None): + def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, + dtype: Optional[DTypeLike] = None) -> Array: check_arraylike(f"{self.__name__}.reduceat", a, indices) a = lax_internal.asarray(a) idx_tuple = _eliminate_deprecated_list_indexing(indices) @@ -292,7 +318,7 @@ class ufunc: axis = canonicalize_axis(axis, a.ndim) out = take(a, indices, axis=axis) ind = jax.lax.expand_dims(append(indices, a.shape[axis]), - np.delete(np.arange(out.ndim), axis)) + list(np.delete(np.arange(out.ndim), axis))) ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis) ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis) def loop_body(i, out): @@ -303,7 +329,7 @@ class ufunc: @_wraps(np.ufunc.outer, module="numpy.ufunc") @partial(jax.jit, static_argnums=[0]) - def outer(self, A, B, /, **kwargs): + def outer(self, A: ArrayLike, B: ArrayLike, /, **kwargs) -> Array: if self.nin != 2: raise ValueError("outer only supported for binary ufuncs") if self.nout != 1: @@ -314,7 +340,8 @@ class ufunc: return result.reshape(*np.shape(A), *np.shape(B)) -def frompyfunc(func, /, nin, nout, *, identity=None): +def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, + *, identity: Any = None) -> ufunc: """Create a JAX ufunc from an arbitrary JAX-compatible scalar function. Args: @@ -326,5 +353,4 @@ def frompyfunc(func, /, nin, nout, *, identity=None): Returns: wrapped : jax.numpy.ufunc wrapper of func. """ - # TODO(jakevdp): use functools.wraps or similar to wrap the docstring? - return ufunc(func, nin, nout, identity=identity) + return ufunc(func, nin, nout, identity=identity, update_doc=True) diff --git a/jax/_src/util.py b/jax/_src/util.py index 3f5a6843e..ae5a55fa5 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -527,8 +527,8 @@ def _original_func(f): return f -def set_module(module): - def wrapper(func): +def set_module(module: str) -> Callable[[T], T]: + def wrapper(func: T) -> T: if module is not None: func.__module__ = module return func