mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #10402 from froystig:aot-jit-avoid-trivial
PiperOrigin-RevId: 443533232
This commit is contained in:
commit
5013bd2e3a
@ -566,7 +566,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
else:
|
||||
arg_specs = []
|
||||
computation = dispatch.lower_xla_callable(flat_fun, device, backend, name,
|
||||
donated_invars,
|
||||
donated_invars, True,
|
||||
*arg_specs_and_device)
|
||||
return stages.Lowered.from_flat_info(
|
||||
computation, in_tree, arg_specs, donate_argnums, out_tree())
|
||||
|
@ -194,7 +194,7 @@ xla.xla_call_p.def_impl(_xla_call_impl)
|
||||
|
||||
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
|
||||
donated_invars, *arg_specs):
|
||||
return lower_xla_callable(fun, device, backend, name, donated_invars,
|
||||
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
|
||||
*arg_specs).compile().unsafe_call
|
||||
|
||||
_xla_callable = lu.cache(_xla_callable_uncached)
|
||||
@ -214,7 +214,7 @@ def log_elapsed_time(fmt: str):
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
donated_invars, *arg_specs):
|
||||
donated_invars, always_lower: bool, *arg_specs):
|
||||
if device is not None and backend is not None:
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
"got device={} and backend={}".format(device, backend))
|
||||
@ -253,7 +253,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
# Computations that only produce constants and/or only rearrange their inputs,
|
||||
# which are often produced from partial evaluation, don't need compilation,
|
||||
# and don't need to evaluate their arguments.
|
||||
if not jaxpr.eqns:
|
||||
if not jaxpr.eqns and not always_lower:
|
||||
return XlaComputation(
|
||||
name, None, True, None, None, jaxpr=jaxpr, consts=consts, device=device,
|
||||
in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)
|
||||
|
@ -716,7 +716,8 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
tiling_method=tiling_method, in_is_global=in_is_global)
|
||||
else:
|
||||
return dispatch.lower_xla_callable(
|
||||
f, None, backend, name, donated_invars, *((a, None) for a in in_avals))
|
||||
f, None, backend, name, donated_invars, False,
|
||||
*[(a, None) for a in in_avals])
|
||||
|
||||
class EvaluationPlan(NamedTuple):
|
||||
"""Encapsulates preprocessing common to top-level xmap invocations and its translation rule."""
|
||||
|
@ -824,10 +824,11 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
f_exe = self.jit(f).lower(1., 1.).compile()
|
||||
self.assertAllClose(f_exe(1., 1.), 1.)
|
||||
|
||||
@jtu.skip_on_devices("cpu") # no donation on cpu, so this would warn
|
||||
def test_jit_lower_donate_argnums_available(self):
|
||||
def f(*args):
|
||||
x, *_ = args
|
||||
return x
|
||||
return x + 4.
|
||||
f_low = self.jit(f, donate_argnums=(0,)).lower(1., 1.)
|
||||
f_com = f_low.compile()
|
||||
f_low.donate_argnums == f_com.donate_argnums == (0,)
|
||||
@ -852,6 +853,16 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
f = self.jit(lambda x: x + 4).lower(1.).compile()
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
|
||||
def test_jit_lower_trivial_compiler_ir(self):
|
||||
f = self.jit(lambda x: x).lower(1.)
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
||||
|
||||
def test_jit_lower_trivial_compile_compiler_ir(self):
|
||||
f = self.jit(lambda x: x).lower(1.).compile()
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
|
||||
def test_jit_lower_compile_executable(self):
|
||||
f = self.jit(lambda x: x + 4).lower(1.).compile()
|
||||
self.assertIsNotNone(f.runtime_executable())
|
||||
|
Loading…
x
Reference in New Issue
Block a user