jit(f).lower(...) works w/ duck typed shape/dtype

This commit is contained in:
Matthew Johnson 2021-10-27 20:27:09 -07:00
parent 2bda894a30
commit 05708aef2b
3 changed files with 21 additions and 5 deletions

View File

@ -321,9 +321,6 @@ def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums,
else:
donated_invars = (False,) * len(args_flat)
for arg in args_flat:
_check_arg(arg)
return f, in_tree, args_flat, donated_invars
@ -351,6 +348,8 @@ def _python_jit(
return fun(*args, **kwargs)
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
out_flat = xla.xla_call(
flat_fun, *args_flat,
@ -412,6 +411,8 @@ def _cpp_jit(
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
out_flat = xla.xla_call(
flat_fun, *args_flat,
@ -561,6 +562,15 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
# If the function we returned from ``jit`` were a class instance,
# this might naturally be a method, with ``fun`` as a ``self`` and
# all the other arguments stored as attributes.
def arg_spec(x):
# like xla.arg_spec but duck-types on x.shape and x.dtype
aval = shaped_abstractify(x)
try:
return aval, x._device
except:
return aval, None
@api_boundary
def lower(*args, **kwargs) -> Lowered:
"""Lower this function for the given arguments.
@ -576,7 +586,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
name = flat_fun.__name__
arg_specs = unsafe_map(xla.arg_spec, args_flat)
arg_specs = unsafe_map(arg_spec, args_flat)
computation = xla.lower_xla_callable(
flat_fun, device, backend, name, donated_invars, *arg_specs)
return Lowered(computation, in_tree, out_tree())

View File

@ -723,6 +723,12 @@ class CPPJitTest(jtu.BufferDonationTestCase):
f_exe = f_low.compile()
self.assertAllClose(f_exe(1.), 2.)
def test_jit_lower_duck_typing(self):
f_jit = self.jit(lambda x: 2 * x)
f_low = f_jit.lower(jax.ShapeDtypeStruct((), 'float32')) # doesn't crash
f_exe = f_low.compile()
self.assertAllClose(f_exe(jnp.float32(1.)), jnp.float32(2.))
def test_jit_lower_compile_in_tree_mismatch(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.

View File

@ -372,7 +372,7 @@ class PJitTest(jtu.BufferDonationTestCase):
xla._translations[pjit_p] = rule
@jtu.with_mesh([('x', 2)])
def testLowerWithAbstractArgs(self):
def testLowerWithDuckTyping(self):
x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
# Make sure this doesn't crash
pjit(lambda x: x + 4, in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)