add donated_invars to xla.XlaComputation

Co-authored-by: Brennan Saeta <saeta@google.com>
This commit is contained in:
Matthew Johnson 2021-11-16 11:21:27 -08:00
parent 476ca94379
commit 5d35b8a119
6 changed files with 51 additions and 17 deletions

View File

@ -488,17 +488,21 @@ class Lowered:
querying properties of lowered computations across JAX's various
lowering paths (``jit``, ``pmap``, etc.).
"""
__slots__ = ['in_tree', 'out_tree', '_lowering', '_no_kwargs']
__slots__ = ['in_tree', 'out_tree', 'donate_argnums', '_lowering',
'_no_kwargs']
in_tree: PyTreeDef
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_lowering: Union[xla.XlaComputation, pxla.MeshComputation]
_no_kwargs: bool
def __init__(self, lowering, in_tree, out_tree, no_kwargs=False):
def __init__(self, lowering, in_tree, out_tree, donate_argnums,
no_kwargs=False):
self._lowering = lowering
self.in_tree = in_tree
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs
def _xla_computation(self):
@ -508,7 +512,8 @@ class Lowered:
def compile(self) -> 'Compiled':
return Compiled(
self._lowering.compile(), self.in_tree, self.out_tree, self._no_kwargs)
self._lowering.compile(), self.in_tree, self.out_tree,
self.donate_argnums, self._no_kwargs)
class Compiled:
@ -519,17 +524,21 @@ class Compiled:
common API for querying properties of compiled computations across
JAX's various compilation paths and backends.
"""
__slots__ = ['in_tree', 'out_tree', '_executable', '_no_kwargs']
__slots__ = ['in_tree', 'out_tree', 'donate_argnums', '_executable',
'_no_kwargs']
in_tree: PyTreeDef
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_executable: Union[xla.XlaCompiledComputation, pxla.MeshExecutable]
_no_kwargs: bool
def __init__(self, executable, in_tree, out_tree, no_kwargs=False):
def __init__(self, executable, in_tree, out_tree, donate_argnums,
no_kwargs=False):
self._executable = executable
self.in_tree = in_tree
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs
def _xla_executable(self):
@ -589,7 +598,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
arg_specs = unsafe_map(arg_spec, args_flat)
computation = xla.lower_xla_callable(
flat_fun, device, backend, name, donated_invars, *arg_specs)
return Lowered(computation, in_tree, out_tree())
return Lowered(computation, in_tree, out_tree(), donate_argnums)
return lower

View File

@ -229,23 +229,24 @@ def pjit(fun: Callable,
donated_invars=donated_invars,
name=flat_fun.__name__,
positional_semantics=maps._positional_semantics)
return args_flat, params, in_tree, out_tree()
return args_flat, params, in_tree, out_tree(), donate_argnums
@wraps(fun)
def wrapped(*args, **kwargs):
for arg in tree_leaves(args):
_check_arg(arg)
args_flat, params, _, out_tree = infer_params(*args, **kwargs)
args_flat, params, _, out_tree, _ = infer_params(*args, **kwargs)
out = pjit_p.bind(*args_flat, **params)
return tree_unflatten(out_tree, out)
def lower(*args, **kwargs):
args_flat, params, in_tree, out_tree = infer_params(*args, **kwargs)
args_flat, params, in_tree, out_tree, donate_argnums = \
infer_params(*args, **kwargs)
lowering = _pjit_lower(
params['jaxpr'], params['in_axis_resources'],
params['out_axis_resources'], params['resource_env'],
params['donated_invars'], params['name'], maps._positional_semantics)
return Lowered(lowering, in_tree, out_tree, no_kwargs=True)
return Lowered(lowering, in_tree, out_tree, donate_argnums, no_kwargs=True)
wrapped.lower = lower
return wrapped

View File

@ -1723,15 +1723,16 @@ def lower_mesh_computation(
built = c.Build(out_tuple)
return MeshComputation(
built, mesh, local_in_untiled_avals,
built, donated_invars, mesh, local_in_untiled_avals,
local_out_untiled_avals, (out_jaxpr_avals if spmd_lowering else None),
in_axes, out_axes, spmd_lowering, tuple_args)
class MeshComputation:
def __init__(self, hlo, *compile_args):
def __init__(self, hlo, donated_invars, *compile_args):
self._executable = None
self._hlo = hlo
self._donated_invars = donated_invars
self.compile_args = compile_args
def hlo(self):

View File

@ -800,7 +800,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
# and don't need to evaluate their arguments.
if not jaxpr.eqns:
return XlaComputation(
name, None, True, jaxpr, consts, device, abstract_args, out_avals,
name, None, True, None, jaxpr, consts, device, abstract_args, out_avals,
kept_var_idx)
if not _on_exit:
@ -850,8 +850,8 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
", ".join(unused_donations)))
built = c.build(output)
return XlaComputation(
name, built, False, nreps, device, backend, tuple_args, abstract_args,
out_avals, kept_var_idx)
name, built, False, donated_invars, nreps, device, backend, tuple_args,
abstract_args, out_avals, kept_var_idx)
def compile_or_get_cached(backend, computation, compile_options):
@ -875,11 +875,14 @@ class XlaComputation:
name: str
_is_trivial: bool
_executable: Optional['XlaCompiledComputation']
_donated_invars: Optional[Sequence[bool]]
def __init__(self, name: str, hlo, is_trivial: bool, *compile_args):
def __init__(self, name: str, hlo, is_trivial: bool,
donated_invars: Optional[Sequence[bool]], *compile_args):
self.name = name
self._hlo = hlo
self._is_trivial = is_trivial
self._donated_invars = donated_invars
self._executable = None
self.compile_args = compile_args

View File

@ -780,6 +780,14 @@ class CPPJitTest(jtu.BufferDonationTestCase):
f_exe = self.jit(f).lower(1., 1.).compile()
self.assertAllClose(f_exe(1., 1.), 1.)
def test_jit_lower_donate_argnums_available(self):
def f(*args):
x, *_ = args
return x
f_low = self.jit(f, donate_argnums=(0,)).lower(1., 1.)
f_com = f_low.compile()
f_low.donate_argnums == f_com.donate_argnums == (0,)
class PythonJitTest(CPPJitTest):

View File

@ -385,7 +385,19 @@ class PJitTest(jtu.BufferDonationTestCase):
def testLowerWithDuckTyping(self):
x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
# Make sure this doesn't crash
pjit(lambda x: x + 4, in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)
pjit(lambda x: x + 4,
in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)
@jtu.with_mesh([('x', 2)])
def testLowerDonateArgnumsAvailable(self):
x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
def f(*args):
x, *_ = args
return x
f_low = pjit(f, donate_argnums=(0,),
in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)
f_com = f_low.compile()
f_low.donate_argnums == f_com.donate_argnums == (0,)
def testInfeed(self):
devices = np.array(jax.local_devices())