mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix typing annotations for jax.named_call
PiperOrigin-RevId: 582235119
This commit is contained in:
parent
cfce0802f8
commit
7cf66dfe4b
@ -2835,10 +2835,10 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
|
||||
|
||||
def named_call(
|
||||
fun: Callable[..., Any],
|
||||
fun: F,
|
||||
*,
|
||||
name: str | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
) -> F:
|
||||
"""Adds a user specified name to a function when staging out JAX computations.
|
||||
|
||||
When staging out computations for just-in-time compilation to XLA (or other
|
||||
@ -2867,6 +2867,7 @@ def named_call(
|
||||
|
||||
return source_info_util.extend_name_stack(name)(fun)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def named_scope(
|
||||
name: str,
|
||||
|
Loading…
x
Reference in New Issue
Block a user