mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Replace donation_vector's logic with donation_vector_with_in_tree
which is now deleted
PiperOrigin-RevId: 627556267
This commit is contained in:
parent
8842c0bc91
commit
8239674dab
@ -513,7 +513,7 @@ def xla_computation(fun: Callable,
|
||||
f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False)
|
||||
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
|
||||
if donate_argnums:
|
||||
donated_invars = donation_vector(donate_argnums, (), dyn_args, kwargs)
|
||||
donated_invars = donation_vector(donate_argnums, (), in_tree)
|
||||
else:
|
||||
donated_invars = (False,) * len(args_flat)
|
||||
|
||||
@ -1635,7 +1635,7 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
args, in_tree = tree_flatten((dyn_args, kwargs))
|
||||
|
||||
if donate_tuple and not config.debug_nans.value:
|
||||
donated_invars = donation_vector(donate_tuple, (), dyn_args, kwargs)
|
||||
donated_invars = donation_vector(donate_tuple, (), in_tree)
|
||||
else:
|
||||
donated_invars = (False,) * len(args)
|
||||
try:
|
||||
|
@ -27,7 +27,7 @@ from jax._src import dtypes
|
||||
from jax._src.abstract_arrays import numpy_scalar_types
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.tree_util import (
|
||||
PyTreeDef, tree_flatten, tree_unflatten, tree_map, tree_structure,
|
||||
PyTreeDef, tree_flatten, tree_unflatten, tree_map,
|
||||
treedef_children, generate_key_paths, keystr, broadcast_prefix,
|
||||
prefix_errors)
|
||||
from jax._src.tree_util import _replace_nones
|
||||
@ -319,20 +319,8 @@ def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def donation_vector_with_in_tree(donate_argnums, donate_argnames, in_tree
|
||||
) -> tuple[bool, ...]:
|
||||
res: list[bool] = []
|
||||
args_tree, kwargs_tree = treedef_children(in_tree)
|
||||
for i, arg in enumerate(args_tree.children()):
|
||||
donate = bool(i in donate_argnums)
|
||||
res.extend((donate,) * arg.num_leaves)
|
||||
for key, val in safe_zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore
|
||||
donate = key in donate_argnames
|
||||
res.extend((donate,) * val.num_leaves)
|
||||
return tuple(res)
|
||||
|
||||
|
||||
def donation_vector(donate_argnums, donate_argnames, args, kwargs) -> tuple[bool, ...]:
|
||||
def donation_vector(donate_argnums, donate_argnames, in_tree,
|
||||
kws: bool = True) -> tuple[bool, ...]:
|
||||
"""Returns a tuple with a boolean value for each leaf in args and kwargs.
|
||||
|
||||
What if a user specifies donate_argnums but calls the function with kwargs
|
||||
@ -346,12 +334,17 @@ def donation_vector(donate_argnums, donate_argnames, args, kwargs) -> tuple[bool
|
||||
kwargs specified are donated.
|
||||
"""
|
||||
res: list[bool] = []
|
||||
for i, arg in enumerate(args):
|
||||
if kws:
|
||||
args_tree, kwargs_tree = treedef_children(in_tree)
|
||||
else:
|
||||
args_tree, kwargs_tree = in_tree, None
|
||||
for i, arg in enumerate(args_tree.children()):
|
||||
donate = bool(i in donate_argnums)
|
||||
res.extend((donate,) * tree_structure(arg).num_leaves)
|
||||
for key, val in kwargs.items():
|
||||
res.extend((donate,) * arg.num_leaves)
|
||||
if kwargs_tree is not None:
|
||||
for key, val in safe_zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore
|
||||
donate = key in donate_argnames
|
||||
res.extend((donate,) * tree_structure(val).num_leaves)
|
||||
res.extend((donate,) * val.num_leaves)
|
||||
return tuple(res)
|
||||
|
||||
def rebase_donate_argnums(donate_argnums, static_argnums) -> tuple[int, ...]:
|
||||
|
@ -534,7 +534,7 @@ def xmap(fun: Callable,
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
if donate_argnums:
|
||||
donated_invars = donation_vector(donate_argnums, (), args, {})
|
||||
donated_invars = donation_vector(donate_argnums, (), in_tree, kws=False)
|
||||
else:
|
||||
donated_invars = (False,) * len(args_flat)
|
||||
in_axes_flat = _flatten_axes("xmap in_axes", in_tree, in_axes, tupled_args=True)
|
||||
|
@ -47,9 +47,9 @@ from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.api_util import (
|
||||
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
||||
donation_vector_with_in_tree, shaped_abstractify, check_callable,
|
||||
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
|
||||
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info,
|
||||
hoist_obj_attrs, resolve_argnums)
|
||||
hoist_obj_attrs)
|
||||
from jax._src.errors import JAXTypeError
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -556,8 +556,7 @@ def _infer_params(jit_info, args, kwargs):
|
||||
flat_fun, explicit_args = hoist_obj_attrs(flat_fun, explicit_args)
|
||||
|
||||
if (donate_argnums or donate_argnames) and not config.debug_nans.value:
|
||||
donated_invars = donation_vector_with_in_tree(
|
||||
donate_argnums, donate_argnames, in_tree)
|
||||
donated_invars = donation_vector(donate_argnums, donate_argnames, in_tree)
|
||||
else:
|
||||
donated_invars = (False,) * len(explicit_args)
|
||||
del donate_argnums, donate_argnames
|
||||
|
@ -43,7 +43,8 @@ class ApiUtilTest(jtu.JaxTestCase):
|
||||
expected += (False,)
|
||||
self.assertEqual(
|
||||
expected,
|
||||
api_util.donation_vector(donate_argnums, (), args, kwargs))
|
||||
api_util.donation_vector(donate_argnums, (),
|
||||
jax.tree.structure((args, kwargs))))
|
||||
|
||||
@parameterized.parameters(
|
||||
((0,), (0,)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user