[typing] use protocol for cumulative reductions

This commit is contained in:
Jake VanderPlas 2023-03-30 15:43:43 -07:00
parent 69c9660aab
commit 6d006b5994

View File

@ -16,7 +16,8 @@ import builtins
from functools import partial
import math
import operator
from typing import overload, Any, Callable, Literal, Optional, Sequence, Tuple, Union
from typing import (
overload, Any, Callable, Literal, Optional, Protocol, Sequence, Tuple, Union)
import warnings
import numpy as np
@ -634,9 +635,13 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None =
return lax.sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
# TODO(jakevdp): use a protocol here for better typing?
class CumulativeReduction(Protocol):
def __call__(self, a: ArrayLike, axis: Axis = None,
dtype: DTypeLike = None, out: None = None) -> Array: ...
def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array],
fill_nan: bool = False, fill_value: ArrayLike = 0) -> Callable[..., Array]:
fill_nan: bool = False, fill_value: ArrayLike = 0) -> CumulativeReduction:
@_wraps(np_reduction, skip_params=['out'])
def cumulative_reduction(a: ArrayLike, axis: Axis = None,
dtype: DTypeLike = None, out: None = None) -> Array: