Makes it possible to lower primitives with user-defined lowering rules.

PiperOrigin-RevId: 547228102
This commit is contained in:
Juliana Franco 2023-07-11 10:23:48 -07:00 committed by jax authors
parent 17c4b57f97
commit f81a48a819
4 changed files with 57 additions and 7 deletions

View File

@ -420,6 +420,10 @@ class ModuleContext:
cached_primitive_lowerings: dict[Any, func_dialect.FuncOp]
cached_call_jaxpr_lowerings: dict[Any, func_dialect.FuncOp]
# A mapping between primitives and user-defined LoweringRules.
# When lowering a primitive, give priorioty to the rule in this map over
# existing Jax rules.
override_lowering_rules: Optional[tuple[tuple[core.Primitive, LoweringRule]]]
@property
def axis_env(self) -> sharding_impls.AxisEnv:
@ -442,6 +446,8 @@ class ModuleContext:
func_dialect.FuncOp]] = None,
cached_call_jaxpr_lowerings: Optional[dict[Any,
func_dialect.FuncOp]] = None,
override_lowering_rules: Optional[
tuple[tuple[core.Primitive, LoweringRule]]] = None,
shape_poly_state = None):
assert platform is not None
self.context = context or make_ir_context()
@ -460,6 +466,7 @@ class ModuleContext:
self.cached_call_jaxpr_lowerings = ({}
if cached_call_jaxpr_lowerings is None
else cached_call_jaxpr_lowerings)
self.override_lowering_rules = override_lowering_rules
self.shape_poly_state = shape_poly_state or ShapePolyLoweringState(())
@property
@ -639,6 +646,8 @@ def lower_jaxpr_to_module(
result_names: Optional[Sequence[Optional[str]]] = None,
num_replicas: int = 1,
num_partitions: int = 1,
override_lowering_rules: Optional[
tuple[tuple[core.Primitive, LoweringRule]]] = None,
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
@ -691,6 +700,7 @@ def lower_jaxpr_to_module(
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
keepalives, channel_iter, host_callbacks,
override_lowering_rules=override_lowering_rules,
shape_poly_state=ShapePolyLoweringState(dim_vars))
with ctx.context, ir.Location.unknown(ctx.context):
# Remove module name characters that XLA would alter. This ensures that
@ -1135,6 +1145,13 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
assert node is not None
env[v] = tuple(node)
def get_lowering(primitive: core.Primitive) -> Optional[LoweringRule]:
if ctx.override_lowering_rules is None:
return None
for p, rule in ctx.override_lowering_rules:
if primitive is p:
return rule
return None
env: dict[core.Var, tuple[ir.Value, ...]] = {}
@ -1153,7 +1170,10 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
loc = _source_info_to_location(eqn.primitive, eqn.params, source_info,
ctx.name_stack)
with source_info_util.user_context(eqn.source_info.traceback), loc:
if eqn.primitive in _platform_specific_lowerings[ctx.platform]:
override_rule = get_lowering(eqn.primitive)
if override_rule is not None:
rule = override_rule
elif eqn.primitive in _platform_specific_lowerings[ctx.platform]:
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
elif eqn.primitive in xla._backend_specific_translations[ctx.platform]:
rule = xla_fallback_lowering(eqn.primitive)

View File

@ -1870,7 +1870,7 @@ class SemanticallyEqualShardings:
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
semantic_in_shardings, semantic_out_shardings,
da_object, lowering_platform,
donated_invars, name_stack):
donated_invars, name_stack, override_lowering_rules):
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings.shardings
out_shardings = semantic_out_shardings.shardings
@ -1940,7 +1940,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=nreps,
num_partitions=num_partitions)
num_partitions=num_partitions,
override_lowering_rules=override_lowering_rules)
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
@ -1998,6 +1999,8 @@ def lower_sharding_computation(
always_lower: bool,
devices_from_context: Optional[Sequence[xc.Device]] = None,
lowering_platform: Optional[str],
override_lowering_rules: Optional[
tuple[tuple[core.Primitive, mlir.LoweringRule]]] = None,
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
@ -2084,7 +2087,7 @@ def lower_sharding_computation(
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
semantic_out_shardings, da_object, lowering_platform,
donated_invars, name_stack)
donated_invars, name_stack, override_lowering_rules)
# backend and device_assignment is passed through to MeshExecutable because
# if keep_unused=False and all in_shardings are pruned, then there is no way

View File

@ -351,6 +351,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
def lower(*args, **kwargs):
_experimental_lowering_platform = kwargs.pop(
'_experimental_lowering_platform', None)
_experimental_override_lowering_rules = kwargs.pop(
'_experimental_override_lowering_rules', None)
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
donate_argnums) = infer_params_fn(*args, **kwargs)
resource_env = params['resource_env']
@ -362,7 +364,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
params['jaxpr'], in_shardings, params['out_shardings'],
params['resource_env'], params['donated_invars'], params['name'],
params['keep_unused'], params['inline'], always_lower=True,
lowering_platform=_experimental_lowering_platform)
lowering_platform=_experimental_lowering_platform,
override_lowering_rules=_experimental_override_lowering_rules)
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if params['resource_env'] is None else 'pjit'
@ -1268,7 +1271,9 @@ def _pjit_lower_cached(
inline: bool,
always_lower: bool,
*,
lowering_platform: Optional[str]):
lowering_platform: Optional[str],
override_lowering_rules: Optional[
tuple[tuple[core.Primitive, mlir.LoweringRule]]] = None):
in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast(
tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings
@ -1299,7 +1304,9 @@ def _pjit_lower_cached(
keep_unused=keep_unused, inline=inline, always_lower=always_lower,
devices_from_context=(
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
lowering_platform=lowering_platform)
lowering_platform=lowering_platform,
override_lowering_rules=override_lowering_rules,
)
def pjit_staging_rule(trace, *args, **params):

View File

@ -10086,5 +10086,25 @@ class DeprecationsTest(jtu.JaxTestCase):
with self.assertWarns(DeprecationWarning):
self.assertIs(jax.flatten_fun_nokwargs, jax.api_util.flatten_fun_nokwargs)
class OverrideLoweringTest(jtu.JaxTestCase):
def test_sharding_constraint_as_noop(self):
def f(x):
return jax.lax.with_sharding_constraint(
x, jax.sharding.SingleDeviceSharding(jax.devices()[0]))
def wsc_as_noop(ctx, operand, *args, **kwargs):
del ctx, args, kwargs
return [operand]
rules = ((jax.lax.sharding_constraint_p, wsc_as_noop),)
lowered_ir = (
jax.jit(f)
.lower(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16),
_experimental_override_lowering_rules=rules).as_text())
self.assertNotIn("stablehlo.custom_call", lowered_ir)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())