[shape_poly] Add partial support for call_exported with polymorphic shapes

Until now the jax_export.call_exported did not allow calling functions
that were exported with polymorphic shapes. We now add that support,
including resolving the dimension variables of the called function
in terms of the shapes at the call site (which themselves may include
dimension variables), and then computing the output shape of the
called function.

The support is partial in that we can export a JAX function that
calls an exported polymorphic function, but we cannot invoke it.
This is because we do not yet have access to the shape refinement
machinery that XlaCallModule uses. For now, we use XlaCallModule
for invoking exported that includes shape polymorphism.
This commit is contained in:
George Necula 2023-04-26 09:11:04 +02:00
parent 7833528765
commit 46a258ba17
2 changed files with 76 additions and 14 deletions

View File

@ -755,15 +755,19 @@ call_exported_p.multiple_results = True
def _call_exported_abstract_eval(*in_avals: core.AbstractValue,
exported: Exported) -> Tuple[core.AbstractValue, ...]:
exported_dim_vars = shape_poly.all_dim_vars(exported.in_avals)
if exported_dim_vars:
raise NotImplementedError("call_exported for exported with polymorphic shapes")
assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure
# Must express the exported_dim_vars in terms of the shapes in in_avals.
_ = shape_poly.unify_avals_with_args(
exported_dim_values = shape_poly.unify_avals_with_args(
exported.in_avals, exported_dim_vars, *in_avals, # type: ignore
use_static_dimension_size=True,
args_kwargs_tree=exported.in_tree)
return exported.out_avals
return tuple(
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
*exported_dim_values),
dtype=out_aval.dtype, weak_type=out_aval.weak_type,
named_shape=out_aval.named_shape)
for out_aval in exported.out_avals)
call_exported_p.def_abstract_eval(_call_exported_abstract_eval)
@ -783,16 +787,32 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
f"on '{platform}'.")
submodule = ir.Module.parse(exported.mlir_module)
symtab = ir.SymbolTable(submodule.operation)
# The called function may have been exported with polymorphic shapes and called
# now with more refined shapes. We insert hlo.ConvertOp to ensure the module
# is valid.
def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value:
new_ir_type = mlir.aval_to_ir_type(new_aval)
if x.type != new_ir_type:
return mlir.convert_hlo(ctx, x, x_aval, new_aval)
else:
return x
callee_result_types = symtab["main"].type.results
# TODO: maybe cache multiple calls
fn = mlir.merge_mlir_modules(ctx.module_context.module,
f"call_exported_{exported.fun_name}",
submodule)
kept_args = [a for i, a in enumerate(args) if i in exported.module_kept_var_idx]
kept_args = [
convert_shape(a, a_aval, exported_in_aval)
for i, (a, a_aval, exported_in_aval) in enumerate(zip(args, ctx.avals_in, exported.in_avals))
if i in exported.module_kept_var_idx]
call = func_dialect.CallOp(callee_result_types,
ir.FlatSymbolRefAttr.get(fn),
kept_args)
return call.results
# The ctx.avals_out already contain the abstract values refined by
# _call_exported_abstract_eval.
return tuple(convert_shape(out, out_aval, refined_out_aval)
for out, out_aval, refined_out_aval in zip(call.results, exported.out_avals, ctx.avals_out))
for _p in ("cpu", "tpu", "cuda", "rocm"):

View File

@ -23,6 +23,11 @@ from jax import tree_util
from jax import numpy as jnp
from jax.config import config
from jax.experimental.jax2tf import jax_export
try:
from jax.experimental.jax2tf import jax2tf # TODO: temporary
except ImportError:
jax2tf = None
from jax._src import core
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
@ -208,14 +213,51 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
jax_export.call_exported(exp_f2)(a))
def test_call_poly_error(self):
a = np.arange(4, dtype=np.float32)
exp_f1 = jax_export.export(jnp.sin)(
jax_export.poly_spec(a.shape, a.dtype, "b, ...")
)
with self.assertRaisesRegex(NotImplementedError,
"call_exported for exported with polymorphic shapes"):
jax_export.call_exported(exp_f1)(a)
# An inner function is exported with polymorphic shapes inner_poly_spec, and
# is called from an outer function, that is exported with outer_poly_spec.
@parameterized.named_parameters(
dict(testcase_name=f"inner={inner_poly_spec}_outer={outer_poly_spec}",
inner_poly_spec=inner_poly_spec, outer_poly_spec=outer_poly_spec,
expect_error=expect_error)
for inner_poly_spec, outer_poly_spec, expect_error in (
("3,a,a+b", "3,4,12", None),
("3,a,a+b", "3,4,c", None),
("3,a,a+b", "3,c,c", r"Dimension variable.*b.*must have.* >= 1. Found value 0"),
("3,a,a+b", "c,4,12", r"Shape mismatch for args\[0\] in dimension 0"),
("3,a,a+b", "3,c+4,12", None), # TODO: This should be an error, c = 0
("3,4,3*a", "3,4,12", None),
("3,4,5*a", "3,4,12", r"Dimension variable 'a' must have integer value >= 1. Found value 2.4"),
# ("3,a,a", "3,a,a", None), # TODO: wrong error. It should be shape mismatch
# ("3,4,5*a", "3,4,c", None), # TODO: wrong error. It should be "not divisible by 5"
))
def test_poly(self, inner_poly_spec="3,a,a+b",
outer_poly_spec="3,4,12", expect_error=None):
# Polymorphic export called with static or polymorphic shapes
def inner(x): # x: export_poly_spec
return jnp.reshape(x, (x.shape[0] * x.shape[1], x.shape[2]))
x1 = np.arange(3 * 4 * 6, dtype=np.float32).reshape((3, 4, 6)) # x1 : f32[3,4,6]
exp1 = jax_export.export(inner)(jax_export.poly_spec(x1.shape, x1.dtype, inner_poly_spec))
x2 = np.concatenate([x1, x1], axis=2) # x2: f32[3,4,12]
def outer(x): # x: call_poly_spec
# Use an addition to test that the shapes are refined properly for the
# result of the call_exported.
return jax_export.call_exported(exp1)(x) + inner(x)
with contextlib.ExitStack() as stack:
if expect_error is not None:
stack.push(self.assertRaisesRegex(ValueError, expect_error))
# Call it after exporting again, with polymorphic shapes
exp2 = jax_export.export(outer)(
jax_export.poly_spec(x2.shape, x2.dtype, outer_poly_spec))
# TODO: for now, we use XlaCallModule to run modules with polymorphic shapes
# until we create the python bindings to invoke shape refinement.
if jax2tf is not None:
res2 = jax2tf._run_exported_as_tf([x2], exp2)[0].numpy()
# res2 = jax_export.call_exported(exp2)(x2)
self.assertAllClose(2. * inner(x2), res2)
if __name__ == "__main__":