diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 69a6f40ee..67ebd2408 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -186,6 +186,12 @@ def batched_device_put(aval: core.ShapedArray, return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore +def refine_shape_polymorphism(module: ir.Module) -> ir.Module: + # In order to avoid depending on jax2tf/jax_export.py we will monkey patch + # this from jax_export to refine the polymorphic shapes in the module. + raise NotImplementedError("Compiling modules with shape polymorphism") + + # NOTE(skye): we could refactor to generate _multi_slice parameters directly # from the input ShardingSpec, rather than the indices. However, this would # require duplicating the ordering logic of spec_to_indices, which is more @@ -2624,7 +2630,8 @@ class UnloadedMeshExecutable: shape_poly_state: Optional[mlir.ShapePolyLoweringState] = None, compiler_options=None ) -> MeshExecutable: - del shape_poly_state + if shape_poly_state is not None and shape_poly_state.uses_dim_vars: + hlo = refine_shape_polymorphism(hlo) compiler_options_keys = tuple( compiler_options.keys()) if compiler_options is not None else None compiler_options_values = tuple( diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 7ceee4dff..fb4ea2583 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -24,6 +24,7 @@ from absl import logging import jax from jax import sharding +from jax.lib import xla_client as xc from jax._src import core from jax._src import dispatch @@ -37,6 +38,7 @@ from jax._src.lib.mlir.dialects import stablehlo from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import func as func_dialect +from jax._src.lib import xla_extension from jax._src import tree_util from jax._src import util from jax._src import xla_bridge as xb @@ -815,7 +817,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, f"The exported function '{exported.fun_name}' was lowered for " f"platform '{exported.lowering_platform}' but it is used " f"on '{platform}'.") - if any(not core.is_constant_shape(a.shape) for a in exported.in_avals): + if exported.module_uses_dim_vars: ctx.module_context.shape_poly_state.uses_dim_vars = True submodule = ir.Module.parse(exported.mlir_module) @@ -852,3 +854,22 @@ for _p in ("cpu", "tpu", "cuda", "rocm"): mlir.register_lowering(call_exported_p, functools.partial(_call_exported_lowering, platform=_p), platform=_p) + + +def _refine_polymorphic_shapes(module: ir.Module) -> ir.Module: + """Refine the polymorphic shapes inside a module. + + Given a module with static input shapes, but using dynamic shapes due to + shape polymorphism, run shape refinement to resolve all the dynamic shapes. + """ + if xc.mlir_api_version < 50: + raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12") + + refined_module_str = xla_extension.mlir.refine_polymorphic_shapes( + mlir.module_to_bytecode(module) + ) + context = mlir.make_ir_context() + with context: + return ir.Module.parse(refined_module_str) + +pxla.refine_shape_polymorphism = _refine_polymorphic_shapes diff --git a/jax/experimental/jax2tf/tests/jax_export_test.py b/jax/experimental/jax2tf/tests/jax_export_test.py index 384e606fa..976bbde15 100644 --- a/jax/experimental/jax2tf/tests/jax_export_test.py +++ b/jax/experimental/jax2tf/tests/jax_export_test.py @@ -14,6 +14,7 @@ import contextlib import logging import math +import re from typing import List import unittest @@ -28,6 +29,7 @@ try: except ImportError: jax2tf = None # type: ignore +from jax.lib import xla_client as xc from jax._src import core from jax._src import test_util as jtu from jax._src import xla_bridge as xb @@ -234,19 +236,25 @@ class JaxExportTest(jtu.JaxTestCase): dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,12"), dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,c"), dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c,c", - expect_error=( + expect_error=re.escape( r"Dimension variable 'b' must have integer value >= 1. " - r"Found 0 when solving a \+ b == args\[0\].shape\[2\]")), + r"Found 0 when solving a + b == args[0].shape[2]")), dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12", expect_error=r"Shape mismatch for args\[0\].shape\[0\] \(expected constant\)"), - dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c+4,12"), # TODO: This should be an error, c = 0 + dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c+4,12", + expect_error=re.escape( + r"Dimension variable 'c' must have integer value >= 1. " + r"Found 0 when solving c + 4 == args[0].shape[1]")), dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,12"), dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,12", - expect_error=( + expect_error=re.escape( r"Dimension variable 'a' must have integer value >= 1. " - r"Non-zero remainder 2 for factor 5 when solving 5\*a == args\[0\].shape\[2\]")), + r"Non-zero remainder 2 for factor 5 when solving 5*a == args[0].shape[2]")), # dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,c"), # TODO: there should be an error 5*a != c == 12 - # dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a"), # TODO: this should be a dynamic error + dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a", + expect_error=re.escape( + r"Found inconsistency 12 != 4 when solving " + r"a == args[0].shape[2]")), dict(inner_poly_spec="3,a", inner_x_shape=(3, 4), outer_poly_spec="3,a,a", expect_error=r"Rank mismatch for args\[0\]"), dict(inner_poly_spec="3,a,a+b", inner_x_dtype=np.int32, outer_poly_spec="3,c,d", @@ -283,7 +291,8 @@ class JaxExportTest(jtu.JaxTestCase): jax_export.poly_spec(outer_x.shape, outer_x.dtype, outer_poly_spec)) self.assertEqual(outer_exp.module_uses_dim_vars, (inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12")) - if not outer_exp.module_uses_dim_vars: + # TODO(necula): need conditionals until jaxlib 0.4.12 is the minimum version + if not outer_exp.module_uses_dim_vars or xc.mlir_api_version >= 50: res = jax_export.call_exported(outer_exp)(outer_x) self.assertAllClose(2. * inner(outer_x), res) else: