diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index f80dbc008..6c366fb8f 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -420,6 +420,10 @@ class ModuleContext: cached_primitive_lowerings: dict[Any, func_dialect.FuncOp] cached_call_jaxpr_lowerings: dict[Any, func_dialect.FuncOp] + # A mapping between primitives and user-defined LoweringRules. + # When lowering a primitive, give priorioty to the rule in this map over + # existing Jax rules. + override_lowering_rules: Optional[tuple[tuple[core.Primitive, LoweringRule]]] @property def axis_env(self) -> sharding_impls.AxisEnv: @@ -442,6 +446,8 @@ class ModuleContext: func_dialect.FuncOp]] = None, cached_call_jaxpr_lowerings: Optional[dict[Any, func_dialect.FuncOp]] = None, + override_lowering_rules: Optional[ + tuple[tuple[core.Primitive, LoweringRule]]] = None, shape_poly_state = None): assert platform is not None self.context = context or make_ir_context() @@ -460,6 +466,7 @@ class ModuleContext: self.cached_call_jaxpr_lowerings = ({} if cached_call_jaxpr_lowerings is None else cached_call_jaxpr_lowerings) + self.override_lowering_rules = override_lowering_rules self.shape_poly_state = shape_poly_state or ShapePolyLoweringState(()) @property @@ -639,6 +646,8 @@ def lower_jaxpr_to_module( result_names: Optional[Sequence[Optional[str]]] = None, num_replicas: int = 1, num_partitions: int = 1, + override_lowering_rules: Optional[ + tuple[tuple[core.Primitive, LoweringRule]]] = None, ) -> LoweringResult: """Lowers a top-level jaxpr to an MLIR module. @@ -691,6 +700,7 @@ def lower_jaxpr_to_module( ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack, keepalives, channel_iter, host_callbacks, + override_lowering_rules=override_lowering_rules, shape_poly_state=ShapePolyLoweringState(dim_vars)) with ctx.context, ir.Location.unknown(ctx.context): # Remove module name characters that XLA would alter. This ensures that @@ -1135,6 +1145,13 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, assert node is not None env[v] = tuple(node) + def get_lowering(primitive: core.Primitive) -> Optional[LoweringRule]: + if ctx.override_lowering_rules is None: + return None + for p, rule in ctx.override_lowering_rules: + if primitive is p: + return rule + return None env: dict[core.Var, tuple[ir.Value, ...]] = {} @@ -1153,7 +1170,10 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, loc = _source_info_to_location(eqn.primitive, eqn.params, source_info, ctx.name_stack) with source_info_util.user_context(eqn.source_info.traceback), loc: - if eqn.primitive in _platform_specific_lowerings[ctx.platform]: + override_rule = get_lowering(eqn.primitive) + if override_rule is not None: + rule = override_rule + elif eqn.primitive in _platform_specific_lowerings[ctx.platform]: rule = _platform_specific_lowerings[ctx.platform][eqn.primitive] elif eqn.primitive in xla._backend_specific_translations[ctx.platform]: rule = xla_fallback_lowering(eqn.primitive) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 5fec7eeb8..f27890750 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1870,7 +1870,7 @@ class SemanticallyEqualShardings: def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, da_object, lowering_platform, - donated_invars, name_stack): + donated_invars, name_stack, override_lowering_rules): jaxpr = closed_jaxpr.jaxpr in_shardings = semantic_in_shardings.shardings out_shardings = semantic_out_shardings.shardings @@ -1940,7 +1940,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, num_replicas=nreps, - num_partitions=num_partitions) + num_partitions=num_partitions, + override_lowering_rules=override_lowering_rules) tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform) unordered_effects = list( effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) @@ -1998,6 +1999,8 @@ def lower_sharding_computation( always_lower: bool, devices_from_context: Optional[Sequence[xc.Device]] = None, lowering_platform: Optional[str], + override_lowering_rules: Optional[ + tuple[tuple[core.Primitive, mlir.LoweringRule]]] = None, ) -> MeshComputation: """Lowers a computation to XLA. It can take arbitrary shardings as input. @@ -2084,7 +2087,7 @@ def lower_sharding_computation( nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, da_object, lowering_platform, - donated_invars, name_stack) + donated_invars, name_stack, override_lowering_rules) # backend and device_assignment is passed through to MeshExecutable because # if keep_unused=False and all in_shardings are pruned, then there is no way diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index e0946da26..c282db5df 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -351,6 +351,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames, def lower(*args, **kwargs): _experimental_lowering_platform = kwargs.pop( '_experimental_lowering_platform', None) + _experimental_override_lowering_rules = kwargs.pop( + '_experimental_override_lowering_rules', None) (args_flat, flat_global_in_avals, params, in_tree, out_tree, donate_argnums) = infer_params_fn(*args, **kwargs) resource_env = params['resource_env'] @@ -362,7 +364,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames, params['jaxpr'], in_shardings, params['out_shardings'], params['resource_env'], params['donated_invars'], params['name'], params['keep_unused'], params['inline'], always_lower=True, - lowering_platform=_experimental_lowering_platform) + lowering_platform=_experimental_lowering_platform, + override_lowering_rules=_experimental_override_lowering_rules) except pxla.DeviceAssignmentMismatchError as e: fails, = e.args api_name = 'jit' if params['resource_env'] is None else 'pjit' @@ -1268,7 +1271,9 @@ def _pjit_lower_cached( inline: bool, always_lower: bool, *, - lowering_platform: Optional[str]): + lowering_platform: Optional[str], + override_lowering_rules: Optional[ + tuple[tuple[core.Primitive, mlir.LoweringRule]]] = None): in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast( tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings) out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings @@ -1299,7 +1304,9 @@ def _pjit_lower_cached( keep_unused=keep_unused, inline=inline, always_lower=always_lower, devices_from_context=( None if mesh is None or mesh.empty else list(mesh.devices.flat)), - lowering_platform=lowering_platform) + lowering_platform=lowering_platform, + override_lowering_rules=override_lowering_rules, +) def pjit_staging_rule(trace, *args, **params): diff --git a/tests/api_test.py b/tests/api_test.py index c78a4de8d..18acaea36 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -10086,5 +10086,25 @@ class DeprecationsTest(jtu.JaxTestCase): with self.assertWarns(DeprecationWarning): self.assertIs(jax.flatten_fun_nokwargs, jax.api_util.flatten_fun_nokwargs) + +class OverrideLoweringTest(jtu.JaxTestCase): + + def test_sharding_constraint_as_noop(self): + def f(x): + return jax.lax.with_sharding_constraint( + x, jax.sharding.SingleDeviceSharding(jax.devices()[0])) + + def wsc_as_noop(ctx, operand, *args, **kwargs): + del ctx, args, kwargs + return [operand] + + rules = ((jax.lax.sharding_constraint_p, wsc_as_noop),) + lowered_ir = ( + jax.jit(f) + .lower(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16), + _experimental_override_lowering_rules=rules).as_text()) + self.assertNotIn("stablehlo.custom_call", lowered_ir) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())