mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[shape_poly] Add a polymorphic shape refinement MLIR pass accessible to JAX Python.
At the moment we can run the StableHLO module lowered by jax2tf with polymorphic shapes only with jax2tf, because the tf.XlaCallModule op has the necessary shape refinement logic (which is necessary to legalize the StableHLO module with dynamic shapes to MHLO). Here we expose the shape refinement MLIR transformation to JAX Python. For now this is used only in a test in jax_export_test.py. PiperOrigin-RevId: 537485288
This commit is contained in:
parent
886185831f
commit
ec8b855fa1
@ -12,14 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
import jax
|
||||
from jax import tree_util
|
||||
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax.config import config
|
||||
from jax.experimental.jax2tf import jax_export
|
||||
try:
|
||||
@ -292,6 +293,82 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
res = jax2tf._run_exported_as_tf([outer_x], outer_exp)[0].numpy()
|
||||
self.assertAllClose(2. * inner(outer_x), res)
|
||||
|
||||
def test_call_poly(self):
|
||||
a_shape = (3, 4)
|
||||
a = np.arange(math.prod(a_shape), dtype=np.float32).reshape(a_shape)
|
||||
|
||||
def f_inner(x): # x: f32[w, h]
|
||||
return jnp.reshape(x, (-1,))
|
||||
|
||||
exp_inner = jax_export.export(f_inner)(
|
||||
jax_export.poly_spec(a.shape, a.dtype, "w, h")
|
||||
)
|
||||
|
||||
# There are dynamic shapes in the exported module
|
||||
self.assertIn("?x", exp_inner.mlir_module)
|
||||
self.assertIn("stablehlo.dynamic_reshape", exp_inner.mlir_module)
|
||||
|
||||
# Add a wrapper "main" func with static shapes
|
||||
# TODO(necula): We will add this functionality to jax_export.
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
|
||||
context = mlir.make_ir_context()
|
||||
with context, ir.Location.unknown(context):
|
||||
wrapped_module = ir.Module.parse(exp_inner.mlir_module)
|
||||
symbol_table = ir.SymbolTable(wrapped_module.operation)
|
||||
orig_main = symbol_table["main"]
|
||||
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
||||
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
|
||||
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
|
||||
# Use static shapes
|
||||
new_main_input_types = [
|
||||
mlir.aval_to_ir_type(core.ShapedArray((3, 4), np.float32))
|
||||
]
|
||||
orig_output_types = orig_main.type.results
|
||||
new_main_ftype = ir.FunctionType.get(
|
||||
new_main_input_types, orig_output_types
|
||||
)
|
||||
new_main_op = func_dialect.FuncOp(
|
||||
"main",
|
||||
new_main_ftype,
|
||||
ip=ir.InsertionPoint.at_block_begin(wrapped_module.body),
|
||||
)
|
||||
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
|
||||
symbol_table.insert(new_main_op)
|
||||
entry_block = new_main_op.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
orig_main_args: List[ir.Value] = []
|
||||
for new_arg, orig_arg_type in zip(
|
||||
new_main_op.arguments, orig_main.type.inputs
|
||||
):
|
||||
orig_main_args.append(hlo.ConvertOp(orig_arg_type, new_arg).result)
|
||||
call = func_dialect.CallOp(
|
||||
orig_output_types,
|
||||
ir.FlatSymbolRefAttr.get(orig_main_name),
|
||||
orig_main_args,
|
||||
)
|
||||
func_dialect.ReturnOp(call.results)
|
||||
symbol_table.set_symbol_name(new_main_op, "main")
|
||||
|
||||
# TODO(necula): need conditionals until jaxlib 0.4.12 is the minimum version
|
||||
if xc.mlir_api_version >= 50:
|
||||
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
|
||||
mlir.module_to_bytecode(wrapped_module)
|
||||
)
|
||||
context = mlir.make_ir_context()
|
||||
with context:
|
||||
refined_module = ir.Module.parse(refined_module_str)
|
||||
|
||||
logging.info("Postprocessed module %s", str(refined_module))
|
||||
self.assertNotIn("?x", str(refined_module))
|
||||
self.assertNotIn("stablehlo.dynamic_reshape", str(refined_module))
|
||||
self.assertIn("stablehlo.reshape", str(refined_module))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user