diff --git a/jax/_src/api.py b/jax/_src/api.py index 521f6a00c..01d755e95 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -513,7 +513,7 @@ def xla_computation(fun: Callable, f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False) args_flat, in_tree = tree_flatten((dyn_args, kwargs)) if donate_argnums: - donated_invars = donation_vector(donate_argnums, (), dyn_args, kwargs) + donated_invars = donation_vector(donate_argnums, (), in_tree) else: donated_invars = (False,) * len(args_flat) @@ -1635,7 +1635,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, args, in_tree = tree_flatten((dyn_args, kwargs)) if donate_tuple and not config.debug_nans.value: - donated_invars = donation_vector(donate_tuple, (), dyn_args, kwargs) + donated_invars = donation_vector(donate_tuple, (), in_tree) else: donated_invars = (False,) * len(args) try: diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 539240430..657f1e113 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -27,7 +27,7 @@ from jax._src import dtypes from jax._src.abstract_arrays import numpy_scalar_types from jax._src.core import ShapedArray from jax._src.tree_util import ( - PyTreeDef, tree_flatten, tree_unflatten, tree_map, tree_structure, + PyTreeDef, tree_flatten, tree_unflatten, tree_map, treedef_children, generate_key_paths, keystr, broadcast_prefix, prefix_errors) from jax._src.tree_util import _replace_nones @@ -319,20 +319,8 @@ def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): @lru_cache(maxsize=4096) -def donation_vector_with_in_tree(donate_argnums, donate_argnames, in_tree - ) -> tuple[bool, ...]: - res: list[bool] = [] - args_tree, kwargs_tree = treedef_children(in_tree) - for i, arg in enumerate(args_tree.children()): - donate = bool(i in donate_argnums) - res.extend((donate,) * arg.num_leaves) - for key, val in safe_zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore - donate = key in donate_argnames - res.extend((donate,) * val.num_leaves) - return tuple(res) - - -def donation_vector(donate_argnums, donate_argnames, args, kwargs) -> tuple[bool, ...]: +def donation_vector(donate_argnums, donate_argnames, in_tree, + kws: bool = True) -> tuple[bool, ...]: """Returns a tuple with a boolean value for each leaf in args and kwargs. What if a user specifies donate_argnums but calls the function with kwargs @@ -346,12 +334,17 @@ def donation_vector(donate_argnums, donate_argnames, args, kwargs) -> tuple[bool kwargs specified are donated. """ res: list[bool] = [] - for i, arg in enumerate(args): + if kws: + args_tree, kwargs_tree = treedef_children(in_tree) + else: + args_tree, kwargs_tree = in_tree, None + for i, arg in enumerate(args_tree.children()): donate = bool(i in donate_argnums) - res.extend((donate,) * tree_structure(arg).num_leaves) - for key, val in kwargs.items(): - donate = key in donate_argnames - res.extend((donate,) * tree_structure(val).num_leaves) + res.extend((donate,) * arg.num_leaves) + if kwargs_tree is not None: + for key, val in safe_zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore + donate = key in donate_argnames + res.extend((donate,) * val.num_leaves) return tuple(res) def rebase_donate_argnums(donate_argnums, static_argnums) -> tuple[int, ...]: diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 4e37153b4..7936a47b8 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -534,7 +534,7 @@ def xmap(fun: Callable, args_flat, in_tree = tree_flatten(args) fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) if donate_argnums: - donated_invars = donation_vector(donate_argnums, (), args, {}) + donated_invars = donation_vector(donate_argnums, (), in_tree, kws=False) else: donated_invars = (False,) * len(args_flat) in_axes_flat = _flatten_axes("xmap in_axes", in_tree, in_axes, tupled_args=True) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 9a8f37891..6b97073e5 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -47,9 +47,9 @@ from jax._src import util from jax._src import xla_bridge as xb from jax._src.api_util import ( argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs, - donation_vector_with_in_tree, shaped_abstractify, check_callable, + donation_vector, shaped_abstractify, check_callable, resolve_argnums, argnames_partial_except, debug_info, result_paths, jaxpr_debug_info, - hoist_obj_attrs, resolve_argnums) + hoist_obj_attrs) from jax._src.errors import JAXTypeError from jax._src.interpreters import partial_eval as pe from jax._src.partition_spec import PartitionSpec @@ -556,8 +556,7 @@ def _infer_params(jit_info, args, kwargs): flat_fun, explicit_args = hoist_obj_attrs(flat_fun, explicit_args) if (donate_argnums or donate_argnames) and not config.debug_nans.value: - donated_invars = donation_vector_with_in_tree( - donate_argnums, donate_argnames, in_tree) + donated_invars = donation_vector(donate_argnums, donate_argnames, in_tree) else: donated_invars = (False,) * len(explicit_args) del donate_argnums, donate_argnames diff --git a/tests/api_util_test.py b/tests/api_util_test.py index f78b5948f..46bed8c86 100644 --- a/tests/api_util_test.py +++ b/tests/api_util_test.py @@ -43,7 +43,8 @@ class ApiUtilTest(jtu.JaxTestCase): expected += (False,) self.assertEqual( expected, - api_util.donation_vector(donate_argnums, (), args, kwargs)) + api_util.donation_vector(donate_argnums, (), + jax.tree.structure((args, kwargs)))) @parameterized.parameters( ((0,), (0,)),