Annotate vmap

This commit is contained in:
Neil Girdhar 2022-03-24 19:06:12 -04:00
parent b3130b7d65
commit 7697616b85

View File

@ -1384,7 +1384,11 @@ def _split(x, indices, axis):
return x.split(indices, axis)
def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None, axis_size=None) -> F:
def vmap(fun: F,
in_axes: Union[int, Sequence[Any]] = 0,
out_axes: Any = 0,
axis_name: Optional[Hashable] = None,
axis_size: Optional[int] = None) -> F:
"""Vectorizing map. Creates a function which maps ``fun`` over argument axes.
Args: