mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Move _reduce_window
docstring to public func lax.reduce_window
.
This commit is contained in:
parent
640cb009f1
commit
6e9a34f791
@ -54,10 +54,6 @@ def _reduce_window(
|
||||
base_dilation: Sequence[int] | None = None,
|
||||
window_dilation: Sequence[int] | None = None,
|
||||
):
|
||||
"""Wraps XLA's `ReduceWindowWithGeneralPadding
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
|
||||
operator.
|
||||
"""
|
||||
flat_operands, operand_tree = tree_util.tree_flatten(operand)
|
||||
flat_init_values, init_value_tree = tree_util.tree_flatten(init_value)
|
||||
if operand_tree != init_value_tree:
|
||||
@ -123,6 +119,10 @@ def reduce_window(
|
||||
base_dilation: Sequence[int] | None = None,
|
||||
window_dilation: Sequence[int] | None = None,
|
||||
) -> Array:
|
||||
"""Wraps XLA's `ReduceWindowWithGeneralPadding
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
|
||||
operator.
|
||||
"""
|
||||
return _reduce_window(
|
||||
operand,
|
||||
init_value,
|
||||
|
Loading…
x
Reference in New Issue
Block a user