mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Cleanup the handling of single- and multi-platform lowering in ModuleContext
Previously, we introduced support for multi-platform lowering, by adding a new LoweringParameters object that can be used to specify a cross-lowering platform or even multiple platforms. But we had kept the ModuleContext.platform in place because some lowering rules were still referencing it. Now we replace ModuleContext.platform with ModuleContext.platforms, which removes the redundancy, simplifies the code, and makes it clearer that the lowering rules should not simply assume single-platform lowering. PiperOrigin-RevId: 576575376
This commit is contained in:
parent
468f66671b
commit
edbe49fb2a
@ -565,7 +565,7 @@ def xla_computation(fun: Callable,
|
||||
core.ClosedJaxpr(jaxpr, consts),
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
platform=platform,
|
||||
platforms=[platform],
|
||||
axis_context=sharding_impls.ReplicaAxisContext(axis_env_),
|
||||
name_stack=source_info_util.new_name_stack(
|
||||
wrap_name(fun_name, "xla_computation")),
|
||||
|
@ -385,7 +385,7 @@ def inspect_sharding_partition(shapes, arg_shardings, result_shape,
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
trivial_comp = mlir.build_xla_computation_helper(closed_jaxpr,
|
||||
name="tmp_xla_computation", platform=module_context.platform,
|
||||
name="tmp_xla_computation", platforms=module_context.platforms,
|
||||
backend_or_name=module_context.backend_or_name,
|
||||
axis_context=module_context.axis_context)
|
||||
# The trivial computation built here has a dummy tuple as the result,
|
||||
|
@ -554,6 +554,6 @@ def _common_device_put_lowering(ctx, x, *, device, src):
|
||||
device.memory_kind is not None):
|
||||
raise NotImplementedError(
|
||||
"Passing memory_kind to device_put via Shardings is not supported on"
|
||||
f" platform {ctx.module_context.platform}")
|
||||
f" platforms {ctx.module_context.platforms}")
|
||||
return [x]
|
||||
mlir.register_lowering(device_put_p, _common_device_put_lowering)
|
||||
|
@ -430,8 +430,7 @@ class LoweringParameters:
|
||||
# The current lowering platforms, a non-empty tuple containing some of
|
||||
# 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are
|
||||
# doing multi-platform lowering, otherwise it can specify cross-platform
|
||||
# lowering. The value None specifies default lowering platform, for the
|
||||
# platform specified by `ModuleContext.platform`.
|
||||
# lowering. The value None specifies the default lowering platform.
|
||||
# This is used only in export and jax2tf.
|
||||
platforms: tuple[str, ...] | None = None
|
||||
|
||||
@ -454,23 +453,6 @@ class LoweringParameters:
|
||||
# native execution (and we can remove this parameter).
|
||||
replace_tokens_with_dummy: bool = True
|
||||
|
||||
@property
|
||||
def override_platform(self) -> str | None:
|
||||
"""Overrides the lowering platform for cross-platform lowering.
|
||||
|
||||
One of 'cpu', 'cuda', 'rocm', 'tpu'.
|
||||
If None, use the default JAX mechanisms to pick the lowering platform.
|
||||
This is currently used for export and jax2tf.
|
||||
"""
|
||||
if self.platforms is not None:
|
||||
return self.platforms[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_multi_platform(self) -> bool:
|
||||
return self.platforms is not None and len(self.platforms) > 1
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModuleContext:
|
||||
@ -480,7 +462,7 @@ class ModuleContext:
|
||||
ip: ir.InsertionPoint
|
||||
symbol_table: ir.SymbolTable
|
||||
backend_or_name: str | xb.XlaBackend | None
|
||||
platform: str
|
||||
platforms: Sequence[str]
|
||||
axis_context: AxisContext
|
||||
name_stack: source_info_util.NameStack
|
||||
keepalives: list[Any]
|
||||
@ -503,7 +485,7 @@ class ModuleContext:
|
||||
self,
|
||||
*,
|
||||
backend_or_name: str | xb.XlaBackend | None,
|
||||
platform: str,
|
||||
platforms: Sequence[str],
|
||||
axis_context: AxisContext,
|
||||
name_stack: source_info_util.NameStack,
|
||||
keepalives: list[Any],
|
||||
@ -519,13 +501,13 @@ class ModuleContext:
|
||||
cached_call_jaxpr_lowerings: None | (dict[Any,
|
||||
func_dialect.FuncOp]) = None,
|
||||
shape_poly_state = None):
|
||||
assert platform is not None
|
||||
|
||||
self.context = context or make_ir_context()
|
||||
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
|
||||
self.ip = ip or ir.InsertionPoint(self.module.body)
|
||||
self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation)
|
||||
self.backend_or_name = backend_or_name
|
||||
self.platform = platform
|
||||
self.platforms = platforms
|
||||
self.axis_context = axis_context
|
||||
self.name_stack = name_stack
|
||||
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
|
||||
@ -536,12 +518,18 @@ class ModuleContext:
|
||||
self.cached_call_jaxpr_lowerings = ({}
|
||||
if cached_call_jaxpr_lowerings is None
|
||||
else cached_call_jaxpr_lowerings)
|
||||
self.shape_poly_state = shape_poly_state or ShapePolyLoweringState((),
|
||||
(platform,))
|
||||
self.shape_poly_state = (
|
||||
shape_poly_state or ShapePolyLoweringState((), tuple(platforms)))
|
||||
self.lowering_parameters = lowering_parameters
|
||||
|
||||
@property
|
||||
def backend(self) -> xb.XlaBackend:
|
||||
# TODO(necula): clean the use of backend and backend_or_name vs. platforms
|
||||
if len(self.platforms) > 1:
|
||||
raise NotImplementedError(
|
||||
"accessing .backend in multi-lowering setting. This can occur when "
|
||||
"lowering a primitive that has not been adapted to multi-platform "
|
||||
"lowering")
|
||||
if self.backend_or_name is None or isinstance(self.backend_or_name, str):
|
||||
return xb.get_backend(self.backend_or_name)
|
||||
return self.backend_or_name
|
||||
@ -722,7 +710,7 @@ def lower_jaxpr_to_module(
|
||||
*,
|
||||
ordered_effects: list[core.Effect],
|
||||
backend_or_name: str | xb.XlaBackend | None,
|
||||
platform: str,
|
||||
platforms: Sequence[str],
|
||||
axis_context: AxisContext,
|
||||
name_stack: source_info_util.NameStack,
|
||||
donated_args: Sequence[bool],
|
||||
@ -741,13 +729,7 @@ def lower_jaxpr_to_module(
|
||||
Handles the quirks of the argument/return value passing conventions of the
|
||||
runtime.
|
||||
"""
|
||||
platforms: tuple[str, ...]
|
||||
platform = xb.canonicalize_platform(platform)
|
||||
if lowering_parameters.is_multi_platform:
|
||||
platforms = tuple(map(xb.canonicalize_platform,
|
||||
lowering_parameters.platforms)) # type: ignore
|
||||
else:
|
||||
platforms = (platform,)
|
||||
platforms = tuple(map(xb.canonicalize_platform, platforms))
|
||||
|
||||
input_output_aliases = None
|
||||
in_avals = (jaxpr.in_avals if arg_shardings is None else
|
||||
@ -809,7 +791,7 @@ def lower_jaxpr_to_module(
|
||||
if result_shardings is not None else result_shardings)
|
||||
|
||||
ctx = ModuleContext(backend_or_name=backend_or_name,
|
||||
platform=platform, axis_context=axis_context,
|
||||
platforms=platforms, axis_context=axis_context,
|
||||
name_stack=name_stack,
|
||||
keepalives=keepalives,
|
||||
channel_iterator=channel_iter,
|
||||
@ -1349,7 +1331,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
dim_var_values: the list of dimension variables values in the current
|
||||
IR function, in the order of ctx.shape_poly_state.dim_vars.
|
||||
"""
|
||||
assert ctx.platform != "gpu"
|
||||
assert "gpu" not in ctx.platforms
|
||||
def read(v: core.Atom) -> Sequence[ir.Value]:
|
||||
if type(v) is core.Literal:
|
||||
return ir_constants(xla.canonicalize_dtype(v.val))
|
||||
@ -1393,14 +1375,15 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
ctx.name_stack)
|
||||
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
||||
override_rule = get_override_lowering_rule(eqn.primitive)
|
||||
if not ctx.lowering_parameters.is_multi_platform:
|
||||
if len(ctx.platforms) == 1:
|
||||
# Classic, single-platform lowering
|
||||
# TODO(necula): unify the code paths when multi-platform is finished
|
||||
platform = ctx.platforms[0]
|
||||
if override_rule is not None:
|
||||
rule = override_rule
|
||||
elif eqn.primitive in _platform_specific_lowerings[ctx.platform]:
|
||||
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
|
||||
elif eqn.primitive in xla._backend_specific_translations[ctx.platform]:
|
||||
elif eqn.primitive in _platform_specific_lowerings[platform]:
|
||||
rule = _platform_specific_lowerings[platform][eqn.primitive]
|
||||
elif eqn.primitive in xla._backend_specific_translations[platform]:
|
||||
rule = xla_fallback_lowering(eqn.primitive)
|
||||
elif eqn.primitive in _lowerings:
|
||||
rule = _lowerings[eqn.primitive]
|
||||
@ -1409,7 +1392,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
|
||||
f"found for platform {ctx.platform}")
|
||||
f"found for platform {platform}")
|
||||
else:
|
||||
rules: list[MultiPlatformLoweringRule]
|
||||
# See mlir.lower_multi_platform for the `rules` format
|
||||
@ -1418,7 +1401,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
else:
|
||||
# First the platform-specific rules
|
||||
rules = []
|
||||
for p in ctx.lowering_parameters.platforms: # type: ignore
|
||||
for p in ctx.platforms:
|
||||
if eqn.primitive in _platform_specific_lowerings[p]:
|
||||
rules.append(
|
||||
([p], _platform_specific_lowerings[p][eqn.primitive]))
|
||||
@ -1449,7 +1432,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
|
||||
|
||||
rule_inputs = map(_unwrap_singleton_ir_values, in_nodes)
|
||||
if not ctx.lowering_parameters.is_multi_platform:
|
||||
if len(ctx.platforms) == 1:
|
||||
# Classic, single-platform lowering
|
||||
ans = rule(rule_ctx, *rule_inputs, **eqn.params)
|
||||
else:
|
||||
@ -1528,12 +1511,7 @@ def lower_multi_platform(ctx: LoweringRuleContext,
|
||||
rule_args: the args of the lowering rules.
|
||||
rule_kwargs: the kwargs of the lowering rules.
|
||||
"""
|
||||
platforms: Sequence[str]
|
||||
if ctx.module_context.lowering_parameters.is_multi_platform:
|
||||
assert ctx.module_context.lowering_parameters.platforms is not None
|
||||
platforms = ctx.module_context.lowering_parameters.platforms
|
||||
else:
|
||||
platforms = (ctx.module_context.platform,)
|
||||
platforms: Sequence[str] = ctx.module_context.platforms
|
||||
platforms_with_specific_rules: Sequence[str] = util.flatten(
|
||||
[ps for ps, _ in rules if ps is not None])
|
||||
platforms_with_default_rule = [p for p in platforms
|
||||
@ -1681,11 +1659,16 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects,
|
||||
return func_op
|
||||
|
||||
|
||||
def check_backend_matches(inner_backend, outer_backend):
|
||||
def check_backend_matches(inner_backend: Optional[str],
|
||||
lowering_platforms: Sequence[str]):
|
||||
# For nested calls, the outermost call sets the backend for all inner calls;
|
||||
# it's an error if the inner call has a conflicting explicit backend spec.
|
||||
if inner_backend is None:
|
||||
return
|
||||
outer_backend, *more_lowering_platforms = lowering_platforms
|
||||
if more_lowering_platforms:
|
||||
raise NotImplementedError(
|
||||
"Multi-platform lowering when a backend= parameter is specified")
|
||||
if (inner_backend != outer_backend and
|
||||
outer_backend not in xb.expand_platform_alias(inner_backend)):
|
||||
raise ValueError(
|
||||
@ -1693,13 +1676,15 @@ def check_backend_matches(inner_backend, outer_backend):
|
||||
f"inner-jit backend specification {inner_backend}.")
|
||||
|
||||
|
||||
def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
|
||||
def _call_lowering(fn_name, stack_name, call_jaxpr, backend,
|
||||
ctx: ModuleContext, avals_in,
|
||||
avals_out, tokens_in, *args,
|
||||
dim_var_values: Sequence[ir.Value],
|
||||
arg_names=None, result_names=None):
|
||||
del stack_name, avals_in
|
||||
if isinstance(call_jaxpr, core.Jaxpr):
|
||||
call_jaxpr = core.ClosedJaxpr(call_jaxpr, ())
|
||||
check_backend_matches(backend, ctx.platform)
|
||||
check_backend_matches(backend, ctx.platforms)
|
||||
effects = list(tokens_in.effects())
|
||||
output_types = map(aval_to_ir_types, avals_out)
|
||||
output_types = [token_type()] * len(effects) + output_types
|
||||
@ -1717,7 +1702,8 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
|
||||
tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens)))
|
||||
return out_nodes, tokens_out
|
||||
|
||||
def core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
|
||||
def core_call_lowering(ctx: LoweringRuleContext,
|
||||
*args, name, backend=None, call_jaxpr):
|
||||
out_nodes, tokens = _call_lowering(
|
||||
name, name, call_jaxpr, backend, ctx.module_context,
|
||||
ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args,
|
||||
@ -2137,8 +2123,11 @@ def xla_fallback_lowering(prim: core.Primitive):
|
||||
raise NotImplementedError(
|
||||
f"Shape polymorphism for xla_fallback_lowering is not implemented ({ctx.primitive}); b/261682623")
|
||||
|
||||
if len(module_ctx.platforms) > 1:
|
||||
raise NotImplementedError(
|
||||
"fallback lowering not implemented for multi-platform lowering")
|
||||
xla_computation = xla.primitive_subcomputation(
|
||||
module_ctx.platform, axis_env, prim, ctx.avals_in,
|
||||
module_ctx.platforms[0], axis_env, prim, ctx.avals_in,
|
||||
ctx.avals_out, **params)
|
||||
xla_module = xla_computation_to_mlir_module(xla_computation)
|
||||
callee_name = merge_mlir_modules(
|
||||
@ -2301,7 +2290,9 @@ def emit_python_callback(
|
||||
result_layouts: Sequence[Sequence[int] | None] | None = None,
|
||||
) -> tuple[Sequence[ir.Value], Any, Any]:
|
||||
"""Emits MLIR that calls back to a provided Python function."""
|
||||
platform = ctx.module_context.platform
|
||||
if len(ctx.module_context.platforms) > 1:
|
||||
raise NotImplementedError("multi-platform lowering for python_callback")
|
||||
platform = ctx.module_context.platforms[0]
|
||||
if platform not in {"cpu", "cuda", "rocm", "tpu"}:
|
||||
raise ValueError(
|
||||
f"`EmitPythonCallback` not supported on {platform} backend.")
|
||||
@ -2423,7 +2414,8 @@ def emit_python_callback(
|
||||
return results, token, ifrt_callback
|
||||
|
||||
def build_xla_computation_helper(
|
||||
closed_jaxpr: core.ClosedJaxpr, *, name: str, platform: str,
|
||||
closed_jaxpr: core.ClosedJaxpr, *, name: str,
|
||||
platforms: Sequence[str],
|
||||
backend_or_name: str, axis_context: AxisContext) -> xc.XlaComputation:
|
||||
"""Helper to generate pmap-style XLA computations for custom partitioners."""
|
||||
if closed_jaxpr.effects:
|
||||
@ -2432,7 +2424,7 @@ def build_xla_computation_helper(
|
||||
backend_or_name=backend_or_name, ordered_effects=[],
|
||||
name_stack=source_info_util.NameStack(),
|
||||
donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
|
||||
axis_context=axis_context, platform=platform,
|
||||
axis_context=axis_context, platforms=platforms,
|
||||
lowering_parameters=LoweringParameters())
|
||||
return xc._xla.mlir.mlir_module_to_xla_computation(
|
||||
module_to_string(lowering_result.module), use_tuple_args=False,
|
||||
|
@ -758,7 +758,7 @@ def lower_parallel_callable(
|
||||
closed_jaxpr,
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
platform=lowering_parameters.override_platform or backend.platform,
|
||||
platforms=lowering_parameters.platforms or (backend.platform,),
|
||||
axis_context=sharding_impls.ReplicaAxisContext(axis_env),
|
||||
name_stack=name_stack,
|
||||
donated_args=donated_invars,
|
||||
@ -1362,7 +1362,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
|
||||
call_jaxpr, backend=None, in_axes, out_axes,
|
||||
donated_invars, is_explicit_global_axis_size):
|
||||
del donated_invars # Unused.
|
||||
mlir.check_backend_matches(backend, ctx.module_context.platform)
|
||||
mlir.check_backend_matches(backend, ctx.module_context.platforms)
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
if ctx.module_context.axis_env.names and devices is not None:
|
||||
@ -1835,7 +1835,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
# Optionally, override the lowering platform
|
||||
platform=lowering_parameters.override_platform or backend.platform,
|
||||
platforms=lowering_parameters.platforms or (backend.platform,),
|
||||
axis_context=axis_ctx,
|
||||
name_stack=name_stack,
|
||||
donated_args=donated_invars,
|
||||
@ -2201,7 +2201,7 @@ def lower_mesh_computation(
|
||||
closed_jaxpr,
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
platform=lowering_parameters.override_platform or backend.platform,
|
||||
platforms=lowering_parameters.platforms or (backend.platform,),
|
||||
axis_context=axis_ctx,
|
||||
name_stack=name_stack,
|
||||
donated_args=donated_invars,
|
||||
|
@ -725,12 +725,7 @@ def _allreduce_abstract_eval(*args, axes, axis_index_groups):
|
||||
for arg, named_shape in zip(args, named_shapes)]
|
||||
|
||||
def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
# TODO(necula): clean this up when we have module_context.platforms
|
||||
if ctx.module_context.lowering_parameters.is_multi_platform:
|
||||
for_tpu = ("tpu" in ctx.module_context.lowering_parameters.platforms)
|
||||
else:
|
||||
for_tpu = (ctx.module_context.platform == "tpu")
|
||||
if axis_index_groups is not None and for_tpu:
|
||||
if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms):
|
||||
len_0 = len(axis_index_groups[0])
|
||||
if any(len(g) != len_0 for g in axis_index_groups):
|
||||
raise ValueError("axis_index_groups must all be the same size for TPU lowering")
|
||||
|
@ -1305,7 +1305,8 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
|
||||
global_axis_sizes,
|
||||
spmd_in_axes, spmd_out_axes,
|
||||
axis_resources, resource_env, backend):
|
||||
mlir.check_backend_matches(backend, ctx.module_context.platform)
|
||||
mlir.check_backend_matches(backend, ctx.module_context.platforms)
|
||||
del backend, donated_invars
|
||||
# The only way for any of those two assertions to be violated is when xmap
|
||||
# is using the SPMD lowering, but then this rule shouldn't even trigger.
|
||||
assert spmd_in_axes is None and spmd_out_axes is None
|
||||
@ -1381,7 +1382,8 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
donated_invars, global_axis_sizes, spmd_in_axes,
|
||||
spmd_out_axes, axis_resources,
|
||||
resource_env, backend):
|
||||
mlir.check_backend_matches(backend, ctx.module_context.platform)
|
||||
mlir.check_backend_matches(backend, ctx.module_context.platforms)
|
||||
del backend, donated_invars
|
||||
plan = EvaluationPlan.from_axis_resources(
|
||||
axis_resources, resource_env, global_axis_sizes)
|
||||
|
||||
@ -1447,9 +1449,10 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
||||
donated_invars, global_axis_sizes, spmd_in_axes,
|
||||
spmd_out_axes, axis_resources,
|
||||
resource_env, backend):
|
||||
del donated_invars
|
||||
assert spmd_in_axes is None and spmd_out_axes is None
|
||||
# This first part (up to vtile_manual) is shared with non-MANUAL SPMD rule.
|
||||
mlir.check_backend_matches(backend, ctx.module_context.platform)
|
||||
mlir.check_backend_matches(backend, ctx.module_context.platforms)
|
||||
plan = EvaluationPlan.from_axis_resources(
|
||||
axis_resources, resource_env, global_axis_sizes)
|
||||
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
|
||||
|
@ -1503,7 +1503,9 @@ def pallas_call_lowering(
|
||||
**compiler_params
|
||||
)
|
||||
num_warps = compiler_params.get("num_warps", 4)
|
||||
if ctx.module_context.platform == 'rocm':
|
||||
if len(ctx.module_context.platforms) > 1:
|
||||
raise NotImplementedError("multi-platform lowering for Pallas kernels")
|
||||
if ctx.module_context.platforms[0] == 'rocm':
|
||||
num_stages = compiler_params.get("num_stages", 1)
|
||||
else:
|
||||
num_stages = compiler_params.get("num_stages", 3)
|
||||
@ -1521,7 +1523,7 @@ def pallas_call_lowering(
|
||||
debug=debug,
|
||||
)
|
||||
#Triton returns a tuple for ROCm. We just want file path to be passed
|
||||
if ctx.module_context.platform == 'rocm':
|
||||
if ctx.module_context.platforms[0] == 'rocm':
|
||||
compilation_result.ptx = compilation_result.ptx[1]
|
||||
|
||||
if debug:
|
||||
|
@ -181,7 +181,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
|
||||
built = mlir.build_xla_computation_helper(
|
||||
closed_jaxpr,
|
||||
name="tmp_xla_computation",
|
||||
platform=module_context.platform,
|
||||
platforms=module_context.platforms,
|
||||
backend_or_name=module_context.backend_or_name,
|
||||
axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)),
|
||||
)
|
||||
|
@ -689,7 +689,7 @@ def _wrap_main_func(
|
||||
with ir.InsertionPoint(entry_block):
|
||||
# Make a context just for lowering the dimension value computations
|
||||
module_context = mlir.ModuleContext(
|
||||
backend_or_name="cpu", platform="cpu",
|
||||
backend_or_name="cpu", platforms=["cpu"],
|
||||
axis_context=sharding_impls.ShardingContext([]),
|
||||
name_stack=source_info_util.new_name_stack(),
|
||||
keepalives=[], channel_iterator=itertools.count(1),
|
||||
@ -1179,11 +1179,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
submodule_args = []
|
||||
# All the platforms for the current lowering must be among the platforms
|
||||
# for which the callee was lowered.
|
||||
if ctx.module_context.lowering_parameters.is_multi_platform:
|
||||
assert ctx.module_context.lowering_parameters.platforms is not None
|
||||
lowering_platforms = ctx.module_context.lowering_parameters.platforms
|
||||
else:
|
||||
lowering_platforms = (ctx.module_context.platform,)
|
||||
lowering_platforms = ctx.module_context.platforms
|
||||
|
||||
callee_lowering_platform_index: list[int] = []
|
||||
for platform in lowering_platforms:
|
||||
|
@ -1101,7 +1101,9 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext,
|
||||
flat_results_aval=(),
|
||||
**params):
|
||||
"""MLIR Lowering for `CustomCall`-based HCB."""
|
||||
platform = ctx.module_context.platform
|
||||
if len(ctx.module_context.platforms) > 1:
|
||||
raise NotImplementedError("multi-platform lowering for host_callback")
|
||||
platform = ctx.module_context.platforms[0]
|
||||
use_outfeed = _use_outfeed(platform)
|
||||
if use_outfeed:
|
||||
# Fall back to XLA path if we are using the outfeed
|
||||
@ -1118,7 +1120,7 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext,
|
||||
else:
|
||||
# TODO(necula): It seems that on CPU, with custom call, the device_index
|
||||
# does not work, and the callback is always run on device_index=0
|
||||
if (device_index != 0 and ctx.module_context.platform == "cpu"):
|
||||
if (device_index != 0 and "cpu" in ctx.module_context.platforms):
|
||||
raise ValueError(
|
||||
"The device_index feature on CPU works only when using outfeed.")
|
||||
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
|
||||
|
Loading…
x
Reference in New Issue
Block a user