[shape_poly] Add a polymorphic shape refinement MLIR pass accessible to JAX Python.

At the moment we can run the StableHLO module lowered by jax2tf
with polymorphic shapes only with jax2tf, because the tf.XlaCallModule op has the
necessary shape refinement logic (which is necessary to legalize
the StableHLO module with dynamic shapes to MHLO). Here we
expose the shape refinement MLIR transformation to JAX Python.

For now this is used only in a test in jax_export_test.py.

PiperOrigin-RevId: 537485288
This commit is contained in:
George Necula 2023-06-02 21:48:45 -07:00 committed by jax authors
parent 886185831f
commit ec8b855fa1

View File

@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
import math
from typing import List
import unittest
from absl.testing import absltest, parameterized
import jax
from jax import tree_util
from jax import numpy as jnp
from jax import tree_util
from jax.config import config
from jax.experimental.jax2tf import jax_export
try:
@ -292,6 +293,82 @@ class JaxExportTest(jtu.JaxTestCase):
res = jax2tf._run_exported_as_tf([outer_x], outer_exp)[0].numpy()
self.assertAllClose(2. * inner(outer_x), res)
def test_call_poly(self):
a_shape = (3, 4)
a = np.arange(math.prod(a_shape), dtype=np.float32).reshape(a_shape)
def f_inner(x): # x: f32[w, h]
return jnp.reshape(x, (-1,))
exp_inner = jax_export.export(f_inner)(
jax_export.poly_spec(a.shape, a.dtype, "w, h")
)
# There are dynamic shapes in the exported module
self.assertIn("?x", exp_inner.mlir_module)
self.assertIn("stablehlo.dynamic_reshape", exp_inner.mlir_module)
# Add a wrapper "main" func with static shapes
# TODO(necula): We will add this functionality to jax_export.
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax.lib import xla_client as xc
from jax._src.lib import xla_extension
context = mlir.make_ir_context()
with context, ir.Location.unknown(context):
wrapped_module = ir.Module.parse(exp_inner.mlir_module)
symbol_table = ir.SymbolTable(wrapped_module.operation)
orig_main = symbol_table["main"]
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
# Use static shapes
new_main_input_types = [
mlir.aval_to_ir_type(core.ShapedArray((3, 4), np.float32))
]
orig_output_types = orig_main.type.results
new_main_ftype = ir.FunctionType.get(
new_main_input_types, orig_output_types
)
new_main_op = func_dialect.FuncOp(
"main",
new_main_ftype,
ip=ir.InsertionPoint.at_block_begin(wrapped_module.body),
)
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
symbol_table.insert(new_main_op)
entry_block = new_main_op.add_entry_block()
with ir.InsertionPoint(entry_block):
orig_main_args: List[ir.Value] = []
for new_arg, orig_arg_type in zip(
new_main_op.arguments, orig_main.type.inputs
):
orig_main_args.append(hlo.ConvertOp(orig_arg_type, new_arg).result)
call = func_dialect.CallOp(
orig_output_types,
ir.FlatSymbolRefAttr.get(orig_main_name),
orig_main_args,
)
func_dialect.ReturnOp(call.results)
symbol_table.set_symbol_name(new_main_op, "main")
# TODO(necula): need conditionals until jaxlib 0.4.12 is the minimum version
if xc.mlir_api_version >= 50:
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
mlir.module_to_bytecode(wrapped_module)
)
context = mlir.make_ir_context()
with context:
refined_module = ir.Module.parse(refined_module_str)
logging.info("Postprocessed module %s", str(refined_module))
self.assertNotIn("?x", str(refined_module))
self.assertNotIn("stablehlo.dynamic_reshape", str(refined_module))
self.assertIn("stablehlo.reshape", str(refined_module))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())