From 01a110c4c9b4d19f897f073d65d84a319387ccb6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 18 Mar 2025 15:50:27 -0700 Subject: [PATCH] Better mosaic lowering for dynamic shapes, extend an interpreter into shape_poly dimexpr and lower them alongside the graph if we are in a dynamic export regime. PiperOrigin-RevId: 738171437 --- jax/_src/pallas/core.py | 9 +- jax/_src/pallas/mosaic/lowering.py | 127 +++++++++++++++++++++++++++-- tests/pallas/pallas_test.py | 5 +- 3 files changed, 129 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 5342a6946..206c2a73f 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -35,6 +35,7 @@ from jax._src import linear_util as lu from jax._src import state from jax._src import tree_util from jax._src import util +from jax._src.export._export import export from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.state import discharge as state_discharge @@ -1165,14 +1166,16 @@ jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule def lower_as_mlir( - f, *args, dynamic_shapes=False, device=None, **kwargs + f, *args, dynamic_shapes=False, device=None, static_argnames=(), **kwargs ) -> mlir.ir.Module: with pallas_export_experimental(dynamic_shapes): - lowered = jax.jit(f, device=device).lower(*args, **kwargs) - stablehlo = lowered.compiler_ir(dialect="stablehlo") + f = jax.jit(f, device=device, static_argnames=static_argnames) + exported = export(f, platforms=["tpu"])(*args, **kwargs) + stablehlo = exported.mlir_module() return stablehlo # type: ignore[return-value] + _out_shape_to_aval_mapping: dict[ type[Any], Callable[[Any], jax_core.AbstractValue] ] = {} diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 4efb2b276..10b9de748 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -40,6 +40,7 @@ from jax._src import source_info_util from jax._src import state from jax._src import traceback_util from jax._src.cloud_tpu_init import is_cloud_tpu_older_than +from jax._src.export._export import export from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal @@ -89,6 +90,11 @@ BOOL_MEMREF_TYPE = np.dtype('int32') # The value interpreted as a dynamic dimension by MLIR. MLIR_DYNAMIC = -9223372036854775808 +# TODO(mvoz): Find a way to make this a contract we can share with the +# export specialization step in XLA export. +DIM_UPPER_BOUND = np.iinfo(np.int32).max +DIM_LOWER_BOUND = -128 + partial = functools.partial map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin @@ -102,17 +108,49 @@ class MeshContext: # Note - On Export Placeholders # -# Mosaic uses vector IR, which does not have a concept of dynamic -# dimensions. We need to come up with a way to represent dynamic dimensions in -# vector IR, and so we use placeholders, which are later replaced during -# specialization. +# Since the vector dialect used by Mosaic does not support dynamic shapes, +# we replace all top-level symbolic dimensions with placeholder +# constants (between max(int32) - 128 and max(int32)) and we keep a +# mapping from the placeholder constants to SHLO functions that encode +# the symbolic dimension expression, as a function of the dimension +# variables. +# +# The calling convention of the produced MLIR module is the same as +# regular mosaic module, except we add on two new attributes to the custom call +# *per* intermediary placeholder dimension. +# +# The attributes are: +# +# tpu.dynamic_dimension_mapping_arg_name_ +# tpu.dynamic_dimension_mapping_module_ +# +# The first attribute is a comma-separated list of the dimension variables +# that are used to compute the symbolic dimension expression for the +# placeholder. The second attribute is the MLIR module that contains the +# SHLO functions that compute the symbolic dimension expression for the +# placeholder. class LoweringDynamicShapeEnv: - dim_expr_to_placeholder: dict[Any, ir.Value] = {} + dim_expr_to_placeholder: dict[shape_poly._DimExpr, int] = {} + placeholder_to_dim_expr: dict[int, shape_poly._DimExpr] = {} def to_placeholder(self, dim_expr: Any) -> ir.Value: + if jax_core.is_constant_dim(dim_expr): + # avoid ints, these are not dynamic + return dim_expr if dim_expr not in self.dim_expr_to_placeholder: - next_val = np.iinfo(np.int32).max - len(self.dim_expr_to_placeholder) + next_val = DIM_UPPER_BOUND - len(self.dim_expr_to_placeholder) + if next_val < DIM_LOWER_BOUND: + # In practice, even with the largest of programs, we see rarely see + # anything even close to this limit. It is arbitrary, and can be safely + # increased if needed. + raise ValueError( + "Too many dynamic shapes in the input. Mosaic currently only" + " supports up to 128 dynamic dimension values." + ) self.dim_expr_to_placeholder[dim_expr] = next_val + # Reverse mapping - this is consumed to generate a table that is either + # input<>placeholder or intermediary computation<>placeholder. + self.placeholder_to_dim_expr[next_val] = dim_expr return self.dim_expr_to_placeholder[dim_expr] @@ -622,6 +660,7 @@ def lower_jaxpr_to_module( "Pallas TPU requires a libTPU version that's at most a month old" ) debug_info = jaxpr.debug_info + _mosaic_lowering_dynamic_shape_env = None if dynamic_shape_replacement_enabled: _mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv() @@ -663,10 +702,12 @@ def lower_jaxpr_to_module( for_verification=for_verification, forward_compatible=lowering_context.is_forward_compat(), dynamic_shape_replacement_fn=dynamic_shape_replacement_fn, + dynamic_shape_replacement_enabled=dynamic_shape_replacement_enabled, ) m.body.append(func_op) sym_tab.insert(func_op) window_params = [] + static_grid = None grid = mosaic_grid_mapping.grid if grid: for i, bm in enumerate(grid_mapping.block_mappings): @@ -738,7 +779,6 @@ def lower_jaxpr_to_module( ] static_grid = dynamic_shape_replacement_fn(static_grid) func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid) - func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get( ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types)) func_op.attributes["scratch_operands"] = ir.IntegerAttr.get( @@ -746,6 +786,60 @@ def lower_jaxpr_to_module( func_op.attributes["dimension_semantics"] = ( mosaic_grid_mapping.get_dimension_semantics() ) + if dynamic_shape_replacement_enabled: + if _mosaic_lowering_dynamic_shape_env is None: + raise ValueError( + "Dynamic shape env is None, invariant violated. Unreachable?" + ) + + # Now we can use jax to compute the dynamic shape graph + + if static_grid is not None: + grid_vars = [ + _mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.get(g, g) + for g in static_grid + ] + else: + grid_vars = [] + + invars = [invar.aval for invar in jaxpr.invars] + # Faux shape for grid, just to get the avals + invars.append(jax.ShapeDtypeStruct(grid_vars, jax.numpy.int32)) + args_dimvars = shape_poly.all_dim_vars(invars) + + # This is dimexpr var -> placeholder value for when we jit the dim expr + env: dict[str, int] = {} + for aval in args_dimvars: + env[aval] = _mosaic_lowering_dynamic_shape_env.to_placeholder(aval) + + for ( + placeholder, + dim_expr, + ) in _mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.items(): + top_level_names = list(env.keys()) + if dim_expr not in top_level_names: + jitted_eval = jax.jit( + jax_core.evaluate_shape, + static_argnames=( + "shape", + "dim_vars", + ), + keep_unused=True, + ) + stablehlo = export( + jitted_eval, platforms=[str(jax.devices()[0].platform)] + )( + (dim_expr,), tuple(args_dimvars), *(env[v] for v in args_dimvars) + ).mlir_module() + arg_name = args_dimvars + # See Note - On Export Placeholders for more details. + m.operation.attributes[ + "tpu.dynamic_dimension_mapping_module_" + str(placeholder) + ] = ir.StringAttr.get(str(stablehlo)) + arg_name_str = ",".join(arg_name) + m.operation.attributes[ + "tpu.dynamic_dimension_mapping_arg_name_" + str(placeholder) + ] = ir.StringAttr.get(arg_name_str) return m, mosaic_grid_mapping.get_extra_args() @@ -828,6 +922,7 @@ def lower_jaxpr_to_func( dynamic_shape_replacement_fn: ( Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None ) = None, + dynamic_shape_replacement_enabled: bool = False, ) -> func.FuncOp: num_grid = len(mosaic_grid_mapping.grid_types) num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types) @@ -874,6 +969,12 @@ def lower_jaxpr_to_func( ) body_func.__name__ = name body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) + if dynamic_shape_replacement_enabled: + # Skip verification for dynamic shape replacement - you can potentially + # produce ir like ex: add(x[placeholder_0, placeholder_1], y[128, 128]) + # which is not valid, but we don't care since we'll run the verifier again + # after the dynamic shape replacement pass. + return body.func_op try: body.func_op.verify() except ir.MLIRError as e: @@ -3851,3 +3952,15 @@ def _platform_index_lowering( lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering + + +def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, dim): + placeholder = ctx.lowering_context.dynamic_shape_replacement_fn((dim,))[0] + return ir_constant( + placeholder, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")) + ) + + +import jax._src.export.shape_poly as shape_poly + +lowering_rules[shape_poly.dim_as_value_p] = _dim_as_value_lowering diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 78c34d404..745c30ba9 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -2501,7 +2501,8 @@ class SymbolicPallasTest(PallasBaseTest): ) assert exported_module is not None self.assertIn( - "tensor, %arg6: tensor, %arg7: tensor", + "%arg0: tensor loc(unknown), %arg1: tensor" + " loc(unknown), %arg2: tensor", str(exported_module), ) x = jax.ShapeDtypeStruct((128, 1024), jax.numpy.float32) @@ -2512,7 +2513,7 @@ class SymbolicPallasTest(PallasBaseTest): ) assert exported_module is not None self.assertIn( - "@sym_matmul(%arg0: tensor<128x1024xf32>, %arg1: tensor<1024x512xf32>", + "call @sym_matmul(%arg0, %arg1)", str(exported_module), )