Fix typing annotations for jax.named_call

PiperOrigin-RevId: 582235119
This commit is contained in:
Etienne Pot 2023-11-14 01:47:04 -08:00 committed by jax authors
parent cfce0802f8
commit 7cf66dfe4b

View File

@ -2835,10 +2835,10 @@ def eval_shape(fun: Callable, *args, **kwargs):
def named_call(
fun: Callable[..., Any],
fun: F,
*,
name: str | None = None,
) -> Callable[..., Any]:
) -> F:
"""Adds a user specified name to a function when staging out JAX computations.
When staging out computations for just-in-time compilation to XLA (or other
@ -2867,6 +2867,7 @@ def named_call(
return source_info_util.extend_name_stack(name)(fun)
@contextmanager
def named_scope(
name: str,