Merge pull request #10823 from LenaMartens:changelist/450914215

PiperOrigin-RevId: 451427821
This commit is contained in:
jax authors 2022-05-27 10:39:24 -07:00
commit 094e706498
2 changed files with 43 additions and 22 deletions

View File

@ -578,18 +578,28 @@ Buffer donation
(This feature is implemented only for TPU and GPU.)
When JAX executes a computation it reserves buffers on the device for all inputs and outputs.
When JAX executes a computation it uses buffers on the device for all inputs and outputs.
If you know than one of the inputs is not needed after the computation, and if it
matches the shape and element type of one of the outputs, you can specify that you
want the corresponding input buffer to be donated to hold an output. This will reduce
the memory required for the execution by the size of the donated buffer.
If you have something like the following pattern, you can use buffer donation::
params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state)
You can think of this as a way to do a memory-efficient functional update
on your immutable JAX arrays. Within the boundaries of a computation XLA can
make this optimization for you, but at the jit/pmap boundary you need to
guarantee to XLA that you will not use the donated input buffer after calling
the donating function.
You achieve this by using the `donate_argnums` parameter to the functions :func:`jax.jit`,
:func:`jax.pjit`, and :func:`jax.pmap`. This parameter is a sequence of indices (0 based) into
the positional argument list::
def add(x, y):
return x + y
def add(x, y):
return x + y
x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
@ -597,11 +607,16 @@ the positional argument list::
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y)
Note that this currently does not work when calling your function with key-word arguments!
The following code will not donate any buffers::
params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state)
If an argument whose buffer is donated is a pytree, then all the buffers
for its components are donated::
def add_ones(xs: List[Array]):
return [x + 1 for x in xs]
def add_ones(xs: List[Array]):
return [x + 1 for x in xs]
xs = [jax.device_put(np.ones((2, 3)), jax.device_put(np.ones((3, 4))]
# Execute `add_ones` with donation of all the buffers for `xs`.

View File

@ -221,7 +221,6 @@ def _infer_argnums_and_argnames(
return argnums, argnames
def jit(
fun: Callable,
*,
@ -277,15 +276,19 @@ def jit(
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
donate_argnums: Specify which argument buffers are "donated" to the computation.
It is safe to donate argument buffers if you no longer need them once the
computation has finished. In some cases XLA can make use of donated
buffers to reduce the amount of memory needed to perform a computation,
for example recycling one of your input buffers to store a result. You
should not reuse buffers that you donate to a computation, JAX will raise
an error if you try to. By default, no argument buffers are donated.
donate_argnums: Specify which positional argument buffers are "donated" to
the computation. It is safe to donate argument buffers if you no longer
need them once the computation has finished. In some cases XLA can make
use of donated buffers to reduce the amount of memory needed to perform a
computation, for example recycling one of your input buffers to store a
result. You should not reuse buffers that you donate to a computation, JAX
will raise an error if you try to. By default, no argument buffers are
donated.
Note that donate_argnums only work for positional arguments, and keyword
arguments will not be donated.
For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
For more details on buffer donation see the
[FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
inline: Specify whether this function should be inlined into enclosing
jaxprs (rather than being represented as an application of the xla_call
@ -1676,15 +1679,18 @@ def pmap(
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'.
axis_size: Optional; the size of the mapped axis.
donate_argnums: Specify which argument buffers are "donated" to the computation.
It is safe to donate argument buffers if you no longer need them once the
computation has finished. In some cases XLA can make use of donated
buffers to reduce the amount of memory needed to perform a computation,
for example recycling one of your input buffers to store a result. You
should not reuse buffers that you donate to a computation, JAX will raise
an error if you try to.
donate_argnums: Specify which positional argument buffers are "donated" to
the computation. It is safe to donate argument buffers if you no longer need
them once the computation has finished. In some cases XLA can make use of
donated buffers to reduce the amount of memory needed to perform a
computation, for example recycling one of your input buffers to store a
result. You should not reuse buffers that you donate to a computation, JAX
will raise an error if you try to.
Note that donate_argnums only work for positional arguments, and keyword
arguments will not be donated.
For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
For more details on buffer donation see the
[FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
global_arg_shapes: Optional, must be set when using pmap(sharded_jit) and
the partitioned values span multiple processes. The global cross-process