mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Makes it possible to lower primitives with user-defined lowering rules.
PiperOrigin-RevId: 547228102
This commit is contained in:
parent
17c4b57f97
commit
f81a48a819
@ -420,6 +420,10 @@ class ModuleContext:
|
||||
cached_primitive_lowerings: dict[Any, func_dialect.FuncOp]
|
||||
cached_call_jaxpr_lowerings: dict[Any, func_dialect.FuncOp]
|
||||
|
||||
# A mapping between primitives and user-defined LoweringRules.
|
||||
# When lowering a primitive, give priorioty to the rule in this map over
|
||||
# existing Jax rules.
|
||||
override_lowering_rules: Optional[tuple[tuple[core.Primitive, LoweringRule]]]
|
||||
|
||||
@property
|
||||
def axis_env(self) -> sharding_impls.AxisEnv:
|
||||
@ -442,6 +446,8 @@ class ModuleContext:
|
||||
func_dialect.FuncOp]] = None,
|
||||
cached_call_jaxpr_lowerings: Optional[dict[Any,
|
||||
func_dialect.FuncOp]] = None,
|
||||
override_lowering_rules: Optional[
|
||||
tuple[tuple[core.Primitive, LoweringRule]]] = None,
|
||||
shape_poly_state = None):
|
||||
assert platform is not None
|
||||
self.context = context or make_ir_context()
|
||||
@ -460,6 +466,7 @@ class ModuleContext:
|
||||
self.cached_call_jaxpr_lowerings = ({}
|
||||
if cached_call_jaxpr_lowerings is None
|
||||
else cached_call_jaxpr_lowerings)
|
||||
self.override_lowering_rules = override_lowering_rules
|
||||
self.shape_poly_state = shape_poly_state or ShapePolyLoweringState(())
|
||||
|
||||
@property
|
||||
@ -639,6 +646,8 @@ def lower_jaxpr_to_module(
|
||||
result_names: Optional[Sequence[Optional[str]]] = None,
|
||||
num_replicas: int = 1,
|
||||
num_partitions: int = 1,
|
||||
override_lowering_rules: Optional[
|
||||
tuple[tuple[core.Primitive, LoweringRule]]] = None,
|
||||
) -> LoweringResult:
|
||||
"""Lowers a top-level jaxpr to an MLIR module.
|
||||
|
||||
@ -691,6 +700,7 @@ def lower_jaxpr_to_module(
|
||||
|
||||
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
|
||||
keepalives, channel_iter, host_callbacks,
|
||||
override_lowering_rules=override_lowering_rules,
|
||||
shape_poly_state=ShapePolyLoweringState(dim_vars))
|
||||
with ctx.context, ir.Location.unknown(ctx.context):
|
||||
# Remove module name characters that XLA would alter. This ensures that
|
||||
@ -1135,6 +1145,13 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
assert node is not None
|
||||
env[v] = tuple(node)
|
||||
|
||||
def get_lowering(primitive: core.Primitive) -> Optional[LoweringRule]:
|
||||
if ctx.override_lowering_rules is None:
|
||||
return None
|
||||
for p, rule in ctx.override_lowering_rules:
|
||||
if primitive is p:
|
||||
return rule
|
||||
return None
|
||||
|
||||
env: dict[core.Var, tuple[ir.Value, ...]] = {}
|
||||
|
||||
@ -1153,7 +1170,10 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
loc = _source_info_to_location(eqn.primitive, eqn.params, source_info,
|
||||
ctx.name_stack)
|
||||
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
||||
if eqn.primitive in _platform_specific_lowerings[ctx.platform]:
|
||||
override_rule = get_lowering(eqn.primitive)
|
||||
if override_rule is not None:
|
||||
rule = override_rule
|
||||
elif eqn.primitive in _platform_specific_lowerings[ctx.platform]:
|
||||
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
|
||||
elif eqn.primitive in xla._backend_specific_translations[ctx.platform]:
|
||||
rule = xla_fallback_lowering(eqn.primitive)
|
||||
|
@ -1870,7 +1870,7 @@ class SemanticallyEqualShardings:
|
||||
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
semantic_in_shardings, semantic_out_shardings,
|
||||
da_object, lowering_platform,
|
||||
donated_invars, name_stack):
|
||||
donated_invars, name_stack, override_lowering_rules):
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
in_shardings = semantic_in_shardings.shardings
|
||||
out_shardings = semantic_out_shardings.shardings
|
||||
@ -1940,7 +1940,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=nreps,
|
||||
num_partitions=num_partitions)
|
||||
num_partitions=num_partitions,
|
||||
override_lowering_rules=override_lowering_rules)
|
||||
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
@ -1998,6 +1999,8 @@ def lower_sharding_computation(
|
||||
always_lower: bool,
|
||||
devices_from_context: Optional[Sequence[xc.Device]] = None,
|
||||
lowering_platform: Optional[str],
|
||||
override_lowering_rules: Optional[
|
||||
tuple[tuple[core.Primitive, mlir.LoweringRule]]] = None,
|
||||
) -> MeshComputation:
|
||||
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
|
||||
|
||||
@ -2084,7 +2087,7 @@ def lower_sharding_computation(
|
||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
||||
semantic_out_shardings, da_object, lowering_platform,
|
||||
donated_invars, name_stack)
|
||||
donated_invars, name_stack, override_lowering_rules)
|
||||
|
||||
# backend and device_assignment is passed through to MeshExecutable because
|
||||
# if keep_unused=False and all in_shardings are pruned, then there is no way
|
||||
|
@ -351,6 +351,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
def lower(*args, **kwargs):
|
||||
_experimental_lowering_platform = kwargs.pop(
|
||||
'_experimental_lowering_platform', None)
|
||||
_experimental_override_lowering_rules = kwargs.pop(
|
||||
'_experimental_override_lowering_rules', None)
|
||||
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
|
||||
donate_argnums) = infer_params_fn(*args, **kwargs)
|
||||
resource_env = params['resource_env']
|
||||
@ -362,7 +364,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
params['jaxpr'], in_shardings, params['out_shardings'],
|
||||
params['resource_env'], params['donated_invars'], params['name'],
|
||||
params['keep_unused'], params['inline'], always_lower=True,
|
||||
lowering_platform=_experimental_lowering_platform)
|
||||
lowering_platform=_experimental_lowering_platform,
|
||||
override_lowering_rules=_experimental_override_lowering_rules)
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
api_name = 'jit' if params['resource_env'] is None else 'pjit'
|
||||
@ -1268,7 +1271,9 @@ def _pjit_lower_cached(
|
||||
inline: bool,
|
||||
always_lower: bool,
|
||||
*,
|
||||
lowering_platform: Optional[str]):
|
||||
lowering_platform: Optional[str],
|
||||
override_lowering_rules: Optional[
|
||||
tuple[tuple[core.Primitive, mlir.LoweringRule]]] = None):
|
||||
in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast(
|
||||
tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
|
||||
out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings
|
||||
@ -1299,7 +1304,9 @@ def _pjit_lower_cached(
|
||||
keep_unused=keep_unused, inline=inline, always_lower=always_lower,
|
||||
devices_from_context=(
|
||||
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
|
||||
lowering_platform=lowering_platform)
|
||||
lowering_platform=lowering_platform,
|
||||
override_lowering_rules=override_lowering_rules,
|
||||
)
|
||||
|
||||
|
||||
def pjit_staging_rule(trace, *args, **params):
|
||||
|
@ -10086,5 +10086,25 @@ class DeprecationsTest(jtu.JaxTestCase):
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertIs(jax.flatten_fun_nokwargs, jax.api_util.flatten_fun_nokwargs)
|
||||
|
||||
|
||||
class OverrideLoweringTest(jtu.JaxTestCase):
|
||||
|
||||
def test_sharding_constraint_as_noop(self):
|
||||
def f(x):
|
||||
return jax.lax.with_sharding_constraint(
|
||||
x, jax.sharding.SingleDeviceSharding(jax.devices()[0]))
|
||||
|
||||
def wsc_as_noop(ctx, operand, *args, **kwargs):
|
||||
del ctx, args, kwargs
|
||||
return [operand]
|
||||
|
||||
rules = ((jax.lax.sharding_constraint_p, wsc_as_noop),)
|
||||
lowered_ir = (
|
||||
jax.jit(f)
|
||||
.lower(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16),
|
||||
_experimental_override_lowering_rules=rules).as_text())
|
||||
self.assertNotIn("stablehlo.custom_call", lowered_ir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user