[shape_poly] Enable calling from JAX modules that use shape polymorphism.

If a JAX function uses shape polymorphism the resulting StableHLO contains
dynamic shapes and it is not directly compileable. However, in such modules
the dynamic shapes depend only on the input shapes, and in JAX jit the input
shapes are static. So, we run a shape refinement pass over the module to
resolve the dynamic shapes prior to compilation.

PiperOrigin-RevId: 538275268
This commit is contained in:
George Necula 2023-06-06 13:26:35 -07:00 committed by jax authors
parent 47b8e55451
commit c46fa27b52
3 changed files with 46 additions and 9 deletions

View File

@ -186,6 +186,12 @@ def batched_device_put(aval: core.ShapedArray,
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore
def refine_shape_polymorphism(module: ir.Module) -> ir.Module:
# In order to avoid depending on jax2tf/jax_export.py we will monkey patch
# this from jax_export to refine the polymorphic shapes in the module.
raise NotImplementedError("Compiling modules with shape polymorphism")
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
# from the input ShardingSpec, rather than the indices. However, this would
# require duplicating the ordering logic of spec_to_indices, which is more
@ -2624,7 +2630,8 @@ class UnloadedMeshExecutable:
shape_poly_state: Optional[mlir.ShapePolyLoweringState] = None,
compiler_options=None
) -> MeshExecutable:
del shape_poly_state
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
hlo = refine_shape_polymorphism(hlo)
compiler_options_keys = tuple(
compiler_options.keys()) if compiler_options is not None else None
compiler_options_values = tuple(

View File

@ -24,6 +24,7 @@ from absl import logging
import jax
from jax import sharding
from jax.lib import xla_client as xc
from jax._src import core
from jax._src import dispatch
@ -37,6 +38,7 @@ from jax._src.lib.mlir.dialects import stablehlo
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_extension
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
@ -815,7 +817,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
f"The exported function '{exported.fun_name}' was lowered for "
f"platform '{exported.lowering_platform}' but it is used "
f"on '{platform}'.")
if any(not core.is_constant_shape(a.shape) for a in exported.in_avals):
if exported.module_uses_dim_vars:
ctx.module_context.shape_poly_state.uses_dim_vars = True
submodule = ir.Module.parse(exported.mlir_module)
@ -852,3 +854,22 @@ for _p in ("cpu", "tpu", "cuda", "rocm"):
mlir.register_lowering(call_exported_p,
functools.partial(_call_exported_lowering, platform=_p),
platform=_p)
def _refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
"""Refine the polymorphic shapes inside a module.
Given a module with static input shapes, but using dynamic shapes due to
shape polymorphism, run shape refinement to resolve all the dynamic shapes.
"""
if xc.mlir_api_version < 50:
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
mlir.module_to_bytecode(module)
)
context = mlir.make_ir_context()
with context:
return ir.Module.parse(refined_module_str)
pxla.refine_shape_polymorphism = _refine_polymorphic_shapes

View File

@ -14,6 +14,7 @@
import contextlib
import logging
import math
import re
from typing import List
import unittest
@ -28,6 +29,7 @@ try:
except ImportError:
jax2tf = None # type: ignore
from jax.lib import xla_client as xc
from jax._src import core
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
@ -234,19 +236,25 @@ class JaxExportTest(jtu.JaxTestCase):
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,12"),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,c"),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c,c",
expect_error=(
expect_error=re.escape(
r"Dimension variable 'b' must have integer value >= 1. "
r"Found 0 when solving a \+ b == args\[0\].shape\[2\]")),
r"Found 0 when solving a + b == args[0].shape[2]")),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12",
expect_error=r"Shape mismatch for args\[0\].shape\[0\] \(expected constant\)"),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c+4,12"), # TODO: This should be an error, c = 0
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c+4,12",
expect_error=re.escape(
r"Dimension variable 'c' must have integer value >= 1. "
r"Found 0 when solving c + 4 == args[0].shape[1]")),
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,12"),
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,12",
expect_error=(
expect_error=re.escape(
r"Dimension variable 'a' must have integer value >= 1. "
r"Non-zero remainder 2 for factor 5 when solving 5\*a == args\[0\].shape\[2\]")),
r"Non-zero remainder 2 for factor 5 when solving 5*a == args[0].shape[2]")),
# dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,c"), # TODO: there should be an error 5*a != c == 12
# dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a"), # TODO: this should be a dynamic error
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a",
expect_error=re.escape(
r"Found inconsistency 12 != 4 when solving "
r"a == args[0].shape[2]")),
dict(inner_poly_spec="3,a", inner_x_shape=(3, 4), outer_poly_spec="3,a,a",
expect_error=r"Rank mismatch for args\[0\]"),
dict(inner_poly_spec="3,a,a+b", inner_x_dtype=np.int32, outer_poly_spec="3,c,d",
@ -283,7 +291,8 @@ class JaxExportTest(jtu.JaxTestCase):
jax_export.poly_spec(outer_x.shape, outer_x.dtype, outer_poly_spec))
self.assertEqual(outer_exp.module_uses_dim_vars,
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))
if not outer_exp.module_uses_dim_vars:
# TODO(necula): need conditionals until jaxlib 0.4.12 is the minimum version
if not outer_exp.module_uses_dim_vars or xc.mlir_api_version >= 50:
res = jax_export.call_exported(outer_exp)(outer_x)
self.assertAllClose(2. * inner(outer_x), res)
else: