Merge pull request #17827 from gnecula:lowering_params

PiperOrigin-RevId: 569392664
This commit is contained in:
jax authors 2023-09-28 23:08:28 -07:00
commit f94bbc18ac
10 changed files with 132 additions and 102 deletions

View File

@ -582,7 +582,8 @@ def xla_computation(fun: Callable,
wrap_name(fun_name, "xla_computation")),
donated_args=donated_invars,
arg_shardings=None,
result_shardings=None)
result_shardings=None,
lowering_parameters=mlir.LoweringParameters())
built = xc._xla.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(lowering_result.module),
use_tuple_args=tuple_args,
@ -1904,8 +1905,8 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
Returns:
A ``Lowered`` instance representing the post-map lowering.
"""
_experimental_lowering_platform = kwargs.pop(
'_experimental_lowering_platform', None)
lowering_parameters = kwargs.pop(
'_experimental_lowering_parameters', mlir.LoweringParameters())
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
devices, backend, axis_size, args, kwargs)
@ -1920,7 +1921,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
donated_invars=p.donated_invars,
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
avals=abstract_args,
lowering_platform=_experimental_lowering_platform)
lowering_parameters=lowering_parameters)
return stages.Lowered.from_flat_info(
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree())

View File

@ -149,7 +149,7 @@ def xla_primitive_callable(
computation = sharded_lowering(
flat_fun, prim.name, donated_invars, keep_unused=False,
inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings,
lowering_platform=None)
lowering_parameters=mlir.LoweringParameters())
compiled = computation.compile()
if xla_extension_version >= 192:
if config.jax_disable_jit:
@ -169,7 +169,8 @@ def xla_primitive_callable(
def sharded_lowering(
fun: lu.WrappedFun, name: str, donated_invars: Sequence[bool],
keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...],
in_shardings: Sequence[Sharding | None], lowering_platform: str | None
in_shardings: Sequence[Sharding | None],
lowering_parameters: mlir.LoweringParameters
) -> pxla.MeshComputation:
in_shardings_unspec = [UNSPECIFIED if i is None else i for i in in_shardings]
@ -179,7 +180,8 @@ def sharded_lowering(
return pxla.lower_sharding_computation(
fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
in_avals, keep_unused=keep_unused, inline=inline,
devices_from_context=None, lowering_platform=lowering_platform)
devices_from_context=None,
lowering_parameters=lowering_parameters)
def simple_impl(prim):

View File

@ -390,12 +390,6 @@ AxisContext = Union[
]
class ShapePolyLoweringState:
# The current lowering platforms, a non-empty tuple containing some of
# 'cpu', 'cuda', 'rocm', 'tpu'.
# TODO: this state should be in ModuleContext, but since for now
# multi-platform lowering is implemented only for jax_export, like shape
# polymorphism, we keep it here.
lowering_platforms: tuple[str, ...]
# The names of the dimension variables, sorted by name. This is the order in
# which they are passed to the IR functions that need them. This is only
# used for native serialization with polymorphic shapes when
@ -410,17 +404,48 @@ class ShapePolyLoweringState:
# from an inner call to a polymorphic Exported.
uses_dim_vars: bool
def __init__(self, dim_vars: tuple[str, ...],
lowering_platforms: tuple[str, ...]):
self.lowering_platforms = lowering_platforms
# If the first dimension variable is a platform index argument
has_platform_index_argument: bool
def __init__(self,
dim_vars: tuple[str, ...],
lowering_platforms: tuple[str, ...] | None):
self.uses_dim_vars = (len(dim_vars) > 0)
if len(lowering_platforms) > 1:
if lowering_platforms is not None and len(lowering_platforms) > 1:
dim_vars = ("platform_index_",) + tuple(dim_vars)
self.has_platform_index_argument = True
else:
self.has_platform_index_argument = False
self.dim_vars = dim_vars
@dataclasses.dataclass(frozen=True)
class LoweringParameters:
# A mapping between primitives and user-defined LoweringRules.
# When lowering a primitive, give priorioty to the rule in this map over
# existing Jax rules.
override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None
# The current lowering platforms, a non-empty tuple containing some of
# 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are
# doing multi-platform lowering, otherwise it can specify cross-platform
# lowering. The value None specify default lowering platform.
# This is used only in export and jax2tf.
platforms: tuple[str, ...] | None = None
@property
def has_platform_index_argument(self):
return len(self.lowering_platforms) > 1
def override_platform(self) -> str | None:
"""Overrides the lowering platform for cross-platform lowering.
One of 'cpu', 'cuda', 'rocm', 'tpu'.
If None, use the default JAX mechanisms to pick the lowering platform.
This is currently used for export and jax2tf.
"""
if self.platforms is not None:
return self.platforms[0]
else:
return None
@dataclasses.dataclass
class ModuleContext:
@ -443,10 +468,7 @@ class ModuleContext:
cached_primitive_lowerings: dict[Any, func_dialect.FuncOp]
cached_call_jaxpr_lowerings: dict[Any, func_dialect.FuncOp]
# A mapping between primitives and user-defined LoweringRules.
# When lowering a primitive, give priorioty to the rule in this map over
# existing Jax rules.
override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None
lowering_parameters: LoweringParameters
@property
def axis_env(self) -> sharding_impls.AxisEnv:
@ -454,6 +476,7 @@ class ModuleContext:
def __init__(
self,
*,
backend_or_name: str | xb.XlaBackend | None,
platform: str,
axis_context: AxisContext,
@ -461,6 +484,7 @@ class ModuleContext:
keepalives: list[Any],
channel_iterator: Iterator[int],
host_callbacks: list[Any],
lowering_parameters: LoweringParameters,
context: ir.Context | None = None,
module: ir.Module | None = None,
ip: ir.InsertionPoint | None = None,
@ -469,8 +493,6 @@ class ModuleContext:
func_dialect.FuncOp]) = None,
cached_call_jaxpr_lowerings: None | (dict[Any,
func_dialect.FuncOp]) = None,
override_lowering_rules: None | (
tuple[tuple[core.Primitive, LoweringRule]]) = None,
shape_poly_state = None):
assert platform is not None
self.context = context or make_ir_context()
@ -489,9 +511,9 @@ class ModuleContext:
self.cached_call_jaxpr_lowerings = ({}
if cached_call_jaxpr_lowerings is None
else cached_call_jaxpr_lowerings)
self.override_lowering_rules = override_lowering_rules
self.shape_poly_state = shape_poly_state or ShapePolyLoweringState((),
(platform,))
self.lowering_parameters = lowering_parameters
@property
def backend(self) -> xb.XlaBackend:
@ -664,6 +686,7 @@ def _get_mem_kind(s: Optional[XLACompatibleSharding]) -> Optional[str]:
def lower_jaxpr_to_module(
module_name: str,
jaxpr: core.ClosedJaxpr,
*,
ordered_effects: list[core.Effect],
backend_or_name: str | xb.XlaBackend | None,
platform: str | tuple[str, ...],
@ -678,24 +701,19 @@ def lower_jaxpr_to_module(
num_replicas: int = 1,
num_partitions: int = 1,
all_default_mem_kind: bool = True,
override_lowering_rules: None | (
tuple[tuple[core.Primitive, LoweringRule]]) = None,
lowering_parameters: LoweringParameters,
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
Handles the quirks of the argument/return value passing conventions of the
runtime.
"""
# TODO(necula): for now we receive the tuple of lowering platforms through
# the `platform` arg. For now we lower only for the first specified platform
# TODO(necula): change to "platforms" here and elsewhere.
if isinstance(platform, str):
platforms = (platform,)
else:
platforms = tuple(platform) # type: ignore
platform = platform[0]
if lowering_parameters.platforms is not None:
# Only for multi-platform lowering
# TODO(necula): for now we lower only for the first platform
platform = lowering_parameters.platforms[0]
platform = xb.canonicalize_platform(platform)
platform = xb.canonicalize_platform(platform) # type: ignore
if not xb.is_known_platform(platform):
raise ValueError(f"Unknown platform {platform}")
input_output_aliases = None
@ -750,11 +768,16 @@ def lower_jaxpr_to_module(
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
if result_shardings is not None else result_shardings)
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
keepalives, channel_iter, host_callbacks,
override_lowering_rules=override_lowering_rules,
shape_poly_state=ShapePolyLoweringState(dim_vars,
platforms))
ctx = ModuleContext(backend_or_name=backend_or_name,
platform=platform, axis_context=axis_context,
name_stack=name_stack,
keepalives=keepalives,
channel_iterator=channel_iter,
host_callbacks=host_callbacks,
lowering_parameters=lowering_parameters,
shape_poly_state=ShapePolyLoweringState(
dim_vars,
lowering_parameters.platforms))
with ctx.context, ir.Location.unknown(ctx.context):
# Remove module name characters that XLA would alter. This ensures that
# XLA computation preserves the module name.
@ -1292,9 +1315,9 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
env[v] = tuple(node)
def get_lowering(primitive: core.Primitive) -> LoweringRule | None:
if ctx.override_lowering_rules is None:
if ctx.lowering_parameters.override_lowering_rules is None:
return None
for p, rule in ctx.override_lowering_rules:
for p, rule in ctx.lowering_parameters.override_lowering_rules:
if primitive is p:
return rule
return None
@ -2187,7 +2210,8 @@ def build_xla_computation_helper(
backend_or_name=backend_or_name, ordered_effects=[],
name_stack=source_info_util.NameStack(),
donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
axis_context=axis_context, platform=platform)
axis_context=axis_context, platform=platform,
lowering_parameters=LoweringParameters())
return xc._xla.mlir.mlir_module_to_xla_computation(
module_to_string(lowering_result.module), use_tuple_args=False,
return_tuple=False)

