Merge pull request #15329 from jakevdp:padfunc-protocol-2

PiperOrigin-RevId: 520793934
This commit is contained in:
jax authors 2023-03-30 18:19:43 -07:00
commit 248ffc2ca2

View File

@ -32,7 +32,7 @@ import operator
import types
from typing import (
overload, Any, Callable, Dict, FrozenSet, List, Literal,
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union)
NamedTuple, Optional, Protocol, Sequence, Tuple, TypeVar, Union)
from textwrap import dedent as _dedent
import warnings
@ -1349,8 +1349,11 @@ def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = None,
PadValueLike = Union[T, Sequence[T], Sequence[Sequence[T]]]
PadValue = Tuple[Tuple[T, T], ...]
# TODO(jakevdp): make this a protocol
PadStatFunc = Callable[..., Array]
class PadStatFunc(Protocol):
def __call__(self, array: ArrayLike, /, *,
axis: Optional[int] = None,
keepdims: bool = False) -> Array: ...
def _broadcast_to_pairs(nvals: PadValueLike, nd: int, name: str) -> PadValue:
@ -1597,8 +1600,7 @@ def _pad_func(array: Array, pad_width: PadValue[int], func: Callable[..., Any],
@partial(jit, static_argnums=(1, 2, 4, 5, 6))
def _pad(array: ArrayLike, pad_width: PadValueLike[int],
mode: Union[str, PadStatFunc],
def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str,
constant_values: ArrayLike, stat_length: PadValueLike[int],
end_values: PadValueLike[ArrayLike], reflect_type: str):
array = asarray(array)
@ -1608,7 +1610,11 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int],
return array
stat_funcs: Dict[str, PadStatFunc] = {
"maximum": reductions.amax, "minimum": reductions.amin, "mean": reductions.mean, "median": reductions.median}
"maximum": reductions.amax,
"minimum": reductions.amin,
"mean": reductions.mean,
"median": reductions.median
}
pad_width = _broadcast_to_pairs(pad_width, nd, "pad_width")
pad_width_arr = np.array(pad_width)