Replace donation_vector's logic with donation_vector_with_in_tree which is now deleted

PiperOrigin-RevId: 627556267
This commit is contained in:
Yash Katariya 2024-04-23 17:37:52 -07:00 committed by jax authors
parent 8842c0bc91
commit 8239674dab
5 changed files with 21 additions and 28 deletions

View File

@ -513,7 +513,7 @@ def xla_computation(fun: Callable,
f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False)
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
if donate_argnums:
donated_invars = donation_vector(donate_argnums, (), dyn_args, kwargs)
donated_invars = donation_vector(donate_argnums, (), in_tree)
else:
donated_invars = (False,) * len(args_flat)
@ -1635,7 +1635,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
args, in_tree = tree_flatten((dyn_args, kwargs))
if donate_tuple and not config.debug_nans.value:
donated_invars = donation_vector(donate_tuple, (), dyn_args, kwargs)
donated_invars = donation_vector(donate_tuple, (), in_tree)
else:
donated_invars = (False,) * len(args)
try:

View File

@ -27,7 +27,7 @@ from jax._src import dtypes
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ShapedArray
from jax._src.tree_util import (
PyTreeDef, tree_flatten, tree_unflatten, tree_map, tree_structure,
PyTreeDef, tree_flatten, tree_unflatten, tree_map,
treedef_children, generate_key_paths, keystr, broadcast_prefix,
prefix_errors)
from jax._src.tree_util import _replace_nones
@ -319,20 +319,8 @@ def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
@lru_cache(maxsize=4096)
def donation_vector_with_in_tree(donate_argnums, donate_argnames, in_tree
) -> tuple[bool, ...]:
res: list[bool] = []
args_tree, kwargs_tree = treedef_children(in_tree)
for i, arg in enumerate(args_tree.children()):
donate = bool(i in donate_argnums)
res.extend((donate,) * arg.num_leaves)
for key, val in safe_zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore
donate = key in donate_argnames
res.extend((donate,) * val.num_leaves)
return tuple(res)
def donation_vector(donate_argnums, donate_argnames, args, kwargs) -> tuple[bool, ...]:
def donation_vector(donate_argnums, donate_argnames, in_tree,
kws: bool = True) -> tuple[bool, ...]:
"""Returns a tuple with a boolean value for each leaf in args and kwargs.
What if a user specifies donate_argnums but calls the function with kwargs
@ -346,12 +334,17 @@ def donation_vector(donate_argnums, donate_argnames, args, kwargs) -> tuple[bool
kwargs specified are donated.
"""
res: list[bool] = []
for i, arg in enumerate(args):
if kws:
args_tree, kwargs_tree = treedef_children(in_tree)
else:
args_tree, kwargs_tree = in_tree, None
for i, arg in enumerate(args_tree.children()):
donate = bool(i in donate_argnums)
res.extend((donate,) * tree_structure(arg).num_leaves)
for key, val in kwargs.items():
res.extend((donate,) * arg.num_leaves)
if kwargs_tree is not None:
for key, val in safe_zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore
donate = key in donate_argnames
res.extend((donate,) * tree_structure(val).num_leaves)
res.extend((donate,) * val.num_leaves)
return tuple(res)
def rebase_donate_argnums(donate_argnums, static_argnums) -> tuple[int, ...]:

View File

@ -534,7 +534,7 @@ def xmap(fun: Callable,
args_flat, in_tree = tree_flatten(args)
fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
if donate_argnums:
donated_invars = donation_vector(donate_argnums, (), args, {})
donated_invars = donation_vector(donate_argnums, (), in_tree, kws=False)
else:
donated_invars = (False,) * len(args_flat)
in_axes_flat = _flatten_axes("xmap in_axes", in_tree, in_axes, tupled_args=True)

View File

@ -47,9 +47,9 @@ from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector_with_in_tree, shaped_abstractify, check_callable,
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info,
hoist_obj_attrs, resolve_argnums)
hoist_obj_attrs)
from jax._src.errors import JAXTypeError
from jax._src.interpreters import partial_eval as pe
from jax._src.partition_spec import PartitionSpec
@ -556,8 +556,7 @@ def _infer_params(jit_info, args, kwargs):
flat_fun, explicit_args = hoist_obj_attrs(flat_fun, explicit_args)
if (donate_argnums or donate_argnames) and not config.debug_nans.value:
donated_invars = donation_vector_with_in_tree(
donate_argnums, donate_argnames, in_tree)
donated_invars = donation_vector(donate_argnums, donate_argnames, in_tree)
else:
donated_invars = (False,) * len(explicit_args)
del donate_argnums, donate_argnames

View File

@ -43,7 +43,8 @@ class ApiUtilTest(jtu.JaxTestCase):
expected += (False,)
self.assertEqual(
expected,
api_util.donation_vector(donate_argnums, (), args, kwargs))
api_util.donation_vector(donate_argnums, (),
jax.tree.structure((args, kwargs))))
@parameterized.parameters(
((0,), (0,)),