jax.jit now works correctly if both donate_argnums and donate_argnames are specified.

Update the docstring and changelog too to mention `donate_argnames`.

PiperOrigin-RevId: 548223395
This commit is contained in:
Yash Katariya 2023-07-14 14:27:29 -07:00 committed by jax authors
parent f7eef2eda8
commit 89c78bf53f
8 changed files with 131 additions and 57 deletions

View File

@ -8,6 +8,17 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.14
* Changes
* `jax.jit` takes `donate_argnames` as an argument. It's semantics are similar
to `static_argnames`.
If neither donate_argnums nor donate_argnames is provided, no
arguments are donated. If donate_argnums is not provided but
donate_argnames is, or vice versa, JAX uses
`inspect.signature(fun)` to find any positional arguments that
correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual
parameters listed in either donate_argnums or donate_argnames will
be donated.
* Deprecations
* Python 3.8 support has been dropped as per
https://jax.readthedocs.io/en/latest/deprecation.html

View File

@ -238,8 +238,16 @@ def jit(
result. You should not reuse buffers that you donate to a computation, JAX
will raise an error if you try to. By default, no argument buffers are
donated.
Note that donate_argnums only work for positional arguments, and keyword
arguments will not be donated.
If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
arguments are donated. If ``donate_argnums`` is not provided but
``donate_argnames`` is, or vice versa, JAX uses
:code:`inspect.signature(fun)` to find any positional arguments that
correspond to ``donate_argnames``
(or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
provided, ``inspect.signature`` is not used, and only actual
parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
be donated.
For more details on buffer donation see the
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
@ -297,7 +305,7 @@ def jit(
>>> g(jnp.arange(4), 3)
Array([ 0, 1, 256, 6561], dtype=int32)
"""
(in_shardings, out_shardings, donate_argnums, static_argnums,
(in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
static_argnames) = pjit.pre_infer_params(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes)
@ -307,8 +315,9 @@ def jit(
fun=fun, in_shardings=in_shardings,
out_shardings=out_shardings, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
device=device, backend=backend, keep_unused=keep_unused,
inline=inline, resource_env=None, abstracted_axes=abstracted_axes)
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline, resource_env=None,
abstracted_axes=abstracted_axes)
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
has_explicit_sharding = pjit._pjit_explicit_sharding(
@ -544,7 +553,7 @@ def xla_computation(fun: Callable,
f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False)
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
if donate_argnums:
donated_invars = donation_vector(donate_argnums, dyn_args, kwargs)
donated_invars = donation_vector(donate_argnums, (), dyn_args, kwargs)
else:
donated_invars = (False,) * len(args_flat)
@ -1657,7 +1666,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
args, in_tree = tree_flatten((dyn_args, kwargs))
if donate_tuple and not config.jax_debug_nans:
donated_invars = donation_vector(donate_tuple, dyn_args, kwargs)
donated_invars = donation_vector(donate_tuple, (), dyn_args, kwargs)
else:
donated_invars = (False,) * len(args)
try:

View File

@ -332,15 +332,25 @@ def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
yield ans
def donation_vector(donate_argnums, args, kwargs) -> tuple[bool, ...]:
"""Returns a tuple with a boolean value for each leaf in args."""
def donation_vector(donate_argnums, donate_argnames, args, kwargs) -> tuple[bool, ...]:
"""Returns a tuple with a boolean value for each leaf in args and kwargs.
What if a user specifies donate_argnums but calls the function with kwargs
or vice-versa? In that case, in `resolve_argnums` using the signature of the
function, the counterpart (donate_argnames or donate_argnums respectively) is
calculated so when this function is called both donate_argnums and
donate_argnames are available. This allows JAX to donate kwargs when only
donate_argnums is specified and vice-versa.
When both donate_argnums and donate_argnames are specified, only the args and
kwargs specified are donated.
"""
res: list[bool] = []
for i, arg in enumerate(args):
donate = bool(i in donate_argnums)
res.extend((donate,) * tree_structure(arg).num_leaves)
num_args = len(args)
for i, val in enumerate(kwargs.values()):
donate = bool(i + num_args in donate_argnums)
for key, val in kwargs.items():
donate = key in donate_argnames
res.extend((donate,) * tree_structure(val).num_leaves)
return tuple(res)
@ -483,7 +493,6 @@ def infer_argnums_and_argnames(
if argnums is not None and argnames is not None:
argnums = _ensure_index_tuple(argnums)
argnames = _ensure_str_tuple(argnames)
return argnums, argnames
parameters = sig.parameters
@ -506,7 +515,7 @@ def infer_argnums_and_argnames(
def resolve_argnums(
fun, donate_argnums, donate_argnames, static_argnums, static_argnames
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[str, ...]]:
) -> tuple[tuple[int, ...], tuple[str, ...], tuple[int, ...], tuple[str, ...]]:
try:
sig = inspect.signature(fun)
except ValueError as e:
@ -528,11 +537,6 @@ def resolve_argnums(
# names and vice-versa.
static_argnums, static_argnames = infer_argnums_and_argnames(
sig, static_argnums, static_argnames)
if donate_argnums is not None and donate_argnames is not None:
raise NotImplementedError(
"Currently only specifying either donate_argnums or donate_argnames "
"is allowed. Please file a feature request at "
"https://github.com/google/jax/issues.")
donate_argnums, donate_argnames = infer_argnums_and_argnames(
sig, donate_argnums, donate_argnames)
@ -543,10 +547,17 @@ def resolve_argnums(
validate_argnames(sig, donate_argnames, "donate_argnames")
# Compensate for static argnums absorbing args
# TODO(yashkatariya): Maybe add static_argnames support too here for cases
# when nums cannot be inferred from names.
assert_no_intersection(static_argnames, donate_argnames)
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)
return donate_argnums, static_argnums, static_argnames
return donate_argnums, donate_argnames, static_argnums, static_argnames
def assert_no_intersection(static_argnames, donate_argnames):
out = set(static_argnames).intersection(set(donate_argnames))
if out:
raise ValueError(
"static_argnames and donate_argnames cannot intersect. Argument names "
f"{out} appear in both static_argnames and donate_argnames")
def _dtype(x):

View File

@ -524,7 +524,7 @@ def xmap(fun: Callable,
args_flat, in_tree = tree_flatten(args)
fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
if donate_argnums:
donated_invars = donation_vector(donate_argnums, args, {})
donated_invars = donation_vector(donate_argnums, (), args, {})
else:
donated_invars = (False,) * len(args_flat)
in_axes_flat = _flatten_axes("xmap in_axes", in_tree, in_axes, tupled_args=True)

View File

@ -332,11 +332,11 @@ def pre_infer_params(fun, in_shardings, out_shardings,
in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
out_shardings, _, _ = prepare_axis_resources(out_shardings, 'out_shardings')
donate_argnums, static_argnums, static_argnames = resolve_argnums(
donate_argnums, donate_argnames, static_argnums, static_argnames = resolve_argnums(
fun, donate_argnums, donate_argnames, static_argnums, static_argnames)
return (in_shardings, out_shardings, donate_argnums, static_argnums,
static_argnames)
return (in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames)
def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
@ -406,6 +406,7 @@ class PjitInfo(NamedTuple):
static_argnums: tuple[int, ...]
static_argnames: tuple[str, ...]
donate_argnums: tuple[int, ...]
donate_argnames: tuple[str, ...]
device: Optional[xc.Device]
backend: Optional[str]
keep_unused: bool
@ -416,7 +417,7 @@ class PjitInfo(NamedTuple):
def common_infer_params(pjit_info_args, *args, **kwargs):
(fun, user_in_shardings, user_out_shardings, static_argnums, static_argnames,
donate_argnums, device, backend, keep_unused, inline,
donate_argnums, donate_argnames, device, backend, keep_unused, inline,
resource_env, abstracted_axes) = pjit_info_args
if (kwargs and user_in_shardings is not None and
@ -457,11 +458,9 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
dyn_kwargs = {}
del kwargs
if donate_argnums and not config.jax_debug_nans:
# TODO(yashkatariya): Maybe thread donate_argnames to calculate
# donation_vector. Currently donate_argnames is normalized into
# donate_argnums just like static_argnames.
donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs)
if (donate_argnums or donate_argnames) and not config.jax_debug_nans:
donated_invars = donation_vector(
donate_argnums, donate_argnames, dyn_args, dyn_kwargs)
else:
donated_invars = (False,) * len(explicit_args)
@ -719,14 +718,27 @@ def pjit(
comment on ``static_argnums`` for details. If not
provided but ``static_argnums`` is set, the default is based on calling
``inspect.signature(fun)`` to find corresponding named arguments.
donate_argnums: Specify which argument buffers are "donated" to the computation.
It is safe to donate argument buffers if you no longer need them once the
computation has finished. In some cases XLA can make use of donated
buffers to reduce the amount of memory needed to perform a computation,
for example recycling one of your input buffers to store a result. You
should not reuse buffers that you donate to a computation, JAX will raise
an error if you try to.
For more details on buffer donation see the `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
donate_argnums: Specify which positional argument buffers are "donated" to
the computation. It is safe to donate argument buffers if you no longer
need them once the computation has finished. In some cases XLA can make
use of donated buffers to reduce the amount of memory needed to perform a
computation, for example recycling one of your input buffers to store a
result. You should not reuse buffers that you donate to a computation, JAX
will raise an error if you try to. By default, no argument buffers are
donated.
If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
arguments are donated. If ``donate_argnums`` is not provided but
``donate_argnames`` is, or vice versa, JAX uses
:code:`inspect.signature(fun)` to find any positional arguments that
correspond to ``donate_argnames``
(or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
provided, ``inspect.signature`` is not used, and only actual
parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
be donated.
For more details on buffer donation see the
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
donate_argnames: An optional string or collection of strings specifying
which named arguments are donated to the computation. See the
comment on ``donate_argnums`` for details. If not
@ -770,7 +782,7 @@ def pjit(
in_shardings, out_shardings = _resolve_axis_resources_and_shardings_arg(
in_shardings, out_shardings, in_axis_resources, out_axis_resources)
(in_shardings, out_shardings, donate_argnums, static_argnums,
(in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
static_argnames) = pre_infer_params(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes)
@ -782,8 +794,8 @@ def pjit(
fun=fun, in_shardings=in_shardings,
out_shardings=out_shardings, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
device=device, backend=backend, keep_unused=keep_unused,
inline=inline, resource_env=resource_env,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline, resource_env=resource_env,
abstracted_axes=abstracted_axes)
return common_infer_params(pjit_info_args, *args, **kwargs)
@ -1448,9 +1460,9 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
name=name,
keep_unused=keep_unused,
inline=inline)
resovled_axes_out = batching.resolve_ragged_axes_against_inputs_outputs(
resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs(
vals_in, vals_out, axes_out)
return vals_out, resovled_axes_out
return vals_out, resolved_axes_out
batching.spmd_axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False)
batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None)

View File

@ -497,11 +497,41 @@ class CPPJitTest(jtu.BufferDonationTestCase):
print(x_copy) # doesn't crash
def test_specify_donate_argnums_and_argnames(self):
@partial(jax.jit, donate_argnums=0, donate_argnames=('inp2', 'inp3'))
def f(inp1, inp2, inp3):
return inp1 * 2, inp2 * 2, inp3 * 2
x = jnp.ones((2, 5)) * 4
y = jnp.ones((2, 5)) * 2
z = jnp.ones((2, 5))
f(x, inp2=y, inp3=z)
self.assertDeleted(x)
self.assertDeleted(y)
self.assertDeleted(z)
def test_donate_argnames_with_args(self):
@partial(jax.jit, donate_argnames='inp1')
def f(inp1):
return inp1 * 2
x = jax.device_put(jnp.ones((2, 5)) * 4, jax.devices()[0])
f(x)
self.assertDeleted(x)
def test_donate_argnums_with_kwargs(self):
@partial(jax.jit, donate_argnums=0)
def f(inp1):
return inp1 * 2
x = jax.device_put(jnp.ones((2, 5)) * 4, jax.devices()[0])
f(inp1=x)
self.assertDeleted(x)
def test_intersecting_static_and_donate_argnames(self):
with self.assertRaisesRegex(
NotImplementedError,
"Currently only specifying either donate_argnums or donate_argnames is "
"allowed"):
jax.jit(lambda x: x, donate_argnums=0, donate_argnames='x')
ValueError, "static_argnames and donate_argnames cannot intersect"):
jax.jit(lambda x: x, static_argnames='x', donate_argnames='x')
def test_jit_global_cache(self):
def f(x):

View File

@ -42,7 +42,8 @@ class ApiUtilTest(jtu.JaxTestCase):
if kwargs:
expected += (False,)
self.assertEqual(
expected, api_util.donation_vector(donate_argnums, args, kwargs))
expected,
api_util.donation_vector(donate_argnums, (), args, kwargs))
@parameterized.parameters(
((0,), (0,)),

View File

@ -50,7 +50,7 @@ from jax._src.sharding_impls import (
SingleDeviceSharding, parse_flatten_op_sharding)
import jax._src.pjit as pjit_lib
from jax._src.pjit import pjit, pjit_p
from jax._src import mesh
from jax._src import mesh as mesh_lib
from jax._src.interpreters import pxla
from jax.interpreters import mlir
from jax._src import xla_bridge
@ -248,20 +248,20 @@ class PJitTest(jtu.BufferDonationTestCase):
def testDifferentNestedMesh(self):
with jtu.create_global_mesh((2, 1), ("x", "y")) as m1:
with jtu.create_global_mesh((2, 2), ("a", "b")) as m2:
self.assertEqual(mesh.thread_resources.env.physical_mesh, m2)
self.assertEqual(mesh.thread_resources.env.physical_mesh, m1)
self.assertEqual(mesh.thread_resources.env.physical_mesh,
mesh.EMPTY_ENV.physical_mesh)
self.assertEqual(mesh_lib.thread_resources.env.physical_mesh, m2)
self.assertEqual(mesh_lib.thread_resources.env.physical_mesh, m1)
self.assertEqual(mesh_lib.thread_resources.env.physical_mesh,
mesh_lib.EMPTY_ENV.physical_mesh)
def testSameNestedMesh(self):
mesh = jtu.create_global_mesh((2, 1), ("a", "b"))
thread_resources = jax._src.mesh.thread_resources
thread_resources = mesh_lib.thread_resources
with mesh as m1:
with mesh as m2:
self.assertEqual(thread_resources.env.physical_mesh, m2)
self.assertEqual(thread_resources.env.physical_mesh, m1)
self.assertEqual(thread_resources.env.physical_mesh,
jax._src.mesh.EMPTY_ENV.physical_mesh)
mesh_lib.EMPTY_ENV.physical_mesh)
def testMeshDecorator(self):
x = jnp.arange(8)