mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #8241 from froystig:xla-computation-in-avals
PiperOrigin-RevId: 403500602
This commit is contained in:
commit
40103f6a71
@ -751,9 +751,10 @@ class XlaComputation:
|
||||
|
||||
|
||||
class XlaCompiledComputation:
|
||||
def __init__(self, xla_executable, in_avals, unsafe_call):
|
||||
def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call):
|
||||
self._xla_executable = xla_executable
|
||||
self.in_avals = in_avals
|
||||
self._kept_var_idx = kept_var_idx
|
||||
self.unsafe_call = unsafe_call
|
||||
|
||||
@staticmethod
|
||||
@ -777,8 +778,9 @@ class XlaCompiledComputation:
|
||||
buffer_counts = (None if len(out_avals) == 1 else
|
||||
[len(aval_to_xla_shapes(aval)) for aval in out_avals])
|
||||
execute = _execute_compiled if nreps == 1 else _execute_replicated
|
||||
return XlaCompiledComputation(compiled, in_avals, partial(execute,
|
||||
name, compiled, buffer_counts, result_handlers, kept_var_idx))
|
||||
unsafe_call = partial(execute, name, compiled, buffer_counts,
|
||||
result_handlers, kept_var_idx)
|
||||
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call)
|
||||
|
||||
def is_trivial(self):
|
||||
return self._xla_executable == None
|
||||
@ -792,17 +794,23 @@ class XlaCompiledComputation:
|
||||
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals,
|
||||
kept_var_idx) -> 'XlaCompiledComputation':
|
||||
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
||||
return XlaCompiledComputation(None, in_avals, partial(_execute_trivial,
|
||||
jaxpr, device, consts, out_avals, result_handlers, kept_var_idx))
|
||||
unsafe_call = partial(_execute_trivial, jaxpr, device, consts,
|
||||
out_avals, result_handlers, kept_var_idx)
|
||||
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call)
|
||||
|
||||
def call(self, *args):
|
||||
arg_specs = unsafe_map(arg_spec, args)
|
||||
arg_avals = [spec[0] for spec in arg_specs]
|
||||
arg_avals = [spec[0] for i, spec in enumerate(arg_specs)
|
||||
if i in self._kept_var_idx]
|
||||
check_arg_avals_for_call(self.in_avals, arg_avals)
|
||||
return self.unsafe_call(*args)
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals):
|
||||
if len(ref_avals) != len(arg_avals):
|
||||
raise TypeError(
|
||||
f"Computation compiled for {len(ref_avals)} inputs "
|
||||
f"but called with {len(arg_avals)}")
|
||||
for ref_aval, arg_aval in zip(ref_avals, arg_avals):
|
||||
if not core.typematch(ref_aval, arg_aval):
|
||||
ref_avals_fmt = ', '.join(str(a) for a in ref_avals)
|
||||
|
@ -760,6 +760,20 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
"called with:\n.*int32.*",
|
||||
lambda: f_exe(x_i32))
|
||||
|
||||
def test_jit_lower_compile_multi_arg(self):
|
||||
def f(*args):
|
||||
x, *_ = args
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
f_exe = self.jit(f).lower(1., 1.).compile()
|
||||
self.assertAllClose(f_exe(1., 1.), 2.)
|
||||
|
||||
def test_jit_lower_compile_trivial_multi_arg(self):
|
||||
def f(*args):
|
||||
x, *_ = args
|
||||
return x
|
||||
f_exe = self.jit(f).lower(1., 1.).compile()
|
||||
self.assertAllClose(f_exe(1., 1.), 1.)
|
||||
|
||||
|
||||
class PythonJitTest(CPPJitTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user