View File

@ -560,7 +560,8 @@ def parallel_callable(fun: lu.WrappedFun,
pmap_computation = lower_parallel_callable(
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars,
is_explicit_global_axis_size, avals, lowering_platform=None)
is_explicit_global_axis_size, avals,
lowering_parameters=mlir.LoweringParameters())
pmap_executable = pmap_computation.compile()
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
@ -661,7 +662,7 @@ def lower_parallel_callable(
is_explicit_global_axis_size: bool,
avals: Sequence[core.AbstractValue],
*,
lowering_platform: str | None):
lowering_parameters: mlir.LoweringParameters):
# Determine global_axis_size for use in AxisEnv.
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
@ -755,18 +756,19 @@ def lower_parallel_callable(
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
ordered_effects,
backend,
lowering_platform or backend.platform,
sharding_impls.ReplicaAxisContext(axis_env),
name_stack,
donated_invars,
ordered_effects=ordered_effects,
backend_or_name=backend,
platform=lowering_parameters.override_platform or backend.platform,
axis_context=sharding_impls.ReplicaAxisContext(axis_env),
name_stack=name_stack,
donated_args=donated_invars,
replicated_args=replicated_args,
arg_shardings=None,
result_shardings=None,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=replicas.num_global_replicas)
num_replicas=replicas.num_global_replicas,
lowering_parameters=lowering_parameters)
return PmapComputation(lowering_result.module, pci=pci, replicas=replicas,
shards=shards, tuple_args=tuple_args,
unordered_effects=unordered_effects,
@ -1784,9 +1786,9 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
@weakref_lru_cache
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
semantic_in_shardings, semantic_out_shardings,
da_object, lowering_platform,
da_object,
donated_invars, name_stack, all_default_mem_kind,
override_lowering_rules):
lowering_parameters: mlir.LoweringParameters):
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings.shardings
out_shardings = semantic_out_shardings.shardings
@ -1848,13 +1850,13 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
ordered_effects,
backend,
ordered_effects=ordered_effects,
backend_or_name=backend,
# Optionally, override the lowering platform
lowering_platform or backend.platform,
axis_ctx,
name_stack,
donated_invars,
platform=lowering_parameters.override_platform or backend.platform,
axis_context=axis_ctx,
name_stack=name_stack,
donated_args=donated_invars,
replicated_args=replicated_args,
arg_shardings=in_mlir_shardings,
result_shardings=out_mlir_shardings,
@ -1863,7 +1865,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
num_replicas=nreps,
num_partitions=num_partitions,
all_default_mem_kind=all_default_mem_kind,
override_lowering_rules=override_lowering_rules)
lowering_parameters=lowering_parameters)
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
@ -1969,9 +1971,7 @@ def lower_sharding_computation(
keep_unused: bool,
inline: bool,
devices_from_context: Sequence[xc.Device] | None = None,
lowering_platform: str | None,
override_lowering_rules: None | (
tuple[tuple[core.Primitive, mlir.LoweringRule]]) = None,
lowering_parameters: mlir.LoweringParameters,
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
@ -2048,8 +2048,9 @@ def lower_sharding_computation(
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
semantic_out_shardings, da_object, lowering_platform,
donated_invars, name_stack, all_default_mem_kind, override_lowering_rules)
semantic_out_shardings, da_object,
donated_invars, name_stack, all_default_mem_kind,
lowering_parameters=lowering_parameters)
# backend and device_assignment is passed through to MeshExecutable because
# if keep_unused=False and all in_shardings are pruned, then there is no way
@ -2108,7 +2109,7 @@ def lower_mesh_computation(
spmd_lowering: bool,
global_in_avals: Sequence[core.ShapedArray],
tiling_method: TilingMethod | None,
lowering_platform: str | None) -> MeshComputation:
lowering_parameters: mlir.LoweringParameters) -> MeshComputation:
assert not mesh.empty
backend = xb.get_device_backend(mesh.devices.flat[0])
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
@ -2216,19 +2217,20 @@ def lower_mesh_computation(
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
ordered_effects,
backend,
lowering_platform or backend.platform,
axis_ctx,
name_stack,
donated_invars,
ordered_effects=ordered_effects,
backend_or_name=backend,
platform=lowering_parameters.platforms or backend.platform,
axis_context=axis_ctx,
name_stack=name_stack,
donated_args=donated_invars,
replicated_args=replicated_args,
arg_shardings=in_partitions,
result_shardings=out_partitions,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=num_replicas,
num_partitions=num_partitions)
num_partitions=num_partitions,
lowering_parameters=lowering_parameters)
return MeshComputation(
str(name_stack),

View File

@ -605,8 +605,8 @@ def xmap(fun: Callable,
@decorate_serial
def lower(*args, **kwargs):
_experimental_lowering_platform = kwargs.pop(
'_experimental_lowering_platform', None)
lowering_parameters = kwargs.pop(
'_experimental_lowering_platform', mlir.LoweringParameters())
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
computation = make_xmap_callable(
@ -614,7 +614,7 @@ def xmap(fun: Callable,
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
params['resource_env'], params['backend'], params['spmd_in_axes'],
params['spmd_out_axes_thunk'],
_experimental_lowering_platform, *avals_flat)
lowering_parameters, *avals_flat)
in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
in_avals = in_tree.unflatten(avals_flat)
@ -633,7 +633,7 @@ def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_
fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes_thunk,
None, *in_avals).compile().unsafe_call
mlir.LoweringParameters(), *in_avals).compile().unsafe_call
distributed_debug_log(("Running xmapped function", name),
("python function", fun.f),
("mesh", resource_env.physical_mesh),
@ -646,7 +646,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
in_axes, out_axes_thunk, donated_invars,
global_axis_sizes, axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes_thunk,
lowering_platform: Optional[str],
lowering_parameters: mlir.LoweringParameters,
*in_avals):
plan = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes)
@ -700,11 +700,11 @@ def make_xmap_callable(fun: lu.WrappedFun,
in_shardings, out_shardings, donated_invars,
use_spmd_lowering, in_avals,
tiling_method=tiling_method,
lowering_platform=lowering_platform)
lowering_parameters=lowering_parameters)
else:
return dispatch.sharded_lowering(
f, name, donated_invars, True, False, in_avals, (None,) * len(in_avals),
lowering_platform=lowering_platform)
lowering_parameters=lowering_parameters)
class EvaluationPlan(NamedTuple):

