[MHLO] Separate registrations for collective and initial_style primitives from the XLA translation rule registration.

Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives.

PiperOrigin-RevId: 441474701
This commit is contained in:
Peter Hawkins 2022-04-13 07:25:51 -07:00 committed by jax authors
parent ad8e6ada4e
commit cb4abe754a
6 changed files with 45 additions and 44 deletions

View File

@ -231,10 +231,10 @@ custom_vmap_p.def_impl(custom_vmap_impl)
custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval)
batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching
ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp
xla.register_initial_style_primitive(custom_vmap_p)
xla.register_translation(custom_vmap_p,
xla.lower_fun(custom_vmap_impl, new_style=True,
multiple_results=True),
initial_style=True)
multiple_results=True))
mlir.register_lowering(custom_vmap_p, mlir.lower_fun(
custom_vmap_impl, multiple_results=True))

View File

@ -397,11 +397,11 @@ def _custom_jvp_call_jaxpr_vmap(
return batched_outs, out_dims
batching.axis_primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap
xla.register_initial_style_primitive(custom_jvp_call_jaxpr_p)
xla.register_translation(
custom_jvp_call_jaxpr_p,
xla.lower_fun(_custom_jvp_call_jaxpr_impl, new_style=True,
multiple_results=True),
initial_style=True)
multiple_results=True))
# If a (multi)linear function is defined with a custom jvp, then
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
@ -768,11 +768,11 @@ def _custom_vjp_call_jaxpr_vmap(
return batched_outs, out_dims
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p)
xla.register_translation(
custom_vjp_call_jaxpr_p,
xla.lower_fun(_custom_vjp_call_jaxpr_impl, new_style=True,
multiple_results=True),
initial_style=True)
multiple_results=True))
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
xla.register_translation(ad.custom_lin_p, ad._raise_custom_vjp_error_on_jvp)
@ -1164,10 +1164,10 @@ linear_call_p.multiple_results = True
linear_call_p.def_impl(_linear_call_impl)
linear_call_p.def_abstract_eval(_linear_call_abstract_eval)
ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule
xla.register_initial_style_primitive(linear_call_p)
xla.register_translation(linear_call_p,
xla.lower_fun(_linear_call_impl, new_style=True,
multiple_results=True),
initial_style=True)
multiple_results=True))
mlir.register_lowering(linear_call_p, mlir.lower_fun(
_linear_call_impl, multiple_results=True))

View File

@ -225,8 +225,8 @@ ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
mlir.register_lowering(
custom_transpose_p,
mlir.lower_fun(custom_transpose_lowering, multiple_results=True))
xla.register_initial_style_primitive(custom_transpose_p)
xla.register_translation(
custom_transpose_p,
xla.lower_fun(
custom_transpose_lowering, new_style=True, multiple_results=True),
initial_style=True)
custom_transpose_lowering, new_style=True, multiple_results=True))

View File

@ -620,8 +620,8 @@ while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
xla.register_translation(while_p, _while_loop_translation_rule,
initial_style=True)
xla.register_initial_style_primitive(while_p)
xla.register_translation(while_p, _while_loop_translation_rule)
ad.primitive_transposes[while_p] = _while_transpose_error
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
pe.partial_eval_jaxpr_custom_rules[while_p] = \
@ -1342,7 +1342,8 @@ ad.primitive_jvps[cond_p] = _cond_jvp
ad.reducing_transposes[cond_p] = _cond_transpose
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
batching.axis_primitive_batchers[cond_p] = _cond_batching_rule
xla.register_translation(cond_p, _cond_translation_rule, initial_style=True)
xla.register_initial_style_primitive(cond_p)
xla.register_translation(cond_p, _cond_translation_rule)
core.custom_typechecks[cond_p] = _cond_typecheck
pe.partial_eval_jaxpr_custom_rules[cond_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond')
@ -2132,9 +2133,9 @@ scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp
ad.reducing_transposes[scan_p] = _scan_transpose
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
xla.register_initial_style_primitive(scan_p)
xla.register_translation(scan_p, xla.lower_fun(_scan_impl, new_style=True,
multiple_results=True),
initial_style=True)
multiple_results=True))
mlir.register_lowering(scan_p,
mlir.lower_fun(_scan_impl, multiple_results=True))
batching.axis_primitive_batchers[scan_p] = _scan_batching_rule
@ -2692,10 +2693,10 @@ linear_solve_p.multiple_results = True
linear_solve_p.def_impl(_custom_linear_solve_impl)
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
xla.register_initial_style_primitive(linear_solve_p)
xla.register_translation(
linear_solve_p, xla.lower_fun(_custom_linear_solve_impl, new_style=True,
multiple_results=True),
initial_style=True)
multiple_results=True))
mlir.register_lowering(
linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl,
multiple_results=True))

View File

@ -783,9 +783,9 @@ psum_p = core.AxisPrimitive('psum')
psum_p.multiple_results = True
psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum))
psum_p.def_abstract_eval(_allreduce_abstract_eval)
xla.register_collective_primitive(psum_p)
xla.register_translation(
psum_p, partial(_allreduce_translation_rule, lax.add_p, lax._reduce_sum),
is_collective=True)
psum_p, partial(_allreduce_translation_rule, lax.add_p, lax._reduce_sum))
mlir.register_lowering(
psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum))
ad.deflinear2(psum_p, _psum_transpose_rule)
@ -822,9 +822,9 @@ pmax_p = core.AxisPrimitive('pmax')
pmax_p.multiple_results = True
pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max))
pmax_p.def_abstract_eval(_allreduce_abstract_eval)
xla.register_collective_primitive(pmax_p)
xla.register_translation(
pmax_p, partial(_allreduce_translation_rule, lax.max_p, lax._reduce_max),
is_collective=True)
pmax_p, partial(_allreduce_translation_rule, lax.max_p, lax._reduce_max))
mlir.register_lowering(
pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max))
pxla.multi_host_supported_collectives.add(pmax_p)
@ -838,9 +838,9 @@ pmin_p = core.AxisPrimitive('pmin')
pmin_p.multiple_results = True
pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min))
pmin_p.def_abstract_eval(_allreduce_abstract_eval)
xla.register_collective_primitive(pmin_p)
xla.register_translation(
pmin_p, partial(_allreduce_translation_rule, lax.min_p, lax._reduce_min),
is_collective=True)
pmin_p, partial(_allreduce_translation_rule, lax.min_p, lax._reduce_min))
mlir.register_lowering(
pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min))
pxla.multi_host_supported_collectives.add(pmin_p)
@ -910,8 +910,8 @@ def _collective_batcher(prim, args, dims, **params):
ppermute_p = core.AxisPrimitive('ppermute')
ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
ad.deflinear2(ppermute_p, _ppermute_transpose_rule)
xla.register_translation(ppermute_p, _ppermute_translation_rule,
is_collective=True)
xla.register_collective_primitive(ppermute_p)
xla.register_translation(ppermute_p, _ppermute_translation_rule)
mlir.register_lowering(ppermute_p, _ppermute_lowering)
pxla.multi_host_supported_collectives.add(ppermute_p)
batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p)
@ -1102,8 +1102,8 @@ def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_
all_to_all_p = core.AxisPrimitive('all_to_all')
all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval)
xla.register_translation(all_to_all_p, _all_to_all_translation_rule,
is_collective=True)
xla.register_collective_primitive(all_to_all_p)
xla.register_translation(all_to_all_p, _all_to_all_translation_rule)
mlir.register_lowering(all_to_all_p, _all_to_all_lowering)
ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
pxla.multi_host_supported_collectives.add(all_to_all_p)
@ -1323,8 +1323,8 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in,
all_gather_p = core.AxisPrimitive('all_gather')
all_gather_p.def_abstract_eval(_all_gather_abstract_eval)
all_gather_p.def_impl(_all_gather_impl)
xla.register_translation(all_gather_p, _all_gather_translation_rule,
is_collective=True)
xla.register_collective_primitive(all_gather_p)
xla.register_translation(all_gather_p, _all_gather_translation_rule)
mlir.register_lowering(all_gather_p, _all_gather_lowering)
ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
pxla.multi_host_supported_collectives.add(all_gather_p)
@ -1462,10 +1462,10 @@ def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
reduce_scatter_p = core.AxisPrimitive("reduce_scatter")
reduce_scatter_p.def_abstract_eval(_reduce_scatter_abstract_eval)
xla.register_collective_primitive(reduce_scatter_p)
xla.register_translation(
reduce_scatter_p,
partial(_reduce_scatter_translation_rule, lax.add_p, psum),
is_collective=True)
partial(_reduce_scatter_translation_rule, lax.add_p, psum))
mlir.register_lowering(
reduce_scatter_p,
partial(_reduce_scatter_lowering, lax.add_p, psum))
@ -1590,8 +1590,8 @@ def _axis_index_abstract_eval(*, axis_name):
return ShapedArray((), np.int32, named_shape={axis_name: frame.size})
axis_index_p = core.Primitive('axis_index')
xla.register_translation(axis_index_p, _axis_index_translation_rule,
is_collective=True)
xla.register_collective_primitive(axis_index_p)
xla.register_translation(axis_index_p, _axis_index_translation_rule)
mlir.register_lowering(axis_index_p, _axis_index_lowering)
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
pxla.multi_host_supported_collectives.add(axis_index_p)
@ -1683,10 +1683,10 @@ def _pdot_lowering(x, y, *, axis_name, pos_contract, pos_batch, precision):
precision=precision, preferred_element_type=None)
return psum(local_out, axis_name) if axis_name is not None else local_out
xla.register_collective_primitive(pdot_p)
xla.register_translation(
pdot_p,
xla.lower_fun(_pdot_lowering, multiple_results=False, new_style=True),
is_collective=True)
xla.lower_fun(_pdot_lowering, multiple_results=False, new_style=True))
mlir.register_lowering(
pdot_p,
mlir.lower_fun(_pdot_lowering, multiple_results=False))
@ -1785,8 +1785,8 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a
pgather_p = core.AxisPrimitive('pgather')
pgather_p.def_impl(_pgather_impl)
pgather_p.def_abstract_eval(_pgather_abstract_eval)
xla.register_translation(pgather_p, _pgather_parallel_translation,
is_collective=True)
xla.register_collective_primitive(pgather_p)
xla.register_translation(pgather_p, _pgather_parallel_translation)
mlir.register_lowering(pgather_p, _pgather_parallel_lowering)
# TODO: Transpose? That requires adding pscatter...
batching.primitive_batchers[pgather_p] = _pgather_batcher

View File

@ -870,17 +870,17 @@ _backend_specific_translations = defaultdict(dict)
_collective_primitives: Set[core.Primitive] = set()
_initial_style_primitives: Set[core.Primitive] = set()
def register_initial_style_primitive(prim: core.Primitive):
_initial_style_primitives.add(prim)
def register_collective_primitive(prim: core.Primitive):
_collective_primitives.add(prim)
def register_translation(prim: core.Primitive, rule: TranslationRule, *,
platform: Optional[str] = None,
is_collective: bool = False,
initial_style: bool = False) -> None:
platform: Optional[str] = None) -> None:
ts = (_translations if platform is None
else _backend_specific_translations[platform])
ts[prim] = rule
if is_collective:
_collective_primitives.add(prim)
if initial_style:
_initial_style_primitives.add(prim)
# As a temporary backward compatibility measure, we use an adapter class to
# convert from the old styles of translation rules to the newer ones.