mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add donated_invars to xla.XlaComputation
Co-authored-by: Brennan Saeta <saeta@google.com>
This commit is contained in:
parent
476ca94379
commit
5d35b8a119
@ -488,17 +488,21 @@ class Lowered:
|
||||
querying properties of lowered computations across JAX's various
|
||||
lowering paths (``jit``, ``pmap``, etc.).
|
||||
"""
|
||||
__slots__ = ['in_tree', 'out_tree', '_lowering', '_no_kwargs']
|
||||
__slots__ = ['in_tree', 'out_tree', 'donate_argnums', '_lowering',
|
||||
'_no_kwargs']
|
||||
|
||||
in_tree: PyTreeDef
|
||||
out_tree: PyTreeDef
|
||||
donate_argnums: Tuple[int]
|
||||
_lowering: Union[xla.XlaComputation, pxla.MeshComputation]
|
||||
_no_kwargs: bool
|
||||
|
||||
def __init__(self, lowering, in_tree, out_tree, no_kwargs=False):
|
||||
def __init__(self, lowering, in_tree, out_tree, donate_argnums,
|
||||
no_kwargs=False):
|
||||
self._lowering = lowering
|
||||
self.in_tree = in_tree
|
||||
self.out_tree = out_tree
|
||||
self.donate_argnums = donate_argnums
|
||||
self._no_kwargs = no_kwargs
|
||||
|
||||
def _xla_computation(self):
|
||||
@ -508,7 +512,8 @@ class Lowered:
|
||||
|
||||
def compile(self) -> 'Compiled':
|
||||
return Compiled(
|
||||
self._lowering.compile(), self.in_tree, self.out_tree, self._no_kwargs)
|
||||
self._lowering.compile(), self.in_tree, self.out_tree,
|
||||
self.donate_argnums, self._no_kwargs)
|
||||
|
||||
|
||||
class Compiled:
|
||||
@ -519,17 +524,21 @@ class Compiled:
|
||||
common API for querying properties of compiled computations across
|
||||
JAX's various compilation paths and backends.
|
||||
"""
|
||||
__slots__ = ['in_tree', 'out_tree', '_executable', '_no_kwargs']
|
||||
__slots__ = ['in_tree', 'out_tree', 'donate_argnums', '_executable',
|
||||
'_no_kwargs']
|
||||
|
||||
in_tree: PyTreeDef
|
||||
out_tree: PyTreeDef
|
||||
donate_argnums: Tuple[int]
|
||||
_executable: Union[xla.XlaCompiledComputation, pxla.MeshExecutable]
|
||||
_no_kwargs: bool
|
||||
|
||||
def __init__(self, executable, in_tree, out_tree, no_kwargs=False):
|
||||
def __init__(self, executable, in_tree, out_tree, donate_argnums,
|
||||
no_kwargs=False):
|
||||
self._executable = executable
|
||||
self.in_tree = in_tree
|
||||
self.out_tree = out_tree
|
||||
self.donate_argnums = donate_argnums
|
||||
self._no_kwargs = no_kwargs
|
||||
|
||||
def _xla_executable(self):
|
||||
@ -589,7 +598,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
arg_specs = unsafe_map(arg_spec, args_flat)
|
||||
computation = xla.lower_xla_callable(
|
||||
flat_fun, device, backend, name, donated_invars, *arg_specs)
|
||||
return Lowered(computation, in_tree, out_tree())
|
||||
return Lowered(computation, in_tree, out_tree(), donate_argnums)
|
||||
|
||||
return lower
|
||||
|
||||
|
@ -229,23 +229,24 @@ def pjit(fun: Callable,
|
||||
donated_invars=donated_invars,
|
||||
name=flat_fun.__name__,
|
||||
positional_semantics=maps._positional_semantics)
|
||||
return args_flat, params, in_tree, out_tree()
|
||||
return args_flat, params, in_tree, out_tree(), donate_argnums
|
||||
|
||||
@wraps(fun)
|
||||
def wrapped(*args, **kwargs):
|
||||
for arg in tree_leaves(args):
|
||||
_check_arg(arg)
|
||||
args_flat, params, _, out_tree = infer_params(*args, **kwargs)
|
||||
args_flat, params, _, out_tree, _ = infer_params(*args, **kwargs)
|
||||
out = pjit_p.bind(*args_flat, **params)
|
||||
return tree_unflatten(out_tree, out)
|
||||
|
||||
def lower(*args, **kwargs):
|
||||
args_flat, params, in_tree, out_tree = infer_params(*args, **kwargs)
|
||||
args_flat, params, in_tree, out_tree, donate_argnums = \
|
||||
infer_params(*args, **kwargs)
|
||||
lowering = _pjit_lower(
|
||||
params['jaxpr'], params['in_axis_resources'],
|
||||
params['out_axis_resources'], params['resource_env'],
|
||||
params['donated_invars'], params['name'], maps._positional_semantics)
|
||||
return Lowered(lowering, in_tree, out_tree, no_kwargs=True)
|
||||
return Lowered(lowering, in_tree, out_tree, donate_argnums, no_kwargs=True)
|
||||
|
||||
wrapped.lower = lower
|
||||
return wrapped
|
||||
|
@ -1723,15 +1723,16 @@ def lower_mesh_computation(
|
||||
|
||||
built = c.Build(out_tuple)
|
||||
return MeshComputation(
|
||||
built, mesh, local_in_untiled_avals,
|
||||
built, donated_invars, mesh, local_in_untiled_avals,
|
||||
local_out_untiled_avals, (out_jaxpr_avals if spmd_lowering else None),
|
||||
in_axes, out_axes, spmd_lowering, tuple_args)
|
||||
|
||||
|
||||
class MeshComputation:
|
||||
def __init__(self, hlo, *compile_args):
|
||||
def __init__(self, hlo, donated_invars, *compile_args):
|
||||
self._executable = None
|
||||
self._hlo = hlo
|
||||
self._donated_invars = donated_invars
|
||||
self.compile_args = compile_args
|
||||
|
||||
def hlo(self):
|
||||
|
@ -800,7 +800,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
# and don't need to evaluate their arguments.
|
||||
if not jaxpr.eqns:
|
||||
return XlaComputation(
|
||||
name, None, True, jaxpr, consts, device, abstract_args, out_avals,
|
||||
name, None, True, None, jaxpr, consts, device, abstract_args, out_avals,
|
||||
kept_var_idx)
|
||||
|
||||
if not _on_exit:
|
||||
@ -850,8 +850,8 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
", ".join(unused_donations)))
|
||||
built = c.build(output)
|
||||
return XlaComputation(
|
||||
name, built, False, nreps, device, backend, tuple_args, abstract_args,
|
||||
out_avals, kept_var_idx)
|
||||
name, built, False, donated_invars, nreps, device, backend, tuple_args,
|
||||
abstract_args, out_avals, kept_var_idx)
|
||||
|
||||
|
||||
def compile_or_get_cached(backend, computation, compile_options):
|
||||
@ -875,11 +875,14 @@ class XlaComputation:
|
||||
name: str
|
||||
_is_trivial: bool
|
||||
_executable: Optional['XlaCompiledComputation']
|
||||
_donated_invars: Optional[Sequence[bool]]
|
||||
|
||||
def __init__(self, name: str, hlo, is_trivial: bool, *compile_args):
|
||||
def __init__(self, name: str, hlo, is_trivial: bool,
|
||||
donated_invars: Optional[Sequence[bool]], *compile_args):
|
||||
self.name = name
|
||||
self._hlo = hlo
|
||||
self._is_trivial = is_trivial
|
||||
self._donated_invars = donated_invars
|
||||
self._executable = None
|
||||
self.compile_args = compile_args
|
||||
|
||||
|
@ -780,6 +780,14 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
f_exe = self.jit(f).lower(1., 1.).compile()
|
||||
self.assertAllClose(f_exe(1., 1.), 1.)
|
||||
|
||||
def test_jit_lower_donate_argnums_available(self):
|
||||
def f(*args):
|
||||
x, *_ = args
|
||||
return x
|
||||
f_low = self.jit(f, donate_argnums=(0,)).lower(1., 1.)
|
||||
f_com = f_low.compile()
|
||||
f_low.donate_argnums == f_com.donate_argnums == (0,)
|
||||
|
||||
|
||||
class PythonJitTest(CPPJitTest):
|
||||
|
||||
|
@ -385,7 +385,19 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
def testLowerWithDuckTyping(self):
|
||||
x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
|
||||
# Make sure this doesn't crash
|
||||
pjit(lambda x: x + 4, in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)
|
||||
pjit(lambda x: x + 4,
|
||||
in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testLowerDonateArgnumsAvailable(self):
|
||||
x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
|
||||
def f(*args):
|
||||
x, *_ = args
|
||||
return x
|
||||
f_low = pjit(f, donate_argnums=(0,),
|
||||
in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)
|
||||
f_com = f_low.compile()
|
||||
f_low.donate_argnums == f_com.donate_argnums == (0,)
|
||||
|
||||
def testInfeed(self):
|
||||
devices = np.array(jax.local_devices())
|
||||
|
Loading…
x
Reference in New Issue
Block a user