mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Add partial support for call_exported with polymorphic shapes
Until now the jax_export.call_exported did not allow calling functions that were exported with polymorphic shapes. We now add that support, including resolving the dimension variables of the called function in terms of the shapes at the call site (which themselves may include dimension variables), and then computing the output shape of the called function. The support is partial in that we can export a JAX function that calls an exported polymorphic function, but we cannot invoke it. This is because we do not yet have access to the shape refinement machinery that XlaCallModule uses. For now, we use XlaCallModule for invoking exported that includes shape polymorphism.
This commit is contained in:
parent
7833528765
commit
46a258ba17
@ -755,15 +755,19 @@ call_exported_p.multiple_results = True
|
||||
def _call_exported_abstract_eval(*in_avals: core.AbstractValue,
|
||||
exported: Exported) -> Tuple[core.AbstractValue, ...]:
|
||||
exported_dim_vars = shape_poly.all_dim_vars(exported.in_avals)
|
||||
if exported_dim_vars:
|
||||
raise NotImplementedError("call_exported for exported with polymorphic shapes")
|
||||
assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure
|
||||
# Must express the exported_dim_vars in terms of the shapes in in_avals.
|
||||
_ = shape_poly.unify_avals_with_args(
|
||||
exported_dim_values = shape_poly.unify_avals_with_args(
|
||||
exported.in_avals, exported_dim_vars, *in_avals, # type: ignore
|
||||
use_static_dimension_size=True,
|
||||
args_kwargs_tree=exported.in_tree)
|
||||
return exported.out_avals
|
||||
|
||||
return tuple(
|
||||
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
|
||||
*exported_dim_values),
|
||||
dtype=out_aval.dtype, weak_type=out_aval.weak_type,
|
||||
named_shape=out_aval.named_shape)
|
||||
for out_aval in exported.out_avals)
|
||||
|
||||
|
||||
call_exported_p.def_abstract_eval(_call_exported_abstract_eval)
|
||||
@ -783,16 +787,32 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
f"on '{platform}'.")
|
||||
submodule = ir.Module.parse(exported.mlir_module)
|
||||
symtab = ir.SymbolTable(submodule.operation)
|
||||
# The called function may have been exported with polymorphic shapes and called
|
||||
# now with more refined shapes. We insert hlo.ConvertOp to ensure the module
|
||||
# is valid.
|
||||
def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value:
|
||||
new_ir_type = mlir.aval_to_ir_type(new_aval)
|
||||
if x.type != new_ir_type:
|
||||
return mlir.convert_hlo(ctx, x, x_aval, new_aval)
|
||||
else:
|
||||
return x
|
||||
|
||||
callee_result_types = symtab["main"].type.results
|
||||
# TODO: maybe cache multiple calls
|
||||
fn = mlir.merge_mlir_modules(ctx.module_context.module,
|
||||
f"call_exported_{exported.fun_name}",
|
||||
submodule)
|
||||
kept_args = [a for i, a in enumerate(args) if i in exported.module_kept_var_idx]
|
||||
kept_args = [
|
||||
convert_shape(a, a_aval, exported_in_aval)
|
||||
for i, (a, a_aval, exported_in_aval) in enumerate(zip(args, ctx.avals_in, exported.in_avals))
|
||||
if i in exported.module_kept_var_idx]
|
||||
call = func_dialect.CallOp(callee_result_types,
|
||||
ir.FlatSymbolRefAttr.get(fn),
|
||||
kept_args)
|
||||
return call.results
|
||||
# The ctx.avals_out already contain the abstract values refined by
|
||||
# _call_exported_abstract_eval.
|
||||
return tuple(convert_shape(out, out_aval, refined_out_aval)
|
||||
for out, out_aval, refined_out_aval in zip(call.results, exported.out_avals, ctx.avals_out))
|
||||
|
||||
|
||||
for _p in ("cpu", "tpu", "cuda", "rocm"):
|
||||
|
@ -23,6 +23,11 @@ from jax import tree_util
|
||||
from jax import numpy as jnp
|
||||
from jax.config import config
|
||||
from jax.experimental.jax2tf import jax_export
|
||||
try:
|
||||
from jax.experimental.jax2tf import jax2tf # TODO: temporary
|
||||
except ImportError:
|
||||
jax2tf = None
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
@ -208,14 +213,51 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
|
||||
jax_export.call_exported(exp_f2)(a))
|
||||
|
||||
def test_call_poly_error(self):
|
||||
a = np.arange(4, dtype=np.float32)
|
||||
exp_f1 = jax_export.export(jnp.sin)(
|
||||
jax_export.poly_spec(a.shape, a.dtype, "b, ...")
|
||||
)
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"call_exported for exported with polymorphic shapes"):
|
||||
jax_export.call_exported(exp_f1)(a)
|
||||
# An inner function is exported with polymorphic shapes inner_poly_spec, and
|
||||
# is called from an outer function, that is exported with outer_poly_spec.
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"inner={inner_poly_spec}_outer={outer_poly_spec}",
|
||||
inner_poly_spec=inner_poly_spec, outer_poly_spec=outer_poly_spec,
|
||||
expect_error=expect_error)
|
||||
for inner_poly_spec, outer_poly_spec, expect_error in (
|
||||
("3,a,a+b", "3,4,12", None),
|
||||
("3,a,a+b", "3,4,c", None),
|
||||
("3,a,a+b", "3,c,c", r"Dimension variable.*b.*must have.* >= 1. Found value 0"),
|
||||
("3,a,a+b", "c,4,12", r"Shape mismatch for args\[0\] in dimension 0"),
|
||||
("3,a,a+b", "3,c+4,12", None), # TODO: This should be an error, c = 0
|
||||
("3,4,3*a", "3,4,12", None),
|
||||
("3,4,5*a", "3,4,12", r"Dimension variable 'a' must have integer value >= 1. Found value 2.4"),
|
||||
# ("3,a,a", "3,a,a", None), # TODO: wrong error. It should be shape mismatch
|
||||
# ("3,4,5*a", "3,4,c", None), # TODO: wrong error. It should be "not divisible by 5"
|
||||
))
|
||||
def test_poly(self, inner_poly_spec="3,a,a+b",
|
||||
outer_poly_spec="3,4,12", expect_error=None):
|
||||
# Polymorphic export called with static or polymorphic shapes
|
||||
def inner(x): # x: export_poly_spec
|
||||
return jnp.reshape(x, (x.shape[0] * x.shape[1], x.shape[2]))
|
||||
|
||||
x1 = np.arange(3 * 4 * 6, dtype=np.float32).reshape((3, 4, 6)) # x1 : f32[3,4,6]
|
||||
exp1 = jax_export.export(inner)(jax_export.poly_spec(x1.shape, x1.dtype, inner_poly_spec))
|
||||
|
||||
x2 = np.concatenate([x1, x1], axis=2) # x2: f32[3,4,12]
|
||||
def outer(x): # x: call_poly_spec
|
||||
# Use an addition to test that the shapes are refined properly for the
|
||||
# result of the call_exported.
|
||||
return jax_export.call_exported(exp1)(x) + inner(x)
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error is not None:
|
||||
stack.push(self.assertRaisesRegex(ValueError, expect_error))
|
||||
|
||||
# Call it after exporting again, with polymorphic shapes
|
||||
exp2 = jax_export.export(outer)(
|
||||
jax_export.poly_spec(x2.shape, x2.dtype, outer_poly_spec))
|
||||
# TODO: for now, we use XlaCallModule to run modules with polymorphic shapes
|
||||
# until we create the python bindings to invoke shape refinement.
|
||||
if jax2tf is not None:
|
||||
res2 = jax2tf._run_exported_as_tf([x2], exp2)[0].numpy()
|
||||
# res2 = jax_export.call_exported(exp2)(x2)
|
||||
self.assertAllClose(2. * inner(x2), res2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user