mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jit(f).lower(...) works w/ duck typed shape/dtype
This commit is contained in:
parent
2bda894a30
commit
05708aef2b
@ -321,9 +321,6 @@ def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums,
|
||||
else:
|
||||
donated_invars = (False,) * len(args_flat)
|
||||
|
||||
for arg in args_flat:
|
||||
_check_arg(arg)
|
||||
|
||||
return f, in_tree, args_flat, donated_invars
|
||||
|
||||
|
||||
@ -351,6 +348,8 @@ def _python_jit(
|
||||
return fun(*args, **kwargs)
|
||||
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
|
||||
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
|
||||
for arg in args_flat:
|
||||
_check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
|
||||
out_flat = xla.xla_call(
|
||||
flat_fun, *args_flat,
|
||||
@ -412,6 +411,8 @@ def _cpp_jit(
|
||||
# work/code that is redundant between C++ and Python. We can try that later.
|
||||
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
|
||||
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
|
||||
for arg in args_flat:
|
||||
_check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
|
||||
out_flat = xla.xla_call(
|
||||
flat_fun, *args_flat,
|
||||
@ -561,6 +562,15 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
# If the function we returned from ``jit`` were a class instance,
|
||||
# this might naturally be a method, with ``fun`` as a ``self`` and
|
||||
# all the other arguments stored as attributes.
|
||||
|
||||
def arg_spec(x):
|
||||
# like xla.arg_spec but duck-types on x.shape and x.dtype
|
||||
aval = shaped_abstractify(x)
|
||||
try:
|
||||
return aval, x._device
|
||||
except:
|
||||
return aval, None
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs) -> Lowered:
|
||||
"""Lower this function for the given arguments.
|
||||
@ -576,7 +586,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
|
||||
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
|
||||
name = flat_fun.__name__
|
||||
arg_specs = unsafe_map(xla.arg_spec, args_flat)
|
||||
arg_specs = unsafe_map(arg_spec, args_flat)
|
||||
computation = xla.lower_xla_callable(
|
||||
flat_fun, device, backend, name, donated_invars, *arg_specs)
|
||||
return Lowered(computation, in_tree, out_tree())
|
||||
|
@ -723,6 +723,12 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
f_exe = f_low.compile()
|
||||
self.assertAllClose(f_exe(1.), 2.)
|
||||
|
||||
def test_jit_lower_duck_typing(self):
|
||||
f_jit = self.jit(lambda x: 2 * x)
|
||||
f_low = f_jit.lower(jax.ShapeDtypeStruct((), 'float32')) # doesn't crash
|
||||
f_exe = f_low.compile()
|
||||
self.assertAllClose(f_exe(jnp.float32(1.)), jnp.float32(2.))
|
||||
|
||||
def test_jit_lower_compile_in_tree_mismatch(self):
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
|
@ -372,7 +372,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
xla._translations[pjit_p] = rule
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testLowerWithAbstractArgs(self):
|
||||
def testLowerWithDuckTyping(self):
|
||||
x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
|
||||
# Make sure this doesn't crash
|
||||
pjit(lambda x: x + 4, in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user