mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[shape_poly] Enable calling from JAX modules that use shape polymorphism.
If a JAX function uses shape polymorphism the resulting StableHLO contains dynamic shapes and it is not directly compileable. However, in such modules the dynamic shapes depend only on the input shapes, and in JAX jit the input shapes are static. So, we run a shape refinement pass over the module to resolve the dynamic shapes prior to compilation. PiperOrigin-RevId: 538275268
This commit is contained in:
parent
47b8e55451
commit
c46fa27b52
@ -186,6 +186,12 @@ def batched_device_put(aval: core.ShapedArray,
|
||||
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore
|
||||
|
||||
|
||||
def refine_shape_polymorphism(module: ir.Module) -> ir.Module:
|
||||
# In order to avoid depending on jax2tf/jax_export.py we will monkey patch
|
||||
# this from jax_export to refine the polymorphic shapes in the module.
|
||||
raise NotImplementedError("Compiling modules with shape polymorphism")
|
||||
|
||||
|
||||
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
|
||||
# from the input ShardingSpec, rather than the indices. However, this would
|
||||
# require duplicating the ordering logic of spec_to_indices, which is more
|
||||
@ -2624,7 +2630,8 @@ class UnloadedMeshExecutable:
|
||||
shape_poly_state: Optional[mlir.ShapePolyLoweringState] = None,
|
||||
compiler_options=None
|
||||
) -> MeshExecutable:
|
||||
del shape_poly_state
|
||||
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
|
||||
hlo = refine_shape_polymorphism(hlo)
|
||||
compiler_options_keys = tuple(
|
||||
compiler_options.keys()) if compiler_options is not None else None
|
||||
compiler_options_values = tuple(
|
||||
|
@ -24,6 +24,7 @@ from absl import logging
|
||||
|
||||
import jax
|
||||
from jax import sharding
|
||||
from jax.lib import xla_client as xc
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -37,6 +38,7 @@ from jax._src.lib.mlir.dialects import stablehlo
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
@ -815,7 +817,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
f"The exported function '{exported.fun_name}' was lowered for "
|
||||
f"platform '{exported.lowering_platform}' but it is used "
|
||||
f"on '{platform}'.")
|
||||
if any(not core.is_constant_shape(a.shape) for a in exported.in_avals):
|
||||
if exported.module_uses_dim_vars:
|
||||
ctx.module_context.shape_poly_state.uses_dim_vars = True
|
||||
|
||||
submodule = ir.Module.parse(exported.mlir_module)
|
||||
@ -852,3 +854,22 @@ for _p in ("cpu", "tpu", "cuda", "rocm"):
|
||||
mlir.register_lowering(call_exported_p,
|
||||
functools.partial(_call_exported_lowering, platform=_p),
|
||||
platform=_p)
|
||||
|
||||
|
||||
def _refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
|
||||
"""Refine the polymorphic shapes inside a module.
|
||||
|
||||
Given a module with static input shapes, but using dynamic shapes due to
|
||||
shape polymorphism, run shape refinement to resolve all the dynamic shapes.
|
||||
"""
|
||||
if xc.mlir_api_version < 50:
|
||||
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")
|
||||
|
||||
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
|
||||
mlir.module_to_bytecode(module)
|
||||
)
|
||||
context = mlir.make_ir_context()
|
||||
with context:
|
||||
return ir.Module.parse(refined_module_str)
|
||||
|
||||
pxla.refine_shape_polymorphism = _refine_polymorphic_shapes
|
||||
|
@ -14,6 +14,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from typing import List
|
||||
import unittest
|
||||
|
||||
@ -28,6 +29,7 @@ try:
|
||||
except ImportError:
|
||||
jax2tf = None # type: ignore
|
||||
|
||||
from jax.lib import xla_client as xc
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
@ -234,19 +236,25 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,12"),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,c"),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c,c",
|
||||
expect_error=(
|
||||
expect_error=re.escape(
|
||||
r"Dimension variable 'b' must have integer value >= 1. "
|
||||
r"Found 0 when solving a \+ b == args\[0\].shape\[2\]")),
|
||||
r"Found 0 when solving a + b == args[0].shape[2]")),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12",
|
||||
expect_error=r"Shape mismatch for args\[0\].shape\[0\] \(expected constant\)"),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c+4,12"), # TODO: This should be an error, c = 0
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c+4,12",
|
||||
expect_error=re.escape(
|
||||
r"Dimension variable 'c' must have integer value >= 1. "
|
||||
r"Found 0 when solving c + 4 == args[0].shape[1]")),
|
||||
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,12"),
|
||||
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,12",
|
||||
expect_error=(
|
||||
expect_error=re.escape(
|
||||
r"Dimension variable 'a' must have integer value >= 1. "
|
||||
r"Non-zero remainder 2 for factor 5 when solving 5\*a == args\[0\].shape\[2\]")),
|
||||
r"Non-zero remainder 2 for factor 5 when solving 5*a == args[0].shape[2]")),
|
||||
# dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,c"), # TODO: there should be an error 5*a != c == 12
|
||||
# dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a"), # TODO: this should be a dynamic error
|
||||
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a",
|
||||
expect_error=re.escape(
|
||||
r"Found inconsistency 12 != 4 when solving "
|
||||
r"a == args[0].shape[2]")),
|
||||
dict(inner_poly_spec="3,a", inner_x_shape=(3, 4), outer_poly_spec="3,a,a",
|
||||
expect_error=r"Rank mismatch for args\[0\]"),
|
||||
dict(inner_poly_spec="3,a,a+b", inner_x_dtype=np.int32, outer_poly_spec="3,c,d",
|
||||
@ -283,7 +291,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
jax_export.poly_spec(outer_x.shape, outer_x.dtype, outer_poly_spec))
|
||||
self.assertEqual(outer_exp.module_uses_dim_vars,
|
||||
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))
|
||||
if not outer_exp.module_uses_dim_vars:
|
||||
# TODO(necula): need conditionals until jaxlib 0.4.12 is the minimum version
|
||||
if not outer_exp.module_uses_dim_vars or xc.mlir_api_version >= 50:
|
||||
res = jax_export.call_exported(outer_exp)(outer_x)
|
||||
self.assertAllClose(2. * inner(outer_x), res)
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user