From cb4abe754a44a6ffe5508413957b9d6de394d192 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 13 Apr 2022 07:25:51 -0700 Subject: [PATCH] [MHLO] Separate registrations for collective and initial_style primitives from the XLA translation rule registration. Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives. PiperOrigin-RevId: 441474701 --- jax/_src/custom_batching.py | 4 ++-- jax/_src/custom_derivatives.py | 12 +++++----- jax/_src/custom_transpose.py | 4 ++-- jax/_src/lax/control_flow.py | 15 +++++++------ jax/_src/lax/parallel.py | 40 +++++++++++++++++----------------- jax/interpreters/xla.py | 14 ++++++------ 6 files changed, 45 insertions(+), 44 deletions(-) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 6cd7e69dc..21c34ca8d 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -231,10 +231,10 @@ custom_vmap_p.def_impl(custom_vmap_impl) custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval) batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp +xla.register_initial_style_primitive(custom_vmap_p) xla.register_translation(custom_vmap_p, xla.lower_fun(custom_vmap_impl, new_style=True, - multiple_results=True), - initial_style=True) + multiple_results=True)) mlir.register_lowering(custom_vmap_p, mlir.lower_fun( custom_vmap_impl, multiple_results=True)) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index f4c01b555..87e489cc0 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -397,11 +397,11 @@ def _custom_jvp_call_jaxpr_vmap( return batched_outs, out_dims batching.axis_primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap +xla.register_initial_style_primitive(custom_jvp_call_jaxpr_p) xla.register_translation( custom_jvp_call_jaxpr_p, xla.lower_fun(_custom_jvp_call_jaxpr_impl, new_style=True, - multiple_results=True), - initial_style=True) + multiple_results=True)) # If a (multi)linear function is defined with a custom jvp, then # custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's @@ -768,11 +768,11 @@ def _custom_vjp_call_jaxpr_vmap( return batched_outs, out_dims batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap +xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p) xla.register_translation( custom_vjp_call_jaxpr_p, xla.lower_fun(_custom_vjp_call_jaxpr_impl, new_style=True, - multiple_results=True), - initial_style=True) + multiple_results=True)) batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp xla.register_translation(ad.custom_lin_p, ad._raise_custom_vjp_error_on_jvp) @@ -1164,10 +1164,10 @@ linear_call_p.multiple_results = True linear_call_p.def_impl(_linear_call_impl) linear_call_p.def_abstract_eval(_linear_call_abstract_eval) ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule +xla.register_initial_style_primitive(linear_call_p) xla.register_translation(linear_call_p, xla.lower_fun(_linear_call_impl, new_style=True, - multiple_results=True), - initial_style=True) + multiple_results=True)) mlir.register_lowering(linear_call_p, mlir.lower_fun( _linear_call_impl, multiple_results=True)) diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index a79831970..b7f47b486 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -225,8 +225,8 @@ ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule mlir.register_lowering( custom_transpose_p, mlir.lower_fun(custom_transpose_lowering, multiple_results=True)) +xla.register_initial_style_primitive(custom_transpose_p) xla.register_translation( custom_transpose_p, xla.lower_fun( - custom_transpose_lowering, new_style=True, multiple_results=True), - initial_style=True) + custom_transpose_lowering, new_style=True, multiple_results=True)) diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index 3a40b7ba0..9e5a0598f 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -620,8 +620,8 @@ while_p.def_impl(partial(xla.apply_primitive, while_p)) while_p.def_effectful_abstract_eval(_while_loop_abstract_eval) ad.primitive_jvps[while_p] = _while_loop_jvp pe.custom_partial_eval_rules[while_p] = _while_partial_eval -xla.register_translation(while_p, _while_loop_translation_rule, - initial_style=True) +xla.register_initial_style_primitive(while_p) +xla.register_translation(while_p, _while_loop_translation_rule) ad.primitive_transposes[while_p] = _while_transpose_error batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule pe.partial_eval_jaxpr_custom_rules[while_p] = \ @@ -1342,7 +1342,8 @@ ad.primitive_jvps[cond_p] = _cond_jvp ad.reducing_transposes[cond_p] = _cond_transpose pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval batching.axis_primitive_batchers[cond_p] = _cond_batching_rule -xla.register_translation(cond_p, _cond_translation_rule, initial_style=True) +xla.register_initial_style_primitive(cond_p) +xla.register_translation(cond_p, _cond_translation_rule) core.custom_typechecks[cond_p] = _cond_typecheck pe.partial_eval_jaxpr_custom_rules[cond_p] = \ partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond') @@ -2132,9 +2133,9 @@ scan_p.def_effectful_abstract_eval(_scan_abstract_eval) ad.primitive_jvps[scan_p] = _scan_jvp ad.reducing_transposes[scan_p] = _scan_transpose pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval +xla.register_initial_style_primitive(scan_p) xla.register_translation(scan_p, xla.lower_fun(_scan_impl, new_style=True, - multiple_results=True), - initial_style=True) + multiple_results=True)) mlir.register_lowering(scan_p, mlir.lower_fun(_scan_impl, multiple_results=True)) batching.axis_primitive_batchers[scan_p] = _scan_batching_rule @@ -2692,10 +2693,10 @@ linear_solve_p.multiple_results = True linear_solve_p.def_impl(_custom_linear_solve_impl) linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval) ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp +xla.register_initial_style_primitive(linear_solve_p) xla.register_translation( linear_solve_p, xla.lower_fun(_custom_linear_solve_impl, new_style=True, - multiple_results=True), - initial_style=True) + multiple_results=True)) mlir.register_lowering( linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl, multiple_results=True)) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 262ec44e6..7a9308968 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -783,9 +783,9 @@ psum_p = core.AxisPrimitive('psum') psum_p.multiple_results = True psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum)) psum_p.def_abstract_eval(_allreduce_abstract_eval) +xla.register_collective_primitive(psum_p) xla.register_translation( - psum_p, partial(_allreduce_translation_rule, lax.add_p, lax._reduce_sum), - is_collective=True) + psum_p, partial(_allreduce_translation_rule, lax.add_p, lax._reduce_sum)) mlir.register_lowering( psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum)) ad.deflinear2(psum_p, _psum_transpose_rule) @@ -822,9 +822,9 @@ pmax_p = core.AxisPrimitive('pmax') pmax_p.multiple_results = True pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max)) pmax_p.def_abstract_eval(_allreduce_abstract_eval) +xla.register_collective_primitive(pmax_p) xla.register_translation( - pmax_p, partial(_allreduce_translation_rule, lax.max_p, lax._reduce_max), - is_collective=True) + pmax_p, partial(_allreduce_translation_rule, lax.max_p, lax._reduce_max)) mlir.register_lowering( pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max)) pxla.multi_host_supported_collectives.add(pmax_p) @@ -838,9 +838,9 @@ pmin_p = core.AxisPrimitive('pmin') pmin_p.multiple_results = True pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min)) pmin_p.def_abstract_eval(_allreduce_abstract_eval) +xla.register_collective_primitive(pmin_p) xla.register_translation( - pmin_p, partial(_allreduce_translation_rule, lax.min_p, lax._reduce_min), - is_collective=True) + pmin_p, partial(_allreduce_translation_rule, lax.min_p, lax._reduce_min)) mlir.register_lowering( pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min)) pxla.multi_host_supported_collectives.add(pmin_p) @@ -910,8 +910,8 @@ def _collective_batcher(prim, args, dims, **params): ppermute_p = core.AxisPrimitive('ppermute') ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) ad.deflinear2(ppermute_p, _ppermute_transpose_rule) -xla.register_translation(ppermute_p, _ppermute_translation_rule, - is_collective=True) +xla.register_collective_primitive(ppermute_p) +xla.register_translation(ppermute_p, _ppermute_translation_rule) mlir.register_lowering(ppermute_p, _ppermute_lowering) pxla.multi_host_supported_collectives.add(ppermute_p) batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p) @@ -1102,8 +1102,8 @@ def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_ all_to_all_p = core.AxisPrimitive('all_to_all') all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval) -xla.register_translation(all_to_all_p, _all_to_all_translation_rule, - is_collective=True) +xla.register_collective_primitive(all_to_all_p) +xla.register_translation(all_to_all_p, _all_to_all_translation_rule) mlir.register_lowering(all_to_all_p, _all_to_all_lowering) ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule) pxla.multi_host_supported_collectives.add(all_to_all_p) @@ -1323,8 +1323,8 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in, all_gather_p = core.AxisPrimitive('all_gather') all_gather_p.def_abstract_eval(_all_gather_abstract_eval) all_gather_p.def_impl(_all_gather_impl) -xla.register_translation(all_gather_p, _all_gather_translation_rule, - is_collective=True) +xla.register_collective_primitive(all_gather_p) +xla.register_translation(all_gather_p, _all_gather_translation_rule) mlir.register_lowering(all_gather_p, _all_gather_lowering) ad.deflinear2(all_gather_p, _all_gather_transpose_rule) pxla.multi_host_supported_collectives.add(all_gather_p) @@ -1462,10 +1462,10 @@ def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension, reduce_scatter_p = core.AxisPrimitive("reduce_scatter") reduce_scatter_p.def_abstract_eval(_reduce_scatter_abstract_eval) +xla.register_collective_primitive(reduce_scatter_p) xla.register_translation( reduce_scatter_p, - partial(_reduce_scatter_translation_rule, lax.add_p, psum), - is_collective=True) + partial(_reduce_scatter_translation_rule, lax.add_p, psum)) mlir.register_lowering( reduce_scatter_p, partial(_reduce_scatter_lowering, lax.add_p, psum)) @@ -1590,8 +1590,8 @@ def _axis_index_abstract_eval(*, axis_name): return ShapedArray((), np.int32, named_shape={axis_name: frame.size}) axis_index_p = core.Primitive('axis_index') -xla.register_translation(axis_index_p, _axis_index_translation_rule, - is_collective=True) +xla.register_collective_primitive(axis_index_p) +xla.register_translation(axis_index_p, _axis_index_translation_rule) mlir.register_lowering(axis_index_p, _axis_index_lowering) axis_index_p.def_abstract_eval(_axis_index_abstract_eval) pxla.multi_host_supported_collectives.add(axis_index_p) @@ -1683,10 +1683,10 @@ def _pdot_lowering(x, y, *, axis_name, pos_contract, pos_batch, precision): precision=precision, preferred_element_type=None) return psum(local_out, axis_name) if axis_name is not None else local_out +xla.register_collective_primitive(pdot_p) xla.register_translation( pdot_p, - xla.lower_fun(_pdot_lowering, multiple_results=False, new_style=True), - is_collective=True) + xla.lower_fun(_pdot_lowering, multiple_results=False, new_style=True)) mlir.register_lowering( pdot_p, mlir.lower_fun(_pdot_lowering, multiple_results=False)) @@ -1785,8 +1785,8 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a pgather_p = core.AxisPrimitive('pgather') pgather_p.def_impl(_pgather_impl) pgather_p.def_abstract_eval(_pgather_abstract_eval) -xla.register_translation(pgather_p, _pgather_parallel_translation, - is_collective=True) +xla.register_collective_primitive(pgather_p) +xla.register_translation(pgather_p, _pgather_parallel_translation) mlir.register_lowering(pgather_p, _pgather_parallel_lowering) # TODO: Transpose? That requires adding pscatter... batching.primitive_batchers[pgather_p] = _pgather_batcher diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 2c74a9ccf..9ea7e8b7f 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -870,17 +870,17 @@ _backend_specific_translations = defaultdict(dict) _collective_primitives: Set[core.Primitive] = set() _initial_style_primitives: Set[core.Primitive] = set() +def register_initial_style_primitive(prim: core.Primitive): + _initial_style_primitives.add(prim) + +def register_collective_primitive(prim: core.Primitive): + _collective_primitives.add(prim) + def register_translation(prim: core.Primitive, rule: TranslationRule, *, - platform: Optional[str] = None, - is_collective: bool = False, - initial_style: bool = False) -> None: + platform: Optional[str] = None) -> None: ts = (_translations if platform is None else _backend_specific_translations[platform]) ts[prim] = rule - if is_collective: - _collective_primitives.add(prim) - if initial_style: - _initial_style_primitives.add(prim) # As a temporary backward compatibility measure, we use an adapter class to # convert from the old styles of translation rules to the newer ones.