mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[typing] use protocol for cumulative reductions
This commit is contained in:
parent
69c9660aab
commit
6d006b5994
@ -16,7 +16,8 @@ import builtins
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from typing import overload, Any, Callable, Literal, Optional, Sequence, Tuple, Union
|
||||
from typing import (
|
||||
overload, Any, Callable, Literal, Optional, Protocol, Sequence, Tuple, Union)
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -634,9 +635,13 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None =
|
||||
return lax.sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
|
||||
|
||||
|
||||
# TODO(jakevdp): use a protocol here for better typing?
|
||||
class CumulativeReduction(Protocol):
|
||||
def __call__(self, a: ArrayLike, axis: Axis = None,
|
||||
dtype: DTypeLike = None, out: None = None) -> Array: ...
|
||||
|
||||
|
||||
def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array],
|
||||
fill_nan: bool = False, fill_value: ArrayLike = 0) -> Callable[..., Array]:
|
||||
fill_nan: bool = False, fill_value: ArrayLike = 0) -> CumulativeReduction:
|
||||
@_wraps(np_reduction, skip_params=['out'])
|
||||
def cumulative_reduction(a: ArrayLike, axis: Axis = None,
|
||||
dtype: DTypeLike = None, out: None = None) -> Array:
|
||||
|
Loading…
x
Reference in New Issue
Block a user