View File

@ -325,10 +325,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
@api_boundary
def lower(*args, **kwargs):
_experimental_lowering_platform = kwargs.pop(
'_experimental_lowering_platform', None)
_experimental_override_lowering_rules = kwargs.pop(
'_experimental_override_lowering_rules', None)
lowering_parameters = kwargs.pop(
'_experimental_lowering_parameters', mlir.LoweringParameters())
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
donated_invars) = infer_params_fn(*args, **kwargs)
resource_env = params['resource_env']
@ -340,8 +338,7 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
params['jaxpr'], in_shardings, params['out_shardings'],
params['resource_env'], params['donated_invars'], params['name'],
params['keep_unused'], params['inline'],
lowering_platform=_experimental_lowering_platform,
override_lowering_rules=_experimental_override_lowering_rules)
lowering_parameters=lowering_parameters)
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if params['resource_env'] is None else 'pjit'
@ -1131,7 +1128,7 @@ def _pjit_call_impl_python(
compiled = _pjit_lower(
jaxpr, in_shardings, out_shardings, resource_env,
donated_invars, name, keep_unused, inline,
lowering_platform=None).compile()
lowering_parameters=mlir.LoweringParameters()).compile()
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
# This check is expensive so only do it if enable_checks is on.
if compiled._auto_spmd_lowering and config.jax_enable_checks:
@ -1273,9 +1270,7 @@ def _pjit_lower_cached(
keep_unused: bool,
inline: bool,
*,
lowering_platform: Optional[str],
override_lowering_rules: Optional[
tuple[tuple[core.Primitive, mlir.LoweringRule]]] = None):
lowering_parameters: mlir.LoweringParameters):
in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast(
tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings
@ -1298,7 +1293,7 @@ def _pjit_lower_cached(
jaxpr, api_name, name, mesh,
in_shardings, out_shardings, donated_invars,
True, jaxpr.in_avals, tiling_method=None,
lowering_platform=lowering_platform)
lowering_parameters=lowering_parameters)
else:
return pxla.lower_sharding_computation(
jaxpr, api_name, name, in_shardings, out_shardings,
@ -1306,8 +1301,7 @@ def _pjit_lower_cached(
keep_unused=keep_unused, inline=inline,
devices_from_context=(
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
lowering_platform=lowering_platform,
override_lowering_rules=override_lowering_rules,
lowering_parameters=lowering_parameters,
)

View File

@ -421,7 +421,9 @@ def export(fun_jax: Callable,
shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions
lowered = wrapped_fun_jax.lower(
*args_specs, **kwargs_specs,
_experimental_lowering_platform=lowering_platforms)
_experimental_lowering_parameters=mlir.LoweringParameters(
platforms=lowering_platforms,
))
lowering = lowered._lowering # type: ignore
_check_lowering(lowering)
@ -601,9 +603,12 @@ def _wrap_main_func(
entry_block = new_main_op.add_entry_block()
with ir.InsertionPoint(entry_block):
module_context = mlir.ModuleContext(
"cpu", "cpu", sharding_impls.ShardingContext([]),
source_info_util.new_name_stack(),
[], itertools.count(1), [], module=wrapped_module, context=context)
backend_or_name="cpu", platform="cpu",
axis_context=sharding_impls.ShardingContext([]),
name_stack=source_info_util.new_name_stack(),
keepalives=[], channel_iterator=itertools.count(1),
host_callbacks=[], module=wrapped_module, context=context,
lowering_parameters=mlir.LoweringParameters())
ctx = mlir.LoweringRuleContext(
module_context=module_context, primitive=None,
avals_in=args_avals_flat, avals_out=None,

View File

@ -16,6 +16,7 @@ from jax._src.interpreters.mlir import (
AxisContext as AxisContext,
ConstantHandler as ConstantHandler,
DEVICE_TO_DEVICE_TYPE as DEVICE_TO_DEVICE_TYPE,
LoweringParameters as LoweringParameters,
LoweringResult as LoweringResult,
LoweringRule as LoweringRule,
LoweringRuleContext as LoweringRuleContext,

View File

@ -10347,7 +10347,8 @@ class OverrideLoweringTest(jtu.JaxTestCase):
lowered_ir = (
jax.jit(f)
.lower(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16),
_experimental_override_lowering_rules=rules).as_text())
_experimental_lowering_parameters=mlir.LoweringParameters(
override_lowering_rules=rules)).as_text())
self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir)

View File

@ -570,7 +570,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(5, dtype=np.float32)
# TODO: use a function with different behavior for different platforms
exp = export.export(jnp.sin,
lowering_platforms=('cpu', 'tpu'))(x)
lowering_platforms=('cpu', 'tpu'))(x)
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu'))
module_str = str(exp.mlir_module())
platform_index = re.findall(