From c46fa27b521b54687fdf86886c5769f0b0220c7c Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 6 Jun 2023 13:26:35 -0700 Subject: [PATCH] [shape_poly] Enable calling from JAX modules that use shape polymorphism. If a JAX function uses shape polymorphism the resulting StableHLO contains dynamic shapes and it is not directly compileable. However, in such modules the dynamic shapes depend only on the input shapes, and in JAX jit the input shapes are static. So, we run a shape refinement pass over the module to resolve the dynamic shapes prior to compilation. PiperOrigin-RevId: 538275268 --- jax/_src/interpreters/pxla.py | 9 +++++++- jax/experimental/jax2tf/jax_export.py | 23 ++++++++++++++++++- .../jax2tf/tests/jax_export_test.py | 23 +++++++++++++------ 3 files changed, 46 insertions(+), 9 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 69a6f40ee..67ebd2408 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -186,6 +186,12 @@ def batched_device_put(aval: core.ShapedArray, return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore +def refine_shape_polymorphism(module: ir.Module) -> ir.Module: + # In order to avoid depending on jax2tf/jax_export.py we will monkey patch + # this from jax_export to refine the polymorphic shapes in the module. + raise NotImplementedError("Compiling modules with shape polymorphism") + + # NOTE(skye): we could refactor to generate _multi_slice parameters directly # from the input ShardingSpec, rather than the indices. However, this would # require duplicating the ordering logic of spec_to_indices, which is more @@ -2624,7 +2630,8 @@ class UnloadedMeshExecutable: shape_poly_state: Optional[mlir.ShapePolyLoweringState] = None, compiler_options=None ) -> MeshExecutable: - del shape_poly_state + if shape_poly_state is not None and shape_poly_state.uses_dim_vars: + hlo = refine_shape_polymorphism(hlo) compiler_options_keys = tuple( compiler_options.keys()) if compiler_options is not None else None compiler_options_values = tuple( diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 7ceee4dff..fb4ea2583 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -24,6 +24,7 @@ from absl import logging import jax from jax import sharding +from jax.lib import xla_client as xc from jax._src import core from jax._src import dispatch @@ -37,6 +38,7 @@ from jax._src.lib.mlir.dialects import stablehlo from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import func as func_dialect +from jax._src.lib import xla_extension from jax._src import tree_util from jax._src import util from jax._src import xla_bridge as xb @@ -815,7 +817,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, f"The exported function '{exported.fun_name}' was lowered for " f"platform '{exported.lowering_platform}' but it is used " f"on '{platform}'.") - if any(not core.is_constant_shape(a.shape) for a in exported.in_avals): + if exported.module_uses_dim_vars: ctx.module_context.shape_poly_state.uses_dim_vars = True submodule = ir.Module.parse(exported.mlir_module) @@ -852,3 +854,22 @@ for _p in ("cpu", "tpu", "cuda", "rocm"): mlir.register_lowering(call_exported_p, functools.partial(_call_exported_lowering, platform=_p), platform=_p) + + +def _refine_polymorphic_shapes(module: ir.Module) -> ir.Module: + """Refine the polymorphic shapes inside a module. + + Given a module with static input shapes, but using dynamic shapes due to + shape polymorphism, run shape refinement to resolve all the dynamic shapes. + """ + if xc.mlir_api_version < 50: + raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12") + + refined_module_str = xla_extension.mlir.refine_polymorphic_shapes( + mlir.module_to_bytecode(module) + ) + context = mlir.make_ir_context() + with context: + return ir.Module.parse(refined_module_str) + +pxla.refine_shape_polymorphism = _refine_polymorphic_shapes diff --git a/jax/experimental/jax2tf/tests/jax_export_test.py b/jax/experimental/jax2tf/tests/jax_export_test.py index 384e606fa..976bbde15 100644 --- a/jax/experimental/jax2tf/tests/jax_export_test.py +++ b/jax/experimental/jax2tf/tests/jax_export_test.py @@ -14,6 +14,7 @@ import contextlib import logging import math +import re from typing import List import unittest @@ -28,6 +29,7 @@ try: except ImportError: jax2tf = None # type: ignore +from jax.lib import xla_client as xc from jax._src import core from jax._src import test_util as jtu from jax._src import xla_bridge as xb @@ -234,19 +236,25 @@ class JaxExportTest(jtu.JaxTestCase): dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,12"), dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,c"), dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c,c", - expect_error=( + expect_error=re.escape( r"Dimension variable 'b' must have integer value >= 1. " - r"Found 0 when solving a \+ b == args\[0\].shape\[2\]")), + r"Found 0 when solving a + b == args[0].shape[2]")), dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12", expect_error=r"Shape mismatch for args\[0\].shape\[0\] \(expected constant\)"), - dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c+4,12"), # TODO: This should be an error, c = 0 + dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c+4,12", + expect_error=re.escape( + r"Dimension variable 'c' must have integer value >= 1. " + r"Found 0 when solving c + 4 == args[0].shape[1]")), dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,12"), dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,12", - expect_error=( + expect_error=re.escape( r"Dimension variable 'a' must have integer value >= 1. " - r"Non-zero remainder 2 for factor 5 when solving 5\*a == args\[0\].shape\[2\]")), + r"Non-zero remainder 2 for factor 5 when solving 5*a == args[0].shape[2]")), # dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,c"), # TODO: there should be an error 5*a != c == 12 - # dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a"), # TODO: this should be a dynamic error + dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a", + expect_error=re.escape( + r"Found inconsistency 12 != 4 when solving " + r"a == args[0].shape[2]")), dict(inner_poly_spec="3,a", inner_x_shape=(3, 4), outer_poly_spec="3,a,a", expect_error=r"Rank mismatch for args\[0\]"), dict(inner_poly_spec="3,a,a+b", inner_x_dtype=np.int32, outer_poly_spec="3,c,d", @@ -283,7 +291,8 @@ class JaxExportTest(jtu.JaxTestCase): jax_export.poly_spec(outer_x.shape, outer_x.dtype, outer_poly_spec)) self.assertEqual(outer_exp.module_uses_dim_vars, (inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12")) - if not outer_exp.module_uses_dim_vars: + # TODO(necula): need conditionals until jaxlib 0.4.12 is the minimum version + if not outer_exp.module_uses_dim_vars or xc.mlir_api_version >= 50: res = jax_export.call_exported(outer_exp)(outer_x) self.assertAllClose(2. * inner(outer_x), res) else: