Better mosaic lowering for dynamic shapes, extend an interpreter into shape_poly dimexpr and lower them alongside the graph if we are in a dynamic export regime.

PiperOrigin-RevId: 738171437
This commit is contained in:
jax authors 2025-03-18 15:50:27 -07:00
parent 0fb59747f0
commit 01a110c4c9
3 changed files with 129 additions and 12 deletions

View File

@ -35,6 +35,7 @@ from jax._src import linear_util as lu
from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.export._export import export
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.state import discharge as state_discharge
@ -1165,14 +1166,16 @@ jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule
def lower_as_mlir(
f, *args, dynamic_shapes=False, device=None, **kwargs
f, *args, dynamic_shapes=False, device=None, static_argnames=(), **kwargs
) -> mlir.ir.Module:
with pallas_export_experimental(dynamic_shapes):
lowered = jax.jit(f, device=device).lower(*args, **kwargs)
stablehlo = lowered.compiler_ir(dialect="stablehlo")
f = jax.jit(f, device=device, static_argnames=static_argnames)
exported = export(f, platforms=["tpu"])(*args, **kwargs)
stablehlo = exported.mlir_module()
return stablehlo # type: ignore[return-value]
_out_shape_to_aval_mapping: dict[
type[Any], Callable[[Any], jax_core.AbstractValue]
] = {}

View File

