mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15329 from jakevdp:padfunc-protocol-2
PiperOrigin-RevId: 520793934
This commit is contained in:
commit
248ffc2ca2
@ -32,7 +32,7 @@ import operator
|
||||
import types
|
||||
from typing import (
|
||||
overload, Any, Callable, Dict, FrozenSet, List, Literal,
|
||||
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union)
|
||||
NamedTuple, Optional, Protocol, Sequence, Tuple, TypeVar, Union)
|
||||
from textwrap import dedent as _dedent
|
||||
import warnings
|
||||
|
||||
@ -1349,8 +1349,11 @@ def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = None,
|
||||
|
||||
PadValueLike = Union[T, Sequence[T], Sequence[Sequence[T]]]
|
||||
PadValue = Tuple[Tuple[T, T], ...]
|
||||
# TODO(jakevdp): make this a protocol
|
||||
PadStatFunc = Callable[..., Array]
|
||||
|
||||
class PadStatFunc(Protocol):
|
||||
def __call__(self, array: ArrayLike, /, *,
|
||||
axis: Optional[int] = None,
|
||||
keepdims: bool = False) -> Array: ...
|
||||
|
||||
|
||||
def _broadcast_to_pairs(nvals: PadValueLike, nd: int, name: str) -> PadValue:
|
||||
@ -1597,8 +1600,7 @@ def _pad_func(array: Array, pad_width: PadValue[int], func: Callable[..., Any],
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2, 4, 5, 6))
|
||||
def _pad(array: ArrayLike, pad_width: PadValueLike[int],
|
||||
mode: Union[str, PadStatFunc],
|
||||
def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str,
|
||||
constant_values: ArrayLike, stat_length: PadValueLike[int],
|
||||
end_values: PadValueLike[ArrayLike], reflect_type: str):
|
||||
array = asarray(array)
|
||||
@ -1608,7 +1610,11 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int],
|
||||
return array
|
||||
|
||||
stat_funcs: Dict[str, PadStatFunc] = {
|
||||
"maximum": reductions.amax, "minimum": reductions.amin, "mean": reductions.mean, "median": reductions.median}
|
||||
"maximum": reductions.amax,
|
||||
"minimum": reductions.amin,
|
||||
"mean": reductions.mean,
|
||||
"median": reductions.median
|
||||
}
|
||||
|
||||
pad_width = _broadcast_to_pairs(pad_width, nd, "pad_width")
|
||||
pad_width_arr = np.array(pad_width)
|
||||
|
Loading…
x
Reference in New Issue
Block a user