mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16199 from gnecula:poly_state
PiperOrigin-RevId: 536932763
This commit is contained in:
commit
a45fbef807
@ -371,6 +371,21 @@ AxisContext = Union[
|
||||
sharding_impls.ShardingContext,
|
||||
]
|
||||
|
||||
class ShapePolyLoweringState:
|
||||
# The names of the dimension variables, sorted by name. This is the order in
|
||||
# which they are passed to the IR functions that need them. This is only
|
||||
# used for native serialization with polymorphic shapes when
|
||||
# --jax_dynamic_shapes is off.
|
||||
dim_vars: Sequence[str]
|
||||
# Whether the module uses dimension variables, either in its inputs or
|
||||
# from an inner call to a polymorphic Exported.
|
||||
uses_dim_vars: bool
|
||||
|
||||
def __init__(self, dim_vars: Sequence[str]):
|
||||
self.dim_vars = dim_vars
|
||||
self.uses_dim_vars = (len(dim_vars) > 0)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModuleContext:
|
||||
"""Module-wide context information for MLIR lowering."""
|
||||
@ -385,11 +400,8 @@ class ModuleContext:
|
||||
keepalives: List[Any]
|
||||
channel_iterator: Iterator[int]
|
||||
host_callbacks: List[Any]
|
||||
# The names of the dimension variables, sorted by name. This is the order in
|
||||
# which they are passed to the IR functions that need them. This is only
|
||||
# used for native serialization with polymorphic shapes when
|
||||
# --jax_dynamic_shapes is off.
|
||||
dim_vars: Sequence[str]
|
||||
# Keep state for the lowering of shape polymorphism
|
||||
shape_poly_state: ShapePolyLoweringState
|
||||
|
||||
# Cached primitive lowerings.
|
||||
cached_primitive_lowerings: Dict[Any, func_dialect.FuncOp]
|
||||
@ -417,7 +429,7 @@ class ModuleContext:
|
||||
func_dialect.FuncOp]] = None,
|
||||
cached_call_jaxpr_lowerings: Optional[Dict[Any,
|
||||
func_dialect.FuncOp]] = None,
|
||||
dim_vars: Sequence[str] = ()):
|
||||
shape_poly_state = None):
|
||||
assert platform is not None
|
||||
self.context = context or make_ir_context()
|
||||
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
|
||||
@ -435,7 +447,7 @@ class ModuleContext:
|
||||
self.cached_call_jaxpr_lowerings = ({}
|
||||
if cached_call_jaxpr_lowerings is None
|
||||
else cached_call_jaxpr_lowerings)
|
||||
self.dim_vars = dim_vars
|
||||
self.shape_poly_state = shape_poly_state or ShapePolyLoweringState(())
|
||||
|
||||
@property
|
||||
def backend(self) -> xb.XlaBackend:
|
||||
@ -466,7 +478,7 @@ class LoweringRuleContext:
|
||||
tokens_out: Optional[TokenSet] # Mutable store for output containers
|
||||
axis_size_env: Optional[Dict[core.Var, ir.Value]] = None # Dynamic axis sizes
|
||||
dim_var_values: Sequence[ir.Value] = () # The values for the dimension variables
|
||||
# in same order as module_context.dim_vars
|
||||
# in same order as module_context.shape_poly_state.dim_vars
|
||||
|
||||
def set_tokens_out(self, tokens_out: TokenSet):
|
||||
assert self.tokens_out is None, 'Should only set `tokens_out` once.'
|
||||
@ -535,9 +547,9 @@ def eval_dynamic_shape(ctx: LoweringRuleContext,
|
||||
else:
|
||||
ctx = ctx.replace(
|
||||
primitive="eval_dynamic_shape",
|
||||
avals_in=[core.dim_value_aval()] * len(ctx.module_context.dim_vars))
|
||||
avals_in=[core.dim_value_aval()] * len(ctx.module_context.shape_poly_state.dim_vars))
|
||||
res = lower_fun(
|
||||
partial(core.evaluate_shape, shape, ctx.module_context.dim_vars),
|
||||
partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars),
|
||||
multiple_results=True)(ctx, *ctx.dim_var_values)
|
||||
return util.flatten(res) # type: ignore
|
||||
|
||||
@ -546,6 +558,7 @@ class LoweringResult(NamedTuple):
|
||||
module: ir.Module
|
||||
keepalive: Optional[Any]
|
||||
host_callbacks: List[Any]
|
||||
shape_poly_state: ShapePolyLoweringState
|
||||
|
||||
|
||||
_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
|
||||
@ -628,7 +641,8 @@ def lower_jaxpr_to_module(
|
||||
if result_shardings is not None else result_shardings)
|
||||
|
||||
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
|
||||
keepalives, channel_iter, host_callbacks, dim_vars=dim_vars)
|
||||
keepalives, channel_iter, host_callbacks,
|
||||
shape_poly_state=ShapePolyLoweringState(dim_vars))
|
||||
with ctx.context, ir.Location.unknown(ctx.context):
|
||||
# Remove module name characters that XLA would alter. This ensures that
|
||||
# XLA computation preserves the module name.
|
||||
@ -658,7 +672,8 @@ def lower_jaxpr_to_module(
|
||||
raise ValueError(
|
||||
f"Cannot lower jaxpr with verifier errors: {module_string}") from e
|
||||
|
||||
return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks)
|
||||
return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks,
|
||||
ctx.shape_poly_state)
|
||||
|
||||
def module_to_string(module: ir.Module) -> str:
|
||||
output = io.StringIO()
|
||||
@ -805,7 +820,7 @@ def lower_jaxpr_to_fun(
|
||||
aval = core.ShapedArray((), np.dtype(np.bool_))
|
||||
return aval_to_ir_types(aval)
|
||||
|
||||
num_dim_vars = len(ctx.dim_vars)
|
||||
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
|
||||
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
|
||||
dim_var_types = map(aval_to_types, dim_var_avals)
|
||||
|
||||
@ -1006,7 +1021,7 @@ def _to_physical_op_sharding(
|
||||
def _emit_lowering_rule_as_fun(lowering_rule,
|
||||
ctx: LoweringRuleContext) -> func_dialect.FuncOp:
|
||||
"""Emits the contents of a lowering rule as a private function."""
|
||||
num_dim_vars = len(ctx.module_context.dim_vars)
|
||||
num_dim_vars = len(ctx.module_context.shape_poly_state.dim_vars)
|
||||
# TODO(necula) maybe only pass the dim_vars if they are needed?
|
||||
dim_var_types = map(aval_to_ir_types, [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars)
|
||||
|
||||
@ -1049,7 +1064,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
Assumes that an MLIR context, location, and insertion point are set.
|
||||
|
||||
dim_var_values: the list of dimension variables values in the current
|
||||
IR function, in the order of ctx.dim_vars.
|
||||
IR function, in the order of ctx.shape_poly_state.dim_vars.
|
||||
"""
|
||||
assert ctx.platform != "gpu"
|
||||
def read(v: core.Atom) -> Sequence[ir.Value]:
|
||||
@ -1075,7 +1090,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
assert len(args) == len(jaxpr.invars), (jaxpr, args)
|
||||
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
|
||||
assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts
|
||||
assert len(ctx.dim_vars) == len(dim_var_values), (ctx.dim_vars, dim_var_values)
|
||||
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values)
|
||||
map(write, jaxpr.constvars, consts)
|
||||
map(write, jaxpr.invars, args)
|
||||
for eqn in jaxpr.eqns:
|
||||
|
@ -1945,7 +1945,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
return (lowering_result.module, lowering_result.keepalive,
|
||||
lowering_result.host_callbacks, unordered_effects, ordered_effects,
|
||||
nreps, tuple_args)
|
||||
nreps, tuple_args, lowering_result.shape_poly_state)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -2080,7 +2080,7 @@ def lower_sharding_computation(
|
||||
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
|
||||
semantic_out_shardings = SemanticallyEqualShardings(out_shardings)
|
||||
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
||||
nreps, tuple_args) = _cached_lowering_to_hlo(
|
||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
||||
semantic_out_shardings, da_object, lowering_platform,
|
||||
donated_invars, name_stack)
|
||||
@ -2111,7 +2111,8 @@ def lower_sharding_computation(
|
||||
device_assignment=da_object,
|
||||
committed=committed,
|
||||
pmap_nreps=nreps,
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
|
||||
shape_poly_state=shape_poly_state)
|
||||
|
||||
|
||||
def _to_logical_sharding(
|
||||
@ -2285,7 +2286,8 @@ def lower_mesh_computation(
|
||||
backend=backend,
|
||||
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
|
||||
committed=True,
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
|
||||
shape_poly_state=lowering_result.shape_poly_state)
|
||||
|
||||
class MeshComputation(stages.XlaLowering):
|
||||
_hlo: Optional[ir.Module]
|
||||
@ -2617,8 +2619,10 @@ class UnloadedMeshExecutable:
|
||||
committed: bool,
|
||||
pmap_nreps: int = 1,
|
||||
jaxpr_debug_info: Optional[core.JaxprDebugInfo] = None,
|
||||
shape_poly_state: Optional[mlir.ShapePolyLoweringState] = None,
|
||||
compiler_options=None
|
||||
) -> MeshExecutable:
|
||||
del shape_poly_state
|
||||
compiler_options_keys = tuple(
|
||||
compiler_options.keys()) if compiler_options is not None else None
|
||||
compiler_options_values = tuple(
|
||||
|
@ -76,6 +76,9 @@ class Exported:
|
||||
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
|
||||
must be passed to the module. The other arguments have been dropped
|
||||
because they are not used. Same length as `in_shardings`.
|
||||
module_uses_dim_vars: whether the `mlir_module_serialized` uses shape
|
||||
polymorphic dimension variables. This may be from `in_avals` but also
|
||||
from inner calls of Exported modules.
|
||||
strict_checks: whether the module was serialized with the following safety
|
||||
checking: (A) the lowered computation can only be executed on a platform
|
||||
for which it was lowered; (B) the serialized computation contains only
|
||||
@ -101,6 +104,7 @@ class Exported:
|
||||
mlir_module_serialized: bytes
|
||||
xla_call_module_version: int
|
||||
module_kept_var_idx: Tuple[int, ...]
|
||||
module_uses_dim_vars: bool
|
||||
|
||||
_get_vjp: Optional[Callable[["Exported"], "Exported"]]
|
||||
|
||||
@ -264,10 +268,9 @@ def export(fun_jax: Callable,
|
||||
else:
|
||||
# For pmap
|
||||
module_kept_var_idx = tuple(range(len(args_avals_flat)))
|
||||
|
||||
if not all(
|
||||
core.is_constant_shape(a.shape) for a in args_avals_flat
|
||||
) or lowering.compile_args.get("ordered_effects", []):
|
||||
shape_poly_state = lowering.compile_args["shape_poly_state"]
|
||||
if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat)
|
||||
or lowering.compile_args.get("ordered_effects", [])):
|
||||
# All arguments are kept if we have dimension variables.
|
||||
assert len(module_kept_var_idx) == len(args_avals_flat)
|
||||
mlir_module = _wrap_main_func(
|
||||
@ -334,6 +337,7 @@ def export(fun_jax: Callable,
|
||||
strict_checks=strict_checks,
|
||||
mlir_module_serialized=mlir_module_serialized,
|
||||
module_kept_var_idx=module_kept_var_idx,
|
||||
module_uses_dim_vars=shape_poly_state.uses_dim_vars,
|
||||
xla_call_module_version=xla_call_module_version,
|
||||
_get_vjp=lambda exported: _export_native_vjp(fun_jax, exported))
|
||||
|
||||
@ -387,7 +391,6 @@ def _wrap_main_func(
|
||||
Returns the wrapped module.
|
||||
"""
|
||||
dim_vars = shape_poly.all_dim_vars(args_avals_flat)
|
||||
|
||||
# Make a new module, do not mutate the "module" because it may be cached
|
||||
context = mlir.make_ir_context()
|
||||
with context, ir.Location.unknown(context):
|
||||
@ -512,7 +515,7 @@ def _check_lowering(lowering) -> None:
|
||||
"spmd_lowering", "auto_spmd_lowering",
|
||||
"tuple_args", "ordered_effects", "unordered_effects",
|
||||
"keepalive", "host_callbacks", "pmap_nreps", "committed",
|
||||
"device_assignment", "jaxpr_debug_info"]
|
||||
"device_assignment", "jaxpr_debug_info", "shape_poly_state"]
|
||||
for compile_arg in lowering.compile_args.keys():
|
||||
if compile_arg not in allowed_compile_args:
|
||||
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
|
||||
@ -538,6 +541,7 @@ def _check_lowering(lowering) -> None:
|
||||
# used on all platforms for callbacks. Not supported yet.
|
||||
("keepalive", lambda v: not v, "empty"),
|
||||
("pmap_nreps", lambda v: v == 1, "1"),
|
||||
("shape_poly_state", lambda v: True, "N/A"),
|
||||
):
|
||||
if compile_arg in lowering.compile_args:
|
||||
if not check_value(lowering.compile_args[compile_arg]):
|
||||
@ -810,6 +814,9 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
f"The exported function '{exported.fun_name}' was lowered for "
|
||||
f"platform '{exported.lowering_platform}' but it is used "
|
||||
f"on '{platform}'.")
|
||||
if any(not core.is_constant_shape(a.shape) for a in exported.in_avals):
|
||||
ctx.module_context.shape_poly_state.uses_dim_vars = True
|
||||
|
||||
submodule = ir.Module.parse(exported.mlir_module)
|
||||
symtab = ir.SymbolTable(submodule.operation)
|
||||
# The called function may have been exported with polymorphic shapes and called
|
||||
|
@ -275,6 +275,8 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
mlir_module_serialized=data.mlir_module_serialized,
|
||||
xla_call_module_version=data.xla_call_module_version,
|
||||
module_kept_var_idx=tuple(range(len(in_avals))),
|
||||
module_uses_dim_vars=any(not core.is_constant_shape(a.shape)
|
||||
for a in in_avals),
|
||||
_get_vjp=_get_vjp)
|
||||
|
||||
# We use pjit in case there are shardings in the exported module.
|
||||
|
@ -264,6 +264,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
inner_exp = jax_export.export(inner)(
|
||||
jax_export.poly_spec(inner_x.shape, inner_x.dtype, inner_poly_spec))
|
||||
|
||||
self.assertEqual(inner_exp.module_uses_dim_vars,
|
||||
(inner_poly_spec != "3,4,12"))
|
||||
outer_x = np.arange(np.prod(outer_x_shape),
|
||||
dtype=np.float32).reshape(outer_x_shape) # outer_x : f32[3,4,12]
|
||||
def outer(x): # x: outer_poly_spec
|
||||
@ -278,12 +280,17 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# Call it after exporting again, with polymorphic shapes
|
||||
outer_exp = jax_export.export(outer)(
|
||||
jax_export.poly_spec(outer_x.shape, outer_x.dtype, outer_poly_spec))
|
||||
# TODO: for now, we use XlaCallModule to run modules with polymorphic shapes
|
||||
# until we create the python bindings to invoke shape refinement.
|
||||
if jax2tf is not None:
|
||||
res2 = jax2tf._run_exported_as_tf([outer_x], outer_exp)[0].numpy()
|
||||
# res2 = jax_export.call_exported(exp2)(x2)
|
||||
self.assertAllClose(2. * inner(outer_x), res2)
|
||||
self.assertEqual(outer_exp.module_uses_dim_vars,
|
||||
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))
|
||||
if not outer_exp.module_uses_dim_vars:
|
||||
res = jax_export.call_exported(outer_exp)(outer_x)
|
||||
self.assertAllClose(2. * inner(outer_x), res)
|
||||
else:
|
||||
# TODO: for now, we use XlaCallModule to run modules with polymorphic shapes
|
||||
# until we create the python bindings to invoke shape refinement.
|
||||
if jax2tf is not None:
|
||||
res = jax2tf._run_exported_as_tf([outer_x], outer_exp)[0].numpy()
|
||||
self.assertAllClose(2. * inner(outer_x), res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user