mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Annotate vmap
This commit is contained in:
parent
b3130b7d65
commit
7697616b85
@ -1384,7 +1384,11 @@ def _split(x, indices, axis):
|
||||
return x.split(indices, axis)
|
||||
|
||||
|
||||
def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None, axis_size=None) -> F:
|
||||
def vmap(fun: F,
|
||||
in_axes: Union[int, Sequence[Any]] = 0,
|
||||
out_axes: Any = 0,
|
||||
axis_name: Optional[Hashable] = None,
|
||||
axis_size: Optional[int] = None) -> F:
|
||||
"""Vectorizing map. Creates a function which maps ``fun`` over argument axes.
|
||||
|
||||
Args:
|
||||
|
Loading…
x
Reference in New Issue
Block a user