Merge pull request #10402 from froystig:aot-jit-avoid-trivial

PiperOrigin-RevId: 443533232
This commit is contained in:
jax authors 2022-04-21 18:13:10 -07:00
commit 5013bd2e3a
4 changed files with 18 additions and 6 deletions

View File

@ -566,7 +566,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
else:
arg_specs = []
computation = dispatch.lower_xla_callable(flat_fun, device, backend, name,
donated_invars,
donated_invars, True,
*arg_specs_and_device)
return stages.Lowered.from_flat_info(
computation, in_tree, arg_specs, donate_argnums, out_tree())

View File

@ -194,7 +194,7 @@ xla.xla_call_p.def_impl(_xla_call_impl)
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
return lower_xla_callable(fun, device, backend, name, donated_invars,
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
*arg_specs).compile().unsafe_call
_xla_callable = lu.cache(_xla_callable_uncached)
@ -214,7 +214,7 @@ def log_elapsed_time(fmt: str):
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
donated_invars, always_lower: bool, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
@ -253,7 +253,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
if not jaxpr.eqns:
if not jaxpr.eqns and not always_lower:
return XlaComputation(
name, None, True, None, None, jaxpr=jaxpr, consts=consts, device=device,
in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)

View File

@ -716,7 +716,8 @@ def make_xmap_callable(fun: lu.WrappedFun,
tiling_method=tiling_method, in_is_global=in_is_global)
else:
return dispatch.lower_xla_callable(
f, None, backend, name, donated_invars, *((a, None) for a in in_avals))
f, None, backend, name, donated_invars, False,
*[(a, None) for a in in_avals])
class EvaluationPlan(NamedTuple):
"""Encapsulates preprocessing common to top-level xmap invocations and its translation rule."""

View File

@ -824,10 +824,11 @@ class CPPJitTest(jtu.BufferDonationTestCase):
f_exe = self.jit(f).lower(1., 1.).compile()
self.assertAllClose(f_exe(1., 1.), 1.)
@jtu.skip_on_devices("cpu") # no donation on cpu, so this would warn
def test_jit_lower_donate_argnums_available(self):
def f(*args):
x, *_ = args
return x
return x + 4.
f_low = self.jit(f, donate_argnums=(0,)).lower(1., 1.)
f_com = f_low.compile()
f_low.donate_argnums == f_com.donate_argnums == (0,)
@ -852,6 +853,16 @@ class CPPJitTest(jtu.BufferDonationTestCase):
f = self.jit(lambda x: x + 4).lower(1.).compile()
self.assertIsNotNone(f.compiler_ir())
def test_jit_lower_trivial_compiler_ir(self):
f = self.jit(lambda x: x).lower(1.)
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
def test_jit_lower_trivial_compile_compiler_ir(self):
f = self.jit(lambda x: x).lower(1.).compile()
self.assertIsNotNone(f.compiler_ir())
def test_jit_lower_compile_executable(self):
f = self.jit(lambda x: x + 4).lower(1.).compile()
self.assertIsNotNone(f.runtime_executable())