mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Better mosaic lowering for dynamic shapes, extend an interpreter into shape_poly dimexpr and lower them alongside the graph if we are in a dynamic export regime.
PiperOrigin-RevId: 738171437
This commit is contained in:
parent
0fb59747f0
commit
01a110c4c9
@ -35,6 +35,7 @@ from jax._src import linear_util as lu
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.export._export import export
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.state import discharge as state_discharge
|
||||
@ -1165,14 +1166,16 @@ jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule
|
||||
|
||||
|
||||
def lower_as_mlir(
|
||||
f, *args, dynamic_shapes=False, device=None, **kwargs
|
||||
f, *args, dynamic_shapes=False, device=None, static_argnames=(), **kwargs
|
||||
) -> mlir.ir.Module:
|
||||
with pallas_export_experimental(dynamic_shapes):
|
||||
lowered = jax.jit(f, device=device).lower(*args, **kwargs)
|
||||
stablehlo = lowered.compiler_ir(dialect="stablehlo")
|
||||
f = jax.jit(f, device=device, static_argnames=static_argnames)
|
||||
exported = export(f, platforms=["tpu"])(*args, **kwargs)
|
||||
stablehlo = exported.mlir_module()
|
||||
|
||||
return stablehlo # type: ignore[return-value]
|
||||
|
||||
|
||||
_out_shape_to_aval_mapping: dict[
|
||||
type[Any], Callable[[Any], jax_core.AbstractValue]
|
||||
] = {}
|
||||
|
@ -40,6 +40,7 @@ from jax._src import source_info_util
|
||||
from jax._src import state
|
||||
from jax._src import traceback_util
|
||||
from jax._src.cloud_tpu_init import is_cloud_tpu_older_than
|
||||
from jax._src.export._export import export
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax as lax_internal
|
||||
@ -89,6 +90,11 @@ BOOL_MEMREF_TYPE = np.dtype('int32')
|
||||
# The value interpreted as a dynamic dimension by MLIR.
|
||||
MLIR_DYNAMIC = -9223372036854775808
|
||||
|
||||
# TODO(mvoz): Find a way to make this a contract we can share with the
|
||||
# export specialization step in XLA export.
|
||||
DIM_UPPER_BOUND = np.iinfo(np.int32).max
|
||||
DIM_LOWER_BOUND = -128
|
||||
|
||||
partial = functools.partial
|
||||
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
|
||||
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin
|
||||
@ -102,17 +108,49 @@ class MeshContext:
|
||||
|
||||
# Note - On Export Placeholders
|
||||
#
|
||||
# Mosaic uses vector IR, which does not have a concept of dynamic
|
||||
# dimensions. We need to come up with a way to represent dynamic dimensions in
|
||||
# vector IR, and so we use placeholders, which are later replaced during
|
||||
# specialization.
|
||||
# Since the vector dialect used by Mosaic does not support dynamic shapes,
|
||||
# we replace all top-level symbolic dimensions with placeholder
|
||||
# constants (between max(int32) - 128 and max(int32)) and we keep a
|
||||
# mapping from the placeholder constants to SHLO functions that encode
|
||||
# the symbolic dimension expression, as a function of the dimension
|
||||
# variables.
|
||||
#
|
||||
# The calling convention of the produced MLIR module is the same as
|
||||
# regular mosaic module, except we add on two new attributes to the custom call
|
||||
# *per* intermediary placeholder dimension.
|
||||
#
|
||||
# The attributes are:
|
||||
#
|
||||
# tpu.dynamic_dimension_mapping_arg_name_<placeholder>
|
||||
# tpu.dynamic_dimension_mapping_module_<placeholder>
|
||||
#
|
||||
# The first attribute is a comma-separated list of the dimension variables
|
||||
# that are used to compute the symbolic dimension expression for the
|
||||
# placeholder. The second attribute is the MLIR module that contains the
|
||||
# SHLO functions that compute the symbolic dimension expression for the
|
||||
# placeholder.
|
||||
class LoweringDynamicShapeEnv:
|
||||
dim_expr_to_placeholder: dict[Any, ir.Value] = {}
|
||||
dim_expr_to_placeholder: dict[shape_poly._DimExpr, int] = {}
|
||||
placeholder_to_dim_expr: dict[int, shape_poly._DimExpr] = {}
|
||||
|
||||
def to_placeholder(self, dim_expr: Any) -> ir.Value:
|
||||
if jax_core.is_constant_dim(dim_expr):
|
||||
# avoid ints, these are not dynamic
|
||||
return dim_expr
|
||||
if dim_expr not in self.dim_expr_to_placeholder:
|
||||
next_val = np.iinfo(np.int32).max - len(self.dim_expr_to_placeholder)
|
||||
next_val = DIM_UPPER_BOUND - len(self.dim_expr_to_placeholder)
|
||||
if next_val < DIM_LOWER_BOUND:
|
||||
# In practice, even with the largest of programs, we see rarely see
|
||||
# anything even close to this limit. It is arbitrary, and can be safely
|
||||
# increased if needed.
|
||||
raise ValueError(
|
||||
"Too many dynamic shapes in the input. Mosaic currently only"
|
||||
" supports up to 128 dynamic dimension values."
|
||||
)
|
||||
self.dim_expr_to_placeholder[dim_expr] = next_val
|
||||
# Reverse mapping - this is consumed to generate a table that is either
|
||||
# input<>placeholder or intermediary computation<>placeholder.
|
||||
self.placeholder_to_dim_expr[next_val] = dim_expr
|
||||
return self.dim_expr_to_placeholder[dim_expr]
|
||||
|
||||
|
||||
@ -622,6 +660,7 @@ def lower_jaxpr_to_module(
|
||||
"Pallas TPU requires a libTPU version that's at most a month old"
|
||||
)
|
||||
debug_info = jaxpr.debug_info
|
||||
_mosaic_lowering_dynamic_shape_env = None
|
||||
if dynamic_shape_replacement_enabled:
|
||||
_mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv()
|
||||
|
||||
@ -663,10 +702,12 @@ def lower_jaxpr_to_module(
|
||||
for_verification=for_verification,
|
||||
forward_compatible=lowering_context.is_forward_compat(),
|
||||
dynamic_shape_replacement_fn=dynamic_shape_replacement_fn,
|
||||
dynamic_shape_replacement_enabled=dynamic_shape_replacement_enabled,
|
||||
)
|
||||
m.body.append(func_op)
|
||||
sym_tab.insert(func_op)
|
||||
window_params = []
|
||||
static_grid = None
|
||||
grid = mosaic_grid_mapping.grid
|
||||
if grid:
|
||||
for i, bm in enumerate(grid_mapping.block_mappings):
|
||||
@ -738,7 +779,6 @@ def lower_jaxpr_to_module(
|
||||
]
|
||||
static_grid = dynamic_shape_replacement_fn(static_grid)
|
||||
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)
|
||||
|
||||
func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
|
||||
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types))
|
||||
func_op.attributes["scratch_operands"] = ir.IntegerAttr.get(
|
||||
@ -746,6 +786,60 @@ def lower_jaxpr_to_module(
|
||||
func_op.attributes["dimension_semantics"] = (
|
||||
mosaic_grid_mapping.get_dimension_semantics()
|
||||
)
|
||||
if dynamic_shape_replacement_enabled:
|
||||
if _mosaic_lowering_dynamic_shape_env is None:
|
||||
raise ValueError(
|
||||
"Dynamic shape env is None, invariant violated. Unreachable?"
|
||||
)
|
||||
|
||||
# Now we can use jax to compute the dynamic shape graph
|
||||
|
||||
if static_grid is not None:
|
||||
grid_vars = [
|
||||
_mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.get(g, g)
|
||||
for g in static_grid
|
||||
]
|
||||
else:
|
||||
grid_vars = []
|
||||
|
||||
invars = [invar.aval for invar in jaxpr.invars]
|
||||
# Faux shape for grid, just to get the avals
|
||||
invars.append(jax.ShapeDtypeStruct(grid_vars, jax.numpy.int32))
|
||||
args_dimvars = shape_poly.all_dim_vars(invars)
|
||||
|
||||
# This is dimexpr var -> placeholder value for when we jit the dim expr
|
||||
env: dict[str, int] = {}
|
||||
for aval in args_dimvars:
|
||||
env[aval] = _mosaic_lowering_dynamic_shape_env.to_placeholder(aval)
|
||||
|
||||
for (
|
||||
placeholder,
|
||||
dim_expr,
|
||||
) in _mosaic_lowering_dynamic_shape_env.placeholder_to_dim_expr.items():
|
||||
top_level_names = list(env.keys())
|
||||
if dim_expr not in top_level_names:
|
||||
jitted_eval = jax.jit(
|
||||
jax_core.evaluate_shape,
|
||||
static_argnames=(
|
||||
"shape",
|
||||
"dim_vars",
|
||||
),
|
||||
keep_unused=True,
|
||||
)
|
||||
stablehlo = export(
|
||||
jitted_eval, platforms=[str(jax.devices()[0].platform)]
|
||||
)(
|
||||
(dim_expr,), tuple(args_dimvars), *(env[v] for v in args_dimvars)
|
||||
).mlir_module()
|
||||
arg_name = args_dimvars
|
||||
# See Note - On Export Placeholders for more details.
|
||||
m.operation.attributes[
|
||||
"tpu.dynamic_dimension_mapping_module_" + str(placeholder)
|
||||
] = ir.StringAttr.get(str(stablehlo))
|
||||
arg_name_str = ",".join(arg_name)
|
||||
m.operation.attributes[
|
||||
"tpu.dynamic_dimension_mapping_arg_name_" + str(placeholder)
|
||||
] = ir.StringAttr.get(arg_name_str)
|
||||
return m, mosaic_grid_mapping.get_extra_args()
|
||||
|
||||
|
||||
@ -828,6 +922,7 @@ def lower_jaxpr_to_func(
|
||||
dynamic_shape_replacement_fn: (
|
||||
Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None
|
||||
) = None,
|
||||
dynamic_shape_replacement_enabled: bool = False,
|
||||
) -> func.FuncOp:
|
||||
num_grid = len(mosaic_grid_mapping.grid_types)
|
||||
num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types)
|
||||
@ -874,6 +969,12 @@ def lower_jaxpr_to_func(
|
||||
)
|
||||
body_func.__name__ = name
|
||||
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
|
||||
if dynamic_shape_replacement_enabled:
|
||||
# Skip verification for dynamic shape replacement - you can potentially
|
||||
# produce ir like ex: add(x[placeholder_0, placeholder_1], y[128, 128])
|
||||
# which is not valid, but we don't care since we'll run the verifier again
|
||||
# after the dynamic shape replacement pass.
|
||||
return body.func_op
|
||||
try:
|
||||
body.func_op.verify()
|
||||
except ir.MLIRError as e:
|
||||
@ -3851,3 +3952,15 @@ def _platform_index_lowering(
|
||||
|
||||
|
||||
lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering
|
||||
|
||||
|
||||
def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, dim):
|
||||
placeholder = ctx.lowering_context.dynamic_shape_replacement_fn((dim,))[0]
|
||||
return ir_constant(
|
||||
placeholder, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))
|
||||
)
|
||||
|
||||
|
||||
import jax._src.export.shape_poly as shape_poly
|
||||
|
||||
lowering_rules[shape_poly.dim_as_value_p] = _dim_as_value_lowering
|
||||
|
@ -2501,7 +2501,8 @@ class SymbolicPallasTest(PallasBaseTest):
|
||||
)
|
||||
assert exported_module is not None
|
||||
self.assertIn(
|
||||
"tensor<?x?xf32>, %arg6: tensor<?x?xf32>, %arg7: tensor<?x?xf32>",
|
||||
"%arg0: tensor<?x?xf32> loc(unknown), %arg1: tensor<?x?xf32>"
|
||||
" loc(unknown), %arg2: tensor<?x?xf32>",
|
||||
str(exported_module),
|
||||
)
|
||||
x = jax.ShapeDtypeStruct((128, 1024), jax.numpy.float32)
|
||||
@ -2512,7 +2513,7 @@ class SymbolicPallasTest(PallasBaseTest):
|
||||
)
|
||||
assert exported_module is not None
|
||||
self.assertIn(
|
||||
"@sym_matmul(%arg0: tensor<128x1024xf32>, %arg1: tensor<1024x512xf32>",
|
||||
"call @sym_matmul(%arg0, %arg1)",
|
||||
str(exported_module),
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user