mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
jax.jit now works correctly if both donate_argnums and donate_argnames are specified.
Update the docstring and changelog too to mention `donate_argnames`. PiperOrigin-RevId: 548223395
This commit is contained in:
parent
f7eef2eda8
commit
89c78bf53f
11
CHANGELOG.md
11
CHANGELOG.md
@ -8,6 +8,17 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.14
|
||||
|
||||
* Changes
|
||||
* `jax.jit` takes `donate_argnames` as an argument. It's semantics are similar
|
||||
to `static_argnames`.
|
||||
If neither donate_argnums nor donate_argnames is provided, no
|
||||
arguments are donated. If donate_argnums is not provided but
|
||||
donate_argnames is, or vice versa, JAX uses
|
||||
`inspect.signature(fun)` to find any positional arguments that
|
||||
correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual
|
||||
parameters listed in either donate_argnums or donate_argnames will
|
||||
be donated.
|
||||
|
||||
* Deprecations
|
||||
* Python 3.8 support has been dropped as per
|
||||
https://jax.readthedocs.io/en/latest/deprecation.html
|
||||
|
@ -238,8 +238,16 @@ def jit(
|
||||
result. You should not reuse buffers that you donate to a computation, JAX
|
||||
will raise an error if you try to. By default, no argument buffers are
|
||||
donated.
|
||||
Note that donate_argnums only work for positional arguments, and keyword
|
||||
arguments will not be donated.
|
||||
|
||||
If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
|
||||
arguments are donated. If ``donate_argnums`` is not provided but
|
||||
``donate_argnames`` is, or vice versa, JAX uses
|
||||
:code:`inspect.signature(fun)` to find any positional arguments that
|
||||
correspond to ``donate_argnames``
|
||||
(or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
|
||||
provided, ``inspect.signature`` is not used, and only actual
|
||||
parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
|
||||
be donated.
|
||||
|
||||
For more details on buffer donation see the
|
||||
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
|
||||
@ -297,7 +305,7 @@ def jit(
|
||||
>>> g(jnp.arange(4), 3)
|
||||
Array([ 0, 1, 256, 6561], dtype=int32)
|
||||
"""
|
||||
(in_shardings, out_shardings, donate_argnums, static_argnums,
|
||||
(in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
|
||||
static_argnames) = pjit.pre_infer_params(
|
||||
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
|
||||
static_argnums, static_argnames, device, backend, abstracted_axes)
|
||||
@ -307,8 +315,9 @@ def jit(
|
||||
fun=fun, in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, static_argnums=static_argnums,
|
||||
static_argnames=static_argnames, donate_argnums=donate_argnums,
|
||||
device=device, backend=backend, keep_unused=keep_unused,
|
||||
inline=inline, resource_env=None, abstracted_axes=abstracted_axes)
|
||||
donate_argnames=donate_argnames, device=device, backend=backend,
|
||||
keep_unused=keep_unused, inline=inline, resource_env=None,
|
||||
abstracted_axes=abstracted_axes)
|
||||
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
|
||||
|
||||
has_explicit_sharding = pjit._pjit_explicit_sharding(
|
||||
@ -544,7 +553,7 @@ def xla_computation(fun: Callable,
|
||||
f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False)
|
||||
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
|
||||
if donate_argnums:
|
||||
donated_invars = donation_vector(donate_argnums, dyn_args, kwargs)
|
||||
donated_invars = donation_vector(donate_argnums, (), dyn_args, kwargs)
|
||||
else:
|
||||
donated_invars = (False,) * len(args_flat)
|
||||
|
||||
@ -1657,7 +1666,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
args, in_tree = tree_flatten((dyn_args, kwargs))
|
||||
|
||||
if donate_tuple and not config.jax_debug_nans:
|
||||
donated_invars = donation_vector(donate_tuple, dyn_args, kwargs)
|
||||
donated_invars = donation_vector(donate_tuple, (), dyn_args, kwargs)
|
||||
else:
|
||||
donated_invars = (False,) * len(args)
|
||||
try:
|
||||
|
@ -332,15 +332,25 @@ def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
|
||||
yield ans
|
||||
|
||||
|
||||
def donation_vector(donate_argnums, args, kwargs) -> tuple[bool, ...]:
|
||||
"""Returns a tuple with a boolean value for each leaf in args."""
|
||||
def donation_vector(donate_argnums, donate_argnames, args, kwargs) -> tuple[bool, ...]:
|
||||
"""Returns a tuple with a boolean value for each leaf in args and kwargs.
|
||||
|
||||
What if a user specifies donate_argnums but calls the function with kwargs
|
||||
or vice-versa? In that case, in `resolve_argnums` using the signature of the
|
||||
function, the counterpart (donate_argnames or donate_argnums respectively) is
|
||||
calculated so when this function is called both donate_argnums and
|
||||
donate_argnames are available. This allows JAX to donate kwargs when only
|
||||
donate_argnums is specified and vice-versa.
|
||||
|
||||
When both donate_argnums and donate_argnames are specified, only the args and
|
||||
kwargs specified are donated.
|
||||
"""
|
||||
res: list[bool] = []
|
||||
for i, arg in enumerate(args):
|
||||
donate = bool(i in donate_argnums)
|
||||
res.extend((donate,) * tree_structure(arg).num_leaves)
|
||||
num_args = len(args)
|
||||
for i, val in enumerate(kwargs.values()):
|
||||
donate = bool(i + num_args in donate_argnums)
|
||||
for key, val in kwargs.items():
|
||||
donate = key in donate_argnames
|
||||
res.extend((donate,) * tree_structure(val).num_leaves)
|
||||
return tuple(res)
|
||||
|
||||
@ -483,7 +493,6 @@ def infer_argnums_and_argnames(
|
||||
if argnums is not None and argnames is not None:
|
||||
argnums = _ensure_index_tuple(argnums)
|
||||
argnames = _ensure_str_tuple(argnames)
|
||||
|
||||
return argnums, argnames
|
||||
|
||||
parameters = sig.parameters
|
||||
@ -506,7 +515,7 @@ def infer_argnums_and_argnames(
|
||||
|
||||
def resolve_argnums(
|
||||
fun, donate_argnums, donate_argnames, static_argnums, static_argnames
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[str, ...]]:
|
||||
) -> tuple[tuple[int, ...], tuple[str, ...], tuple[int, ...], tuple[str, ...]]:
|
||||
try:
|
||||
sig = inspect.signature(fun)
|
||||
except ValueError as e:
|
||||
@ -528,11 +537,6 @@ def resolve_argnums(
|
||||
# names and vice-versa.
|
||||
static_argnums, static_argnames = infer_argnums_and_argnames(
|
||||
sig, static_argnums, static_argnames)
|
||||
if donate_argnums is not None and donate_argnames is not None:
|
||||
raise NotImplementedError(
|
||||
"Currently only specifying either donate_argnums or donate_argnames "
|
||||
"is allowed. Please file a feature request at "
|
||||
"https://github.com/google/jax/issues.")
|
||||
donate_argnums, donate_argnames = infer_argnums_and_argnames(
|
||||
sig, donate_argnums, donate_argnames)
|
||||
|
||||
@ -543,10 +547,17 @@ def resolve_argnums(
|
||||
validate_argnames(sig, donate_argnames, "donate_argnames")
|
||||
|
||||
# Compensate for static argnums absorbing args
|
||||
# TODO(yashkatariya): Maybe add static_argnames support too here for cases
|
||||
# when nums cannot be inferred from names.
|
||||
assert_no_intersection(static_argnames, donate_argnames)
|
||||
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)
|
||||
return donate_argnums, static_argnums, static_argnames
|
||||
return donate_argnums, donate_argnames, static_argnums, static_argnames
|
||||
|
||||
|
||||
def assert_no_intersection(static_argnames, donate_argnames):
|
||||
out = set(static_argnames).intersection(set(donate_argnames))
|
||||
if out:
|
||||
raise ValueError(
|
||||
"static_argnames and donate_argnames cannot intersect. Argument names "
|
||||
f"{out} appear in both static_argnames and donate_argnames")
|
||||
|
||||
|
||||
def _dtype(x):
|
||||
|
@ -524,7 +524,7 @@ def xmap(fun: Callable,
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
if donate_argnums:
|
||||
donated_invars = donation_vector(donate_argnums, args, {})
|
||||
donated_invars = donation_vector(donate_argnums, (), args, {})
|
||||
else:
|
||||
donated_invars = (False,) * len(args_flat)
|
||||
in_axes_flat = _flatten_axes("xmap in_axes", in_tree, in_axes, tupled_args=True)
|
||||
|
@ -332,11 +332,11 @@ def pre_infer_params(fun, in_shardings, out_shardings,
|
||||
in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
|
||||
out_shardings, _, _ = prepare_axis_resources(out_shardings, 'out_shardings')
|
||||
|
||||
donate_argnums, static_argnums, static_argnames = resolve_argnums(
|
||||
donate_argnums, donate_argnames, static_argnums, static_argnames = resolve_argnums(
|
||||
fun, donate_argnums, donate_argnames, static_argnums, static_argnames)
|
||||
|
||||
return (in_shardings, out_shardings, donate_argnums, static_argnums,
|
||||
static_argnames)
|
||||
return (in_shardings, out_shardings, donate_argnums, donate_argnames,
|
||||
static_argnums, static_argnames)
|
||||
|
||||
|
||||
def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
@ -406,6 +406,7 @@ class PjitInfo(NamedTuple):
|
||||
static_argnums: tuple[int, ...]
|
||||
static_argnames: tuple[str, ...]
|
||||
donate_argnums: tuple[int, ...]
|
||||
donate_argnames: tuple[str, ...]
|
||||
device: Optional[xc.Device]
|
||||
backend: Optional[str]
|
||||
keep_unused: bool
|
||||
@ -416,7 +417,7 @@ class PjitInfo(NamedTuple):
|
||||
|
||||
def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
(fun, user_in_shardings, user_out_shardings, static_argnums, static_argnames,
|
||||
donate_argnums, device, backend, keep_unused, inline,
|
||||
donate_argnums, donate_argnames, device, backend, keep_unused, inline,
|
||||
resource_env, abstracted_axes) = pjit_info_args
|
||||
|
||||
if (kwargs and user_in_shardings is not None and
|
||||
@ -457,11 +458,9 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
dyn_kwargs = {}
|
||||
del kwargs
|
||||
|
||||
if donate_argnums and not config.jax_debug_nans:
|
||||
# TODO(yashkatariya): Maybe thread donate_argnames to calculate
|
||||
# donation_vector. Currently donate_argnames is normalized into
|
||||
# donate_argnums just like static_argnames.
|
||||
donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs)
|
||||
if (donate_argnums or donate_argnames) and not config.jax_debug_nans:
|
||||
donated_invars = donation_vector(
|
||||
donate_argnums, donate_argnames, dyn_args, dyn_kwargs)
|
||||
else:
|
||||
donated_invars = (False,) * len(explicit_args)
|
||||
|
||||
@ -719,14 +718,27 @@ def pjit(
|
||||
comment on ``static_argnums`` for details. If not
|
||||
provided but ``static_argnums`` is set, the default is based on calling
|
||||
``inspect.signature(fun)`` to find corresponding named arguments.
|
||||
donate_argnums: Specify which argument buffers are "donated" to the computation.
|
||||
It is safe to donate argument buffers if you no longer need them once the
|
||||
computation has finished. In some cases XLA can make use of donated
|
||||
buffers to reduce the amount of memory needed to perform a computation,
|
||||
for example recycling one of your input buffers to store a result. You
|
||||
should not reuse buffers that you donate to a computation, JAX will raise
|
||||
an error if you try to.
|
||||
For more details on buffer donation see the `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
|
||||
donate_argnums: Specify which positional argument buffers are "donated" to
|
||||
the computation. It is safe to donate argument buffers if you no longer
|
||||
need them once the computation has finished. In some cases XLA can make
|
||||
use of donated buffers to reduce the amount of memory needed to perform a
|
||||
computation, for example recycling one of your input buffers to store a
|
||||
result. You should not reuse buffers that you donate to a computation, JAX
|
||||
will raise an error if you try to. By default, no argument buffers are
|
||||
donated.
|
||||
|
||||
If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
|
||||
arguments are donated. If ``donate_argnums`` is not provided but
|
||||
``donate_argnames`` is, or vice versa, JAX uses
|
||||
:code:`inspect.signature(fun)` to find any positional arguments that
|
||||
correspond to ``donate_argnames``
|
||||
(or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
|
||||
provided, ``inspect.signature`` is not used, and only actual
|
||||
parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
|
||||
be donated.
|
||||
|
||||
For more details on buffer donation see the
|
||||
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
|
||||
donate_argnames: An optional string or collection of strings specifying
|
||||
which named arguments are donated to the computation. See the
|
||||
comment on ``donate_argnums`` for details. If not
|
||||
@ -770,7 +782,7 @@ def pjit(
|
||||
in_shardings, out_shardings = _resolve_axis_resources_and_shardings_arg(
|
||||
in_shardings, out_shardings, in_axis_resources, out_axis_resources)
|
||||
|
||||
(in_shardings, out_shardings, donate_argnums, static_argnums,
|
||||
(in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
|
||||
static_argnames) = pre_infer_params(
|
||||
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
|
||||
static_argnums, static_argnames, device, backend, abstracted_axes)
|
||||
@ -782,8 +794,8 @@ def pjit(
|
||||
fun=fun, in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, static_argnums=static_argnums,
|
||||
static_argnames=static_argnames, donate_argnums=donate_argnums,
|
||||
device=device, backend=backend, keep_unused=keep_unused,
|
||||
inline=inline, resource_env=resource_env,
|
||||
donate_argnames=donate_argnames, device=device, backend=backend,
|
||||
keep_unused=keep_unused, inline=inline, resource_env=resource_env,
|
||||
abstracted_axes=abstracted_axes)
|
||||
return common_infer_params(pjit_info_args, *args, **kwargs)
|
||||
|
||||
@ -1448,9 +1460,9 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
resovled_axes_out = batching.resolve_ragged_axes_against_inputs_outputs(
|
||||
resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs(
|
||||
vals_in, vals_out, axes_out)
|
||||
return vals_out, resovled_axes_out
|
||||
return vals_out, resolved_axes_out
|
||||
|
||||
batching.spmd_axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False)
|
||||
batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None)
|
||||
|
@ -497,11 +497,41 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
print(x_copy) # doesn't crash
|
||||
|
||||
def test_specify_donate_argnums_and_argnames(self):
|
||||
@partial(jax.jit, donate_argnums=0, donate_argnames=('inp2', 'inp3'))
|
||||
def f(inp1, inp2, inp3):
|
||||
return inp1 * 2, inp2 * 2, inp3 * 2
|
||||
|
||||
x = jnp.ones((2, 5)) * 4
|
||||
y = jnp.ones((2, 5)) * 2
|
||||
z = jnp.ones((2, 5))
|
||||
|
||||
f(x, inp2=y, inp3=z)
|
||||
self.assertDeleted(x)
|
||||
self.assertDeleted(y)
|
||||
self.assertDeleted(z)
|
||||
|
||||
def test_donate_argnames_with_args(self):
|
||||
@partial(jax.jit, donate_argnames='inp1')
|
||||
def f(inp1):
|
||||
return inp1 * 2
|
||||
|
||||
x = jax.device_put(jnp.ones((2, 5)) * 4, jax.devices()[0])
|
||||
f(x)
|
||||
self.assertDeleted(x)
|
||||
|
||||
def test_donate_argnums_with_kwargs(self):
|
||||
@partial(jax.jit, donate_argnums=0)
|
||||
def f(inp1):
|
||||
return inp1 * 2
|
||||
|
||||
x = jax.device_put(jnp.ones((2, 5)) * 4, jax.devices()[0])
|
||||
f(inp1=x)
|
||||
self.assertDeleted(x)
|
||||
|
||||
def test_intersecting_static_and_donate_argnames(self):
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Currently only specifying either donate_argnums or donate_argnames is "
|
||||
"allowed"):
|
||||
jax.jit(lambda x: x, donate_argnums=0, donate_argnames='x')
|
||||
ValueError, "static_argnames and donate_argnames cannot intersect"):
|
||||
jax.jit(lambda x: x, static_argnames='x', donate_argnames='x')
|
||||
|
||||
def test_jit_global_cache(self):
|
||||
def f(x):
|
||||
|
@ -42,7 +42,8 @@ class ApiUtilTest(jtu.JaxTestCase):
|
||||
if kwargs:
|
||||
expected += (False,)
|
||||
self.assertEqual(
|
||||
expected, api_util.donation_vector(donate_argnums, args, kwargs))
|
||||
expected,
|
||||
api_util.donation_vector(donate_argnums, (), args, kwargs))
|
||||
|
||||
@parameterized.parameters(
|
||||
((0,), (0,)),
|
||||
|
@ -50,7 +50,7 @@ from jax._src.sharding_impls import (
|
||||
SingleDeviceSharding, parse_flatten_op_sharding)
|
||||
import jax._src.pjit as pjit_lib
|
||||
from jax._src.pjit import pjit, pjit_p
|
||||
from jax._src import mesh
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import mlir
|
||||
from jax._src import xla_bridge
|
||||
@ -248,20 +248,20 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
def testDifferentNestedMesh(self):
|
||||
with jtu.create_global_mesh((2, 1), ("x", "y")) as m1:
|
||||
with jtu.create_global_mesh((2, 2), ("a", "b")) as m2:
|
||||
self.assertEqual(mesh.thread_resources.env.physical_mesh, m2)
|
||||
self.assertEqual(mesh.thread_resources.env.physical_mesh, m1)
|
||||
self.assertEqual(mesh.thread_resources.env.physical_mesh,
|
||||
mesh.EMPTY_ENV.physical_mesh)
|
||||
self.assertEqual(mesh_lib.thread_resources.env.physical_mesh, m2)
|
||||
self.assertEqual(mesh_lib.thread_resources.env.physical_mesh, m1)
|
||||
self.assertEqual(mesh_lib.thread_resources.env.physical_mesh,
|
||||
mesh_lib.EMPTY_ENV.physical_mesh)
|
||||
|
||||
def testSameNestedMesh(self):
|
||||
mesh = jtu.create_global_mesh((2, 1), ("a", "b"))
|
||||
thread_resources = jax._src.mesh.thread_resources
|
||||
thread_resources = mesh_lib.thread_resources
|
||||
with mesh as m1:
|
||||
with mesh as m2:
|
||||
self.assertEqual(thread_resources.env.physical_mesh, m2)
|
||||
self.assertEqual(thread_resources.env.physical_mesh, m1)
|
||||
self.assertEqual(thread_resources.env.physical_mesh,
|
||||
jax._src.mesh.EMPTY_ENV.physical_mesh)
|
||||
mesh_lib.EMPTY_ENV.physical_mesh)
|
||||
|
||||
def testMeshDecorator(self):
|
||||
x = jnp.arange(8)
|
||||
|
Loading…
x
Reference in New Issue
Block a user