@ -40,6 +40,7 @@ from jax._src import source_info_util
from jax._src import state
from jax._src import traceback_util
from jax._src.cloud_tpu_init import is_cloud_tpu_older_than
from jax._src.export._export import export
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
@ -89,6 +90,11 @@ BOOL_MEMREF_TYPE = np.dtype('int32')
# The value interpreted as a dynamic dimension by MLIR.
MLIR_DYNAMIC = -9223372036854775808
# TODO(mvoz): Find a way to make this a contract we can share with the
# export specialization step in XLA export.
DIM_UPPER_BOUND = np.iinfo(np.int32).max
DIM_LOWER_BOUND = -128
partial = functools.partial
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin
@ -102,17 +108,49 @@ class MeshContext:
# Note - On Export Placeholders
#
# Mosaic uses vector IR, which does not have a concept of dynamic
# dimensions. We need to come up with a way to represent dynamic dimensions in
# vector IR, and so we use placeholders, which are later replaced during
# specialization.
# Since the vector dialect used by Mosaic does not support dynamic shapes,
# we replace all top-level symbolic dimensions with placeholder
# constants (between max(int32) - 128 and max(int32)) and we keep a
# mapping from the placeholder constants to SHLO functions that encode
# the symbolic dimension expression, as a function of the dimension
# variables.
#
# The calling convention of the produced MLIR module is the same as
# regular mosaic module, except we add on two new attributes to the custom call
# *per* intermediary placeholder dimension.
#
# The attributes are:
#
# tpu.dynamic_dimension_mapping_arg_name_<placeholder>
# tpu.dynamic_dimension_mapping_module_<placeholder>
#
# The first attribute is a comma-separated list of the dimension variables
# that are used to compute the symbolic dimension expression for the
# placeholder. The second attribute is the MLIR module that contains the
# SHLO functions that compute the symbolic dimension expression for the
# placeholder.
class LoweringDynamicShapeEnv:
dim_expr_to_placeholder: dict[Any, ir.Value] = {}
dim_expr_to_placeholder: dict[shape_poly._DimExpr, int] = {}
placeholder_to_dim_expr: dict[int, shape_poly._DimExpr] = {}
def to_placeholder(self, dim_expr: Any) -> ir.Value:
if jax_core.is_constant_dim(dim_expr):
# avoid ints, these are not dynamic
return dim_expr
if dim_expr not in self.dim_expr_to_placeholder:
next_val = np.iinfo(np.int32).max - len(self.dim_expr_to_placeholder)
next_val = DIM_UPPER_BOUND - len(self.dim_expr_to_placeholder)
if next_val < DIM_LOWER_BOUND:
# In practice, even with the largest of programs, we see rarely see
# anything even close to this limit. It is arbitrary, and can be safely
# increased if needed.
raise ValueError(
"Too many dynamic shapes in the input. Mosaic currently only"
" supports up to 128 dynamic dimension values."
)
self.dim_expr_to_placeholder[dim_expr] = next_val
# Reverse mapping - this is consumed to generate a table that is either
# input<>placeholder or intermediary computation<>placeholder.
self.placeholder_to_dim_expr[next_val] = dim_expr
return self.dim_expr_to_placeholder[dim_expr]
@ -622,6 +660,7 @@ def lower_jaxpr_to_module(
"Pallas TPU requires a libTPU version that's at most a month old"
)
debug_info = jaxpr.debug_info
_mosaic_lowering_dynamic_shape_env = None
if dynamic_shape_replacement_enabled:
_mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv()
@ -663,10 +702,12 @@ def lower_jaxpr_to_module(
for_verification=for_verification,
forward_compatible=lowering_context.is_forward_compat(),
dynamic_shape_replacement_fn=dynamic_shape_replacement_fn,
dynamic_shape_replacement_enabled=dynamic_shape_replacement_enabled,
)
m.body.append(func_op)
sym_tab.insert(func_op)
window_params = []
static_grid = None
grid = mosaic_grid_mapping.grid
if grid:
for i, bm in enumerate(grid_mapping.block_mappings):
@ -738,7 +779,6 @@ def lower_jaxpr_to_module(
]
static_grid = dynamic_shape_replacement_fn(static_grid)
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)
func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types))
func_op.attributes["scratch_operands"] = ir.IntegerAttr.get(
@ -746,6 +786,60 @@ def lower_jaxpr_to_module(
func_op.attributes["dimension_semantics"] = (
mosaic_grid_mapping.get_dimension_semantics()
)
if dynamic_shape_replacement_enabled:
if _mosaic_lowering_dynamic_shape_env is None:
raise ValueError(
"Dynamic shape env is None, invariant violated. Unreachable?"
)
# Now we can use jax to compute the dynamic shape graph
if static_grid is not None:
grid_vars = [
_mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.get(g, g)
for g in static_grid
]
else:
grid_vars = []
invars = [invar.aval for invar in jaxpr.invars]
# Faux shape for grid, just to get the avals
invars.append(jax.ShapeDtypeStruct(grid_vars, jax.numpy.int32))
args_dimvars = shape_poly.all_dim_vars(invars)
# This is dimexpr var -> placeholder value for when we jit the dim expr
env: dict[str, int] = {}
for aval in args_dimvars:
env[aval] = _mosaic_lowering_dynamic_shape_env.to_placeholder(aval)
for (
placeholder,
dim_expr,
) in _mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.items():
top_level_names = list(env.keys())
if dim_expr not in top_level_names:
jitted_eval = jax.jit(
jax_core.evaluate_shape,
static_argnames=(
"shape",
"dim_vars",
),
keep_unused=True,
)
stablehlo = export(
jitted_eval, platforms=[str(jax.devices()[0].platform)]
)(
(dim_expr,), tuple(args_dimvars), *(env[v] for v in args_dimvars)
).mlir_module()
arg_name = args_dimvars
# See Note - On Export Placeholders for more details.
m.operation.attributes[
"tpu.dynamic_dimension_mapping_module_" + str(placeholder)
] = ir.StringAttr.get(str(stablehlo))
arg_name_str = ",".join(arg_name)
m.operation.attributes[
"tpu.dynamic_dimension_mapping_arg_name_" + str(placeholder)
] = ir.StringAttr.get(arg_name_str)
return m, mosaic_grid_mapping.get_extra_args()
@ -828,6 +922,7 @@ def lower_jaxpr_to_func(
dynamic_shape_replacement_fn: (
Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None
) = None,
dynamic_shape_replacement_enabled: bool = False,
) -> func.FuncOp:
num_grid = len(mosaic_grid_mapping.grid_types)
num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types)
@ -874,6 +969,12 @@ def lower_jaxpr_to_func(
)
body_func.__name__ = name
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
if dynamic_shape_replacement_enabled:
# Skip verification for dynamic shape replacement - you can potentially
# produce ir like ex: add(x[placeholder_0, placeholder_1], y[128, 128])
# which is not valid, but we don't care since we'll run the verifier again
# after the dynamic shape replacement pass.
return body.func_op
try:
body.func_op.verify()
except ir.MLIRError as e:
@ -3851,3 +3952,15 @@ def _platform_index_lowering(
lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering
def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, dim):
placeholder = ctx.lowering_context.dynamic_shape_replacement_fn((dim,))[0]
return ir_constant(
placeholder, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))
)
import jax._src.export.shape_poly as shape_poly
lowering_rules[shape_poly.dim_as_value_p] = _dim_as_value_lowering

View File

@ -2501,7 +2501,8 @@ class SymbolicPallasTest(PallasBaseTest):
)
assert exported_module is not None
self.assertIn(
"tensor<?x?xf32>, %arg6: tensor<?x?xf32>, %arg7: tensor<?x?xf32>",
"%arg0: tensor<?x?xf32> loc(unknown), %arg1: tensor<?x?xf32>"
" loc(unknown), %arg2: tensor<?x?xf32>",
str(exported_module),
)
x = jax.ShapeDtypeStruct((128, 1024), jax.numpy.float32)
@ -2512,7 +2513,7 @@ class SymbolicPallasTest(PallasBaseTest):
)
assert exported_module is not None
self.assertIn(
"@sym_matmul(%arg0: tensor<128x1024xf32>, %arg1: tensor<1024x512xf32>",
"call @sym_matmul(%arg0, %arg1)",
str(exported_module),
)