[docs] donate_argnums FAQ link to rst format

This commit is contained in:
Jiho Lee 2023-01-10 18:11:08 +09:00
parent b5be92e449
commit 41b9c5e8cd
3 changed files with 4 additions and 4 deletions

View File

@ -228,7 +228,7 @@ def jit(
arguments will not be donated.
For more details on buffer donation see the
[FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
inline: Specify whether this function should be inlined into enclosing
jaxprs (rather than being represented as an application of the xla_call
@ -1774,7 +1774,7 @@ def pmap(
arguments will not be donated.
For more details on buffer donation see the
[FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
global_arg_shapes: Optional, must be set when using pmap(sharded_jit) and
the partitioned values span multiple processes. The global cross-process

View File

@ -283,7 +283,7 @@ def pjit(
for example recycling one of your input buffers to store a result. You
should not reuse buffers that you donate to a computation, JAX will raise
an error if you try to.
For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
For more details on buffer donation see the `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
keep_unused: If `False` (the default), arguments that JAX determines to be
unused by `fun` *may* be dropped from resulting compiled XLA executables.
Such arguments will not be transferred to the device nor provided to the

View File

@ -371,7 +371,7 @@ def xmap(fun: Callable,
should not reuse buffers that you donate to a computation, JAX will raise
an error if you try to.
For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
For more details on buffer donation see the `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'.