mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #10823 from LenaMartens:changelist/450914215
PiperOrigin-RevId: 451427821
This commit is contained in:
commit
094e706498
25
docs/faq.rst
25
docs/faq.rst
@ -578,18 +578,28 @@ Buffer donation
|
||||
|
||||
(This feature is implemented only for TPU and GPU.)
|
||||
|
||||
When JAX executes a computation it reserves buffers on the device for all inputs and outputs.
|
||||
When JAX executes a computation it uses buffers on the device for all inputs and outputs.
|
||||
If you know than one of the inputs is not needed after the computation, and if it
|
||||
matches the shape and element type of one of the outputs, you can specify that you
|
||||
want the corresponding input buffer to be donated to hold an output. This will reduce
|
||||
the memory required for the execution by the size of the donated buffer.
|
||||
|
||||
If you have something like the following pattern, you can use buffer donation::
|
||||
|
||||
params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state)
|
||||
|
||||
You can think of this as a way to do a memory-efficient functional update
|
||||
on your immutable JAX arrays. Within the boundaries of a computation XLA can
|
||||
make this optimization for you, but at the jit/pmap boundary you need to
|
||||
guarantee to XLA that you will not use the donated input buffer after calling
|
||||
the donating function.
|
||||
|
||||
You achieve this by using the `donate_argnums` parameter to the functions :func:`jax.jit`,
|
||||
:func:`jax.pjit`, and :func:`jax.pmap`. This parameter is a sequence of indices (0 based) into
|
||||
the positional argument list::
|
||||
|
||||
def add(x, y):
|
||||
return x + y
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
x = jax.device_put(np.ones((2, 3)))
|
||||
y = jax.device_put(np.ones((2, 3)))
|
||||
@ -597,11 +607,16 @@ the positional argument list::
|
||||
# the same shape and type as `y`, so it will share its buffer.
|
||||
z = jax.jit(add, donate_argnums=(1,))(x, y)
|
||||
|
||||
Note that this currently does not work when calling your function with key-word arguments!
|
||||
The following code will not donate any buffers::
|
||||
|
||||
params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state)
|
||||
|
||||
If an argument whose buffer is donated is a pytree, then all the buffers
|
||||
for its components are donated::
|
||||
|
||||
def add_ones(xs: List[Array]):
|
||||
return [x + 1 for x in xs]
|
||||
def add_ones(xs: List[Array]):
|
||||
return [x + 1 for x in xs]
|
||||
|
||||
xs = [jax.device_put(np.ones((2, 3)), jax.device_put(np.ones((3, 4))]
|
||||
# Execute `add_ones` with donation of all the buffers for `xs`.
|
||||
|
@ -221,7 +221,6 @@ def _infer_argnums_and_argnames(
|
||||
|
||||
return argnums, argnames
|
||||
|
||||
|
||||
def jit(
|
||||
fun: Callable,
|
||||
*,
|
||||
@ -277,15 +276,19 @@ def jit(
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
|
||||
``'tpu'``.
|
||||
donate_argnums: Specify which argument buffers are "donated" to the computation.
|
||||
It is safe to donate argument buffers if you no longer need them once the
|
||||
computation has finished. In some cases XLA can make use of donated
|
||||
buffers to reduce the amount of memory needed to perform a computation,
|
||||
for example recycling one of your input buffers to store a result. You
|
||||
should not reuse buffers that you donate to a computation, JAX will raise
|
||||
an error if you try to. By default, no argument buffers are donated.
|
||||
donate_argnums: Specify which positional argument buffers are "donated" to
|
||||
the computation. It is safe to donate argument buffers if you no longer
|
||||
need them once the computation has finished. In some cases XLA can make
|
||||
use of donated buffers to reduce the amount of memory needed to perform a
|
||||
computation, for example recycling one of your input buffers to store a
|
||||
result. You should not reuse buffers that you donate to a computation, JAX
|
||||
will raise an error if you try to. By default, no argument buffers are
|
||||
donated.
|
||||
Note that donate_argnums only work for positional arguments, and keyword
|
||||
arguments will not be donated.
|
||||
|
||||
For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
|
||||
For more details on buffer donation see the
|
||||
[FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
|
||||
|
||||
inline: Specify whether this function should be inlined into enclosing
|
||||
jaxprs (rather than being represented as an application of the xla_call
|
||||
@ -1676,15 +1679,18 @@ def pmap(
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'.
|
||||
axis_size: Optional; the size of the mapped axis.
|
||||
donate_argnums: Specify which argument buffers are "donated" to the computation.
|
||||
It is safe to donate argument buffers if you no longer need them once the
|
||||
computation has finished. In some cases XLA can make use of donated
|
||||
buffers to reduce the amount of memory needed to perform a computation,
|
||||
for example recycling one of your input buffers to store a result. You
|
||||
should not reuse buffers that you donate to a computation, JAX will raise
|
||||
an error if you try to.
|
||||
donate_argnums: Specify which positional argument buffers are "donated" to
|
||||
the computation. It is safe to donate argument buffers if you no longer need
|
||||
them once the computation has finished. In some cases XLA can make use of
|
||||
donated buffers to reduce the amount of memory needed to perform a
|
||||
computation, for example recycling one of your input buffers to store a
|
||||
result. You should not reuse buffers that you donate to a computation, JAX
|
||||
will raise an error if you try to.
|
||||
Note that donate_argnums only work for positional arguments, and keyword
|
||||
arguments will not be donated.
|
||||
|
||||
For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
|
||||
For more details on buffer donation see the
|
||||
[FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
|
||||
|
||||
global_arg_shapes: Optional, must be set when using pmap(sharded_jit) and
|
||||
the partitioned values span multiple processes. The global cross-process
|
||||
|
Loading…
x
Reference in New Issue
Block a user