Merge pull request #8241 from froystig:xla-computation-in-avals

PiperOrigin-RevId: 403500602
This commit is contained in:
jax authors 2021-10-15 17:08:57 -07:00
commit 40103f6a71
2 changed files with 28 additions and 6 deletions

View File

@ -751,9 +751,10 @@ class XlaComputation:
class XlaCompiledComputation:
def __init__(self, xla_executable, in_avals, unsafe_call):
def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call):
self._xla_executable = xla_executable
self.in_avals = in_avals
self._kept_var_idx = kept_var_idx
self.unsafe_call = unsafe_call
@staticmethod
@ -777,8 +778,9 @@ class XlaCompiledComputation:
buffer_counts = (None if len(out_avals) == 1 else
[len(aval_to_xla_shapes(aval)) for aval in out_avals])
execute = _execute_compiled if nreps == 1 else _execute_replicated
return XlaCompiledComputation(compiled, in_avals, partial(execute,
name, compiled, buffer_counts, result_handlers, kept_var_idx))
unsafe_call = partial(execute, name, compiled, buffer_counts,
result_handlers, kept_var_idx)
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call)
def is_trivial(self):
return self._xla_executable == None
@ -792,17 +794,23 @@ class XlaCompiledComputation:
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals,
kept_var_idx) -> 'XlaCompiledComputation':
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
return XlaCompiledComputation(None, in_avals, partial(_execute_trivial,
jaxpr, device, consts, out_avals, result_handlers, kept_var_idx))
unsafe_call = partial(_execute_trivial, jaxpr, device, consts,
out_avals, result_handlers, kept_var_idx)
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call)
def call(self, *args):
arg_specs = unsafe_map(arg_spec, args)
arg_avals = [spec[0] for spec in arg_specs]
arg_avals = [spec[0] for i, spec in enumerate(arg_specs)
if i in self._kept_var_idx]
check_arg_avals_for_call(self.in_avals, arg_avals)
return self.unsafe_call(*args)
def check_arg_avals_for_call(ref_avals, arg_avals):
if len(ref_avals) != len(arg_avals):
raise TypeError(
f"Computation compiled for {len(ref_avals)} inputs "
f"but called with {len(arg_avals)}")
for ref_aval, arg_aval in zip(ref_avals, arg_avals):
if not core.typematch(ref_aval, arg_aval):
ref_avals_fmt = ', '.join(str(a) for a in ref_avals)

View File

@ -760,6 +760,20 @@ class CPPJitTest(jtu.BufferDonationTestCase):
"called with:\n.*int32.*",
lambda: f_exe(x_i32))
def test_jit_lower_compile_multi_arg(self):
def f(*args):
x, *_ = args
return jnp.sqrt(x ** 2) + 1.
f_exe = self.jit(f).lower(1., 1.).compile()
self.assertAllClose(f_exe(1., 1.), 2.)
def test_jit_lower_compile_trivial_multi_arg(self):
def f(*args):
x, *_ = args
return x
f_exe = self.jit(f).lower(1., 1.).compile()
self.assertAllClose(f_exe(1., 1.), 1.)
class PythonJitTest(CPPJitTest):