Move _reduce_window docstring to public func lax.reduce_window.

This commit is contained in:
David Boetius 2025-01-09 13:31:48 +01:00 committed by GitHub
parent 640cb009f1
commit 6e9a34f791
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -54,10 +54,6 @@ def _reduce_window(
base_dilation: Sequence[int] | None = None,
window_dilation: Sequence[int] | None = None,
):
"""Wraps XLA's `ReduceWindowWithGeneralPadding
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
operator.
"""
flat_operands, operand_tree = tree_util.tree_flatten(operand)
flat_init_values, init_value_tree = tree_util.tree_flatten(init_value)
if operand_tree != init_value_tree:
@ -123,6 +119,10 @@ def reduce_window(
base_dilation: Sequence[int] | None = None,
window_dilation: Sequence[int] | None = None,
) -> Array:
"""Wraps XLA's `ReduceWindowWithGeneralPadding
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
operator.
"""
return _reduce_window(
operand,
init_value,