mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #17827 from gnecula:lowering_params
PiperOrigin-RevId: 569392664
This commit is contained in:
commit
f94bbc18ac
@ -582,7 +582,8 @@ def xla_computation(fun: Callable,
|
||||
wrap_name(fun_name, "xla_computation")),
|
||||
donated_args=donated_invars,
|
||||
arg_shardings=None,
|
||||
result_shardings=None)
|
||||
result_shardings=None,
|
||||
lowering_parameters=mlir.LoweringParameters())
|
||||
built = xc._xla.mlir.mlir_module_to_xla_computation(
|
||||
mlir.module_to_string(lowering_result.module),
|
||||
use_tuple_args=tuple_args,
|
||||
@ -1904,8 +1905,8 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
Returns:
|
||||
A ``Lowered`` instance representing the post-map lowering.
|
||||
"""
|
||||
_experimental_lowering_platform = kwargs.pop(
|
||||
'_experimental_lowering_platform', None)
|
||||
lowering_parameters = kwargs.pop(
|
||||
'_experimental_lowering_parameters', mlir.LoweringParameters())
|
||||
p = _prepare_pmap(
|
||||
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
|
||||
devices, backend, axis_size, args, kwargs)
|
||||
@ -1920,7 +1921,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
donated_invars=p.donated_invars,
|
||||
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
|
||||
avals=abstract_args,
|
||||
lowering_platform=_experimental_lowering_platform)
|
||||
lowering_parameters=lowering_parameters)
|
||||
return stages.Lowered.from_flat_info(
|
||||
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree())
|
||||
|
||||
|
@ -149,7 +149,7 @@ def xla_primitive_callable(
|
||||
computation = sharded_lowering(
|
||||
flat_fun, prim.name, donated_invars, keep_unused=False,
|
||||
inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings,
|
||||
lowering_platform=None)
|
||||
lowering_parameters=mlir.LoweringParameters())
|
||||
compiled = computation.compile()
|
||||
if xla_extension_version >= 192:
|
||||
if config.jax_disable_jit:
|
||||
@ -169,7 +169,8 @@ def xla_primitive_callable(
|
||||
def sharded_lowering(
|
||||
fun: lu.WrappedFun, name: str, donated_invars: Sequence[bool],
|
||||
keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...],
|
||||
in_shardings: Sequence[Sharding | None], lowering_platform: str | None
|
||||
in_shardings: Sequence[Sharding | None],
|
||||
lowering_parameters: mlir.LoweringParameters
|
||||
) -> pxla.MeshComputation:
|
||||
in_shardings_unspec = [UNSPECIFIED if i is None else i for i in in_shardings]
|
||||
|
||||
@ -179,7 +180,8 @@ def sharded_lowering(
|
||||
return pxla.lower_sharding_computation(
|
||||
fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
|
||||
in_avals, keep_unused=keep_unused, inline=inline,
|
||||
devices_from_context=None, lowering_platform=lowering_platform)
|
||||
devices_from_context=None,
|
||||
lowering_parameters=lowering_parameters)
|
||||
|
||||
|
||||
def simple_impl(prim):
|
||||
|
@ -390,12 +390,6 @@ AxisContext = Union[
|
||||
]
|
||||
|
||||
class ShapePolyLoweringState:
|
||||
# The current lowering platforms, a non-empty tuple containing some of
|
||||
# 'cpu', 'cuda', 'rocm', 'tpu'.
|
||||
# TODO: this state should be in ModuleContext, but since for now
|
||||
# multi-platform lowering is implemented only for jax_export, like shape
|
||||
# polymorphism, we keep it here.
|
||||
lowering_platforms: tuple[str, ...]
|
||||
# The names of the dimension variables, sorted by name. This is the order in
|
||||
# which they are passed to the IR functions that need them. This is only
|
||||
# used for native serialization with polymorphic shapes when
|
||||
@ -410,17 +404,48 @@ class ShapePolyLoweringState:
|
||||
# from an inner call to a polymorphic Exported.
|
||||
uses_dim_vars: bool
|
||||
|
||||
def __init__(self, dim_vars: tuple[str, ...],
|
||||
lowering_platforms: tuple[str, ...]):
|
||||
self.lowering_platforms = lowering_platforms
|
||||
# If the first dimension variable is a platform index argument
|
||||
has_platform_index_argument: bool
|
||||
|
||||
def __init__(self,
|
||||
dim_vars: tuple[str, ...],
|
||||
lowering_platforms: tuple[str, ...] | None):
|
||||
self.uses_dim_vars = (len(dim_vars) > 0)
|
||||
if len(lowering_platforms) > 1:
|
||||
if lowering_platforms is not None and len(lowering_platforms) > 1:
|
||||
dim_vars = ("platform_index_",) + tuple(dim_vars)
|
||||
self.has_platform_index_argument = True
|
||||
else:
|
||||
self.has_platform_index_argument = False
|
||||
self.dim_vars = dim_vars
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LoweringParameters:
|
||||
# A mapping between primitives and user-defined LoweringRules.
|
||||
# When lowering a primitive, give priorioty to the rule in this map over
|
||||
# existing Jax rules.
|
||||
override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None
|
||||
|
||||
# The current lowering platforms, a non-empty tuple containing some of
|
||||
# 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are
|
||||
# doing multi-platform lowering, otherwise it can specify cross-platform
|
||||
# lowering. The value None specify default lowering platform.
|
||||
# This is used only in export and jax2tf.
|
||||
platforms: tuple[str, ...] | None = None
|
||||
|
||||
@property
|
||||
def has_platform_index_argument(self):
|
||||
return len(self.lowering_platforms) > 1
|
||||
def override_platform(self) -> str | None:
|
||||
"""Overrides the lowering platform for cross-platform lowering.
|
||||
|
||||
One of 'cpu', 'cuda', 'rocm', 'tpu'.
|
||||
If None, use the default JAX mechanisms to pick the lowering platform.
|
||||
This is currently used for export and jax2tf.
|
||||
"""
|
||||
if self.platforms is not None:
|
||||
return self.platforms[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModuleContext:
|
||||
@ -443,10 +468,7 @@ class ModuleContext:
|
||||
cached_primitive_lowerings: dict[Any, func_dialect.FuncOp]
|
||||
cached_call_jaxpr_lowerings: dict[Any, func_dialect.FuncOp]
|
||||
|
||||
# A mapping between primitives and user-defined LoweringRules.
|
||||
# When lowering a primitive, give priorioty to the rule in this map over
|
||||
# existing Jax rules.
|
||||
override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None
|
||||
lowering_parameters: LoweringParameters
|
||||
|
||||
@property
|
||||
def axis_env(self) -> sharding_impls.AxisEnv:
|
||||
@ -454,6 +476,7 @@ class ModuleContext:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
backend_or_name: str | xb.XlaBackend | None,
|
||||
platform: str,
|
||||
axis_context: AxisContext,
|
||||
@ -461,6 +484,7 @@ class ModuleContext:
|
||||
keepalives: list[Any],
|
||||
channel_iterator: Iterator[int],
|
||||
host_callbacks: list[Any],
|
||||
lowering_parameters: LoweringParameters,
|
||||
context: ir.Context | None = None,
|
||||
module: ir.Module | None = None,
|
||||
ip: ir.InsertionPoint | None = None,
|
||||
@ -469,8 +493,6 @@ class ModuleContext:
|
||||
func_dialect.FuncOp]) = None,
|
||||
cached_call_jaxpr_lowerings: None | (dict[Any,
|
||||
func_dialect.FuncOp]) = None,
|
||||
override_lowering_rules: None | (
|
||||
tuple[tuple[core.Primitive, LoweringRule]]) = None,
|
||||
shape_poly_state = None):
|
||||
assert platform is not None
|
||||
self.context = context or make_ir_context()
|
||||
@ -489,9 +511,9 @@ class ModuleContext:
|
||||
self.cached_call_jaxpr_lowerings = ({}
|
||||
if cached_call_jaxpr_lowerings is None
|
||||
else cached_call_jaxpr_lowerings)
|
||||
self.override_lowering_rules = override_lowering_rules
|
||||
self.shape_poly_state = shape_poly_state or ShapePolyLoweringState((),
|
||||
(platform,))
|
||||
self.lowering_parameters = lowering_parameters
|
||||
|
||||
@property
|
||||
def backend(self) -> xb.XlaBackend:
|
||||
@ -664,6 +686,7 @@ def _get_mem_kind(s: Optional[XLACompatibleSharding]) -> Optional[str]:
|
||||
def lower_jaxpr_to_module(
|
||||
module_name: str,
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
*,
|
||||
ordered_effects: list[core.Effect],
|
||||
backend_or_name: str | xb.XlaBackend | None,
|
||||
platform: str | tuple[str, ...],
|
||||
@ -678,24 +701,19 @@ def lower_jaxpr_to_module(
|
||||
num_replicas: int = 1,
|
||||
num_partitions: int = 1,
|
||||
all_default_mem_kind: bool = True,
|
||||
override_lowering_rules: None | (
|
||||
tuple[tuple[core.Primitive, LoweringRule]]) = None,
|
||||
lowering_parameters: LoweringParameters,
|
||||
) -> LoweringResult:
|
||||
"""Lowers a top-level jaxpr to an MLIR module.
|
||||
|
||||
Handles the quirks of the argument/return value passing conventions of the
|
||||
runtime.
|
||||
"""
|
||||
# TODO(necula): for now we receive the tuple of lowering platforms through
|
||||
# the `platform` arg. For now we lower only for the first specified platform
|
||||
# TODO(necula): change to "platforms" here and elsewhere.
|
||||
if isinstance(platform, str):
|
||||
platforms = (platform,)
|
||||
else:
|
||||
platforms = tuple(platform) # type: ignore
|
||||
platform = platform[0]
|
||||
if lowering_parameters.platforms is not None:
|
||||
# Only for multi-platform lowering
|
||||
# TODO(necula): for now we lower only for the first platform
|
||||
platform = lowering_parameters.platforms[0]
|
||||
|
||||
platform = xb.canonicalize_platform(platform)
|
||||
platform = xb.canonicalize_platform(platform) # type: ignore
|
||||
if not xb.is_known_platform(platform):
|
||||
raise ValueError(f"Unknown platform {platform}")
|
||||
input_output_aliases = None
|
||||
@ -750,11 +768,16 @@ def lower_jaxpr_to_module(
|
||||
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
|
||||
if result_shardings is not None else result_shardings)
|
||||
|
||||
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
|
||||
keepalives, channel_iter, host_callbacks,
|
||||
override_lowering_rules=override_lowering_rules,
|
||||
shape_poly_state=ShapePolyLoweringState(dim_vars,
|
||||
platforms))
|
||||
ctx = ModuleContext(backend_or_name=backend_or_name,
|
||||
platform=platform, axis_context=axis_context,
|
||||
name_stack=name_stack,
|
||||
keepalives=keepalives,
|
||||
channel_iterator=channel_iter,
|
||||
host_callbacks=host_callbacks,
|
||||
lowering_parameters=lowering_parameters,
|
||||
shape_poly_state=ShapePolyLoweringState(
|
||||
dim_vars,
|
||||
lowering_parameters.platforms))
|
||||
with ctx.context, ir.Location.unknown(ctx.context):
|
||||
# Remove module name characters that XLA would alter. This ensures that
|
||||
# XLA computation preserves the module name.
|
||||
@ -1292,9 +1315,9 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
env[v] = tuple(node)
|
||||
|
||||
def get_lowering(primitive: core.Primitive) -> LoweringRule | None:
|
||||
if ctx.override_lowering_rules is None:
|
||||
if ctx.lowering_parameters.override_lowering_rules is None:
|
||||
return None
|
||||
for p, rule in ctx.override_lowering_rules:
|
||||
for p, rule in ctx.lowering_parameters.override_lowering_rules:
|
||||
if primitive is p:
|
||||
return rule
|
||||
return None
|
||||
@ -2187,7 +2210,8 @@ def build_xla_computation_helper(
|
||||
backend_or_name=backend_or_name, ordered_effects=[],
|
||||
name_stack=source_info_util.NameStack(),
|
||||
donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
|
||||
axis_context=axis_context, platform=platform)
|
||||
axis_context=axis_context, platform=platform,
|
||||
lowering_parameters=LoweringParameters())
|
||||
return xc._xla.mlir.mlir_module_to_xla_computation(
|
||||
module_to_string(lowering_result.module), use_tuple_args=False,
|
||||
return_tuple=False)
|
||||
|
@ -560,7 +560,8 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
pmap_computation = lower_parallel_callable(
|
||||
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
|
||||
in_axes, out_axes_thunk, donated_invars,
|
||||
is_explicit_global_axis_size, avals, lowering_platform=None)
|
||||
is_explicit_global_axis_size, avals,
|
||||
lowering_parameters=mlir.LoweringParameters())
|
||||
pmap_executable = pmap_computation.compile()
|
||||
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
|
||||
|
||||
@ -661,7 +662,7 @@ def lower_parallel_callable(
|
||||
is_explicit_global_axis_size: bool,
|
||||
avals: Sequence[core.AbstractValue],
|
||||
*,
|
||||
lowering_platform: str | None):
|
||||
lowering_parameters: mlir.LoweringParameters):
|
||||
# Determine global_axis_size for use in AxisEnv.
|
||||
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
|
||||
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
|
||||
@ -755,18 +756,19 @@ def lower_parallel_callable(
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
ordered_effects,
|
||||
backend,
|
||||
lowering_platform or backend.platform,
|
||||
sharding_impls.ReplicaAxisContext(axis_env),
|
||||
name_stack,
|
||||
donated_invars,
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
platform=lowering_parameters.override_platform or backend.platform,
|
||||
axis_context=sharding_impls.ReplicaAxisContext(axis_env),
|
||||
name_stack=name_stack,
|
||||
donated_args=donated_invars,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=None,
|
||||
result_shardings=None,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=replicas.num_global_replicas)
|
||||
num_replicas=replicas.num_global_replicas,
|
||||
lowering_parameters=lowering_parameters)
|
||||
return PmapComputation(lowering_result.module, pci=pci, replicas=replicas,
|
||||
shards=shards, tuple_args=tuple_args,
|
||||
unordered_effects=unordered_effects,
|
||||
@ -1784,9 +1786,9 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
|
||||
@weakref_lru_cache
|
||||
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
semantic_in_shardings, semantic_out_shardings,
|
||||
da_object, lowering_platform,
|
||||
da_object,
|
||||
donated_invars, name_stack, all_default_mem_kind,
|
||||
override_lowering_rules):
|
||||
lowering_parameters: mlir.LoweringParameters):
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
in_shardings = semantic_in_shardings.shardings
|
||||
out_shardings = semantic_out_shardings.shardings
|
||||
@ -1848,13 +1850,13 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
ordered_effects,
|
||||
backend,
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
# Optionally, override the lowering platform
|
||||
lowering_platform or backend.platform,
|
||||
axis_ctx,
|
||||
name_stack,
|
||||
donated_invars,
|
||||
platform=lowering_parameters.override_platform or backend.platform,
|
||||
axis_context=axis_ctx,
|
||||
name_stack=name_stack,
|
||||
donated_args=donated_invars,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=in_mlir_shardings,
|
||||
result_shardings=out_mlir_shardings,
|
||||
@ -1863,7 +1865,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
num_replicas=nreps,
|
||||
num_partitions=num_partitions,
|
||||
all_default_mem_kind=all_default_mem_kind,
|
||||
override_lowering_rules=override_lowering_rules)
|
||||
lowering_parameters=lowering_parameters)
|
||||
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
@ -1969,9 +1971,7 @@ def lower_sharding_computation(
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
devices_from_context: Sequence[xc.Device] | None = None,
|
||||
lowering_platform: str | None,
|
||||
override_lowering_rules: None | (
|
||||
tuple[tuple[core.Primitive, mlir.LoweringRule]]) = None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
) -> MeshComputation:
|
||||
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
|
||||
|
||||
@ -2048,8 +2048,9 @@ def lower_sharding_computation(
|
||||
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
||||
semantic_out_shardings, da_object, lowering_platform,
|
||||
donated_invars, name_stack, all_default_mem_kind, override_lowering_rules)
|
||||
semantic_out_shardings, da_object,
|
||||
donated_invars, name_stack, all_default_mem_kind,
|
||||
lowering_parameters=lowering_parameters)
|
||||
|
||||
# backend and device_assignment is passed through to MeshExecutable because
|
||||
# if keep_unused=False and all in_shardings are pruned, then there is no way
|
||||
@ -2108,7 +2109,7 @@ def lower_mesh_computation(
|
||||
spmd_lowering: bool,
|
||||
global_in_avals: Sequence[core.ShapedArray],
|
||||
tiling_method: TilingMethod | None,
|
||||
lowering_platform: str | None) -> MeshComputation:
|
||||
lowering_parameters: mlir.LoweringParameters) -> MeshComputation:
|
||||
assert not mesh.empty
|
||||
backend = xb.get_device_backend(mesh.devices.flat[0])
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
||||
@ -2216,19 +2217,20 @@ def lower_mesh_computation(
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
ordered_effects,
|
||||
backend,
|
||||
lowering_platform or backend.platform,
|
||||
axis_ctx,
|
||||
name_stack,
|
||||
donated_invars,
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
platform=lowering_parameters.platforms or backend.platform,
|
||||
axis_context=axis_ctx,
|
||||
name_stack=name_stack,
|
||||
donated_args=donated_invars,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=in_partitions,
|
||||
result_shardings=out_partitions,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=num_replicas,
|
||||
num_partitions=num_partitions)
|
||||
num_partitions=num_partitions,
|
||||
lowering_parameters=lowering_parameters)
|
||||
|
||||
return MeshComputation(
|
||||
str(name_stack),
|
||||
|
@ -605,8 +605,8 @@ def xmap(fun: Callable,
|
||||
|
||||
@decorate_serial
|
||||
def lower(*args, **kwargs):
|
||||
_experimental_lowering_platform = kwargs.pop(
|
||||
'_experimental_lowering_platform', None)
|
||||
lowering_parameters = kwargs.pop(
|
||||
'_experimental_lowering_platform', mlir.LoweringParameters())
|
||||
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
|
||||
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
|
||||
computation = make_xmap_callable(
|
||||
@ -614,7 +614,7 @@ def xmap(fun: Callable,
|
||||
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
|
||||
params['resource_env'], params['backend'], params['spmd_in_axes'],
|
||||
params['spmd_out_axes_thunk'],
|
||||
_experimental_lowering_platform, *avals_flat)
|
||||
lowering_parameters, *avals_flat)
|
||||
|
||||
in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
|
||||
in_avals = in_tree.unflatten(avals_flat)
|
||||
@ -633,7 +633,7 @@ def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_
|
||||
fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
|
||||
axis_resources, resource_env, backend,
|
||||
spmd_in_axes, spmd_out_axes_thunk,
|
||||
None, *in_avals).compile().unsafe_call
|
||||
mlir.LoweringParameters(), *in_avals).compile().unsafe_call
|
||||
distributed_debug_log(("Running xmapped function", name),
|
||||
("python function", fun.f),
|
||||
("mesh", resource_env.physical_mesh),
|
||||
@ -646,7 +646,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
in_axes, out_axes_thunk, donated_invars,
|
||||
global_axis_sizes, axis_resources, resource_env, backend,
|
||||
spmd_in_axes, spmd_out_axes_thunk,
|
||||
lowering_platform: Optional[str],
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
*in_avals):
|
||||
plan = EvaluationPlan.from_axis_resources(
|
||||
axis_resources, resource_env, global_axis_sizes)
|
||||
@ -700,11 +700,11 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
in_shardings, out_shardings, donated_invars,
|
||||
use_spmd_lowering, in_avals,
|
||||
tiling_method=tiling_method,
|
||||
lowering_platform=lowering_platform)
|
||||
lowering_parameters=lowering_parameters)
|
||||
else:
|
||||
return dispatch.sharded_lowering(
|
||||
f, name, donated_invars, True, False, in_avals, (None,) * len(in_avals),
|
||||
lowering_platform=lowering_platform)
|
||||
lowering_parameters=lowering_parameters)
|
||||
|
||||
|
||||
class EvaluationPlan(NamedTuple):
|
||||
|
@ -325,10 +325,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs):
|
||||
_experimental_lowering_platform = kwargs.pop(
|
||||
'_experimental_lowering_platform', None)
|
||||
_experimental_override_lowering_rules = kwargs.pop(
|
||||
'_experimental_override_lowering_rules', None)
|
||||
lowering_parameters = kwargs.pop(
|
||||
'_experimental_lowering_parameters', mlir.LoweringParameters())
|
||||
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
|
||||
donated_invars) = infer_params_fn(*args, **kwargs)
|
||||
resource_env = params['resource_env']
|
||||
@ -340,8 +338,7 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
params['jaxpr'], in_shardings, params['out_shardings'],
|
||||
params['resource_env'], params['donated_invars'], params['name'],
|
||||
params['keep_unused'], params['inline'],
|
||||
lowering_platform=_experimental_lowering_platform,
|
||||
override_lowering_rules=_experimental_override_lowering_rules)
|
||||
lowering_parameters=lowering_parameters)
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
api_name = 'jit' if params['resource_env'] is None else 'pjit'
|
||||
@ -1131,7 +1128,7 @@ def _pjit_call_impl_python(
|
||||
compiled = _pjit_lower(
|
||||
jaxpr, in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name, keep_unused, inline,
|
||||
lowering_platform=None).compile()
|
||||
lowering_parameters=mlir.LoweringParameters()).compile()
|
||||
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
|
||||
# This check is expensive so only do it if enable_checks is on.
|
||||
if compiled._auto_spmd_lowering and config.jax_enable_checks:
|
||||
@ -1273,9 +1270,7 @@ def _pjit_lower_cached(
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
*,
|
||||
lowering_platform: Optional[str],
|
||||
override_lowering_rules: Optional[
|
||||
tuple[tuple[core.Primitive, mlir.LoweringRule]]] = None):
|
||||
lowering_parameters: mlir.LoweringParameters):
|
||||
in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast(
|
||||
tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
|
||||
out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings
|
||||
@ -1298,7 +1293,7 @@ def _pjit_lower_cached(
|
||||
jaxpr, api_name, name, mesh,
|
||||
in_shardings, out_shardings, donated_invars,
|
||||
True, jaxpr.in_avals, tiling_method=None,
|
||||
lowering_platform=lowering_platform)
|
||||
lowering_parameters=lowering_parameters)
|
||||
else:
|
||||
return pxla.lower_sharding_computation(
|
||||
jaxpr, api_name, name, in_shardings, out_shardings,
|
||||
@ -1306,8 +1301,7 @@ def _pjit_lower_cached(
|
||||
keep_unused=keep_unused, inline=inline,
|
||||
devices_from_context=(
|
||||
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
|
||||
lowering_platform=lowering_platform,
|
||||
override_lowering_rules=override_lowering_rules,
|
||||
lowering_parameters=lowering_parameters,
|
||||
)
|
||||
|
||||
|
||||
|
@ -421,7 +421,9 @@ def export(fun_jax: Callable,
|
||||
shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions
|
||||
lowered = wrapped_fun_jax.lower(
|
||||
*args_specs, **kwargs_specs,
|
||||
_experimental_lowering_platform=lowering_platforms)
|
||||
_experimental_lowering_parameters=mlir.LoweringParameters(
|
||||
platforms=lowering_platforms,
|
||||
))
|
||||
|
||||
lowering = lowered._lowering # type: ignore
|
||||
_check_lowering(lowering)
|
||||
@ -601,9 +603,12 @@ def _wrap_main_func(
|
||||
entry_block = new_main_op.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
module_context = mlir.ModuleContext(
|
||||
"cpu", "cpu", sharding_impls.ShardingContext([]),
|
||||
source_info_util.new_name_stack(),
|
||||
[], itertools.count(1), [], module=wrapped_module, context=context)
|
||||
backend_or_name="cpu", platform="cpu",
|
||||
axis_context=sharding_impls.ShardingContext([]),
|
||||
name_stack=source_info_util.new_name_stack(),
|
||||
keepalives=[], channel_iterator=itertools.count(1),
|
||||
host_callbacks=[], module=wrapped_module, context=context,
|
||||
lowering_parameters=mlir.LoweringParameters())
|
||||
ctx = mlir.LoweringRuleContext(
|
||||
module_context=module_context, primitive=None,
|
||||
avals_in=args_avals_flat, avals_out=None,
|
||||
|
@ -16,6 +16,7 @@ from jax._src.interpreters.mlir import (
|
||||
AxisContext as AxisContext,
|
||||
ConstantHandler as ConstantHandler,
|
||||
DEVICE_TO_DEVICE_TYPE as DEVICE_TO_DEVICE_TYPE,
|
||||
LoweringParameters as LoweringParameters,
|
||||
LoweringResult as LoweringResult,
|
||||
LoweringRule as LoweringRule,
|
||||
LoweringRuleContext as LoweringRuleContext,
|
||||
|
@ -10347,7 +10347,8 @@ class OverrideLoweringTest(jtu.JaxTestCase):
|
||||
lowered_ir = (
|
||||
jax.jit(f)
|
||||
.lower(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16),
|
||||
_experimental_override_lowering_rules=rules).as_text())
|
||||
_experimental_lowering_parameters=mlir.LoweringParameters(
|
||||
override_lowering_rules=rules)).as_text())
|
||||
self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir)
|
||||
|
||||
|
||||
|
@ -570,7 +570,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.arange(5, dtype=np.float32)
|
||||
# TODO: use a function with different behavior for different platforms
|
||||
exp = export.export(jnp.sin,
|
||||
lowering_platforms=('cpu', 'tpu'))(x)
|
||||
lowering_platforms=('cpu', 'tpu'))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu'))
|
||||
module_str = str(exp.mlir_module())
|
||||
platform_index = re.findall(
|
||||
|
Loading…
x
Reference in New Issue
Block a user